├── .coveragerc ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── continuous_integration.yml │ └── test_pr.yml ├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── TODO.txt ├── docs ├── .gitignore ├── Makefile ├── __init__.py ├── _static │ └── theme_overrides.css ├── _templates │ ├── latex.tex_t │ ├── tabular.tex_t │ └── tabulary.tex_t ├── conf.py ├── index.rst ├── requirements.txt └── source │ ├── features │ ├── mushroom_rl.features.basis.rst │ ├── mushroom_rl.features.tensors.rst │ └── mushroom_rl.features.tiles.rst │ ├── mushroom_rl.agent_environment_interface.rst │ ├── mushroom_rl.algorithms.actor_critic.rst │ ├── mushroom_rl.algorithms.policy_search.rst │ ├── mushroom_rl.algorithms.value.rst │ ├── mushroom_rl.approximators.rst │ ├── mushroom_rl.distributions.rst │ ├── mushroom_rl.environments.rst │ ├── mushroom_rl.features.rst │ ├── mushroom_rl.policy.rst │ ├── mushroom_rl.rl_utils.rst │ ├── mushroom_rl.solvers.rst │ ├── mushroom_rl.utils.rst │ └── tutorials │ ├── code │ ├── advanced_experiment.py │ ├── approximator.py │ ├── ddpg.py │ ├── dqn.py │ ├── generic_regressor.py │ ├── logger.py │ ├── room_env.py │ ├── serialization.py │ └── simple_experiment.py │ ├── tutorials.0_experiments.rst │ ├── tutorials.1_advanced.rst │ ├── tutorials.2_approximator.rst │ ├── tutorials.3_deep.rst │ ├── tutorials.4_logger.rst │ ├── tutorials.5_environments.rst │ ├── tutorials.6_serializable.rst │ └── tutorials.99_examples.rst ├── examples ├── __init__.py ├── acrobot_a2c.py ├── acrobot_dqn.py ├── atari_dqn.py ├── car_on_hill_fqi.py ├── cartpole_lspi.py ├── double_chain_q_learning │ ├── __init__.py │ ├── chain_structure │ │ ├── p.npy │ │ └── rew.npy │ └── double_chain.py ├── grid_world_td.py ├── gym_recurrent_ppo.py ├── habitat │ ├── __init__.py │ ├── habitat_nav_dqn.py │ ├── habitat_rearrange_sac.py │ ├── pointnav_apartment-0.yaml │ ├── replica_train_apartment-0.json │ └── replica_train_apartment-0.json.gz ├── igibson_dqn.py ├── isaacsim │ ├── a1_rudin_ppo.py │ ├── cartpole_ppo.py │ ├── honey_badger_ppo.py │ └── silver_badger_ppo.py ├── lqr_bbo.py ├── lqr_pg.py ├── minigrid_dqn.py ├── mountain_car_sarsa.py ├── mujoco_air_hockey_sac.py ├── mujoco_locomotion_ppo.py ├── mujoco_manipulation_ppo.py ├── omni_isaac_gym_example.py ├── pendulum_a2c.py ├── pendulum_ac.py ├── pendulum_ddpg.py ├── pendulum_dpg.py ├── pendulum_sac.py ├── pendulum_trust_region.py ├── plotting_and_normalization.py ├── puddle_world_sarsa.py ├── segway_bbo.py ├── segway_eppo.py ├── ship_steering_bbo.py ├── simple_chain_qlearning.py ├── taxi_mellow_sarsa │ ├── grid.txt │ └── taxi_mellow.py ├── vectorized_core │ ├── __init__.py │ ├── pendulum_trust_region.py │ └── segway_bbo.py ├── walker_stand_ddpg.py └── walker_stand_ddpg_shared_net.py ├── mushroom_rl ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── actor_critic │ │ ├── __init__.py │ │ ├── classic_actor_critic │ │ │ ├── __init__.py │ │ │ ├── copdac_q.py │ │ │ └── stochastic_ac.py │ │ └── deep_actor_critic │ │ │ ├── __init__.py │ │ │ ├── a2c.py │ │ │ ├── ddpg.py │ │ │ ├── deep_actor_critic.py │ │ │ ├── ppo.py │ │ │ ├── ppo_bptt.py │ │ │ ├── ppo_rudin.py │ │ │ ├── sac.py │ │ │ ├── td3.py │ │ │ └── trpo.py │ ├── policy_search │ │ ├── __init__.py │ │ ├── black_box_optimization │ │ │ ├── __init__.py │ │ │ ├── black_box_optimization.py │ │ │ ├── constrained_reps.py │ │ │ ├── context_builder.py │ │ │ ├── eppo.py │ │ │ ├── more.py │ │ │ ├── pgpe.py │ │ │ ├── reps.py │ │ │ └── rwr.py │ │ └── policy_gradient │ │ │ ├── __init__.py │ │ │ ├── enac.py │ │ │ ├── gpomdp.py │ │ │ ├── policy_gradient.py │ │ │ └── reinforce.py │ └── value │ │ ├── __init__.py │ │ ├── batch_td │ │ ├── __init__.py │ │ ├── batch_td.py │ │ ├── boosted_fqi.py │ │ ├── double_fqi.py │ │ ├── fqi.py │ │ └── lspi.py │ │ ├── dqn │ │ ├── __init__.py │ │ ├── abstract_dqn.py │ │ ├── averaged_dqn.py │ │ ├── categorical_dqn.py │ │ ├── double_dqn.py │ │ ├── dqn.py │ │ ├── dueling_dqn.py │ │ ├── maxmin_dqn.py │ │ ├── noisy_dqn.py │ │ ├── quantile_dqn.py │ │ └── rainbow.py │ │ └── td │ │ ├── __init__.py │ │ ├── double_q_learning.py │ │ ├── expected_sarsa.py │ │ ├── maxmin_q_learning.py │ │ ├── q_lambda.py │ │ ├── q_learning.py │ │ ├── r_learning.py │ │ ├── rq_learning.py │ │ ├── sarsa.py │ │ ├── sarsa_lambda.py │ │ ├── sarsa_lambda_continuous.py │ │ ├── speedy_q_learning.py │ │ ├── td.py │ │ ├── true_online_sarsa_lambda.py │ │ └── weighted_q_learning.py ├── approximators │ ├── __init__.py │ ├── _implementations │ │ ├── __init__.py │ │ ├── action_regressor.py │ │ ├── generic_regressor.py │ │ └── q_regressor.py │ ├── ensemble.py │ ├── ensemble_table.py │ ├── parametric │ │ ├── __init__.py │ │ ├── cmac.py │ │ ├── linear.py │ │ ├── networks │ │ │ ├── __init__.py │ │ │ └── linear_network.py │ │ └── torch_approximator.py │ ├── regressor.py │ └── table.py ├── core │ ├── __init__.py │ ├── _impl │ │ ├── __init__.py │ │ ├── core_logic.py │ │ ├── list_dataset.py │ │ ├── numpy_dataset.py │ │ ├── torch_dataset.py │ │ └── vectorized_core_logic.py │ ├── agent.py │ ├── array_backend.py │ ├── core.py │ ├── dataset.py │ ├── environment.py │ ├── extra_info.py │ ├── logger │ │ ├── __init__.py │ │ ├── console_logger.py │ │ ├── data_logger.py │ │ └── logger.py │ ├── multiprocess_environment.py │ ├── serialization.py │ ├── vectorized_core.py │ └── vectorized_env.py ├── distributions │ ├── __init__.py │ ├── distribution.py │ ├── gaussian.py │ └── torch_distribution.py ├── environments │ ├── __init__.py │ ├── atari.py │ ├── car_on_hill.py │ ├── cart_pole.py │ ├── dm_control_env.py │ ├── finite_mdp.py │ ├── generators │ │ ├── __init__.py │ │ ├── grid_world.py │ │ ├── simple_chain.py │ │ └── taxi.py │ ├── grid_world.py │ ├── gymnasium_env.py │ ├── habitat_env.py │ ├── igibson_env.py │ ├── inverted_pendulum.py │ ├── isaacsim_env.py │ ├── isaacsim_envs │ │ ├── __init__.py │ │ ├── a1_walking.py │ │ ├── cartpole.py │ │ ├── honey_badger_walking.py │ │ ├── robots_usds │ │ │ ├── a1 │ │ │ │ ├── .thumbs │ │ │ │ │ └── 256x256 │ │ │ │ │ │ ├── a1.usd.png │ │ │ │ │ │ └── instanceable_meshes.usd.png │ │ │ │ ├── a1.usd │ │ │ │ └── instanceable_meshes.usd │ │ │ ├── cartpole │ │ │ │ └── cartpole.usd │ │ │ ├── honey_badger │ │ │ │ ├── .thumbs │ │ │ │ │ └── 256x256 │ │ │ │ │ │ ├── honey_badger.usd.png │ │ │ │ │ │ └── instanceable_meshes.usd.png │ │ │ │ ├── honey_badger.usd │ │ │ │ └── instanceable_meshes.usd │ │ │ └── silver_badger │ │ │ │ ├── .thumbs │ │ │ │ └── 256x256 │ │ │ │ │ ├── instanceable_meshes.usd.png │ │ │ │ │ └── silver_badger.usd.png │ │ │ │ ├── instanceable_meshes.usd │ │ │ │ └── silver_badger.usd │ │ └── silver_badger_walking.py │ ├── lqr.py │ ├── minigrid_env.py │ ├── mujoco.py │ ├── mujoco_envs │ │ ├── __init__.py │ │ ├── air_hockey │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── defend.py │ │ │ ├── double.py │ │ │ ├── hit.py │ │ │ ├── prepare.py │ │ │ ├── repel.py │ │ │ └── single.py │ │ ├── ant.py │ │ ├── ball_in_a_cup.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── air_hockey │ │ │ │ ├── __init__.py │ │ │ │ ├── double.xml │ │ │ │ ├── planar_robot_1.xml │ │ │ │ ├── planar_robot_2.xml │ │ │ │ ├── single.xml │ │ │ │ └── table.xml │ │ │ ├── ant │ │ │ │ ├── __init__.py │ │ │ │ └── model.xml │ │ │ ├── ball_in_a_cup │ │ │ │ ├── __init__.py │ │ │ │ ├── meshes │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_link_convex.stl │ │ │ │ │ ├── base_link_fine.stl │ │ │ │ │ ├── cup_split1.stl │ │ │ │ │ ├── cup_split10.stl │ │ │ │ │ ├── cup_split11.stl │ │ │ │ │ ├── cup_split12.stl │ │ │ │ │ ├── cup_split13.stl │ │ │ │ │ ├── cup_split14.stl │ │ │ │ │ ├── cup_split15.stl │ │ │ │ │ ├── cup_split16.stl │ │ │ │ │ ├── cup_split17.stl │ │ │ │ │ ├── cup_split18.stl │ │ │ │ │ ├── cup_split2.stl │ │ │ │ │ ├── cup_split3.stl │ │ │ │ │ ├── cup_split4.stl │ │ │ │ │ ├── cup_split5.stl │ │ │ │ │ ├── cup_split6.stl │ │ │ │ │ ├── cup_split7.stl │ │ │ │ │ ├── cup_split8.stl │ │ │ │ │ ├── cup_split9.stl │ │ │ │ │ ├── elbow_link_convex.stl │ │ │ │ │ ├── elbow_link_fine.stl │ │ │ │ │ ├── forearm_link_convex_decomposition_p1.stl │ │ │ │ │ ├── forearm_link_convex_decomposition_p2.stl │ │ │ │ │ ├── forearm_link_fine.stl │ │ │ │ │ ├── shoulder_link_convex_decomposition_p1.stl │ │ │ │ │ ├── shoulder_link_convex_decomposition_p2.stl │ │ │ │ │ ├── shoulder_link_convex_decomposition_p3.stl │ │ │ │ │ ├── shoulder_link_fine.stl │ │ │ │ │ ├── shoulder_pitch_link_convex.stl │ │ │ │ │ ├── shoulder_pitch_link_fine.stl │ │ │ │ │ ├── upper_arm_link_convex_decomposition_p1.stl │ │ │ │ │ ├── upper_arm_link_convex_decomposition_p2.stl │ │ │ │ │ ├── upper_arm_link_fine.stl │ │ │ │ │ ├── wrist_palm_link_convex.stl │ │ │ │ │ ├── wrist_palm_link_fine.stl │ │ │ │ │ ├── wrist_pitch_link_convex_decomposition_p1.stl │ │ │ │ │ ├── wrist_pitch_link_convex_decomposition_p2.stl │ │ │ │ │ ├── wrist_pitch_link_convex_decomposition_p3.stl │ │ │ │ │ ├── wrist_pitch_link_fine.stl │ │ │ │ │ ├── wrist_yaw_link_convex_decomposition_p1.stl │ │ │ │ │ ├── wrist_yaw_link_convex_decomposition_p2.stl │ │ │ │ │ └── wrist_yaw_link_fine.stl │ │ │ │ └── model.xml │ │ │ ├── half_cheetah │ │ │ │ ├── __init__.py │ │ │ │ └── model.xml │ │ │ ├── hopper │ │ │ │ ├── __init__.py │ │ │ │ └── model.xml │ │ │ ├── panda │ │ │ │ ├── assets │ │ │ │ │ ├── cube.xml │ │ │ │ │ ├── finger_0.obj │ │ │ │ │ ├── finger_1.obj │ │ │ │ │ ├── hand.stl │ │ │ │ │ ├── hand_0.obj │ │ │ │ │ ├── hand_1.obj │ │ │ │ │ ├── hand_2.obj │ │ │ │ │ ├── hand_3.obj │ │ │ │ │ ├── hand_4.obj │ │ │ │ │ ├── link0.stl │ │ │ │ │ ├── link0_0.obj │ │ │ │ │ ├── link0_1.obj │ │ │ │ │ ├── link0_10.obj │ │ │ │ │ ├── link0_11.obj │ │ │ │ │ ├── link0_2.obj │ │ │ │ │ ├── link0_3.obj │ │ │ │ │ ├── link0_4.obj │ │ │ │ │ ├── link0_5.obj │ │ │ │ │ ├── link0_7.obj │ │ │ │ │ ├── link0_8.obj │ │ │ │ │ ├── link0_9.obj │ │ │ │ │ ├── link1.obj │ │ │ │ │ ├── link1.stl │ │ │ │ │ ├── link2.obj │ │ │ │ │ ├── link2.stl │ │ │ │ │ ├── link3.stl │ │ │ │ │ ├── link3_0.obj │ │ │ │ │ ├── link3_1.obj │ │ │ │ │ ├── link3_2.obj │ │ │ │ │ ├── link3_3.obj │ │ │ │ │ ├── link4.stl │ │ │ │ │ ├── link4_0.obj │ │ │ │ │ ├── link4_1.obj │ │ │ │ │ ├── link4_2.obj │ │ │ │ │ ├── link4_3.obj │ │ │ │ │ ├── link5_0.obj │ │ │ │ │ ├── link5_1.obj │ │ │ │ │ ├── link5_2.obj │ │ │ │ │ ├── link5_collision_0.obj │ │ │ │ │ ├── link5_collision_1.obj │ │ │ │ │ ├── link5_collision_2.obj │ │ │ │ │ ├── link6.stl │ │ │ │ │ ├── link6_0.obj │ │ │ │ │ ├── link6_1.obj │ │ │ │ │ ├── link6_10.obj │ │ │ │ │ ├── link6_11.obj │ │ │ │ │ ├── link6_12.obj │ │ │ │ │ ├── link6_13.obj │ │ │ │ │ ├── link6_14.obj │ │ │ │ │ ├── link6_15.obj │ │ │ │ │ ├── link6_16.obj │ │ │ │ │ ├── link6_2.obj │ │ │ │ │ ├── link6_3.obj │ │ │ │ │ ├── link6_4.obj │ │ │ │ │ ├── link6_5.obj │ │ │ │ │ ├── link6_6.obj │ │ │ │ │ ├── link6_7.obj │ │ │ │ │ ├── link6_8.obj │ │ │ │ │ ├── link6_9.obj │ │ │ │ │ ├── link7.stl │ │ │ │ │ ├── link7_0.obj │ │ │ │ │ ├── link7_1.obj │ │ │ │ │ ├── link7_2.obj │ │ │ │ │ ├── link7_3.obj │ │ │ │ │ ├── link7_4.obj │ │ │ │ │ ├── link7_5.obj │ │ │ │ │ ├── link7_6.obj │ │ │ │ │ ├── link7_7.obj │ │ │ │ │ └── table.xml │ │ │ │ ├── panda.xml │ │ │ │ ├── peg_insertion.xml │ │ │ │ ├── pick.xml │ │ │ │ ├── push.xml │ │ │ │ └── reach.xml │ │ │ └── walker_2d │ │ │ │ ├── __init__.py │ │ │ │ └── model.xml │ │ ├── half_cheetah.py │ │ ├── hopper.py │ │ ├── panda.py │ │ ├── peg_insertion.py │ │ ├── pick.py │ │ ├── push.py │ │ ├── reach.py │ │ └── walker_2d.py │ ├── omni_isaac_gym_env.py │ ├── puddle_world.py │ ├── pybullet.py │ ├── pybullet_envs │ │ ├── __init__.py │ │ ├── air_hockey │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── defend.py │ │ │ ├── double.py │ │ │ ├── hit.py │ │ │ ├── prepare.py │ │ │ ├── repel.py │ │ │ └── single.py │ │ └── data │ │ │ ├── __init__.py │ │ │ └── air_hockey │ │ │ ├── air_hockey_table.urdf │ │ │ ├── planar │ │ │ ├── planar_robot_1.urdf │ │ │ └── planar_robot_2.urdf │ │ │ └── puck.urdf │ ├── segway.py │ └── ship_steering.py ├── features │ ├── __init__.py │ ├── _implementations │ │ ├── __init__.py │ │ ├── basis_features.py │ │ ├── features_implementation.py │ │ ├── functional_features.py │ │ ├── tiles_features.py │ │ └── torch_features.py │ ├── basis │ │ ├── __init__.py │ │ ├── fourier.py │ │ ├── gaussian_rbf.py │ │ └── polynomial.py │ ├── features.py │ ├── tensors │ │ ├── __init__.py │ │ ├── basis_tensor.py │ │ ├── constant_tensor.py │ │ └── random_fourier_tensor.py │ └── tiles │ │ ├── __init__.py │ │ ├── tiles.py │ │ └── voronoi.py ├── policy │ ├── __init__.py │ ├── deterministic_policy.py │ ├── dmp.py │ ├── gaussian_policy.py │ ├── noise_policy.py │ ├── policy.py │ ├── promps.py │ ├── recurrent_torch_policy.py │ ├── td_policy.py │ ├── torch_policy.py │ └── vector_policy.py ├── rl_utils │ ├── __init__.py │ ├── eligibility_trace.py │ ├── optimizers.py │ ├── parameters.py │ ├── preprocessors.py │ ├── replay_memory.py │ ├── running_stats.py │ ├── spaces.py │ ├── value_functions.py │ └── variance_parameters.py ├── solvers │ ├── __init__.py │ ├── car_on_hill.py │ ├── dynamic_programming.py │ └── lqr.py └── utils │ ├── __init__.py │ ├── angles.py │ ├── callbacks │ ├── __init__.py │ ├── callback.py │ ├── collect_dataset.py │ ├── collect_max_q.py │ ├── collect_parameters.py │ ├── collect_q.py │ └── plot_dataset.py │ ├── episodes.py │ ├── features.py │ ├── frames.py │ ├── isaac_sim │ ├── __init__.py │ ├── action_helper.py │ ├── collision_helper.py │ ├── general_task.py │ └── observation_helper.py │ ├── isaac_utils.py │ ├── minibatches.py │ ├── mujoco │ ├── __init__.py │ ├── kinematics.py │ ├── observation_helper.py │ └── viewer.py │ ├── numerical_gradient.py │ ├── plot.py │ ├── plots │ ├── __init__.py │ ├── common_plots.py │ ├── databuffer.py │ ├── plot_item_buffer.py │ └── window.py │ ├── pybullet │ ├── __init__.py │ ├── contacts.py │ ├── index_map.py │ ├── joints_helper.py │ ├── observation.py │ └── viewer.py │ ├── quaternions.py │ ├── record.py │ ├── torch.py │ └── viewer.py ├── pyproject.toml ├── setup.py └── tests ├── algorithms ├── helper │ └── utils.py ├── test_a2c.py ├── test_black_box.py ├── test_ddpg.py ├── test_dpg.py ├── test_dqn.py ├── test_fqi.py ├── test_lspi.py ├── test_policy_gradient.py ├── test_sac.py ├── test_stochastic_ac.py ├── test_td.py └── test_trust_region.py ├── approximators ├── test_cmac_approximator.py ├── test_linear_approximator.py └── test_torch_approximator.py ├── core ├── test_array_backend.py ├── test_core.py ├── test_dataset.py ├── test_extra_info.py ├── test_logger.py ├── test_serialization.py └── test_vectorized_core.py ├── distributions ├── test_distribution_interface.py └── test_gaussian_distribution.py ├── environments ├── grid.txt ├── isaacsim_envs │ └── test_envs.py ├── mujoco_envs │ ├── air_hockey_defend_data.npy │ ├── air_hockey_hit_data.npy │ ├── air_hockey_prepare_data.npy │ ├── air_hockey_repel_data.npy │ ├── test_air_hockey.py │ ├── test_ball_in_a_cup.py │ └── test_locomotion.py ├── pybullet_envs │ ├── air_hockey_defend_data.npy │ ├── air_hockey_hit_data.npy │ ├── air_hockey_prepare_data.npy │ ├── air_hockey_repel_data.npy │ └── test_air_hockey_bullet.py ├── taxi.txt ├── test_all_envs.py ├── test_atari_1.npy └── test_mujoco.py ├── features └── test_features.py ├── policy ├── test_deterministic_policy.py ├── test_gaussian_policy.py ├── test_noise_policy.py ├── test_policy_interface.py ├── test_td_policy.py └── test_torch_policy.py ├── rl_utils └── test_value_functions.py ├── solvers ├── test_car_on_hill.py ├── test_dynamic_programming.py └── test_lqr.py ├── test_imports.py └── utils ├── test_callbacks.py ├── test_episodes.py └── test_preprocessors.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | mushroom_rl/environments/mujoco.py 4 | mushroom_rl/environments/mujoco_envs/* 5 | 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Provide a snippet of code, or a Python file. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **System information (please complete the following information):** 20 | - OS: [e.g. Ubuntu 18.04] 21 | - Python version [e.g. Python3.6] 22 | - Torch version [e.g. Pytorch 1.3] 23 | - Mushroom version [e.g. 1.2.0, master] 24 | 25 | **Additional context** 26 | Add any other context about the problem here. 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/continuous_integration.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python, upload the coverage results to code climate 2 | 3 | name: Continuous Integration 4 | 5 | on: 6 | push: 7 | branches: [ dev, dev-v1, master ] 8 | 9 | jobs: 10 | build: 11 | if: github.repository == 'MushroomRL/mushroom-rl' 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python 3.8 17 | uses: actions/setup-python@v3 18 | with: 19 | python-version: 3.8 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install flake8 pytest pytest-cov 24 | pip install -e .[all] 25 | - name: Install Atari ROMs 26 | run: | 27 | pip install "autorom[accept-rom-license]" 28 | - name: Lint with flake8 29 | run: | 30 | # stop the build if there are Python syntax errors or undefined names 31 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 32 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 33 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 34 | - name: Test with pytest 35 | run: | 36 | pytest --cov=mushroom_rl --cov-report=xml 37 | - name: Publish code coverage to CodeClimate 38 | uses: paambaati/codeclimate-action@v2.7.5 39 | env: 40 | CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} 41 | -------------------------------------------------------------------------------- /.github/workflows/test_pr.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python, upload the coverage results to code climate 2 | 3 | name: Test Pull Request 4 | 5 | on: 6 | pull_request: 7 | branches: [ dev ] 8 | 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 3.8 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: 3.8 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install flake8 pytest pytest-cov 25 | pip install -e .[all] 26 | - name: Install Atari ROMs 27 | run: | 28 | pip install "autorom[accept-rom-license]" 29 | - name: Lint with flake8 30 | run: | 31 | # stop the build if there are Python syntax errors or undefined names 32 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 33 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 34 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 35 | - name: Test with pytest 36 | run: | 37 | pytest 38 | 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_store 2 | build/ 3 | dist/ 4 | examples/mushroom_rl_recordings/ 5 | examples/habitat/Replica-Dataset 6 | examples/habitat/data 7 | mushroom_rl.egg-info/ 8 | mushroom_rl_recordings/ 9 | .idea/ 10 | *.pyc 11 | *.pyd 12 | *.xml 13 | logs/ 14 | *.h5 15 | .pytest_cache 16 | .coverage* 17 | *.c 18 | *.so 19 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-20.04 10 | tools: 11 | python: "3.8" 12 | 13 | sphinx: 14 | configuration: docs/conf.py 15 | 16 | python: 17 | install: 18 | - requirements: docs/requirements.txt 19 | - method: pip 20 | path: . 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018, 2019, 2020 Carlo D'Eramo, Davide Tateo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include setup.py 3 | prune tests 4 | prune examples 5 | prune .github 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | package: 2 | python3 -m build 3 | 4 | install: 5 | pip install $(shell ls dist/*.tar.gz) 6 | 7 | all: clean package install 8 | 9 | upload: 10 | python3 -m twine upload dist/* 11 | 12 | clean: 13 | rm -rf dist 14 | rm -rf build 15 | 16 | .NOTPARALLEL: 17 | -------------------------------------------------------------------------------- /TODO.txt: -------------------------------------------------------------------------------- 1 | Environments: 2 | * implement Multirobot PyBullet Interface to solve issues with joint name clash 3 | 4 | Algorithms: 5 | * Conservative Q-Learning 6 | * Policy Search: 7 | - Natural gradient 8 | - NES 9 | - PAPI 10 | 11 | Policy: 12 | * Add Boltzmann from logits for traditional policy gradient methods 13 | 14 | Approximator: 15 | * support for LSTM 16 | * Generalize LazyFrame to LazyState 17 | * add neural network generator 18 | 19 | For Mushroom 2.0: 20 | * Simplify Regressor interface: drop GenericRegressor, remove facade pattern 21 | * vectorize basis functions and simplify interface, simplify facade pattern 22 | * remove custom save for plotting, use Serializable 23 | * support multi-objective RL 24 | * support model-based RL 25 | * Improve replay memory, allowing to store arbitrary information into replay buffer 26 | 27 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build/ 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = MushroomRL 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/docs/__init__.py -------------------------------------------------------------------------------- /docs/_static/theme_overrides.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: 1000px !important; 3 | } -------------------------------------------------------------------------------- /docs/_templates/tabular.tex_t: -------------------------------------------------------------------------------- 1 | \begin{savenotes}\sphinxattablestart 2 | \small 3 | <% if table.align -%> 4 | <%- if table.align in ('center', 'default') -%> 5 | \centering 6 | <%- elif table.align == 'left' -%> 7 | \raggedright 8 | <%- else -%> 9 | \raggedleft 10 | <%- endif %> 11 | <%- else -%> 12 | \centering 13 | <%- endif %> 14 | <% if table.caption -%> 15 | \sphinxcapstartof{table} 16 | \sphinxthecaptionisattop 17 | \sphinxcaption{<%= ''.join(table.caption) %>}<%= labels %> 18 | \sphinxaftertopcaption 19 | <% elif labels -%> 20 | \phantomsection<%= labels %>\nobreak 21 | <% endif -%> 22 | \begin{tabular}[t]<%= table.get_colspec() -%> 23 | \hline 24 | <%= ''.join(table.header) %> 25 | <%=- ''.join(table.body) %> 26 | \end{tabular} 27 | \par 28 | \sphinxattableend\end{savenotes} 29 | -------------------------------------------------------------------------------- /docs/_templates/tabulary.tex_t: -------------------------------------------------------------------------------- 1 | \begin{savenotes}\sphinxattablestart 2 | \footnotesize 3 | <% if table.align -%> 4 | <%- if table.align in ('center', 'default') -%> 5 | \centering 6 | <%- elif table.align == 'left' -%> 7 | \raggedright 8 | <%- else -%> 9 | \raggedleft 10 | <%- endif %> 11 | <%- else -%> 12 | \centering 13 | <%- endif %> 14 | <% if table.caption -%> 15 | \sphinxcapstartof{table} 16 | \sphinxthecaptionisattop 17 | \sphinxcaption{<%= ''.join(table.caption) %>}<%= labels %> 18 | \sphinxaftertopcaption 19 | <% elif labels -%> 20 | \phantomsection<%= labels %>\nobreak 21 | <% endif -%> 22 | \begin{tabulary}{\linewidth}[t]<%= table.get_colspec() -%> 23 | \hline 24 | <%= ''.join(table.header) %> 25 | <%=- ''.join(table.body) %> 26 | \end{tabulary} 27 | \par 28 | \sphinxattableend\end{savenotes} -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | scipy 4 | gym 5 | scikit-learn 6 | matplotlib 7 | joblib 8 | tqdm 9 | pygame 10 | sphinx-rtd-theme 11 | -------------------------------------------------------------------------------- /docs/source/features/mushroom_rl.features.basis.rst: -------------------------------------------------------------------------------- 1 | Basis 2 | ===== 3 | 4 | Fourier 5 | ------- 6 | 7 | .. automodule:: mushroom_rl.features.basis.fourier 8 | :members: 9 | :private-members: 10 | :inherited-members: 11 | :show-inheritance: 12 | 13 | Gaussian RBF 14 | ------------ 15 | 16 | .. automodule:: mushroom_rl.features.basis.gaussian_rbf 17 | :members: 18 | :private-members: 19 | :inherited-members: 20 | :show-inheritance: 21 | 22 | Polynomial 23 | ---------- 24 | 25 | .. automodule:: mushroom_rl.features.basis.polynomial 26 | :members: 27 | :private-members: 28 | :inherited-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/features/mushroom_rl.features.tensors.rst: -------------------------------------------------------------------------------- 1 | Tensors 2 | ======= 3 | 4 | .. automodule:: mushroom_rl.features.tensors.constant_tensor 5 | :members: 6 | :private-members: 7 | :show-inheritance: 8 | 9 | .. automodule:: mushroom_rl.features.tensors.basis_tensor 10 | :members: 11 | :private-members: 12 | :show-inheritance: 13 | 14 | .. automodule:: mushroom_rl.features.tensors.random_fourier_tensor 15 | :members: 16 | :private-members: 17 | :show-inheritance: 18 | -------------------------------------------------------------------------------- /docs/source/features/mushroom_rl.features.tiles.rst: -------------------------------------------------------------------------------- 1 | Tiles 2 | ===== 3 | 4 | Rectangular Tiles 5 | ----------------- 6 | 7 | .. automodule:: mushroom_rl.features.tiles.tiles 8 | :members: 9 | :private-members: 10 | :inherited-members: 11 | :show-inheritance: 12 | 13 | Voronoi Tiles 14 | ------------- 15 | 16 | .. automodule:: mushroom_rl.features.tiles.voronoi 17 | :members: 18 | :private-members: 19 | :inherited-members: 20 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/mushroom_rl.agent_environment_interface.rst: -------------------------------------------------------------------------------- 1 | Agent-Environment Interface 2 | =========================== 3 | 4 | The three basic interface of mushroom_rl are the Agent, the Environment and the Core interface. 5 | 6 | - The ``Agent`` is the basic interface for any Reinforcement Learning algorithm. 7 | - The ``Environment`` is the basic interface for every problem/task that the agent should solve. 8 | - The ``Core`` is a class used to control the interaction between an agent and an environment. 9 | 10 | To implement serialization of MushroomRL data on the disk (load/save functionality) we also provide the ``Serializable`` 11 | interface. Finally, we provide the logging functionality with the ``Logger`` class. 12 | 13 | 14 | Agent 15 | ----- 16 | 17 | MushroomRL provides the implementations of several algorithms belonging to all 18 | categories of RL: 19 | 20 | - value-based; 21 | - policy-search; 22 | - actor-critic. 23 | 24 | One can easily implement customized algorithms following the structure of the 25 | already available ones, by extending the following interface: 26 | 27 | .. automodule:: mushroom_rl.core.agent 28 | :members: 29 | :private-members: 30 | :inherited-members: 31 | :show-inheritance: 32 | 33 | Environment 34 | ----------- 35 | 36 | MushroomRL provides several implementation of well known benchmarks with both 37 | continuous and discrete action spaces. 38 | 39 | To implement a new environment, it is mandatory to use the following interface: 40 | 41 | .. automodule:: mushroom_rl.core.environment 42 | :members: 43 | :private-members: 44 | :inherited-members: 45 | :show-inheritance: 46 | 47 | 48 | Core 49 | ---- 50 | 51 | .. automodule:: mushroom_rl.core.core 52 | :members: 53 | :private-members: 54 | :inherited-members: 55 | :show-inheritance: 56 | 57 | Serialization 58 | ------------- 59 | 60 | .. automodule:: mushroom_rl.core.serialization 61 | :members: 62 | :private-members: 63 | :inherited-members: 64 | :show-inheritance: 65 | 66 | Logger 67 | ------ 68 | 69 | .. automodule:: mushroom_rl.core.logger 70 | :members: 71 | :private-members: 72 | :inherited-members: 73 | :show-inheritance: 74 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.algorithms.actor_critic.rst: -------------------------------------------------------------------------------- 1 | Actor-Critic 2 | ============ 3 | 4 | Classical Actor-Critic Methods 5 | ------------------------------ 6 | 7 | .. automodule:: mushroom_rl.algorithms.actor_critic.classic_actor_critic 8 | :members: 9 | :private-members: 10 | :inherited-members: 11 | :show-inheritance: 12 | 13 | Deep Actor-Critic Methods 14 | ------------------------- 15 | 16 | .. automodule:: mushroom_rl.algorithms.actor_critic.deep_actor_critic 17 | :members: 18 | :private-members: 19 | :inherited-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.algorithms.policy_search.rst: -------------------------------------------------------------------------------- 1 | Policy search 2 | ============= 3 | 4 | Policy gradient 5 | --------------- 6 | 7 | .. automodule:: mushroom_rl.algorithms.policy_search.policy_gradient 8 | :members: 9 | :private-members: 10 | :inherited-members: 11 | :show-inheritance: 12 | 13 | Black-Box optimization 14 | ---------------------- 15 | 16 | .. automodule:: mushroom_rl.algorithms.policy_search.black_box_optimization 17 | :members: 18 | :private-members: 19 | :inherited-members: 20 | :show-inheritance: 21 | 22 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.algorithms.value.rst: -------------------------------------------------------------------------------- 1 | Value-Based 2 | =========== 3 | 4 | TD 5 | -- 6 | 7 | .. automodule:: mushroom_rl.algorithms.value.td 8 | :members: 9 | :private-members: 10 | :inherited-members: 11 | :show-inheritance: 12 | 13 | Batch TD 14 | -------- 15 | 16 | .. automodule:: mushroom_rl.algorithms.value.batch_td 17 | :members: 18 | :private-members: 19 | :inherited-members: 20 | :show-inheritance: 21 | 22 | DQN 23 | --- 24 | 25 | .. automodule:: mushroom_rl.algorithms.value.dqn 26 | :members: 27 | :private-members: 28 | :inherited-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.approximators.rst: -------------------------------------------------------------------------------- 1 | Approximators 2 | ============= 3 | 4 | MushroomRL exposes the high-level class ``Regressor`` that can manage any type of 5 | function regressor. This class is a wrapper for any kind of function 6 | approximator, e.g. a scikit-learn approximator or a pytorch neural network. 7 | 8 | Regressor 9 | --------- 10 | 11 | .. automodule:: mushroom_rl.approximators.regressor 12 | :members: 13 | :private-members: 14 | :inherited-members: 15 | :show-inheritance: 16 | 17 | 18 | Approximators 19 | ------------- 20 | 21 | Tabular 22 | ~~~~~~~ 23 | 24 | .. automodule:: mushroom_rl.approximators.table 25 | :members: 26 | :private-members: 27 | :inherited-members: 28 | :show-inheritance: 29 | 30 | 31 | 32 | Linear 33 | ~~~~~~ 34 | 35 | .. automodule:: mushroom_rl.approximators.parametric.linear 36 | :members: 37 | :private-members: 38 | :inherited-members: 39 | :show-inheritance: 40 | 41 | CMAC 42 | ~~~~ 43 | 44 | .. automodule:: mushroom_rl.approximators.parametric.cmac 45 | :members: 46 | :private-members: 47 | :inherited-members: 48 | :show-inheritance: 49 | 50 | Torch Approximator 51 | ~~~~~~~~~~~~~~~~~~ 52 | 53 | .. automodule:: mushroom_rl.approximators.parametric.torch_approximator 54 | :members: 55 | :private-members: 56 | :inherited-members: 57 | :show-inheritance: 58 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.distributions.rst: -------------------------------------------------------------------------------- 1 | Distributions 2 | ============= 3 | 4 | .. automodule:: mushroom_rl.distributions.distribution 5 | :members: 6 | :private-members: 7 | :inherited-members: 8 | :show-inheritance: 9 | 10 | Gaussian 11 | -------- 12 | 13 | .. automodule:: mushroom_rl.distributions.gaussian 14 | :members: 15 | :private-members: 16 | :inherited-members: 17 | :show-inheritance: 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.features.rst: -------------------------------------------------------------------------------- 1 | Features 2 | ======== 3 | 4 | The features in MushroomRL are 1-D arrays computed applying a specified function 5 | to a raw input, e.g. polynomial features of the state of an MDP. 6 | MushroomRL supports three types of features: 7 | 8 | * basis functions; 9 | * tensor basis functions; 10 | * tiles. 11 | 12 | The tensor basis functions are a PyTorch implementation of the standard 13 | basis functions. They are less straightforward than the standard ones, but they 14 | are faster to compute as they can exploit parallel computing, e.g. GPU-acceleration 15 | and multi-core systems. 16 | 17 | All the types of features are exposed by a single factory method ``Features`` 18 | that builds the one requested by the user. 19 | 20 | .. automodule:: mushroom_rl.features.features 21 | :members: 22 | :private-members: 23 | :inherited-members: 24 | :show-inheritance: 25 | 26 | The factory method returns a class that extends the abstract class 27 | ``FeatureImplementation``. 28 | 29 | .. automodule:: mushroom_rl.features._implementations.features_implementation 30 | :members: 31 | :private-members: 32 | :inherited-members: 33 | :show-inheritance: 34 | 35 | The documentation for every feature type can be found here: 36 | 37 | .. toctree:: 38 | 39 | features/mushroom_rl.features.basis 40 | features/mushroom_rl.features.tensors 41 | features/mushroom_rl.features.tiles 42 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.policy.rst: -------------------------------------------------------------------------------- 1 | Policy 2 | ====== 3 | 4 | .. automodule:: mushroom_rl.policy.policy 5 | :members: 6 | :private-members: 7 | :inherited-members: 8 | :show-inheritance: 9 | 10 | Deterministic policy 11 | -------------------- 12 | 13 | .. automodule:: mushroom_rl.policy.deterministic_policy 14 | :members: 15 | :private-members: 16 | :inherited-members: 17 | :show-inheritance: 18 | 19 | Gaussian policy 20 | --------------- 21 | 22 | .. automodule:: mushroom_rl.policy.gaussian_policy 23 | :members: 24 | :private-members: 25 | :inherited-members: 26 | :show-inheritance: 27 | 28 | Noise policy 29 | ------------ 30 | 31 | .. automodule:: mushroom_rl.policy.noise_policy 32 | :members: 33 | :private-members: 34 | :inherited-members: 35 | :show-inheritance: 36 | 37 | TD policy 38 | --------- 39 | 40 | .. automodule:: mushroom_rl.policy.td_policy 41 | :members: 42 | :private-members: 43 | :inherited-members: 44 | :show-inheritance: 45 | 46 | Torch policy 47 | ------------ 48 | 49 | .. automodule:: mushroom_rl.policy.torch_policy 50 | :members: 51 | :private-members: 52 | :inherited-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.rl_utils.rst: -------------------------------------------------------------------------------- 1 | Reinforcement Learning utils 2 | ============================ 3 | 4 | Eligibility trace 5 | ----------------- 6 | 7 | .. automodule:: mushroom_rl.rl_utils.eligibility_trace 8 | :members: 9 | :private-members: 10 | :inherited-members: 11 | :undoc-members: 12 | :show-inheritance: 13 | 14 | Optimizers 15 | ---------- 16 | 17 | .. automodule:: mushroom_rl.rl_utils.optimizers 18 | :members: 19 | :private-members: 20 | :inherited-members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | Parameters 25 | ---------- 26 | 27 | .. automodule:: mushroom_rl.rl_utils.parameters 28 | :members: 29 | :private-members: 30 | :inherited-members: 31 | :show-inheritance: 32 | 33 | 34 | Preprocessors 35 | ------------- 36 | 37 | .. automodule:: mushroom_rl.rl_utils.preprocessors 38 | :members: 39 | :private-members: 40 | :inherited-members: 41 | :show-inheritance: 42 | 43 | Replay memory 44 | ------------- 45 | 46 | .. automodule:: mushroom_rl.rl_utils.replay_memory 47 | :members: 48 | :private-members: 49 | :inherited-members: 50 | :show-inheritance: 51 | 52 | Running Statistics 53 | ------------------ 54 | 55 | .. automodule:: mushroom_rl.rl_utils.running_stats 56 | :members: 57 | :private-members: 58 | :inherited-members: 59 | :show-inheritance: 60 | 61 | Spaces 62 | ------ 63 | 64 | .. automodule:: mushroom_rl.rl_utils.spaces 65 | :members: 66 | :show-inheritance: 67 | 68 | 69 | Value Functions 70 | --------------- 71 | 72 | .. automodule:: mushroom_rl.rl_utils.value_functions 73 | :members: 74 | :private-members: 75 | :inherited-members: 76 | :show-inheritance: 77 | 78 | Variance parameters 79 | ------------------- 80 | 81 | .. automodule:: mushroom_rl.rl_utils.variance_parameters 82 | :members: 83 | :private-members: 84 | :inherited-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.solvers.rst: -------------------------------------------------------------------------------- 1 | Solvers 2 | ======= 3 | 4 | Dynamic programming 5 | ------------------- 6 | 7 | .. automodule:: mushroom_rl.solvers.dynamic_programming 8 | :members: 9 | :private-members: 10 | :inherited-members: 11 | :show-inheritance: 12 | 13 | Car-On-Hill brute-force solver 14 | ------------------------------ 15 | 16 | .. automodule:: mushroom_rl.solvers.car_on_hill 17 | :members: 18 | :private-members: 19 | :inherited-members: 20 | :show-inheritance: 21 | 22 | 23 | LQR solver 24 | ---------- 25 | 26 | .. automodule:: mushroom_rl.solvers.lqr 27 | :members: 28 | :private-members: 29 | :inherited-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/mushroom_rl.utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | 4 | Angles 5 | ------ 6 | 7 | .. automodule:: mushroom_rl.utils.angles 8 | :members: 9 | :show-inheritance: 10 | 11 | Features 12 | -------- 13 | 14 | .. automodule:: mushroom_rl.utils.features 15 | :members: 16 | :private-members: 17 | :inherited-members: 18 | :show-inheritance: 19 | 20 | 21 | Frames 22 | ------ 23 | 24 | .. automodule:: mushroom_rl.utils.frames 25 | :members: 26 | :private-members: 27 | :inherited-members: 28 | :show-inheritance: 29 | 30 | 31 | Minibatches 32 | ----------- 33 | 34 | .. automodule:: mushroom_rl.utils.minibatches 35 | :members: 36 | :private-members: 37 | :inherited-members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | Numerical gradient 42 | ------------------ 43 | 44 | .. automodule:: mushroom_rl.utils.numerical_gradient 45 | :members: 46 | :private-members: 47 | :inherited-members: 48 | :show-inheritance: 49 | 50 | 51 | Plots 52 | ----- 53 | 54 | .. automodule:: mushroom_rl.utils.plot 55 | :members: 56 | :private-members: 57 | :inherited-members: 58 | :show-inheritance: 59 | 60 | 61 | Record 62 | ------ 63 | 64 | .. automodule:: mushroom_rl.utils.record 65 | :members: 66 | :private-members: 67 | :inherited-members: 68 | :show-inheritance: 69 | 70 | 71 | Torch 72 | ----- 73 | 74 | .. automodule:: mushroom_rl.utils.torch 75 | :members: 76 | :private-members: 77 | :inherited-members: 78 | :show-inheritance: 79 | 80 | 81 | Viewer 82 | ------ 83 | 84 | .. automodule:: mushroom_rl.utils.viewer 85 | :members: 86 | :private-members: 87 | :inherited-members: 88 | :show-inheritance: 89 | -------------------------------------------------------------------------------- /docs/source/tutorials/code/advanced_experiment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value import SARSALambdaContinuous 4 | from mushroom_rl.approximators.parametric import LinearApproximator 5 | from mushroom_rl.core import Core 6 | from mushroom_rl.features import Features 7 | from mushroom_rl.features.tiles import Tiles 8 | from mushroom_rl.policy import EpsGreedy 9 | from mushroom_rl.utils.callbacks import CollectDataset 10 | from mushroom_rl.rl_utils.parameters import Parameter 11 | from mushroom_rl.environments import Gymnasium 12 | 13 | # MDP 14 | mdp = Gymnasium(name='MountainCar-v0', horizon=np.inf, gamma=1.) 15 | 16 | # Policy 17 | epsilon = Parameter(value=0.) 18 | pi = EpsGreedy(epsilon=epsilon) 19 | 20 | # Q-function approximator 21 | n_tilings = 10 22 | tilings = Tiles.generate(n_tilings, [10, 10], 23 | mdp.info.observation_space.low, 24 | mdp.info.observation_space.high) 25 | features = Features(tilings=tilings) 26 | 27 | approximator_params = dict(input_shape=(features.size,), 28 | output_shape=(mdp.info.action_space.n,), 29 | n_actions=mdp.info.action_space.n) 30 | 31 | # Agent 32 | learning_rate = Parameter(.1 / n_tilings) 33 | 34 | agent = SARSALambdaContinuous(mdp.info, pi, LinearApproximator, 35 | approximator_params=approximator_params, 36 | learning_rate=learning_rate, 37 | lambda_coeff=.9, features=features) 38 | 39 | # Algorithm 40 | collect_dataset = CollectDataset() 41 | callbacks = [collect_dataset] 42 | core = Core(agent, mdp, callbacks_fit=callbacks) 43 | 44 | # Train 45 | core.learn(n_episodes=100, n_steps_per_fit=1) 46 | 47 | # Evaluate 48 | core.evaluate(n_episodes=1, render=True) 49 | -------------------------------------------------------------------------------- /docs/source/tutorials/code/approximator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value import SARSALambdaContinuous 4 | from mushroom_rl.approximators.parametric import LinearApproximator 5 | from mushroom_rl.core import Core 6 | from mushroom_rl.environments import * 7 | from mushroom_rl.features import Features 8 | from mushroom_rl.features.tiles import Tiles 9 | from mushroom_rl.policy import EpsGreedy 10 | from mushroom_rl.utils.callbacks import CollectDataset 11 | from mushroom_rl.rl_utils.parameters import Parameter 12 | 13 | 14 | # MDP 15 | mdp = Gymnasium(name='MountainCar-v0', horizon=np.inf, gamma=1.) 16 | 17 | # Policy 18 | epsilon = Parameter(value=0.) 19 | pi = EpsGreedy(epsilon=epsilon) 20 | 21 | # Q-function approximator 22 | n_tilings = 10 23 | tilings = Tiles.generate(n_tilings, [10, 10], 24 | mdp.info.observation_space.low, 25 | mdp.info.observation_space.high) 26 | features = Features(tilings=tilings) 27 | 28 | # Agent 29 | learning_rate = Parameter(.1 / n_tilings) 30 | approximator_params = dict(input_shape=(features.size,), 31 | output_shape=(mdp.info.action_space.n,), 32 | n_actions=mdp.info.action_space.n) 33 | agent = SARSALambdaContinuous(mdp.info, pi, LinearApproximator, 34 | approximator_params=approximator_params, 35 | learning_rate=learning_rate, 36 | lambda_coeff=.9, features=features) 37 | 38 | # Algorithm 39 | collect_dataset = CollectDataset() 40 | callbacks = [collect_dataset] 41 | core = Core(agent, mdp, callbacks_fit=callbacks) 42 | 43 | # Train 44 | core.learn(n_episodes=100, n_steps_per_fit=1) 45 | 46 | # Evaluate 47 | core.evaluate(n_episodes=1, render=True) 48 | -------------------------------------------------------------------------------- /docs/source/tutorials/code/generic_regressor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | from mushroom_rl.approximators import Regressor 5 | from mushroom_rl.approximators.parametric import LinearApproximator 6 | 7 | 8 | x = np.arange(10).reshape(-1, 1) 9 | 10 | intercept = 10 11 | noise = np.random.randn(10, 1) * 1 12 | y = 2 * x + intercept + noise 13 | 14 | phi = np.concatenate((np.ones(10).reshape(-1, 1), x), axis=1) 15 | 16 | regressor = Regressor(LinearApproximator, 17 | input_shape=(2,), 18 | output_shape=(1,)) 19 | 20 | regressor.fit(phi, y) 21 | 22 | print('Weights: ' + str(regressor.get_weights())) 23 | print('Gradient: ' + str(regressor.diff(np.array([[5.]])))) 24 | 25 | plt.scatter(x, y) 26 | plt.plot(x, regressor.predict(phi)) 27 | plt.show() 28 | -------------------------------------------------------------------------------- /docs/source/tutorials/code/simple_experiment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.ensemble import ExtraTreesRegressor 3 | 4 | from mushroom_rl.algorithms.value import FQI 5 | from mushroom_rl.core import Core 6 | from mushroom_rl.environments import CarOnHill 7 | from mushroom_rl.policy import EpsGreedy 8 | from mushroom_rl.rl_utils.parameters import Parameter 9 | 10 | mdp = CarOnHill() 11 | 12 | # Policy 13 | epsilon = Parameter(value=1.) 14 | pi = EpsGreedy(epsilon=epsilon) 15 | 16 | # Approximator 17 | approximator_params = dict(input_shape=mdp.info.observation_space.shape, 18 | n_actions=mdp.info.action_space.n, 19 | n_estimators=50, 20 | min_samples_split=5, 21 | min_samples_leaf=2) 22 | approximator = ExtraTreesRegressor 23 | 24 | # Agent 25 | agent = FQI(mdp.info, pi, approximator, n_iterations=20, 26 | approximator_params=approximator_params) 27 | 28 | core = Core(agent, mdp) 29 | 30 | core.learn(n_episodes=1000, n_episodes_per_fit=1000) 31 | 32 | pi.set_epsilon(Parameter(0.)) 33 | initial_state = np.array([[-.5, 0.]]) 34 | dataset = core.evaluate(initial_states=initial_state) 35 | 36 | print(dataset.discounted_return) 37 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/__init__.py -------------------------------------------------------------------------------- /examples/car_on_hill_fqi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from joblib import Parallel, delayed 3 | from sklearn.ensemble import ExtraTreesRegressor 4 | 5 | from mushroom_rl.algorithms.value import FQI 6 | from mushroom_rl.core import Core, Logger 7 | from mushroom_rl.environments import * 8 | from mushroom_rl.policy import EpsGreedy 9 | from mushroom_rl.rl_utils.parameters import Parameter 10 | 11 | """ 12 | This script aims to replicate the experiments on the Car on Hill MDP as 13 | presented in: 14 | "Tree-Based Batch Mode Reinforcement Learning", Ernst D. et al.. 2005. 15 | 16 | """ 17 | 18 | 19 | def experiment(): 20 | np.random.seed() 21 | 22 | # MDP 23 | mdp = CarOnHill() 24 | 25 | # Policy 26 | epsilon = Parameter(value=1.) 27 | pi = EpsGreedy(epsilon=epsilon) 28 | 29 | # Approximator 30 | approximator_params = dict(input_shape=mdp.info.observation_space.shape, 31 | n_actions=mdp.info.action_space.n, 32 | n_estimators=50, 33 | min_samples_split=5, 34 | min_samples_leaf=2) 35 | approximator = ExtraTreesRegressor 36 | 37 | # Agent 38 | algorithm_params = dict(n_iterations=20) 39 | agent = FQI(mdp.info, pi, approximator, 40 | approximator_params=approximator_params, **algorithm_params) 41 | 42 | # Algorithm 43 | core = Core(agent, mdp) 44 | 45 | # Render 46 | core.evaluate(n_episodes=1, render=True) 47 | 48 | # Train 49 | core.learn(n_episodes=1000, n_episodes_per_fit=1000) 50 | 51 | # Test 52 | test_epsilon = Parameter(0.) 53 | agent.policy.set_epsilon(test_epsilon) 54 | 55 | initial_states = np.zeros((289, 2)) 56 | cont = 0 57 | for i in range(-8, 9): 58 | for j in range(-8, 9): 59 | initial_states[cont, :] = [0.125 * i, 0.375 * j] 60 | cont += 1 61 | 62 | dataset = core.evaluate(initial_states=initial_states) 63 | 64 | # Render 65 | core.evaluate(n_episodes=3, render=True) 66 | 67 | return np.mean(dataset.discounted_return) 68 | 69 | 70 | if __name__ == '__main__': 71 | n_experiment = 1 72 | 73 | logger = Logger(FQI.__name__, results_dir=None) 74 | logger.strong_line() 75 | logger.info('Experiment Algorithm: ' + FQI.__name__) 76 | 77 | Js = Parallel(n_jobs=None)(delayed(experiment)() for _ in range(n_experiment)) 78 | logger.info((np.mean(Js))) 79 | -------------------------------------------------------------------------------- /examples/cartpole_lspi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value import LSPI 4 | from mushroom_rl.core import Core, Logger 5 | from mushroom_rl.environments import * 6 | from mushroom_rl.features import Features 7 | from mushroom_rl.features.basis import PolynomialBasis, GaussianRBF 8 | from mushroom_rl.policy import EpsGreedy 9 | from mushroom_rl.rl_utils.parameters import Parameter 10 | 11 | 12 | """ 13 | This script aims to replicate the experiments on the Inverted Pendulum MDP as 14 | presented in: 15 | "Least-Squares Policy Iteration". Lagoudakis M. G. and Parr R.. 2003. 16 | 17 | """ 18 | 19 | 20 | def experiment(): 21 | np.random.seed() 22 | 23 | # MDP 24 | mdp = CartPole() 25 | 26 | # Policy 27 | epsilon = Parameter(value=1.) 28 | pi = EpsGreedy(epsilon=epsilon) 29 | 30 | # Agent 31 | basis = [PolynomialBasis()] 32 | 33 | s1 = np.array([-np.pi, 0, np.pi]) * .25 34 | s2 = np.array([-1, 0, 1]) 35 | for i in s1: 36 | for j in s2: 37 | basis.append(GaussianRBF(np.array([i, j]), np.array([1.]))) 38 | features = Features(basis_list=basis) 39 | 40 | fit_params = dict() 41 | approximator_params = dict(input_shape=(features.size,), 42 | output_shape=(mdp.info.action_space.n,), 43 | n_actions=mdp.info.action_space.n, 44 | phi=features) 45 | agent = LSPI(mdp.info, pi, approximator_params=approximator_params, fit_params=fit_params) 46 | 47 | # Algorithm 48 | core = Core(agent, mdp) 49 | core.evaluate(n_episodes=3, render=True) 50 | 51 | # Train 52 | core.learn(n_episodes=500, n_episodes_per_fit=500) 53 | 54 | # Test 55 | test_epsilon = Parameter(0.) 56 | agent.policy.set_epsilon(test_epsilon) 57 | 58 | dataset = core.evaluate(n_episodes=1, quiet=True) 59 | 60 | core.evaluate(n_steps=100, render=True) 61 | 62 | return np.mean(dataset.episodes_length) 63 | 64 | 65 | if __name__ == '__main__': 66 | n_experiment = 1 67 | 68 | logger = Logger(LSPI.__name__, results_dir=None) 69 | logger.strong_line() 70 | logger.info('Experiment Algorithm: ' + LSPI.__name__) 71 | 72 | steps = experiment() 73 | logger.info('Final episode length: %d' % steps) 74 | -------------------------------------------------------------------------------- /examples/double_chain_q_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/double_chain_q_learning/__init__.py -------------------------------------------------------------------------------- /examples/double_chain_q_learning/chain_structure/p.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/double_chain_q_learning/chain_structure/p.npy -------------------------------------------------------------------------------- /examples/double_chain_q_learning/chain_structure/rew.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/double_chain_q_learning/chain_structure/rew.npy -------------------------------------------------------------------------------- /examples/double_chain_q_learning/double_chain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from joblib import Parallel, delayed 3 | from pathlib import Path 4 | 5 | from mushroom_rl.algorithms.value import QLearning, DoubleQLearning,\ 6 | WeightedQLearning, SpeedyQLearning 7 | from mushroom_rl.core import Core 8 | from mushroom_rl.environments import * 9 | from mushroom_rl.policy import EpsGreedy 10 | from mushroom_rl.utils.callbacks import CollectQ 11 | from mushroom_rl.rl_utils.parameters import Parameter, DecayParameter 12 | 13 | 14 | """ 15 | Simple script to solve a double chain with Q-Learning and some of its variants. 16 | The considered double chain is the one presented in: 17 | "Relative Entropy Policy Search". Peters J. et al.. 2010. 18 | 19 | """ 20 | 21 | 22 | def experiment(algorithm_class, exp): 23 | np.random.seed() 24 | 25 | # MDP 26 | path = Path(__file__).resolve().parent / 'chain_structure' 27 | p = np.load(path / 'p.npy') 28 | rew = np.load(path / 'rew.npy') 29 | mdp = FiniteMDP(p, rew, gamma=.9) 30 | 31 | # Policy 32 | epsilon = Parameter(value=1.) 33 | pi = EpsGreedy(epsilon=epsilon) 34 | 35 | # Agent 36 | learning_rate = DecayParameter(value=1., exp=exp, size=mdp.info.size) 37 | algorithm_params = dict(learning_rate=learning_rate) 38 | agent = algorithm_class(mdp.info, pi, **algorithm_params) 39 | 40 | # Algorithm 41 | collect_Q = CollectQ(agent.Q) 42 | callbacks = [collect_Q] 43 | core = Core(agent, mdp, callbacks) 44 | 45 | # Train 46 | core.learn(n_steps=20000, n_steps_per_fit=1, quiet=True) 47 | 48 | Qs = collect_Q.get() 49 | 50 | return Qs 51 | 52 | 53 | if __name__ == '__main__': 54 | n_experiment = 5 55 | 56 | names = {1: '1', .51: '51', QLearning: 'Q', DoubleQLearning: 'DQ', 57 | WeightedQLearning: 'WQ', SpeedyQLearning: 'SPQ'} 58 | 59 | log_path = Path(__file__).resolve().parent / 'logs' 60 | 61 | log_path.mkdir(parents=True, exist_ok=True) 62 | 63 | for e in [1, .51]: 64 | for a in [QLearning, DoubleQLearning, WeightedQLearning, 65 | SpeedyQLearning]: 66 | out = Parallel(n_jobs=1)( 67 | delayed(experiment)(a, e) for _ in range(n_experiment)) 68 | Qs = np.array([o for o in out]) 69 | 70 | Qs = np.mean(Qs, 0) 71 | 72 | filename = names[a] + names[e] + '.npy' 73 | np.save(log_path / filename, Qs[:, 0, 0]) 74 | -------------------------------------------------------------------------------- /examples/habitat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/habitat/__init__.py -------------------------------------------------------------------------------- /examples/habitat/pointnav_apartment-0.yaml: -------------------------------------------------------------------------------- 1 | ENV_NAME: "NavRLEnv" 2 | 3 | ENVIRONMENT: 4 | MAX_EPISODE_STEPS: 500 5 | 6 | SIMULATOR: 7 | HABITAT_SIM_V0: 8 | GPU_DEVICE_ID: 0 9 | 10 | RGB_SENSOR: # Used for observations 11 | WIDTH: 64 12 | HEIGHT: 64 13 | HFOV: 79 14 | POSITION: [0, 0.88, 0] 15 | 16 | ACTION_SPACE_CONFIG: "v0" 17 | FORWARD_STEP_SIZE: 0.25 # How much the agent moves with 'forward' 18 | TURN_ANGLE: 10 # How much the agent turns with 'left' / 'right' actions 19 | 20 | TASK: 21 | TYPE: Nav-v0 22 | 23 | # Set both to the same value 24 | SUCCESS_DISTANCE: 0.2 25 | SUCCESS: 26 | SUCCESS_DISTANCE: 0.2 27 | 28 | SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR'] 29 | POINTGOAL_WITH_GPS_COMPASS_SENSOR: 30 | GOAL_FORMAT: "POLAR" 31 | DIMENSIONALITY: 2 32 | GOAL_SENSOR_UUID: pointgoal_with_gps_compass 33 | 34 | MEASUREMENTS: ['DISTANCE_TO_GOAL', 'SUCCESS', 'SPL'] 35 | 36 | DATASET: # Replica scene 37 | CONTENT_SCENES: ['*'] 38 | DATA_PATH: "replica_{split}_apartment-0.json.gz" 39 | SCENES_DIR: "Replica-Dataset/replica-path/apartment_0" 40 | TYPE: PointNav-v1 41 | SPLIT: train 42 | -------------------------------------------------------------------------------- /examples/habitat/replica_train_apartment-0.json: -------------------------------------------------------------------------------- 1 | { 2 | "episodes": [{"episode_id": "0", 3 | "scene_id": "habitat/mesh_semantic.ply", 4 | "start_position": [-0.716670036315918, -1.374765157699585, 0.7762265205383301], 5 | "start_rotation": [0.0, 0.0, 0.0, 1.0], 6 | "goals": [{"position": [4.170074462890625, -1.374765157699585, 1.8612048625946045], "radius": null}], 7 | "shortest_paths": null, 8 | "start_room": null}] 9 | } 10 | -------------------------------------------------------------------------------- /examples/habitat/replica_train_apartment-0.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/habitat/replica_train_apartment-0.json.gz -------------------------------------------------------------------------------- /examples/simple_chain_qlearning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value import QLearning 4 | from mushroom_rl.core import Core, Logger 5 | from mushroom_rl.environments import * 6 | from mushroom_rl.policy import EpsGreedy 7 | from mushroom_rl.rl_utils.parameters import Parameter 8 | 9 | 10 | """ 11 | Simple script to solve a simple chain with Q-Learning. 12 | 13 | """ 14 | 15 | 16 | def experiment(): 17 | np.random.seed() 18 | 19 | logger = Logger(QLearning.__name__, results_dir=None) 20 | logger.strong_line() 21 | logger.info('Experiment Algorithm: ' + QLearning.__name__) 22 | 23 | # MDP 24 | mdp = generate_simple_chain(state_n=5, goal_states=[2], prob=.8, rew=1, 25 | gamma=.9) 26 | 27 | # Policy 28 | epsilon = Parameter(value=.15) 29 | pi = EpsGreedy(epsilon=epsilon) 30 | 31 | # Agent 32 | learning_rate = Parameter(value=.2) 33 | algorithm_params = dict(learning_rate=learning_rate) 34 | agent = QLearning(mdp.info, pi, **algorithm_params) 35 | 36 | # Core 37 | core = Core(agent, mdp) 38 | 39 | # Initial policy Evaluation 40 | dataset = core.evaluate(n_steps=1000) 41 | J = np.mean(dataset.discounted_return) 42 | logger.info(f'J start: {J}') 43 | 44 | # Train 45 | core.learn(n_steps=10000, n_steps_per_fit=1) 46 | 47 | # Final Policy Evaluation 48 | dataset = core.evaluate(n_steps=1000) 49 | J = np.mean(dataset.discounted_return) 50 | logger.info(f'J final: {J}') 51 | 52 | 53 | if __name__ == '__main__': 54 | experiment() 55 | -------------------------------------------------------------------------------- /examples/taxi_mellow_sarsa/grid.txt: -------------------------------------------------------------------------------- 1 | S#F.#.G 2 | .#..#.. 3 | ....... 4 | ##...## 5 | ......F 6 | F.....# 7 | -------------------------------------------------------------------------------- /examples/taxi_mellow_sarsa/taxi_mellow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from joblib import Parallel, delayed 3 | 4 | from mushroom_rl.algorithms.value import SARSA 5 | from mushroom_rl.core import Core 6 | from mushroom_rl.environments.generators.taxi import generate_taxi 7 | from mushroom_rl.policy import Boltzmann, EpsGreedy, Mellowmax 8 | from mushroom_rl.utils.callbacks import CollectDataset 9 | from mushroom_rl.rl_utils.parameters import Parameter 10 | 11 | 12 | """ 13 | This script aims to replicate the experiments on the Taxi MDP as presented in: 14 | "An Alternative Softmax Operator for Reinforcement Learning", Asadi K. et al.. 15 | 2017. 16 | 17 | """ 18 | 19 | 20 | def experiment(policy, value): 21 | np.random.seed() 22 | 23 | # MDP 24 | mdp = generate_taxi('grid.txt') 25 | 26 | # Policy 27 | pi = policy(Parameter(value=value)) 28 | 29 | # Agent 30 | learning_rate = Parameter(value=.15) 31 | algorithm_params = dict(learning_rate=learning_rate) 32 | agent = SARSA(mdp.info, pi, **algorithm_params) 33 | 34 | # Algorithm 35 | collect_dataset = CollectDataset() 36 | callbacks = [collect_dataset] 37 | core = Core(agent, mdp, callbacks) 38 | 39 | # Train 40 | n_steps = 300000 41 | core.learn(n_steps=n_steps, n_steps_per_fit=1, quiet=True) 42 | 43 | return np.sum(np.array(collect_dataset.get())[:, 2]) / float(n_steps) 44 | 45 | 46 | if __name__ == '__main__': 47 | n_experiment = 25 48 | 49 | algs = {EpsGreedy: 'epsilon', Boltzmann: 'boltzmann', Mellowmax: 'mellow'} 50 | ranges = {EpsGreedy: np.linspace(.05, .5, 10), 51 | Boltzmann: np.linspace(.5, 10, 10), 52 | Mellowmax: np.linspace(.5, 10, 10)} 53 | 54 | for p in [EpsGreedy, Boltzmann, Mellowmax]: 55 | print('Policy: ', algs[p]) 56 | Js = list() 57 | for v in ranges[p]: 58 | out = Parallel(n_jobs=-1)( 59 | delayed(experiment)(p, v) for _ in range(n_experiment)) 60 | J = [np.mean(o) for o in out] 61 | Js.append(np.mean(J)) 62 | 63 | np.save('r_%s.npy' % algs[p], Js) 64 | -------------------------------------------------------------------------------- /examples/vectorized_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/vectorized_core/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.0.0-rc1' -------------------------------------------------------------------------------- /mushroom_rl/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/algorithms/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/algorithms/actor_critic/__init__.py: -------------------------------------------------------------------------------- 1 | from .classic_actor_critic import StochasticAC, StochasticAC_AVG, COPDAC_Q 2 | from .deep_actor_critic import DeepAC, A2C, DDPG, TD3, SAC, TRPO, PPO, PPO_BPTT, RudinPPO 3 | 4 | __all__ = ['COPDAC_Q', 'StochasticAC', 'StochasticAC_AVG', 5 | 'DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO', 'PPO_BPTT', 6 | 'RudinPPO'] 7 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/actor_critic/classic_actor_critic/__init__.py: -------------------------------------------------------------------------------- 1 | from .copdac_q import COPDAC_Q 2 | from .stochastic_ac import StochasticAC, StochasticAC_AVG 3 | 4 | __all__ = ['COPDAC_Q', 'StochasticAC', 'StochasticAC_AVG'] -------------------------------------------------------------------------------- /mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py: -------------------------------------------------------------------------------- 1 | from .deep_actor_critic import OnPolicyDeepAC, DeepAC 2 | from .a2c import A2C 3 | from .ddpg import DDPG 4 | from .td3 import TD3 5 | from .sac import SAC 6 | from .trpo import TRPO 7 | from .ppo import PPO 8 | from .ppo_bptt import PPO_BPTT 9 | from .ppo_rudin import RudinPPO 10 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy_gradient import REINFORCE, GPOMDP, eNAC 2 | from .black_box_optimization import RWR, PGPE, REPS, ConstrainedREPS, MORE, ePPO 3 | 4 | 5 | __all__ = ['REINFORCE', 'GPOMDP', 'eNAC', 'RWR', 'PGPE', 'REPS', 'ConstrainedREPS', 'MORE', 'ePPO'] 6 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/black_box_optimization/__init__.py: -------------------------------------------------------------------------------- 1 | from .context_builder import ContextBuilder 2 | from .black_box_optimization import BlackBoxOptimization 3 | from .rwr import RWR 4 | from .reps import REPS 5 | from .pgpe import PGPE 6 | from .constrained_reps import ConstrainedREPS 7 | from .more import MORE 8 | from .eppo import ePPO 9 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/black_box_optimization/context_builder.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.core import Serializable 2 | 3 | 4 | class ContextBuilder(Serializable): 5 | def __init__(self, context_shape=None): 6 | self._context_shape = context_shape 7 | 8 | super().__init__() 9 | 10 | self._add_save_attr(_context_shape='primitive') 11 | 12 | def __call__(self, initial_state, **episode_info): 13 | return None 14 | 15 | @property 16 | def context_shape(self): 17 | return self._context_shape 18 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/black_box_optimization/pgpe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.policy_search.black_box_optimization import BlackBoxOptimization 4 | 5 | 6 | class PGPE(BlackBoxOptimization): 7 | """ 8 | Policy Gradient with Parameter Exploration algorithm. 9 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G., 10 | Peters J.. 2013. 11 | 12 | """ 13 | def __init__(self, mdp_info, distribution, policy, optimizer, context_builder=None): 14 | """ 15 | Constructor. 16 | 17 | Args: 18 | optimizer: the gradient step optimizer. 19 | 20 | """ 21 | self.optimizer = optimizer 22 | 23 | super().__init__(mdp_info, distribution, policy, context_builder=context_builder) 24 | 25 | self._add_save_attr(optimizer='mushroom') 26 | 27 | def _update(self, Jep, theta, context): 28 | baseline_num_list = list() 29 | baseline_den_list = list() 30 | diff_log_dist_list = list() 31 | 32 | # Compute derivatives of distribution and baseline components 33 | for i in range(len(Jep)): 34 | J_i = Jep[i] 35 | theta_i = theta[i] 36 | 37 | diff_log_dist = self.distribution.diff_log(theta_i, context) 38 | diff_log_dist2 = diff_log_dist**2 39 | 40 | diff_log_dist_list.append(diff_log_dist) 41 | baseline_num_list.append(J_i * diff_log_dist2) 42 | baseline_den_list.append(diff_log_dist2) 43 | 44 | # Compute baseline 45 | baseline = np.mean(baseline_num_list, axis=0) / np.mean(baseline_den_list, axis=0) 46 | baseline[np.logical_not(np.isfinite(baseline))] = 0. 47 | 48 | # Compute gradient 49 | grad_J_list = list() 50 | for i in range(len(Jep)): 51 | diff_log_dist = diff_log_dist_list[i] 52 | J_i = Jep[i] 53 | 54 | grad_J_list.append(diff_log_dist * (J_i - baseline)) 55 | 56 | grad_J = np.mean(grad_J_list, axis=0) 57 | 58 | omega_old = self.distribution.get_parameters() 59 | omega_new = self.optimizer(omega_old, grad_J) 60 | self.distribution.set_parameters(omega_new) 61 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/black_box_optimization/reps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy.optimize import minimize 4 | 5 | from mushroom_rl.algorithms.policy_search.black_box_optimization import BlackBoxOptimization 6 | from mushroom_rl.rl_utils.parameters import to_parameter 7 | 8 | 9 | class REPS(BlackBoxOptimization): 10 | """ 11 | Episodic Relative Entropy Policy Search algorithm. 12 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G., 13 | Peters J.. 2013. 14 | 15 | """ 16 | def __init__(self, mdp_info, distribution, policy, eps): 17 | """ 18 | Constructor. 19 | 20 | Args: 21 | eps ([float, Parameter]): the maximum admissible value for the Kullback-Leibler 22 | divergence between the new distribution and the 23 | previous one at each update step. 24 | 25 | """ 26 | assert not distribution.is_contextual 27 | 28 | self._eps = to_parameter(eps) 29 | 30 | super().__init__(mdp_info, distribution, policy) 31 | 32 | self._add_save_attr(_eps='mushroom') 33 | 34 | def _update(self, Jep, theta, context): 35 | eta_start = np.ones(1) 36 | 37 | res = minimize(REPS._dual_function, eta_start, 38 | jac=REPS._dual_function_diff, 39 | bounds=((np.finfo(np.float32).eps, np.inf),), 40 | args=(self._eps(), Jep, theta)) 41 | 42 | eta_opt = res.x.item() 43 | 44 | Jep -= np.max(Jep) 45 | 46 | d = np.exp(Jep / eta_opt) 47 | 48 | self.distribution.mle(theta, d) 49 | 50 | @staticmethod 51 | def _dual_function(eta_array, *args): 52 | eta = eta_array.item() 53 | eps, Jep, theta = args 54 | 55 | max_J = np.max(Jep) 56 | 57 | r = Jep - max_J 58 | sum1 = np.mean(np.exp(r / eta)) 59 | 60 | return eta * eps + eta * np.log(sum1) + max_J 61 | 62 | @staticmethod 63 | def _dual_function_diff(eta_array, *args): 64 | eta = eta_array.item() 65 | eps, Jep, theta = args 66 | 67 | max_J = np.max(Jep) 68 | 69 | r = Jep - max_J 70 | 71 | sum1 = np.mean(np.exp(r / eta)) 72 | sum2 = np.mean(np.exp(r / eta) * r) 73 | 74 | gradient = eps + np.log(sum1) - sum2 / (eta * sum1) 75 | 76 | return np.array([gradient]) 77 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/black_box_optimization/rwr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.policy_search.black_box_optimization import BlackBoxOptimization 4 | from mushroom_rl.rl_utils.parameters import to_parameter 5 | 6 | 7 | class RWR(BlackBoxOptimization): 8 | """ 9 | Reward-Weighted Regression algorithm. 10 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G., 11 | Peters J.. 2013. 12 | 13 | """ 14 | def __init__(self, mdp_info, distribution, policy, beta): 15 | """ 16 | Constructor. 17 | 18 | Args: 19 | beta ([float, Parameter]): the temperature for the exponential reward 20 | transformation. 21 | 22 | """ 23 | assert not distribution.is_contextual 24 | 25 | self._beta = to_parameter(beta) 26 | 27 | super().__init__(mdp_info, distribution, policy) 28 | 29 | self._add_save_attr(_beta='mushroom') 30 | 31 | def _update(self, Jep, theta, context): 32 | Jep -= np.max(Jep) 33 | 34 | d = np.exp(self._beta() * Jep) 35 | 36 | self.distribution.mle(theta, d) 37 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/policy_gradient/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy_gradient import PolicyGradient 2 | from .reinforce import REINFORCE 3 | from .gpomdp import GPOMDP 4 | from .enac import eNAC 5 | 6 | __all__ = ['REINFORCE', 'GPOMDP', 'eNAC'] -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/policy_gradient/enac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.policy_search.policy_gradient import PolicyGradient 4 | 5 | 6 | class eNAC(PolicyGradient): 7 | """ 8 | Episodic Natural Actor Critic algorithm. 9 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G., 10 | Peters J. 2013. 11 | 12 | """ 13 | def __init__(self, mdp_info, policy, optimizer, critic_features=None): 14 | """ 15 | Constructor. 16 | 17 | Args: 18 | critic_features (Features, None): features used by the critic. 19 | 20 | """ 21 | super().__init__(mdp_info, policy, optimizer) 22 | self.phi_c = critic_features 23 | 24 | self.sum_grad_log = None 25 | self.psi_ext = None 26 | self.sum_grad_log_list = list() 27 | 28 | self._add_save_attr( 29 | phi_c='pickle', 30 | sum_grad_log='numpy', 31 | psi_ext='pickle', 32 | sum_grad_log_list='pickle' 33 | ) 34 | 35 | def _compute_gradient(self, J): 36 | R = np.array(J) 37 | PSI = np.array(self.sum_grad_log_list) 38 | 39 | w_and_v = np.linalg.pinv(PSI).dot(R) 40 | nat_grad = w_and_v[:self.policy.weights_size] 41 | 42 | self.sum_grad_log_list = list() 43 | 44 | return nat_grad 45 | 46 | def _step_update(self, x, u, r): 47 | self.sum_grad_log += self.df*self.policy.diff_log(x, u) 48 | 49 | if self.psi_ext is None: 50 | if self.phi_c is None: 51 | self.psi_ext = np.ones(1) 52 | else: 53 | self.psi_ext = self.phi_c(x) 54 | 55 | def _episode_end_update(self): 56 | psi = np.concatenate((self.sum_grad_log, self.psi_ext)) 57 | self.sum_grad_log_list.append(psi) 58 | 59 | def _init_update(self): 60 | self.psi_ext = None 61 | self.sum_grad_log = np.zeros(self.policy.weights_size) 62 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.policy_search.policy_gradient import PolicyGradient 4 | 5 | 6 | class REINFORCE(PolicyGradient): 7 | """ 8 | REINFORCE algorithm. 9 | "Simple Statistical Gradient-Following Algorithms for Connectionist 10 | Reinforcement Learning", Williams R. J.. 1992. 11 | 12 | """ 13 | def __init__(self, mdp_info, policy, optimizer): 14 | super().__init__(mdp_info, policy, optimizer) 15 | self.sum_d_log_pi = None 16 | self.list_sum_d_log_pi = list() 17 | self.baseline_num = list() 18 | self.baseline_den = list() 19 | 20 | self._add_save_attr( 21 | sum_d_log_pi='numpy', 22 | list_sum_d_log_pi='pickle', 23 | baseline_num='pickle', 24 | baseline_den='pickle' 25 | ) 26 | 27 | # Ignore divide by zero 28 | np.seterr(divide='ignore', invalid='ignore') 29 | 30 | def _compute_gradient(self, J): 31 | baseline = np.mean(self.baseline_num, axis=0) / np.mean(self.baseline_den, axis=0) 32 | baseline[np.logical_not(np.isfinite(baseline))] = 0. 33 | grad_J_episode = list() 34 | for i, J_episode in enumerate(J): 35 | sum_d_log_pi = self.list_sum_d_log_pi[i] 36 | grad_J_episode.append(sum_d_log_pi * (J_episode - baseline)) 37 | 38 | grad_J = np.mean(grad_J_episode, axis=0) 39 | self.list_sum_d_log_pi = list() 40 | self.baseline_den = list() 41 | self.baseline_num = list() 42 | 43 | return grad_J 44 | 45 | def _step_update(self, x, u, r): 46 | d_log_pi = self.policy.diff_log(x, u) 47 | self.sum_d_log_pi += d_log_pi 48 | 49 | def _episode_end_update(self): 50 | self.list_sum_d_log_pi.append(self.sum_d_log_pi) 51 | squared_sum_d_log_pi = np.square(self.sum_d_log_pi) 52 | self.baseline_num.append(squared_sum_d_log_pi * self.J_episode) 53 | self.baseline_den.append(squared_sum_d_log_pi) 54 | 55 | def _init_update(self): 56 | self.sum_d_log_pi = np.zeros(self.policy.weights_size) 57 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_td import * 2 | from .dqn import * 3 | from .td import * 4 | 5 | __all__ = ['FQI', 'DoubleFQI', 'BoostedFQI', 'LSPI', 'AbstractDQN', 'DQN', 'DoubleDQN', 6 | 'AveragedDQN', 'CategoricalDQN', 'DuelingDQN', 'NoisyDQN', 'QuantileDQN', 7 | 'MaxminDQN', 'Rainbow', 'QLearning', 'QLambda', 'DoubleQLearning', 'WeightedQLearning', 8 | 'MaxminQLearning', 'SpeedyQLearning', 'RLearning', 'RQLearning', 9 | 'SARSA', 'SARSALambda', 'SARSALambdaContinuous', 'ExpectedSARSA', 10 | 'TrueOnlineSARSALambda'] 11 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/batch_td/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_td import BatchTD 2 | from .fqi import FQI 3 | from .double_fqi import DoubleFQI 4 | from .boosted_fqi import BoostedFQI 5 | from .lspi import LSPI 6 | 7 | __all__ = ['FQI', 'DoubleFQI', 'BoostedFQI', 'LSPI'] 8 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/batch_td/batch_td.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.core import Agent 2 | from mushroom_rl.approximators import Regressor 3 | 4 | 5 | class BatchTD(Agent): 6 | """ 7 | Abstract class to implement a generic Batch TD algorithm. 8 | 9 | """ 10 | def __init__(self, mdp_info, policy, approximator, approximator_params=None, fit_params=None): 11 | """ 12 | Constructor. 13 | 14 | Args: 15 | approximator (object): approximator used by the algorithm and the 16 | policy. 17 | approximator_params (dict, None): parameters of the approximator to 18 | build; 19 | fit_params (dict, None): parameters of the fitting algorithm of the 20 | approximator; 21 | 22 | """ 23 | approximator_params = dict() if approximator_params is None else\ 24 | approximator_params 25 | self._fit_params = dict() if fit_params is None else fit_params 26 | 27 | self.approximator = Regressor(approximator, **approximator_params) 28 | policy.set_q(self.approximator) 29 | 30 | self._add_save_attr( 31 | approximator='mushroom', 32 | _fit_params='pickle' 33 | ) 34 | 35 | super().__init__(mdp_info, policy) 36 | 37 | def _post_load(self): 38 | self.policy.set_q(self.approximator) 39 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/batch_td/boosted_fqi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import trange 3 | 4 | from .fqi import FQI 5 | 6 | 7 | class BoostedFQI(FQI): 8 | """ 9 | Boosted Fitted Q-Iteration algorithm. 10 | "Boosted Fitted Q-Iteration". Tosatto S. et al.. 2017. 11 | 12 | """ 13 | def __init__(self, mdp_info, policy, approximator, n_iterations, 14 | approximator_params=None, fit_params=None, quiet=False): 15 | self._prediction = 0. 16 | self._next_q = 0. 17 | self._idx = 0 18 | 19 | assert approximator_params['n_models'] == n_iterations 20 | 21 | self._add_save_attr( 22 | _n_iterations='primitive', 23 | _quiet='primitive', 24 | _prediction='primitive', 25 | _next_q='numpy', 26 | _idx='primitive', 27 | _target='pickle' 28 | ) 29 | 30 | super().__init__(mdp_info, policy, approximator, n_iterations, approximator_params, fit_params, quiet) 31 | 32 | def fit(self, dataset): 33 | state, action, reward, next_state, absorbing, _ = dataset.parse(to='numpy') 34 | for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): 35 | if self._target is None: 36 | self._target = reward 37 | else: 38 | self._next_q += self.approximator.predict(next_state, 39 | idx=self._idx - 1) 40 | if np.any(absorbing): 41 | self._next_q *= 1 - absorbing.reshape(-1, 1) 42 | 43 | max_q = np.max(self._next_q, axis=1) 44 | self._target = reward + self.mdp_info.gamma * max_q 45 | 46 | self._target -= self._prediction 47 | self._prediction += self._target 48 | 49 | self.approximator.fit(state, action, self._target, idx=self._idx, 50 | **self._fit_params) 51 | 52 | self._idx += 1 53 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/batch_td/double_fqi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import trange 3 | 4 | from .fqi import FQI 5 | 6 | 7 | class DoubleFQI(FQI): 8 | """ 9 | Double Fitted Q-Iteration algorithm. 10 | "Estimating the Maximum Expected Value in Continuous Reinforcement Learning 11 | Problems". D'Eramo C. et al.. 2017. 12 | 13 | """ 14 | def __init__(self, mdp_info, policy, approximator, n_iterations, 15 | approximator_params=None, fit_params=None, quiet=False): 16 | approximator_params['n_models'] = 2 17 | 18 | super().__init__(mdp_info, policy, approximator, n_iterations, 19 | approximator_params, fit_params, quiet) 20 | 21 | def fit(self, dataset): 22 | for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): 23 | state = list() 24 | action = list() 25 | reward = list() 26 | next_state = list() 27 | absorbing = list() 28 | 29 | half = len(dataset) // 2 30 | for i in range(2): 31 | s, a, r, ss, ab, _ = dataset[i * half:(i + 1) * half].parse(to='numpy') 32 | state.append(s) 33 | action.append(a) 34 | reward.append(r) 35 | next_state.append(ss) 36 | absorbing.append(ab) 37 | 38 | if self._target is None: 39 | self._target = reward 40 | else: 41 | for i in range(2): 42 | q_i = self.approximator.predict(next_state[i], idx=i) 43 | 44 | amax_q = np.expand_dims(np.argmax(q_i, axis=1), axis=1) 45 | max_q = self.approximator.predict(next_state[i], amax_q, 46 | idx=1 - i) 47 | if np.any(absorbing[i]): 48 | max_q *= 1 - absorbing[i] 49 | self._target[i] = reward[i] + self.mdp_info.gamma * max_q 50 | 51 | for i in range(2): 52 | self.approximator.fit(state[i], action[i], self._target[i], idx=i, 53 | **self._fit_params) 54 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/batch_td/fqi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import trange 3 | 4 | from mushroom_rl.algorithms.value.batch_td import BatchTD 5 | from mushroom_rl.rl_utils.parameters import to_parameter 6 | 7 | 8 | class FQI(BatchTD): 9 | """ 10 | Fitted Q-Iteration algorithm. 11 | "Tree-Based Batch Mode Reinforcement Learning", Ernst D. et al.. 2005. 12 | 13 | """ 14 | def __init__(self, mdp_info, policy, approximator, n_iterations, 15 | approximator_params=None, fit_params=None, quiet=False): 16 | """ 17 | Constructor. 18 | 19 | Args: 20 | n_iterations ([int, Parameter]): number of iterations to perform for training; 21 | quiet (bool, False): whether to show the progress bar or not. 22 | 23 | """ 24 | self._n_iterations = to_parameter(n_iterations) 25 | self._quiet = quiet 26 | self._target = None 27 | 28 | self._add_save_attr( 29 | _n_iterations='mushroom', 30 | _quiet='primitive', 31 | _target='pickle' 32 | ) 33 | 34 | super().__init__(mdp_info, policy, approximator, approximator_params, fit_params) 35 | 36 | def fit(self, dataset): 37 | state, action, reward, next_state, absorbing, _ = dataset.parse(to='numpy') 38 | for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False): 39 | if self._target is None: 40 | self._target = reward 41 | else: 42 | q = self.approximator.predict(next_state) 43 | if np.any(absorbing): 44 | q *= 1 - absorbing.reshape(-1, 1) 45 | 46 | max_q = np.max(q, axis=1) 47 | self._target = reward + self.mdp_info.gamma * max_q 48 | 49 | self.approximator.fit(state, action, self._target, **self._fit_params) 50 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/batch_td/lspi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.batch_td import BatchTD 4 | from mushroom_rl.approximators.parametric import LinearApproximator 5 | from mushroom_rl.features import get_action_features 6 | from mushroom_rl.rl_utils.parameters import to_parameter 7 | 8 | 9 | class LSPI(BatchTD): 10 | """ 11 | Least-Squares Policy Iteration algorithm. 12 | "Least-Squares Policy Iteration". Lagoudakis M. G. and Parr R.. 2003. 13 | 14 | """ 15 | def __init__(self, mdp_info, policy, approximator_params=None, epsilon=1e-2, fit_params=None): 16 | """ 17 | Constructor. 18 | 19 | Args: 20 | epsilon ([float, Parameter], 1e-2): termination coefficient. 21 | 22 | """ 23 | self._epsilon = to_parameter(epsilon) 24 | 25 | self._add_save_attr(_epsilon='mushroom') 26 | 27 | super().__init__(mdp_info, policy, LinearApproximator, approximator_params, fit_params) 28 | 29 | def fit(self, dataset): 30 | state, action, reward, next_state, absorbing, _ = dataset.parse(to='numpy') 31 | 32 | phi_state = self.approximator.model.phi(state) 33 | phi_next_state = self.approximator.model.phi(next_state) 34 | 35 | phi_state_action = get_action_features(phi_state, action, self.mdp_info.action_space.n) 36 | 37 | norm = np.inf 38 | while norm > self._epsilon(): 39 | q = self.approximator.predict(next_state) 40 | if np.any(absorbing): 41 | q *= 1 - absorbing.reshape(-1, 1) 42 | 43 | next_action = np.argmax(q, axis=1).reshape(-1, 1) 44 | phi_next_state_next_action = get_action_features(phi_next_state, next_action, self.mdp_info.action_space.n) 45 | 46 | tmp = phi_state_action - self.mdp_info.gamma * phi_next_state_next_action 47 | A = phi_state_action.T.dot(tmp) 48 | b = (phi_state_action.T.dot(reward)).reshape(-1, 1) 49 | 50 | old_w = self.approximator.get_weights() 51 | if np.linalg.matrix_rank(A) == A.shape[1]: 52 | w = np.linalg.solve(A, b).ravel() 53 | else: 54 | w = np.linalg.pinv(A).dot(b).ravel() 55 | self.approximator.set_weights(w) 56 | 57 | norm = np.linalg.norm(w - old_w) 58 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract_dqn import AbstractDQN 2 | from .dqn import DQN 3 | from .double_dqn import DoubleDQN 4 | from .averaged_dqn import AveragedDQN 5 | from .maxmin_dqn import MaxminDQN 6 | from .dueling_dqn import DuelingDQN 7 | from .categorical_dqn import CategoricalDQN 8 | from .noisy_dqn import NoisyDQN 9 | from .quantile_dqn import QuantileDQN 10 | from .rainbow import Rainbow 11 | 12 | 13 | __all__ = ['AbstractDQN', 'DQN', 'DoubleDQN', 'AveragedDQN', 'MaxminDQN', 14 | 'DuelingDQN', 'CategoricalDQN', 'NoisyDQN', 'QuantileDQN', 'Rainbow'] 15 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/dqn/averaged_dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.dqn import AbstractDQN 4 | from mushroom_rl.approximators.regressor import Regressor 5 | 6 | 7 | class AveragedDQN(AbstractDQN): 8 | """ 9 | Averaged-DQN algorithm. 10 | "Averaged-DQN: Variance Reduction and Stabilization for Deep Reinforcement 11 | Learning". Anschel O. et al.. 2017. 12 | 13 | """ 14 | def __init__(self, mdp_info, policy, approximator, n_approximators, 15 | **params): 16 | """ 17 | Constructor. 18 | 19 | Args: 20 | n_approximators (int): the number of target approximators to store. 21 | 22 | """ 23 | assert n_approximators > 1 24 | 25 | self._n_approximators = n_approximators 26 | 27 | super().__init__(mdp_info, policy, approximator, **params) 28 | 29 | self._n_fitted_target_models = 1 30 | 31 | self._add_save_attr(_n_fitted_target_models='primitive') 32 | 33 | def _initialize_regressors(self, approximator, apprx_params_train, 34 | apprx_params_target): 35 | self.approximator = Regressor(approximator, **apprx_params_train) 36 | self.target_approximator = Regressor(approximator, 37 | n_models=self._n_approximators, 38 | **apprx_params_target) 39 | for i in range(len(self.target_approximator)): 40 | self.target_approximator[i].set_weights( 41 | self.approximator.get_weights() 42 | ) 43 | 44 | def _update_target(self): 45 | idx = self._n_updates // self._target_update_frequency\ 46 | % self._n_approximators 47 | self.target_approximator[idx].set_weights( 48 | self.approximator.get_weights()) 49 | 50 | if self._n_fitted_target_models < self._n_approximators: 51 | self._n_fitted_target_models += 1 52 | 53 | def _next_q(self, next_state, absorbing): 54 | q = list() 55 | for idx in range(self._n_fitted_target_models): 56 | q_target_idx = self.target_approximator.predict(next_state, idx=idx, **self._predict_params) 57 | q.append(q_target_idx) 58 | q = np.mean(q, axis=0) 59 | if np.any(absorbing): 60 | q *= 1 - absorbing.reshape(-1, 1) 61 | 62 | return np.max(q, axis=1) 63 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/dqn/double_dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.dqn import DQN 4 | 5 | 6 | class DoubleDQN(DQN): 7 | """ 8 | Double DQN algorithm. 9 | "Deep Reinforcement Learning with Double Q-Learning". 10 | Hasselt H. V. et al.. 2016. 11 | 12 | """ 13 | def _next_q(self, next_state, absorbing): 14 | q = self.approximator.predict(next_state, **self._predict_params) 15 | max_a = np.argmax(q, axis=1) 16 | 17 | double_q = self.target_approximator.predict(next_state, max_a, **self._predict_params) 18 | if np.any(absorbing): 19 | double_q *= 1 - absorbing 20 | 21 | return double_q 22 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/dqn/dqn.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.algorithms.value.dqn import AbstractDQN 2 | 3 | 4 | class DQN(AbstractDQN): 5 | """ 6 | Deep Q-Network algorithm. 7 | "Human-Level Control Through Deep Reinforcement Learning". 8 | Mnih V. et al.. 2015. 9 | 10 | """ 11 | def _next_q(self, next_state, absorbing): 12 | q = self.target_approximator.predict(next_state, **self._predict_params) 13 | if absorbing.any(): 14 | q *= 1 - absorbing.reshape(-1, 1) 15 | 16 | return q.max(1) 17 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/dqn/maxmin_dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.dqn import DQN 4 | from mushroom_rl.approximators.regressor import Regressor 5 | 6 | 7 | class MaxminDQN(DQN): 8 | """ 9 | MaxminDQN algorithm. 10 | "Maxmin Q-learning: Controlling the Estimation Bias of Q-learning". 11 | Lan Q. et al.. 2020. 12 | 13 | """ 14 | def __init__(self, mdp_info, policy, approximator, n_approximators, **params): 15 | """ 16 | Constructor. 17 | 18 | Args: 19 | n_approximators (int): the number of approximators in the ensemble. 20 | 21 | """ 22 | assert n_approximators > 1 23 | 24 | self._n_approximators = n_approximators 25 | 26 | super().__init__(mdp_info, policy, approximator, **params) 27 | 28 | def fit(self, dataset): 29 | self._fit_params['idx'] = np.random.randint(self._n_approximators) 30 | 31 | super().fit(dataset) 32 | 33 | def _initialize_regressors(self, approximator, apprx_params_train, apprx_params_target): 34 | self.approximator = Regressor(approximator, 35 | n_models=self._n_approximators, 36 | prediction='min', **apprx_params_train) 37 | self.target_approximator = Regressor(approximator, 38 | n_models=self._n_approximators, 39 | prediction='min', 40 | **apprx_params_target) 41 | self._update_target() 42 | 43 | def _update_target(self): 44 | for i in range(len(self.target_approximator)): 45 | self.target_approximator[i].set_weights(self.approximator[i].get_weights()) 46 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/__init__.py: -------------------------------------------------------------------------------- 1 | from .td import TD 2 | from .sarsa import SARSA 3 | from .sarsa_lambda import SARSALambda 4 | from .expected_sarsa import ExpectedSARSA 5 | from .q_learning import QLearning 6 | from .q_lambda import QLambda 7 | from .double_q_learning import DoubleQLearning 8 | from .speedy_q_learning import SpeedyQLearning 9 | from .r_learning import RLearning 10 | from .weighted_q_learning import WeightedQLearning 11 | from .maxmin_q_learning import MaxminQLearning 12 | from .rq_learning import RQLearning 13 | from .sarsa_lambda_continuous import SARSALambdaContinuous 14 | from .true_online_sarsa_lambda import TrueOnlineSARSALambda 15 | 16 | __all__ = ['SARSA', 'SARSALambda', 'ExpectedSARSA', 'QLearning', 17 | 'QLambda', 'DoubleQLearning', 'SpeedyQLearning', 18 | 'RLearning', 'WeightedQLearning', 'MaxminQLearning', 19 | 'RQLearning', 'SARSALambdaContinuous', 'TrueOnlineSARSALambda'] 20 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/double_q_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | from mushroom_rl.algorithms.value.td import TD 5 | from mushroom_rl.approximators.ensemble_table import EnsembleTable 6 | 7 | 8 | class DoubleQLearning(TD): 9 | """ 10 | Double Q-Learning algorithm. 11 | "Double Q-Learning". Hasselt H. V.. 2010. 12 | 13 | """ 14 | def __init__(self, mdp_info, policy, learning_rate): 15 | Q = EnsembleTable(2, mdp_info.size) 16 | 17 | super().__init__(mdp_info, policy, Q, learning_rate) 18 | 19 | self._alpha_double = [deepcopy(self._alpha), deepcopy(self._alpha)] 20 | 21 | self._add_save_attr( 22 | _alpha_double='primitive' 23 | ) 24 | 25 | assert len(self.Q) == 2, 'The regressor ensemble must' \ 26 | ' have exactly 2 models.' 27 | 28 | def _update(self, state, action, reward, next_state, absorbing): 29 | approximator_idx = 0 if np.random.uniform() < .5 else 1 30 | 31 | q_current = self.Q[approximator_idx][state, action] 32 | 33 | if not absorbing: 34 | q_ss = self.Q[approximator_idx][next_state, :] 35 | max_q = np.max(q_ss) 36 | a_n = np.array( 37 | [np.random.choice(np.argwhere(q_ss == max_q).ravel())]) 38 | q_next = self.Q[1 - approximator_idx][next_state, a_n] 39 | else: 40 | q_next = 0. 41 | 42 | q = q_current + self._alpha_double[approximator_idx](state, action) * ( 43 | reward + self.mdp_info.gamma * q_next - q_current) 44 | 45 | self.Q[approximator_idx][state, action] = q 46 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/expected_sarsa.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.algorithms.value.td import TD 2 | from mushroom_rl.approximators.table import Table 3 | 4 | 5 | class ExpectedSARSA(TD): 6 | """ 7 | Expected SARSA algorithm. 8 | "A theoretical and empirical analysis of Expected Sarsa". Seijen H. V. et 9 | al.. 2009. 10 | 11 | """ 12 | def __init__(self, mdp_info, policy, learning_rate): 13 | Q = Table(mdp_info.size) 14 | 15 | super().__init__(mdp_info, policy, Q, learning_rate) 16 | 17 | def _update(self, state, action, reward, next_state, absorbing): 18 | q_current = self.Q[state, action] 19 | 20 | if not absorbing: 21 | q_next = self.Q[next_state, :].dot(self.policy(next_state)) 22 | else: 23 | q_next = 0. 24 | 25 | self.Q[state, action] = q_current + self._alpha(state, action) * ( 26 | reward + self.mdp_info.gamma * q_next - q_current) 27 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/maxmin_q_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | from mushroom_rl.algorithms.value.td import TD 5 | from mushroom_rl.approximators.ensemble_table import EnsembleTable 6 | 7 | 8 | class MaxminQLearning(TD): 9 | """ 10 | Maxmin Q-Learning algorithm without replay memory. 11 | "Maxmin Q-learning: Controlling the Estimation Bias of Q-learning". 12 | Lan Q. et al. 2019. 13 | 14 | """ 15 | def __init__(self, mdp_info, policy, learning_rate, n_tables): 16 | """ 17 | Constructor. 18 | 19 | Args: 20 | n_tables (int): number of tables in the ensemble. 21 | 22 | """ 23 | self._n_tables = n_tables 24 | Q = EnsembleTable(n_tables, mdp_info.size, prediction='min') 25 | 26 | super().__init__(mdp_info, policy, Q, learning_rate) 27 | 28 | self._alpha_mm = [deepcopy(self._alpha) for _ in range(n_tables)] 29 | 30 | self._add_save_attr(_n_tables='primitive', _alpha_mm='primitive') 31 | 32 | def _update(self, state, action, reward, next_state, absorbing): 33 | approximator_idx = np.random.choice(self._n_tables) 34 | 35 | q_current = self.Q[approximator_idx][state, action] 36 | 37 | if not absorbing: 38 | q_ss = self.Q.predict(next_state) 39 | q_next = np.max(q_ss) 40 | else: 41 | q_next = 0. 42 | 43 | q = q_current + self._alpha_mm[approximator_idx](state, action) * ( 44 | reward + self.mdp_info.gamma * q_next - q_current) 45 | 46 | self.Q[approximator_idx][state, action] = q 47 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/q_lambda.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.td import TD 4 | from mushroom_rl.rl_utils.eligibility_trace import EligibilityTrace 5 | from mushroom_rl.approximators.table import Table 6 | from mushroom_rl.rl_utils.parameters import to_parameter 7 | 8 | 9 | class QLambda(TD): 10 | """ 11 | Q(Lambda) algorithm. 12 | "Learning from Delayed Rewards". Watkins C.J.C.H.. 1989. 13 | 14 | """ 15 | def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, 16 | trace='replacing'): 17 | """ 18 | Constructor. 19 | 20 | Args: 21 | lambda_coeff ([float, Parameter]): eligibility trace coefficient; 22 | trace (str, 'replacing'): type of eligibility trace to use. 23 | 24 | """ 25 | Q = Table(mdp_info.size) 26 | self._lambda = to_parameter(lambda_coeff) 27 | 28 | self.e = EligibilityTrace(Q.shape, trace) 29 | self._add_save_attr( 30 | _lambda='mushroom', 31 | e='mushroom' 32 | ) 33 | 34 | super().__init__(mdp_info, policy, Q, learning_rate) 35 | 36 | def _update(self, state, action, reward, next_state, absorbing): 37 | q_current = self.Q[state, action] 38 | 39 | q_next = np.max(self.Q[next_state, :]) if not absorbing else 0. 40 | 41 | delta = reward + self.mdp_info.gamma*q_next - q_current 42 | self.e.update(state, action) 43 | 44 | self.Q.table += self._alpha(state, action) * delta * self.e.table 45 | self.e.table *= self.mdp_info.gamma * self._lambda() 46 | 47 | def episode_start(self, initial_state, episode_info): 48 | self.e.reset() 49 | 50 | return super().episode_start(initial_state, episode_info) 51 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/q_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.td import TD 4 | from mushroom_rl.approximators.table import Table 5 | 6 | 7 | class QLearning(TD): 8 | """ 9 | Q-Learning algorithm. 10 | "Learning from Delayed Rewards". Watkins C.J.C.H.. 1989. 11 | 12 | """ 13 | def __init__(self, mdp_info, policy, learning_rate): 14 | Q = Table(mdp_info.size) 15 | 16 | super().__init__(mdp_info, policy, Q, learning_rate) 17 | 18 | def _update(self, state, action, reward, next_state, absorbing): 19 | q_current = self.Q[state, action] 20 | 21 | q_next = np.max(self.Q[next_state, :]) if not absorbing else 0. 22 | 23 | self.Q[state, action] = q_current + self._alpha(state, action) * ( 24 | reward + self.mdp_info.gamma * q_next - q_current) 25 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/r_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.td import TD 4 | from mushroom_rl.approximators.table import Table 5 | 6 | from mushroom_rl.rl_utils.parameters import to_parameter 7 | 8 | 9 | class RLearning(TD): 10 | """ 11 | R-Learning algorithm. 12 | "A Reinforcement Learning Method for Maximizing Undiscounted Rewards". 13 | Schwartz A.. 1993. 14 | 15 | """ 16 | def __init__(self, mdp_info, policy, learning_rate, beta): 17 | """ 18 | Constructor. 19 | 20 | Args: 21 | beta ([float, Parameter]): beta coefficient. 22 | 23 | """ 24 | Q = Table(mdp_info.size) 25 | self._rho = 0. 26 | self._beta = to_parameter(beta) 27 | 28 | self._add_save_attr(_rho='primitive', _beta='mushroom') 29 | 30 | super().__init__(mdp_info, policy, Q, learning_rate) 31 | 32 | def _update(self, state, action, reward, next_state, absorbing): 33 | q_current = self.Q[state, action] 34 | q_next = np.max(self.Q[next_state, :]) if not absorbing else 0. 35 | delta = reward - self._rho + q_next - q_current 36 | q_new = q_current + self._alpha(state, action) * delta 37 | 38 | self.Q[state, action] = q_new 39 | 40 | q_max = np.max(self.Q[state, :]) 41 | if q_new == q_max: 42 | delta = reward + q_next - q_max - self._rho 43 | self._rho += self._beta(state, action) * delta 44 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/sarsa.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.algorithms.value.td import TD 2 | from mushroom_rl.approximators.table import Table 3 | 4 | 5 | class SARSA(TD): 6 | """ 7 | SARSA algorithm. 8 | 9 | """ 10 | def __init__(self, mdp_info, policy, learning_rate): 11 | Q = Table(mdp_info.size) 12 | 13 | super().__init__(mdp_info, policy, Q, learning_rate) 14 | 15 | def _update(self, state, action, reward, next_state, absorbing): 16 | q_current = self.Q[state, action] 17 | 18 | self.next_action, _ = self.draw_action(next_state) 19 | q_next = self.Q[next_state, self.next_action] if not absorbing else 0. 20 | 21 | self.Q[state, action] = q_current + self._alpha(state, action) * ( 22 | reward + self.mdp_info.gamma * q_next - q_current) 23 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/sarsa_lambda.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.algorithms.value.td import TD 2 | from mushroom_rl.rl_utils.eligibility_trace import EligibilityTrace 3 | from mushroom_rl.approximators.table import Table 4 | from mushroom_rl.rl_utils.parameters import to_parameter 5 | 6 | 7 | class SARSALambda(TD): 8 | """ 9 | The SARSA(lambda) algorithm for finite MDPs. 10 | 11 | """ 12 | def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, trace='replacing'): 13 | """ 14 | Constructor. 15 | 16 | Args: 17 | lambda_coeff ([float, Parameter]): eligibility trace coefficient; 18 | trace (str, 'replacing'): type of eligibility trace to use. 19 | 20 | """ 21 | Q = Table(mdp_info.size) 22 | self._lambda = to_parameter(lambda_coeff) 23 | 24 | self.e = EligibilityTrace(Q.shape, trace) 25 | self._add_save_attr( 26 | _lambda='mushroom', 27 | e='mushroom' 28 | ) 29 | 30 | super().__init__(mdp_info, policy, Q, learning_rate) 31 | 32 | def _update(self, state, action, reward, next_state, absorbing): 33 | q_current = self.Q[state, action] 34 | 35 | self.next_action, _ = self.draw_action(next_state) 36 | q_next = self.Q[next_state, self.next_action] if not absorbing else 0. 37 | 38 | delta = reward + self.mdp_info.gamma * q_next - q_current 39 | self.e.update(state, action) 40 | 41 | self.Q.table += self._alpha(state, action) * delta * self.e.table 42 | self.e.table *= self.mdp_info.gamma * self._lambda() 43 | 44 | def episode_start(self, initial_state, episode_info): 45 | self.e.reset() 46 | 47 | return super().episode_start(initial_state, episode_info) 48 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.algorithms.value.td import TD 4 | from mushroom_rl.approximators import Regressor 5 | from mushroom_rl.rl_utils.parameters import to_parameter 6 | 7 | 8 | class SARSALambdaContinuous(TD): 9 | """ 10 | Continuous version of SARSA(lambda) algorithm. 11 | 12 | """ 13 | def __init__(self, mdp_info, policy, approximator, learning_rate, lambda_coeff, approximator_params=None): 14 | """ 15 | Constructor. 16 | 17 | Args: 18 | lambda_coeff ([float, Parameter]): eligibility trace coefficient. 19 | 20 | """ 21 | approximator_params = dict() if approximator_params is None else approximator_params 22 | 23 | Q = Regressor(approximator, **approximator_params) 24 | self.e = np.zeros(Q.weights_size) 25 | self._lambda = to_parameter(lambda_coeff) 26 | 27 | self._add_save_attr( 28 | _lambda='primitive', 29 | e='numpy' 30 | ) 31 | 32 | super().__init__(mdp_info, policy, Q, learning_rate) 33 | 34 | def _update(self, state, action, reward, next_state, absorbing): 35 | q_current = self.Q.predict(state, action) 36 | 37 | alpha = self._alpha(state, action) 38 | 39 | self.e = self.mdp_info.gamma * self._lambda() * self.e + self.Q.diff(state, action) 40 | 41 | self.next_action, _ = self.draw_action(next_state) 42 | q_next = self.Q.predict(next_state, self.next_action) if not absorbing else 0. 43 | 44 | delta = reward + self.mdp_info.gamma * q_next - q_current 45 | 46 | theta = self.Q.get_weights() 47 | theta += alpha * delta * self.e 48 | self.Q.set_weights(theta) 49 | 50 | def episode_start(self, initial_state, episode_info): 51 | self.e = np.zeros(self.Q.weights_size) 52 | 53 | return super().episode_start(initial_state, episode_info) 54 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/speedy_q_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | from mushroom_rl.algorithms.value.td import TD 5 | from mushroom_rl.approximators.table import Table 6 | 7 | 8 | class SpeedyQLearning(TD): 9 | """ 10 | Speedy Q-Learning algorithm. 11 | "Speedy Q-Learning". Ghavamzadeh et. al.. 2011. 12 | 13 | """ 14 | def __init__(self, mdp_info, policy, learning_rate): 15 | Q = Table(mdp_info.size) 16 | self.old_q = deepcopy(Q) 17 | 18 | self._add_save_attr(old_q='mushroom') 19 | 20 | super().__init__(mdp_info, policy, Q, learning_rate) 21 | 22 | def _update(self, state, action, reward, next_state, absorbing): 23 | old_q = deepcopy(self.Q) 24 | 25 | max_q_cur = np.max(self.Q[next_state, :]) if not absorbing else 0. 26 | max_q_old = np.max(self.old_q[next_state, :]) if not absorbing else 0. 27 | 28 | target_cur = reward + self.mdp_info.gamma * max_q_cur 29 | target_old = reward + self.mdp_info.gamma * max_q_old 30 | 31 | alpha = self._alpha(state, action) 32 | q_cur = self.Q[state, action] 33 | self.Q[state, action] = q_cur + alpha * (target_old - q_cur) + ( 34 | 1. - alpha) * (target_cur - target_old) 35 | 36 | self.old_q = old_q 37 | -------------------------------------------------------------------------------- /mushroom_rl/algorithms/value/td/td.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.core import Agent 4 | 5 | 6 | class TD(Agent): 7 | """ 8 | Implements functions to run TD algorithms. 9 | 10 | """ 11 | def __init__(self, mdp_info, policy, approximator, learning_rate): 12 | """ 13 | Constructor. 14 | 15 | Args: 16 | approximator: the approximator to use to fit the Q-function; 17 | learning_rate (Parameter): the learning rate. 18 | 19 | """ 20 | self._alpha = learning_rate 21 | 22 | policy.set_q(approximator) 23 | self.Q = approximator 24 | 25 | self._add_save_attr(_alpha='mushroom', Q='mushroom') 26 | 27 | super().__init__(mdp_info, policy) 28 | 29 | def fit(self, dataset): 30 | assert len(dataset) == 1 31 | 32 | state, action, reward, next_state, absorbing, _ = dataset.item() 33 | self._update(state, action, reward, next_state, absorbing) 34 | 35 | def _update(self, state, action, reward, next_state, absorbing): 36 | """ 37 | Update the Q-table. 38 | 39 | Args: 40 | state (np.ndarray): state; 41 | action (np.ndarray): action; 42 | reward (np.ndarray): reward; 43 | next_state (np.ndarray): next state; 44 | absorbing (np.ndarray): absorbing flag. 45 | 46 | """ 47 | pass 48 | 49 | def _post_load(self): 50 | self.policy.set_q(self.Q) -------------------------------------------------------------------------------- /mushroom_rl/approximators/__init__.py: -------------------------------------------------------------------------------- 1 | from .regressor import Regressor 2 | 3 | -------------------------------------------------------------------------------- /mushroom_rl/approximators/_implementations/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mushroom_rl/approximators/_implementations/generic_regressor.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.core.serialization import Serializable 2 | 3 | 4 | class GenericRegressor(Serializable): 5 | """ 6 | This class is used to create a regressor that approximates a generic 7 | function. An arbitrary number of inputs and outputs is supported. 8 | 9 | """ 10 | def __init__(self, approximator, n_inputs, **params): 11 | """ 12 | Constructor. 13 | 14 | Args: 15 | approximator (class): the model class to approximate the 16 | a generic function; 17 | n_inputs (int): number of inputs of the regressor; 18 | **params: parameters dictionary to the regressor; 19 | 20 | """ 21 | self._n_inputs = n_inputs 22 | self.model = approximator(**params) 23 | 24 | self._add_save_attr( 25 | _n_inputs='primitive', 26 | model=self._get_serialization_method(approximator) 27 | ) 28 | 29 | def fit(self, *z, **fit_params): 30 | """ 31 | Fit the model. 32 | 33 | Args: 34 | *z: list of inputs and targets; 35 | **fit_params: other parameters used by the fit method of the 36 | regressor. 37 | 38 | """ 39 | self.model.fit(*z, **fit_params) 40 | 41 | def predict(self, *x, **predict_params): 42 | """ 43 | Predict. 44 | 45 | Args: 46 | x (list): list of inputs; 47 | **predict_params: other parameters used by the predict method 48 | the regressor. 49 | 50 | Returns: 51 | The predictions of the model. 52 | 53 | """ 54 | return self.model.predict(*x, **predict_params) 55 | 56 | def reset(self): 57 | """ 58 | Reset the model parameters. 59 | 60 | """ 61 | try: 62 | self.model.reset() 63 | except AttributeError: 64 | raise NotImplementedError('Attempt to reset weights of a' 65 | ' non-parametric regressor.') 66 | 67 | @property 68 | def weights_size(self): 69 | return self.model.weights_size 70 | 71 | def get_weights(self): 72 | return self.model.get_weights() 73 | 74 | def set_weights(self, w): 75 | self.model.set_weights(w) 76 | 77 | def diff(self, *x): 78 | return self.model.diff(*x) 79 | 80 | def __len__(self): 81 | return len(self.model) 82 | -------------------------------------------------------------------------------- /mushroom_rl/approximators/ensemble_table.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.approximators.table import Table 2 | from mushroom_rl.approximators.ensemble import Ensemble 3 | 4 | 5 | class EnsembleTable(Ensemble): 6 | """ 7 | This class implements functions to manage table ensembles. 8 | 9 | """ 10 | def __init__(self, n_models, shape, **params): 11 | """ 12 | Constructor. 13 | 14 | Args: 15 | n_models (int): number of models in the ensemble; 16 | shape (np.ndarray): shape of each table in the ensemble. 17 | **params: parameters dictionary to create each regressor. 18 | 19 | """ 20 | params['shape'] = shape 21 | super().__init__(Table, n_models, **params) 22 | 23 | @property 24 | def n_actions(self): 25 | return self._model[0].shape[-1] -------------------------------------------------------------------------------- /mushroom_rl/approximators/parametric/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import LinearApproximator 2 | from .torch_approximator import TorchApproximator, NumpyTorchApproximator 3 | from .cmac import CMAC 4 | 5 | -------------------------------------------------------------------------------- /mushroom_rl/approximators/parametric/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear_network import LinearNetwork -------------------------------------------------------------------------------- /mushroom_rl/approximators/parametric/networks/linear_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class LinearNetwork(nn.Module): 5 | def __init__(self, input_shape, output_shape, use_bias=False, gain=None, **kwargs): 6 | super().__init__() 7 | 8 | n_input = input_shape[-1] 9 | n_output = output_shape[0] 10 | 11 | self._f = nn.Linear(n_input, n_output, bias=use_bias) 12 | 13 | if gain is None: 14 | gain = nn.init.calculate_gain('linear') 15 | 16 | nn.init.xavier_uniform_(self._f.weight, gain=gain) 17 | 18 | def forward(self, state, **kwargs): 19 | return self._f(state) 20 | -------------------------------------------------------------------------------- /mushroom_rl/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .array_backend import ArrayBackend 2 | from .core import Core 3 | from .dataset import DatasetInfo, Dataset, VectorizedDataset 4 | from .environment import Environment, MDPInfo 5 | from .agent import Agent, AgentInfo 6 | from .serialization import Serializable 7 | from .logger import Logger 8 | 9 | from .extra_info import ExtraInfo 10 | 11 | from .vectorized_core import VectorCore 12 | from .vectorized_env import VectorizedEnvironment 13 | from .multiprocess_environment import MultiprocessEnvironment 14 | 15 | import mushroom_rl.environments 16 | 17 | __all__ = ['ArrayBackend', 'Core', 'DatasetInfo', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo', 18 | 'Serializable', 'Logger', 'ExtraInfo', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment'] 19 | -------------------------------------------------------------------------------- /mushroom_rl/core/_impl/__init__.py: -------------------------------------------------------------------------------- 1 | from .numpy_dataset import NumpyDataset 2 | from .torch_dataset import TorchDataset 3 | from .list_dataset import ListDataset 4 | from .core_logic import CoreLogic 5 | from .vectorized_core_logic import VectorizedCoreLogic 6 | -------------------------------------------------------------------------------- /mushroom_rl/core/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .console_logger import ConsoleLogger 2 | from .data_logger import DataLogger 3 | from .logger import Logger 4 | 5 | 6 | __all__ = ['Logger', 'ConsoleLogger', 'DataLogger'] 7 | -------------------------------------------------------------------------------- /mushroom_rl/core/logger/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | 4 | from .console_logger import ConsoleLogger 5 | from .data_logger import DataLogger 6 | 7 | 8 | class Logger(DataLogger, ConsoleLogger): 9 | """ 10 | This class implements the logging functionality. It can be used to create 11 | automatically a log directory, save numpy data array and the current agent. 12 | 13 | """ 14 | def __init__(self, log_name='', results_dir='./logs', log_console=False, 15 | use_timestamp=False, append=False, seed=None, **kwargs): 16 | """ 17 | Constructor. 18 | 19 | Args: 20 | log_name (string, ''): name of the current experiment directory if not 21 | specified, the current timestamp is used. 22 | results_dir (string, './logs'): name of the base logging directory. 23 | If set to None, no directory is created; 24 | log_console (bool, False): whether to log or not the console output; 25 | use_timestamp (bool, False): If true, adds the current timestamp to 26 | the folder name; 27 | append (bool, False): If true, the logger will append the new 28 | data logged to the one already existing in the directory; 29 | seed (int, None): seed for the current run. It can be optionally 30 | specified to add a seed suffix for each data file logged; 31 | **kwargs: other parameters for ConsoleLogger class. 32 | 33 | """ 34 | 35 | if log_console: 36 | assert results_dir is not None 37 | 38 | timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 39 | 40 | if not log_name: 41 | log_name = timestamp 42 | elif use_timestamp: 43 | log_name += '_' + timestamp 44 | 45 | if results_dir: 46 | results_dir = Path(results_dir) / log_name 47 | results_dir.mkdir(parents=True, exist_ok=True) 48 | 49 | suffix = '' if seed is None else '-' + str(seed) 50 | 51 | DataLogger.__init__(self, results_dir, suffix=suffix, append=append) 52 | ConsoleLogger.__init__(self, log_name, results_dir if log_console else None, 53 | suffix=suffix, **kwargs) 54 | -------------------------------------------------------------------------------- /mushroom_rl/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .distribution import Distribution 2 | from .gaussian import GaussianDistribution, GaussianDiagonalDistribution, GaussianCholeskyDistribution 3 | from .torch_distribution import AbstractGaussianTorchDistribution, DiagonalGaussianTorchDistribution 4 | from .torch_distribution import CholeskyGaussianTorchDistribution 5 | -------------------------------------------------------------------------------- /mushroom_rl/environments/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | Atari = None 3 | from .atari import Atari 4 | Atari.register() 5 | except ImportError: 6 | pass 7 | 8 | try: 9 | Gymnasium = None 10 | from .gymnasium_env import Gymnasium 11 | Gymnasium.register() 12 | except ImportError: 13 | pass 14 | 15 | try: 16 | DMControl = None 17 | from .dm_control_env import DMControl 18 | DMControl.register() 19 | except ImportError: 20 | pass 21 | 22 | try: 23 | MiniGrid = None 24 | from .minigrid_env import MiniGrid 25 | MiniGrid.register() 26 | except ImportError: 27 | pass 28 | 29 | try: 30 | iGibson = None 31 | from .igibson_env import iGibson 32 | iGibson.register() 33 | except ImportError: 34 | import logging 35 | logging.disable(logging.NOTSET) 36 | 37 | try: 38 | Habitat = None 39 | from .habitat_env import Habitat 40 | Habitat.register() 41 | except ImportError: 42 | pass 43 | 44 | try: 45 | MuJoCo = None 46 | from .mujoco import MuJoCo, MultiMuJoCo 47 | from .mujoco_envs import * 48 | except ImportError: 49 | pass 50 | 51 | try: 52 | OmniIsaacGymEnv = None 53 | from .omni_isaac_gym_env import OmniIsaacGymEnv 54 | except ImportError: 55 | pass 56 | 57 | try: 58 | PyBullet = None 59 | from .pybullet import PyBullet 60 | from .pybullet_envs import * 61 | except ImportError: 62 | pass 63 | 64 | try: 65 | IsaacSim = None 66 | from .isaacsim_env import IsaacSim 67 | except ImportError: 68 | pass 69 | 70 | from .generators.simple_chain import generate_simple_chain 71 | 72 | from .car_on_hill import CarOnHill 73 | CarOnHill.register() 74 | 75 | from .cart_pole import CartPole 76 | CartPole.register() 77 | 78 | from .finite_mdp import FiniteMDP 79 | FiniteMDP.register() 80 | 81 | from .grid_world import GridWorld, GridWorldVanHasselt 82 | GridWorld.register() 83 | GridWorldVanHasselt.register() 84 | 85 | from .inverted_pendulum import InvertedPendulum 86 | InvertedPendulum.register() 87 | 88 | from .lqr import LQR 89 | LQR.register() 90 | 91 | from .puddle_world import PuddleWorld 92 | PuddleWorld.register() 93 | 94 | from .segway import Segway 95 | Segway.register() 96 | 97 | from .ship_steering import ShipSteering 98 | ShipSteering.register() 99 | -------------------------------------------------------------------------------- /mushroom_rl/environments/finite_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.core import Environment, MDPInfo 4 | from mushroom_rl.rl_utils import spaces 5 | 6 | 7 | class FiniteMDP(Environment): 8 | """ 9 | Finite Markov Decision Process. 10 | 11 | """ 12 | def __init__(self, p, rew, mu=None, gamma=.9, horizon=np.inf, dt=1e-1): 13 | """ 14 | Constructor. 15 | 16 | Args: 17 | p (np.ndarray): transition probability matrix; 18 | rew (np.ndarray): reward matrix; 19 | mu (np.ndarray, None): initial state probability distribution; 20 | gamma (float, .9): discount factor; 21 | horizon (int, np.inf): the horizon; 22 | dt (float, 1e-1): the control timestep of the environment. 23 | 24 | """ 25 | assert p.shape == rew.shape 26 | assert mu is None or p.shape[0] == mu.size 27 | 28 | # MDP parameters 29 | self.p = p 30 | self.r = rew 31 | self.mu = mu 32 | 33 | # MDP properties 34 | observation_space = spaces.Discrete(p.shape[0]) 35 | action_space = spaces.Discrete(p.shape[1]) 36 | horizon = horizon 37 | gamma = gamma 38 | mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) 39 | 40 | super().__init__(mdp_info) 41 | 42 | def reset(self, state=None): 43 | if state is None: 44 | if self.mu is not None: 45 | self._state = np.array( 46 | [np.random.choice(self.mu.size, p=self.mu)]) 47 | else: 48 | self._state = np.array([np.random.choice(self.p.shape[0])]) 49 | else: 50 | self._state = state 51 | 52 | return self._state, {} 53 | 54 | def step(self, action): 55 | p = self.p[self._state[0], action[0], :] 56 | next_state = np.array([np.random.choice(p.size, p=p)]) 57 | absorbing = not np.any(self.p[next_state[0]]) 58 | reward = self.r[self._state[0], action[0], next_state[0]] 59 | 60 | self._state = next_state 61 | 62 | return self._state, reward, absorbing, {} 63 | -------------------------------------------------------------------------------- /mushroom_rl/environments/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from .simple_chain import generate_simple_chain 2 | from .grid_world import generate_grid_world 3 | from .taxi import generate_taxi 4 | -------------------------------------------------------------------------------- /mushroom_rl/environments/generators/simple_chain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.environments.finite_mdp import FiniteMDP 4 | 5 | 6 | def generate_simple_chain(state_n, goal_states, prob, rew, mu=None, gamma=.9, 7 | horizon=100): 8 | """ 9 | Simple chain generator. 10 | 11 | Args: 12 | state_n (int): number of states; 13 | goal_states (list): list of goal states; 14 | prob (float): probability of success of an action; 15 | rew (float): reward obtained in goal states; 16 | mu (np.ndarray): initial state probability distribution; 17 | gamma (float, .9): discount factor; 18 | horizon (int, 100): the horizon. 19 | 20 | Returns: 21 | A FiniteMDP object built with the provided parameters. 22 | 23 | """ 24 | p = compute_probabilities(state_n, prob) 25 | r = compute_reward(state_n, goal_states, rew) 26 | 27 | assert mu is None or len(mu) == state_n 28 | 29 | return FiniteMDP(p, r, mu, gamma, horizon) 30 | 31 | 32 | def compute_probabilities(state_n, prob): 33 | """ 34 | Compute the transition probability matrix. 35 | 36 | Args: 37 | state_n (int): number of states; 38 | prob (float): probability of success of an action. 39 | 40 | Returns: 41 | The transition probability matrix; 42 | 43 | """ 44 | p = np.zeros((state_n, 2, state_n)) 45 | 46 | for i in range(state_n): 47 | if i == 0: 48 | p[i, 1, i] = 1. 49 | else: 50 | p[i, 1, i] = 1. - prob 51 | p[i, 1, i - 1] = prob 52 | 53 | if i == state_n - 1: 54 | p[i, 0, i] = 1. 55 | else: 56 | p[i, 0, i] = 1. - prob 57 | p[i, 0, i + 1] = prob 58 | 59 | return p 60 | 61 | 62 | def compute_reward(state_n, goal_states, rew): 63 | """ 64 | Compute the reward matrix. 65 | 66 | Args: 67 | state_n (int): number of states; 68 | goal_states (list): list of goal states; 69 | rew (float): reward obtained in goal states. 70 | 71 | Returns: 72 | The reward matrix. 73 | 74 | """ 75 | r = np.zeros((state_n, 2, state_n)) 76 | 77 | for g in goal_states: 78 | if g != 0: 79 | r[g - 1, 0, g] = rew 80 | 81 | if g != state_n - 1: 82 | r[g + 1, 1, g] = rew 83 | 84 | return r 85 | -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .cartpole import CartPole 2 | from .a1_walking import A1Walking 3 | from .honey_badger_walking import HoneyBadgerWalking 4 | from .silver_badger_walking import SilverBadgerWalking -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/a1.usd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/a1.usd.png -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/instanceable_meshes.usd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/instanceable_meshes.usd.png -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/a1/a1.usd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/a1.usd -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/a1/instanceable_meshes.usd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/instanceable_meshes.usd -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/cartpole/cartpole.usd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/cartpole/cartpole.usd -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/honey_badger.usd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/honey_badger.usd.png -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/instanceable_meshes.usd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/instanceable_meshes.usd.png -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/honey_badger.usd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/honey_badger.usd -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/instanceable_meshes.usd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/instanceable_meshes.usd -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/instanceable_meshes.usd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/instanceable_meshes.usd.png -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/silver_badger.usd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/silver_badger.usd.png -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/instanceable_meshes.usd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/instanceable_meshes.usd -------------------------------------------------------------------------------- /mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/silver_badger.usd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/silver_badger.usd -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .ball_in_a_cup import BallInACup 2 | from .air_hockey import AirHockeyHit, AirHockeyDefend, AirHockeyPrepare, AirHockeyRepel 3 | from .ant import Ant 4 | from .half_cheetah import HalfCheetah 5 | from .hopper import Hopper 6 | from .walker_2d import Walker2D 7 | 8 | BallInACup.register() 9 | AirHockeyHit.register() 10 | AirHockeyDefend.register() 11 | AirHockeyPrepare.register() 12 | AirHockeyRepel.register() 13 | Ant.register() 14 | HalfCheetah.register() 15 | Hopper.register() 16 | Walker2D.register() 17 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/air_hockey/__init__.py: -------------------------------------------------------------------------------- 1 | from .hit import AirHockeyHit 2 | from .defend import AirHockeyDefend 3 | from .prepare import AirHockeyPrepare 4 | from .repel import AirHockeyRepel -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/air_hockey/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/air_hockey/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/air_hockey/double.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/air_hockey/single.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 13 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/air_hockey/table.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ant/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_convex.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_convex.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split1.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split10.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split10.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split11.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split11.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split12.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split12.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split13.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split13.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split14.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split14.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split15.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split15.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split16.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split16.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split17.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split17.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split18.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split18.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split2.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split3.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split4.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split5.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split6.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split7.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split8.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split8.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split9.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split9.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_convex.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_convex.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p1.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p2.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p1.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p2.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p3.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_convex.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_convex.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p1.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p2.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_convex.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_convex.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p1.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p2.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p3.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p1.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p2.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_fine.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/half_cheetah/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/half_cheetah/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/hopper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/hopper/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/cube.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/hand.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/link0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link0.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link1.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link2.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link3.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link4.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link6.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/link7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link7.stl -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/assets/table.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/peg_insertion.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/pick.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/push.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/panda/reach.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /mushroom_rl/environments/mujoco_envs/data/walker_2d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/walker_2d/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/environments/pybullet_envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .air_hockey import * 2 | -------------------------------------------------------------------------------- /mushroom_rl/environments/pybullet_envs/air_hockey/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .hit import AirHockeyHitBullet 3 | from .defend import AirHockeyDefendBullet 4 | from .prepare import AirHockeyPrepareBullet 5 | from .repel import AirHockeyRepelBullet 6 | 7 | 8 | AirHockeyHitBullet.register() 9 | AirHockeyDefendBullet.register() 10 | AirHockeyPrepareBullet.register() 11 | AirHockeyRepelBullet.register() 12 | except ImportError: 13 | pass 14 | -------------------------------------------------------------------------------- /mushroom_rl/environments/pybullet_envs/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/pybullet_envs/data/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/features/__init__.py: -------------------------------------------------------------------------------- 1 | from .features import Features, get_action_features 2 | 3 | __all__ = ['Features', 'get_action_features'] 4 | -------------------------------------------------------------------------------- /mushroom_rl/features/_implementations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/features/_implementations/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/features/_implementations/basis_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .features_implementation import FeaturesImplementation 4 | 5 | 6 | class BasisFeatures(FeaturesImplementation): 7 | def __init__(self, basis): 8 | self._basis = basis 9 | 10 | def __call__(self, *args): 11 | x = self._concatenate(args) 12 | 13 | y = list() 14 | 15 | x = np.atleast_2d(x) 16 | for s in x: 17 | out = np.empty(self.size) 18 | 19 | for i, bf in enumerate(self._basis): 20 | out[i] = bf(s) 21 | 22 | y.append(out) 23 | 24 | if len(y) == 1: 25 | y = y[0] 26 | else: 27 | y = np.array(y) 28 | 29 | return y 30 | 31 | @property 32 | def size(self): 33 | return len(self._basis) 34 | -------------------------------------------------------------------------------- /mushroom_rl/features/_implementations/features_implementation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class FeaturesImplementation(object): 5 | def __call__(self, *x): 6 | """ 7 | Evaluate the feature vector in the given raw input. If more than one 8 | element is passed, the raw input is concatenated before computing the 9 | features. 10 | 11 | Args: 12 | *x (list): the raw input. 13 | 14 | Returns: 15 | The features vector computed from the raw input. 16 | 17 | """ 18 | pass 19 | 20 | @staticmethod 21 | def _concatenate(args): 22 | if len(args) > 1: 23 | x = np.concatenate(args, axis=-1) 24 | else: 25 | x = args[0] 26 | 27 | return x 28 | 29 | @property 30 | def size(self): 31 | """ 32 | Returns: 33 | The number of elements in the features vector. 34 | 35 | """ 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /mushroom_rl/features/_implementations/functional_features.py: -------------------------------------------------------------------------------- 1 | from .features_implementation import FeaturesImplementation 2 | 3 | 4 | class FunctionalFeatures(FeaturesImplementation): 5 | def __init__(self, n_outputs, function): 6 | self._n_outputs = n_outputs 7 | self._function = function if function is not None else self._identity 8 | 9 | def __call__(self, *args): 10 | x = self._concatenate(args) 11 | 12 | return self._function(x) 13 | 14 | def _identity(self, x): 15 | return x 16 | 17 | @property 18 | def size(self): 19 | return self._n_outputs 20 | -------------------------------------------------------------------------------- /mushroom_rl/features/_implementations/tiles_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .features_implementation import FeaturesImplementation 4 | 5 | 6 | class TilesFeatures(FeaturesImplementation): 7 | def __init__(self, tiles): 8 | 9 | if isinstance(tiles, list): 10 | self._tiles = tiles 11 | else: 12 | self._tiles = [tiles] 13 | self._size = 0 14 | 15 | for tiling in self._tiles: 16 | self._size += tiling.size 17 | 18 | def __call__(self, *args): 19 | x = self._concatenate(args) 20 | 21 | y = list() 22 | 23 | x = np.atleast_2d(x) 24 | for s in x: 25 | out = np.zeros(self._size) 26 | 27 | offset = 0 28 | for tiling in self._tiles: 29 | index = tiling(s) 30 | 31 | if index is not None: 32 | out[index + offset] = 1. 33 | 34 | offset += tiling.size 35 | 36 | y.append(out) 37 | 38 | if len(y) == 1: 39 | y = y[0] 40 | else: 41 | y = np.array(y) 42 | 43 | return y 44 | 45 | def compute_indexes(self, *args): 46 | x = self._concatenate(args) 47 | 48 | y = list() 49 | 50 | x = np.atleast_2d(x) 51 | for s in x: 52 | out = list() 53 | 54 | offset = 0 55 | for tiling in self._tiles: 56 | index = tiling(s) 57 | 58 | if index is not None: 59 | out.append(index + offset) 60 | 61 | offset += tiling.size 62 | 63 | y.append(out) 64 | 65 | if len(y) == 1: 66 | return y[0] 67 | else: 68 | return y 69 | 70 | @property 71 | def size(self): 72 | return self._size 73 | -------------------------------------------------------------------------------- /mushroom_rl/features/_implementations/torch_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .features_implementation import FeaturesImplementation 5 | from mushroom_rl.utils.torch import TorchUtils 6 | 7 | 8 | class TorchFeatures(FeaturesImplementation): 9 | def __init__(self, tensor_list): 10 | self._phi = tensor_list 11 | 12 | def __call__(self, *args): 13 | x = self._concatenate(args) 14 | 15 | x = TorchUtils.to_float_tensor(np.atleast_2d(x)) 16 | 17 | y_list = [self._phi[i].forward(x) for i in range(len(self._phi))] 18 | y = torch.cat(y_list, 1).squeeze() 19 | 20 | y = y.detach().cpu().numpy() 21 | 22 | if y.shape[0] == 1: 23 | return y[0] 24 | else: 25 | return y 26 | 27 | @property 28 | def size(self): 29 | return np.sum([phi.size for phi in self._phi]) 30 | -------------------------------------------------------------------------------- /mushroom_rl/features/basis/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_rbf import GaussianRBF 2 | from .polynomial import PolynomialBasis 3 | from .fourier import FourierBasis 4 | 5 | __all__ = ['GaussianRBF', 'PolynomialBasis', 'FourierBasis'] 6 | -------------------------------------------------------------------------------- /mushroom_rl/features/tensors/__init__.py: -------------------------------------------------------------------------------- 1 | from .basis_tensor import GenericBasisTensor, GaussianRBFTensor, VonMisesBFTensor 2 | from .constant_tensor import ConstantTensor 3 | from .random_fourier_tensor import RandomFourierBasis 4 | 5 | __all_ = ['GenericBasisTensor', 'GaussianRBFTensor', 'VonMisesBFTensor', 'ConstantTensor', 'RandomFourierBasis'] -------------------------------------------------------------------------------- /mushroom_rl/features/tensors/constant_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from mushroom_rl.utils.torch import TorchUtils 5 | 6 | 7 | class ConstantTensor(nn.Module): 8 | """ 9 | Pytorch module to implement a constant function (always one). 10 | 11 | """ 12 | 13 | def forward(self, x): 14 | return torch.ones(x.shape[0], 1).to(TorchUtils.get_device()) 15 | 16 | @property 17 | def size(self): 18 | return 1 19 | -------------------------------------------------------------------------------- /mushroom_rl/features/tiles/__init__.py: -------------------------------------------------------------------------------- 1 | from .tiles import Tiles 2 | from .voronoi import VoronoiTiles 3 | 4 | __all__ = ['Tiles', 'VoronoiTiles'] 5 | -------------------------------------------------------------------------------- /mushroom_rl/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy import Policy, ParametricPolicy 2 | from .vector_policy import VectorPolicy 3 | from .noise_policy import OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy 4 | from .td_policy import TDPolicy, Boltzmann, EpsGreedy, Mellowmax 5 | from .gaussian_policy import GaussianPolicy, DiagonalGaussianPolicy, \ 6 | StateStdGaussianPolicy, StateLogStdGaussianPolicy 7 | from .deterministic_policy import DeterministicPolicy 8 | from .torch_policy import TorchPolicy, GaussianTorchPolicy, BoltzmannTorchPolicy 9 | from .recurrent_torch_policy import RecurrentGaussianTorchPolicy 10 | from .promps import ProMP 11 | from .dmp import DMP 12 | 13 | 14 | __all_td__ = ['TDPolicy', 'Boltzmann', 'EpsGreedy', 'Mellowmax'] 15 | __all_parametric__ = ['ParametricPolicy', 'GaussianPolicy', 16 | 'DiagonalGaussianPolicy', 'StateStdGaussianPolicy', 17 | 'StateLogStdGaussianPolicy', 'ProMP'] 18 | __all_torch__ = ['TorchPolicy', 'GaussianTorchPolicy', 'BoltzmannTorchPolicy'] 19 | __all_noise__ = ['OrnsteinUhlenbeckPolicy', 'ClippedGaussianPolicy'] 20 | __all_mp__ = ['ProMP', 'DMP'] 21 | 22 | __all__ = ['Policy', 'DeterministicPolicy', ] \ 23 | + __all_td__ + __all_parametric__ + __all_torch__ + __all_noise__ + __all_mp__ 24 | -------------------------------------------------------------------------------- /mushroom_rl/policy/deterministic_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .policy import ParametricPolicy 3 | 4 | 5 | class DeterministicPolicy(ParametricPolicy): 6 | """ 7 | Simple parametric policy representing a deterministic policy. As 8 | deterministic policies are degenerate probability functions where all 9 | the probability mass is on the deterministic action,they are not 10 | differentiable, even if the mean value approximator is differentiable. 11 | 12 | """ 13 | def __init__(self, mu, policy_state_shape=None): 14 | """ 15 | Constructor. 16 | 17 | Args: 18 | mu (Regressor): the regressor representing the action to select 19 | in each state. 20 | 21 | """ 22 | super().__init__(policy_state_shape) 23 | 24 | self._approximator = mu 25 | self._predict_params = dict() 26 | 27 | self._add_save_attr(_approximator='mushroom', 28 | _predict_params='pickle') 29 | 30 | def get_regressor(self): 31 | """ 32 | Getter. 33 | 34 | Returns: 35 | The regressor that is used to map state to actions. 36 | 37 | """ 38 | return self._approximator 39 | 40 | def __call__(self, state, action, policy_state=None): 41 | policy_action = self._approximator.predict(state, **self._predict_params) 42 | 43 | return 1. if np.array_equal(action, policy_action) else 0. 44 | 45 | def draw_action(self, state, policy_state=None): 46 | return self._approximator.predict(state, **self._predict_params), None 47 | 48 | def set_weights(self, weights): 49 | self._approximator.set_weights(weights) 50 | 51 | def get_weights(self): 52 | return self._approximator.get_weights() 53 | 54 | @property 55 | def weights_size(self): 56 | return self._approximator.weights_size 57 | -------------------------------------------------------------------------------- /mushroom_rl/rl_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .eligibility_trace import EligibilityTrace, ReplacingTrace, AccumulatingTrace 2 | from .optimizers import Optimizer, AdamOptimizer, SGDOptimizer, AdaptiveOptimizer 3 | from .parameters import Parameter, DecayParameter, LinearParameter, to_parameter 4 | from .preprocessors import StandardizationPreprocessor, MinMaxPreprocessor 5 | from .replay_memory import ReplayMemory, PrioritizedReplayMemory 6 | from .running_stats import RunningStandardization, RunningAveragedWindow, RunningExpWeightedAverage 7 | from .spaces import Box, Discrete 8 | from .value_functions import compute_advantage, compute_advantage_montecarlo, compute_gae 9 | from .variance_parameters import VarianceDecreasingParameter, VarianceIncreasingParameter 10 | from .variance_parameters import WindowedVarianceParameter, WindowedVarianceIncreasingParameter -------------------------------------------------------------------------------- /mushroom_rl/rl_utils/eligibility_trace.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.approximators.table import Table 2 | 3 | 4 | def EligibilityTrace(shape, name='replacing'): 5 | """ 6 | Factory method to create an eligibility trace of the provided type. 7 | 8 | Args: 9 | shape (list): shape of the eligibility trace table; 10 | name (str, 'replacing'): type of the eligibility trace. 11 | 12 | Returns: 13 | The eligibility trace table of the provided shape and type. 14 | 15 | """ 16 | if name == 'replacing': 17 | return ReplacingTrace(shape) 18 | elif name == 'accumulating': 19 | return AccumulatingTrace(shape) 20 | else: 21 | raise ValueError('Unknown type of trace.') 22 | 23 | 24 | class ReplacingTrace(Table): 25 | """ 26 | Replacing trace. 27 | 28 | """ 29 | def reset(self): 30 | self.table[:] = 0. 31 | 32 | def update(self, state, action): 33 | self.table[state, action] = 1. 34 | 35 | 36 | class AccumulatingTrace(Table): 37 | """ 38 | Accumulating trace. 39 | 40 | """ 41 | def reset(self): 42 | self.table[:] = 0. 43 | 44 | def update(self, state, action): 45 | self.table[state, action] += 1. 46 | -------------------------------------------------------------------------------- /mushroom_rl/solvers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/solvers/__init__.py -------------------------------------------------------------------------------- /mushroom_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .angles import normalize_angle_positive, normalize_angle, shortest_angular_distance 2 | from .angles import quat_to_euler, euler_to_quat, euler_to_mat, mat_to_euler 3 | from .features import uniform_grid 4 | from .frames import LazyFrames, preprocess_frame 5 | from .numerical_gradient import numerical_diff_dist, numerical_diff_function, numerical_diff_policy 6 | from .minibatches import minibatch_number, minibatch_generator 7 | from .plot import plot_mean_conf, get_mean_and_confidence 8 | from .record import VideoRecorder 9 | from .torch import TorchUtils, CategoricalWrapper 10 | from .viewer import Viewer, CV2Viewer, ImageViewer 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /mushroom_rl/utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | __extras__ = [] 2 | 3 | try: 4 | from mushroom_rl.utils.callbacks.plot_dataset import PlotDataset 5 | __extras__.append('PlotDataset') 6 | except ImportError: 7 | pass 8 | 9 | from .callback import Callback, CallbackList 10 | from .collect_dataset import CollectDataset 11 | from .collect_max_q import CollectMaxQ 12 | from .collect_q import CollectQ 13 | from .collect_parameters import CollectParameters 14 | 15 | __all__ = ['Callback', 'CollectDataset', 'CollectQ', 'CollectMaxQ', 16 | 'CollectParameters'] + __extras__ 17 | -------------------------------------------------------------------------------- /mushroom_rl/utils/callbacks/callback.py: -------------------------------------------------------------------------------- 1 | class Callback(object): 2 | """ 3 | Interface for all basic callbacks. Implements a list in which it is possible 4 | to store data and methods to query and clean the content stored by the 5 | callback. 6 | 7 | """ 8 | def __call__(self, dataset): 9 | """ 10 | Add samples to the samples list. 11 | 12 | Args: 13 | dataset (Dataset): the samples to collect. 14 | 15 | """ 16 | raise NotImplementedError 17 | 18 | def get(self): 19 | """ 20 | Returns: 21 | The current collected data. 22 | 23 | """ 24 | raise NotImplementedError 25 | 26 | def clean(self): 27 | """ 28 | Delete the current stored data 29 | 30 | """ 31 | raise NotImplementedError 32 | 33 | 34 | class CallbackList(Callback): 35 | """ 36 | Simple interface for callbacks storing a single list for data collection 37 | 38 | """ 39 | def __init__(self): 40 | self._data_list = list() 41 | 42 | def get(self): 43 | """ 44 | Returns: 45 | The current collected data. 46 | 47 | """ 48 | return self._data_list 49 | 50 | def clean(self): 51 | """ 52 | Delete the current stored data 53 | 54 | """ 55 | self._data_list = list() 56 | 57 | -------------------------------------------------------------------------------- /mushroom_rl/utils/callbacks/collect_dataset.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.utils.callbacks.callback import Callback 2 | from mushroom_rl.core.dataset import VectorizedDataset 3 | 4 | 5 | class CollectDataset(Callback): 6 | """ 7 | This callback can be used to collect samples during the learning of the 8 | agent. 9 | 10 | """ 11 | def __init__(self): 12 | """ 13 | Constructor. 14 | """ 15 | self._dataset = None 16 | 17 | def __call__(self, dataset): 18 | if isinstance(dataset, VectorizedDataset): 19 | dataset = dataset.flatten() 20 | 21 | if self._dataset is None: 22 | self._dataset = dataset.copy() 23 | else: 24 | self._dataset += dataset 25 | 26 | def clean(self): 27 | self._dataset.clear() 28 | 29 | def get(self): 30 | return self._dataset 31 | -------------------------------------------------------------------------------- /mushroom_rl/utils/callbacks/collect_max_q.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.utils.callbacks.callback import CallbackList 2 | import numpy as np 3 | 4 | 5 | class CollectMaxQ(CallbackList): 6 | """ 7 | This callback can be used to collect the maximum action value in a given 8 | state at each call. 9 | 10 | """ 11 | def __init__(self, approximator, state): 12 | """ 13 | Constructor. 14 | 15 | Args: 16 | approximator ([Table, EnsembleTable]): the approximator to use; 17 | state (np.ndarray): the state to consider. 18 | 19 | """ 20 | self._approximator = approximator 21 | self._state = state 22 | 23 | super().__init__() 24 | 25 | def __call__(self, dataset): 26 | q = self._approximator.predict(self._state) 27 | max_q = np.max(q) 28 | 29 | self._data_list.append(max_q) 30 | -------------------------------------------------------------------------------- /mushroom_rl/utils/callbacks/collect_parameters.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.utils.callbacks.callback import CallbackList 2 | import numpy as np 3 | 4 | 5 | class CollectParameters(CallbackList): 6 | """ 7 | This callback can be used to collect the values of a parameter 8 | (e.g. learning rate) during a run of the agent. 9 | 10 | """ 11 | def __init__(self, parameter, *idx): 12 | """ 13 | Constructor. 14 | 15 | Args: 16 | parameter (Parameter): the parameter whose values have to be 17 | collected; 18 | *idx (list): index of the parameter when the ``parameter`` is 19 | tabular. 20 | 21 | """ 22 | self._parameter = parameter 23 | self._idx = idx 24 | 25 | super().__init__() 26 | 27 | def __call__(self, dataset): 28 | value = self._parameter.get_value(*self._idx) 29 | if isinstance(value, np.ndarray): 30 | value = np.array(value) 31 | self._data_list.append(value) 32 | -------------------------------------------------------------------------------- /mushroom_rl/utils/callbacks/collect_q.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | from mushroom_rl.utils.callbacks.callback import CallbackList 5 | from mushroom_rl.approximators.ensemble_table import EnsembleTable 6 | 7 | 8 | class CollectQ(CallbackList): 9 | """ 10 | This callback can be used to collect the action values in all states at the 11 | current time step. 12 | 13 | """ 14 | def __init__(self, approximator): 15 | """ 16 | Constructor. 17 | 18 | Args: 19 | approximator ([Table, EnsembleTable]): the approximator to use to 20 | predict the action values. 21 | 22 | """ 23 | self._approximator = approximator 24 | 25 | super().__init__() 26 | 27 | def __call__(self, dataset): 28 | if isinstance(self._approximator, EnsembleTable): 29 | qs = list() 30 | for m in self._approximator.model: 31 | qs.append(m.table) 32 | self._data_list.append(deepcopy(np.mean(qs, 0))) 33 | else: 34 | self._data_list.append(deepcopy(self._approximator.table)) 35 | -------------------------------------------------------------------------------- /mushroom_rl/utils/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def uniform_grid(n_centers, low, high, eta=0.25, cyclic=False): 5 | """ 6 | This function is used to create the parameters of uniformly spaced radial 7 | basis functions with `eta` of overlap. It creates a uniformly spaced grid of 8 | ``n_centers[i]`` points in each dimension i. Also returns a vector 9 | containing the appropriate width of the radial basis functions. 10 | 11 | Args: 12 | n_centers (list): number of centers of each dimension; 13 | low (np.ndarray): lowest value for each dimension; 14 | high (np.ndarray): highest value for each dimension; 15 | eta (float, 0.25): overlap between two radial basis functions; 16 | cyclic (bool, False): whether the state space is a ring or not 17 | 18 | Returns: 19 | The uniformly spaced grid and the width vector. 20 | 21 | """ 22 | assert 0 < eta < 1.0 23 | 24 | n_features = len(low) 25 | w = np.zeros(n_features) 26 | c = list() 27 | tot_points = 1 28 | for i, n in enumerate(n_centers): 29 | start = low[i] 30 | end = high[i] 31 | # m = abs(start - end) / n 32 | if n == 1: 33 | w[i] = abs(end - start) / 2 34 | c_i = (start + end) / 2. 35 | c.append(np.array([c_i])) 36 | else: 37 | if cyclic: 38 | end_new = end - abs(end-start) / n 39 | else: 40 | end_new = end 41 | w[i] = (1 + eta) * abs(end_new - start) / n 42 | c_i = np.linspace(start, end_new, n) 43 | c.append(c_i) 44 | tot_points *= n 45 | 46 | n_rows = 1 47 | n_cols = 0 48 | 49 | grid = np.zeros((tot_points, n_features)) 50 | 51 | for discrete_values in c: 52 | i1 = 0 53 | dim = len(discrete_values) 54 | 55 | for i in range(dim): 56 | for r in range(n_rows): 57 | idx_r = r + i * n_rows 58 | for c in range(n_cols): 59 | grid[idx_r, c] = grid[r, c] 60 | grid[idx_r, n_cols] = discrete_values[i1] 61 | 62 | i1 += 1 63 | 64 | n_cols += 1 65 | n_rows *= len(discrete_values) 66 | 67 | return grid, w 68 | -------------------------------------------------------------------------------- /mushroom_rl/utils/frames.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | cv2.ocl.setUseOpenCL(False) 5 | 6 | 7 | class LazyFrames(object): 8 | """ 9 | From OpenAI Baseline. 10 | https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py 11 | 12 | This class provides a solution to optimize the use of memory when 13 | concatenating different frames, e.g. Atari frames in DQN. The frames are 14 | individually stored in a list and, when numpy arrays containing them are 15 | created, the reference to each frame is used instead of a copy. 16 | 17 | """ 18 | def __init__(self, frames, history_length): 19 | self._frames = frames 20 | 21 | assert len(self._frames) == history_length 22 | 23 | def __array__(self, dtype=None): 24 | out = np.array(self._frames) 25 | if dtype is not None: 26 | out = out.astype(dtype) 27 | 28 | return out 29 | 30 | def copy(self): 31 | return self 32 | 33 | @property 34 | def shape(self): 35 | return (len(self._frames),) + self._frames[0].shape 36 | 37 | 38 | def preprocess_frame(obs, img_size): 39 | """ 40 | Convert a frame from rgb to grayscale and resize it. 41 | 42 | Args: 43 | obs (np.ndarray): array representing an rgb frame; 44 | img_size (tuple): target size for images. 45 | 46 | Returns: 47 | The transformed frame as 8 bit integer array. 48 | 49 | """ 50 | image = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY) 51 | image = cv2.resize(image, img_size, interpolation=cv2.INTER_LINEAR) 52 | 53 | return np.array(image, dtype=np.uint8) 54 | -------------------------------------------------------------------------------- /mushroom_rl/utils/isaac_sim/__init__.py: -------------------------------------------------------------------------------- 1 | from .observation_helper import ObservationHelper, ObservationType 2 | from .collision_helper import CollisionHelper 3 | from .action_helper import ActionType -------------------------------------------------------------------------------- /mushroom_rl/utils/isaac_sim/action_helper.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class ActionType(Enum): 4 | EFFORT = "joint_efforts" 5 | POSITION = "joint_positions" 6 | VELOCITY = "joint_velocities" -------------------------------------------------------------------------------- /mushroom_rl/utils/isaac_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def convert_task_observation(observation): 5 | obs_t = observation 6 | for _ in range(5): 7 | if torch.is_tensor(obs_t): 8 | break 9 | obs_t = obs_t[list(obs_t.keys())[0]] 10 | return obs_t -------------------------------------------------------------------------------- /mushroom_rl/utils/minibatches.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def minibatch_number(size, batch_size): 5 | """ 6 | Function to retrieve the number of batches, given a batch sizes. 7 | 8 | Args: 9 | size (int): size of the dataset; 10 | batch_size (int): size of the batches. 11 | 12 | Returns: 13 | The number of minibatches in the dataset. 14 | 15 | """ 16 | return int(np.ceil(size / batch_size)) 17 | 18 | 19 | def minibatch_generator(batch_size, *dataset): 20 | """ 21 | Generator that creates a minibatch from the full dataset. 22 | 23 | Args: 24 | batch_size (int): the maximum size of each minibatch; 25 | dataset: the dataset to be splitted. 26 | 27 | Returns: 28 | The current minibatch. 29 | 30 | """ 31 | size = len(dataset[0]) 32 | num_batches = minibatch_number(size, batch_size) 33 | indexes = np.arange(0, size, 1) 34 | np.random.shuffle(indexes) 35 | batches = [(i * batch_size, min(size, (i + 1) * batch_size)) 36 | for i in range(0, num_batches)] 37 | 38 | for (batch_start, batch_end) in batches: 39 | batch = [] 40 | for i in range(len(dataset)): 41 | batch.append(dataset[i][indexes[batch_start:batch_end]]) 42 | yield batch 43 | -------------------------------------------------------------------------------- /mushroom_rl/utils/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | from .viewer import MujocoViewer 2 | from .observation_helper import ObservationHelper, ObservationType 3 | from .kinematics import forward_kinematics 4 | -------------------------------------------------------------------------------- /mushroom_rl/utils/mujoco/kinematics.py: -------------------------------------------------------------------------------- 1 | import mujoco 2 | 3 | 4 | def forward_kinematics(mj_model, mj_data, q, body_name): 5 | """ 6 | Compute the forward kinematics of the robots. 7 | 8 | Args: 9 | mj_model (mujoco.MjModel): mujoco MjModel of the robot-only model 10 | mj_data (mujoco.MjData): mujoco MjData object generated from the model 11 | q (np.array): joint configuration for which the forward kinematics are computed 12 | body_name (str): name of the body for which the fk is computed 13 | 14 | Returns (np.array(3), np.array(3, 3)): 15 | Position and Orientation of the body with the name body_name 16 | """ 17 | 18 | mj_data.qpos[:len(q)] = q 19 | mujoco.mj_fwdPosition(mj_model, mj_data) 20 | return mj_data.body(body_name).xpos.copy(), mj_data.body(body_name).xmat.reshape(3, 3).copy() 21 | -------------------------------------------------------------------------------- /mushroom_rl/utils/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as st 3 | 4 | 5 | def get_mean_and_confidence(data): 6 | """ 7 | Compute the mean and 95% confidence interval 8 | 9 | Args: 10 | data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs). 11 | 12 | Returns: 13 | The mean of the dataset at each epoch along with the confidence interval. 14 | 15 | """ 16 | mean = np.mean(data, axis=0) 17 | se = st.sem(data, axis=0) 18 | n = len(data) 19 | _, interval = st.t.interval(0.95, n-1, scale=se) 20 | return mean, interval 21 | 22 | 23 | def plot_mean_conf(data, ax, color='blue', line='-', facecolor=None, alpha=0.4, label=None): 24 | """ 25 | Method to plot mean and confidence interval for data on matplotlib axes. 26 | 27 | Args: 28 | data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs); 29 | ax (plt.Axes): matplotlib axes where to create the curve; 30 | color (str, 'blue'): matplotlib color identifier for the mean curve; 31 | line (str, '-'): matplotlib line type to be used for the mean curve; 32 | facecolor (str, None): matplotlib color identifier for the confidence interval; 33 | alpha (float, 0.4): transparency of the confidence interval; 34 | label (str, one): legend label for the plotted curve. 35 | 36 | 37 | """ 38 | facecolor = color if facecolor is None else facecolor 39 | 40 | mean, conf = get_mean_and_confidence(np.array(data)) 41 | upper_bound = mean + conf 42 | lower_bound = mean - conf 43 | 44 | ax.plot(mean, color=color, linestyle=line, label=label) 45 | ax.fill_between(np.arange(np.size(mean)), upper_bound, lower_bound, facecolor=facecolor, alpha=alpha) 46 | 47 | 48 | -------------------------------------------------------------------------------- /mushroom_rl/utils/plots/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | try: 4 | from .plot_item_buffer import PlotItemBuffer 5 | __all__.append('PlotItemBuffer') 6 | 7 | from .databuffer import DataBuffer 8 | __all__.append('DataBuffer') 9 | 10 | from .window import Window 11 | __all__.append('Window') 12 | 13 | from .common_plots import Actions, LenOfEpisodeTraining, Observations,\ 14 | RewardPerEpisode, RewardPerStep 15 | 16 | __all__ += ['Actions', 'LenOfEpisodeTraining', 'Observations', 17 | 'RewardPerEpisode', 'RewardPerStep'] 18 | 19 | except ImportError: 20 | pass 21 | -------------------------------------------------------------------------------- /mushroom_rl/utils/pybullet/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .observation import PyBulletObservationType 3 | from .index_map import IndexMap 4 | from .viewer import PyBulletViewer 5 | from .joints_helper import JointsHelper 6 | except ImportError: 7 | pass 8 | -------------------------------------------------------------------------------- /mushroom_rl/utils/pybullet/joints_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .observation import PyBulletObservationType 4 | 5 | 6 | class JointsHelper(object): 7 | def __init__(self, client, indexer, observation_spec): 8 | self._joint_pos_indexes = list() 9 | self._joint_velocity_indexes = list() 10 | joint_limits_low = list() 11 | joint_limits_high = list() 12 | joint_velocity_limits = list() 13 | for joint_name, obs_type in observation_spec: 14 | joint_idx = indexer.get_index(joint_name, obs_type) 15 | if obs_type == PyBulletObservationType.JOINT_VEL: 16 | self._joint_velocity_indexes.append(joint_idx[0]) 17 | 18 | model_id, joint_id = indexer.joint_map[joint_name] 19 | joint_info = client.getJointInfo(model_id, joint_id) 20 | joint_velocity_limits.append(joint_info[11]) 21 | 22 | elif obs_type == PyBulletObservationType.JOINT_POS: 23 | self._joint_pos_indexes.append(joint_idx[0]) 24 | 25 | model_id, joint_id = indexer.joint_map[joint_name] 26 | joint_info = client.getJointInfo(model_id, joint_id) 27 | joint_limits_low.append(joint_info[8]) 28 | joint_limits_high.append(joint_info[9]) 29 | 30 | self._joint_limits_low = np.array(joint_limits_low) 31 | self._joint_limits_high = np.array(joint_limits_high) 32 | self._joint_velocity_limits = np.array(joint_velocity_limits) 33 | 34 | def positions(self, state): 35 | return state[self._joint_pos_indexes] 36 | 37 | def velocities(self, state): 38 | return state[self._joint_velocity_indexes] 39 | 40 | def limits(self): 41 | return self._joint_limits_low, self._joint_limits_high 42 | 43 | def velocity_limits(self): 44 | return self._joint_velocity_limits 45 | -------------------------------------------------------------------------------- /mushroom_rl/utils/pybullet/observation.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class PyBulletObservationType(Enum): 5 | """ 6 | An enum indicating the type of data that should be added to the observation 7 | of the environment, can be Joint-/Body-/Site- positions and velocities. 8 | 9 | """ 10 | __order__ = "BODY_POS BODY_LIN_VEL BODY_ANG_VEL JOINT_POS JOINT_VEL LINK_POS LINK_LIN_VEL LINK_ANG_VEL CONTACT_FLAG" 11 | BODY_POS = 0 12 | BODY_LIN_VEL = 1 13 | BODY_ANG_VEL = 2 14 | JOINT_POS = 3 15 | JOINT_VEL = 4 16 | LINK_POS = 5 17 | LINK_LIN_VEL = 6 18 | LINK_ANG_VEL = 7 19 | CONTACT_FLAG = 8 -------------------------------------------------------------------------------- /mushroom_rl/utils/pybullet/viewer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pybullet 3 | from mushroom_rl.utils.viewer import ImageViewer 4 | 5 | 6 | class PyBulletViewer(ImageViewer): 7 | def __init__(self, client, dt, size=(500, 500), distance=4, origin=(0, 0, 1), angles=(0, -45, 60), 8 | fov=60, aspect=1, near_val=0.01, far_val=100): 9 | self._client = client 10 | self._size = size 11 | self._distance = distance 12 | self._origin = origin 13 | self._angles = angles 14 | self._fov = fov 15 | self._aspect = aspect 16 | self._near_val = near_val 17 | self._far_val = far_val 18 | super().__init__(size, dt) 19 | 20 | def display(self): 21 | img = self._get_image() 22 | super().display(img) 23 | 24 | return img 25 | 26 | def _get_image(self): 27 | view_matrix = self._client.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=self._origin, 28 | distance=self._distance, 29 | roll=self._angles[0], 30 | pitch=self._angles[1], 31 | yaw=self._angles[2], 32 | upAxisIndex=2) 33 | proj_matrix = self._client.computeProjectionMatrixFOV(fov=self._fov, aspect=self._aspect, 34 | nearVal=self._near_val, farVal=self._far_val) 35 | (_, _, px, _, _) = self._client.getCameraImage(width=self._size[0], 36 | height=self._size[1], 37 | viewMatrix=view_matrix, 38 | projectionMatrix=proj_matrix, 39 | renderer=pybullet.ER_BULLET_HARDWARE_OPENGL) 40 | 41 | rgb_array = np.reshape(np.array(px), (self._size[0], self._size[1], -1)) 42 | rgb_array = rgb_array[:, :, :3] 43 | return rgb_array 44 | -------------------------------------------------------------------------------- /mushroom_rl/utils/quaternions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.utils.angles import mat_to_euler, euler_to_quat 4 | 5 | 6 | def normalize_quaternion(q): 7 | norm = np.linalg.norm(q) 8 | return q / norm 9 | 10 | 11 | def quaternion_distance(q1, q2): 12 | q1 = normalize_quaternion(q1) 13 | q2 = normalize_quaternion(q2) 14 | 15 | cos_half_angle = np.abs(np.dot(q1, q2)) 16 | 17 | theta = 2 * np.arccos(cos_half_angle) 18 | return theta / 2 19 | 20 | 21 | def mat_to_quat(mat): 22 | euler = mat_to_euler(mat) 23 | quat = euler_to_quat(euler) 24 | return quat 25 | -------------------------------------------------------------------------------- /mushroom_rl/utils/record.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import datetime 3 | from pathlib import Path 4 | 5 | 6 | class VideoRecorder(object): 7 | """ 8 | Simple video record that creates a video from a stream of images. 9 | 10 | """ 11 | 12 | def __init__(self, path="./mushroom_rl_recordings", tag=None, video_name=None, fps=60): 13 | """ 14 | Constructor. 15 | 16 | Args: 17 | path: Path at which videos will be stored. 18 | tag: Name of the directory at path in which the video will be stored. If None, a timestamp will be created. 19 | video_name: Name of the video without extension. Default is "recording". 20 | fps: Frame rate of the video. 21 | """ 22 | 23 | if tag is None: 24 | date_time = datetime.datetime.now() 25 | tag = date_time.strftime("%d-%m-%Y_%H-%M-%S") 26 | 27 | self._path = Path(path) 28 | self._path = self._path / tag 29 | 30 | self._video_name = video_name if video_name else "recording" 31 | self._counter = 0 32 | 33 | self._fps = fps 34 | 35 | self._video_writer = None 36 | 37 | def __call__(self, frame): 38 | """ 39 | Args: 40 | frame (np.ndarray): Frame to be added to the video (H, W, RGB) 41 | """ 42 | assert frame is not None 43 | 44 | if self._video_writer is None: 45 | height, width = frame.shape[:2] 46 | self._create_video_writer(height, width) 47 | 48 | self._video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) 49 | 50 | def _create_video_writer(self, height, width): 51 | 52 | name = self._video_name 53 | if self._counter > 0: 54 | name += f"-{self._counter}.mp4" 55 | else: 56 | name += ".mp4" 57 | 58 | self._path.mkdir(parents=True, exist_ok=True) 59 | 60 | path = self._path / name 61 | 62 | self._video_writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), 63 | self._fps, (width, height)) 64 | 65 | def stop(self): 66 | cv2.destroyAllWindows() 67 | self._video_writer.release() 68 | self._video_writer = None 69 | self._counter += 1 70 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "numpy"] 3 | 4 | [project] 5 | name = "mushroom-rl" 6 | dependencies = [ 7 | "numpy", 8 | "scipy", 9 | "scikit-learn", 10 | "matplotlib", 11 | "joblib", 12 | "tqdm", 13 | "pygame", 14 | "opencv-python>=4.7", 15 | "torch" 16 | ] 17 | requires-python = ">=3.8" 18 | authors = [ 19 | { name="Carlo D'Eramo", email="carlo.deramo@gmail.com"}, 20 | { name = "Davide Tateo", email = "davide@robot-learning.de" } 21 | ] 22 | maintainers = [ 23 | { name = "Davide Tateo", email = "davide@robot-learning.de" } 24 | ] 25 | description = "A Python library for Reinforcement Learning experiments." 26 | readme = "README.rst" 27 | license = { file= "LICENSE" } 28 | keywords = ["Reinforcement Learning", "Machine Learning", "Robotics"] 29 | classifiers = [ 30 | "Programming Language :: Python :: 3", 31 | "License :: OSI Approved :: MIT License", 32 | "Operating System :: OS Independent", 33 | ] 34 | 35 | dynamic = ["version", "optional-dependencies"] 36 | 37 | [project.urls] 38 | Homepage = "https://github.com/MushroomRL" 39 | Documentation = "https://mushroomrl.readthedocs.io/en/latest/" 40 | Repository = "https://github.com/MushroomRL/mushroom-rl" 41 | Issues = "https://github.com/MushroomRL/mushroom-rl/issues" 42 | -------------------------------------------------------------------------------- /tests/algorithms/test_dpg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from datetime import datetime 4 | from helper.utils import TestUtils as tu 5 | 6 | from mushroom_rl.core import Agent 7 | from mushroom_rl.algorithms.actor_critic import COPDAC_Q 8 | from mushroom_rl.core import Core 9 | from mushroom_rl.environments import * 10 | from mushroom_rl.features import Features 11 | from mushroom_rl.features.tiles import Tiles 12 | from mushroom_rl.approximators import Regressor 13 | from mushroom_rl.approximators.parametric import LinearApproximator 14 | from mushroom_rl.policy import GaussianPolicy 15 | from mushroom_rl.rl_utils.parameters import Parameter 16 | 17 | 18 | def learn_copdac_q(): 19 | n_steps = 50 20 | mdp = InvertedPendulum(horizon=n_steps) 21 | np.random.seed(1) 22 | torch.manual_seed(1) 23 | torch.cuda.manual_seed(1) 24 | 25 | # Agent 26 | n_tilings = 1 27 | alpha_theta = Parameter(5e-3 / n_tilings) 28 | alpha_omega = Parameter(0.5 / n_tilings) 29 | alpha_v = Parameter(0.5 / n_tilings) 30 | tilings = Tiles.generate(n_tilings, [2, 2], 31 | mdp.info.observation_space.low, 32 | mdp.info.observation_space.high + 1e-3) 33 | 34 | phi = Features(tilings=tilings) 35 | 36 | input_shape = (phi.size,) 37 | 38 | mu = Regressor(LinearApproximator, input_shape=input_shape, output_shape=mdp.info.action_space.shape, phi=phi) 39 | 40 | sigma = 1e-1 * np.eye(1) 41 | policy = GaussianPolicy(mu, sigma) 42 | 43 | agent = COPDAC_Q(mdp.info, policy, mu, alpha_theta, alpha_omega, alpha_v, value_function_features=phi) 44 | 45 | # Train 46 | core = Core(agent, mdp) 47 | 48 | core.learn(n_episodes=2, n_episodes_per_fit=1) 49 | 50 | return agent 51 | 52 | 53 | def test_copdac_q(): 54 | policy = learn_copdac_q().policy 55 | w = policy.get_weights() 56 | w_test = np.array([[0.0, -6.62180045e-07, 0.0, -4.23972882e-02]]) 57 | 58 | assert np.allclose(w, w_test) 59 | 60 | 61 | def test_copdac_q_save(tmpdir): 62 | agent_path = tmpdir / 'agent_{}'.format(datetime.now().strftime("%H%M%S%f")) 63 | 64 | agent_save = learn_copdac_q() 65 | 66 | agent_save.save(agent_path) 67 | agent_load = Agent.load(agent_path) 68 | 69 | for att, method in vars(agent_save).items(): 70 | save_attr = getattr(agent_save, att) 71 | load_attr = getattr(agent_load, att) 72 | 73 | tu.assert_eq(save_attr, load_attr) 74 | -------------------------------------------------------------------------------- /tests/algorithms/test_lspi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from datetime import datetime 4 | from helper.utils import TestUtils as tu 5 | 6 | from mushroom_rl.core import Agent 7 | from mushroom_rl.algorithms.value import LSPI 8 | from mushroom_rl.core import Core 9 | from mushroom_rl.environments import * 10 | from mushroom_rl.features import Features 11 | from mushroom_rl.features.basis import PolynomialBasis 12 | from mushroom_rl.policy import EpsGreedy 13 | from mushroom_rl.rl_utils.parameters import Parameter 14 | 15 | 16 | def learn_lspi(): 17 | np.random.seed(1) 18 | 19 | # MDP 20 | mdp = CartPole() 21 | 22 | # Policy 23 | epsilon = Parameter(value=1.) 24 | pi = EpsGreedy(epsilon=epsilon) 25 | 26 | # Agent 27 | basis = [PolynomialBasis()] 28 | features = Features(basis_list=basis) 29 | 30 | fit_params = dict() 31 | approximator_params = dict(input_shape=(features.size,), 32 | output_shape=(mdp.info.action_space.n,), 33 | n_actions=mdp.info.action_space.n, 34 | phi=features) 35 | agent = LSPI(mdp.info, pi, approximator_params=approximator_params, fit_params=fit_params) 36 | 37 | # Algorithm 38 | core = Core(agent, mdp) 39 | 40 | # Train 41 | core.learn(n_episodes=10, n_episodes_per_fit=10) 42 | 43 | return agent 44 | 45 | 46 | def test_lspi(): 47 | 48 | w = learn_lspi().approximator.get_weights() 49 | w_test = np.array([-1.67115903, -1.43755615, -1.67115903]) 50 | 51 | assert np.allclose(w, w_test) 52 | 53 | 54 | def test_lspi_save(tmpdir): 55 | agent_path = tmpdir / 'agent_{}'.format(datetime.now().strftime("%H%M%S%f")) 56 | 57 | agent_save = learn_lspi() 58 | 59 | agent_save.save(agent_path) 60 | agent_load = Agent.load(agent_path) 61 | 62 | for att, method in vars(agent_save).items(): 63 | save_attr = getattr(agent_save, att) 64 | load_attr = getattr(agent_load, att) 65 | 66 | tu.assert_eq(save_attr, load_attr) 67 | -------------------------------------------------------------------------------- /tests/core/test_core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.core import Agent, Core 4 | from mushroom_rl.environments import Atari 5 | 6 | from mushroom_rl.policy import Policy 7 | 8 | 9 | class RandomDiscretePolicy(Policy): 10 | def __init__(self, n): 11 | super().__init__() 12 | self._n = n 13 | 14 | def draw_action(self, state, policy_state=None): 15 | return [np.random.randint(self._n)], None 16 | 17 | 18 | class DummyAgent(Agent): 19 | def __init__(self, mdp_info): 20 | policy = RandomDiscretePolicy(mdp_info.action_space.n) 21 | super().__init__(mdp_info, policy) 22 | 23 | def fit(self, dataset): 24 | pass 25 | 26 | 27 | def test_core(): 28 | mdp = Atari(name='ALE/Breakout-v5', repeat_action_probability=0.0) 29 | 30 | agent = DummyAgent(mdp.info) 31 | 32 | core = Core(agent, mdp) 33 | 34 | np.random.seed(2) 35 | mdp.seed(2) 36 | 37 | core.learn(n_steps=100, n_steps_per_fit=1) 38 | 39 | dataset = core.evaluate(n_steps=20) 40 | 41 | assert 'lives' in dataset.info 42 | assert 'episode_frame_number' in dataset.info 43 | assert 'frame_number' in dataset.info 44 | 45 | info_lives = np.array(dataset.info['lives']) 46 | 47 | print(info_lives) 48 | lives_gt = np.array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]) 49 | assert len(info_lives) == 20 50 | assert np.all(info_lives == lives_gt) 51 | assert len(dataset) == 20 52 | 53 | 54 | -------------------------------------------------------------------------------- /tests/core/test_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mushroom_rl.core import Logger 3 | 4 | 5 | def test_logger(tmpdir): 6 | logger_1 = Logger('test', seed=1, results_dir=tmpdir) 7 | logger_2 = Logger('test', seed=2, results_dir=tmpdir) 8 | 9 | for i in range(3): 10 | logger_1.log_numpy(a=i, b=2*i+1) 11 | logger_2.log_numpy(a=2*i+1, b=i) 12 | 13 | a_1 = np.load(str(tmpdir / 'test' / 'a-1.npy')) 14 | a_2 = np.load(str(tmpdir / 'test' / 'a-2.npy')) 15 | b_1 = np.load(str(tmpdir / 'test' / 'b-1.npy')) 16 | b_2 = np.load(str(tmpdir / 'test' / 'b-2.npy')) 17 | 18 | assert np.array_equal(a_1, np.arange(3)) 19 | assert np.array_equal(b_2, np.arange(3)) 20 | assert np.array_equal(a_1, b_2) 21 | assert np.array_equal(b_1, a_2) 22 | 23 | logger_1_bis = Logger('test', append=True, seed=1, results_dir=tmpdir) 24 | 25 | logger_1_bis.log_numpy(a=3, b=7) 26 | a_1 = np.load(str(tmpdir / 'test' / 'a-1.npy')) 27 | b_2 = np.load(str(tmpdir / 'test' / 'b-2.npy')) 28 | 29 | assert np.array_equal(a_1, np.arange(4)) 30 | assert np.array_equal(b_2, np.arange(3)) 31 | -------------------------------------------------------------------------------- /tests/core/test_serialization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from mushroom_rl.core import Serializable 5 | from mushroom_rl.utils import TorchUtils 6 | 7 | 8 | class DummyClass(Serializable): 9 | def __init__(self): 10 | self.torch_tensor = torch.randn(2, 2).to(TorchUtils.get_device()) 11 | self.numpy_array = np.random.randn(3, 4) 12 | self.scalar = 1 13 | self.dictionary = {'a': 'test', 'b': 5, 'd': (2, 3)} 14 | self.not_saved = 'test2' 15 | 16 | self._add_save_attr( 17 | torch_tensor='torch', 18 | numpy_array='numpy', 19 | scalar='primitive', 20 | dictionary='pickle', 21 | not_saved='none' 22 | ) 23 | 24 | def __eq__(self, other): 25 | f1 = torch.equal(self.torch_tensor.cpu(), other.torch_tensor.cpu()) 26 | f2 = np.array_equal(self.numpy_array, other.numpy_array) 27 | f3 = self.scalar == other.scalar 28 | f4 = self.dictionary == other.dictionary 29 | 30 | return f1 and f2 and f3 and f4 31 | 32 | 33 | def test_serialization(tmpdir): 34 | TorchUtils.set_default_device('cpu') 35 | 36 | a = DummyClass() 37 | a.save(tmpdir / 'test.msh') 38 | 39 | b = Serializable.load(tmpdir / 'test.msh') 40 | 41 | assert a == b 42 | assert b.not_saved == None 43 | 44 | 45 | def test_serialization_cuda_cpu(tmpdir): 46 | if torch.cuda.is_available(): 47 | TorchUtils.set_default_device('cuda') 48 | 49 | a = DummyClass() 50 | a.save(tmpdir / 'test.msh') 51 | 52 | TorchUtils.set_default_device('cpu') 53 | 54 | assert a.torch_tensor.device.type == 'cuda' 55 | 56 | b = Serializable.load(tmpdir / 'test.msh') 57 | 58 | assert b.torch_tensor.device.type == 'cpu' 59 | 60 | assert a == b 61 | 62 | 63 | def test_serialization_cpu_cuda(tmpdir): 64 | if torch.cuda.is_available(): 65 | TorchUtils.set_default_device('cpu') 66 | 67 | a = DummyClass() 68 | a.save(tmpdir / 'test.msh') 69 | 70 | TorchUtils.set_default_device('cuda') 71 | 72 | assert a.torch_tensor.device.type == 'cpu' 73 | 74 | b = Serializable.load(tmpdir / 'test.msh') 75 | 76 | assert b.torch_tensor.device.type == 'cuda' 77 | 78 | assert a == b 79 | 80 | TorchUtils.set_default_device('cpu') 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /tests/distributions/test_distribution_interface.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.distributions import Distribution 2 | 3 | 4 | def abstract_method_tester(f, *args): 5 | try: 6 | f(*args) 7 | except NotImplementedError: 8 | pass 9 | else: 10 | assert False 11 | 12 | 13 | def test_distribution_interface(): 14 | tmp = Distribution() 15 | 16 | abstract_method_tester(tmp.sample) 17 | abstract_method_tester(tmp.log_pdf, None) 18 | abstract_method_tester(tmp.__call__, None) 19 | abstract_method_tester(tmp.entropy) 20 | abstract_method_tester(tmp.mle, None) 21 | abstract_method_tester(tmp.diff_log, None) 22 | abstract_method_tester(tmp.diff, None) 23 | 24 | abstract_method_tester(tmp.get_parameters) 25 | abstract_method_tester(tmp.set_parameters, None) 26 | 27 | try: 28 | tmp.parameters_size 29 | except NotImplementedError: 30 | pass 31 | else: 32 | assert False 33 | -------------------------------------------------------------------------------- /tests/environments/grid.txt: -------------------------------------------------------------------------------- 1 | ##### 2 | #S.*# 3 | #...# 4 | ##..# 5 | #G..# 6 | ##### 7 | -------------------------------------------------------------------------------- /tests/environments/mujoco_envs/air_hockey_defend_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_defend_data.npy -------------------------------------------------------------------------------- /tests/environments/mujoco_envs/air_hockey_hit_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_hit_data.npy -------------------------------------------------------------------------------- /tests/environments/mujoco_envs/air_hockey_prepare_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_prepare_data.npy -------------------------------------------------------------------------------- /tests/environments/mujoco_envs/air_hockey_repel_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_repel_data.npy -------------------------------------------------------------------------------- /tests/environments/mujoco_envs/test_ball_in_a_cup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from mushroom_rl.environments.mujoco_envs import BallInACup 3 | import numpy as np 4 | 5 | def linear_movement(start, end, n_steps, i): 6 | t = np.minimum(1., float(i) / float(n_steps)) 7 | return start + (end - start) * t 8 | 9 | 10 | def test_ball_in_a_cup(): 11 | env = BallInACup() 12 | 13 | des_pos = np.array([0.0, -0.58760536, 0.0, 1.36004913, 0.0, -0.32072943, -1.57]) 14 | p_gains = np.array([200, 300, 100, 100, 10, 10, 2.5])/5 15 | d_gains = np.array([7, 15, 5, 2.5, 0.3, 0.3, 0.05])/10 16 | 17 | obs_0, _ = env.reset() 18 | 19 | for _ in [1,2]: 20 | obs, _ = env.reset() 21 | 22 | assert np.array_equal(obs, obs_0) 23 | done = False 24 | i = 0 25 | while not done: 26 | q_cmd = env.linear_movement(env.init_robot_pos, des_pos, 100, i) 27 | q_curr = obs[0:14:2] 28 | qdot_cur = obs[1:14:2] 29 | pos_err = q_cmd - q_curr 30 | 31 | a = env._data.qfrc_bias[:7] + p_gains * pos_err - d_gains * qdot_cur 32 | #a = np.zeros(7) 33 | 34 | # Check the observations 35 | assert np.allclose(obs[0:14:2], env._data.qpos[0:7]) 36 | assert np.allclose(obs[1:14:2], env._data.qvel[0:7]) 37 | # assert np.allclose(obs[14:17], env._data.xpos[40]) 38 | # assert np.allclose(obs[17:], env._data.cvel[40]) 39 | obs, reward, done, info = env.step(a) 40 | 41 | i += 1 42 | 43 | except ImportError: 44 | pass 45 | -------------------------------------------------------------------------------- /tests/environments/pybullet_envs/air_hockey_defend_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_defend_data.npy -------------------------------------------------------------------------------- /tests/environments/pybullet_envs/air_hockey_hit_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_hit_data.npy -------------------------------------------------------------------------------- /tests/environments/pybullet_envs/air_hockey_prepare_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_prepare_data.npy -------------------------------------------------------------------------------- /tests/environments/pybullet_envs/air_hockey_repel_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_repel_data.npy -------------------------------------------------------------------------------- /tests/environments/taxi.txt: -------------------------------------------------------------------------------- 1 | S#F.#.G 2 | .#..#.. 3 | ....... 4 | ##...## 5 | ......F 6 | F.....# 7 | -------------------------------------------------------------------------------- /tests/environments/test_atari_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/test_atari_1.npy -------------------------------------------------------------------------------- /tests/environments/test_mujoco.py: -------------------------------------------------------------------------------- 1 | try: 2 | from mushroom_rl.environments.dm_control_env import DMControl 3 | import numpy as np 4 | 5 | def test_dm_control(): 6 | np.random.seed(1) 7 | mdp = DMControl('hopper', 'hop', 1000, .99, task_kwargs={'random': 1}) 8 | mdp.reset() 9 | for i in range(10): 10 | ns, r, ab, _ = mdp.step( 11 | np.random.rand(mdp.info.action_space.shape[0])) 12 | ns_test = np.array([-0.25868173, -2.24011367, 0.45346572, -0.55528368, 13 | 0.51603826, -0.21782316, -0.58708578, -2.04541986, 14 | -17.24931206, 5.42227781, 21.39084468, -2.42071806, 15 | 3.85448837, 0., 0.]) 16 | 17 | assert np.allclose(ns, ns_test) 18 | except ImportError: 19 | pass 20 | -------------------------------------------------------------------------------- /tests/policy/test_deterministic_policy.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.policy import DeterministicPolicy 2 | from mushroom_rl.approximators.regressor import Regressor 3 | from mushroom_rl.approximators.parametric import LinearApproximator 4 | from mushroom_rl.utils.numerical_gradient import numerical_diff_policy 5 | 6 | import numpy as np 7 | 8 | 9 | def test_deterministic_policy(): 10 | np.random.seed(88) 11 | 12 | n_dims = 5 13 | 14 | approximator = Regressor(LinearApproximator, 15 | input_shape=(n_dims,), 16 | output_shape=(2,)) 17 | 18 | pi = DeterministicPolicy(approximator) 19 | 20 | w_new = np.random.rand(pi.weights_size) 21 | 22 | w_old = pi.get_weights() 23 | pi.set_weights(w_new) 24 | 25 | assert np.array_equal(w_new, approximator.get_weights()) 26 | assert not np.array_equal(w_old, w_new) 27 | assert np.array_equal(w_new, pi.get_weights()) 28 | 29 | s_test_1 = np.random.randn(5) 30 | s_test_2 = np.random.randn(5) 31 | a_test = approximator.predict(s_test_1) 32 | 33 | assert pi.get_regressor() == approximator 34 | 35 | assert pi(s_test_1, a_test) == 1 36 | assert pi(s_test_2, a_test) == 0 37 | 38 | a_stored = np.array([-1.86941072, -0.1789696]) 39 | action, _ = pi.draw_action(s_test_1) 40 | assert np.allclose(action, a_stored) 41 | 42 | -------------------------------------------------------------------------------- /tests/policy/test_noise_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mushroom_rl.approximators import Regressor 4 | from mushroom_rl.approximators.parametric import TorchApproximator 5 | from mushroom_rl.approximators.parametric.networks import LinearNetwork 6 | from mushroom_rl.policy import OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy 7 | 8 | 9 | def test_ornstein_uhlenbeck_policy(): 10 | torch.manual_seed(42) 11 | 12 | mu = Regressor(TorchApproximator, network=LinearNetwork, input_shape=(5,), output_shape=(2,)) 13 | pi = OrnsteinUhlenbeckPolicy(mu, sigma=torch.ones(1) * .2, theta=.15, dt=1e-2) 14 | 15 | w = torch.randn(pi.weights_size) 16 | pi.set_weights(w) 17 | assert torch.equal(pi.get_weights(), w) 18 | 19 | state = torch.randn(5) 20 | 21 | policy_state = pi.reset() 22 | 23 | action, policy_state = pi.draw_action(state, policy_state) 24 | action_test = torch.tensor([-0.7055691481, 1.1255935431]) 25 | assert torch.allclose(action, action_test) 26 | 27 | policy_state = pi.reset() 28 | action, policy_state = pi.draw_action(state, policy_state) 29 | action_test = torch.tensor([-0.7114595175, 1.1141412258]) 30 | assert torch.allclose(action, action_test) 31 | 32 | try: 33 | pi(state, action) 34 | except NotImplementedError: 35 | pass 36 | else: 37 | assert False 38 | 39 | 40 | def test_clipped_gaussian_policy(): 41 | torch.manual_seed(1) 42 | 43 | low = -torch.ones(2) 44 | high = torch.ones(2) 45 | 46 | mu = Regressor(TorchApproximator, network=LinearNetwork, input_shape=(5,), output_shape=(2,)) 47 | pi = ClippedGaussianPolicy(mu, torch.eye(2), low, high) 48 | 49 | w = torch.randn(pi.weights_size) 50 | pi.set_weights(w) 51 | assert torch.equal(pi.get_weights(), w) 52 | 53 | state = torch.randn(5) 54 | 55 | action, _ = pi.draw_action(state) 56 | action_test = torch.tensor([-1.0, 1.0]) 57 | assert torch.allclose(action, action_test) 58 | 59 | action, _ = pi.draw_action(state) 60 | action_test = torch.tensor([0.4926533699, 1.0]) 61 | assert torch.allclose(action, action_test) 62 | 63 | try: 64 | pi(state, action) 65 | except NotImplementedError: 66 | pass 67 | else: 68 | assert False 69 | 70 | # TODO Missing test for clipped gaussian! 71 | 72 | -------------------------------------------------------------------------------- /tests/policy/test_policy_interface.py: -------------------------------------------------------------------------------- 1 | from mushroom_rl.policy import Policy, ParametricPolicy 2 | 3 | 4 | def abstract_method_tester(f, ex, *args): 5 | try: 6 | f(*args) 7 | except ex: 8 | pass 9 | else: 10 | assert False 11 | 12 | 13 | def test_policy_interface(): 14 | tmp = Policy() 15 | abstract_method_tester(tmp.__call__, NotImplementedError, None, None, None) 16 | abstract_method_tester(tmp.draw_action, NotImplementedError, None, None) 17 | tmp.reset() 18 | 19 | 20 | def test_parametric_policy(): 21 | tmp = ParametricPolicy() 22 | abstract_method_tester(tmp.diff_log, RuntimeError, None, None, None) 23 | abstract_method_tester(tmp.diff, RuntimeError, None, None, None) 24 | abstract_method_tester(tmp.set_weights, NotImplementedError, None) 25 | abstract_method_tester(tmp.get_weights, NotImplementedError) 26 | try: 27 | tmp.weights_size 28 | except NotImplementedError: 29 | pass 30 | else: 31 | assert False 32 | -------------------------------------------------------------------------------- /tests/solvers/test_car_on_hill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mushroom_rl.environments.car_on_hill import CarOnHill 4 | from mushroom_rl.solvers.car_on_hill import solve_car_on_hill 5 | 6 | 7 | def test_car_on_hill(): 8 | mdp = CarOnHill() 9 | mdp._discrete_actions = np.array([-8., 8.]) 10 | 11 | states = np.array([[-.5, 0], [0., 0.], [.5, 0.]]) 12 | actions = np.array([[0], [1], [0]]) 13 | q = solve_car_on_hill(mdp, states, actions, .95) 14 | q_test = np.array([0.5688000922764597, 0.48767497911552954, 15 | 0.5688000922764597]) 16 | 17 | assert np.allclose(q, q_test) 18 | -------------------------------------------------------------------------------- /tests/test_imports.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def test_imports(): 4 | import mushroom_rl 5 | 6 | import mushroom_rl.algorithms 7 | import mushroom_rl.algorithms.actor_critic 8 | import mushroom_rl.algorithms.actor_critic.classic_actor_critic 9 | import mushroom_rl.algorithms.actor_critic.deep_actor_critic 10 | import mushroom_rl.algorithms.policy_search 11 | import mushroom_rl.algorithms.policy_search.black_box_optimization 12 | import mushroom_rl.algorithms.policy_search.policy_gradient 13 | import mushroom_rl.algorithms.value 14 | import mushroom_rl.algorithms.value.batch_td 15 | import mushroom_rl.algorithms.value.td 16 | import mushroom_rl.algorithms.value.dqn 17 | 18 | import mushroom_rl.approximators 19 | import mushroom_rl.approximators._implementations 20 | import mushroom_rl.approximators.parametric 21 | 22 | import mushroom_rl.core 23 | 24 | import mushroom_rl.distributions 25 | 26 | import mushroom_rl.environments 27 | import mushroom_rl.environments.generators 28 | 29 | try: 30 | import mujoco 31 | except ImportError: 32 | pass 33 | else: 34 | import mushroom_rl.environments.mujoco_envs 35 | 36 | import mushroom_rl.features 37 | import mushroom_rl.features._implementations 38 | import mushroom_rl.features.basis 39 | import mushroom_rl.features.tensors 40 | import mushroom_rl.features.tiles 41 | 42 | import mushroom_rl.policy 43 | 44 | import mushroom_rl.solvers 45 | 46 | import mushroom_rl.utils 47 | 48 | --------------------------------------------------------------------------------