├── .coveragerc
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── continuous_integration.yml
│ └── test_pr.yml
├── .gitignore
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.rst
├── TODO.txt
├── docs
├── .gitignore
├── Makefile
├── __init__.py
├── _static
│ └── theme_overrides.css
├── _templates
│ ├── latex.tex_t
│ ├── tabular.tex_t
│ └── tabulary.tex_t
├── conf.py
├── index.rst
├── requirements.txt
└── source
│ ├── features
│ ├── mushroom_rl.features.basis.rst
│ ├── mushroom_rl.features.tensors.rst
│ └── mushroom_rl.features.tiles.rst
│ ├── mushroom_rl.agent_environment_interface.rst
│ ├── mushroom_rl.algorithms.actor_critic.rst
│ ├── mushroom_rl.algorithms.policy_search.rst
│ ├── mushroom_rl.algorithms.value.rst
│ ├── mushroom_rl.approximators.rst
│ ├── mushroom_rl.distributions.rst
│ ├── mushroom_rl.environments.rst
│ ├── mushroom_rl.features.rst
│ ├── mushroom_rl.policy.rst
│ ├── mushroom_rl.rl_utils.rst
│ ├── mushroom_rl.solvers.rst
│ ├── mushroom_rl.utils.rst
│ └── tutorials
│ ├── code
│ ├── advanced_experiment.py
│ ├── approximator.py
│ ├── ddpg.py
│ ├── dqn.py
│ ├── generic_regressor.py
│ ├── logger.py
│ ├── room_env.py
│ ├── serialization.py
│ └── simple_experiment.py
│ ├── tutorials.0_experiments.rst
│ ├── tutorials.1_advanced.rst
│ ├── tutorials.2_approximator.rst
│ ├── tutorials.3_deep.rst
│ ├── tutorials.4_logger.rst
│ ├── tutorials.5_environments.rst
│ ├── tutorials.6_serializable.rst
│ └── tutorials.99_examples.rst
├── examples
├── __init__.py
├── acrobot_a2c.py
├── acrobot_dqn.py
├── atari_dqn.py
├── car_on_hill_fqi.py
├── cartpole_lspi.py
├── double_chain_q_learning
│ ├── __init__.py
│ ├── chain_structure
│ │ ├── p.npy
│ │ └── rew.npy
│ └── double_chain.py
├── grid_world_td.py
├── gym_recurrent_ppo.py
├── habitat
│ ├── __init__.py
│ ├── habitat_nav_dqn.py
│ ├── habitat_rearrange_sac.py
│ ├── pointnav_apartment-0.yaml
│ ├── replica_train_apartment-0.json
│ └── replica_train_apartment-0.json.gz
├── igibson_dqn.py
├── isaacsim
│ ├── a1_rudin_ppo.py
│ ├── cartpole_ppo.py
│ ├── honey_badger_ppo.py
│ └── silver_badger_ppo.py
├── lqr_bbo.py
├── lqr_pg.py
├── minigrid_dqn.py
├── mountain_car_sarsa.py
├── mujoco_air_hockey_sac.py
├── mujoco_locomotion_ppo.py
├── mujoco_manipulation_ppo.py
├── omni_isaac_gym_example.py
├── pendulum_a2c.py
├── pendulum_ac.py
├── pendulum_ddpg.py
├── pendulum_dpg.py
├── pendulum_sac.py
├── pendulum_trust_region.py
├── plotting_and_normalization.py
├── puddle_world_sarsa.py
├── segway_bbo.py
├── segway_eppo.py
├── ship_steering_bbo.py
├── simple_chain_qlearning.py
├── taxi_mellow_sarsa
│ ├── grid.txt
│ └── taxi_mellow.py
├── vectorized_core
│ ├── __init__.py
│ ├── pendulum_trust_region.py
│ └── segway_bbo.py
├── walker_stand_ddpg.py
└── walker_stand_ddpg_shared_net.py
├── mushroom_rl
├── __init__.py
├── algorithms
│ ├── __init__.py
│ ├── actor_critic
│ │ ├── __init__.py
│ │ ├── classic_actor_critic
│ │ │ ├── __init__.py
│ │ │ ├── copdac_q.py
│ │ │ └── stochastic_ac.py
│ │ └── deep_actor_critic
│ │ │ ├── __init__.py
│ │ │ ├── a2c.py
│ │ │ ├── ddpg.py
│ │ │ ├── deep_actor_critic.py
│ │ │ ├── ppo.py
│ │ │ ├── ppo_bptt.py
│ │ │ ├── ppo_rudin.py
│ │ │ ├── sac.py
│ │ │ ├── td3.py
│ │ │ └── trpo.py
│ ├── policy_search
│ │ ├── __init__.py
│ │ ├── black_box_optimization
│ │ │ ├── __init__.py
│ │ │ ├── black_box_optimization.py
│ │ │ ├── constrained_reps.py
│ │ │ ├── context_builder.py
│ │ │ ├── eppo.py
│ │ │ ├── more.py
│ │ │ ├── pgpe.py
│ │ │ ├── reps.py
│ │ │ └── rwr.py
│ │ └── policy_gradient
│ │ │ ├── __init__.py
│ │ │ ├── enac.py
│ │ │ ├── gpomdp.py
│ │ │ ├── policy_gradient.py
│ │ │ └── reinforce.py
│ └── value
│ │ ├── __init__.py
│ │ ├── batch_td
│ │ ├── __init__.py
│ │ ├── batch_td.py
│ │ ├── boosted_fqi.py
│ │ ├── double_fqi.py
│ │ ├── fqi.py
│ │ └── lspi.py
│ │ ├── dqn
│ │ ├── __init__.py
│ │ ├── abstract_dqn.py
│ │ ├── averaged_dqn.py
│ │ ├── categorical_dqn.py
│ │ ├── double_dqn.py
│ │ ├── dqn.py
│ │ ├── dueling_dqn.py
│ │ ├── maxmin_dqn.py
│ │ ├── noisy_dqn.py
│ │ ├── quantile_dqn.py
│ │ └── rainbow.py
│ │ └── td
│ │ ├── __init__.py
│ │ ├── double_q_learning.py
│ │ ├── expected_sarsa.py
│ │ ├── maxmin_q_learning.py
│ │ ├── q_lambda.py
│ │ ├── q_learning.py
│ │ ├── r_learning.py
│ │ ├── rq_learning.py
│ │ ├── sarsa.py
│ │ ├── sarsa_lambda.py
│ │ ├── sarsa_lambda_continuous.py
│ │ ├── speedy_q_learning.py
│ │ ├── td.py
│ │ ├── true_online_sarsa_lambda.py
│ │ └── weighted_q_learning.py
├── approximators
│ ├── __init__.py
│ ├── _implementations
│ │ ├── __init__.py
│ │ ├── action_regressor.py
│ │ ├── generic_regressor.py
│ │ └── q_regressor.py
│ ├── ensemble.py
│ ├── ensemble_table.py
│ ├── parametric
│ │ ├── __init__.py
│ │ ├── cmac.py
│ │ ├── linear.py
│ │ ├── networks
│ │ │ ├── __init__.py
│ │ │ └── linear_network.py
│ │ └── torch_approximator.py
│ ├── regressor.py
│ └── table.py
├── core
│ ├── __init__.py
│ ├── _impl
│ │ ├── __init__.py
│ │ ├── core_logic.py
│ │ ├── list_dataset.py
│ │ ├── numpy_dataset.py
│ │ ├── torch_dataset.py
│ │ └── vectorized_core_logic.py
│ ├── agent.py
│ ├── array_backend.py
│ ├── core.py
│ ├── dataset.py
│ ├── environment.py
│ ├── extra_info.py
│ ├── logger
│ │ ├── __init__.py
│ │ ├── console_logger.py
│ │ ├── data_logger.py
│ │ └── logger.py
│ ├── multiprocess_environment.py
│ ├── serialization.py
│ ├── vectorized_core.py
│ └── vectorized_env.py
├── distributions
│ ├── __init__.py
│ ├── distribution.py
│ ├── gaussian.py
│ └── torch_distribution.py
├── environments
│ ├── __init__.py
│ ├── atari.py
│ ├── car_on_hill.py
│ ├── cart_pole.py
│ ├── dm_control_env.py
│ ├── finite_mdp.py
│ ├── generators
│ │ ├── __init__.py
│ │ ├── grid_world.py
│ │ ├── simple_chain.py
│ │ └── taxi.py
│ ├── grid_world.py
│ ├── gymnasium_env.py
│ ├── habitat_env.py
│ ├── igibson_env.py
│ ├── inverted_pendulum.py
│ ├── isaacsim_env.py
│ ├── isaacsim_envs
│ │ ├── __init__.py
│ │ ├── a1_walking.py
│ │ ├── cartpole.py
│ │ ├── honey_badger_walking.py
│ │ ├── robots_usds
│ │ │ ├── a1
│ │ │ │ ├── .thumbs
│ │ │ │ │ └── 256x256
│ │ │ │ │ │ ├── a1.usd.png
│ │ │ │ │ │ └── instanceable_meshes.usd.png
│ │ │ │ ├── a1.usd
│ │ │ │ └── instanceable_meshes.usd
│ │ │ ├── cartpole
│ │ │ │ └── cartpole.usd
│ │ │ ├── honey_badger
│ │ │ │ ├── .thumbs
│ │ │ │ │ └── 256x256
│ │ │ │ │ │ ├── honey_badger.usd.png
│ │ │ │ │ │ └── instanceable_meshes.usd.png
│ │ │ │ ├── honey_badger.usd
│ │ │ │ └── instanceable_meshes.usd
│ │ │ └── silver_badger
│ │ │ │ ├── .thumbs
│ │ │ │ └── 256x256
│ │ │ │ │ ├── instanceable_meshes.usd.png
│ │ │ │ │ └── silver_badger.usd.png
│ │ │ │ ├── instanceable_meshes.usd
│ │ │ │ └── silver_badger.usd
│ │ └── silver_badger_walking.py
│ ├── lqr.py
│ ├── minigrid_env.py
│ ├── mujoco.py
│ ├── mujoco_envs
│ │ ├── __init__.py
│ │ ├── air_hockey
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── defend.py
│ │ │ ├── double.py
│ │ │ ├── hit.py
│ │ │ ├── prepare.py
│ │ │ ├── repel.py
│ │ │ └── single.py
│ │ ├── ant.py
│ │ ├── ball_in_a_cup.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── air_hockey
│ │ │ │ ├── __init__.py
│ │ │ │ ├── double.xml
│ │ │ │ ├── planar_robot_1.xml
│ │ │ │ ├── planar_robot_2.xml
│ │ │ │ ├── single.xml
│ │ │ │ └── table.xml
│ │ │ ├── ant
│ │ │ │ ├── __init__.py
│ │ │ │ └── model.xml
│ │ │ ├── ball_in_a_cup
│ │ │ │ ├── __init__.py
│ │ │ │ ├── meshes
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── base_link_convex.stl
│ │ │ │ │ ├── base_link_fine.stl
│ │ │ │ │ ├── cup_split1.stl
│ │ │ │ │ ├── cup_split10.stl
│ │ │ │ │ ├── cup_split11.stl
│ │ │ │ │ ├── cup_split12.stl
│ │ │ │ │ ├── cup_split13.stl
│ │ │ │ │ ├── cup_split14.stl
│ │ │ │ │ ├── cup_split15.stl
│ │ │ │ │ ├── cup_split16.stl
│ │ │ │ │ ├── cup_split17.stl
│ │ │ │ │ ├── cup_split18.stl
│ │ │ │ │ ├── cup_split2.stl
│ │ │ │ │ ├── cup_split3.stl
│ │ │ │ │ ├── cup_split4.stl
│ │ │ │ │ ├── cup_split5.stl
│ │ │ │ │ ├── cup_split6.stl
│ │ │ │ │ ├── cup_split7.stl
│ │ │ │ │ ├── cup_split8.stl
│ │ │ │ │ ├── cup_split9.stl
│ │ │ │ │ ├── elbow_link_convex.stl
│ │ │ │ │ ├── elbow_link_fine.stl
│ │ │ │ │ ├── forearm_link_convex_decomposition_p1.stl
│ │ │ │ │ ├── forearm_link_convex_decomposition_p2.stl
│ │ │ │ │ ├── forearm_link_fine.stl
│ │ │ │ │ ├── shoulder_link_convex_decomposition_p1.stl
│ │ │ │ │ ├── shoulder_link_convex_decomposition_p2.stl
│ │ │ │ │ ├── shoulder_link_convex_decomposition_p3.stl
│ │ │ │ │ ├── shoulder_link_fine.stl
│ │ │ │ │ ├── shoulder_pitch_link_convex.stl
│ │ │ │ │ ├── shoulder_pitch_link_fine.stl
│ │ │ │ │ ├── upper_arm_link_convex_decomposition_p1.stl
│ │ │ │ │ ├── upper_arm_link_convex_decomposition_p2.stl
│ │ │ │ │ ├── upper_arm_link_fine.stl
│ │ │ │ │ ├── wrist_palm_link_convex.stl
│ │ │ │ │ ├── wrist_palm_link_fine.stl
│ │ │ │ │ ├── wrist_pitch_link_convex_decomposition_p1.stl
│ │ │ │ │ ├── wrist_pitch_link_convex_decomposition_p2.stl
│ │ │ │ │ ├── wrist_pitch_link_convex_decomposition_p3.stl
│ │ │ │ │ ├── wrist_pitch_link_fine.stl
│ │ │ │ │ ├── wrist_yaw_link_convex_decomposition_p1.stl
│ │ │ │ │ ├── wrist_yaw_link_convex_decomposition_p2.stl
│ │ │ │ │ └── wrist_yaw_link_fine.stl
│ │ │ │ └── model.xml
│ │ │ ├── half_cheetah
│ │ │ │ ├── __init__.py
│ │ │ │ └── model.xml
│ │ │ ├── hopper
│ │ │ │ ├── __init__.py
│ │ │ │ └── model.xml
│ │ │ ├── panda
│ │ │ │ ├── assets
│ │ │ │ │ ├── cube.xml
│ │ │ │ │ ├── finger_0.obj
│ │ │ │ │ ├── finger_1.obj
│ │ │ │ │ ├── hand.stl
│ │ │ │ │ ├── hand_0.obj
│ │ │ │ │ ├── hand_1.obj
│ │ │ │ │ ├── hand_2.obj
│ │ │ │ │ ├── hand_3.obj
│ │ │ │ │ ├── hand_4.obj
│ │ │ │ │ ├── link0.stl
│ │ │ │ │ ├── link0_0.obj
│ │ │ │ │ ├── link0_1.obj
│ │ │ │ │ ├── link0_10.obj
│ │ │ │ │ ├── link0_11.obj
│ │ │ │ │ ├── link0_2.obj
│ │ │ │ │ ├── link0_3.obj
│ │ │ │ │ ├── link0_4.obj
│ │ │ │ │ ├── link0_5.obj
│ │ │ │ │ ├── link0_7.obj
│ │ │ │ │ ├── link0_8.obj
│ │ │ │ │ ├── link0_9.obj
│ │ │ │ │ ├── link1.obj
│ │ │ │ │ ├── link1.stl
│ │ │ │ │ ├── link2.obj
│ │ │ │ │ ├── link2.stl
│ │ │ │ │ ├── link3.stl
│ │ │ │ │ ├── link3_0.obj
│ │ │ │ │ ├── link3_1.obj
│ │ │ │ │ ├── link3_2.obj
│ │ │ │ │ ├── link3_3.obj
│ │ │ │ │ ├── link4.stl
│ │ │ │ │ ├── link4_0.obj
│ │ │ │ │ ├── link4_1.obj
│ │ │ │ │ ├── link4_2.obj
│ │ │ │ │ ├── link4_3.obj
│ │ │ │ │ ├── link5_0.obj
│ │ │ │ │ ├── link5_1.obj
│ │ │ │ │ ├── link5_2.obj
│ │ │ │ │ ├── link5_collision_0.obj
│ │ │ │ │ ├── link5_collision_1.obj
│ │ │ │ │ ├── link5_collision_2.obj
│ │ │ │ │ ├── link6.stl
│ │ │ │ │ ├── link6_0.obj
│ │ │ │ │ ├── link6_1.obj
│ │ │ │ │ ├── link6_10.obj
│ │ │ │ │ ├── link6_11.obj
│ │ │ │ │ ├── link6_12.obj
│ │ │ │ │ ├── link6_13.obj
│ │ │ │ │ ├── link6_14.obj
│ │ │ │ │ ├── link6_15.obj
│ │ │ │ │ ├── link6_16.obj
│ │ │ │ │ ├── link6_2.obj
│ │ │ │ │ ├── link6_3.obj
│ │ │ │ │ ├── link6_4.obj
│ │ │ │ │ ├── link6_5.obj
│ │ │ │ │ ├── link6_6.obj
│ │ │ │ │ ├── link6_7.obj
│ │ │ │ │ ├── link6_8.obj
│ │ │ │ │ ├── link6_9.obj
│ │ │ │ │ ├── link7.stl
│ │ │ │ │ ├── link7_0.obj
│ │ │ │ │ ├── link7_1.obj
│ │ │ │ │ ├── link7_2.obj
│ │ │ │ │ ├── link7_3.obj
│ │ │ │ │ ├── link7_4.obj
│ │ │ │ │ ├── link7_5.obj
│ │ │ │ │ ├── link7_6.obj
│ │ │ │ │ ├── link7_7.obj
│ │ │ │ │ └── table.xml
│ │ │ │ ├── panda.xml
│ │ │ │ ├── peg_insertion.xml
│ │ │ │ ├── pick.xml
│ │ │ │ ├── push.xml
│ │ │ │ └── reach.xml
│ │ │ └── walker_2d
│ │ │ │ ├── __init__.py
│ │ │ │ └── model.xml
│ │ ├── half_cheetah.py
│ │ ├── hopper.py
│ │ ├── panda.py
│ │ ├── peg_insertion.py
│ │ ├── pick.py
│ │ ├── push.py
│ │ ├── reach.py
│ │ └── walker_2d.py
│ ├── omni_isaac_gym_env.py
│ ├── puddle_world.py
│ ├── pybullet.py
│ ├── pybullet_envs
│ │ ├── __init__.py
│ │ ├── air_hockey
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── defend.py
│ │ │ ├── double.py
│ │ │ ├── hit.py
│ │ │ ├── prepare.py
│ │ │ ├── repel.py
│ │ │ └── single.py
│ │ └── data
│ │ │ ├── __init__.py
│ │ │ └── air_hockey
│ │ │ ├── air_hockey_table.urdf
│ │ │ ├── planar
│ │ │ ├── planar_robot_1.urdf
│ │ │ └── planar_robot_2.urdf
│ │ │ └── puck.urdf
│ ├── segway.py
│ └── ship_steering.py
├── features
│ ├── __init__.py
│ ├── _implementations
│ │ ├── __init__.py
│ │ ├── basis_features.py
│ │ ├── features_implementation.py
│ │ ├── functional_features.py
│ │ ├── tiles_features.py
│ │ └── torch_features.py
│ ├── basis
│ │ ├── __init__.py
│ │ ├── fourier.py
│ │ ├── gaussian_rbf.py
│ │ └── polynomial.py
│ ├── features.py
│ ├── tensors
│ │ ├── __init__.py
│ │ ├── basis_tensor.py
│ │ ├── constant_tensor.py
│ │ └── random_fourier_tensor.py
│ └── tiles
│ │ ├── __init__.py
│ │ ├── tiles.py
│ │ └── voronoi.py
├── policy
│ ├── __init__.py
│ ├── deterministic_policy.py
│ ├── dmp.py
│ ├── gaussian_policy.py
│ ├── noise_policy.py
│ ├── policy.py
│ ├── promps.py
│ ├── recurrent_torch_policy.py
│ ├── td_policy.py
│ ├── torch_policy.py
│ └── vector_policy.py
├── rl_utils
│ ├── __init__.py
│ ├── eligibility_trace.py
│ ├── optimizers.py
│ ├── parameters.py
│ ├── preprocessors.py
│ ├── replay_memory.py
│ ├── running_stats.py
│ ├── spaces.py
│ ├── value_functions.py
│ └── variance_parameters.py
├── solvers
│ ├── __init__.py
│ ├── car_on_hill.py
│ ├── dynamic_programming.py
│ └── lqr.py
└── utils
│ ├── __init__.py
│ ├── angles.py
│ ├── callbacks
│ ├── __init__.py
│ ├── callback.py
│ ├── collect_dataset.py
│ ├── collect_max_q.py
│ ├── collect_parameters.py
│ ├── collect_q.py
│ └── plot_dataset.py
│ ├── episodes.py
│ ├── features.py
│ ├── frames.py
│ ├── isaac_sim
│ ├── __init__.py
│ ├── action_helper.py
│ ├── collision_helper.py
│ ├── general_task.py
│ └── observation_helper.py
│ ├── isaac_utils.py
│ ├── minibatches.py
│ ├── mujoco
│ ├── __init__.py
│ ├── kinematics.py
│ ├── observation_helper.py
│ └── viewer.py
│ ├── numerical_gradient.py
│ ├── plot.py
│ ├── plots
│ ├── __init__.py
│ ├── common_plots.py
│ ├── databuffer.py
│ ├── plot_item_buffer.py
│ └── window.py
│ ├── pybullet
│ ├── __init__.py
│ ├── contacts.py
│ ├── index_map.py
│ ├── joints_helper.py
│ ├── observation.py
│ └── viewer.py
│ ├── quaternions.py
│ ├── record.py
│ ├── torch.py
│ └── viewer.py
├── pyproject.toml
├── setup.py
└── tests
├── algorithms
├── helper
│ └── utils.py
├── test_a2c.py
├── test_black_box.py
├── test_ddpg.py
├── test_dpg.py
├── test_dqn.py
├── test_fqi.py
├── test_lspi.py
├── test_policy_gradient.py
├── test_sac.py
├── test_stochastic_ac.py
├── test_td.py
└── test_trust_region.py
├── approximators
├── test_cmac_approximator.py
├── test_linear_approximator.py
└── test_torch_approximator.py
├── core
├── test_array_backend.py
├── test_core.py
├── test_dataset.py
├── test_extra_info.py
├── test_logger.py
├── test_serialization.py
└── test_vectorized_core.py
├── distributions
├── test_distribution_interface.py
└── test_gaussian_distribution.py
├── environments
├── grid.txt
├── isaacsim_envs
│ └── test_envs.py
├── mujoco_envs
│ ├── air_hockey_defend_data.npy
│ ├── air_hockey_hit_data.npy
│ ├── air_hockey_prepare_data.npy
│ ├── air_hockey_repel_data.npy
│ ├── test_air_hockey.py
│ ├── test_ball_in_a_cup.py
│ └── test_locomotion.py
├── pybullet_envs
│ ├── air_hockey_defend_data.npy
│ ├── air_hockey_hit_data.npy
│ ├── air_hockey_prepare_data.npy
│ ├── air_hockey_repel_data.npy
│ └── test_air_hockey_bullet.py
├── taxi.txt
├── test_all_envs.py
├── test_atari_1.npy
└── test_mujoco.py
├── features
└── test_features.py
├── policy
├── test_deterministic_policy.py
├── test_gaussian_policy.py
├── test_noise_policy.py
├── test_policy_interface.py
├── test_td_policy.py
└── test_torch_policy.py
├── rl_utils
└── test_value_functions.py
├── solvers
├── test_car_on_hill.py
├── test_dynamic_programming.py
└── test_lqr.py
├── test_imports.py
└── utils
├── test_callbacks.py
├── test_episodes.py
└── test_preprocessors.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit =
3 | mushroom_rl/environments/mujoco.py
4 | mushroom_rl/environments/mujoco_envs/*
5 |
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Provide a snippet of code, or a Python file.
15 |
16 | **Expected behavior**
17 | A clear and concise description of what you expected to happen.
18 |
19 | **System information (please complete the following information):**
20 | - OS: [e.g. Ubuntu 18.04]
21 | - Python version [e.g. Python3.6]
22 | - Torch version [e.g. Pytorch 1.3]
23 | - Mushroom version [e.g. 1.2.0, master]
24 |
25 | **Additional context**
26 | Add any other context about the problem here.
27 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is.
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/workflows/continuous_integration.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python, upload the coverage results to code climate
2 |
3 | name: Continuous Integration
4 |
5 | on:
6 | push:
7 | branches: [ dev, dev-v1, master ]
8 |
9 | jobs:
10 | build:
11 | if: github.repository == 'MushroomRL/mushroom-rl'
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v3
16 | - name: Set up Python 3.8
17 | uses: actions/setup-python@v3
18 | with:
19 | python-version: 3.8
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install flake8 pytest pytest-cov
24 | pip install -e .[all]
25 | - name: Install Atari ROMs
26 | run: |
27 | pip install "autorom[accept-rom-license]"
28 | - name: Lint with flake8
29 | run: |
30 | # stop the build if there are Python syntax errors or undefined names
31 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
32 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
33 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
34 | - name: Test with pytest
35 | run: |
36 | pytest --cov=mushroom_rl --cov-report=xml
37 | - name: Publish code coverage to CodeClimate
38 | uses: paambaati/codeclimate-action@v2.7.5
39 | env:
40 | CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }}
41 |
--------------------------------------------------------------------------------
/.github/workflows/test_pr.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python, upload the coverage results to code climate
2 |
3 | name: Test Pull Request
4 |
5 | on:
6 | pull_request:
7 | branches: [ dev ]
8 |
9 |
10 | jobs:
11 | build:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python 3.8
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: 3.8
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install flake8 pytest pytest-cov
25 | pip install -e .[all]
26 | - name: Install Atari ROMs
27 | run: |
28 | pip install "autorom[accept-rom-license]"
29 | - name: Lint with flake8
30 | run: |
31 | # stop the build if there are Python syntax errors or undefined names
32 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
33 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
34 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
35 | - name: Test with pytest
36 | run: |
37 | pytest
38 |
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_store
2 | build/
3 | dist/
4 | examples/mushroom_rl_recordings/
5 | examples/habitat/Replica-Dataset
6 | examples/habitat/data
7 | mushroom_rl.egg-info/
8 | mushroom_rl_recordings/
9 | .idea/
10 | *.pyc
11 | *.pyd
12 | *.xml
13 | logs/
14 | *.h5
15 | .pytest_cache
16 | .coverage*
17 | *.c
18 | *.so
19 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | build:
9 | os: ubuntu-20.04
10 | tools:
11 | python: "3.8"
12 |
13 | sphinx:
14 | configuration: docs/conf.py
15 |
16 | python:
17 | install:
18 | - requirements: docs/requirements.txt
19 | - method: pip
20 | path: .
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2018, 2019, 2020 Carlo D'Eramo, Davide Tateo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include setup.py
3 | prune tests
4 | prune examples
5 | prune .github
6 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | package:
2 | python3 -m build
3 |
4 | install:
5 | pip install $(shell ls dist/*.tar.gz)
6 |
7 | all: clean package install
8 |
9 | upload:
10 | python3 -m twine upload dist/*
11 |
12 | clean:
13 | rm -rf dist
14 | rm -rf build
15 |
16 | .NOTPARALLEL:
17 |
--------------------------------------------------------------------------------
/TODO.txt:
--------------------------------------------------------------------------------
1 | Environments:
2 | * implement Multirobot PyBullet Interface to solve issues with joint name clash
3 |
4 | Algorithms:
5 | * Conservative Q-Learning
6 | * Policy Search:
7 | - Natural gradient
8 | - NES
9 | - PAPI
10 |
11 | Policy:
12 | * Add Boltzmann from logits for traditional policy gradient methods
13 |
14 | Approximator:
15 | * support for LSTM
16 | * Generalize LazyFrame to LazyState
17 | * add neural network generator
18 |
19 | For Mushroom 2.0:
20 | * Simplify Regressor interface: drop GenericRegressor, remove facade pattern
21 | * vectorize basis functions and simplify interface, simplify facade pattern
22 | * remove custom save for plotting, use Serializable
23 | * support multi-objective RL
24 | * support model-based RL
25 | * Improve replay memory, allowing to store arbitrary information into replay buffer
26 |
27 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _build/
2 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = MushroomRL
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/docs/__init__.py
--------------------------------------------------------------------------------
/docs/_static/theme_overrides.css:
--------------------------------------------------------------------------------
1 | .wy-nav-content {
2 | max-width: 1000px !important;
3 | }
--------------------------------------------------------------------------------
/docs/_templates/tabular.tex_t:
--------------------------------------------------------------------------------
1 | \begin{savenotes}\sphinxattablestart
2 | \small
3 | <% if table.align -%>
4 | <%- if table.align in ('center', 'default') -%>
5 | \centering
6 | <%- elif table.align == 'left' -%>
7 | \raggedright
8 | <%- else -%>
9 | \raggedleft
10 | <%- endif %>
11 | <%- else -%>
12 | \centering
13 | <%- endif %>
14 | <% if table.caption -%>
15 | \sphinxcapstartof{table}
16 | \sphinxthecaptionisattop
17 | \sphinxcaption{<%= ''.join(table.caption) %>}<%= labels %>
18 | \sphinxaftertopcaption
19 | <% elif labels -%>
20 | \phantomsection<%= labels %>\nobreak
21 | <% endif -%>
22 | \begin{tabular}[t]<%= table.get_colspec() -%>
23 | \hline
24 | <%= ''.join(table.header) %>
25 | <%=- ''.join(table.body) %>
26 | \end{tabular}
27 | \par
28 | \sphinxattableend\end{savenotes}
29 |
--------------------------------------------------------------------------------
/docs/_templates/tabulary.tex_t:
--------------------------------------------------------------------------------
1 | \begin{savenotes}\sphinxattablestart
2 | \footnotesize
3 | <% if table.align -%>
4 | <%- if table.align in ('center', 'default') -%>
5 | \centering
6 | <%- elif table.align == 'left' -%>
7 | \raggedright
8 | <%- else -%>
9 | \raggedleft
10 | <%- endif %>
11 | <%- else -%>
12 | \centering
13 | <%- endif %>
14 | <% if table.caption -%>
15 | \sphinxcapstartof{table}
16 | \sphinxthecaptionisattop
17 | \sphinxcaption{<%= ''.join(table.caption) %>}<%= labels %>
18 | \sphinxaftertopcaption
19 | <% elif labels -%>
20 | \phantomsection<%= labels %>\nobreak
21 | <% endif -%>
22 | \begin{tabulary}{\linewidth}[t]<%= table.get_colspec() -%>
23 | \hline
24 | <%= ''.join(table.header) %>
25 | <%=- ''.join(table.body) %>
26 | \end{tabulary}
27 | \par
28 | \sphinxattableend\end{savenotes}
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | numpy
3 | scipy
4 | gym
5 | scikit-learn
6 | matplotlib
7 | joblib
8 | tqdm
9 | pygame
10 | sphinx-rtd-theme
11 |
--------------------------------------------------------------------------------
/docs/source/features/mushroom_rl.features.basis.rst:
--------------------------------------------------------------------------------
1 | Basis
2 | =====
3 |
4 | Fourier
5 | -------
6 |
7 | .. automodule:: mushroom_rl.features.basis.fourier
8 | :members:
9 | :private-members:
10 | :inherited-members:
11 | :show-inheritance:
12 |
13 | Gaussian RBF
14 | ------------
15 |
16 | .. automodule:: mushroom_rl.features.basis.gaussian_rbf
17 | :members:
18 | :private-members:
19 | :inherited-members:
20 | :show-inheritance:
21 |
22 | Polynomial
23 | ----------
24 |
25 | .. automodule:: mushroom_rl.features.basis.polynomial
26 | :members:
27 | :private-members:
28 | :inherited-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/docs/source/features/mushroom_rl.features.tensors.rst:
--------------------------------------------------------------------------------
1 | Tensors
2 | =======
3 |
4 | .. automodule:: mushroom_rl.features.tensors.constant_tensor
5 | :members:
6 | :private-members:
7 | :show-inheritance:
8 |
9 | .. automodule:: mushroom_rl.features.tensors.basis_tensor
10 | :members:
11 | :private-members:
12 | :show-inheritance:
13 |
14 | .. automodule:: mushroom_rl.features.tensors.random_fourier_tensor
15 | :members:
16 | :private-members:
17 | :show-inheritance:
18 |
--------------------------------------------------------------------------------
/docs/source/features/mushroom_rl.features.tiles.rst:
--------------------------------------------------------------------------------
1 | Tiles
2 | =====
3 |
4 | Rectangular Tiles
5 | -----------------
6 |
7 | .. automodule:: mushroom_rl.features.tiles.tiles
8 | :members:
9 | :private-members:
10 | :inherited-members:
11 | :show-inheritance:
12 |
13 | Voronoi Tiles
14 | -------------
15 |
16 | .. automodule:: mushroom_rl.features.tiles.voronoi
17 | :members:
18 | :private-members:
19 | :inherited-members:
20 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.agent_environment_interface.rst:
--------------------------------------------------------------------------------
1 | Agent-Environment Interface
2 | ===========================
3 |
4 | The three basic interface of mushroom_rl are the Agent, the Environment and the Core interface.
5 |
6 | - The ``Agent`` is the basic interface for any Reinforcement Learning algorithm.
7 | - The ``Environment`` is the basic interface for every problem/task that the agent should solve.
8 | - The ``Core`` is a class used to control the interaction between an agent and an environment.
9 |
10 | To implement serialization of MushroomRL data on the disk (load/save functionality) we also provide the ``Serializable``
11 | interface. Finally, we provide the logging functionality with the ``Logger`` class.
12 |
13 |
14 | Agent
15 | -----
16 |
17 | MushroomRL provides the implementations of several algorithms belonging to all
18 | categories of RL:
19 |
20 | - value-based;
21 | - policy-search;
22 | - actor-critic.
23 |
24 | One can easily implement customized algorithms following the structure of the
25 | already available ones, by extending the following interface:
26 |
27 | .. automodule:: mushroom_rl.core.agent
28 | :members:
29 | :private-members:
30 | :inherited-members:
31 | :show-inheritance:
32 |
33 | Environment
34 | -----------
35 |
36 | MushroomRL provides several implementation of well known benchmarks with both
37 | continuous and discrete action spaces.
38 |
39 | To implement a new environment, it is mandatory to use the following interface:
40 |
41 | .. automodule:: mushroom_rl.core.environment
42 | :members:
43 | :private-members:
44 | :inherited-members:
45 | :show-inheritance:
46 |
47 |
48 | Core
49 | ----
50 |
51 | .. automodule:: mushroom_rl.core.core
52 | :members:
53 | :private-members:
54 | :inherited-members:
55 | :show-inheritance:
56 |
57 | Serialization
58 | -------------
59 |
60 | .. automodule:: mushroom_rl.core.serialization
61 | :members:
62 | :private-members:
63 | :inherited-members:
64 | :show-inheritance:
65 |
66 | Logger
67 | ------
68 |
69 | .. automodule:: mushroom_rl.core.logger
70 | :members:
71 | :private-members:
72 | :inherited-members:
73 | :show-inheritance:
74 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.algorithms.actor_critic.rst:
--------------------------------------------------------------------------------
1 | Actor-Critic
2 | ============
3 |
4 | Classical Actor-Critic Methods
5 | ------------------------------
6 |
7 | .. automodule:: mushroom_rl.algorithms.actor_critic.classic_actor_critic
8 | :members:
9 | :private-members:
10 | :inherited-members:
11 | :show-inheritance:
12 |
13 | Deep Actor-Critic Methods
14 | -------------------------
15 |
16 | .. automodule:: mushroom_rl.algorithms.actor_critic.deep_actor_critic
17 | :members:
18 | :private-members:
19 | :inherited-members:
20 | :show-inheritance:
21 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.algorithms.policy_search.rst:
--------------------------------------------------------------------------------
1 | Policy search
2 | =============
3 |
4 | Policy gradient
5 | ---------------
6 |
7 | .. automodule:: mushroom_rl.algorithms.policy_search.policy_gradient
8 | :members:
9 | :private-members:
10 | :inherited-members:
11 | :show-inheritance:
12 |
13 | Black-Box optimization
14 | ----------------------
15 |
16 | .. automodule:: mushroom_rl.algorithms.policy_search.black_box_optimization
17 | :members:
18 | :private-members:
19 | :inherited-members:
20 | :show-inheritance:
21 |
22 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.algorithms.value.rst:
--------------------------------------------------------------------------------
1 | Value-Based
2 | ===========
3 |
4 | TD
5 | --
6 |
7 | .. automodule:: mushroom_rl.algorithms.value.td
8 | :members:
9 | :private-members:
10 | :inherited-members:
11 | :show-inheritance:
12 |
13 | Batch TD
14 | --------
15 |
16 | .. automodule:: mushroom_rl.algorithms.value.batch_td
17 | :members:
18 | :private-members:
19 | :inherited-members:
20 | :show-inheritance:
21 |
22 | DQN
23 | ---
24 |
25 | .. automodule:: mushroom_rl.algorithms.value.dqn
26 | :members:
27 | :private-members:
28 | :inherited-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.approximators.rst:
--------------------------------------------------------------------------------
1 | Approximators
2 | =============
3 |
4 | MushroomRL exposes the high-level class ``Regressor`` that can manage any type of
5 | function regressor. This class is a wrapper for any kind of function
6 | approximator, e.g. a scikit-learn approximator or a pytorch neural network.
7 |
8 | Regressor
9 | ---------
10 |
11 | .. automodule:: mushroom_rl.approximators.regressor
12 | :members:
13 | :private-members:
14 | :inherited-members:
15 | :show-inheritance:
16 |
17 |
18 | Approximators
19 | -------------
20 |
21 | Tabular
22 | ~~~~~~~
23 |
24 | .. automodule:: mushroom_rl.approximators.table
25 | :members:
26 | :private-members:
27 | :inherited-members:
28 | :show-inheritance:
29 |
30 |
31 |
32 | Linear
33 | ~~~~~~
34 |
35 | .. automodule:: mushroom_rl.approximators.parametric.linear
36 | :members:
37 | :private-members:
38 | :inherited-members:
39 | :show-inheritance:
40 |
41 | CMAC
42 | ~~~~
43 |
44 | .. automodule:: mushroom_rl.approximators.parametric.cmac
45 | :members:
46 | :private-members:
47 | :inherited-members:
48 | :show-inheritance:
49 |
50 | Torch Approximator
51 | ~~~~~~~~~~~~~~~~~~
52 |
53 | .. automodule:: mushroom_rl.approximators.parametric.torch_approximator
54 | :members:
55 | :private-members:
56 | :inherited-members:
57 | :show-inheritance:
58 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.distributions.rst:
--------------------------------------------------------------------------------
1 | Distributions
2 | =============
3 |
4 | .. automodule:: mushroom_rl.distributions.distribution
5 | :members:
6 | :private-members:
7 | :inherited-members:
8 | :show-inheritance:
9 |
10 | Gaussian
11 | --------
12 |
13 | .. automodule:: mushroom_rl.distributions.gaussian
14 | :members:
15 | :private-members:
16 | :inherited-members:
17 | :show-inheritance:
18 |
19 |
20 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.features.rst:
--------------------------------------------------------------------------------
1 | Features
2 | ========
3 |
4 | The features in MushroomRL are 1-D arrays computed applying a specified function
5 | to a raw input, e.g. polynomial features of the state of an MDP.
6 | MushroomRL supports three types of features:
7 |
8 | * basis functions;
9 | * tensor basis functions;
10 | * tiles.
11 |
12 | The tensor basis functions are a PyTorch implementation of the standard
13 | basis functions. They are less straightforward than the standard ones, but they
14 | are faster to compute as they can exploit parallel computing, e.g. GPU-acceleration
15 | and multi-core systems.
16 |
17 | All the types of features are exposed by a single factory method ``Features``
18 | that builds the one requested by the user.
19 |
20 | .. automodule:: mushroom_rl.features.features
21 | :members:
22 | :private-members:
23 | :inherited-members:
24 | :show-inheritance:
25 |
26 | The factory method returns a class that extends the abstract class
27 | ``FeatureImplementation``.
28 |
29 | .. automodule:: mushroom_rl.features._implementations.features_implementation
30 | :members:
31 | :private-members:
32 | :inherited-members:
33 | :show-inheritance:
34 |
35 | The documentation for every feature type can be found here:
36 |
37 | .. toctree::
38 |
39 | features/mushroom_rl.features.basis
40 | features/mushroom_rl.features.tensors
41 | features/mushroom_rl.features.tiles
42 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.policy.rst:
--------------------------------------------------------------------------------
1 | Policy
2 | ======
3 |
4 | .. automodule:: mushroom_rl.policy.policy
5 | :members:
6 | :private-members:
7 | :inherited-members:
8 | :show-inheritance:
9 |
10 | Deterministic policy
11 | --------------------
12 |
13 | .. automodule:: mushroom_rl.policy.deterministic_policy
14 | :members:
15 | :private-members:
16 | :inherited-members:
17 | :show-inheritance:
18 |
19 | Gaussian policy
20 | ---------------
21 |
22 | .. automodule:: mushroom_rl.policy.gaussian_policy
23 | :members:
24 | :private-members:
25 | :inherited-members:
26 | :show-inheritance:
27 |
28 | Noise policy
29 | ------------
30 |
31 | .. automodule:: mushroom_rl.policy.noise_policy
32 | :members:
33 | :private-members:
34 | :inherited-members:
35 | :show-inheritance:
36 |
37 | TD policy
38 | ---------
39 |
40 | .. automodule:: mushroom_rl.policy.td_policy
41 | :members:
42 | :private-members:
43 | :inherited-members:
44 | :show-inheritance:
45 |
46 | Torch policy
47 | ------------
48 |
49 | .. automodule:: mushroom_rl.policy.torch_policy
50 | :members:
51 | :private-members:
52 | :inherited-members:
53 | :show-inheritance:
54 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.rl_utils.rst:
--------------------------------------------------------------------------------
1 | Reinforcement Learning utils
2 | ============================
3 |
4 | Eligibility trace
5 | -----------------
6 |
7 | .. automodule:: mushroom_rl.rl_utils.eligibility_trace
8 | :members:
9 | :private-members:
10 | :inherited-members:
11 | :undoc-members:
12 | :show-inheritance:
13 |
14 | Optimizers
15 | ----------
16 |
17 | .. automodule:: mushroom_rl.rl_utils.optimizers
18 | :members:
19 | :private-members:
20 | :inherited-members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
24 | Parameters
25 | ----------
26 |
27 | .. automodule:: mushroom_rl.rl_utils.parameters
28 | :members:
29 | :private-members:
30 | :inherited-members:
31 | :show-inheritance:
32 |
33 |
34 | Preprocessors
35 | -------------
36 |
37 | .. automodule:: mushroom_rl.rl_utils.preprocessors
38 | :members:
39 | :private-members:
40 | :inherited-members:
41 | :show-inheritance:
42 |
43 | Replay memory
44 | -------------
45 |
46 | .. automodule:: mushroom_rl.rl_utils.replay_memory
47 | :members:
48 | :private-members:
49 | :inherited-members:
50 | :show-inheritance:
51 |
52 | Running Statistics
53 | ------------------
54 |
55 | .. automodule:: mushroom_rl.rl_utils.running_stats
56 | :members:
57 | :private-members:
58 | :inherited-members:
59 | :show-inheritance:
60 |
61 | Spaces
62 | ------
63 |
64 | .. automodule:: mushroom_rl.rl_utils.spaces
65 | :members:
66 | :show-inheritance:
67 |
68 |
69 | Value Functions
70 | ---------------
71 |
72 | .. automodule:: mushroom_rl.rl_utils.value_functions
73 | :members:
74 | :private-members:
75 | :inherited-members:
76 | :show-inheritance:
77 |
78 | Variance parameters
79 | -------------------
80 |
81 | .. automodule:: mushroom_rl.rl_utils.variance_parameters
82 | :members:
83 | :private-members:
84 | :inherited-members:
85 | :show-inheritance:
86 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.solvers.rst:
--------------------------------------------------------------------------------
1 | Solvers
2 | =======
3 |
4 | Dynamic programming
5 | -------------------
6 |
7 | .. automodule:: mushroom_rl.solvers.dynamic_programming
8 | :members:
9 | :private-members:
10 | :inherited-members:
11 | :show-inheritance:
12 |
13 | Car-On-Hill brute-force solver
14 | ------------------------------
15 |
16 | .. automodule:: mushroom_rl.solvers.car_on_hill
17 | :members:
18 | :private-members:
19 | :inherited-members:
20 | :show-inheritance:
21 |
22 |
23 | LQR solver
24 | ----------
25 |
26 | .. automodule:: mushroom_rl.solvers.lqr
27 | :members:
28 | :private-members:
29 | :inherited-members:
30 | :show-inheritance:
31 |
--------------------------------------------------------------------------------
/docs/source/mushroom_rl.utils.rst:
--------------------------------------------------------------------------------
1 | Utils
2 | =====
3 |
4 | Angles
5 | ------
6 |
7 | .. automodule:: mushroom_rl.utils.angles
8 | :members:
9 | :show-inheritance:
10 |
11 | Features
12 | --------
13 |
14 | .. automodule:: mushroom_rl.utils.features
15 | :members:
16 | :private-members:
17 | :inherited-members:
18 | :show-inheritance:
19 |
20 |
21 | Frames
22 | ------
23 |
24 | .. automodule:: mushroom_rl.utils.frames
25 | :members:
26 | :private-members:
27 | :inherited-members:
28 | :show-inheritance:
29 |
30 |
31 | Minibatches
32 | -----------
33 |
34 | .. automodule:: mushroom_rl.utils.minibatches
35 | :members:
36 | :private-members:
37 | :inherited-members:
38 | :undoc-members:
39 | :show-inheritance:
40 |
41 | Numerical gradient
42 | ------------------
43 |
44 | .. automodule:: mushroom_rl.utils.numerical_gradient
45 | :members:
46 | :private-members:
47 | :inherited-members:
48 | :show-inheritance:
49 |
50 |
51 | Plots
52 | -----
53 |
54 | .. automodule:: mushroom_rl.utils.plot
55 | :members:
56 | :private-members:
57 | :inherited-members:
58 | :show-inheritance:
59 |
60 |
61 | Record
62 | ------
63 |
64 | .. automodule:: mushroom_rl.utils.record
65 | :members:
66 | :private-members:
67 | :inherited-members:
68 | :show-inheritance:
69 |
70 |
71 | Torch
72 | -----
73 |
74 | .. automodule:: mushroom_rl.utils.torch
75 | :members:
76 | :private-members:
77 | :inherited-members:
78 | :show-inheritance:
79 |
80 |
81 | Viewer
82 | ------
83 |
84 | .. automodule:: mushroom_rl.utils.viewer
85 | :members:
86 | :private-members:
87 | :inherited-members:
88 | :show-inheritance:
89 |
--------------------------------------------------------------------------------
/docs/source/tutorials/code/advanced_experiment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value import SARSALambdaContinuous
4 | from mushroom_rl.approximators.parametric import LinearApproximator
5 | from mushroom_rl.core import Core
6 | from mushroom_rl.features import Features
7 | from mushroom_rl.features.tiles import Tiles
8 | from mushroom_rl.policy import EpsGreedy
9 | from mushroom_rl.utils.callbacks import CollectDataset
10 | from mushroom_rl.rl_utils.parameters import Parameter
11 | from mushroom_rl.environments import Gymnasium
12 |
13 | # MDP
14 | mdp = Gymnasium(name='MountainCar-v0', horizon=np.inf, gamma=1.)
15 |
16 | # Policy
17 | epsilon = Parameter(value=0.)
18 | pi = EpsGreedy(epsilon=epsilon)
19 |
20 | # Q-function approximator
21 | n_tilings = 10
22 | tilings = Tiles.generate(n_tilings, [10, 10],
23 | mdp.info.observation_space.low,
24 | mdp.info.observation_space.high)
25 | features = Features(tilings=tilings)
26 |
27 | approximator_params = dict(input_shape=(features.size,),
28 | output_shape=(mdp.info.action_space.n,),
29 | n_actions=mdp.info.action_space.n)
30 |
31 | # Agent
32 | learning_rate = Parameter(.1 / n_tilings)
33 |
34 | agent = SARSALambdaContinuous(mdp.info, pi, LinearApproximator,
35 | approximator_params=approximator_params,
36 | learning_rate=learning_rate,
37 | lambda_coeff=.9, features=features)
38 |
39 | # Algorithm
40 | collect_dataset = CollectDataset()
41 | callbacks = [collect_dataset]
42 | core = Core(agent, mdp, callbacks_fit=callbacks)
43 |
44 | # Train
45 | core.learn(n_episodes=100, n_steps_per_fit=1)
46 |
47 | # Evaluate
48 | core.evaluate(n_episodes=1, render=True)
49 |
--------------------------------------------------------------------------------
/docs/source/tutorials/code/approximator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value import SARSALambdaContinuous
4 | from mushroom_rl.approximators.parametric import LinearApproximator
5 | from mushroom_rl.core import Core
6 | from mushroom_rl.environments import *
7 | from mushroom_rl.features import Features
8 | from mushroom_rl.features.tiles import Tiles
9 | from mushroom_rl.policy import EpsGreedy
10 | from mushroom_rl.utils.callbacks import CollectDataset
11 | from mushroom_rl.rl_utils.parameters import Parameter
12 |
13 |
14 | # MDP
15 | mdp = Gymnasium(name='MountainCar-v0', horizon=np.inf, gamma=1.)
16 |
17 | # Policy
18 | epsilon = Parameter(value=0.)
19 | pi = EpsGreedy(epsilon=epsilon)
20 |
21 | # Q-function approximator
22 | n_tilings = 10
23 | tilings = Tiles.generate(n_tilings, [10, 10],
24 | mdp.info.observation_space.low,
25 | mdp.info.observation_space.high)
26 | features = Features(tilings=tilings)
27 |
28 | # Agent
29 | learning_rate = Parameter(.1 / n_tilings)
30 | approximator_params = dict(input_shape=(features.size,),
31 | output_shape=(mdp.info.action_space.n,),
32 | n_actions=mdp.info.action_space.n)
33 | agent = SARSALambdaContinuous(mdp.info, pi, LinearApproximator,
34 | approximator_params=approximator_params,
35 | learning_rate=learning_rate,
36 | lambda_coeff=.9, features=features)
37 |
38 | # Algorithm
39 | collect_dataset = CollectDataset()
40 | callbacks = [collect_dataset]
41 | core = Core(agent, mdp, callbacks_fit=callbacks)
42 |
43 | # Train
44 | core.learn(n_episodes=100, n_steps_per_fit=1)
45 |
46 | # Evaluate
47 | core.evaluate(n_episodes=1, render=True)
48 |
--------------------------------------------------------------------------------
/docs/source/tutorials/code/generic_regressor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 |
4 | from mushroom_rl.approximators import Regressor
5 | from mushroom_rl.approximators.parametric import LinearApproximator
6 |
7 |
8 | x = np.arange(10).reshape(-1, 1)
9 |
10 | intercept = 10
11 | noise = np.random.randn(10, 1) * 1
12 | y = 2 * x + intercept + noise
13 |
14 | phi = np.concatenate((np.ones(10).reshape(-1, 1), x), axis=1)
15 |
16 | regressor = Regressor(LinearApproximator,
17 | input_shape=(2,),
18 | output_shape=(1,))
19 |
20 | regressor.fit(phi, y)
21 |
22 | print('Weights: ' + str(regressor.get_weights()))
23 | print('Gradient: ' + str(regressor.diff(np.array([[5.]]))))
24 |
25 | plt.scatter(x, y)
26 | plt.plot(x, regressor.predict(phi))
27 | plt.show()
28 |
--------------------------------------------------------------------------------
/docs/source/tutorials/code/simple_experiment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.ensemble import ExtraTreesRegressor
3 |
4 | from mushroom_rl.algorithms.value import FQI
5 | from mushroom_rl.core import Core
6 | from mushroom_rl.environments import CarOnHill
7 | from mushroom_rl.policy import EpsGreedy
8 | from mushroom_rl.rl_utils.parameters import Parameter
9 |
10 | mdp = CarOnHill()
11 |
12 | # Policy
13 | epsilon = Parameter(value=1.)
14 | pi = EpsGreedy(epsilon=epsilon)
15 |
16 | # Approximator
17 | approximator_params = dict(input_shape=mdp.info.observation_space.shape,
18 | n_actions=mdp.info.action_space.n,
19 | n_estimators=50,
20 | min_samples_split=5,
21 | min_samples_leaf=2)
22 | approximator = ExtraTreesRegressor
23 |
24 | # Agent
25 | agent = FQI(mdp.info, pi, approximator, n_iterations=20,
26 | approximator_params=approximator_params)
27 |
28 | core = Core(agent, mdp)
29 |
30 | core.learn(n_episodes=1000, n_episodes_per_fit=1000)
31 |
32 | pi.set_epsilon(Parameter(0.))
33 | initial_state = np.array([[-.5, 0.]])
34 | dataset = core.evaluate(initial_states=initial_state)
35 |
36 | print(dataset.discounted_return)
37 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/__init__.py
--------------------------------------------------------------------------------
/examples/car_on_hill_fqi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from joblib import Parallel, delayed
3 | from sklearn.ensemble import ExtraTreesRegressor
4 |
5 | from mushroom_rl.algorithms.value import FQI
6 | from mushroom_rl.core import Core, Logger
7 | from mushroom_rl.environments import *
8 | from mushroom_rl.policy import EpsGreedy
9 | from mushroom_rl.rl_utils.parameters import Parameter
10 |
11 | """
12 | This script aims to replicate the experiments on the Car on Hill MDP as
13 | presented in:
14 | "Tree-Based Batch Mode Reinforcement Learning", Ernst D. et al.. 2005.
15 |
16 | """
17 |
18 |
19 | def experiment():
20 | np.random.seed()
21 |
22 | # MDP
23 | mdp = CarOnHill()
24 |
25 | # Policy
26 | epsilon = Parameter(value=1.)
27 | pi = EpsGreedy(epsilon=epsilon)
28 |
29 | # Approximator
30 | approximator_params = dict(input_shape=mdp.info.observation_space.shape,
31 | n_actions=mdp.info.action_space.n,
32 | n_estimators=50,
33 | min_samples_split=5,
34 | min_samples_leaf=2)
35 | approximator = ExtraTreesRegressor
36 |
37 | # Agent
38 | algorithm_params = dict(n_iterations=20)
39 | agent = FQI(mdp.info, pi, approximator,
40 | approximator_params=approximator_params, **algorithm_params)
41 |
42 | # Algorithm
43 | core = Core(agent, mdp)
44 |
45 | # Render
46 | core.evaluate(n_episodes=1, render=True)
47 |
48 | # Train
49 | core.learn(n_episodes=1000, n_episodes_per_fit=1000)
50 |
51 | # Test
52 | test_epsilon = Parameter(0.)
53 | agent.policy.set_epsilon(test_epsilon)
54 |
55 | initial_states = np.zeros((289, 2))
56 | cont = 0
57 | for i in range(-8, 9):
58 | for j in range(-8, 9):
59 | initial_states[cont, :] = [0.125 * i, 0.375 * j]
60 | cont += 1
61 |
62 | dataset = core.evaluate(initial_states=initial_states)
63 |
64 | # Render
65 | core.evaluate(n_episodes=3, render=True)
66 |
67 | return np.mean(dataset.discounted_return)
68 |
69 |
70 | if __name__ == '__main__':
71 | n_experiment = 1
72 |
73 | logger = Logger(FQI.__name__, results_dir=None)
74 | logger.strong_line()
75 | logger.info('Experiment Algorithm: ' + FQI.__name__)
76 |
77 | Js = Parallel(n_jobs=None)(delayed(experiment)() for _ in range(n_experiment))
78 | logger.info((np.mean(Js)))
79 |
--------------------------------------------------------------------------------
/examples/cartpole_lspi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value import LSPI
4 | from mushroom_rl.core import Core, Logger
5 | from mushroom_rl.environments import *
6 | from mushroom_rl.features import Features
7 | from mushroom_rl.features.basis import PolynomialBasis, GaussianRBF
8 | from mushroom_rl.policy import EpsGreedy
9 | from mushroom_rl.rl_utils.parameters import Parameter
10 |
11 |
12 | """
13 | This script aims to replicate the experiments on the Inverted Pendulum MDP as
14 | presented in:
15 | "Least-Squares Policy Iteration". Lagoudakis M. G. and Parr R.. 2003.
16 |
17 | """
18 |
19 |
20 | def experiment():
21 | np.random.seed()
22 |
23 | # MDP
24 | mdp = CartPole()
25 |
26 | # Policy
27 | epsilon = Parameter(value=1.)
28 | pi = EpsGreedy(epsilon=epsilon)
29 |
30 | # Agent
31 | basis = [PolynomialBasis()]
32 |
33 | s1 = np.array([-np.pi, 0, np.pi]) * .25
34 | s2 = np.array([-1, 0, 1])
35 | for i in s1:
36 | for j in s2:
37 | basis.append(GaussianRBF(np.array([i, j]), np.array([1.])))
38 | features = Features(basis_list=basis)
39 |
40 | fit_params = dict()
41 | approximator_params = dict(input_shape=(features.size,),
42 | output_shape=(mdp.info.action_space.n,),
43 | n_actions=mdp.info.action_space.n,
44 | phi=features)
45 | agent = LSPI(mdp.info, pi, approximator_params=approximator_params, fit_params=fit_params)
46 |
47 | # Algorithm
48 | core = Core(agent, mdp)
49 | core.evaluate(n_episodes=3, render=True)
50 |
51 | # Train
52 | core.learn(n_episodes=500, n_episodes_per_fit=500)
53 |
54 | # Test
55 | test_epsilon = Parameter(0.)
56 | agent.policy.set_epsilon(test_epsilon)
57 |
58 | dataset = core.evaluate(n_episodes=1, quiet=True)
59 |
60 | core.evaluate(n_steps=100, render=True)
61 |
62 | return np.mean(dataset.episodes_length)
63 |
64 |
65 | if __name__ == '__main__':
66 | n_experiment = 1
67 |
68 | logger = Logger(LSPI.__name__, results_dir=None)
69 | logger.strong_line()
70 | logger.info('Experiment Algorithm: ' + LSPI.__name__)
71 |
72 | steps = experiment()
73 | logger.info('Final episode length: %d' % steps)
74 |
--------------------------------------------------------------------------------
/examples/double_chain_q_learning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/double_chain_q_learning/__init__.py
--------------------------------------------------------------------------------
/examples/double_chain_q_learning/chain_structure/p.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/double_chain_q_learning/chain_structure/p.npy
--------------------------------------------------------------------------------
/examples/double_chain_q_learning/chain_structure/rew.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/double_chain_q_learning/chain_structure/rew.npy
--------------------------------------------------------------------------------
/examples/double_chain_q_learning/double_chain.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from joblib import Parallel, delayed
3 | from pathlib import Path
4 |
5 | from mushroom_rl.algorithms.value import QLearning, DoubleQLearning,\
6 | WeightedQLearning, SpeedyQLearning
7 | from mushroom_rl.core import Core
8 | from mushroom_rl.environments import *
9 | from mushroom_rl.policy import EpsGreedy
10 | from mushroom_rl.utils.callbacks import CollectQ
11 | from mushroom_rl.rl_utils.parameters import Parameter, DecayParameter
12 |
13 |
14 | """
15 | Simple script to solve a double chain with Q-Learning and some of its variants.
16 | The considered double chain is the one presented in:
17 | "Relative Entropy Policy Search". Peters J. et al.. 2010.
18 |
19 | """
20 |
21 |
22 | def experiment(algorithm_class, exp):
23 | np.random.seed()
24 |
25 | # MDP
26 | path = Path(__file__).resolve().parent / 'chain_structure'
27 | p = np.load(path / 'p.npy')
28 | rew = np.load(path / 'rew.npy')
29 | mdp = FiniteMDP(p, rew, gamma=.9)
30 |
31 | # Policy
32 | epsilon = Parameter(value=1.)
33 | pi = EpsGreedy(epsilon=epsilon)
34 |
35 | # Agent
36 | learning_rate = DecayParameter(value=1., exp=exp, size=mdp.info.size)
37 | algorithm_params = dict(learning_rate=learning_rate)
38 | agent = algorithm_class(mdp.info, pi, **algorithm_params)
39 |
40 | # Algorithm
41 | collect_Q = CollectQ(agent.Q)
42 | callbacks = [collect_Q]
43 | core = Core(agent, mdp, callbacks)
44 |
45 | # Train
46 | core.learn(n_steps=20000, n_steps_per_fit=1, quiet=True)
47 |
48 | Qs = collect_Q.get()
49 |
50 | return Qs
51 |
52 |
53 | if __name__ == '__main__':
54 | n_experiment = 5
55 |
56 | names = {1: '1', .51: '51', QLearning: 'Q', DoubleQLearning: 'DQ',
57 | WeightedQLearning: 'WQ', SpeedyQLearning: 'SPQ'}
58 |
59 | log_path = Path(__file__).resolve().parent / 'logs'
60 |
61 | log_path.mkdir(parents=True, exist_ok=True)
62 |
63 | for e in [1, .51]:
64 | for a in [QLearning, DoubleQLearning, WeightedQLearning,
65 | SpeedyQLearning]:
66 | out = Parallel(n_jobs=1)(
67 | delayed(experiment)(a, e) for _ in range(n_experiment))
68 | Qs = np.array([o for o in out])
69 |
70 | Qs = np.mean(Qs, 0)
71 |
72 | filename = names[a] + names[e] + '.npy'
73 | np.save(log_path / filename, Qs[:, 0, 0])
74 |
--------------------------------------------------------------------------------
/examples/habitat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/habitat/__init__.py
--------------------------------------------------------------------------------
/examples/habitat/pointnav_apartment-0.yaml:
--------------------------------------------------------------------------------
1 | ENV_NAME: "NavRLEnv"
2 |
3 | ENVIRONMENT:
4 | MAX_EPISODE_STEPS: 500
5 |
6 | SIMULATOR:
7 | HABITAT_SIM_V0:
8 | GPU_DEVICE_ID: 0
9 |
10 | RGB_SENSOR: # Used for observations
11 | WIDTH: 64
12 | HEIGHT: 64
13 | HFOV: 79
14 | POSITION: [0, 0.88, 0]
15 |
16 | ACTION_SPACE_CONFIG: "v0"
17 | FORWARD_STEP_SIZE: 0.25 # How much the agent moves with 'forward'
18 | TURN_ANGLE: 10 # How much the agent turns with 'left' / 'right' actions
19 |
20 | TASK:
21 | TYPE: Nav-v0
22 |
23 | # Set both to the same value
24 | SUCCESS_DISTANCE: 0.2
25 | SUCCESS:
26 | SUCCESS_DISTANCE: 0.2
27 |
28 | SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR']
29 | POINTGOAL_WITH_GPS_COMPASS_SENSOR:
30 | GOAL_FORMAT: "POLAR"
31 | DIMENSIONALITY: 2
32 | GOAL_SENSOR_UUID: pointgoal_with_gps_compass
33 |
34 | MEASUREMENTS: ['DISTANCE_TO_GOAL', 'SUCCESS', 'SPL']
35 |
36 | DATASET: # Replica scene
37 | CONTENT_SCENES: ['*']
38 | DATA_PATH: "replica_{split}_apartment-0.json.gz"
39 | SCENES_DIR: "Replica-Dataset/replica-path/apartment_0"
40 | TYPE: PointNav-v1
41 | SPLIT: train
42 |
--------------------------------------------------------------------------------
/examples/habitat/replica_train_apartment-0.json:
--------------------------------------------------------------------------------
1 | {
2 | "episodes": [{"episode_id": "0",
3 | "scene_id": "habitat/mesh_semantic.ply",
4 | "start_position": [-0.716670036315918, -1.374765157699585, 0.7762265205383301],
5 | "start_rotation": [0.0, 0.0, 0.0, 1.0],
6 | "goals": [{"position": [4.170074462890625, -1.374765157699585, 1.8612048625946045], "radius": null}],
7 | "shortest_paths": null,
8 | "start_room": null}]
9 | }
10 |
--------------------------------------------------------------------------------
/examples/habitat/replica_train_apartment-0.json.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/habitat/replica_train_apartment-0.json.gz
--------------------------------------------------------------------------------
/examples/simple_chain_qlearning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value import QLearning
4 | from mushroom_rl.core import Core, Logger
5 | from mushroom_rl.environments import *
6 | from mushroom_rl.policy import EpsGreedy
7 | from mushroom_rl.rl_utils.parameters import Parameter
8 |
9 |
10 | """
11 | Simple script to solve a simple chain with Q-Learning.
12 |
13 | """
14 |
15 |
16 | def experiment():
17 | np.random.seed()
18 |
19 | logger = Logger(QLearning.__name__, results_dir=None)
20 | logger.strong_line()
21 | logger.info('Experiment Algorithm: ' + QLearning.__name__)
22 |
23 | # MDP
24 | mdp = generate_simple_chain(state_n=5, goal_states=[2], prob=.8, rew=1,
25 | gamma=.9)
26 |
27 | # Policy
28 | epsilon = Parameter(value=.15)
29 | pi = EpsGreedy(epsilon=epsilon)
30 |
31 | # Agent
32 | learning_rate = Parameter(value=.2)
33 | algorithm_params = dict(learning_rate=learning_rate)
34 | agent = QLearning(mdp.info, pi, **algorithm_params)
35 |
36 | # Core
37 | core = Core(agent, mdp)
38 |
39 | # Initial policy Evaluation
40 | dataset = core.evaluate(n_steps=1000)
41 | J = np.mean(dataset.discounted_return)
42 | logger.info(f'J start: {J}')
43 |
44 | # Train
45 | core.learn(n_steps=10000, n_steps_per_fit=1)
46 |
47 | # Final Policy Evaluation
48 | dataset = core.evaluate(n_steps=1000)
49 | J = np.mean(dataset.discounted_return)
50 | logger.info(f'J final: {J}')
51 |
52 |
53 | if __name__ == '__main__':
54 | experiment()
55 |
--------------------------------------------------------------------------------
/examples/taxi_mellow_sarsa/grid.txt:
--------------------------------------------------------------------------------
1 | S#F.#.G
2 | .#..#..
3 | .......
4 | ##...##
5 | ......F
6 | F.....#
7 |
--------------------------------------------------------------------------------
/examples/taxi_mellow_sarsa/taxi_mellow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from joblib import Parallel, delayed
3 |
4 | from mushroom_rl.algorithms.value import SARSA
5 | from mushroom_rl.core import Core
6 | from mushroom_rl.environments.generators.taxi import generate_taxi
7 | from mushroom_rl.policy import Boltzmann, EpsGreedy, Mellowmax
8 | from mushroom_rl.utils.callbacks import CollectDataset
9 | from mushroom_rl.rl_utils.parameters import Parameter
10 |
11 |
12 | """
13 | This script aims to replicate the experiments on the Taxi MDP as presented in:
14 | "An Alternative Softmax Operator for Reinforcement Learning", Asadi K. et al..
15 | 2017.
16 |
17 | """
18 |
19 |
20 | def experiment(policy, value):
21 | np.random.seed()
22 |
23 | # MDP
24 | mdp = generate_taxi('grid.txt')
25 |
26 | # Policy
27 | pi = policy(Parameter(value=value))
28 |
29 | # Agent
30 | learning_rate = Parameter(value=.15)
31 | algorithm_params = dict(learning_rate=learning_rate)
32 | agent = SARSA(mdp.info, pi, **algorithm_params)
33 |
34 | # Algorithm
35 | collect_dataset = CollectDataset()
36 | callbacks = [collect_dataset]
37 | core = Core(agent, mdp, callbacks)
38 |
39 | # Train
40 | n_steps = 300000
41 | core.learn(n_steps=n_steps, n_steps_per_fit=1, quiet=True)
42 |
43 | return np.sum(np.array(collect_dataset.get())[:, 2]) / float(n_steps)
44 |
45 |
46 | if __name__ == '__main__':
47 | n_experiment = 25
48 |
49 | algs = {EpsGreedy: 'epsilon', Boltzmann: 'boltzmann', Mellowmax: 'mellow'}
50 | ranges = {EpsGreedy: np.linspace(.05, .5, 10),
51 | Boltzmann: np.linspace(.5, 10, 10),
52 | Mellowmax: np.linspace(.5, 10, 10)}
53 |
54 | for p in [EpsGreedy, Boltzmann, Mellowmax]:
55 | print('Policy: ', algs[p])
56 | Js = list()
57 | for v in ranges[p]:
58 | out = Parallel(n_jobs=-1)(
59 | delayed(experiment)(p, v) for _ in range(n_experiment))
60 | J = [np.mean(o) for o in out]
61 | Js.append(np.mean(J))
62 |
63 | np.save('r_%s.npy' % algs[p], Js)
64 |
--------------------------------------------------------------------------------
/examples/vectorized_core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/examples/vectorized_core/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '2.0.0-rc1'
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/algorithms/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/actor_critic/__init__.py:
--------------------------------------------------------------------------------
1 | from .classic_actor_critic import StochasticAC, StochasticAC_AVG, COPDAC_Q
2 | from .deep_actor_critic import DeepAC, A2C, DDPG, TD3, SAC, TRPO, PPO, PPO_BPTT, RudinPPO
3 |
4 | __all__ = ['COPDAC_Q', 'StochasticAC', 'StochasticAC_AVG',
5 | 'DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO', 'PPO_BPTT',
6 | 'RudinPPO']
7 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/actor_critic/classic_actor_critic/__init__.py:
--------------------------------------------------------------------------------
1 | from .copdac_q import COPDAC_Q
2 | from .stochastic_ac import StochasticAC, StochasticAC_AVG
3 |
4 | __all__ = ['COPDAC_Q', 'StochasticAC', 'StochasticAC_AVG']
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py:
--------------------------------------------------------------------------------
1 | from .deep_actor_critic import OnPolicyDeepAC, DeepAC
2 | from .a2c import A2C
3 | from .ddpg import DDPG
4 | from .td3 import TD3
5 | from .sac import SAC
6 | from .trpo import TRPO
7 | from .ppo import PPO
8 | from .ppo_bptt import PPO_BPTT
9 | from .ppo_rudin import RudinPPO
10 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/__init__.py:
--------------------------------------------------------------------------------
1 | from .policy_gradient import REINFORCE, GPOMDP, eNAC
2 | from .black_box_optimization import RWR, PGPE, REPS, ConstrainedREPS, MORE, ePPO
3 |
4 |
5 | __all__ = ['REINFORCE', 'GPOMDP', 'eNAC', 'RWR', 'PGPE', 'REPS', 'ConstrainedREPS', 'MORE', 'ePPO']
6 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/black_box_optimization/__init__.py:
--------------------------------------------------------------------------------
1 | from .context_builder import ContextBuilder
2 | from .black_box_optimization import BlackBoxOptimization
3 | from .rwr import RWR
4 | from .reps import REPS
5 | from .pgpe import PGPE
6 | from .constrained_reps import ConstrainedREPS
7 | from .more import MORE
8 | from .eppo import ePPO
9 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/black_box_optimization/context_builder.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.core import Serializable
2 |
3 |
4 | class ContextBuilder(Serializable):
5 | def __init__(self, context_shape=None):
6 | self._context_shape = context_shape
7 |
8 | super().__init__()
9 |
10 | self._add_save_attr(_context_shape='primitive')
11 |
12 | def __call__(self, initial_state, **episode_info):
13 | return None
14 |
15 | @property
16 | def context_shape(self):
17 | return self._context_shape
18 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/black_box_optimization/pgpe.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.policy_search.black_box_optimization import BlackBoxOptimization
4 |
5 |
6 | class PGPE(BlackBoxOptimization):
7 | """
8 | Policy Gradient with Parameter Exploration algorithm.
9 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G.,
10 | Peters J.. 2013.
11 |
12 | """
13 | def __init__(self, mdp_info, distribution, policy, optimizer, context_builder=None):
14 | """
15 | Constructor.
16 |
17 | Args:
18 | optimizer: the gradient step optimizer.
19 |
20 | """
21 | self.optimizer = optimizer
22 |
23 | super().__init__(mdp_info, distribution, policy, context_builder=context_builder)
24 |
25 | self._add_save_attr(optimizer='mushroom')
26 |
27 | def _update(self, Jep, theta, context):
28 | baseline_num_list = list()
29 | baseline_den_list = list()
30 | diff_log_dist_list = list()
31 |
32 | # Compute derivatives of distribution and baseline components
33 | for i in range(len(Jep)):
34 | J_i = Jep[i]
35 | theta_i = theta[i]
36 |
37 | diff_log_dist = self.distribution.diff_log(theta_i, context)
38 | diff_log_dist2 = diff_log_dist**2
39 |
40 | diff_log_dist_list.append(diff_log_dist)
41 | baseline_num_list.append(J_i * diff_log_dist2)
42 | baseline_den_list.append(diff_log_dist2)
43 |
44 | # Compute baseline
45 | baseline = np.mean(baseline_num_list, axis=0) / np.mean(baseline_den_list, axis=0)
46 | baseline[np.logical_not(np.isfinite(baseline))] = 0.
47 |
48 | # Compute gradient
49 | grad_J_list = list()
50 | for i in range(len(Jep)):
51 | diff_log_dist = diff_log_dist_list[i]
52 | J_i = Jep[i]
53 |
54 | grad_J_list.append(diff_log_dist * (J_i - baseline))
55 |
56 | grad_J = np.mean(grad_J_list, axis=0)
57 |
58 | omega_old = self.distribution.get_parameters()
59 | omega_new = self.optimizer(omega_old, grad_J)
60 | self.distribution.set_parameters(omega_new)
61 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/black_box_optimization/reps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from scipy.optimize import minimize
4 |
5 | from mushroom_rl.algorithms.policy_search.black_box_optimization import BlackBoxOptimization
6 | from mushroom_rl.rl_utils.parameters import to_parameter
7 |
8 |
9 | class REPS(BlackBoxOptimization):
10 | """
11 | Episodic Relative Entropy Policy Search algorithm.
12 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G.,
13 | Peters J.. 2013.
14 |
15 | """
16 | def __init__(self, mdp_info, distribution, policy, eps):
17 | """
18 | Constructor.
19 |
20 | Args:
21 | eps ([float, Parameter]): the maximum admissible value for the Kullback-Leibler
22 | divergence between the new distribution and the
23 | previous one at each update step.
24 |
25 | """
26 | assert not distribution.is_contextual
27 |
28 | self._eps = to_parameter(eps)
29 |
30 | super().__init__(mdp_info, distribution, policy)
31 |
32 | self._add_save_attr(_eps='mushroom')
33 |
34 | def _update(self, Jep, theta, context):
35 | eta_start = np.ones(1)
36 |
37 | res = minimize(REPS._dual_function, eta_start,
38 | jac=REPS._dual_function_diff,
39 | bounds=((np.finfo(np.float32).eps, np.inf),),
40 | args=(self._eps(), Jep, theta))
41 |
42 | eta_opt = res.x.item()
43 |
44 | Jep -= np.max(Jep)
45 |
46 | d = np.exp(Jep / eta_opt)
47 |
48 | self.distribution.mle(theta, d)
49 |
50 | @staticmethod
51 | def _dual_function(eta_array, *args):
52 | eta = eta_array.item()
53 | eps, Jep, theta = args
54 |
55 | max_J = np.max(Jep)
56 |
57 | r = Jep - max_J
58 | sum1 = np.mean(np.exp(r / eta))
59 |
60 | return eta * eps + eta * np.log(sum1) + max_J
61 |
62 | @staticmethod
63 | def _dual_function_diff(eta_array, *args):
64 | eta = eta_array.item()
65 | eps, Jep, theta = args
66 |
67 | max_J = np.max(Jep)
68 |
69 | r = Jep - max_J
70 |
71 | sum1 = np.mean(np.exp(r / eta))
72 | sum2 = np.mean(np.exp(r / eta) * r)
73 |
74 | gradient = eps + np.log(sum1) - sum2 / (eta * sum1)
75 |
76 | return np.array([gradient])
77 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/black_box_optimization/rwr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.policy_search.black_box_optimization import BlackBoxOptimization
4 | from mushroom_rl.rl_utils.parameters import to_parameter
5 |
6 |
7 | class RWR(BlackBoxOptimization):
8 | """
9 | Reward-Weighted Regression algorithm.
10 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G.,
11 | Peters J.. 2013.
12 |
13 | """
14 | def __init__(self, mdp_info, distribution, policy, beta):
15 | """
16 | Constructor.
17 |
18 | Args:
19 | beta ([float, Parameter]): the temperature for the exponential reward
20 | transformation.
21 |
22 | """
23 | assert not distribution.is_contextual
24 |
25 | self._beta = to_parameter(beta)
26 |
27 | super().__init__(mdp_info, distribution, policy)
28 |
29 | self._add_save_attr(_beta='mushroom')
30 |
31 | def _update(self, Jep, theta, context):
32 | Jep -= np.max(Jep)
33 |
34 | d = np.exp(self._beta() * Jep)
35 |
36 | self.distribution.mle(theta, d)
37 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/policy_gradient/__init__.py:
--------------------------------------------------------------------------------
1 | from .policy_gradient import PolicyGradient
2 | from .reinforce import REINFORCE
3 | from .gpomdp import GPOMDP
4 | from .enac import eNAC
5 |
6 | __all__ = ['REINFORCE', 'GPOMDP', 'eNAC']
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.policy_search.policy_gradient import PolicyGradient
4 |
5 |
6 | class eNAC(PolicyGradient):
7 | """
8 | Episodic Natural Actor Critic algorithm.
9 | "A Survey on Policy Search for Robotics", Deisenroth M. P., Neumann G.,
10 | Peters J. 2013.
11 |
12 | """
13 | def __init__(self, mdp_info, policy, optimizer, critic_features=None):
14 | """
15 | Constructor.
16 |
17 | Args:
18 | critic_features (Features, None): features used by the critic.
19 |
20 | """
21 | super().__init__(mdp_info, policy, optimizer)
22 | self.phi_c = critic_features
23 |
24 | self.sum_grad_log = None
25 | self.psi_ext = None
26 | self.sum_grad_log_list = list()
27 |
28 | self._add_save_attr(
29 | phi_c='pickle',
30 | sum_grad_log='numpy',
31 | psi_ext='pickle',
32 | sum_grad_log_list='pickle'
33 | )
34 |
35 | def _compute_gradient(self, J):
36 | R = np.array(J)
37 | PSI = np.array(self.sum_grad_log_list)
38 |
39 | w_and_v = np.linalg.pinv(PSI).dot(R)
40 | nat_grad = w_and_v[:self.policy.weights_size]
41 |
42 | self.sum_grad_log_list = list()
43 |
44 | return nat_grad
45 |
46 | def _step_update(self, x, u, r):
47 | self.sum_grad_log += self.df*self.policy.diff_log(x, u)
48 |
49 | if self.psi_ext is None:
50 | if self.phi_c is None:
51 | self.psi_ext = np.ones(1)
52 | else:
53 | self.psi_ext = self.phi_c(x)
54 |
55 | def _episode_end_update(self):
56 | psi = np.concatenate((self.sum_grad_log, self.psi_ext))
57 | self.sum_grad_log_list.append(psi)
58 |
59 | def _init_update(self):
60 | self.psi_ext = None
61 | self.sum_grad_log = np.zeros(self.policy.weights_size)
62 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.policy_search.policy_gradient import PolicyGradient
4 |
5 |
6 | class REINFORCE(PolicyGradient):
7 | """
8 | REINFORCE algorithm.
9 | "Simple Statistical Gradient-Following Algorithms for Connectionist
10 | Reinforcement Learning", Williams R. J.. 1992.
11 |
12 | """
13 | def __init__(self, mdp_info, policy, optimizer):
14 | super().__init__(mdp_info, policy, optimizer)
15 | self.sum_d_log_pi = None
16 | self.list_sum_d_log_pi = list()
17 | self.baseline_num = list()
18 | self.baseline_den = list()
19 |
20 | self._add_save_attr(
21 | sum_d_log_pi='numpy',
22 | list_sum_d_log_pi='pickle',
23 | baseline_num='pickle',
24 | baseline_den='pickle'
25 | )
26 |
27 | # Ignore divide by zero
28 | np.seterr(divide='ignore', invalid='ignore')
29 |
30 | def _compute_gradient(self, J):
31 | baseline = np.mean(self.baseline_num, axis=0) / np.mean(self.baseline_den, axis=0)
32 | baseline[np.logical_not(np.isfinite(baseline))] = 0.
33 | grad_J_episode = list()
34 | for i, J_episode in enumerate(J):
35 | sum_d_log_pi = self.list_sum_d_log_pi[i]
36 | grad_J_episode.append(sum_d_log_pi * (J_episode - baseline))
37 |
38 | grad_J = np.mean(grad_J_episode, axis=0)
39 | self.list_sum_d_log_pi = list()
40 | self.baseline_den = list()
41 | self.baseline_num = list()
42 |
43 | return grad_J
44 |
45 | def _step_update(self, x, u, r):
46 | d_log_pi = self.policy.diff_log(x, u)
47 | self.sum_d_log_pi += d_log_pi
48 |
49 | def _episode_end_update(self):
50 | self.list_sum_d_log_pi.append(self.sum_d_log_pi)
51 | squared_sum_d_log_pi = np.square(self.sum_d_log_pi)
52 | self.baseline_num.append(squared_sum_d_log_pi * self.J_episode)
53 | self.baseline_den.append(squared_sum_d_log_pi)
54 |
55 | def _init_update(self):
56 | self.sum_d_log_pi = np.zeros(self.policy.weights_size)
57 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/__init__.py:
--------------------------------------------------------------------------------
1 | from .batch_td import *
2 | from .dqn import *
3 | from .td import *
4 |
5 | __all__ = ['FQI', 'DoubleFQI', 'BoostedFQI', 'LSPI', 'AbstractDQN', 'DQN', 'DoubleDQN',
6 | 'AveragedDQN', 'CategoricalDQN', 'DuelingDQN', 'NoisyDQN', 'QuantileDQN',
7 | 'MaxminDQN', 'Rainbow', 'QLearning', 'QLambda', 'DoubleQLearning', 'WeightedQLearning',
8 | 'MaxminQLearning', 'SpeedyQLearning', 'RLearning', 'RQLearning',
9 | 'SARSA', 'SARSALambda', 'SARSALambdaContinuous', 'ExpectedSARSA',
10 | 'TrueOnlineSARSALambda']
11 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/batch_td/__init__.py:
--------------------------------------------------------------------------------
1 | from .batch_td import BatchTD
2 | from .fqi import FQI
3 | from .double_fqi import DoubleFQI
4 | from .boosted_fqi import BoostedFQI
5 | from .lspi import LSPI
6 |
7 | __all__ = ['FQI', 'DoubleFQI', 'BoostedFQI', 'LSPI']
8 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/batch_td/batch_td.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.core import Agent
2 | from mushroom_rl.approximators import Regressor
3 |
4 |
5 | class BatchTD(Agent):
6 | """
7 | Abstract class to implement a generic Batch TD algorithm.
8 |
9 | """
10 | def __init__(self, mdp_info, policy, approximator, approximator_params=None, fit_params=None):
11 | """
12 | Constructor.
13 |
14 | Args:
15 | approximator (object): approximator used by the algorithm and the
16 | policy.
17 | approximator_params (dict, None): parameters of the approximator to
18 | build;
19 | fit_params (dict, None): parameters of the fitting algorithm of the
20 | approximator;
21 |
22 | """
23 | approximator_params = dict() if approximator_params is None else\
24 | approximator_params
25 | self._fit_params = dict() if fit_params is None else fit_params
26 |
27 | self.approximator = Regressor(approximator, **approximator_params)
28 | policy.set_q(self.approximator)
29 |
30 | self._add_save_attr(
31 | approximator='mushroom',
32 | _fit_params='pickle'
33 | )
34 |
35 | super().__init__(mdp_info, policy)
36 |
37 | def _post_load(self):
38 | self.policy.set_q(self.approximator)
39 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/batch_td/boosted_fqi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import trange
3 |
4 | from .fqi import FQI
5 |
6 |
7 | class BoostedFQI(FQI):
8 | """
9 | Boosted Fitted Q-Iteration algorithm.
10 | "Boosted Fitted Q-Iteration". Tosatto S. et al.. 2017.
11 |
12 | """
13 | def __init__(self, mdp_info, policy, approximator, n_iterations,
14 | approximator_params=None, fit_params=None, quiet=False):
15 | self._prediction = 0.
16 | self._next_q = 0.
17 | self._idx = 0
18 |
19 | assert approximator_params['n_models'] == n_iterations
20 |
21 | self._add_save_attr(
22 | _n_iterations='primitive',
23 | _quiet='primitive',
24 | _prediction='primitive',
25 | _next_q='numpy',
26 | _idx='primitive',
27 | _target='pickle'
28 | )
29 |
30 | super().__init__(mdp_info, policy, approximator, n_iterations, approximator_params, fit_params, quiet)
31 |
32 | def fit(self, dataset):
33 | state, action, reward, next_state, absorbing, _ = dataset.parse(to='numpy')
34 | for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False):
35 | if self._target is None:
36 | self._target = reward
37 | else:
38 | self._next_q += self.approximator.predict(next_state,
39 | idx=self._idx - 1)
40 | if np.any(absorbing):
41 | self._next_q *= 1 - absorbing.reshape(-1, 1)
42 |
43 | max_q = np.max(self._next_q, axis=1)
44 | self._target = reward + self.mdp_info.gamma * max_q
45 |
46 | self._target -= self._prediction
47 | self._prediction += self._target
48 |
49 | self.approximator.fit(state, action, self._target, idx=self._idx,
50 | **self._fit_params)
51 |
52 | self._idx += 1
53 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/batch_td/double_fqi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import trange
3 |
4 | from .fqi import FQI
5 |
6 |
7 | class DoubleFQI(FQI):
8 | """
9 | Double Fitted Q-Iteration algorithm.
10 | "Estimating the Maximum Expected Value in Continuous Reinforcement Learning
11 | Problems". D'Eramo C. et al.. 2017.
12 |
13 | """
14 | def __init__(self, mdp_info, policy, approximator, n_iterations,
15 | approximator_params=None, fit_params=None, quiet=False):
16 | approximator_params['n_models'] = 2
17 |
18 | super().__init__(mdp_info, policy, approximator, n_iterations,
19 | approximator_params, fit_params, quiet)
20 |
21 | def fit(self, dataset):
22 | for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False):
23 | state = list()
24 | action = list()
25 | reward = list()
26 | next_state = list()
27 | absorbing = list()
28 |
29 | half = len(dataset) // 2
30 | for i in range(2):
31 | s, a, r, ss, ab, _ = dataset[i * half:(i + 1) * half].parse(to='numpy')
32 | state.append(s)
33 | action.append(a)
34 | reward.append(r)
35 | next_state.append(ss)
36 | absorbing.append(ab)
37 |
38 | if self._target is None:
39 | self._target = reward
40 | else:
41 | for i in range(2):
42 | q_i = self.approximator.predict(next_state[i], idx=i)
43 |
44 | amax_q = np.expand_dims(np.argmax(q_i, axis=1), axis=1)
45 | max_q = self.approximator.predict(next_state[i], amax_q,
46 | idx=1 - i)
47 | if np.any(absorbing[i]):
48 | max_q *= 1 - absorbing[i]
49 | self._target[i] = reward[i] + self.mdp_info.gamma * max_q
50 |
51 | for i in range(2):
52 | self.approximator.fit(state[i], action[i], self._target[i], idx=i,
53 | **self._fit_params)
54 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/batch_td/fqi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import trange
3 |
4 | from mushroom_rl.algorithms.value.batch_td import BatchTD
5 | from mushroom_rl.rl_utils.parameters import to_parameter
6 |
7 |
8 | class FQI(BatchTD):
9 | """
10 | Fitted Q-Iteration algorithm.
11 | "Tree-Based Batch Mode Reinforcement Learning", Ernst D. et al.. 2005.
12 |
13 | """
14 | def __init__(self, mdp_info, policy, approximator, n_iterations,
15 | approximator_params=None, fit_params=None, quiet=False):
16 | """
17 | Constructor.
18 |
19 | Args:
20 | n_iterations ([int, Parameter]): number of iterations to perform for training;
21 | quiet (bool, False): whether to show the progress bar or not.
22 |
23 | """
24 | self._n_iterations = to_parameter(n_iterations)
25 | self._quiet = quiet
26 | self._target = None
27 |
28 | self._add_save_attr(
29 | _n_iterations='mushroom',
30 | _quiet='primitive',
31 | _target='pickle'
32 | )
33 |
34 | super().__init__(mdp_info, policy, approximator, approximator_params, fit_params)
35 |
36 | def fit(self, dataset):
37 | state, action, reward, next_state, absorbing, _ = dataset.parse(to='numpy')
38 | for _ in trange(self._n_iterations(), dynamic_ncols=True, disable=self._quiet, leave=False):
39 | if self._target is None:
40 | self._target = reward
41 | else:
42 | q = self.approximator.predict(next_state)
43 | if np.any(absorbing):
44 | q *= 1 - absorbing.reshape(-1, 1)
45 |
46 | max_q = np.max(q, axis=1)
47 | self._target = reward + self.mdp_info.gamma * max_q
48 |
49 | self.approximator.fit(state, action, self._target, **self._fit_params)
50 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/batch_td/lspi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.batch_td import BatchTD
4 | from mushroom_rl.approximators.parametric import LinearApproximator
5 | from mushroom_rl.features import get_action_features
6 | from mushroom_rl.rl_utils.parameters import to_parameter
7 |
8 |
9 | class LSPI(BatchTD):
10 | """
11 | Least-Squares Policy Iteration algorithm.
12 | "Least-Squares Policy Iteration". Lagoudakis M. G. and Parr R.. 2003.
13 |
14 | """
15 | def __init__(self, mdp_info, policy, approximator_params=None, epsilon=1e-2, fit_params=None):
16 | """
17 | Constructor.
18 |
19 | Args:
20 | epsilon ([float, Parameter], 1e-2): termination coefficient.
21 |
22 | """
23 | self._epsilon = to_parameter(epsilon)
24 |
25 | self._add_save_attr(_epsilon='mushroom')
26 |
27 | super().__init__(mdp_info, policy, LinearApproximator, approximator_params, fit_params)
28 |
29 | def fit(self, dataset):
30 | state, action, reward, next_state, absorbing, _ = dataset.parse(to='numpy')
31 |
32 | phi_state = self.approximator.model.phi(state)
33 | phi_next_state = self.approximator.model.phi(next_state)
34 |
35 | phi_state_action = get_action_features(phi_state, action, self.mdp_info.action_space.n)
36 |
37 | norm = np.inf
38 | while norm > self._epsilon():
39 | q = self.approximator.predict(next_state)
40 | if np.any(absorbing):
41 | q *= 1 - absorbing.reshape(-1, 1)
42 |
43 | next_action = np.argmax(q, axis=1).reshape(-1, 1)
44 | phi_next_state_next_action = get_action_features(phi_next_state, next_action, self.mdp_info.action_space.n)
45 |
46 | tmp = phi_state_action - self.mdp_info.gamma * phi_next_state_next_action
47 | A = phi_state_action.T.dot(tmp)
48 | b = (phi_state_action.T.dot(reward)).reshape(-1, 1)
49 |
50 | old_w = self.approximator.get_weights()
51 | if np.linalg.matrix_rank(A) == A.shape[1]:
52 | w = np.linalg.solve(A, b).ravel()
53 | else:
54 | w = np.linalg.pinv(A).dot(b).ravel()
55 | self.approximator.set_weights(w)
56 |
57 | norm = np.linalg.norm(w - old_w)
58 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/dqn/__init__.py:
--------------------------------------------------------------------------------
1 | from .abstract_dqn import AbstractDQN
2 | from .dqn import DQN
3 | from .double_dqn import DoubleDQN
4 | from .averaged_dqn import AveragedDQN
5 | from .maxmin_dqn import MaxminDQN
6 | from .dueling_dqn import DuelingDQN
7 | from .categorical_dqn import CategoricalDQN
8 | from .noisy_dqn import NoisyDQN
9 | from .quantile_dqn import QuantileDQN
10 | from .rainbow import Rainbow
11 |
12 |
13 | __all__ = ['AbstractDQN', 'DQN', 'DoubleDQN', 'AveragedDQN', 'MaxminDQN',
14 | 'DuelingDQN', 'CategoricalDQN', 'NoisyDQN', 'QuantileDQN', 'Rainbow']
15 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/dqn/averaged_dqn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.dqn import AbstractDQN
4 | from mushroom_rl.approximators.regressor import Regressor
5 |
6 |
7 | class AveragedDQN(AbstractDQN):
8 | """
9 | Averaged-DQN algorithm.
10 | "Averaged-DQN: Variance Reduction and Stabilization for Deep Reinforcement
11 | Learning". Anschel O. et al.. 2017.
12 |
13 | """
14 | def __init__(self, mdp_info, policy, approximator, n_approximators,
15 | **params):
16 | """
17 | Constructor.
18 |
19 | Args:
20 | n_approximators (int): the number of target approximators to store.
21 |
22 | """
23 | assert n_approximators > 1
24 |
25 | self._n_approximators = n_approximators
26 |
27 | super().__init__(mdp_info, policy, approximator, **params)
28 |
29 | self._n_fitted_target_models = 1
30 |
31 | self._add_save_attr(_n_fitted_target_models='primitive')
32 |
33 | def _initialize_regressors(self, approximator, apprx_params_train,
34 | apprx_params_target):
35 | self.approximator = Regressor(approximator, **apprx_params_train)
36 | self.target_approximator = Regressor(approximator,
37 | n_models=self._n_approximators,
38 | **apprx_params_target)
39 | for i in range(len(self.target_approximator)):
40 | self.target_approximator[i].set_weights(
41 | self.approximator.get_weights()
42 | )
43 |
44 | def _update_target(self):
45 | idx = self._n_updates // self._target_update_frequency\
46 | % self._n_approximators
47 | self.target_approximator[idx].set_weights(
48 | self.approximator.get_weights())
49 |
50 | if self._n_fitted_target_models < self._n_approximators:
51 | self._n_fitted_target_models += 1
52 |
53 | def _next_q(self, next_state, absorbing):
54 | q = list()
55 | for idx in range(self._n_fitted_target_models):
56 | q_target_idx = self.target_approximator.predict(next_state, idx=idx, **self._predict_params)
57 | q.append(q_target_idx)
58 | q = np.mean(q, axis=0)
59 | if np.any(absorbing):
60 | q *= 1 - absorbing.reshape(-1, 1)
61 |
62 | return np.max(q, axis=1)
63 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/dqn/double_dqn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.dqn import DQN
4 |
5 |
6 | class DoubleDQN(DQN):
7 | """
8 | Double DQN algorithm.
9 | "Deep Reinforcement Learning with Double Q-Learning".
10 | Hasselt H. V. et al.. 2016.
11 |
12 | """
13 | def _next_q(self, next_state, absorbing):
14 | q = self.approximator.predict(next_state, **self._predict_params)
15 | max_a = np.argmax(q, axis=1)
16 |
17 | double_q = self.target_approximator.predict(next_state, max_a, **self._predict_params)
18 | if np.any(absorbing):
19 | double_q *= 1 - absorbing
20 |
21 | return double_q
22 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/dqn/dqn.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.algorithms.value.dqn import AbstractDQN
2 |
3 |
4 | class DQN(AbstractDQN):
5 | """
6 | Deep Q-Network algorithm.
7 | "Human-Level Control Through Deep Reinforcement Learning".
8 | Mnih V. et al.. 2015.
9 |
10 | """
11 | def _next_q(self, next_state, absorbing):
12 | q = self.target_approximator.predict(next_state, **self._predict_params)
13 | if absorbing.any():
14 | q *= 1 - absorbing.reshape(-1, 1)
15 |
16 | return q.max(1)
17 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/dqn/maxmin_dqn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.dqn import DQN
4 | from mushroom_rl.approximators.regressor import Regressor
5 |
6 |
7 | class MaxminDQN(DQN):
8 | """
9 | MaxminDQN algorithm.
10 | "Maxmin Q-learning: Controlling the Estimation Bias of Q-learning".
11 | Lan Q. et al.. 2020.
12 |
13 | """
14 | def __init__(self, mdp_info, policy, approximator, n_approximators, **params):
15 | """
16 | Constructor.
17 |
18 | Args:
19 | n_approximators (int): the number of approximators in the ensemble.
20 |
21 | """
22 | assert n_approximators > 1
23 |
24 | self._n_approximators = n_approximators
25 |
26 | super().__init__(mdp_info, policy, approximator, **params)
27 |
28 | def fit(self, dataset):
29 | self._fit_params['idx'] = np.random.randint(self._n_approximators)
30 |
31 | super().fit(dataset)
32 |
33 | def _initialize_regressors(self, approximator, apprx_params_train, apprx_params_target):
34 | self.approximator = Regressor(approximator,
35 | n_models=self._n_approximators,
36 | prediction='min', **apprx_params_train)
37 | self.target_approximator = Regressor(approximator,
38 | n_models=self._n_approximators,
39 | prediction='min',
40 | **apprx_params_target)
41 | self._update_target()
42 |
43 | def _update_target(self):
44 | for i in range(len(self.target_approximator)):
45 | self.target_approximator[i].set_weights(self.approximator[i].get_weights())
46 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/__init__.py:
--------------------------------------------------------------------------------
1 | from .td import TD
2 | from .sarsa import SARSA
3 | from .sarsa_lambda import SARSALambda
4 | from .expected_sarsa import ExpectedSARSA
5 | from .q_learning import QLearning
6 | from .q_lambda import QLambda
7 | from .double_q_learning import DoubleQLearning
8 | from .speedy_q_learning import SpeedyQLearning
9 | from .r_learning import RLearning
10 | from .weighted_q_learning import WeightedQLearning
11 | from .maxmin_q_learning import MaxminQLearning
12 | from .rq_learning import RQLearning
13 | from .sarsa_lambda_continuous import SARSALambdaContinuous
14 | from .true_online_sarsa_lambda import TrueOnlineSARSALambda
15 |
16 | __all__ = ['SARSA', 'SARSALambda', 'ExpectedSARSA', 'QLearning',
17 | 'QLambda', 'DoubleQLearning', 'SpeedyQLearning',
18 | 'RLearning', 'WeightedQLearning', 'MaxminQLearning',
19 | 'RQLearning', 'SARSALambdaContinuous', 'TrueOnlineSARSALambda']
20 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/double_q_learning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 |
4 | from mushroom_rl.algorithms.value.td import TD
5 | from mushroom_rl.approximators.ensemble_table import EnsembleTable
6 |
7 |
8 | class DoubleQLearning(TD):
9 | """
10 | Double Q-Learning algorithm.
11 | "Double Q-Learning". Hasselt H. V.. 2010.
12 |
13 | """
14 | def __init__(self, mdp_info, policy, learning_rate):
15 | Q = EnsembleTable(2, mdp_info.size)
16 |
17 | super().__init__(mdp_info, policy, Q, learning_rate)
18 |
19 | self._alpha_double = [deepcopy(self._alpha), deepcopy(self._alpha)]
20 |
21 | self._add_save_attr(
22 | _alpha_double='primitive'
23 | )
24 |
25 | assert len(self.Q) == 2, 'The regressor ensemble must' \
26 | ' have exactly 2 models.'
27 |
28 | def _update(self, state, action, reward, next_state, absorbing):
29 | approximator_idx = 0 if np.random.uniform() < .5 else 1
30 |
31 | q_current = self.Q[approximator_idx][state, action]
32 |
33 | if not absorbing:
34 | q_ss = self.Q[approximator_idx][next_state, :]
35 | max_q = np.max(q_ss)
36 | a_n = np.array(
37 | [np.random.choice(np.argwhere(q_ss == max_q).ravel())])
38 | q_next = self.Q[1 - approximator_idx][next_state, a_n]
39 | else:
40 | q_next = 0.
41 |
42 | q = q_current + self._alpha_double[approximator_idx](state, action) * (
43 | reward + self.mdp_info.gamma * q_next - q_current)
44 |
45 | self.Q[approximator_idx][state, action] = q
46 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/expected_sarsa.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.algorithms.value.td import TD
2 | from mushroom_rl.approximators.table import Table
3 |
4 |
5 | class ExpectedSARSA(TD):
6 | """
7 | Expected SARSA algorithm.
8 | "A theoretical and empirical analysis of Expected Sarsa". Seijen H. V. et
9 | al.. 2009.
10 |
11 | """
12 | def __init__(self, mdp_info, policy, learning_rate):
13 | Q = Table(mdp_info.size)
14 |
15 | super().__init__(mdp_info, policy, Q, learning_rate)
16 |
17 | def _update(self, state, action, reward, next_state, absorbing):
18 | q_current = self.Q[state, action]
19 |
20 | if not absorbing:
21 | q_next = self.Q[next_state, :].dot(self.policy(next_state))
22 | else:
23 | q_next = 0.
24 |
25 | self.Q[state, action] = q_current + self._alpha(state, action) * (
26 | reward + self.mdp_info.gamma * q_next - q_current)
27 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/maxmin_q_learning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 |
4 | from mushroom_rl.algorithms.value.td import TD
5 | from mushroom_rl.approximators.ensemble_table import EnsembleTable
6 |
7 |
8 | class MaxminQLearning(TD):
9 | """
10 | Maxmin Q-Learning algorithm without replay memory.
11 | "Maxmin Q-learning: Controlling the Estimation Bias of Q-learning".
12 | Lan Q. et al. 2019.
13 |
14 | """
15 | def __init__(self, mdp_info, policy, learning_rate, n_tables):
16 | """
17 | Constructor.
18 |
19 | Args:
20 | n_tables (int): number of tables in the ensemble.
21 |
22 | """
23 | self._n_tables = n_tables
24 | Q = EnsembleTable(n_tables, mdp_info.size, prediction='min')
25 |
26 | super().__init__(mdp_info, policy, Q, learning_rate)
27 |
28 | self._alpha_mm = [deepcopy(self._alpha) for _ in range(n_tables)]
29 |
30 | self._add_save_attr(_n_tables='primitive', _alpha_mm='primitive')
31 |
32 | def _update(self, state, action, reward, next_state, absorbing):
33 | approximator_idx = np.random.choice(self._n_tables)
34 |
35 | q_current = self.Q[approximator_idx][state, action]
36 |
37 | if not absorbing:
38 | q_ss = self.Q.predict(next_state)
39 | q_next = np.max(q_ss)
40 | else:
41 | q_next = 0.
42 |
43 | q = q_current + self._alpha_mm[approximator_idx](state, action) * (
44 | reward + self.mdp_info.gamma * q_next - q_current)
45 |
46 | self.Q[approximator_idx][state, action] = q
47 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/q_lambda.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.td import TD
4 | from mushroom_rl.rl_utils.eligibility_trace import EligibilityTrace
5 | from mushroom_rl.approximators.table import Table
6 | from mushroom_rl.rl_utils.parameters import to_parameter
7 |
8 |
9 | class QLambda(TD):
10 | """
11 | Q(Lambda) algorithm.
12 | "Learning from Delayed Rewards". Watkins C.J.C.H.. 1989.
13 |
14 | """
15 | def __init__(self, mdp_info, policy, learning_rate, lambda_coeff,
16 | trace='replacing'):
17 | """
18 | Constructor.
19 |
20 | Args:
21 | lambda_coeff ([float, Parameter]): eligibility trace coefficient;
22 | trace (str, 'replacing'): type of eligibility trace to use.
23 |
24 | """
25 | Q = Table(mdp_info.size)
26 | self._lambda = to_parameter(lambda_coeff)
27 |
28 | self.e = EligibilityTrace(Q.shape, trace)
29 | self._add_save_attr(
30 | _lambda='mushroom',
31 | e='mushroom'
32 | )
33 |
34 | super().__init__(mdp_info, policy, Q, learning_rate)
35 |
36 | def _update(self, state, action, reward, next_state, absorbing):
37 | q_current = self.Q[state, action]
38 |
39 | q_next = np.max(self.Q[next_state, :]) if not absorbing else 0.
40 |
41 | delta = reward + self.mdp_info.gamma*q_next - q_current
42 | self.e.update(state, action)
43 |
44 | self.Q.table += self._alpha(state, action) * delta * self.e.table
45 | self.e.table *= self.mdp_info.gamma * self._lambda()
46 |
47 | def episode_start(self, initial_state, episode_info):
48 | self.e.reset()
49 |
50 | return super().episode_start(initial_state, episode_info)
51 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/q_learning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.td import TD
4 | from mushroom_rl.approximators.table import Table
5 |
6 |
7 | class QLearning(TD):
8 | """
9 | Q-Learning algorithm.
10 | "Learning from Delayed Rewards". Watkins C.J.C.H.. 1989.
11 |
12 | """
13 | def __init__(self, mdp_info, policy, learning_rate):
14 | Q = Table(mdp_info.size)
15 |
16 | super().__init__(mdp_info, policy, Q, learning_rate)
17 |
18 | def _update(self, state, action, reward, next_state, absorbing):
19 | q_current = self.Q[state, action]
20 |
21 | q_next = np.max(self.Q[next_state, :]) if not absorbing else 0.
22 |
23 | self.Q[state, action] = q_current + self._alpha(state, action) * (
24 | reward + self.mdp_info.gamma * q_next - q_current)
25 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/r_learning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.td import TD
4 | from mushroom_rl.approximators.table import Table
5 |
6 | from mushroom_rl.rl_utils.parameters import to_parameter
7 |
8 |
9 | class RLearning(TD):
10 | """
11 | R-Learning algorithm.
12 | "A Reinforcement Learning Method for Maximizing Undiscounted Rewards".
13 | Schwartz A.. 1993.
14 |
15 | """
16 | def __init__(self, mdp_info, policy, learning_rate, beta):
17 | """
18 | Constructor.
19 |
20 | Args:
21 | beta ([float, Parameter]): beta coefficient.
22 |
23 | """
24 | Q = Table(mdp_info.size)
25 | self._rho = 0.
26 | self._beta = to_parameter(beta)
27 |
28 | self._add_save_attr(_rho='primitive', _beta='mushroom')
29 |
30 | super().__init__(mdp_info, policy, Q, learning_rate)
31 |
32 | def _update(self, state, action, reward, next_state, absorbing):
33 | q_current = self.Q[state, action]
34 | q_next = np.max(self.Q[next_state, :]) if not absorbing else 0.
35 | delta = reward - self._rho + q_next - q_current
36 | q_new = q_current + self._alpha(state, action) * delta
37 |
38 | self.Q[state, action] = q_new
39 |
40 | q_max = np.max(self.Q[state, :])
41 | if q_new == q_max:
42 | delta = reward + q_next - q_max - self._rho
43 | self._rho += self._beta(state, action) * delta
44 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/sarsa.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.algorithms.value.td import TD
2 | from mushroom_rl.approximators.table import Table
3 |
4 |
5 | class SARSA(TD):
6 | """
7 | SARSA algorithm.
8 |
9 | """
10 | def __init__(self, mdp_info, policy, learning_rate):
11 | Q = Table(mdp_info.size)
12 |
13 | super().__init__(mdp_info, policy, Q, learning_rate)
14 |
15 | def _update(self, state, action, reward, next_state, absorbing):
16 | q_current = self.Q[state, action]
17 |
18 | self.next_action, _ = self.draw_action(next_state)
19 | q_next = self.Q[next_state, self.next_action] if not absorbing else 0.
20 |
21 | self.Q[state, action] = q_current + self._alpha(state, action) * (
22 | reward + self.mdp_info.gamma * q_next - q_current)
23 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/sarsa_lambda.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.algorithms.value.td import TD
2 | from mushroom_rl.rl_utils.eligibility_trace import EligibilityTrace
3 | from mushroom_rl.approximators.table import Table
4 | from mushroom_rl.rl_utils.parameters import to_parameter
5 |
6 |
7 | class SARSALambda(TD):
8 | """
9 | The SARSA(lambda) algorithm for finite MDPs.
10 |
11 | """
12 | def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, trace='replacing'):
13 | """
14 | Constructor.
15 |
16 | Args:
17 | lambda_coeff ([float, Parameter]): eligibility trace coefficient;
18 | trace (str, 'replacing'): type of eligibility trace to use.
19 |
20 | """
21 | Q = Table(mdp_info.size)
22 | self._lambda = to_parameter(lambda_coeff)
23 |
24 | self.e = EligibilityTrace(Q.shape, trace)
25 | self._add_save_attr(
26 | _lambda='mushroom',
27 | e='mushroom'
28 | )
29 |
30 | super().__init__(mdp_info, policy, Q, learning_rate)
31 |
32 | def _update(self, state, action, reward, next_state, absorbing):
33 | q_current = self.Q[state, action]
34 |
35 | self.next_action, _ = self.draw_action(next_state)
36 | q_next = self.Q[next_state, self.next_action] if not absorbing else 0.
37 |
38 | delta = reward + self.mdp_info.gamma * q_next - q_current
39 | self.e.update(state, action)
40 |
41 | self.Q.table += self._alpha(state, action) * delta * self.e.table
42 | self.e.table *= self.mdp_info.gamma * self._lambda()
43 |
44 | def episode_start(self, initial_state, episode_info):
45 | self.e.reset()
46 |
47 | return super().episode_start(initial_state, episode_info)
48 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.algorithms.value.td import TD
4 | from mushroom_rl.approximators import Regressor
5 | from mushroom_rl.rl_utils.parameters import to_parameter
6 |
7 |
8 | class SARSALambdaContinuous(TD):
9 | """
10 | Continuous version of SARSA(lambda) algorithm.
11 |
12 | """
13 | def __init__(self, mdp_info, policy, approximator, learning_rate, lambda_coeff, approximator_params=None):
14 | """
15 | Constructor.
16 |
17 | Args:
18 | lambda_coeff ([float, Parameter]): eligibility trace coefficient.
19 |
20 | """
21 | approximator_params = dict() if approximator_params is None else approximator_params
22 |
23 | Q = Regressor(approximator, **approximator_params)
24 | self.e = np.zeros(Q.weights_size)
25 | self._lambda = to_parameter(lambda_coeff)
26 |
27 | self._add_save_attr(
28 | _lambda='primitive',
29 | e='numpy'
30 | )
31 |
32 | super().__init__(mdp_info, policy, Q, learning_rate)
33 |
34 | def _update(self, state, action, reward, next_state, absorbing):
35 | q_current = self.Q.predict(state, action)
36 |
37 | alpha = self._alpha(state, action)
38 |
39 | self.e = self.mdp_info.gamma * self._lambda() * self.e + self.Q.diff(state, action)
40 |
41 | self.next_action, _ = self.draw_action(next_state)
42 | q_next = self.Q.predict(next_state, self.next_action) if not absorbing else 0.
43 |
44 | delta = reward + self.mdp_info.gamma * q_next - q_current
45 |
46 | theta = self.Q.get_weights()
47 | theta += alpha * delta * self.e
48 | self.Q.set_weights(theta)
49 |
50 | def episode_start(self, initial_state, episode_info):
51 | self.e = np.zeros(self.Q.weights_size)
52 |
53 | return super().episode_start(initial_state, episode_info)
54 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/speedy_q_learning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 |
4 | from mushroom_rl.algorithms.value.td import TD
5 | from mushroom_rl.approximators.table import Table
6 |
7 |
8 | class SpeedyQLearning(TD):
9 | """
10 | Speedy Q-Learning algorithm.
11 | "Speedy Q-Learning". Ghavamzadeh et. al.. 2011.
12 |
13 | """
14 | def __init__(self, mdp_info, policy, learning_rate):
15 | Q = Table(mdp_info.size)
16 | self.old_q = deepcopy(Q)
17 |
18 | self._add_save_attr(old_q='mushroom')
19 |
20 | super().__init__(mdp_info, policy, Q, learning_rate)
21 |
22 | def _update(self, state, action, reward, next_state, absorbing):
23 | old_q = deepcopy(self.Q)
24 |
25 | max_q_cur = np.max(self.Q[next_state, :]) if not absorbing else 0.
26 | max_q_old = np.max(self.old_q[next_state, :]) if not absorbing else 0.
27 |
28 | target_cur = reward + self.mdp_info.gamma * max_q_cur
29 | target_old = reward + self.mdp_info.gamma * max_q_old
30 |
31 | alpha = self._alpha(state, action)
32 | q_cur = self.Q[state, action]
33 | self.Q[state, action] = q_cur + alpha * (target_old - q_cur) + (
34 | 1. - alpha) * (target_cur - target_old)
35 |
36 | self.old_q = old_q
37 |
--------------------------------------------------------------------------------
/mushroom_rl/algorithms/value/td/td.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.core import Agent
4 |
5 |
6 | class TD(Agent):
7 | """
8 | Implements functions to run TD algorithms.
9 |
10 | """
11 | def __init__(self, mdp_info, policy, approximator, learning_rate):
12 | """
13 | Constructor.
14 |
15 | Args:
16 | approximator: the approximator to use to fit the Q-function;
17 | learning_rate (Parameter): the learning rate.
18 |
19 | """
20 | self._alpha = learning_rate
21 |
22 | policy.set_q(approximator)
23 | self.Q = approximator
24 |
25 | self._add_save_attr(_alpha='mushroom', Q='mushroom')
26 |
27 | super().__init__(mdp_info, policy)
28 |
29 | def fit(self, dataset):
30 | assert len(dataset) == 1
31 |
32 | state, action, reward, next_state, absorbing, _ = dataset.item()
33 | self._update(state, action, reward, next_state, absorbing)
34 |
35 | def _update(self, state, action, reward, next_state, absorbing):
36 | """
37 | Update the Q-table.
38 |
39 | Args:
40 | state (np.ndarray): state;
41 | action (np.ndarray): action;
42 | reward (np.ndarray): reward;
43 | next_state (np.ndarray): next state;
44 | absorbing (np.ndarray): absorbing flag.
45 |
46 | """
47 | pass
48 |
49 | def _post_load(self):
50 | self.policy.set_q(self.Q)
--------------------------------------------------------------------------------
/mushroom_rl/approximators/__init__.py:
--------------------------------------------------------------------------------
1 | from .regressor import Regressor
2 |
3 |
--------------------------------------------------------------------------------
/mushroom_rl/approximators/_implementations/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/mushroom_rl/approximators/_implementations/generic_regressor.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.core.serialization import Serializable
2 |
3 |
4 | class GenericRegressor(Serializable):
5 | """
6 | This class is used to create a regressor that approximates a generic
7 | function. An arbitrary number of inputs and outputs is supported.
8 |
9 | """
10 | def __init__(self, approximator, n_inputs, **params):
11 | """
12 | Constructor.
13 |
14 | Args:
15 | approximator (class): the model class to approximate the
16 | a generic function;
17 | n_inputs (int): number of inputs of the regressor;
18 | **params: parameters dictionary to the regressor;
19 |
20 | """
21 | self._n_inputs = n_inputs
22 | self.model = approximator(**params)
23 |
24 | self._add_save_attr(
25 | _n_inputs='primitive',
26 | model=self._get_serialization_method(approximator)
27 | )
28 |
29 | def fit(self, *z, **fit_params):
30 | """
31 | Fit the model.
32 |
33 | Args:
34 | *z: list of inputs and targets;
35 | **fit_params: other parameters used by the fit method of the
36 | regressor.
37 |
38 | """
39 | self.model.fit(*z, **fit_params)
40 |
41 | def predict(self, *x, **predict_params):
42 | """
43 | Predict.
44 |
45 | Args:
46 | x (list): list of inputs;
47 | **predict_params: other parameters used by the predict method
48 | the regressor.
49 |
50 | Returns:
51 | The predictions of the model.
52 |
53 | """
54 | return self.model.predict(*x, **predict_params)
55 |
56 | def reset(self):
57 | """
58 | Reset the model parameters.
59 |
60 | """
61 | try:
62 | self.model.reset()
63 | except AttributeError:
64 | raise NotImplementedError('Attempt to reset weights of a'
65 | ' non-parametric regressor.')
66 |
67 | @property
68 | def weights_size(self):
69 | return self.model.weights_size
70 |
71 | def get_weights(self):
72 | return self.model.get_weights()
73 |
74 | def set_weights(self, w):
75 | self.model.set_weights(w)
76 |
77 | def diff(self, *x):
78 | return self.model.diff(*x)
79 |
80 | def __len__(self):
81 | return len(self.model)
82 |
--------------------------------------------------------------------------------
/mushroom_rl/approximators/ensemble_table.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.approximators.table import Table
2 | from mushroom_rl.approximators.ensemble import Ensemble
3 |
4 |
5 | class EnsembleTable(Ensemble):
6 | """
7 | This class implements functions to manage table ensembles.
8 |
9 | """
10 | def __init__(self, n_models, shape, **params):
11 | """
12 | Constructor.
13 |
14 | Args:
15 | n_models (int): number of models in the ensemble;
16 | shape (np.ndarray): shape of each table in the ensemble.
17 | **params: parameters dictionary to create each regressor.
18 |
19 | """
20 | params['shape'] = shape
21 | super().__init__(Table, n_models, **params)
22 |
23 | @property
24 | def n_actions(self):
25 | return self._model[0].shape[-1]
--------------------------------------------------------------------------------
/mushroom_rl/approximators/parametric/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear import LinearApproximator
2 | from .torch_approximator import TorchApproximator, NumpyTorchApproximator
3 | from .cmac import CMAC
4 |
5 |
--------------------------------------------------------------------------------
/mushroom_rl/approximators/parametric/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear_network import LinearNetwork
--------------------------------------------------------------------------------
/mushroom_rl/approximators/parametric/networks/linear_network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class LinearNetwork(nn.Module):
5 | def __init__(self, input_shape, output_shape, use_bias=False, gain=None, **kwargs):
6 | super().__init__()
7 |
8 | n_input = input_shape[-1]
9 | n_output = output_shape[0]
10 |
11 | self._f = nn.Linear(n_input, n_output, bias=use_bias)
12 |
13 | if gain is None:
14 | gain = nn.init.calculate_gain('linear')
15 |
16 | nn.init.xavier_uniform_(self._f.weight, gain=gain)
17 |
18 | def forward(self, state, **kwargs):
19 | return self._f(state)
20 |
--------------------------------------------------------------------------------
/mushroom_rl/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .array_backend import ArrayBackend
2 | from .core import Core
3 | from .dataset import DatasetInfo, Dataset, VectorizedDataset
4 | from .environment import Environment, MDPInfo
5 | from .agent import Agent, AgentInfo
6 | from .serialization import Serializable
7 | from .logger import Logger
8 |
9 | from .extra_info import ExtraInfo
10 |
11 | from .vectorized_core import VectorCore
12 | from .vectorized_env import VectorizedEnvironment
13 | from .multiprocess_environment import MultiprocessEnvironment
14 |
15 | import mushroom_rl.environments
16 |
17 | __all__ = ['ArrayBackend', 'Core', 'DatasetInfo', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo',
18 | 'Serializable', 'Logger', 'ExtraInfo', 'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
19 |
--------------------------------------------------------------------------------
/mushroom_rl/core/_impl/__init__.py:
--------------------------------------------------------------------------------
1 | from .numpy_dataset import NumpyDataset
2 | from .torch_dataset import TorchDataset
3 | from .list_dataset import ListDataset
4 | from .core_logic import CoreLogic
5 | from .vectorized_core_logic import VectorizedCoreLogic
6 |
--------------------------------------------------------------------------------
/mushroom_rl/core/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .console_logger import ConsoleLogger
2 | from .data_logger import DataLogger
3 | from .logger import Logger
4 |
5 |
6 | __all__ = ['Logger', 'ConsoleLogger', 'DataLogger']
7 |
--------------------------------------------------------------------------------
/mushroom_rl/core/logger/logger.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from pathlib import Path
3 |
4 | from .console_logger import ConsoleLogger
5 | from .data_logger import DataLogger
6 |
7 |
8 | class Logger(DataLogger, ConsoleLogger):
9 | """
10 | This class implements the logging functionality. It can be used to create
11 | automatically a log directory, save numpy data array and the current agent.
12 |
13 | """
14 | def __init__(self, log_name='', results_dir='./logs', log_console=False,
15 | use_timestamp=False, append=False, seed=None, **kwargs):
16 | """
17 | Constructor.
18 |
19 | Args:
20 | log_name (string, ''): name of the current experiment directory if not
21 | specified, the current timestamp is used.
22 | results_dir (string, './logs'): name of the base logging directory.
23 | If set to None, no directory is created;
24 | log_console (bool, False): whether to log or not the console output;
25 | use_timestamp (bool, False): If true, adds the current timestamp to
26 | the folder name;
27 | append (bool, False): If true, the logger will append the new
28 | data logged to the one already existing in the directory;
29 | seed (int, None): seed for the current run. It can be optionally
30 | specified to add a seed suffix for each data file logged;
31 | **kwargs: other parameters for ConsoleLogger class.
32 |
33 | """
34 |
35 | if log_console:
36 | assert results_dir is not None
37 |
38 | timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
39 |
40 | if not log_name:
41 | log_name = timestamp
42 | elif use_timestamp:
43 | log_name += '_' + timestamp
44 |
45 | if results_dir:
46 | results_dir = Path(results_dir) / log_name
47 | results_dir.mkdir(parents=True, exist_ok=True)
48 |
49 | suffix = '' if seed is None else '-' + str(seed)
50 |
51 | DataLogger.__init__(self, results_dir, suffix=suffix, append=append)
52 | ConsoleLogger.__init__(self, log_name, results_dir if log_console else None,
53 | suffix=suffix, **kwargs)
54 |
--------------------------------------------------------------------------------
/mushroom_rl/distributions/__init__.py:
--------------------------------------------------------------------------------
1 | from .distribution import Distribution
2 | from .gaussian import GaussianDistribution, GaussianDiagonalDistribution, GaussianCholeskyDistribution
3 | from .torch_distribution import AbstractGaussianTorchDistribution, DiagonalGaussianTorchDistribution
4 | from .torch_distribution import CholeskyGaussianTorchDistribution
5 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | Atari = None
3 | from .atari import Atari
4 | Atari.register()
5 | except ImportError:
6 | pass
7 |
8 | try:
9 | Gymnasium = None
10 | from .gymnasium_env import Gymnasium
11 | Gymnasium.register()
12 | except ImportError:
13 | pass
14 |
15 | try:
16 | DMControl = None
17 | from .dm_control_env import DMControl
18 | DMControl.register()
19 | except ImportError:
20 | pass
21 |
22 | try:
23 | MiniGrid = None
24 | from .minigrid_env import MiniGrid
25 | MiniGrid.register()
26 | except ImportError:
27 | pass
28 |
29 | try:
30 | iGibson = None
31 | from .igibson_env import iGibson
32 | iGibson.register()
33 | except ImportError:
34 | import logging
35 | logging.disable(logging.NOTSET)
36 |
37 | try:
38 | Habitat = None
39 | from .habitat_env import Habitat
40 | Habitat.register()
41 | except ImportError:
42 | pass
43 |
44 | try:
45 | MuJoCo = None
46 | from .mujoco import MuJoCo, MultiMuJoCo
47 | from .mujoco_envs import *
48 | except ImportError:
49 | pass
50 |
51 | try:
52 | OmniIsaacGymEnv = None
53 | from .omni_isaac_gym_env import OmniIsaacGymEnv
54 | except ImportError:
55 | pass
56 |
57 | try:
58 | PyBullet = None
59 | from .pybullet import PyBullet
60 | from .pybullet_envs import *
61 | except ImportError:
62 | pass
63 |
64 | try:
65 | IsaacSim = None
66 | from .isaacsim_env import IsaacSim
67 | except ImportError:
68 | pass
69 |
70 | from .generators.simple_chain import generate_simple_chain
71 |
72 | from .car_on_hill import CarOnHill
73 | CarOnHill.register()
74 |
75 | from .cart_pole import CartPole
76 | CartPole.register()
77 |
78 | from .finite_mdp import FiniteMDP
79 | FiniteMDP.register()
80 |
81 | from .grid_world import GridWorld, GridWorldVanHasselt
82 | GridWorld.register()
83 | GridWorldVanHasselt.register()
84 |
85 | from .inverted_pendulum import InvertedPendulum
86 | InvertedPendulum.register()
87 |
88 | from .lqr import LQR
89 | LQR.register()
90 |
91 | from .puddle_world import PuddleWorld
92 | PuddleWorld.register()
93 |
94 | from .segway import Segway
95 | Segway.register()
96 |
97 | from .ship_steering import ShipSteering
98 | ShipSteering.register()
99 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/finite_mdp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.core import Environment, MDPInfo
4 | from mushroom_rl.rl_utils import spaces
5 |
6 |
7 | class FiniteMDP(Environment):
8 | """
9 | Finite Markov Decision Process.
10 |
11 | """
12 | def __init__(self, p, rew, mu=None, gamma=.9, horizon=np.inf, dt=1e-1):
13 | """
14 | Constructor.
15 |
16 | Args:
17 | p (np.ndarray): transition probability matrix;
18 | rew (np.ndarray): reward matrix;
19 | mu (np.ndarray, None): initial state probability distribution;
20 | gamma (float, .9): discount factor;
21 | horizon (int, np.inf): the horizon;
22 | dt (float, 1e-1): the control timestep of the environment.
23 |
24 | """
25 | assert p.shape == rew.shape
26 | assert mu is None or p.shape[0] == mu.size
27 |
28 | # MDP parameters
29 | self.p = p
30 | self.r = rew
31 | self.mu = mu
32 |
33 | # MDP properties
34 | observation_space = spaces.Discrete(p.shape[0])
35 | action_space = spaces.Discrete(p.shape[1])
36 | horizon = horizon
37 | gamma = gamma
38 | mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt)
39 |
40 | super().__init__(mdp_info)
41 |
42 | def reset(self, state=None):
43 | if state is None:
44 | if self.mu is not None:
45 | self._state = np.array(
46 | [np.random.choice(self.mu.size, p=self.mu)])
47 | else:
48 | self._state = np.array([np.random.choice(self.p.shape[0])])
49 | else:
50 | self._state = state
51 |
52 | return self._state, {}
53 |
54 | def step(self, action):
55 | p = self.p[self._state[0], action[0], :]
56 | next_state = np.array([np.random.choice(p.size, p=p)])
57 | absorbing = not np.any(self.p[next_state[0]])
58 | reward = self.r[self._state[0], action[0], next_state[0]]
59 |
60 | self._state = next_state
61 |
62 | return self._state, reward, absorbing, {}
63 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/generators/__init__.py:
--------------------------------------------------------------------------------
1 | from .simple_chain import generate_simple_chain
2 | from .grid_world import generate_grid_world
3 | from .taxi import generate_taxi
4 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/generators/simple_chain.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.environments.finite_mdp import FiniteMDP
4 |
5 |
6 | def generate_simple_chain(state_n, goal_states, prob, rew, mu=None, gamma=.9,
7 | horizon=100):
8 | """
9 | Simple chain generator.
10 |
11 | Args:
12 | state_n (int): number of states;
13 | goal_states (list): list of goal states;
14 | prob (float): probability of success of an action;
15 | rew (float): reward obtained in goal states;
16 | mu (np.ndarray): initial state probability distribution;
17 | gamma (float, .9): discount factor;
18 | horizon (int, 100): the horizon.
19 |
20 | Returns:
21 | A FiniteMDP object built with the provided parameters.
22 |
23 | """
24 | p = compute_probabilities(state_n, prob)
25 | r = compute_reward(state_n, goal_states, rew)
26 |
27 | assert mu is None or len(mu) == state_n
28 |
29 | return FiniteMDP(p, r, mu, gamma, horizon)
30 |
31 |
32 | def compute_probabilities(state_n, prob):
33 | """
34 | Compute the transition probability matrix.
35 |
36 | Args:
37 | state_n (int): number of states;
38 | prob (float): probability of success of an action.
39 |
40 | Returns:
41 | The transition probability matrix;
42 |
43 | """
44 | p = np.zeros((state_n, 2, state_n))
45 |
46 | for i in range(state_n):
47 | if i == 0:
48 | p[i, 1, i] = 1.
49 | else:
50 | p[i, 1, i] = 1. - prob
51 | p[i, 1, i - 1] = prob
52 |
53 | if i == state_n - 1:
54 | p[i, 0, i] = 1.
55 | else:
56 | p[i, 0, i] = 1. - prob
57 | p[i, 0, i + 1] = prob
58 |
59 | return p
60 |
61 |
62 | def compute_reward(state_n, goal_states, rew):
63 | """
64 | Compute the reward matrix.
65 |
66 | Args:
67 | state_n (int): number of states;
68 | goal_states (list): list of goal states;
69 | rew (float): reward obtained in goal states.
70 |
71 | Returns:
72 | The reward matrix.
73 |
74 | """
75 | r = np.zeros((state_n, 2, state_n))
76 |
77 | for g in goal_states:
78 | if g != 0:
79 | r[g - 1, 0, g] = rew
80 |
81 | if g != state_n - 1:
82 | r[g + 1, 1, g] = rew
83 |
84 | return r
85 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .cartpole import CartPole
2 | from .a1_walking import A1Walking
3 | from .honey_badger_walking import HoneyBadgerWalking
4 | from .silver_badger_walking import SilverBadgerWalking
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/a1.usd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/a1.usd.png
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/instanceable_meshes.usd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/.thumbs/256x256/instanceable_meshes.usd.png
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/a1.usd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/a1.usd
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/instanceable_meshes.usd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/a1/instanceable_meshes.usd
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/cartpole/cartpole.usd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/cartpole/cartpole.usd
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/honey_badger.usd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/honey_badger.usd.png
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/instanceable_meshes.usd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/.thumbs/256x256/instanceable_meshes.usd.png
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/honey_badger.usd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/honey_badger.usd
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/instanceable_meshes.usd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/honey_badger/instanceable_meshes.usd
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/instanceable_meshes.usd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/instanceable_meshes.usd.png
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/silver_badger.usd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/.thumbs/256x256/silver_badger.usd.png
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/instanceable_meshes.usd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/instanceable_meshes.usd
--------------------------------------------------------------------------------
/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/silver_badger.usd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/isaacsim_envs/robots_usds/silver_badger/silver_badger.usd
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .ball_in_a_cup import BallInACup
2 | from .air_hockey import AirHockeyHit, AirHockeyDefend, AirHockeyPrepare, AirHockeyRepel
3 | from .ant import Ant
4 | from .half_cheetah import HalfCheetah
5 | from .hopper import Hopper
6 | from .walker_2d import Walker2D
7 |
8 | BallInACup.register()
9 | AirHockeyHit.register()
10 | AirHockeyDefend.register()
11 | AirHockeyPrepare.register()
12 | AirHockeyRepel.register()
13 | Ant.register()
14 | HalfCheetah.register()
15 | Hopper.register()
16 | Walker2D.register()
17 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/air_hockey/__init__.py:
--------------------------------------------------------------------------------
1 | from .hit import AirHockeyHit
2 | from .defend import AirHockeyDefend
3 | from .prepare import AirHockeyPrepare
4 | from .repel import AirHockeyRepel
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/air_hockey/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/air_hockey/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/air_hockey/double.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
12 |
14 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/air_hockey/single.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
11 |
13 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/air_hockey/table.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ant/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ant/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_convex.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_convex.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/base_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split1.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split10.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split10.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split11.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split11.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split12.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split12.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split13.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split13.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split14.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split14.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split15.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split15.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split16.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split16.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split17.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split17.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split18.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split18.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split2.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split3.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split4.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split5.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split5.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split6.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split6.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split7.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split7.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split8.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split8.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split9.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/cup_split9.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_convex.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_convex.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/elbow_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p1.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_convex_decomposition_p2.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/forearm_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p1.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p2.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_convex_decomposition_p3.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_convex.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_convex.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/shoulder_pitch_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p1.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_convex_decomposition_p2.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/upper_arm_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_convex.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_convex.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_palm_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p1.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p2.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_convex_decomposition_p3.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_pitch_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p1.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_convex_decomposition_p2.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/ball_in_a_cup/meshes/wrist_yaw_link_fine.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/half_cheetah/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/half_cheetah/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/hopper/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/hopper/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/cube.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/hand.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/hand.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/link0.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link0.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/link1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link1.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/link2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link2.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/link3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link3.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/link4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link4.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/link6.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link6.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/link7.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/panda/assets/link7.stl
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/assets/table.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/peg_insertion.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/pick.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/push.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/panda/reach.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/mujoco_envs/data/walker_2d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/mujoco_envs/data/walker_2d/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/environments/pybullet_envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .air_hockey import *
2 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/pybullet_envs/air_hockey/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .hit import AirHockeyHitBullet
3 | from .defend import AirHockeyDefendBullet
4 | from .prepare import AirHockeyPrepareBullet
5 | from .repel import AirHockeyRepelBullet
6 |
7 |
8 | AirHockeyHitBullet.register()
9 | AirHockeyDefendBullet.register()
10 | AirHockeyPrepareBullet.register()
11 | AirHockeyRepelBullet.register()
12 | except ImportError:
13 | pass
14 |
--------------------------------------------------------------------------------
/mushroom_rl/environments/pybullet_envs/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/environments/pybullet_envs/data/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/features/__init__.py:
--------------------------------------------------------------------------------
1 | from .features import Features, get_action_features
2 |
3 | __all__ = ['Features', 'get_action_features']
4 |
--------------------------------------------------------------------------------
/mushroom_rl/features/_implementations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/features/_implementations/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/features/_implementations/basis_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .features_implementation import FeaturesImplementation
4 |
5 |
6 | class BasisFeatures(FeaturesImplementation):
7 | def __init__(self, basis):
8 | self._basis = basis
9 |
10 | def __call__(self, *args):
11 | x = self._concatenate(args)
12 |
13 | y = list()
14 |
15 | x = np.atleast_2d(x)
16 | for s in x:
17 | out = np.empty(self.size)
18 |
19 | for i, bf in enumerate(self._basis):
20 | out[i] = bf(s)
21 |
22 | y.append(out)
23 |
24 | if len(y) == 1:
25 | y = y[0]
26 | else:
27 | y = np.array(y)
28 |
29 | return y
30 |
31 | @property
32 | def size(self):
33 | return len(self._basis)
34 |
--------------------------------------------------------------------------------
/mushroom_rl/features/_implementations/features_implementation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class FeaturesImplementation(object):
5 | def __call__(self, *x):
6 | """
7 | Evaluate the feature vector in the given raw input. If more than one
8 | element is passed, the raw input is concatenated before computing the
9 | features.
10 |
11 | Args:
12 | *x (list): the raw input.
13 |
14 | Returns:
15 | The features vector computed from the raw input.
16 |
17 | """
18 | pass
19 |
20 | @staticmethod
21 | def _concatenate(args):
22 | if len(args) > 1:
23 | x = np.concatenate(args, axis=-1)
24 | else:
25 | x = args[0]
26 |
27 | return x
28 |
29 | @property
30 | def size(self):
31 | """
32 | Returns:
33 | The number of elements in the features vector.
34 |
35 | """
36 | raise NotImplementedError
37 |
--------------------------------------------------------------------------------
/mushroom_rl/features/_implementations/functional_features.py:
--------------------------------------------------------------------------------
1 | from .features_implementation import FeaturesImplementation
2 |
3 |
4 | class FunctionalFeatures(FeaturesImplementation):
5 | def __init__(self, n_outputs, function):
6 | self._n_outputs = n_outputs
7 | self._function = function if function is not None else self._identity
8 |
9 | def __call__(self, *args):
10 | x = self._concatenate(args)
11 |
12 | return self._function(x)
13 |
14 | def _identity(self, x):
15 | return x
16 |
17 | @property
18 | def size(self):
19 | return self._n_outputs
20 |
--------------------------------------------------------------------------------
/mushroom_rl/features/_implementations/tiles_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .features_implementation import FeaturesImplementation
4 |
5 |
6 | class TilesFeatures(FeaturesImplementation):
7 | def __init__(self, tiles):
8 |
9 | if isinstance(tiles, list):
10 | self._tiles = tiles
11 | else:
12 | self._tiles = [tiles]
13 | self._size = 0
14 |
15 | for tiling in self._tiles:
16 | self._size += tiling.size
17 |
18 | def __call__(self, *args):
19 | x = self._concatenate(args)
20 |
21 | y = list()
22 |
23 | x = np.atleast_2d(x)
24 | for s in x:
25 | out = np.zeros(self._size)
26 |
27 | offset = 0
28 | for tiling in self._tiles:
29 | index = tiling(s)
30 |
31 | if index is not None:
32 | out[index + offset] = 1.
33 |
34 | offset += tiling.size
35 |
36 | y.append(out)
37 |
38 | if len(y) == 1:
39 | y = y[0]
40 | else:
41 | y = np.array(y)
42 |
43 | return y
44 |
45 | def compute_indexes(self, *args):
46 | x = self._concatenate(args)
47 |
48 | y = list()
49 |
50 | x = np.atleast_2d(x)
51 | for s in x:
52 | out = list()
53 |
54 | offset = 0
55 | for tiling in self._tiles:
56 | index = tiling(s)
57 |
58 | if index is not None:
59 | out.append(index + offset)
60 |
61 | offset += tiling.size
62 |
63 | y.append(out)
64 |
65 | if len(y) == 1:
66 | return y[0]
67 | else:
68 | return y
69 |
70 | @property
71 | def size(self):
72 | return self._size
73 |
--------------------------------------------------------------------------------
/mushroom_rl/features/_implementations/torch_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from .features_implementation import FeaturesImplementation
5 | from mushroom_rl.utils.torch import TorchUtils
6 |
7 |
8 | class TorchFeatures(FeaturesImplementation):
9 | def __init__(self, tensor_list):
10 | self._phi = tensor_list
11 |
12 | def __call__(self, *args):
13 | x = self._concatenate(args)
14 |
15 | x = TorchUtils.to_float_tensor(np.atleast_2d(x))
16 |
17 | y_list = [self._phi[i].forward(x) for i in range(len(self._phi))]
18 | y = torch.cat(y_list, 1).squeeze()
19 |
20 | y = y.detach().cpu().numpy()
21 |
22 | if y.shape[0] == 1:
23 | return y[0]
24 | else:
25 | return y
26 |
27 | @property
28 | def size(self):
29 | return np.sum([phi.size for phi in self._phi])
30 |
--------------------------------------------------------------------------------
/mushroom_rl/features/basis/__init__.py:
--------------------------------------------------------------------------------
1 | from .gaussian_rbf import GaussianRBF
2 | from .polynomial import PolynomialBasis
3 | from .fourier import FourierBasis
4 |
5 | __all__ = ['GaussianRBF', 'PolynomialBasis', 'FourierBasis']
6 |
--------------------------------------------------------------------------------
/mushroom_rl/features/tensors/__init__.py:
--------------------------------------------------------------------------------
1 | from .basis_tensor import GenericBasisTensor, GaussianRBFTensor, VonMisesBFTensor
2 | from .constant_tensor import ConstantTensor
3 | from .random_fourier_tensor import RandomFourierBasis
4 |
5 | __all_ = ['GenericBasisTensor', 'GaussianRBFTensor', 'VonMisesBFTensor', 'ConstantTensor', 'RandomFourierBasis']
--------------------------------------------------------------------------------
/mushroom_rl/features/tensors/constant_tensor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from mushroom_rl.utils.torch import TorchUtils
5 |
6 |
7 | class ConstantTensor(nn.Module):
8 | """
9 | Pytorch module to implement a constant function (always one).
10 |
11 | """
12 |
13 | def forward(self, x):
14 | return torch.ones(x.shape[0], 1).to(TorchUtils.get_device())
15 |
16 | @property
17 | def size(self):
18 | return 1
19 |
--------------------------------------------------------------------------------
/mushroom_rl/features/tiles/__init__.py:
--------------------------------------------------------------------------------
1 | from .tiles import Tiles
2 | from .voronoi import VoronoiTiles
3 |
4 | __all__ = ['Tiles', 'VoronoiTiles']
5 |
--------------------------------------------------------------------------------
/mushroom_rl/policy/__init__.py:
--------------------------------------------------------------------------------
1 | from .policy import Policy, ParametricPolicy
2 | from .vector_policy import VectorPolicy
3 | from .noise_policy import OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy
4 | from .td_policy import TDPolicy, Boltzmann, EpsGreedy, Mellowmax
5 | from .gaussian_policy import GaussianPolicy, DiagonalGaussianPolicy, \
6 | StateStdGaussianPolicy, StateLogStdGaussianPolicy
7 | from .deterministic_policy import DeterministicPolicy
8 | from .torch_policy import TorchPolicy, GaussianTorchPolicy, BoltzmannTorchPolicy
9 | from .recurrent_torch_policy import RecurrentGaussianTorchPolicy
10 | from .promps import ProMP
11 | from .dmp import DMP
12 |
13 |
14 | __all_td__ = ['TDPolicy', 'Boltzmann', 'EpsGreedy', 'Mellowmax']
15 | __all_parametric__ = ['ParametricPolicy', 'GaussianPolicy',
16 | 'DiagonalGaussianPolicy', 'StateStdGaussianPolicy',
17 | 'StateLogStdGaussianPolicy', 'ProMP']
18 | __all_torch__ = ['TorchPolicy', 'GaussianTorchPolicy', 'BoltzmannTorchPolicy']
19 | __all_noise__ = ['OrnsteinUhlenbeckPolicy', 'ClippedGaussianPolicy']
20 | __all_mp__ = ['ProMP', 'DMP']
21 |
22 | __all__ = ['Policy', 'DeterministicPolicy', ] \
23 | + __all_td__ + __all_parametric__ + __all_torch__ + __all_noise__ + __all_mp__
24 |
--------------------------------------------------------------------------------
/mushroom_rl/policy/deterministic_policy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .policy import ParametricPolicy
3 |
4 |
5 | class DeterministicPolicy(ParametricPolicy):
6 | """
7 | Simple parametric policy representing a deterministic policy. As
8 | deterministic policies are degenerate probability functions where all
9 | the probability mass is on the deterministic action,they are not
10 | differentiable, even if the mean value approximator is differentiable.
11 |
12 | """
13 | def __init__(self, mu, policy_state_shape=None):
14 | """
15 | Constructor.
16 |
17 | Args:
18 | mu (Regressor): the regressor representing the action to select
19 | in each state.
20 |
21 | """
22 | super().__init__(policy_state_shape)
23 |
24 | self._approximator = mu
25 | self._predict_params = dict()
26 |
27 | self._add_save_attr(_approximator='mushroom',
28 | _predict_params='pickle')
29 |
30 | def get_regressor(self):
31 | """
32 | Getter.
33 |
34 | Returns:
35 | The regressor that is used to map state to actions.
36 |
37 | """
38 | return self._approximator
39 |
40 | def __call__(self, state, action, policy_state=None):
41 | policy_action = self._approximator.predict(state, **self._predict_params)
42 |
43 | return 1. if np.array_equal(action, policy_action) else 0.
44 |
45 | def draw_action(self, state, policy_state=None):
46 | return self._approximator.predict(state, **self._predict_params), None
47 |
48 | def set_weights(self, weights):
49 | self._approximator.set_weights(weights)
50 |
51 | def get_weights(self):
52 | return self._approximator.get_weights()
53 |
54 | @property
55 | def weights_size(self):
56 | return self._approximator.weights_size
57 |
--------------------------------------------------------------------------------
/mushroom_rl/rl_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .eligibility_trace import EligibilityTrace, ReplacingTrace, AccumulatingTrace
2 | from .optimizers import Optimizer, AdamOptimizer, SGDOptimizer, AdaptiveOptimizer
3 | from .parameters import Parameter, DecayParameter, LinearParameter, to_parameter
4 | from .preprocessors import StandardizationPreprocessor, MinMaxPreprocessor
5 | from .replay_memory import ReplayMemory, PrioritizedReplayMemory
6 | from .running_stats import RunningStandardization, RunningAveragedWindow, RunningExpWeightedAverage
7 | from .spaces import Box, Discrete
8 | from .value_functions import compute_advantage, compute_advantage_montecarlo, compute_gae
9 | from .variance_parameters import VarianceDecreasingParameter, VarianceIncreasingParameter
10 | from .variance_parameters import WindowedVarianceParameter, WindowedVarianceIncreasingParameter
--------------------------------------------------------------------------------
/mushroom_rl/rl_utils/eligibility_trace.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.approximators.table import Table
2 |
3 |
4 | def EligibilityTrace(shape, name='replacing'):
5 | """
6 | Factory method to create an eligibility trace of the provided type.
7 |
8 | Args:
9 | shape (list): shape of the eligibility trace table;
10 | name (str, 'replacing'): type of the eligibility trace.
11 |
12 | Returns:
13 | The eligibility trace table of the provided shape and type.
14 |
15 | """
16 | if name == 'replacing':
17 | return ReplacingTrace(shape)
18 | elif name == 'accumulating':
19 | return AccumulatingTrace(shape)
20 | else:
21 | raise ValueError('Unknown type of trace.')
22 |
23 |
24 | class ReplacingTrace(Table):
25 | """
26 | Replacing trace.
27 |
28 | """
29 | def reset(self):
30 | self.table[:] = 0.
31 |
32 | def update(self, state, action):
33 | self.table[state, action] = 1.
34 |
35 |
36 | class AccumulatingTrace(Table):
37 | """
38 | Accumulating trace.
39 |
40 | """
41 | def reset(self):
42 | self.table[:] = 0.
43 |
44 | def update(self, state, action):
45 | self.table[state, action] += 1.
46 |
--------------------------------------------------------------------------------
/mushroom_rl/solvers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/mushroom_rl/solvers/__init__.py
--------------------------------------------------------------------------------
/mushroom_rl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .angles import normalize_angle_positive, normalize_angle, shortest_angular_distance
2 | from .angles import quat_to_euler, euler_to_quat, euler_to_mat, mat_to_euler
3 | from .features import uniform_grid
4 | from .frames import LazyFrames, preprocess_frame
5 | from .numerical_gradient import numerical_diff_dist, numerical_diff_function, numerical_diff_policy
6 | from .minibatches import minibatch_number, minibatch_generator
7 | from .plot import plot_mean_conf, get_mean_and_confidence
8 | from .record import VideoRecorder
9 | from .torch import TorchUtils, CategoricalWrapper
10 | from .viewer import Viewer, CV2Viewer, ImageViewer
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | __extras__ = []
2 |
3 | try:
4 | from mushroom_rl.utils.callbacks.plot_dataset import PlotDataset
5 | __extras__.append('PlotDataset')
6 | except ImportError:
7 | pass
8 |
9 | from .callback import Callback, CallbackList
10 | from .collect_dataset import CollectDataset
11 | from .collect_max_q import CollectMaxQ
12 | from .collect_q import CollectQ
13 | from .collect_parameters import CollectParameters
14 |
15 | __all__ = ['Callback', 'CollectDataset', 'CollectQ', 'CollectMaxQ',
16 | 'CollectParameters'] + __extras__
17 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/callbacks/callback.py:
--------------------------------------------------------------------------------
1 | class Callback(object):
2 | """
3 | Interface for all basic callbacks. Implements a list in which it is possible
4 | to store data and methods to query and clean the content stored by the
5 | callback.
6 |
7 | """
8 | def __call__(self, dataset):
9 | """
10 | Add samples to the samples list.
11 |
12 | Args:
13 | dataset (Dataset): the samples to collect.
14 |
15 | """
16 | raise NotImplementedError
17 |
18 | def get(self):
19 | """
20 | Returns:
21 | The current collected data.
22 |
23 | """
24 | raise NotImplementedError
25 |
26 | def clean(self):
27 | """
28 | Delete the current stored data
29 |
30 | """
31 | raise NotImplementedError
32 |
33 |
34 | class CallbackList(Callback):
35 | """
36 | Simple interface for callbacks storing a single list for data collection
37 |
38 | """
39 | def __init__(self):
40 | self._data_list = list()
41 |
42 | def get(self):
43 | """
44 | Returns:
45 | The current collected data.
46 |
47 | """
48 | return self._data_list
49 |
50 | def clean(self):
51 | """
52 | Delete the current stored data
53 |
54 | """
55 | self._data_list = list()
56 |
57 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/callbacks/collect_dataset.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.utils.callbacks.callback import Callback
2 | from mushroom_rl.core.dataset import VectorizedDataset
3 |
4 |
5 | class CollectDataset(Callback):
6 | """
7 | This callback can be used to collect samples during the learning of the
8 | agent.
9 |
10 | """
11 | def __init__(self):
12 | """
13 | Constructor.
14 | """
15 | self._dataset = None
16 |
17 | def __call__(self, dataset):
18 | if isinstance(dataset, VectorizedDataset):
19 | dataset = dataset.flatten()
20 |
21 | if self._dataset is None:
22 | self._dataset = dataset.copy()
23 | else:
24 | self._dataset += dataset
25 |
26 | def clean(self):
27 | self._dataset.clear()
28 |
29 | def get(self):
30 | return self._dataset
31 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/callbacks/collect_max_q.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.utils.callbacks.callback import CallbackList
2 | import numpy as np
3 |
4 |
5 | class CollectMaxQ(CallbackList):
6 | """
7 | This callback can be used to collect the maximum action value in a given
8 | state at each call.
9 |
10 | """
11 | def __init__(self, approximator, state):
12 | """
13 | Constructor.
14 |
15 | Args:
16 | approximator ([Table, EnsembleTable]): the approximator to use;
17 | state (np.ndarray): the state to consider.
18 |
19 | """
20 | self._approximator = approximator
21 | self._state = state
22 |
23 | super().__init__()
24 |
25 | def __call__(self, dataset):
26 | q = self._approximator.predict(self._state)
27 | max_q = np.max(q)
28 |
29 | self._data_list.append(max_q)
30 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/callbacks/collect_parameters.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.utils.callbacks.callback import CallbackList
2 | import numpy as np
3 |
4 |
5 | class CollectParameters(CallbackList):
6 | """
7 | This callback can be used to collect the values of a parameter
8 | (e.g. learning rate) during a run of the agent.
9 |
10 | """
11 | def __init__(self, parameter, *idx):
12 | """
13 | Constructor.
14 |
15 | Args:
16 | parameter (Parameter): the parameter whose values have to be
17 | collected;
18 | *idx (list): index of the parameter when the ``parameter`` is
19 | tabular.
20 |
21 | """
22 | self._parameter = parameter
23 | self._idx = idx
24 |
25 | super().__init__()
26 |
27 | def __call__(self, dataset):
28 | value = self._parameter.get_value(*self._idx)
29 | if isinstance(value, np.ndarray):
30 | value = np.array(value)
31 | self._data_list.append(value)
32 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/callbacks/collect_q.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 |
4 | from mushroom_rl.utils.callbacks.callback import CallbackList
5 | from mushroom_rl.approximators.ensemble_table import EnsembleTable
6 |
7 |
8 | class CollectQ(CallbackList):
9 | """
10 | This callback can be used to collect the action values in all states at the
11 | current time step.
12 |
13 | """
14 | def __init__(self, approximator):
15 | """
16 | Constructor.
17 |
18 | Args:
19 | approximator ([Table, EnsembleTable]): the approximator to use to
20 | predict the action values.
21 |
22 | """
23 | self._approximator = approximator
24 |
25 | super().__init__()
26 |
27 | def __call__(self, dataset):
28 | if isinstance(self._approximator, EnsembleTable):
29 | qs = list()
30 | for m in self._approximator.model:
31 | qs.append(m.table)
32 | self._data_list.append(deepcopy(np.mean(qs, 0)))
33 | else:
34 | self._data_list.append(deepcopy(self._approximator.table))
35 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def uniform_grid(n_centers, low, high, eta=0.25, cyclic=False):
5 | """
6 | This function is used to create the parameters of uniformly spaced radial
7 | basis functions with `eta` of overlap. It creates a uniformly spaced grid of
8 | ``n_centers[i]`` points in each dimension i. Also returns a vector
9 | containing the appropriate width of the radial basis functions.
10 |
11 | Args:
12 | n_centers (list): number of centers of each dimension;
13 | low (np.ndarray): lowest value for each dimension;
14 | high (np.ndarray): highest value for each dimension;
15 | eta (float, 0.25): overlap between two radial basis functions;
16 | cyclic (bool, False): whether the state space is a ring or not
17 |
18 | Returns:
19 | The uniformly spaced grid and the width vector.
20 |
21 | """
22 | assert 0 < eta < 1.0
23 |
24 | n_features = len(low)
25 | w = np.zeros(n_features)
26 | c = list()
27 | tot_points = 1
28 | for i, n in enumerate(n_centers):
29 | start = low[i]
30 | end = high[i]
31 | # m = abs(start - end) / n
32 | if n == 1:
33 | w[i] = abs(end - start) / 2
34 | c_i = (start + end) / 2.
35 | c.append(np.array([c_i]))
36 | else:
37 | if cyclic:
38 | end_new = end - abs(end-start) / n
39 | else:
40 | end_new = end
41 | w[i] = (1 + eta) * abs(end_new - start) / n
42 | c_i = np.linspace(start, end_new, n)
43 | c.append(c_i)
44 | tot_points *= n
45 |
46 | n_rows = 1
47 | n_cols = 0
48 |
49 | grid = np.zeros((tot_points, n_features))
50 |
51 | for discrete_values in c:
52 | i1 = 0
53 | dim = len(discrete_values)
54 |
55 | for i in range(dim):
56 | for r in range(n_rows):
57 | idx_r = r + i * n_rows
58 | for c in range(n_cols):
59 | grid[idx_r, c] = grid[r, c]
60 | grid[idx_r, n_cols] = discrete_values[i1]
61 |
62 | i1 += 1
63 |
64 | n_cols += 1
65 | n_rows *= len(discrete_values)
66 |
67 | return grid, w
68 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/frames.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | cv2.ocl.setUseOpenCL(False)
5 |
6 |
7 | class LazyFrames(object):
8 | """
9 | From OpenAI Baseline.
10 | https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
11 |
12 | This class provides a solution to optimize the use of memory when
13 | concatenating different frames, e.g. Atari frames in DQN. The frames are
14 | individually stored in a list and, when numpy arrays containing them are
15 | created, the reference to each frame is used instead of a copy.
16 |
17 | """
18 | def __init__(self, frames, history_length):
19 | self._frames = frames
20 |
21 | assert len(self._frames) == history_length
22 |
23 | def __array__(self, dtype=None):
24 | out = np.array(self._frames)
25 | if dtype is not None:
26 | out = out.astype(dtype)
27 |
28 | return out
29 |
30 | def copy(self):
31 | return self
32 |
33 | @property
34 | def shape(self):
35 | return (len(self._frames),) + self._frames[0].shape
36 |
37 |
38 | def preprocess_frame(obs, img_size):
39 | """
40 | Convert a frame from rgb to grayscale and resize it.
41 |
42 | Args:
43 | obs (np.ndarray): array representing an rgb frame;
44 | img_size (tuple): target size for images.
45 |
46 | Returns:
47 | The transformed frame as 8 bit integer array.
48 |
49 | """
50 | image = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
51 | image = cv2.resize(image, img_size, interpolation=cv2.INTER_LINEAR)
52 |
53 | return np.array(image, dtype=np.uint8)
54 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/isaac_sim/__init__.py:
--------------------------------------------------------------------------------
1 | from .observation_helper import ObservationHelper, ObservationType
2 | from .collision_helper import CollisionHelper
3 | from .action_helper import ActionType
--------------------------------------------------------------------------------
/mushroom_rl/utils/isaac_sim/action_helper.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class ActionType(Enum):
4 | EFFORT = "joint_efforts"
5 | POSITION = "joint_positions"
6 | VELOCITY = "joint_velocities"
--------------------------------------------------------------------------------
/mushroom_rl/utils/isaac_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def convert_task_observation(observation):
5 | obs_t = observation
6 | for _ in range(5):
7 | if torch.is_tensor(obs_t):
8 | break
9 | obs_t = obs_t[list(obs_t.keys())[0]]
10 | return obs_t
--------------------------------------------------------------------------------
/mushroom_rl/utils/minibatches.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def minibatch_number(size, batch_size):
5 | """
6 | Function to retrieve the number of batches, given a batch sizes.
7 |
8 | Args:
9 | size (int): size of the dataset;
10 | batch_size (int): size of the batches.
11 |
12 | Returns:
13 | The number of minibatches in the dataset.
14 |
15 | """
16 | return int(np.ceil(size / batch_size))
17 |
18 |
19 | def minibatch_generator(batch_size, *dataset):
20 | """
21 | Generator that creates a minibatch from the full dataset.
22 |
23 | Args:
24 | batch_size (int): the maximum size of each minibatch;
25 | dataset: the dataset to be splitted.
26 |
27 | Returns:
28 | The current minibatch.
29 |
30 | """
31 | size = len(dataset[0])
32 | num_batches = minibatch_number(size, batch_size)
33 | indexes = np.arange(0, size, 1)
34 | np.random.shuffle(indexes)
35 | batches = [(i * batch_size, min(size, (i + 1) * batch_size))
36 | for i in range(0, num_batches)]
37 |
38 | for (batch_start, batch_end) in batches:
39 | batch = []
40 | for i in range(len(dataset)):
41 | batch.append(dataset[i][indexes[batch_start:batch_end]])
42 | yield batch
43 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/mujoco/__init__.py:
--------------------------------------------------------------------------------
1 | from .viewer import MujocoViewer
2 | from .observation_helper import ObservationHelper, ObservationType
3 | from .kinematics import forward_kinematics
4 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/mujoco/kinematics.py:
--------------------------------------------------------------------------------
1 | import mujoco
2 |
3 |
4 | def forward_kinematics(mj_model, mj_data, q, body_name):
5 | """
6 | Compute the forward kinematics of the robots.
7 |
8 | Args:
9 | mj_model (mujoco.MjModel): mujoco MjModel of the robot-only model
10 | mj_data (mujoco.MjData): mujoco MjData object generated from the model
11 | q (np.array): joint configuration for which the forward kinematics are computed
12 | body_name (str): name of the body for which the fk is computed
13 |
14 | Returns (np.array(3), np.array(3, 3)):
15 | Position and Orientation of the body with the name body_name
16 | """
17 |
18 | mj_data.qpos[:len(q)] = q
19 | mujoco.mj_fwdPosition(mj_model, mj_data)
20 | return mj_data.body(body_name).xpos.copy(), mj_data.body(body_name).xmat.reshape(3, 3).copy()
21 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats as st
3 |
4 |
5 | def get_mean_and_confidence(data):
6 | """
7 | Compute the mean and 95% confidence interval
8 |
9 | Args:
10 | data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs).
11 |
12 | Returns:
13 | The mean of the dataset at each epoch along with the confidence interval.
14 |
15 | """
16 | mean = np.mean(data, axis=0)
17 | se = st.sem(data, axis=0)
18 | n = len(data)
19 | _, interval = st.t.interval(0.95, n-1, scale=se)
20 | return mean, interval
21 |
22 |
23 | def plot_mean_conf(data, ax, color='blue', line='-', facecolor=None, alpha=0.4, label=None):
24 | """
25 | Method to plot mean and confidence interval for data on matplotlib axes.
26 |
27 | Args:
28 | data (np.ndarray): Array of experiment data of shape (n_runs, n_epochs);
29 | ax (plt.Axes): matplotlib axes where to create the curve;
30 | color (str, 'blue'): matplotlib color identifier for the mean curve;
31 | line (str, '-'): matplotlib line type to be used for the mean curve;
32 | facecolor (str, None): matplotlib color identifier for the confidence interval;
33 | alpha (float, 0.4): transparency of the confidence interval;
34 | label (str, one): legend label for the plotted curve.
35 |
36 |
37 | """
38 | facecolor = color if facecolor is None else facecolor
39 |
40 | mean, conf = get_mean_and_confidence(np.array(data))
41 | upper_bound = mean + conf
42 | lower_bound = mean - conf
43 |
44 | ax.plot(mean, color=color, linestyle=line, label=label)
45 | ax.fill_between(np.arange(np.size(mean)), upper_bound, lower_bound, facecolor=facecolor, alpha=alpha)
46 |
47 |
48 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/plots/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = []
2 |
3 | try:
4 | from .plot_item_buffer import PlotItemBuffer
5 | __all__.append('PlotItemBuffer')
6 |
7 | from .databuffer import DataBuffer
8 | __all__.append('DataBuffer')
9 |
10 | from .window import Window
11 | __all__.append('Window')
12 |
13 | from .common_plots import Actions, LenOfEpisodeTraining, Observations,\
14 | RewardPerEpisode, RewardPerStep
15 |
16 | __all__ += ['Actions', 'LenOfEpisodeTraining', 'Observations',
17 | 'RewardPerEpisode', 'RewardPerStep']
18 |
19 | except ImportError:
20 | pass
21 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/pybullet/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .observation import PyBulletObservationType
3 | from .index_map import IndexMap
4 | from .viewer import PyBulletViewer
5 | from .joints_helper import JointsHelper
6 | except ImportError:
7 | pass
8 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/pybullet/joints_helper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .observation import PyBulletObservationType
4 |
5 |
6 | class JointsHelper(object):
7 | def __init__(self, client, indexer, observation_spec):
8 | self._joint_pos_indexes = list()
9 | self._joint_velocity_indexes = list()
10 | joint_limits_low = list()
11 | joint_limits_high = list()
12 | joint_velocity_limits = list()
13 | for joint_name, obs_type in observation_spec:
14 | joint_idx = indexer.get_index(joint_name, obs_type)
15 | if obs_type == PyBulletObservationType.JOINT_VEL:
16 | self._joint_velocity_indexes.append(joint_idx[0])
17 |
18 | model_id, joint_id = indexer.joint_map[joint_name]
19 | joint_info = client.getJointInfo(model_id, joint_id)
20 | joint_velocity_limits.append(joint_info[11])
21 |
22 | elif obs_type == PyBulletObservationType.JOINT_POS:
23 | self._joint_pos_indexes.append(joint_idx[0])
24 |
25 | model_id, joint_id = indexer.joint_map[joint_name]
26 | joint_info = client.getJointInfo(model_id, joint_id)
27 | joint_limits_low.append(joint_info[8])
28 | joint_limits_high.append(joint_info[9])
29 |
30 | self._joint_limits_low = np.array(joint_limits_low)
31 | self._joint_limits_high = np.array(joint_limits_high)
32 | self._joint_velocity_limits = np.array(joint_velocity_limits)
33 |
34 | def positions(self, state):
35 | return state[self._joint_pos_indexes]
36 |
37 | def velocities(self, state):
38 | return state[self._joint_velocity_indexes]
39 |
40 | def limits(self):
41 | return self._joint_limits_low, self._joint_limits_high
42 |
43 | def velocity_limits(self):
44 | return self._joint_velocity_limits
45 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/pybullet/observation.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class PyBulletObservationType(Enum):
5 | """
6 | An enum indicating the type of data that should be added to the observation
7 | of the environment, can be Joint-/Body-/Site- positions and velocities.
8 |
9 | """
10 | __order__ = "BODY_POS BODY_LIN_VEL BODY_ANG_VEL JOINT_POS JOINT_VEL LINK_POS LINK_LIN_VEL LINK_ANG_VEL CONTACT_FLAG"
11 | BODY_POS = 0
12 | BODY_LIN_VEL = 1
13 | BODY_ANG_VEL = 2
14 | JOINT_POS = 3
15 | JOINT_VEL = 4
16 | LINK_POS = 5
17 | LINK_LIN_VEL = 6
18 | LINK_ANG_VEL = 7
19 | CONTACT_FLAG = 8
--------------------------------------------------------------------------------
/mushroom_rl/utils/pybullet/viewer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pybullet
3 | from mushroom_rl.utils.viewer import ImageViewer
4 |
5 |
6 | class PyBulletViewer(ImageViewer):
7 | def __init__(self, client, dt, size=(500, 500), distance=4, origin=(0, 0, 1), angles=(0, -45, 60),
8 | fov=60, aspect=1, near_val=0.01, far_val=100):
9 | self._client = client
10 | self._size = size
11 | self._distance = distance
12 | self._origin = origin
13 | self._angles = angles
14 | self._fov = fov
15 | self._aspect = aspect
16 | self._near_val = near_val
17 | self._far_val = far_val
18 | super().__init__(size, dt)
19 |
20 | def display(self):
21 | img = self._get_image()
22 | super().display(img)
23 |
24 | return img
25 |
26 | def _get_image(self):
27 | view_matrix = self._client.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=self._origin,
28 | distance=self._distance,
29 | roll=self._angles[0],
30 | pitch=self._angles[1],
31 | yaw=self._angles[2],
32 | upAxisIndex=2)
33 | proj_matrix = self._client.computeProjectionMatrixFOV(fov=self._fov, aspect=self._aspect,
34 | nearVal=self._near_val, farVal=self._far_val)
35 | (_, _, px, _, _) = self._client.getCameraImage(width=self._size[0],
36 | height=self._size[1],
37 | viewMatrix=view_matrix,
38 | projectionMatrix=proj_matrix,
39 | renderer=pybullet.ER_BULLET_HARDWARE_OPENGL)
40 |
41 | rgb_array = np.reshape(np.array(px), (self._size[0], self._size[1], -1))
42 | rgb_array = rgb_array[:, :, :3]
43 | return rgb_array
44 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/quaternions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.utils.angles import mat_to_euler, euler_to_quat
4 |
5 |
6 | def normalize_quaternion(q):
7 | norm = np.linalg.norm(q)
8 | return q / norm
9 |
10 |
11 | def quaternion_distance(q1, q2):
12 | q1 = normalize_quaternion(q1)
13 | q2 = normalize_quaternion(q2)
14 |
15 | cos_half_angle = np.abs(np.dot(q1, q2))
16 |
17 | theta = 2 * np.arccos(cos_half_angle)
18 | return theta / 2
19 |
20 |
21 | def mat_to_quat(mat):
22 | euler = mat_to_euler(mat)
23 | quat = euler_to_quat(euler)
24 | return quat
25 |
--------------------------------------------------------------------------------
/mushroom_rl/utils/record.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import datetime
3 | from pathlib import Path
4 |
5 |
6 | class VideoRecorder(object):
7 | """
8 | Simple video record that creates a video from a stream of images.
9 |
10 | """
11 |
12 | def __init__(self, path="./mushroom_rl_recordings", tag=None, video_name=None, fps=60):
13 | """
14 | Constructor.
15 |
16 | Args:
17 | path: Path at which videos will be stored.
18 | tag: Name of the directory at path in which the video will be stored. If None, a timestamp will be created.
19 | video_name: Name of the video without extension. Default is "recording".
20 | fps: Frame rate of the video.
21 | """
22 |
23 | if tag is None:
24 | date_time = datetime.datetime.now()
25 | tag = date_time.strftime("%d-%m-%Y_%H-%M-%S")
26 |
27 | self._path = Path(path)
28 | self._path = self._path / tag
29 |
30 | self._video_name = video_name if video_name else "recording"
31 | self._counter = 0
32 |
33 | self._fps = fps
34 |
35 | self._video_writer = None
36 |
37 | def __call__(self, frame):
38 | """
39 | Args:
40 | frame (np.ndarray): Frame to be added to the video (H, W, RGB)
41 | """
42 | assert frame is not None
43 |
44 | if self._video_writer is None:
45 | height, width = frame.shape[:2]
46 | self._create_video_writer(height, width)
47 |
48 | self._video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
49 |
50 | def _create_video_writer(self, height, width):
51 |
52 | name = self._video_name
53 | if self._counter > 0:
54 | name += f"-{self._counter}.mp4"
55 | else:
56 | name += ".mp4"
57 |
58 | self._path.mkdir(parents=True, exist_ok=True)
59 |
60 | path = self._path / name
61 |
62 | self._video_writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc('m', 'p', '4', 'v'),
63 | self._fps, (width, height))
64 |
65 | def stop(self):
66 | cv2.destroyAllWindows()
67 | self._video_writer.release()
68 | self._video_writer = None
69 | self._counter += 1
70 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel", "numpy"]
3 |
4 | [project]
5 | name = "mushroom-rl"
6 | dependencies = [
7 | "numpy",
8 | "scipy",
9 | "scikit-learn",
10 | "matplotlib",
11 | "joblib",
12 | "tqdm",
13 | "pygame",
14 | "opencv-python>=4.7",
15 | "torch"
16 | ]
17 | requires-python = ">=3.8"
18 | authors = [
19 | { name="Carlo D'Eramo", email="carlo.deramo@gmail.com"},
20 | { name = "Davide Tateo", email = "davide@robot-learning.de" }
21 | ]
22 | maintainers = [
23 | { name = "Davide Tateo", email = "davide@robot-learning.de" }
24 | ]
25 | description = "A Python library for Reinforcement Learning experiments."
26 | readme = "README.rst"
27 | license = { file= "LICENSE" }
28 | keywords = ["Reinforcement Learning", "Machine Learning", "Robotics"]
29 | classifiers = [
30 | "Programming Language :: Python :: 3",
31 | "License :: OSI Approved :: MIT License",
32 | "Operating System :: OS Independent",
33 | ]
34 |
35 | dynamic = ["version", "optional-dependencies"]
36 |
37 | [project.urls]
38 | Homepage = "https://github.com/MushroomRL"
39 | Documentation = "https://mushroomrl.readthedocs.io/en/latest/"
40 | Repository = "https://github.com/MushroomRL/mushroom-rl"
41 | Issues = "https://github.com/MushroomRL/mushroom-rl/issues"
42 |
--------------------------------------------------------------------------------
/tests/algorithms/test_dpg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from datetime import datetime
4 | from helper.utils import TestUtils as tu
5 |
6 | from mushroom_rl.core import Agent
7 | from mushroom_rl.algorithms.actor_critic import COPDAC_Q
8 | from mushroom_rl.core import Core
9 | from mushroom_rl.environments import *
10 | from mushroom_rl.features import Features
11 | from mushroom_rl.features.tiles import Tiles
12 | from mushroom_rl.approximators import Regressor
13 | from mushroom_rl.approximators.parametric import LinearApproximator
14 | from mushroom_rl.policy import GaussianPolicy
15 | from mushroom_rl.rl_utils.parameters import Parameter
16 |
17 |
18 | def learn_copdac_q():
19 | n_steps = 50
20 | mdp = InvertedPendulum(horizon=n_steps)
21 | np.random.seed(1)
22 | torch.manual_seed(1)
23 | torch.cuda.manual_seed(1)
24 |
25 | # Agent
26 | n_tilings = 1
27 | alpha_theta = Parameter(5e-3 / n_tilings)
28 | alpha_omega = Parameter(0.5 / n_tilings)
29 | alpha_v = Parameter(0.5 / n_tilings)
30 | tilings = Tiles.generate(n_tilings, [2, 2],
31 | mdp.info.observation_space.low,
32 | mdp.info.observation_space.high + 1e-3)
33 |
34 | phi = Features(tilings=tilings)
35 |
36 | input_shape = (phi.size,)
37 |
38 | mu = Regressor(LinearApproximator, input_shape=input_shape, output_shape=mdp.info.action_space.shape, phi=phi)
39 |
40 | sigma = 1e-1 * np.eye(1)
41 | policy = GaussianPolicy(mu, sigma)
42 |
43 | agent = COPDAC_Q(mdp.info, policy, mu, alpha_theta, alpha_omega, alpha_v, value_function_features=phi)
44 |
45 | # Train
46 | core = Core(agent, mdp)
47 |
48 | core.learn(n_episodes=2, n_episodes_per_fit=1)
49 |
50 | return agent
51 |
52 |
53 | def test_copdac_q():
54 | policy = learn_copdac_q().policy
55 | w = policy.get_weights()
56 | w_test = np.array([[0.0, -6.62180045e-07, 0.0, -4.23972882e-02]])
57 |
58 | assert np.allclose(w, w_test)
59 |
60 |
61 | def test_copdac_q_save(tmpdir):
62 | agent_path = tmpdir / 'agent_{}'.format(datetime.now().strftime("%H%M%S%f"))
63 |
64 | agent_save = learn_copdac_q()
65 |
66 | agent_save.save(agent_path)
67 | agent_load = Agent.load(agent_path)
68 |
69 | for att, method in vars(agent_save).items():
70 | save_attr = getattr(agent_save, att)
71 | load_attr = getattr(agent_load, att)
72 |
73 | tu.assert_eq(save_attr, load_attr)
74 |
--------------------------------------------------------------------------------
/tests/algorithms/test_lspi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from datetime import datetime
4 | from helper.utils import TestUtils as tu
5 |
6 | from mushroom_rl.core import Agent
7 | from mushroom_rl.algorithms.value import LSPI
8 | from mushroom_rl.core import Core
9 | from mushroom_rl.environments import *
10 | from mushroom_rl.features import Features
11 | from mushroom_rl.features.basis import PolynomialBasis
12 | from mushroom_rl.policy import EpsGreedy
13 | from mushroom_rl.rl_utils.parameters import Parameter
14 |
15 |
16 | def learn_lspi():
17 | np.random.seed(1)
18 |
19 | # MDP
20 | mdp = CartPole()
21 |
22 | # Policy
23 | epsilon = Parameter(value=1.)
24 | pi = EpsGreedy(epsilon=epsilon)
25 |
26 | # Agent
27 | basis = [PolynomialBasis()]
28 | features = Features(basis_list=basis)
29 |
30 | fit_params = dict()
31 | approximator_params = dict(input_shape=(features.size,),
32 | output_shape=(mdp.info.action_space.n,),
33 | n_actions=mdp.info.action_space.n,
34 | phi=features)
35 | agent = LSPI(mdp.info, pi, approximator_params=approximator_params, fit_params=fit_params)
36 |
37 | # Algorithm
38 | core = Core(agent, mdp)
39 |
40 | # Train
41 | core.learn(n_episodes=10, n_episodes_per_fit=10)
42 |
43 | return agent
44 |
45 |
46 | def test_lspi():
47 |
48 | w = learn_lspi().approximator.get_weights()
49 | w_test = np.array([-1.67115903, -1.43755615, -1.67115903])
50 |
51 | assert np.allclose(w, w_test)
52 |
53 |
54 | def test_lspi_save(tmpdir):
55 | agent_path = tmpdir / 'agent_{}'.format(datetime.now().strftime("%H%M%S%f"))
56 |
57 | agent_save = learn_lspi()
58 |
59 | agent_save.save(agent_path)
60 | agent_load = Agent.load(agent_path)
61 |
62 | for att, method in vars(agent_save).items():
63 | save_attr = getattr(agent_save, att)
64 | load_attr = getattr(agent_load, att)
65 |
66 | tu.assert_eq(save_attr, load_attr)
67 |
--------------------------------------------------------------------------------
/tests/core/test_core.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.core import Agent, Core
4 | from mushroom_rl.environments import Atari
5 |
6 | from mushroom_rl.policy import Policy
7 |
8 |
9 | class RandomDiscretePolicy(Policy):
10 | def __init__(self, n):
11 | super().__init__()
12 | self._n = n
13 |
14 | def draw_action(self, state, policy_state=None):
15 | return [np.random.randint(self._n)], None
16 |
17 |
18 | class DummyAgent(Agent):
19 | def __init__(self, mdp_info):
20 | policy = RandomDiscretePolicy(mdp_info.action_space.n)
21 | super().__init__(mdp_info, policy)
22 |
23 | def fit(self, dataset):
24 | pass
25 |
26 |
27 | def test_core():
28 | mdp = Atari(name='ALE/Breakout-v5', repeat_action_probability=0.0)
29 |
30 | agent = DummyAgent(mdp.info)
31 |
32 | core = Core(agent, mdp)
33 |
34 | np.random.seed(2)
35 | mdp.seed(2)
36 |
37 | core.learn(n_steps=100, n_steps_per_fit=1)
38 |
39 | dataset = core.evaluate(n_steps=20)
40 |
41 | assert 'lives' in dataset.info
42 | assert 'episode_frame_number' in dataset.info
43 | assert 'frame_number' in dataset.info
44 |
45 | info_lives = np.array(dataset.info['lives'])
46 |
47 | print(info_lives)
48 | lives_gt = np.array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.])
49 | assert len(info_lives) == 20
50 | assert np.all(info_lives == lives_gt)
51 | assert len(dataset) == 20
52 |
53 |
54 |
--------------------------------------------------------------------------------
/tests/core/test_logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mushroom_rl.core import Logger
3 |
4 |
5 | def test_logger(tmpdir):
6 | logger_1 = Logger('test', seed=1, results_dir=tmpdir)
7 | logger_2 = Logger('test', seed=2, results_dir=tmpdir)
8 |
9 | for i in range(3):
10 | logger_1.log_numpy(a=i, b=2*i+1)
11 | logger_2.log_numpy(a=2*i+1, b=i)
12 |
13 | a_1 = np.load(str(tmpdir / 'test' / 'a-1.npy'))
14 | a_2 = np.load(str(tmpdir / 'test' / 'a-2.npy'))
15 | b_1 = np.load(str(tmpdir / 'test' / 'b-1.npy'))
16 | b_2 = np.load(str(tmpdir / 'test' / 'b-2.npy'))
17 |
18 | assert np.array_equal(a_1, np.arange(3))
19 | assert np.array_equal(b_2, np.arange(3))
20 | assert np.array_equal(a_1, b_2)
21 | assert np.array_equal(b_1, a_2)
22 |
23 | logger_1_bis = Logger('test', append=True, seed=1, results_dir=tmpdir)
24 |
25 | logger_1_bis.log_numpy(a=3, b=7)
26 | a_1 = np.load(str(tmpdir / 'test' / 'a-1.npy'))
27 | b_2 = np.load(str(tmpdir / 'test' / 'b-2.npy'))
28 |
29 | assert np.array_equal(a_1, np.arange(4))
30 | assert np.array_equal(b_2, np.arange(3))
31 |
--------------------------------------------------------------------------------
/tests/core/test_serialization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from mushroom_rl.core import Serializable
5 | from mushroom_rl.utils import TorchUtils
6 |
7 |
8 | class DummyClass(Serializable):
9 | def __init__(self):
10 | self.torch_tensor = torch.randn(2, 2).to(TorchUtils.get_device())
11 | self.numpy_array = np.random.randn(3, 4)
12 | self.scalar = 1
13 | self.dictionary = {'a': 'test', 'b': 5, 'd': (2, 3)}
14 | self.not_saved = 'test2'
15 |
16 | self._add_save_attr(
17 | torch_tensor='torch',
18 | numpy_array='numpy',
19 | scalar='primitive',
20 | dictionary='pickle',
21 | not_saved='none'
22 | )
23 |
24 | def __eq__(self, other):
25 | f1 = torch.equal(self.torch_tensor.cpu(), other.torch_tensor.cpu())
26 | f2 = np.array_equal(self.numpy_array, other.numpy_array)
27 | f3 = self.scalar == other.scalar
28 | f4 = self.dictionary == other.dictionary
29 |
30 | return f1 and f2 and f3 and f4
31 |
32 |
33 | def test_serialization(tmpdir):
34 | TorchUtils.set_default_device('cpu')
35 |
36 | a = DummyClass()
37 | a.save(tmpdir / 'test.msh')
38 |
39 | b = Serializable.load(tmpdir / 'test.msh')
40 |
41 | assert a == b
42 | assert b.not_saved == None
43 |
44 |
45 | def test_serialization_cuda_cpu(tmpdir):
46 | if torch.cuda.is_available():
47 | TorchUtils.set_default_device('cuda')
48 |
49 | a = DummyClass()
50 | a.save(tmpdir / 'test.msh')
51 |
52 | TorchUtils.set_default_device('cpu')
53 |
54 | assert a.torch_tensor.device.type == 'cuda'
55 |
56 | b = Serializable.load(tmpdir / 'test.msh')
57 |
58 | assert b.torch_tensor.device.type == 'cpu'
59 |
60 | assert a == b
61 |
62 |
63 | def test_serialization_cpu_cuda(tmpdir):
64 | if torch.cuda.is_available():
65 | TorchUtils.set_default_device('cpu')
66 |
67 | a = DummyClass()
68 | a.save(tmpdir / 'test.msh')
69 |
70 | TorchUtils.set_default_device('cuda')
71 |
72 | assert a.torch_tensor.device.type == 'cpu'
73 |
74 | b = Serializable.load(tmpdir / 'test.msh')
75 |
76 | assert b.torch_tensor.device.type == 'cuda'
77 |
78 | assert a == b
79 |
80 | TorchUtils.set_default_device('cpu')
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/tests/distributions/test_distribution_interface.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.distributions import Distribution
2 |
3 |
4 | def abstract_method_tester(f, *args):
5 | try:
6 | f(*args)
7 | except NotImplementedError:
8 | pass
9 | else:
10 | assert False
11 |
12 |
13 | def test_distribution_interface():
14 | tmp = Distribution()
15 |
16 | abstract_method_tester(tmp.sample)
17 | abstract_method_tester(tmp.log_pdf, None)
18 | abstract_method_tester(tmp.__call__, None)
19 | abstract_method_tester(tmp.entropy)
20 | abstract_method_tester(tmp.mle, None)
21 | abstract_method_tester(tmp.diff_log, None)
22 | abstract_method_tester(tmp.diff, None)
23 |
24 | abstract_method_tester(tmp.get_parameters)
25 | abstract_method_tester(tmp.set_parameters, None)
26 |
27 | try:
28 | tmp.parameters_size
29 | except NotImplementedError:
30 | pass
31 | else:
32 | assert False
33 |
--------------------------------------------------------------------------------
/tests/environments/grid.txt:
--------------------------------------------------------------------------------
1 | #####
2 | #S.*#
3 | #...#
4 | ##..#
5 | #G..#
6 | #####
7 |
--------------------------------------------------------------------------------
/tests/environments/mujoco_envs/air_hockey_defend_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_defend_data.npy
--------------------------------------------------------------------------------
/tests/environments/mujoco_envs/air_hockey_hit_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_hit_data.npy
--------------------------------------------------------------------------------
/tests/environments/mujoco_envs/air_hockey_prepare_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_prepare_data.npy
--------------------------------------------------------------------------------
/tests/environments/mujoco_envs/air_hockey_repel_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/mujoco_envs/air_hockey_repel_data.npy
--------------------------------------------------------------------------------
/tests/environments/mujoco_envs/test_ball_in_a_cup.py:
--------------------------------------------------------------------------------
1 | try:
2 | from mushroom_rl.environments.mujoco_envs import BallInACup
3 | import numpy as np
4 |
5 | def linear_movement(start, end, n_steps, i):
6 | t = np.minimum(1., float(i) / float(n_steps))
7 | return start + (end - start) * t
8 |
9 |
10 | def test_ball_in_a_cup():
11 | env = BallInACup()
12 |
13 | des_pos = np.array([0.0, -0.58760536, 0.0, 1.36004913, 0.0, -0.32072943, -1.57])
14 | p_gains = np.array([200, 300, 100, 100, 10, 10, 2.5])/5
15 | d_gains = np.array([7, 15, 5, 2.5, 0.3, 0.3, 0.05])/10
16 |
17 | obs_0, _ = env.reset()
18 |
19 | for _ in [1,2]:
20 | obs, _ = env.reset()
21 |
22 | assert np.array_equal(obs, obs_0)
23 | done = False
24 | i = 0
25 | while not done:
26 | q_cmd = env.linear_movement(env.init_robot_pos, des_pos, 100, i)
27 | q_curr = obs[0:14:2]
28 | qdot_cur = obs[1:14:2]
29 | pos_err = q_cmd - q_curr
30 |
31 | a = env._data.qfrc_bias[:7] + p_gains * pos_err - d_gains * qdot_cur
32 | #a = np.zeros(7)
33 |
34 | # Check the observations
35 | assert np.allclose(obs[0:14:2], env._data.qpos[0:7])
36 | assert np.allclose(obs[1:14:2], env._data.qvel[0:7])
37 | # assert np.allclose(obs[14:17], env._data.xpos[40])
38 | # assert np.allclose(obs[17:], env._data.cvel[40])
39 | obs, reward, done, info = env.step(a)
40 |
41 | i += 1
42 |
43 | except ImportError:
44 | pass
45 |
--------------------------------------------------------------------------------
/tests/environments/pybullet_envs/air_hockey_defend_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_defend_data.npy
--------------------------------------------------------------------------------
/tests/environments/pybullet_envs/air_hockey_hit_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_hit_data.npy
--------------------------------------------------------------------------------
/tests/environments/pybullet_envs/air_hockey_prepare_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_prepare_data.npy
--------------------------------------------------------------------------------
/tests/environments/pybullet_envs/air_hockey_repel_data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/pybullet_envs/air_hockey_repel_data.npy
--------------------------------------------------------------------------------
/tests/environments/taxi.txt:
--------------------------------------------------------------------------------
1 | S#F.#.G
2 | .#..#..
3 | .......
4 | ##...##
5 | ......F
6 | F.....#
7 |
--------------------------------------------------------------------------------
/tests/environments/test_atari_1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MushroomRL/mushroom-rl/d9a8a99287549f06065b7c75e5a92af0ffb4b5c5/tests/environments/test_atari_1.npy
--------------------------------------------------------------------------------
/tests/environments/test_mujoco.py:
--------------------------------------------------------------------------------
1 | try:
2 | from mushroom_rl.environments.dm_control_env import DMControl
3 | import numpy as np
4 |
5 | def test_dm_control():
6 | np.random.seed(1)
7 | mdp = DMControl('hopper', 'hop', 1000, .99, task_kwargs={'random': 1})
8 | mdp.reset()
9 | for i in range(10):
10 | ns, r, ab, _ = mdp.step(
11 | np.random.rand(mdp.info.action_space.shape[0]))
12 | ns_test = np.array([-0.25868173, -2.24011367, 0.45346572, -0.55528368,
13 | 0.51603826, -0.21782316, -0.58708578, -2.04541986,
14 | -17.24931206, 5.42227781, 21.39084468, -2.42071806,
15 | 3.85448837, 0., 0.])
16 |
17 | assert np.allclose(ns, ns_test)
18 | except ImportError:
19 | pass
20 |
--------------------------------------------------------------------------------
/tests/policy/test_deterministic_policy.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.policy import DeterministicPolicy
2 | from mushroom_rl.approximators.regressor import Regressor
3 | from mushroom_rl.approximators.parametric import LinearApproximator
4 | from mushroom_rl.utils.numerical_gradient import numerical_diff_policy
5 |
6 | import numpy as np
7 |
8 |
9 | def test_deterministic_policy():
10 | np.random.seed(88)
11 |
12 | n_dims = 5
13 |
14 | approximator = Regressor(LinearApproximator,
15 | input_shape=(n_dims,),
16 | output_shape=(2,))
17 |
18 | pi = DeterministicPolicy(approximator)
19 |
20 | w_new = np.random.rand(pi.weights_size)
21 |
22 | w_old = pi.get_weights()
23 | pi.set_weights(w_new)
24 |
25 | assert np.array_equal(w_new, approximator.get_weights())
26 | assert not np.array_equal(w_old, w_new)
27 | assert np.array_equal(w_new, pi.get_weights())
28 |
29 | s_test_1 = np.random.randn(5)
30 | s_test_2 = np.random.randn(5)
31 | a_test = approximator.predict(s_test_1)
32 |
33 | assert pi.get_regressor() == approximator
34 |
35 | assert pi(s_test_1, a_test) == 1
36 | assert pi(s_test_2, a_test) == 0
37 |
38 | a_stored = np.array([-1.86941072, -0.1789696])
39 | action, _ = pi.draw_action(s_test_1)
40 | assert np.allclose(action, a_stored)
41 |
42 |
--------------------------------------------------------------------------------
/tests/policy/test_noise_policy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mushroom_rl.approximators import Regressor
4 | from mushroom_rl.approximators.parametric import TorchApproximator
5 | from mushroom_rl.approximators.parametric.networks import LinearNetwork
6 | from mushroom_rl.policy import OrnsteinUhlenbeckPolicy, ClippedGaussianPolicy
7 |
8 |
9 | def test_ornstein_uhlenbeck_policy():
10 | torch.manual_seed(42)
11 |
12 | mu = Regressor(TorchApproximator, network=LinearNetwork, input_shape=(5,), output_shape=(2,))
13 | pi = OrnsteinUhlenbeckPolicy(mu, sigma=torch.ones(1) * .2, theta=.15, dt=1e-2)
14 |
15 | w = torch.randn(pi.weights_size)
16 | pi.set_weights(w)
17 | assert torch.equal(pi.get_weights(), w)
18 |
19 | state = torch.randn(5)
20 |
21 | policy_state = pi.reset()
22 |
23 | action, policy_state = pi.draw_action(state, policy_state)
24 | action_test = torch.tensor([-0.7055691481, 1.1255935431])
25 | assert torch.allclose(action, action_test)
26 |
27 | policy_state = pi.reset()
28 | action, policy_state = pi.draw_action(state, policy_state)
29 | action_test = torch.tensor([-0.7114595175, 1.1141412258])
30 | assert torch.allclose(action, action_test)
31 |
32 | try:
33 | pi(state, action)
34 | except NotImplementedError:
35 | pass
36 | else:
37 | assert False
38 |
39 |
40 | def test_clipped_gaussian_policy():
41 | torch.manual_seed(1)
42 |
43 | low = -torch.ones(2)
44 | high = torch.ones(2)
45 |
46 | mu = Regressor(TorchApproximator, network=LinearNetwork, input_shape=(5,), output_shape=(2,))
47 | pi = ClippedGaussianPolicy(mu, torch.eye(2), low, high)
48 |
49 | w = torch.randn(pi.weights_size)
50 | pi.set_weights(w)
51 | assert torch.equal(pi.get_weights(), w)
52 |
53 | state = torch.randn(5)
54 |
55 | action, _ = pi.draw_action(state)
56 | action_test = torch.tensor([-1.0, 1.0])
57 | assert torch.allclose(action, action_test)
58 |
59 | action, _ = pi.draw_action(state)
60 | action_test = torch.tensor([0.4926533699, 1.0])
61 | assert torch.allclose(action, action_test)
62 |
63 | try:
64 | pi(state, action)
65 | except NotImplementedError:
66 | pass
67 | else:
68 | assert False
69 |
70 | # TODO Missing test for clipped gaussian!
71 |
72 |
--------------------------------------------------------------------------------
/tests/policy/test_policy_interface.py:
--------------------------------------------------------------------------------
1 | from mushroom_rl.policy import Policy, ParametricPolicy
2 |
3 |
4 | def abstract_method_tester(f, ex, *args):
5 | try:
6 | f(*args)
7 | except ex:
8 | pass
9 | else:
10 | assert False
11 |
12 |
13 | def test_policy_interface():
14 | tmp = Policy()
15 | abstract_method_tester(tmp.__call__, NotImplementedError, None, None, None)
16 | abstract_method_tester(tmp.draw_action, NotImplementedError, None, None)
17 | tmp.reset()
18 |
19 |
20 | def test_parametric_policy():
21 | tmp = ParametricPolicy()
22 | abstract_method_tester(tmp.diff_log, RuntimeError, None, None, None)
23 | abstract_method_tester(tmp.diff, RuntimeError, None, None, None)
24 | abstract_method_tester(tmp.set_weights, NotImplementedError, None)
25 | abstract_method_tester(tmp.get_weights, NotImplementedError)
26 | try:
27 | tmp.weights_size
28 | except NotImplementedError:
29 | pass
30 | else:
31 | assert False
32 |
--------------------------------------------------------------------------------
/tests/solvers/test_car_on_hill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mushroom_rl.environments.car_on_hill import CarOnHill
4 | from mushroom_rl.solvers.car_on_hill import solve_car_on_hill
5 |
6 |
7 | def test_car_on_hill():
8 | mdp = CarOnHill()
9 | mdp._discrete_actions = np.array([-8., 8.])
10 |
11 | states = np.array([[-.5, 0], [0., 0.], [.5, 0.]])
12 | actions = np.array([[0], [1], [0]])
13 | q = solve_car_on_hill(mdp, states, actions, .95)
14 | q_test = np.array([0.5688000922764597, 0.48767497911552954,
15 | 0.5688000922764597])
16 |
17 | assert np.allclose(q, q_test)
18 |
--------------------------------------------------------------------------------
/tests/test_imports.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def test_imports():
4 | import mushroom_rl
5 |
6 | import mushroom_rl.algorithms
7 | import mushroom_rl.algorithms.actor_critic
8 | import mushroom_rl.algorithms.actor_critic.classic_actor_critic
9 | import mushroom_rl.algorithms.actor_critic.deep_actor_critic
10 | import mushroom_rl.algorithms.policy_search
11 | import mushroom_rl.algorithms.policy_search.black_box_optimization
12 | import mushroom_rl.algorithms.policy_search.policy_gradient
13 | import mushroom_rl.algorithms.value
14 | import mushroom_rl.algorithms.value.batch_td
15 | import mushroom_rl.algorithms.value.td
16 | import mushroom_rl.algorithms.value.dqn
17 |
18 | import mushroom_rl.approximators
19 | import mushroom_rl.approximators._implementations
20 | import mushroom_rl.approximators.parametric
21 |
22 | import mushroom_rl.core
23 |
24 | import mushroom_rl.distributions
25 |
26 | import mushroom_rl.environments
27 | import mushroom_rl.environments.generators
28 |
29 | try:
30 | import mujoco
31 | except ImportError:
32 | pass
33 | else:
34 | import mushroom_rl.environments.mujoco_envs
35 |
36 | import mushroom_rl.features
37 | import mushroom_rl.features._implementations
38 | import mushroom_rl.features.basis
39 | import mushroom_rl.features.tensors
40 | import mushroom_rl.features.tiles
41 |
42 | import mushroom_rl.policy
43 |
44 | import mushroom_rl.solvers
45 |
46 | import mushroom_rl.utils
47 |
48 |
--------------------------------------------------------------------------------