├── .github ├── GITHUB_ACTIONS.md ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── config.yml └── workflows │ ├── pre-commit.yml │ ├── python-publish-manual.yml │ ├── tests-jax.yml │ └── tests-torch.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── README.md ├── make.bat ├── requirements.txt └── source │ ├── 404.rst │ ├── _static │ ├── css │ │ ├── s5defs-roles.css │ │ └── skrl.css │ ├── data │ │ ├── 404-dark.svg │ │ ├── 404-light.svg │ │ ├── favicon.ico │ │ ├── logo-dark-mode.png │ │ ├── logo-jax.svg │ │ ├── logo-light-mode.png │ │ ├── logo-torch.svg │ │ ├── skrl-up-transparent.png │ │ └── skrl-up.png │ └── imgs │ │ ├── data_tensorboard.jpg │ │ ├── example_bidexhands.png │ │ ├── example_deepmind.png │ │ ├── example_gym.png │ │ ├── example_isaacgym.png │ │ ├── example_isaaclab.png │ │ ├── example_omniverse_isaacgym.png │ │ ├── example_parallel.jpg │ │ ├── example_robosuite.png │ │ ├── example_shimmy.png │ │ ├── manual_trainer-dark.svg │ │ ├── manual_trainer-light.svg │ │ ├── model_categorical-dark.svg │ │ ├── model_categorical-light.svg │ │ ├── model_categorical_cnn-dark.svg │ │ ├── model_categorical_cnn-light.svg │ │ ├── model_categorical_mlp-dark.svg │ │ ├── model_categorical_mlp-light.svg │ │ ├── model_categorical_rnn-dark.svg │ │ ├── model_categorical_rnn-light.svg │ │ ├── model_deterministic-dark.svg │ │ ├── model_deterministic-light.svg │ │ ├── model_deterministic_cnn-dark.svg │ │ ├── model_deterministic_cnn-light.svg │ │ ├── model_deterministic_mlp-dark.svg │ │ ├── model_deterministic_mlp-light.svg │ │ ├── model_deterministic_rnn-dark.svg │ │ ├── model_deterministic_rnn-light.svg │ │ ├── model_gaussian-dark.svg │ │ ├── model_gaussian-light.svg │ │ ├── model_gaussian_cnn-dark.svg │ │ ├── model_gaussian_cnn-light.svg │ │ ├── model_gaussian_mlp-dark.svg │ │ ├── model_gaussian_mlp-light.svg │ │ ├── model_gaussian_rnn-dark.svg │ │ ├── model_gaussian_rnn-light.svg │ │ ├── model_multicategorical-dark.svg │ │ ├── model_multicategorical-light.svg │ │ ├── model_multivariate_gaussian-dark.svg │ │ ├── model_multivariate_gaussian-light.svg │ │ ├── multi_agent_wrapping-dark.svg │ │ ├── multi_agent_wrapping-light.svg │ │ ├── noise_gaussian.png │ │ ├── noise_ornstein_uhlenbeck.png │ │ ├── parallel_trainer-dark.svg │ │ ├── parallel_trainer-light.svg │ │ ├── rl_schema-dark.svg │ │ ├── rl_schema-light.svg │ │ ├── sequential_trainer-dark.svg │ │ ├── sequential_trainer-light.svg │ │ ├── utils_tensorboard_file_iterator.svg │ │ ├── wrapping-dark.svg │ │ └── wrapping-light.svg │ ├── api │ ├── agents.rst │ ├── agents │ │ ├── a2c.rst │ │ ├── amp.rst │ │ ├── cem.rst │ │ ├── ddpg.rst │ │ ├── ddqn.rst │ │ ├── dqn.rst │ │ ├── ppo.rst │ │ ├── q_learning.rst │ │ ├── rpo.rst │ │ ├── sac.rst │ │ ├── sarsa.rst │ │ ├── td3.rst │ │ └── trpo.rst │ ├── config │ │ └── frameworks.rst │ ├── envs.rst │ ├── envs │ │ ├── isaac_gym.rst │ │ ├── isaaclab.rst │ │ ├── multi_agents_wrapping.rst │ │ ├── omniverse_isaac_gym.rst │ │ └── wrapping.rst │ ├── memories.rst │ ├── memories │ │ └── random.rst │ ├── models.rst │ ├── models │ │ ├── categorical.rst │ │ ├── deterministic.rst │ │ ├── gaussian.rst │ │ ├── multicategorical.rst │ │ ├── multivariate_gaussian.rst │ │ ├── shared_model.rst │ │ └── tabular.rst │ ├── multi_agents.rst │ ├── multi_agents │ │ ├── ippo.rst │ │ └── mappo.rst │ ├── resources.rst │ ├── resources │ │ ├── noises.rst │ │ ├── noises │ │ │ ├── gaussian.rst │ │ │ └── ornstein_uhlenbeck.rst │ │ ├── optimizers.rst │ │ ├── optimizers │ │ │ └── adam.rst │ │ ├── preprocessors.rst │ │ ├── preprocessors │ │ │ └── running_standard_scaler.rst │ │ ├── schedulers.rst │ │ └── schedulers │ │ │ └── kl_adaptive.rst │ ├── trainers.rst │ ├── trainers │ │ ├── manual.rst │ │ ├── parallel.rst │ │ ├── sequential.rst │ │ └── step.rst │ ├── utils.rst │ └── utils │ │ ├── distributed.rst │ │ ├── huggingface.rst │ │ ├── isaacgym_utils.rst │ │ ├── model_instantiators.rst │ │ ├── omniverse_isaacgym_utils.rst │ │ ├── postprocessing.rst │ │ ├── runner.rst │ │ ├── seed.rst │ │ └── spaces.rst │ ├── conf.py │ ├── examples │ ├── bidexhands │ │ ├── jax_bidexhands_shadow_hand_over_ippo.py │ │ ├── jax_bidexhands_shadow_hand_over_mappo.py │ │ ├── torch_bidexhands_shadow_hand_over_ippo.py │ │ └── torch_bidexhands_shadow_hand_over_mappo.py │ ├── deepmind │ │ ├── dm_manipulation_stack_sac.py │ │ └── dm_suite_cartpole_swingup_ddpg.py │ ├── gym │ │ ├── jax_gym_cartpole_cem.py │ │ ├── jax_gym_cartpole_dqn.py │ │ ├── jax_gym_cartpole_vector_dqn.py │ │ ├── jax_gym_pendulum_ddpg.py │ │ ├── jax_gym_pendulum_ppo.py │ │ ├── jax_gym_pendulum_sac.py │ │ ├── jax_gym_pendulum_td3.py │ │ ├── jax_gym_pendulum_vector_ddpg.py │ │ ├── torch_gym_cartpole_cem.py │ │ ├── torch_gym_cartpole_dqn.py │ │ ├── torch_gym_cartpole_vector_dqn.py │ │ ├── torch_gym_frozen_lake_q_learning.py │ │ ├── torch_gym_frozen_lake_vector_q_learning.py │ │ ├── torch_gym_pendulum_ddpg.py │ │ ├── torch_gym_pendulum_ppo.py │ │ ├── torch_gym_pendulum_sac.py │ │ ├── torch_gym_pendulum_td3.py │ │ ├── torch_gym_pendulum_trpo.py │ │ ├── torch_gym_pendulum_vector_ddpg.py │ │ ├── torch_gym_pendulumnovel_ddpg.py │ │ ├── torch_gym_pendulumnovel_ddpg_gru.py │ │ ├── torch_gym_pendulumnovel_ddpg_lstm.py │ │ ├── torch_gym_pendulumnovel_ddpg_rnn.py │ │ ├── torch_gym_pendulumnovel_ppo.py │ │ ├── torch_gym_pendulumnovel_ppo_gru.py │ │ ├── torch_gym_pendulumnovel_ppo_lstm.py │ │ ├── torch_gym_pendulumnovel_ppo_rnn.py │ │ ├── torch_gym_pendulumnovel_sac.py │ │ ├── torch_gym_pendulumnovel_sac_gru.py │ │ ├── torch_gym_pendulumnovel_sac_lstm.py │ │ ├── torch_gym_pendulumnovel_sac_rnn.py │ │ ├── torch_gym_pendulumnovel_td3.py │ │ ├── torch_gym_pendulumnovel_td3_gru.py │ │ ├── torch_gym_pendulumnovel_td3_lstm.py │ │ ├── torch_gym_pendulumnovel_td3_rnn.py │ │ ├── torch_gym_pendulumnovel_trpo.py │ │ ├── torch_gym_pendulumnovel_trpo_gru.py │ │ ├── torch_gym_pendulumnovel_trpo_lstm.py │ │ ├── torch_gym_pendulumnovel_trpo_rnn.py │ │ ├── torch_gym_taxi_sarsa.py │ │ └── torch_gym_taxi_vector_sarsa.py │ ├── gymnasium │ │ ├── jax_gymnasium_cartpole_cem.py │ │ ├── jax_gymnasium_cartpole_dqn.py │ │ ├── jax_gymnasium_cartpole_vector_dqn.py │ │ ├── jax_gymnasium_pendulum_ddpg.py │ │ ├── jax_gymnasium_pendulum_ppo.py │ │ ├── jax_gymnasium_pendulum_sac.py │ │ ├── jax_gymnasium_pendulum_td3.py │ │ ├── jax_gymnasium_pendulum_vector_ddpg.py │ │ ├── torch_gymnasium_cartpole_cem.py │ │ ├── torch_gymnasium_cartpole_dqn.py │ │ ├── torch_gymnasium_cartpole_vector_dqn.py │ │ ├── torch_gymnasium_frozen_lake_q_learning.py │ │ ├── torch_gymnasium_frozen_lake_vector_q_learning.py │ │ ├── torch_gymnasium_pendulum_ddpg.py │ │ ├── torch_gymnasium_pendulum_ppo.py │ │ ├── torch_gymnasium_pendulum_sac.py │ │ ├── torch_gymnasium_pendulum_td3.py │ │ ├── torch_gymnasium_pendulum_trpo.py │ │ ├── torch_gymnasium_pendulum_vector_ddpg.py │ │ ├── torch_gymnasium_pendulumnovel_ddpg.py │ │ ├── torch_gymnasium_pendulumnovel_ddpg_gru.py │ │ ├── torch_gymnasium_pendulumnovel_ddpg_lstm.py │ │ ├── torch_gymnasium_pendulumnovel_ddpg_rnn.py │ │ ├── torch_gymnasium_pendulumnovel_ppo.py │ │ ├── torch_gymnasium_pendulumnovel_ppo_gru.py │ │ ├── torch_gymnasium_pendulumnovel_ppo_lstm.py │ │ ├── torch_gymnasium_pendulumnovel_ppo_rnn.py │ │ ├── torch_gymnasium_pendulumnovel_sac.py │ │ ├── torch_gymnasium_pendulumnovel_sac_gru.py │ │ ├── torch_gymnasium_pendulumnovel_sac_lstm.py │ │ ├── torch_gymnasium_pendulumnovel_sac_rnn.py │ │ ├── torch_gymnasium_pendulumnovel_td3.py │ │ ├── torch_gymnasium_pendulumnovel_td3_gru.py │ │ ├── torch_gymnasium_pendulumnovel_td3_lstm.py │ │ ├── torch_gymnasium_pendulumnovel_td3_rnn.py │ │ ├── torch_gymnasium_pendulumnovel_trpo.py │ │ ├── torch_gymnasium_pendulumnovel_trpo_gru.py │ │ ├── torch_gymnasium_pendulumnovel_trpo_lstm.py │ │ ├── torch_gymnasium_pendulumnovel_trpo_rnn.py │ │ ├── torch_gymnasium_taxi_sarsa.py │ │ └── torch_gymnasium_taxi_vector_sarsa.py │ ├── isaacgym │ │ ├── jax_allegro_hand_ppo.py │ │ ├── jax_ant_ddpg.py │ │ ├── jax_ant_ppo.py │ │ ├── jax_ant_sac.py │ │ ├── jax_ant_td3.py │ │ ├── jax_anymal_ppo.py │ │ ├── jax_anymal_terrain_ppo.py │ │ ├── jax_ball_balance_ppo.py │ │ ├── jax_cartpole_ppo.py │ │ ├── jax_factory_task_nut_bolt_pick_ppo.py │ │ ├── jax_factory_task_nut_bolt_place_ppo.py │ │ ├── jax_factory_task_nut_bolt_screw_ppo.py │ │ ├── jax_franka_cabinet_ppo.py │ │ ├── jax_franka_cube_stack_ppo.py │ │ ├── jax_humanoid_ppo.py │ │ ├── jax_ingenuity_ppo.py │ │ ├── jax_quadcopter_ppo.py │ │ ├── jax_shadow_hand_ppo.py │ │ ├── jax_trifinger_ppo.py │ │ ├── torch_allegro_hand_ppo.py │ │ ├── torch_allegro_kuka_ppo.py │ │ ├── torch_ant_ddpg.py │ │ ├── torch_ant_ddpg_td3_sac_parallel_unshared_memory.py │ │ ├── torch_ant_ddpg_td3_sac_sequential_shared_memory.py │ │ ├── torch_ant_ddpg_td3_sac_sequential_unshared_memory.py │ │ ├── torch_ant_ppo.py │ │ ├── torch_ant_sac.py │ │ ├── torch_ant_td3.py │ │ ├── torch_anymal_ppo.py │ │ ├── torch_anymal_terrain_ppo.py │ │ ├── torch_ball_balance_ppo.py │ │ ├── torch_cartpole_ppo.py │ │ ├── torch_factory_task_nut_bolt_pick_ppo.py │ │ ├── torch_factory_task_nut_bolt_place_ppo.py │ │ ├── torch_factory_task_nut_bolt_screw_ppo.py │ │ ├── torch_franka_cabinet_ppo.py │ │ ├── torch_franka_cube_stack_ppo.py │ │ ├── torch_humanoid_amp.py │ │ ├── torch_humanoid_ppo.py │ │ ├── torch_ingenuity_ppo.py │ │ ├── torch_quadcopter_ppo.py │ │ ├── torch_shadow_hand_ppo.py │ │ ├── torch_trifinger_ppo.py │ │ └── trpo_cartpole.py │ ├── isaaclab │ │ ├── jax_ant_ddpg.py │ │ ├── jax_ant_ppo.py │ │ ├── jax_ant_sac.py │ │ ├── jax_ant_td3.py │ │ ├── jax_cartpole_ppo.py │ │ ├── jax_humanoid_ppo.py │ │ ├── jax_lift_franka_ppo.py │ │ ├── jax_reach_franka_ppo.py │ │ ├── jax_velocity_anymal_c_ppo.py │ │ ├── torch_ant_ddpg.py │ │ ├── torch_ant_ppo.py │ │ ├── torch_ant_sac.py │ │ ├── torch_ant_td3.py │ │ ├── torch_cartpole_ppo.py │ │ ├── torch_humanoid_ppo.py │ │ ├── torch_lift_franka_ppo.py │ │ ├── torch_reach_franka_ppo.py │ │ └── torch_velocity_anymal_c_ppo.py │ ├── isaacsim │ │ ├── torch_isaacsim_cartpole_ppo.py │ │ └── torch_isaacsim_jetbot_ppo.py │ ├── omniisaacgym │ │ ├── jax_allegro_hand_ppo.py │ │ ├── jax_ant_ddpg.py │ │ ├── jax_ant_mt_ppo.py │ │ ├── jax_ant_ppo.py │ │ ├── jax_ant_sac.py │ │ ├── jax_ant_td3.py │ │ ├── jax_anymal_ppo.py │ │ ├── jax_anymal_terrain_ppo.py │ │ ├── jax_ball_balance_ppo.py │ │ ├── jax_cartpole_mt_ppo.py │ │ ├── jax_cartpole_ppo.py │ │ ├── jax_crazyflie_ppo.py │ │ ├── jax_factory_task_nut_bolt_pick_ppo.py │ │ ├── jax_franka_cabinet_ppo.py │ │ ├── jax_humanoid_ppo.py │ │ ├── jax_ingenuity_ppo.py │ │ ├── jax_quadcopter_ppo.py │ │ ├── jax_shadow_hand_ppo.py │ │ ├── torch_allegro_hand_ppo.py │ │ ├── torch_ant_ddpg.py │ │ ├── torch_ant_ddpg_td3_sac_parallel_unshared_memory.py │ │ ├── torch_ant_ddpg_td3_sac_sequential_shared_memory.py │ │ ├── torch_ant_ddpg_td3_sac_sequential_unshared_memory.py │ │ ├── torch_ant_mt_ppo.py │ │ ├── torch_ant_ppo.py │ │ ├── torch_ant_sac.py │ │ ├── torch_ant_td3.py │ │ ├── torch_anymal_ppo.py │ │ ├── torch_anymal_terrain_ppo.py │ │ ├── torch_ball_balance_ppo.py │ │ ├── torch_cartpole_mt_ppo.py │ │ ├── torch_cartpole_ppo.py │ │ ├── torch_crazyflie_ppo.py │ │ ├── torch_factory_task_nut_bolt_pick_ppo.py │ │ ├── torch_franka_cabinet_ppo.py │ │ ├── torch_humanoid_ppo.py │ │ ├── torch_ingenuity_ppo.py │ │ ├── torch_quadcopter_ppo.py │ │ └── torch_shadow_hand_ppo.py │ ├── real_world │ │ ├── franka_emika_panda │ │ │ ├── reaching_franka_isaacgym_env.py │ │ │ ├── reaching_franka_isaacgym_skrl_eval.py │ │ │ ├── reaching_franka_isaacgym_skrl_train.py │ │ │ ├── reaching_franka_omniverse_isaacgym_env.py │ │ │ ├── reaching_franka_omniverse_isaacgym_skrl_eval.py │ │ │ ├── reaching_franka_omniverse_isaacgym_skrl_train.py │ │ │ ├── reaching_franka_real_env.py │ │ │ └── reaching_franka_real_skrl_eval.py │ │ └── kuka_lbr_iiwa │ │ │ ├── reaching_iiwa_omniverse_isaacgym_env.py │ │ │ ├── reaching_iiwa_omniverse_isaacgym_skrl_eval.py │ │ │ ├── reaching_iiwa_omniverse_isaacgym_skrl_train.py │ │ │ ├── reaching_iiwa_real_env.py │ │ │ ├── reaching_iiwa_real_ros2_env.py │ │ │ ├── reaching_iiwa_real_ros_env.py │ │ │ ├── reaching_iiwa_real_ros_ros2_skrl_eval.py │ │ │ └── reaching_iiwa_real_skrl_eval.py │ ├── robosuite │ │ └── td3_robosuite_two_arm_lift.py │ ├── shimmy │ │ ├── jax_shimmy_atari_pong_dqn.py │ │ ├── jax_shimmy_dm_control_acrobot_swingup_sparse_sac.py │ │ ├── jax_shimmy_openai_gym_compatibility_pendulum_ddpg.py │ │ ├── torch_shimmy_atari_pong_dqn.py │ │ ├── torch_shimmy_dm_control_acrobot_swingup_sparse_sac.py │ │ └── torch_shimmy_openai_gym_compatibility_pendulum_ddpg.py │ └── utils │ │ └── tensorboard_file_iterator.py │ ├── index.rst │ ├── intro │ ├── data.rst │ ├── examples.rst │ ├── getting_started.rst │ └── installation.rst │ └── snippets │ ├── agent.py │ ├── agents_basic_usage.py │ ├── categorical_model.py │ ├── data.py │ ├── deterministic_model.py │ ├── gaussian_model.py │ ├── isaacgym_utils.py │ ├── loaders.py │ ├── memories.py │ ├── model_instantiators.txt │ ├── model_mixin.py │ ├── multi_agent.py │ ├── multi_agents_basic_usage.py │ ├── multicategorical_model.py │ ├── multivariate_gaussian_model.py │ ├── noises.py │ ├── runner.txt │ ├── shared_model.py │ ├── tabular_model.py │ ├── trainer.py │ ├── utils_distributed.txt │ ├── utils_postprocessing.py │ └── wrapping.py ├── pyproject.toml ├── skrl ├── __init__.py ├── agents │ ├── __init__.py │ ├── jax │ │ ├── __init__.py │ │ ├── a2c │ │ │ ├── __init__.py │ │ │ └── a2c.py │ │ ├── base.py │ │ ├── cem │ │ │ ├── __init__.py │ │ │ └── cem.py │ │ ├── ddpg │ │ │ ├── __init__.py │ │ │ └── ddpg.py │ │ ├── dqn │ │ │ ├── __init__.py │ │ │ ├── ddqn.py │ │ │ └── dqn.py │ │ ├── ppo │ │ │ ├── __init__.py │ │ │ └── ppo.py │ │ ├── rpo │ │ │ ├── __init__.py │ │ │ └── rpo.py │ │ ├── sac │ │ │ ├── __init__.py │ │ │ └── sac.py │ │ └── td3 │ │ │ ├── __init__.py │ │ │ └── td3.py │ └── torch │ │ ├── __init__.py │ │ ├── a2c │ │ ├── __init__.py │ │ ├── a2c.py │ │ └── a2c_rnn.py │ │ ├── amp │ │ ├── __init__.py │ │ └── amp.py │ │ ├── base.py │ │ ├── cem │ │ ├── __init__.py │ │ └── cem.py │ │ ├── ddpg │ │ ├── __init__.py │ │ ├── ddpg.py │ │ └── ddpg_rnn.py │ │ ├── dqn │ │ ├── __init__.py │ │ ├── ddqn.py │ │ └── dqn.py │ │ ├── ppo │ │ ├── __init__.py │ │ ├── ppo.py │ │ └── ppo_rnn.py │ │ ├── q_learning │ │ ├── __init__.py │ │ └── q_learning.py │ │ ├── rpo │ │ ├── __init__.py │ │ ├── rpo.py │ │ └── rpo_rnn.py │ │ ├── sac │ │ ├── __init__.py │ │ ├── sac.py │ │ └── sac_rnn.py │ │ ├── sarsa │ │ ├── __init__.py │ │ └── sarsa.py │ │ ├── td3 │ │ ├── __init__.py │ │ ├── td3.py │ │ └── td3_rnn.py │ │ └── trpo │ │ ├── __init__.py │ │ ├── trpo.py │ │ └── trpo_rnn.py ├── envs │ ├── __init__.py │ ├── jax.py │ ├── loaders │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ ├── bidexhands_envs.py │ │ │ ├── isaacgym_envs.py │ │ │ ├── isaaclab_envs.py │ │ │ └── omniverse_isaacgym_envs.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── bidexhands_envs.py │ │ │ ├── isaacgym_envs.py │ │ │ ├── isaaclab_envs.py │ │ │ └── omniverse_isaacgym_envs.py │ ├── torch.py │ └── wrappers │ │ ├── __init__.py │ │ ├── jax │ │ ├── __init__.py │ │ ├── base.py │ │ ├── bidexhands_envs.py │ │ ├── brax_envs.py │ │ ├── gym_envs.py │ │ ├── gymnasium_envs.py │ │ ├── isaacgym_envs.py │ │ ├── isaaclab_envs.py │ │ ├── omniverse_isaacgym_envs.py │ │ └── pettingzoo_envs.py │ │ └── torch │ │ ├── __init__.py │ │ ├── base.py │ │ ├── bidexhands_envs.py │ │ ├── brax_envs.py │ │ ├── deepmind_envs.py │ │ ├── gym_envs.py │ │ ├── gymnasium_envs.py │ │ ├── isaacgym_envs.py │ │ ├── isaaclab_envs.py │ │ ├── omniverse_isaacgym_envs.py │ │ ├── pettingzoo_envs.py │ │ └── robosuite_envs.py ├── memories │ ├── __init__.py │ ├── jax │ │ ├── __init__.py │ │ ├── base.py │ │ └── random.py │ └── torch │ │ ├── __init__.py │ │ ├── base.py │ │ └── random.py ├── models │ ├── __init__.py │ ├── jax │ │ ├── __init__.py │ │ ├── base.py │ │ ├── categorical.py │ │ ├── deterministic.py │ │ ├── gaussian.py │ │ └── multicategorical.py │ └── torch │ │ ├── __init__.py │ │ ├── base.py │ │ ├── categorical.py │ │ ├── deterministic.py │ │ ├── gaussian.py │ │ ├── multicategorical.py │ │ ├── multivariate_gaussian.py │ │ └── tabular.py ├── multi_agents │ ├── __init__.py │ ├── jax │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ippo │ │ │ ├── __init__.py │ │ │ └── ippo.py │ │ └── mappo │ │ │ ├── __init__.py │ │ │ └── mappo.py │ └── torch │ │ ├── __init__.py │ │ ├── base.py │ │ ├── ippo │ │ ├── __init__.py │ │ └── ippo.py │ │ └── mappo │ │ ├── __init__.py │ │ └── mappo.py ├── resources │ ├── __init__.py │ ├── noises │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── gaussian.py │ │ │ └── ornstein_uhlenbeck.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── gaussian.py │ │ │ └── ornstein_uhlenbeck.py │ ├── optimizers │ │ ├── __init__.py │ │ └── jax │ │ │ ├── __init__.py │ │ │ └── adam.py │ ├── preprocessors │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ └── running_standard_scaler.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ └── running_standard_scaler.py │ └── schedulers │ │ ├── __init__.py │ │ ├── jax │ │ ├── __init__.py │ │ └── kl_adaptive.py │ │ └── torch │ │ ├── __init__.py │ │ └── kl_adaptive.py ├── trainers │ ├── __init__.py │ ├── jax │ │ ├── __init__.py │ │ ├── base.py │ │ ├── sequential.py │ │ └── step.py │ └── torch │ │ ├── __init__.py │ │ ├── base.py │ │ ├── parallel.py │ │ ├── sequential.py │ │ └── step.py └── utils │ ├── __init__.py │ ├── control.py │ ├── distributed │ ├── __init__.py │ └── jax │ │ ├── __main__.py │ │ └── launcher.py │ ├── huggingface.py │ ├── isaacgym_utils.py │ ├── model_instantiators │ ├── __init__.py │ ├── jax │ │ ├── __init__.py │ │ ├── categorical.py │ │ ├── common.py │ │ ├── deterministic.py │ │ ├── gaussian.py │ │ └── multicategorical.py │ └── torch │ │ ├── __init__.py │ │ ├── categorical.py │ │ ├── common.py │ │ ├── deterministic.py │ │ ├── gaussian.py │ │ ├── multicategorical.py │ │ ├── multivariate_gaussian.py │ │ └── shared.py │ ├── omniverse_isaacgym_utils.py │ ├── postprocessing.py │ ├── runner │ ├── __init__.py │ ├── jax │ │ ├── __init__.py │ │ └── runner.py │ └── torch │ │ ├── __init__.py │ │ └── runner.py │ └── spaces │ ├── __init__.py │ ├── jax │ ├── __init__.py │ └── spaces.py │ └── torch │ ├── __init__.py │ └── spaces.py └── tests ├── __init__.py ├── agents ├── __init__.py ├── jax │ ├── __init__.py │ ├── test_a2c.py │ ├── test_cem.py │ ├── test_ddpg.py │ ├── test_ddqn.py │ ├── test_dqn.py │ ├── test_ppo.py │ ├── test_rpo.py │ ├── test_sac.py │ └── test_td3.py └── torch │ ├── __init__.py │ ├── test_a2c.py │ ├── test_amp.py │ ├── test_cem.py │ ├── test_ddpg.py │ ├── test_ddqn.py │ ├── test_dqn.py │ ├── test_ppo.py │ ├── test_rpo.py │ ├── test_sac.py │ ├── test_td3.py │ └── test_trpo.py ├── envs ├── __init__.py └── wrappers │ ├── __init__.py │ ├── jax │ ├── __init__.py │ ├── test_brax_envs.py │ ├── test_gym_envs.py │ ├── test_gymnasium_envs.py │ ├── test_isaacgym_envs.py │ ├── test_isaaclab_envs.py │ ├── test_omniverse_isaacgym_envs.py │ └── test_pettingzoo_envs.py │ └── torch │ ├── __init__.py │ ├── test_brax_envs.py │ ├── test_deepmind_envs.py │ ├── test_gym_envs.py │ ├── test_gymnasium_envs.py │ ├── test_isaacgym_envs.py │ ├── test_isaaclab_envs.py │ ├── test_omniverse_isaacgym_envs.py │ └── test_pettingzoo_envs.py ├── memories ├── __init__.py └── torch │ ├── __init__.py │ └── test_base.py ├── strategies.py ├── test_jax_config.py ├── test_torch_config.py ├── utilities.py └── utils ├── __init__.py ├── model_instantiators ├── __init__.py ├── jax │ ├── __init__.py │ ├── test_definition.py │ └── test_models.py └── torch │ ├── __init__.py │ ├── test_definition.py │ └── test_models.py └── spaces ├── __init__.py ├── jax ├── __init__.py └── test_spaces.py └── torch ├── __init__.py └── test_spaces.py /.github/GITHUB_ACTIONS.md: -------------------------------------------------------------------------------- 1 | ## GitHub Actions 2 | 3 | ### Relevant links 4 | 5 | - `runs-on`: 6 | - [Standard GitHub-hosted runners for public repositories](https://docs.github.com/en/actions/using-github-hosted-runners/using-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories) 7 | - [GitHub Actions Runner Images](https://github.com/actions/runner-images) 8 | - `actions/setup-python`: 9 | - [Building and testing Python](https://docs.github.com/en/actions/use-cases-and-examples/building-and-testing/building-and-testing-python) 10 | - [Available Python versions](https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json) 11 | 12 | ### Run GitHub Actions locally with nektos/act 13 | 14 | [nektos/act](https://nektosact.com/) is a tool to run GitHub Actions locally. Install it as a [GitHub CLI](https://cli.github.com/) extension via [this steps](https://nektosact.com/installation/gh.html). 15 | 16 | #### Useful commands 17 | 18 | * List workflows/jobs: 19 | 20 | ```bash 21 | gh act -l 22 | ``` 23 | 24 | * Run a specific job: 25 | 26 | Use `--env DELETE_HOSTED_TOOL_PYTHON_CACHE=1` to delete the Python cache. 27 | 28 | ```bash 29 | gh act -j Job-ID 30 | ``` 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Submit a bug report 3 | labels: bug 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | **Your help in making skrl better is greatly appreciated!** 9 | 10 | * Please ensure that: 11 | * The issue hasn't already been reported by using the [issue search](https://github.com/Toni-SM/skrl/search?q=is%3Aissue&type=issues). 12 | * The issue (and its solution) is not listed in the skrl documentation [troubleshooting](https://skrl.readthedocs.io/en/latest/intro/installation.html#known-issues-and-troubleshooting) section. 13 | * For questions, please consider [open a discussion](https://github.com/Toni-SM/skrl/discussions). 14 |
15 | - type: textarea 16 | attributes: 17 | label: Description 18 | description: A clear and concise description of the bug/issue. Try to provide a minimal example to reproduce it (error/log messages are also helpful). 19 | placeholder: | 20 | Markdown formatting might be applied to the text. 21 | 22 | ```python 23 | # use triple backticks for code blocks or error/log messages 24 | ``` 25 | validations: 26 | required: true 27 | - type: dropdown 28 | attributes: 29 | label: What skrl version are you using? 30 | description: The skrl version can be obtained with the command `pip show skrl`. 31 | options: 32 | - --- 33 | - 1.4.3 34 | - 1.4.2 35 | - 1.4.1 36 | - 1.4.0 37 | - 1.3.0 38 | - 1.2.0 39 | - 1.1.0 40 | - 1.0.0 41 | - 1.0.0-rc2 42 | - 1.0.0-rc1 43 | - 0.10.2 or 0.10.1 44 | - 0.10.0 or earlier 45 | - develop branch 46 | validations: 47 | required: true 48 | - type: input 49 | attributes: 50 | label: What ML framework/library version are you using? 51 | description: The version can be obtained with the command `pip show torch` or `pip show jax jaxlib flax optax`. 52 | placeholder: PyTorch version, JAX/jaxlib/Flax/Optax version, etc. 53 | - type: input 54 | attributes: 55 | label: Additional system information 56 | placeholder: Python version, OS (Linux/Windows/macOS/WSL), etc. 57 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Question 4 | url: https://github.com/Toni-SM/skrl/discussions 5 | about: Please ask questions on the Discussions tab 6 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | 7 | pre-commit: 8 | name: Pre-commit hooks 9 | runs-on: ubuntu-latest 10 | steps: 11 | # setup 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.10' 17 | # install dependencies 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --quiet --upgrade pip 21 | python -m pip install --quiet pre-commit 22 | # run pre-commit 23 | - name: Run pre-commit 24 | run: | 25 | pre-commit run --all-files 26 | -------------------------------------------------------------------------------- /.github/workflows/python-publish-manual.yml: -------------------------------------------------------------------------------- 1 | name: pypi (manually triggered workflow) 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | job: 7 | description: 'Upload Python Package to PyPI/TestPyPI' 8 | required: true 9 | default: 'test-pypi' 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | 16 | pypi: 17 | name: Publish package to PyPI 18 | runs-on: ubuntu-22.04 19 | if: ${{ github.event.inputs.job == 'pypi'}} 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: '3.10.16' 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | 34 | - name: Build package 35 | run: python -m build 36 | 37 | - name: Publish package to PyPI 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | with: 40 | user: __token__ 41 | password: ${{ secrets.PYPI_API_TOKEN }} 42 | verbose: true 43 | 44 | test-pypi: 45 | name: Publish package to TestPyPI 46 | runs-on: ubuntu-22.04 47 | if: ${{ github.event.inputs.job == 'test-pypi'}} 48 | 49 | steps: 50 | - uses: actions/checkout@v3 51 | 52 | - name: Set up Python 53 | uses: actions/setup-python@v3 54 | with: 55 | python-version: '3.10.16' 56 | 57 | - name: Install dependencies 58 | run: | 59 | python -m pip install --upgrade pip 60 | pip install build 61 | 62 | - name: Build package 63 | run: python -m build 64 | 65 | - name: Publish package to TestPyPI 66 | uses: pypa/gh-action-pypi-publish@release/v1 67 | with: 68 | user: __token__ 69 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 70 | repository_url: https://test.pypi.org/legacy/ 71 | verbose: true 72 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-case-conflict 7 | - id: check-docstring-first 8 | - id: check-json 9 | - id: check-merge-conflict 10 | - id: check-toml 11 | - id: check-yaml 12 | - id: debug-statements 13 | - id: detect-private-key 14 | - id: end-of-file-fixer 15 | - id: name-tests-test 16 | args: ["--pytest-test-first"] 17 | exclude: ^(tests/strategies.py|tests/utilities.py) 18 | - id: trailing-whitespace 19 | - repo: https://github.com/codespell-project/codespell 20 | rev: v2.3.0 21 | hooks: 22 | - id: codespell 23 | exclude: ^(docs/source/_static|docs/_build|pyproject.toml) 24 | additional_dependencies: 25 | - tomli 26 | - repo: https://github.com/python/black 27 | rev: 24.8.0 28 | hooks: 29 | - id: black 30 | args: ["--line-length=120"] 31 | exclude: ^(docs/) 32 | - repo: https://github.com/pycqa/isort 33 | rev: 5.13.2 34 | hooks: 35 | - id: isort 36 | - repo: https://github.com/pre-commit/pygrep-hooks 37 | rev: v1.10.0 38 | hooks: 39 | - id: rst-backticks 40 | - id: rst-directive-colons 41 | - id: rst-inline-touching-normal 42 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the version of Python and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.10" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/source/conf.py 16 | # builder: html 17 | # fail_on_warning: false 18 | 19 | # If using Sphinx, optionally build your docs in additional formats such as PDF 20 | # formats: 21 | # - pdf 22 | 23 | # Python requirements required to build your docs 24 | python: 25 | install: 26 | - requirements: docs/requirements.txt 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Toni-SM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![pypi](https://img.shields.io/pypi/v/skrl)](https://pypi.org/project/skrl) 2 | [](https://huggingface.co/skrl) 3 | ![discussions](https://img.shields.io/github/discussions/Toni-SM/skrl) 4 |
5 | [![license](https://img.shields.io/github/license/Toni-SM/skrl)](https://github.com/Toni-SM/skrl) 6 |      7 | [![docs](https://readthedocs.org/projects/skrl/badge/?version=latest)](https://skrl.readthedocs.io/en/latest/?badge=latest) 8 | [![pre-commit](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml) 9 | [![pytest-torch](https://github.com/Toni-SM/skrl/actions/workflows/tests-torch.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/tests-torch.yml) 10 | [![pytest-jax](https://github.com/Toni-SM/skrl/actions/workflows/tests-jax.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/tests-jax.yml) 11 | 12 |
13 |

14 | 15 | 16 | 17 |

18 |

SKRL - Reinforcement Learning library

19 |
20 | 21 | **skrl** is an open-source modular library for Reinforcement Learning written in Python (on top of [PyTorch](https://pytorch.org/) and [JAX](https://jax.readthedocs.io)) and designed with a focus on modularity, readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI [Gym](https://www.gymlibrary.dev), Farama [Gymnasium](https://gymnasium.farama.org) and [PettingZoo](https://pettingzoo.farama.org), Google [DeepMind](https://github.com/deepmind/dm_env) and [Brax](https://github.com/google/brax), among other environment interfaces, it allows loading and configuring NVIDIA [Isaac Lab](https://isaac-sim.github.io/IsaacLab/index.html) (as well as [Isaac Gym](https://developer.nvidia.com/isaac-gym/) and [Omniverse Isaac Gym](https://github.com/isaac-sim/OmniIsaacGymEnvs)) environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run. 22 | 23 |
24 | 25 | ### Please, visit the documentation for usage details and examples 26 | 27 | https://skrl.readthedocs.io 28 | 29 |
30 | 31 | > **Note:** This project is under **active continuous development**. Please make sure you always have the latest version. Visit the [develop](https://github.com/Toni-SM/skrl/tree/develop) branch or its [documentation](https://skrl.readthedocs.io/en/develop) to access the latest updates to be released. 32 | 33 |
34 | 35 | ### Citing this library 36 | 37 | To cite this library in publications, please use the following reference: 38 | 39 | ```bibtex 40 | @article{serrano2023skrl, 41 | author = {Antonio Serrano-Muñoz and Dimitrios Chrysostomou and Simon Bøgh and Nestor Arana-Arexolaleiba}, 42 | title = {skrl: Modular and Flexible Library for Reinforcement Learning}, 43 | journal = {Journal of Machine Learning Research}, 44 | year = {2023}, 45 | volume = {24}, 46 | number = {254}, 47 | pages = {1--9}, 48 | url = {http://jmlr.org/papers/v24/23-0112.html} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | ## Install Sphinx and Read the Docs Sphinx Theme 4 | 5 | ```bash 6 | cd docs 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## Building the documentation 11 | 12 | ```bash 13 | cd docs 14 | make html 15 | ``` 16 | 17 | Building each time a file is changed: 18 | 19 | ```bash 20 | cd docs 21 | sphinx-autobuild ./source/ _build/html 22 | ``` 23 | 24 | ## Useful links 25 | 26 | - [Sphinx directives](https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html) 27 | - [Math support in Sphinx](https://www.sphinx-doc.org/en/1.0/ext/math.html) 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | furo==2024.8.6 2 | sphinx 3 | sphinx-tabs 4 | sphinx-autobuild 5 | sphinx-copybutton 6 | sphinx-notfound-page 7 | decorator 8 | numpy 9 | -------------------------------------------------------------------------------- /docs/source/404.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | Page not found 4 | ============== 5 | 6 | .. image:: _static/data/404-light.svg 7 | :width: 50% 8 | :align: center 9 | :class: only-light 10 | :alt: 404 11 | 12 | .. image:: _static/data/404-dark.svg 13 | :width: 50% 14 | :align: center 15 | :class: only-dark 16 | :alt: 404 17 | 18 | .. raw:: html 19 | 20 |
21 |
22 |

404: Puzzle piece not found.

23 |

Did you look under the sofa cushions?

24 |
25 |
26 |
27 | 28 | Since version 1.0.0, the documentation structure has changed to improve content organization and to provide a better browsing experience. 29 | Navigate using the left sidebar or type in the search box to find what you are looking for. 30 | -------------------------------------------------------------------------------- /docs/source/_static/css/s5defs-roles.css: -------------------------------------------------------------------------------- 1 | 2 | .black { 3 | color: black; 4 | } 5 | 6 | .gray { 7 | color: gray; 8 | } 9 | 10 | .grey { 11 | color: gray; 12 | } 13 | 14 | .silver { 15 | color: silver; 16 | } 17 | 18 | .white { 19 | color: white; 20 | } 21 | 22 | .maroon { 23 | color: maroon; 24 | } 25 | 26 | .red { 27 | color: red; 28 | } 29 | 30 | .magenta { 31 | color: magenta; 32 | } 33 | 34 | .fuchsia { 35 | color: fuchsia; 36 | } 37 | 38 | .pink { 39 | color: pink; 40 | } 41 | 42 | .orange { 43 | color: orange; 44 | } 45 | 46 | .yellow { 47 | color: yellow; 48 | } 49 | 50 | .lime { 51 | color: lime; 52 | } 53 | 54 | .green { 55 | color: #02a802; 56 | } 57 | 58 | .olive { 59 | color: olive; 60 | } 61 | 62 | .teal { 63 | color: teal; 64 | } 65 | 66 | .cyan { 67 | color: cyan; 68 | } 69 | 70 | .aqua { 71 | color: aqua; 72 | } 73 | 74 | .blue { 75 | color: #007cea; 76 | } 77 | 78 | .navy { 79 | color: navy; 80 | } 81 | 82 | .purple { 83 | color: purple; 84 | } 85 | -------------------------------------------------------------------------------- /docs/source/_static/css/skrl.css: -------------------------------------------------------------------------------- 1 | .nowrap { 2 | white-space: nowrap; 3 | } 4 | 5 | .sidebar-brand-text { 6 | font-size: 1.25rem !important; 7 | } 8 | 9 | tbody > tr > th.stub { 10 | font-weight: normal; 11 | text-align: left; 12 | } 13 | -------------------------------------------------------------------------------- /docs/source/_static/data/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/data/logo-dark-mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/logo-dark-mode.png -------------------------------------------------------------------------------- /docs/source/_static/data/logo-light-mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/logo-light-mode.png -------------------------------------------------------------------------------- /docs/source/_static/data/logo-torch.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | image/svg+xml 46 | 49 | 54 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /docs/source/_static/data/skrl-up-transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/skrl-up-transparent.png -------------------------------------------------------------------------------- /docs/source/_static/data/skrl-up.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/skrl-up.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/data_tensorboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/data_tensorboard.jpg -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_bidexhands.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_bidexhands.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_deepmind.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_deepmind.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_gym.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_gym.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_isaacgym.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_isaacgym.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_isaaclab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_isaaclab.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_omniverse_isaacgym.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_omniverse_isaacgym.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_parallel.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_parallel.jpg -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_robosuite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_robosuite.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/example_shimmy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_shimmy.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/noise_gaussian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/noise_gaussian.png -------------------------------------------------------------------------------- /docs/source/_static/imgs/noise_ornstein_uhlenbeck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/noise_ornstein_uhlenbeck.png -------------------------------------------------------------------------------- /docs/source/api/envs.rst: -------------------------------------------------------------------------------- 1 | Environments 2 | ============ 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | Wrapping (single-agent) 8 | Wrapping (multi-agents) 9 | Isaac Lab environments 10 | Isaac Gym environments 11 | Omniverse Isaac Gym environments 12 | 13 | The environment plays a fundamental and crucial role in defining the RL setup. It is the place where the agent interacts, and it is responsible for providing the agent with information about its current state, as well as the rewards/penalties associated with each action. 14 | 15 | .. raw:: html 16 | 17 |

18 | 19 | In this section you will find how to load environments from NVIDIA Isaac Lab (as well as Isaac Gym and Omniverse Isaac Gym) with a simple function. 20 | 21 | .. list-table:: 22 | :header-rows: 1 23 | 24 | * - Loaders 25 | - .. centered:: |_4| |pytorch| |_4| 26 | - .. centered:: |_4| |jax| |_4| 27 | * - :doc:`Isaac Lab environments ` 28 | - .. centered:: :math:`\blacksquare` 29 | - .. centered:: :math:`\blacksquare` 30 | * - :doc:`Isaac Gym environments ` 31 | - .. centered:: :math:`\blacksquare` 32 | - .. centered:: :math:`\blacksquare` 33 | * - :doc:`Omniverse Isaac Gym environments ` 34 | - .. centered:: :math:`\blacksquare` 35 | - .. centered:: :math:`\blacksquare` 36 | 37 | In addition, you will be able to :doc:`wrap single-agent ` and :doc:`multi-agent ` RL environment interfaces. 38 | 39 | .. list-table:: 40 | :header-rows: 1 41 | 42 | * - Wrappers 43 | - .. centered:: |_4| |pytorch| |_4| 44 | - .. centered:: |_4| |jax| |_4| 45 | * - Bi-DexHands 46 | - .. centered:: :math:`\blacksquare` 47 | - .. centered:: :math:`\blacksquare` 48 | * - Brax 49 | - .. centered:: :math:`\blacksquare` 50 | - .. centered:: :math:`\blacksquare` 51 | * - DeepMind 52 | - .. centered:: :math:`\blacksquare` 53 | - .. centered:: :math:`\square` 54 | * - Gym 55 | - .. centered:: :math:`\blacksquare` 56 | - .. centered:: :math:`\blacksquare` 57 | * - Gymnasium 58 | - .. centered:: :math:`\blacksquare` 59 | - .. centered:: :math:`\blacksquare` 60 | * - Isaac Lab 61 | - .. centered:: :math:`\blacksquare` 62 | - .. centered:: :math:`\blacksquare` 63 | * - Isaac Gym (previews) 64 | - .. centered:: :math:`\blacksquare` 65 | - .. centered:: :math:`\blacksquare` 66 | * - Omniverse Isaac Gym |_5| |_5| |_5| |_5| |_2| 67 | - .. centered:: :math:`\blacksquare` 68 | - .. centered:: :math:`\blacksquare` 69 | * - PettingZoo 70 | - .. centered:: :math:`\blacksquare` 71 | - .. centered:: :math:`\blacksquare` 72 | * - robosuite 73 | - .. centered:: :math:`\blacksquare` 74 | - .. centered:: :math:`\square` 75 | * - Shimmy 76 | - .. centered:: :math:`\blacksquare` 77 | - .. centered:: :math:`\blacksquare` 78 | -------------------------------------------------------------------------------- /docs/source/api/envs/isaaclab.rst: -------------------------------------------------------------------------------- 1 | Isaac Lab environments 2 | ====================== 3 | 4 | .. image:: ../../_static/imgs/example_isaaclab.png 5 | :width: 100% 6 | :align: center 7 | :alt: Isaac Lab environments 8 | 9 | .. raw:: html 10 | 11 |


12 | 13 | Environments 14 | ------------ 15 | 16 | The repository https://github.com/isaac-sim/IsaacLab provides the example reinforcement learning environments for Isaac Lab (Orbit and Omniverse Isaac Gym unification). 17 | 18 | These environments can be easily loaded and configured by calling a single function provided with this library. This function also makes it possible to configure the environment from the command line arguments (see Isaac Lab's `Training with an RL Agent `_) or from its parameters (:literal:`task_name`, :literal:`num_envs`, :literal:`headless`, and :literal:`cli_args`). 19 | 20 | .. note:: 21 | 22 | The command line arguments has priority over the function parameters. 23 | 24 | .. note:: 25 | 26 | Isaac Lab environments implement a functionality to get their configuration from the command line. Setting the :literal:`headless` option from the trainer configuration will not work. In this case, it is necessary to set the load function's :literal:`headless` argument to True or to invoke the scripts as follows: :literal:`isaaclab -p script.py --headless`. 27 | 28 | .. raw:: html 29 | 30 |
31 | 32 | Usage 33 | ^^^^^ 34 | 35 | .. tabs:: 36 | 37 | .. tab:: Function parameters 38 | 39 | .. tabs:: 40 | 41 | .. group-tab:: |_4| |pytorch| |_4| 42 | 43 | .. literalinclude:: ../../snippets/loaders.py 44 | :language: python 45 | :emphasize-lines: 2, 5 46 | :start-after: [start-isaaclab-envs-parameters-torch] 47 | :end-before: [end-isaaclab-envs-parameters-torch] 48 | 49 | .. group-tab:: |_4| |jax| |_4| 50 | 51 | .. literalinclude:: ../../snippets/loaders.py 52 | :language: python 53 | :emphasize-lines: 2, 5 54 | :start-after: [start-isaaclab-envs-parameters-jax] 55 | :end-before: [end-isaaclab-envs-parameters-jax] 56 | 57 | .. tab:: Command line arguments (priority) 58 | 59 | .. tabs:: 60 | 61 | .. group-tab:: |_4| |pytorch| |_4| 62 | 63 | .. literalinclude:: ../../snippets/loaders.py 64 | :language: python 65 | :emphasize-lines: 2, 5 66 | :start-after: [start-isaaclab-envs-cli-torch] 67 | :end-before: [end-isaaclab-envs-cli-torch] 68 | 69 | .. group-tab:: |_4| |jax| |_4| 70 | 71 | .. literalinclude:: ../../snippets/loaders.py 72 | :language: python 73 | :emphasize-lines: 2, 5 74 | :start-after: [start-isaaclab-envs-cli-jax] 75 | :end-before: [end-isaaclab-envs-cli-jax] 76 | 77 | Run the main script passing the configuration as command line arguments. For example: 78 | 79 | .. code-block:: 80 | 81 | isaaclab -p main.py --task Isaac-Cartpole-v0 82 | 83 | .. raw:: html 84 | 85 |
86 | 87 | API 88 | ^^^ 89 | 90 | .. autofunction:: skrl.envs.loaders.torch.load_isaaclab_env 91 | -------------------------------------------------------------------------------- /docs/source/api/memories.rst: -------------------------------------------------------------------------------- 1 | Memories 2 | ======== 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | Random 8 | 9 | Memories are storage components that allow agents to collect and use/reuse current or past experiences of their interaction with the environment or other types of information. 10 | 11 | .. raw:: html 12 | 13 |

14 | 15 | .. list-table:: 16 | :header-rows: 1 17 | 18 | * - Memories 19 | - .. centered:: |_4| |pytorch| |_4| 20 | - .. centered:: |_4| |jax| |_4| 21 | * - :doc:`Random memory ` 22 | - .. centered:: :math:`\blacksquare` 23 | - .. centered:: :math:`\blacksquare` 24 | 25 | Base class 26 | ---------- 27 | 28 | .. note:: 29 | 30 | This is the base class for all the other classes in this module. 31 | It provides the basic functionality for the other classes. 32 | **It is not intended to be used directly**. 33 | 34 | .. raw:: html 35 | 36 |
37 | 38 | Basic inheritance usage 39 | ^^^^^^^^^^^^^^^^^^^^^^^ 40 | 41 | .. tabs:: 42 | 43 | .. group-tab:: |_4| |pytorch| |_4| 44 | 45 | .. literalinclude:: ../snippets/memories.py 46 | :language: python 47 | :start-after: [start-base-class-torch] 48 | :end-before: [end-base-class-torch] 49 | 50 | .. group-tab:: |_4| |jax| |_4| 51 | 52 | .. literalinclude:: ../snippets/memories.py 53 | :language: python 54 | :start-after: [start-base-class-jax] 55 | :end-before: [end-base-class-jax] 56 | 57 | .. raw:: html 58 | 59 |
60 | 61 | API (PyTorch) 62 | ^^^^^^^^^^^^^ 63 | 64 | .. autoclass:: skrl.memories.torch.base.Memory 65 | :undoc-members: 66 | :show-inheritance: 67 | :members: 68 | 69 | .. automethod:: __len__ 70 | 71 | .. raw:: html 72 | 73 |
74 | 75 | API (JAX) 76 | ^^^^^^^^^ 77 | 78 | .. autoclass:: skrl.memories.jax.base.Memory 79 | :undoc-members: 80 | :show-inheritance: 81 | :members: 82 | 83 | .. automethod:: __len__ 84 | -------------------------------------------------------------------------------- /docs/source/api/memories/random.rst: -------------------------------------------------------------------------------- 1 | Random memory 2 | ============= 3 | 4 | Random sampling memory 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Usage 11 | ----- 12 | 13 | .. tabs:: 14 | 15 | .. group-tab:: |_4| |pytorch| |_4| 16 | 17 | .. literalinclude:: ../../snippets/memories.py 18 | :language: python 19 | :emphasize-lines: 2, 5 20 | :start-after: [start-random-torch] 21 | :end-before: [end-random-torch] 22 | 23 | .. group-tab:: |_4| |jax| |_4| 24 | 25 | .. literalinclude:: ../../snippets/memories.py 26 | :language: python 27 | :emphasize-lines: 2, 5 28 | :start-after: [start-random-jax] 29 | :end-before: [end-random-jax] 30 | 31 | .. raw:: html 32 | 33 |
34 | 35 | API (PyTorch) 36 | ------------- 37 | 38 | .. autoclass:: skrl.memories.torch.random.RandomMemory 39 | :undoc-members: 40 | :show-inheritance: 41 | :inherited-members: 42 | :members: 43 | 44 | .. automethod:: __len__ 45 | 46 | .. raw:: html 47 | 48 |
49 | 50 | API (JAX) 51 | --------- 52 | 53 | .. autoclass:: skrl.memories.jax.random.RandomMemory 54 | :undoc-members: 55 | :show-inheritance: 56 | :inherited-members: 57 | :members: 58 | 59 | .. automethod:: __len__ 60 | -------------------------------------------------------------------------------- /docs/source/api/models/shared_model.rst: -------------------------------------------------------------------------------- 1 | Shared model 2 | ============ 3 | 4 | Sometimes it is desirable to define models that use shared layers or network to represent multiple function approximators. This practice, known as *shared parameters* (or *parameter sharing*), *shared layers*, *shared model*, *shared networks* or *joint architecture* among others, is typically justified by the following criteria: 5 | 6 | * Learning the same characteristics, especially when processing large inputs (such as images, e.g.). 7 | 8 | * Reduce the number of parameters in the whole system. 9 | 10 | * Make the computation more efficient (single forward-pass). 11 | 12 | .. raw:: html 13 | 14 |

15 | 16 | Implementation 17 | -------------- 18 | 19 | By combining the implemented mixins, it is possible to define shared models with skrl. In these cases, the use of the :literal:`role` argument (a Python string) is relevant. The agents will call the models by setting the :literal:`role` argument according to their requirements. Visit each agent's documentation (*Key* column of the table under *Spaces and models* section) to know the possible values that this parameter can take. 20 | 21 | The code snippet below shows how to define a shared model. The following practices for building shared models can be identified: 22 | 23 | * The definition of multiple inheritance must always include the :ref:`Model ` base class at the end. 24 | 25 | * The :ref:`Model ` base class constructor must be invoked before the mixins constructor. 26 | 27 | * All mixin constructors must be invoked. 28 | 29 | * Specify :literal:`role` argument is optional if all constructors belong to different mixins. 30 | 31 | * If multiple models of the same mixin type are required, the same constructor must be invoked as many times as needed. To do so, it is mandatory to specify the :literal:`role` argument. 32 | 33 | * The :literal:`.act(...)` method needs to be overridden to disambiguate its call. 34 | 35 | * The same instance of the shared model must be passed to all keys involved. 36 | 37 | .. raw:: html 38 | 39 |
40 | 41 | .. tabs:: 42 | 43 | .. group-tab:: |_4| |pytorch| |_4| 44 | 45 | .. tabs:: 46 | 47 | .. group-tab:: Single forward-pass 48 | 49 | .. warning:: 50 | 51 | The implementation described for single forward-pass requires that the value-pass always follows the policy-pass (e.g.: ``PPO``) which may not be generalized to other algorithms. 52 | 53 | If this requirement is not met, other forms of "chaching" the shared layers/network output could be implemented. 54 | 55 | .. literalinclude:: ../../snippets/shared_model.py 56 | :language: python 57 | :start-after: [start-mlp-single-forward-pass-torch] 58 | :end-before: [end-mlp-single-forward-pass-torch] 59 | 60 | .. group-tab:: Multiple forward-pass 61 | 62 | .. literalinclude:: ../../snippets/shared_model.py 63 | :language: python 64 | :start-after: [start-mlp-multi-forward-pass-torch] 65 | :end-before: [end-mlp-multi-forward-pass-torch] 66 | -------------------------------------------------------------------------------- /docs/source/api/models/tabular.rst: -------------------------------------------------------------------------------- 1 | .. _models_tabular: 2 | 3 | Tabular model 4 | ============= 5 | 6 | Tabular models run **discrete-domain deterministic/stochastic** policies. 7 | 8 | .. raw:: html 9 | 10 |

11 | 12 | skrl provides a Python mixin (:literal:`TabularMixin`) to assist in the creation of these types of models, allowing users to have full control over the table definitions. Note that the use of this mixin must comply with the following rules: 13 | 14 | * The definition of multiple inheritance must always include the :ref:`Model ` base class at the end. 15 | 16 | * The :ref:`Model ` base class constructor must be invoked before the mixins constructor. 17 | 18 | .. tabs:: 19 | 20 | .. group-tab:: |_4| |pytorch| |_4| 21 | 22 | .. literalinclude:: ../../snippets/tabular_model.py 23 | :language: python 24 | :emphasize-lines: 1, 3-4 25 | :start-after: [start-definition-torch] 26 | :end-before: [end-definition-torch] 27 | 28 | .. raw:: html 29 | 30 |
31 | 32 | Usage 33 | ----- 34 | 35 | .. tabs:: 36 | 37 | .. tab:: :math:`\epsilon`-greedy policy 38 | 39 | .. tabs:: 40 | 41 | .. group-tab:: |_4| |pytorch| |_4| 42 | 43 | .. literalinclude:: ../../snippets/tabular_model.py 44 | :language: python 45 | :start-after: [start-epsilon-greedy-torch] 46 | :end-before: [end-epsilon-greedy-torch] 47 | 48 | .. raw:: html 49 | 50 |
51 | 52 | API (PyTorch) 53 | ------------- 54 | 55 | .. autoclass:: skrl.models.torch.tabular.TabularMixin 56 | :show-inheritance: 57 | :exclude-members: to, state_dict, load_state_dict, load, save 58 | :members: 59 | -------------------------------------------------------------------------------- /docs/source/api/multi_agents.rst: -------------------------------------------------------------------------------- 1 | Multi-agents 2 | ============ 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | IPPO 8 | MAPPO 9 | 10 | Multi-agents are autonomous entities that interact with the environment to learn and improve their behavior. Multi-agents' goal is to learn optimal policies, which are correspondence between states and actions that maximize the cumulative reward received from the environment over time. 11 | 12 | .. raw:: html 13 | 14 |

15 | 16 | .. list-table:: 17 | :header-rows: 1 18 | 19 | * - Multi-agents 20 | - .. centered:: |_4| |pytorch| |_4| 21 | - .. centered:: |_4| |jax| |_4| 22 | * - :doc:`Independent Proximal Policy Optimization ` (**IPPO**) 23 | - .. centered:: :math:`\blacksquare` 24 | - .. centered:: :math:`\blacksquare` 25 | * - :doc:`Multi-Agent Proximal Policy Optimization ` (**MAPPO**) 26 | - .. centered:: :math:`\blacksquare` 27 | - .. centered:: :math:`\blacksquare` 28 | 29 | Base class 30 | ---------- 31 | 32 | .. note:: 33 | 34 | This is the base class for all multi-agents and provides only basic functionality that is not tied to any implementation of the optimization algorithms. 35 | **It is not intended to be used directly**. 36 | 37 | .. raw:: html 38 | 39 |
40 | 41 | Basic inheritance usage 42 | ^^^^^^^^^^^^^^^^^^^^^^^ 43 | 44 | .. tabs:: 45 | 46 | .. tab:: Inheritance 47 | 48 | .. tabs:: 49 | 50 | .. group-tab:: |_4| |pytorch| |_4| 51 | 52 | .. literalinclude:: ../snippets/multi_agent.py 53 | :language: python 54 | :start-after: [start-multi-agent-base-class-torch] 55 | :end-before: [end-multi-agent-base-class-torch] 56 | 57 | .. group-tab:: |_4| |jax| |_4| 58 | 59 | .. literalinclude:: ../snippets/multi_agent.py 60 | :language: python 61 | :start-after: [start-multi-agent-base-class-jax] 62 | :end-before: [end-multi-agent-base-class-jax] 63 | 64 | .. raw:: html 65 | 66 |
67 | 68 | API (PyTorch) 69 | ^^^^^^^^^^^^^ 70 | 71 | .. autoclass:: skrl.multi_agents.torch.base.MultiAgent 72 | :undoc-members: 73 | :show-inheritance: 74 | :inherited-members: 75 | :private-members: _update, _empty_preprocessor, _get_internal_value, _as_dict 76 | :members: 77 | 78 | .. automethod:: __str__ 79 | 80 | .. raw:: html 81 | 82 |
83 | 84 | API (JAX) 85 | ^^^^^^^^^ 86 | 87 | .. autoclass:: skrl.multi_agents.jax.base.MultiAgent 88 | :undoc-members: 89 | :show-inheritance: 90 | :inherited-members: 91 | :private-members: _update, _empty_preprocessor, _get_internal_value, _as_dict 92 | :members: 93 | 94 | .. automethod:: __str__ 95 | -------------------------------------------------------------------------------- /docs/source/api/resources.rst: -------------------------------------------------------------------------------- 1 | Resources 2 | ========= 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | Noises 8 | Preprocessors 9 | Learning rate schedulers 10 | Optimizers 11 | 12 | Resources groups a variety of components that may be used to improve the agents' performance. 13 | 14 | .. raw:: html 15 | 16 |

17 | 18 | Available resources are :doc:`noises `, input :doc:`preprocessors `, learning rate :doc:`schedulers ` and :doc:`optimizers ` (this last one only for JAX). 19 | 20 | .. list-table:: 21 | :header-rows: 1 22 | 23 | * - Noises 24 | - .. centered:: |_4| |pytorch| |_4| 25 | - .. centered:: |_4| |jax| |_4| 26 | * - :doc:`Gaussian ` noise 27 | - .. centered:: :math:`\blacksquare` 28 | - .. centered:: :math:`\blacksquare` 29 | * - :doc:`Ornstein-Uhlenbeck ` noise |_2| 30 | - .. centered:: :math:`\blacksquare` 31 | - .. centered:: :math:`\blacksquare` 32 | 33 | .. list-table:: 34 | :header-rows: 1 35 | 36 | * - Preprocessors 37 | - .. centered:: |_4| |pytorch| |_4| 38 | - .. centered:: |_4| |jax| |_4| 39 | * - :doc:`Running standard scaler ` |_4| 40 | - .. centered:: :math:`\blacksquare` 41 | - .. centered:: :math:`\blacksquare` 42 | 43 | .. list-table:: 44 | :header-rows: 1 45 | 46 | * - Learning rate schedulers 47 | - .. centered:: |_4| |pytorch| |_4| 48 | - .. centered:: |_4| |jax| |_4| 49 | * - :doc:`KL Adaptive ` 50 | - .. centered:: :math:`\blacksquare` 51 | - .. centered:: :math:`\blacksquare` 52 | 53 | .. list-table:: 54 | :header-rows: 1 55 | 56 | * - Optimizers 57 | - .. centered:: |_4| |pytorch| |_4| 58 | - .. centered:: |_4| |jax| |_4| 59 | * - :doc:`Adam `\ |_5| |_5| |_5| |_5| |_5| |_5| |_3| 60 | - .. centered:: :math:`\scriptscriptstyle \texttt{PyTorch}` 61 | - .. centered:: :math:`\blacksquare` 62 | -------------------------------------------------------------------------------- /docs/source/api/resources/noises.rst: -------------------------------------------------------------------------------- 1 | Noises 2 | ====== 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | Gaussian noise 8 | Ornstein-Uhlenbeck 9 | 10 | Definition of the noises used by the agents during the exploration stage. All noises inherit from a base class that defines a uniform interface. 11 | 12 | .. raw:: html 13 | 14 |

15 | 16 | .. list-table:: 17 | :header-rows: 1 18 | 19 | * - Noises 20 | - .. centered:: |_4| |pytorch| |_4| 21 | - .. centered:: |_4| |jax| |_4| 22 | * - :doc:`Gaussian ` noise 23 | - .. centered:: :math:`\blacksquare` 24 | - .. centered:: :math:`\blacksquare` 25 | * - :doc:`Ornstein-Uhlenbeck ` noise |_2| 26 | - .. centered:: :math:`\blacksquare` 27 | - .. centered:: :math:`\blacksquare` 28 | 29 | Base class 30 | ---------- 31 | 32 | .. note:: 33 | 34 | This is the base class for all the other classes in this module. 35 | It provides the basic functionality for the other classes. 36 | **It is not intended to be used directly**. 37 | 38 | .. raw:: html 39 | 40 |
41 | 42 | Basic inheritance usage 43 | ^^^^^^^^^^^^^^^^^^^^^^^ 44 | 45 | .. tabs:: 46 | 47 | .. group-tab:: |_4| |pytorch| |_4| 48 | 49 | .. literalinclude:: ../../snippets/noises.py 50 | :language: python 51 | :start-after: [start-base-class-torch] 52 | :end-before: [end-base-class-torch] 53 | 54 | .. group-tab:: |_4| |jax| |_4| 55 | 56 | .. literalinclude:: ../../snippets/noises.py 57 | :language: python 58 | :start-after: [start-base-class-jax] 59 | :end-before: [end-base-class-jax] 60 | 61 | .. raw:: html 62 | 63 |
64 | 65 | API (PyTorch) 66 | ^^^^^^^^^^^^^ 67 | 68 | .. autoclass:: skrl.resources.noises.torch.base.Noise 69 | :undoc-members: 70 | :show-inheritance: 71 | :inherited-members: 72 | :members: 73 | 74 | .. raw:: html 75 | 76 |
77 | 78 | API (JAX) 79 | ^^^^^^^^^ 80 | 81 | .. autoclass:: skrl.resources.noises.jax.base.Noise 82 | :undoc-members: 83 | :show-inheritance: 84 | :inherited-members: 85 | :members: 86 | -------------------------------------------------------------------------------- /docs/source/api/resources/noises/gaussian.rst: -------------------------------------------------------------------------------- 1 | .. _gaussian-noise: 2 | 3 | Gaussian noise 4 | ============== 5 | 6 | Noise generated by normal distribution. 7 | 8 | .. raw:: html 9 | 10 |

11 | 12 | Usage 13 | ----- 14 | 15 | The noise usage is defined in each agent's configuration dictionary. A noise instance is set under the :literal:`"noise"` sub-key. The following examples show how to set the noise for an agent: 16 | 17 | | 18 | 19 | .. image:: ../../../_static/imgs/noise_gaussian.png 20 | :width: 75% 21 | :align: center 22 | :alt: Gaussian noise 23 | 24 | .. raw:: html 25 | 26 |

27 | 28 | .. tabs:: 29 | 30 | .. group-tab:: |_4| |pytorch| |_4| 31 | 32 | .. literalinclude:: ../../../snippets/noises.py 33 | :language: python 34 | :emphasize-lines: 1, 4 35 | :start-after: [torch-start-gaussian] 36 | :end-before: [torch-end-gaussian] 37 | 38 | .. group-tab:: |_4| |jax| |_4| 39 | 40 | .. literalinclude:: ../../../snippets/noises.py 41 | :language: python 42 | :emphasize-lines: 1, 4 43 | :start-after: [jax-start-gaussian] 44 | :end-before: [jax-end-gaussian] 45 | 46 | .. raw:: html 47 | 48 |
49 | 50 | API (PyTorch) 51 | ------------- 52 | 53 | .. autoclass:: skrl.resources.noises.torch.gaussian.GaussianNoise 54 | :undoc-members: 55 | :show-inheritance: 56 | :inherited-members: 57 | :private-members: _update 58 | :members: 59 | 60 | .. raw:: html 61 | 62 |
63 | 64 | API (JAX) 65 | --------- 66 | 67 | .. autoclass:: skrl.resources.noises.jax.gaussian.GaussianNoise 68 | :undoc-members: 69 | :show-inheritance: 70 | :inherited-members: 71 | :private-members: _update 72 | :members: 73 | -------------------------------------------------------------------------------- /docs/source/api/resources/noises/ornstein_uhlenbeck.rst: -------------------------------------------------------------------------------- 1 | .. _ornstein-uhlenbeck-noise: 2 | 3 | Ornstein-Uhlenbeck noise 4 | ======================== 5 | 6 | Noise generated by a stochastic process that is characterized by its mean-reverting behavior. 7 | 8 | .. raw:: html 9 | 10 |

11 | 12 | Usage 13 | ----- 14 | 15 | The noise usage is defined in each agent's configuration dictionary. A noise instance is set under the :literal:`"noise"` sub-key. The following examples show how to set the noise for an agent: 16 | 17 | | 18 | 19 | .. image:: ../../../_static/imgs/noise_ornstein_uhlenbeck.png 20 | :width: 75% 21 | :align: center 22 | :alt: Ornstein-Uhlenbeck noise 23 | 24 | .. raw:: html 25 | 26 |

27 | 28 | .. tabs:: 29 | 30 | .. group-tab:: |_4| |pytorch| |_4| 31 | 32 | .. literalinclude:: ../../../snippets/noises.py 33 | :language: python 34 | :emphasize-lines: 1, 4 35 | :start-after: [torch-start-ornstein-uhlenbeck] 36 | :end-before: [torch-end-ornstein-uhlenbeck] 37 | 38 | .. group-tab:: |_4| |jax| |_4| 39 | 40 | .. literalinclude:: ../../../snippets/noises.py 41 | :language: python 42 | :emphasize-lines: 1, 4 43 | :start-after: [jax-start-ornstein-uhlenbeck] 44 | :end-before: [jax-end-ornstein-uhlenbeck] 45 | 46 | .. raw:: html 47 | 48 |
49 | 50 | API (PyTorch) 51 | ------------- 52 | 53 | .. autoclass:: skrl.resources.noises.torch.ornstein_uhlenbeck.OrnsteinUhlenbeckNoise 54 | :undoc-members: 55 | :show-inheritance: 56 | :inherited-members: 57 | :private-members: _update 58 | :members: 59 | 60 | .. raw:: html 61 | 62 |
63 | 64 | API (JAX) 65 | --------- 66 | 67 | .. autoclass:: skrl.resources.noises.jax.ornstein_uhlenbeck.OrnsteinUhlenbeckNoise 68 | :undoc-members: 69 | :show-inheritance: 70 | :inherited-members: 71 | :private-members: _update 72 | :members: 73 | -------------------------------------------------------------------------------- /docs/source/api/resources/optimizers.rst: -------------------------------------------------------------------------------- 1 | Optimizers 2 | ========== 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | Adam 8 | 9 | Optimizers are algorithms that adjust the parameters of artificial neural networks to minimize the error or loss function during the training process. 10 | 11 | .. raw:: html 12 | 13 |

14 | 15 | .. list-table:: 16 | :header-rows: 1 17 | 18 | * - Optimizers 19 | - .. centered:: |_4| |pytorch| |_4| 20 | - .. centered:: |_4| |jax| |_4| 21 | * - :doc:`Adam `\ |_5| |_5| |_5| |_5| |_5| |_5| |_3| 22 | - .. centered:: :math:`\scriptscriptstyle \texttt{PyTorch}` 23 | - .. centered:: :math:`\blacksquare` 24 | -------------------------------------------------------------------------------- /docs/source/api/resources/optimizers/adam.rst: -------------------------------------------------------------------------------- 1 | Adam 2 | ==== 3 | 4 | An extension of the stochastic gradient descent algorithm that adaptively changes the learning rate for each neural network parameter. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Usage 11 | ----- 12 | 13 | .. note:: 14 | 15 | This class is the result of isolating the Optax optimizer that is mixed with the model parameters, as defined in the `Flax's TrainState `_ class. It is not intended to be used directly by the user, but by agent implementations. 16 | 17 | .. tabs:: 18 | 19 | .. group-tab:: |_4| |jax| |_4| 20 | 21 | .. code-block:: python 22 | :emphasize-lines: 2, 5, 8 23 | 24 | # import the optimizer class 25 | from skrl.resources.optimizers.jax import Adam 26 | 27 | # instantiate the optimizer 28 | optimizer = Adam(model=model, lr=1e-3) 29 | 30 | # step the optimizer 31 | optimizer = optimizer.step(grad, model) 32 | 33 | .. raw:: html 34 | 35 |
36 | 37 | API (JAX) 38 | --------- 39 | 40 | .. autoclass:: skrl.resources.optimizers.jax.adam.Adam 41 | :show-inheritance: 42 | :inherited-members: 43 | :members: 44 | 45 | .. automethod:: __new__ 46 | -------------------------------------------------------------------------------- /docs/source/api/resources/preprocessors.rst: -------------------------------------------------------------------------------- 1 | Preprocessors 2 | ============= 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | Running standard scaler 8 | 9 | Preprocessors are functions used to transform or encode raw input data into a form more suitable for the learning algorithm. 10 | 11 | .. raw:: html 12 | 13 |

14 | 15 | .. list-table:: 16 | :header-rows: 1 17 | 18 | * - Preprocessors 19 | - .. centered:: |_4| |pytorch| |_4| 20 | - .. centered:: |_4| |jax| |_4| 21 | * - :doc:`Running standard scaler ` |_4| 22 | - .. centered:: :math:`\blacksquare` 23 | - .. centered:: :math:`\blacksquare` 24 | -------------------------------------------------------------------------------- /docs/source/api/resources/schedulers/kl_adaptive.rst: -------------------------------------------------------------------------------- 1 | KL Adaptive 2 | =========== 3 | 4 | Adjust the learning rate according to the value of the Kullback-Leibler (KL) divergence. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Algorithm 11 | --------- 12 | 13 | .. raw:: html 14 | 15 |
16 | 17 | Algorithm implementation 18 | ^^^^^^^^^^^^^^^^^^^^^^^^ 19 | 20 | The learning rate (:math:`\eta`) at each step is modified as follows: 21 | 22 | | **IF** :math:`\; KL >` :guilabel:`kl_factor` :guilabel:`kl_threshold` **THEN** 23 | | :math:`\eta_{t + 1} = \max(\eta_t \,/` :guilabel:`lr_factor` :math:`,` :guilabel:`min_lr` :math:`)` 24 | | **IF** :math:`\; KL <` :guilabel:`kl_threshold` :math:`/` :guilabel:`kl_factor` **THEN** 25 | | :math:`\eta_{t + 1} = \min(` :guilabel:`lr_factor` :math:`\eta_t,` :guilabel:`max_lr` :math:`)` 26 | 27 | .. raw:: html 28 | 29 |
30 | 31 | Usage 32 | ----- 33 | 34 | The learning rate scheduler usage is defined in each agent's configuration dictionary. The scheduler class is set under the :literal:`"learning_rate_scheduler"` key and its arguments are set under the :literal:`"learning_rate_scheduler_kwargs"` key as a keyword argument dictionary, without specifying the optimizer (first argument). The following examples show how to set the scheduler for an agent: 35 | 36 | .. tabs:: 37 | 38 | .. group-tab:: |_4| |pytorch| |_4| 39 | 40 | .. code-block:: python 41 | :emphasize-lines: 2, 5-6 42 | 43 | # import the scheduler class 44 | from skrl.resources.schedulers.torch import KLAdaptiveLR 45 | 46 | cfg = DEFAULT_CONFIG.copy() 47 | cfg["learning_rate_scheduler"] = KLAdaptiveLR 48 | cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01} 49 | 50 | .. group-tab:: |_4| |jax| |_4| 51 | 52 | .. code-block:: python 53 | :emphasize-lines: 2, 5-6 54 | 55 | # import the scheduler class 56 | from skrl.resources.schedulers.jax import KLAdaptiveLR # or kl_adaptive (Optax style) 57 | 58 | cfg = DEFAULT_CONFIG.copy() 59 | cfg["learning_rate_scheduler"] = KLAdaptiveLR 60 | cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01} 61 | 62 | .. raw:: html 63 | 64 |
65 | 66 | API (PyTorch) 67 | ------------- 68 | 69 | .. autoclass:: skrl.resources.schedulers.torch.kl_adaptive.KLAdaptiveLR 70 | :show-inheritance: 71 | :inherited-members: 72 | :members: 73 | 74 | .. raw:: html 75 | 76 |
77 | 78 | API (JAX) 79 | --------- 80 | 81 | .. autofunction:: skrl.resources.schedulers.jax.kl_adaptive.KLAdaptiveLR 82 | -------------------------------------------------------------------------------- /docs/source/api/trainers.rst: -------------------------------------------------------------------------------- 1 | Trainers 2 | ======== 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | Sequential 8 | Parallel 9 | Step 10 | Manual training 11 | 12 | Trainers are responsible for orchestrating and managing the training/evaluation of agents and their interactions with the environment. 13 | 14 | .. raw:: html 15 | 16 |

17 | 18 | .. list-table:: 19 | :header-rows: 1 20 | 21 | * - Trainers 22 | - .. centered:: |_4| |pytorch| |_4| 23 | - .. centered:: |_4| |jax| |_4| 24 | * - :doc:`Sequential trainer ` 25 | - .. centered:: :math:`\blacksquare` 26 | - .. centered:: :math:`\blacksquare` 27 | * - :doc:`Parallel trainer ` 28 | - .. centered:: :math:`\blacksquare` 29 | - .. centered:: :math:`\square` 30 | * - :doc:`Step trainer ` 31 | - .. centered:: :math:`\blacksquare` 32 | - .. centered:: :math:`\blacksquare` 33 | * - :doc:`Manual training ` 34 | - .. centered:: :math:`\blacksquare` 35 | - .. centered:: :math:`\blacksquare` 36 | 37 | Base class 38 | ---------- 39 | 40 | .. note:: 41 | 42 | This is the base class for all the other classes in this module. 43 | It provides the basic functionality for the other classes. 44 | **It is not intended to be used directly**. 45 | 46 | .. raw:: html 47 | 48 |
49 | 50 | Basic inheritance usage 51 | ^^^^^^^^^^^^^^^^^^^^^^^ 52 | 53 | .. tabs:: 54 | 55 | .. group-tab:: |_4| |pytorch| |_4| 56 | 57 | .. literalinclude:: ../snippets/trainer.py 58 | :language: python 59 | :start-after: [pytorch-start-base] 60 | :end-before: [pytorch-end-base] 61 | 62 | .. group-tab:: |_4| |jax| |_4| 63 | 64 | .. literalinclude:: ../snippets/trainer.py 65 | :language: python 66 | :start-after: [jax-start-base] 67 | :end-before: [jax-end-base] 68 | 69 | .. raw:: html 70 | 71 |
72 | 73 | API (PyTorch) 74 | ^^^^^^^^^^^^^ 75 | 76 | .. autoclass:: skrl.trainers.torch.base.Trainer 77 | :undoc-members: 78 | :show-inheritance: 79 | :inherited-members: 80 | :private-members: _setup_agents 81 | :members: 82 | 83 | .. automethod:: __str__ 84 | 85 | .. raw:: html 86 | 87 |
88 | 89 | API (JAX) 90 | ^^^^^^^^^ 91 | 92 | .. autoclass:: skrl.trainers.jax.base.Trainer 93 | :undoc-members: 94 | :show-inheritance: 95 | :inherited-members: 96 | :private-members: _setup_agents 97 | :members: 98 | 99 | .. automethod:: __str__ 100 | -------------------------------------------------------------------------------- /docs/source/api/trainers/manual.rst: -------------------------------------------------------------------------------- 1 | Manual training 2 | =============== 3 | 4 | Train agents by manually controlling the training/evaluation loop. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Concept 11 | ------- 12 | 13 | .. image:: ../../_static/imgs/manual_trainer-light.svg 14 | :width: 100% 15 | :align: center 16 | :class: only-light 17 | :alt: Manual trainer 18 | 19 | .. image:: ../../_static/imgs/manual_trainer-dark.svg 20 | :width: 100% 21 | :align: center 22 | :class: only-dark 23 | :alt: Manual trainer 24 | 25 | .. raw:: html 26 | 27 |
28 | 29 | Usage 30 | ----- 31 | 32 | .. tabs:: 33 | 34 | .. group-tab:: |_4| |pytorch| |_4| 35 | 36 | .. tabs:: 37 | 38 | .. group-tab:: Training 39 | 40 | .. literalinclude:: ../../snippets/trainer.py 41 | :language: python 42 | :start-after: [pytorch-start-manual-training] 43 | :end-before: [pytorch-end-manual-training] 44 | 45 | .. group-tab:: Evaluation 46 | 47 | .. literalinclude:: ../../snippets/trainer.py 48 | :language: python 49 | :start-after: [pytorch-start-manual-evaluation] 50 | :end-before: [pytorch-end-manual-evaluation] 51 | 52 | .. group-tab:: |_4| |jax| |_4| 53 | 54 | .. tabs:: 55 | 56 | .. group-tab:: Training 57 | 58 | .. literalinclude:: ../../snippets/trainer.py 59 | :language: python 60 | :start-after: [jax-start-manual-training] 61 | :end-before: [jax-end-manual-training] 62 | 63 | .. group-tab:: Evaluation 64 | 65 | .. literalinclude:: ../../snippets/trainer.py 66 | :language: python 67 | :start-after: [jax-start-manual-evaluation] 68 | :end-before: [jax-end-manual-evaluation] 69 | 70 | .. raw:: html 71 | 72 |
73 | -------------------------------------------------------------------------------- /docs/source/api/trainers/parallel.rst: -------------------------------------------------------------------------------- 1 | Parallel trainer 2 | ================ 3 | 4 | Train agents in parallel using multiple processes. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Concept 11 | ------- 12 | 13 | .. image:: ../../_static/imgs/parallel_trainer-light.svg 14 | :width: 100% 15 | :align: center 16 | :class: only-light 17 | :alt: Parallel trainer 18 | 19 | .. image:: ../../_static/imgs/parallel_trainer-dark.svg 20 | :width: 100% 21 | :align: center 22 | :class: only-dark 23 | :alt: Parallel trainer 24 | 25 | .. raw:: html 26 | 27 |
28 | 29 | Usage 30 | ----- 31 | 32 | .. note:: 33 | 34 | Each process adds a GPU memory overhead (~1GB, although it can be much higher) due to PyTorch's CUDA kernels. See PyTorch `Issue #12873 `_ for more details 35 | 36 | .. note:: 37 | 38 | At the moment, only simultaneous training and evaluation of agents with local memory (no memory sharing) is implemented 39 | 40 | .. tabs:: 41 | 42 | .. group-tab:: |_4| |pytorch| |_4| 43 | 44 | .. literalinclude:: ../../snippets/trainer.py 45 | :language: python 46 | :start-after: [pytorch-start-parallel] 47 | :end-before: [pytorch-end-parallel] 48 | 49 | .. raw:: html 50 | 51 |
52 | 53 | Configuration 54 | ------------- 55 | 56 | .. literalinclude:: ../../../../skrl/trainers/torch/parallel.py 57 | :language: python 58 | :start-after: [start-config-dict-torch] 59 | :end-before: [end-config-dict-torch] 60 | 61 | .. raw:: html 62 | 63 |
64 | 65 | API (PyTorch) 66 | ------------- 67 | 68 | .. autoclass:: skrl.trainers.torch.parallel.PARALLEL_TRAINER_DEFAULT_CONFIG 69 | 70 | .. autoclass:: skrl.trainers.torch.parallel.ParallelTrainer 71 | :undoc-members: 72 | :show-inheritance: 73 | :inherited-members: 74 | :members: 75 | -------------------------------------------------------------------------------- /docs/source/api/trainers/sequential.rst: -------------------------------------------------------------------------------- 1 | Sequential trainer 2 | ================== 3 | 4 | Train agents sequentially (i.e., one after the other in each interaction with the environment). 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Concept 11 | ------- 12 | 13 | .. image:: ../../_static/imgs/sequential_trainer-light.svg 14 | :width: 100% 15 | :align: center 16 | :class: only-light 17 | :alt: Sequential trainer 18 | 19 | .. image:: ../../_static/imgs/sequential_trainer-dark.svg 20 | :width: 100% 21 | :align: center 22 | :class: only-dark 23 | :alt: Sequential trainer 24 | 25 | .. raw:: html 26 | 27 |
28 | 29 | Usage 30 | ----- 31 | 32 | .. tabs:: 33 | 34 | .. group-tab:: |_4| |pytorch| |_4| 35 | 36 | .. literalinclude:: ../../snippets/trainer.py 37 | :language: python 38 | :start-after: [pytorch-start-sequential] 39 | :end-before: [pytorch-end-sequential] 40 | 41 | .. group-tab:: |_4| |jax| |_4| 42 | 43 | .. literalinclude:: ../../snippets/trainer.py 44 | :language: python 45 | :start-after: [jax-start-sequential] 46 | :end-before: [jax-end-sequential] 47 | 48 | .. raw:: html 49 | 50 |
51 | 52 | Configuration 53 | ------------- 54 | 55 | .. literalinclude:: ../../../../skrl/trainers/torch/sequential.py 56 | :language: python 57 | :start-after: [start-config-dict-torch] 58 | :end-before: [end-config-dict-torch] 59 | 60 | .. raw:: html 61 | 62 |
63 | 64 | API (PyTorch) 65 | ------------- 66 | 67 | .. autoclass:: skrl.trainers.torch.sequential.SEQUENTIAL_TRAINER_DEFAULT_CONFIG 68 | 69 | .. autoclass:: skrl.trainers.torch.sequential.SequentialTrainer 70 | :undoc-members: 71 | :show-inheritance: 72 | :inherited-members: 73 | :members: 74 | 75 | .. raw:: html 76 | 77 |
78 | 79 | API (JAX) 80 | --------- 81 | 82 | .. autoclass:: skrl.trainers.jax.sequential.SEQUENTIAL_TRAINER_DEFAULT_CONFIG 83 | 84 | .. autoclass:: skrl.trainers.jax.sequential.SequentialTrainer 85 | :undoc-members: 86 | :show-inheritance: 87 | :inherited-members: 88 | :members: 89 | -------------------------------------------------------------------------------- /docs/source/api/trainers/step.rst: -------------------------------------------------------------------------------- 1 | Step trainer 2 | ============ 3 | 4 | Train agents controlling the training/evaluation loop step-by-step. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Concept 11 | ------- 12 | 13 | .. image:: ../../_static/imgs/manual_trainer-light.svg 14 | :width: 100% 15 | :align: center 16 | :class: only-light 17 | :alt: Step-by-step trainer 18 | 19 | .. image:: ../../_static/imgs/manual_trainer-dark.svg 20 | :width: 100% 21 | :align: center 22 | :class: only-dark 23 | :alt: Step-by-step trainer 24 | 25 | .. raw:: html 26 | 27 |
28 | 29 | Usage 30 | ----- 31 | 32 | .. tabs:: 33 | 34 | .. group-tab:: |_4| |pytorch| |_4| 35 | 36 | .. literalinclude:: ../../snippets/trainer.py 37 | :language: python 38 | :start-after: [pytorch-start-step] 39 | :end-before: [pytorch-end-step] 40 | 41 | .. group-tab:: |_4| |jax| |_4| 42 | 43 | .. literalinclude:: ../../snippets/trainer.py 44 | :language: python 45 | :start-after: [jax-start-step] 46 | :end-before: [jax-end-step] 47 | 48 | .. raw:: html 49 | 50 |
51 | 52 | Configuration 53 | ------------- 54 | 55 | .. literalinclude:: ../../../../skrl/trainers/torch/step.py 56 | :language: python 57 | :start-after: [start-config-dict-torch] 58 | :end-before: [end-config-dict-torch] 59 | 60 | .. raw:: html 61 | 62 |
63 | 64 | API (PyTorch) 65 | ------------- 66 | 67 | .. autoclass:: skrl.trainers.torch.step.STEP_TRAINER_DEFAULT_CONFIG 68 | 69 | .. autoclass:: skrl.trainers.torch.step.StepTrainer 70 | :undoc-members: 71 | :show-inheritance: 72 | :inherited-members: 73 | :members: 74 | 75 | .. raw:: html 76 | 77 |
78 | 79 | API (JAX) 80 | --------- 81 | 82 | .. autoclass:: skrl.trainers.jax.step.STEP_TRAINER_DEFAULT_CONFIG 83 | 84 | .. autoclass:: skrl.trainers.jax.step.StepTrainer 85 | :undoc-members: 86 | :show-inheritance: 87 | :inherited-members: 88 | :members: 89 | -------------------------------------------------------------------------------- /docs/source/api/utils.rst: -------------------------------------------------------------------------------- 1 | Utils and configurations 2 | ======================== 3 | 4 | .. toctree:: 5 | :hidden: 6 | 7 | ML frameworks configuration 8 | Random seed 9 | Spaces 10 | Model instantiators 11 | Runner 12 | Distributed runs 13 | Memory and Tensorboard file post-processing 14 | Hugging Face integration 15 | Isaac Gym utils 16 | Omniverse Isaac Gym utils 17 | 18 | A set of utilities and configurations for managing an RL setup is provided as part of the library. 19 | 20 | .. raw:: html 21 | 22 |

23 | 24 | .. list-table:: 25 | :header-rows: 1 26 | 27 | * - Configurations 28 | - .. centered:: |_4| |pytorch| |_4| 29 | - .. centered:: |_4| |jax| |_4| 30 | * - :doc:`ML frameworks ` configuration |_5| |_5| |_5| |_5| |_5| |_2| 31 | - .. centered:: :math:`\blacksquare` 32 | - .. centered:: :math:`\blacksquare` 33 | 34 | .. list-table:: 35 | :header-rows: 1 36 | 37 | * - Utils 38 | - .. centered:: |_4| |pytorch| |_4| 39 | - .. centered:: |_4| |jax| |_4| 40 | * - :doc:`Random seed ` 41 | - .. centered:: :math:`\blacksquare` 42 | - .. centered:: :math:`\blacksquare` 43 | * - :doc:`Spaces ` 44 | - .. centered:: :math:`\blacksquare` 45 | - .. centered:: :math:`\blacksquare` 46 | * - :doc:`Model instantiators ` 47 | - .. centered:: :math:`\blacksquare` 48 | - .. centered:: :math:`\blacksquare` 49 | * - :doc:`Runner ` 50 | - .. centered:: :math:`\blacksquare` 51 | - .. centered:: :math:`\blacksquare` 52 | * - :doc:`Distributed runs ` 53 | - .. centered:: :math:`\blacksquare` 54 | - .. centered:: :math:`\blacksquare` 55 | * - Memory and Tensorboard :doc:`file post-processing ` 56 | - .. centered:: :math:`\blacksquare` 57 | - .. centered:: :math:`\blacksquare` 58 | * - :doc:`Hugging Face integration ` 59 | - .. centered:: :math:`\blacksquare` 60 | - .. centered:: :math:`\blacksquare` 61 | * - :doc:`Isaac Gym utils ` 62 | - .. centered:: :math:`\blacksquare` 63 | - .. centered:: :math:`\blacksquare` 64 | * - :doc:`Omniverse Isaac Gym utils ` 65 | - .. centered:: :math:`\blacksquare` 66 | - .. centered:: :math:`\blacksquare` 67 | -------------------------------------------------------------------------------- /docs/source/api/utils/distributed.rst: -------------------------------------------------------------------------------- 1 | Distributed 2 | =========== 3 | 4 | Utilities to start multiple processes from a single program invocation in distributed learning 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | PyTorch 11 | ------- 12 | 13 | PyTorch provides a Python module/console script to launch distributed runs. Visit PyTorch's `torchrun `_ documentation for more details. 14 | 15 | The following environment variables available in all processes can be accessed through the library: 16 | 17 | * ``LOCAL_RANK`` (accessible via :data:`skrl.config.torch.local_rank`): The local rank. 18 | * ``RANK`` (accessible via :data:`skrl.config.torch.rank`): The global rank. 19 | * ``WORLD_SIZE`` (accessible via :data:`skrl.config.torch.world_size`): The world size (total number of workers in the job). 20 | 21 | JAX 22 | --- 23 | 24 | According to the JAX documentation for `multi-host and multi-process environments `_, JAX doesn't automatically start multiple processes from a single program invocation. 25 | 26 | Therefore, in order to make distributed learning simpler, this library provides a module (based on the PyTorch ``torch.distributed.run`` module) for launching multi-host and multi-process learning directly from the command line. 27 | 28 | This module launches, in multiple processes, the same JAX Python program (Single Program, Multiple Data (SPMD) parallel computation technique) that defines the following environment variables for each process: 29 | 30 | * ``JAX_LOCAL_RANK`` (accessible via :data:`skrl.config.jax.local_rank`): The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node). 31 | * ``JAX_RANK`` (accessible via :data:`skrl.config.jax.rank`): The rank (ID number) of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes). 32 | * ``JAX_WORLD_SIZE`` (accessible via :data:`skrl.config.jax.world_size`): The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes). 33 | * ``JAX_COORDINATOR_ADDR`` (accessible via :data:`skrl.config.jax.coordinator_address`): IP address where process 0 will start a JAX coordinator service. 34 | * ``JAX_COORDINATOR_PORT`` (accessible via :data:`skrl.config.jax.coordinator_address`): Port where process 0 will start a JAX coordinator service. 35 | 36 | .. raw:: html 37 | 38 |
39 | 40 | Usage 41 | ^^^^^ 42 | 43 | .. code-block:: bash 44 | 45 | $ python -m skrl.utils.distributed.jax --help 46 | 47 | .. literalinclude:: ../../snippets/utils_distributed.txt 48 | :language: text 49 | :start-after: [start-distributed-launcher-jax] 50 | :end-before: [end-distributed-launcher-jax] 51 | 52 | .. raw:: html 53 | 54 |
55 | 56 | API 57 | ^^^ 58 | 59 | .. autofunction:: skrl.utils.distributed.jax.launcher.launch 60 | -------------------------------------------------------------------------------- /docs/source/api/utils/huggingface.rst: -------------------------------------------------------------------------------- 1 | Hugging Face integration 2 | ======================== 3 | 4 | The Hugging Face (HF) Hub is a platform for building, training, and deploying ML models, as well as accessing a variety of datasets and metrics for further analysis and validation. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Integration 11 | ----------- 12 | 13 | .. raw:: html 14 | 15 |
16 | 17 | Download model from HF Hub 18 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 19 | 20 | Several skrl-trained models (agent checkpoints) for different environments/tasks of Gym/Gymnasium, Isaac Gym, Omniverse Isaac Gym, etc. are available in the Hugging Face Hub 21 | 22 | These models can be used as comparison benchmarks, for collecting environment transitions in memory (for offline reinforcement learning, e.g.) or for pre-initialization of agents for performing similar tasks, among others 23 | 24 | Visit the `skrl organization on the Hugging Face Hub `_ to access publicly available models! 25 | 26 | .. raw:: html 27 | 28 |
29 | 30 | API 31 | --- 32 | 33 | .. autofunction:: skrl.utils.huggingface.download_model_from_huggingface 34 | -------------------------------------------------------------------------------- /docs/source/api/utils/omniverse_isaacgym_utils.rst: -------------------------------------------------------------------------------- 1 | Omniverse Isaac Gym utils 2 | ========================= 3 | 4 | Utilities for ease of programming of Omniverse Isaac Gym environments. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Control of robotic manipulators 11 | ------------------------------- 12 | 13 | .. raw:: html 14 | 15 |
16 | 17 | Differential inverse kinematics 18 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 19 | 20 | This implementation attempts to unify under a single and reusable function the whole set of procedures used to compute the inverse kinematics of a robotic manipulator, originally shown in the Isaac Lab (Orbit then) framework's task space controllers section, but this time for Omniverse Isaac Gym. 21 | 22 | :math:`\Delta\theta =` :guilabel:`scale` :math:`J^\dagger \, \vec{e}` 23 | 24 | where 25 | 26 | | :math:`\qquad \Delta\theta \;` is the change in joint angles 27 | | :math:`\qquad \vec{e} \;` is the Cartesian pose error (position and orientation) 28 | | :math:`\qquad J^\dagger \;` is the pseudoinverse of the Jacobian estimated as follows: 29 | 30 | The pseudoinverse of the Jacobian (:math:`J^\dagger`) is estimated as follows: 31 | 32 | * Tanspose: :math:`\; J^\dagger = J^T` 33 | * Pseduoinverse: :math:`\; J^\dagger = J^T(JJ^T)^{-1}` 34 | * Damped least-squares: :math:`\; J^\dagger = J^T(JJ^T \, +` :guilabel:`damping`:math:`{}^2 I)^{-1}` 35 | * Singular-vale decomposition: See `buss2004introduction `_ (section 6) 36 | 37 | .. raw:: html 38 | 39 |
40 | 41 | API 42 | ^^^ 43 | 44 | .. autofunction:: skrl.utils.omniverse_isaacgym_utils.ik 45 | 46 | .. raw:: html 47 | 48 |
49 | 50 | OmniIsaacGymEnvs-like environment instance 51 | ------------------------------------------ 52 | 53 | Instantiate a VecEnvBase-based object compatible with OmniIsaacGymEnvs for use outside of the OmniIsaacGymEnvs implementation. 54 | 55 | .. raw:: html 56 | 57 |
58 | 59 | API 60 | ^^^ 61 | 62 | .. autofunction:: skrl.utils.omniverse_isaacgym_utils.get_env_instance 63 | -------------------------------------------------------------------------------- /docs/source/api/utils/postprocessing.rst: -------------------------------------------------------------------------------- 1 | File post-processing 2 | ==================== 3 | 4 | Utilities for processing files generated during training/evaluation. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Exported memories 11 | ----------------- 12 | 13 | This library provides an implementation for quickly loading exported memory files to inspect their contents in future post-processing steps. See the section :ref:`Library utilities (skrl.utils module) ` for a real use case 14 | 15 | .. raw:: html 16 | 17 |
18 | 19 | Usage 20 | ^^^^^ 21 | 22 | .. tabs:: 23 | 24 | .. tab:: PyTorch (.pt) 25 | 26 | .. literalinclude:: ../../snippets/utils_postprocessing.py 27 | :language: python 28 | :emphasize-lines: 1, 5-6 29 | :start-after: [start-memory_file_iterator-torch] 30 | :end-before: [end-memory_file_iterator-torch] 31 | 32 | .. tab:: NumPy (.npz) 33 | 34 | .. literalinclude:: ../../snippets/utils_postprocessing.py 35 | :language: python 36 | :emphasize-lines: 1, 5-6 37 | :start-after: [start-memory_file_iterator-numpy] 38 | :end-before: [end-memory_file_iterator-numpy] 39 | 40 | .. tab:: Comma-separated values (.csv) 41 | 42 | .. literalinclude:: ../../snippets/utils_postprocessing.py 43 | :language: python 44 | :emphasize-lines: 1, 5-6 45 | :start-after: [start-memory_file_iterator-csv] 46 | :end-before: [end-memory_file_iterator-csv] 47 | 48 | .. raw:: html 49 | 50 |
51 | 52 | API 53 | ^^^ 54 | 55 | .. autoclass:: skrl.utils.postprocessing.MemoryFileIterator 56 | :undoc-members: 57 | :show-inheritance: 58 | :inherited-members: 59 | :private-members: _format_numpy, _format_torch, _format_csv 60 | :members: 61 | 62 | .. automethod:: __iter__ 63 | .. automethod:: __next__ 64 | 65 | .. raw:: html 66 | 67 |
68 | 69 | Tensorboard files 70 | ----------------- 71 | 72 | This library provides an implementation for quickly loading Tensorboard files to inspect their contents in future post-processing steps. See the section :ref:`Library utilities (skrl.utils module) ` for a real use case 73 | 74 | .. raw:: html 75 | 76 |
77 | 78 | Requirements 79 | ^^^^^^^^^^^^ 80 | 81 | This utility requires the `TensorFlow `_ package to be installed to load and parse Tensorboard files: 82 | 83 | .. code-block:: bash 84 | 85 | pip install tensorflow 86 | 87 | .. raw:: html 88 | 89 |
90 | 91 | Usage 92 | ^^^^^ 93 | 94 | .. tabs:: 95 | 96 | .. tab:: Tensorboard (events.out.tfevents.*) 97 | 98 | .. literalinclude:: ../../snippets/utils_postprocessing.py 99 | :language: python 100 | :emphasize-lines: 1, 5-7 101 | :start-after: [start-tensorboard_file_iterator-list] 102 | :end-before: [end-tensorboard_file_iterator-list] 103 | 104 | .. raw:: html 105 | 106 |
107 | 108 | API 109 | ^^^ 110 | 111 | .. autoclass:: skrl.utils.postprocessing.TensorboardFileIterator 112 | :undoc-members: 113 | :show-inheritance: 114 | :inherited-members: 115 | :members: 116 | 117 | .. automethod:: __iter__ 118 | .. automethod:: __next__ 119 | -------------------------------------------------------------------------------- /docs/source/api/utils/runner.rst: -------------------------------------------------------------------------------- 1 | Runner 2 | ====== 3 | 4 | Utility that configures and instantiates skrl's components to run training/evaluation workflows in a few lines of code. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Usage 11 | ----- 12 | 13 | .. hint:: 14 | 15 | The ``Runner`` classes encapsulates, and greatly simplifies, the definitions and instantiations needed to execute RL tasks. 16 | However, such simplification hides and makes difficult the modification and readability of the code (models, agents, etc.). 17 | 18 | For more control and readability over the RL system setup refer to the :doc:`Examples <../../intro/examples>` section's training scripts (**recommended!**) 19 | 20 | .. raw:: html 21 | 22 |
23 | 24 | .. tabs:: 25 | 26 | .. group-tab:: |_4| |pytorch| |_4| 27 | 28 | .. tabs:: 29 | 30 | .. group-tab:: Python code 31 | 32 | .. literalinclude:: ../../snippets/runner.txt 33 | :language: python 34 | :emphasize-lines: 1, 10, 13, 16 35 | :start-after: [start-runner-train-torch] 36 | :end-before: [end-runner-train-torch] 37 | 38 | .. group-tab:: Example .yaml file (PPO) 39 | 40 | .. literalinclude:: ../../snippets/runner.txt 41 | :language: yaml 42 | :start-after: [start-cfg-yaml] 43 | :end-before: [end-cfg-yaml] 44 | 45 | .. group-tab:: |_4| |jax| |_4| 46 | 47 | .. tabs:: 48 | 49 | .. group-tab:: Python code 50 | 51 | .. literalinclude:: ../../snippets/runner.txt 52 | :language: python 53 | :emphasize-lines: 1, 10, 13, 16 54 | :start-after: [start-runner-train-jax] 55 | :end-before: [end-runner-train-jax] 56 | 57 | .. group-tab:: Example .yaml file (PPO) 58 | 59 | .. literalinclude:: ../../snippets/runner.txt 60 | :language: yaml 61 | :start-after: [start-cfg-yaml] 62 | :end-before: [end-cfg-yaml] 63 | 64 | API (PyTorch) 65 | ------------- 66 | 67 | .. autoclass:: skrl.utils.runner.torch.Runner 68 | :show-inheritance: 69 | :members: 70 | 71 | .. raw:: html 72 | 73 |
74 | 75 | API (JAX) 76 | --------- 77 | 78 | .. autoclass:: skrl.utils.runner.jax.Runner 79 | :show-inheritance: 80 | :members: 81 | -------------------------------------------------------------------------------- /docs/source/api/utils/seed.rst: -------------------------------------------------------------------------------- 1 | Random seed 2 | =========== 3 | 4 | Utilities for seeding the random number generators. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Seed the random number generators 11 | --------------------------------- 12 | 13 | .. raw:: html 14 | 15 |
16 | 17 | API 18 | ^^^ 19 | 20 | .. autofunction:: skrl.utils.set_seed 21 | -------------------------------------------------------------------------------- /docs/source/api/utils/spaces.rst: -------------------------------------------------------------------------------- 1 | Spaces 2 | ====== 3 | 4 | Utilities to operate on Gymnasium `spaces `_. 5 | 6 | .. raw:: html 7 | 8 |

9 | 10 | Overview 11 | -------- 12 | 13 | The utilities described in this section supports the following Gymnasium spaces: 14 | 15 | .. list-table:: 16 | :header-rows: 1 17 | 18 | * - Type 19 | - Supported spaces 20 | * - Fundamental 21 | - :py:class:`~gymnasium.spaces.Box`, :py:class:`~gymnasium.spaces.Discrete`, and :py:class:`~gymnasium.spaces.MultiDiscrete` 22 | * - Composite 23 | - :py:class:`~gymnasium.spaces.Dict` and :py:class:`~gymnasium.spaces.Tuple` 24 | 25 | The following table provides a snapshot of the space sample conversion functions: 26 | 27 | .. list-table:: 28 | :header-rows: 1 29 | 30 | * - Input 31 | - Function 32 | - Output 33 | * - Space (NumPy / int) 34 | - :py:func:`~skrl.utils.spaces.torch.tensorize_space` 35 | - Space (PyTorch / JAX) 36 | * - Space (PyTorch / JAX) 37 | - :py:func:`~skrl.utils.spaces.torch.untensorize_space` 38 | - Space (NumPy / int) 39 | * - Space (PyTorch / JAX) 40 | - :py:func:`~skrl.utils.spaces.torch.flatten_tensorized_space` 41 | - PyTorch tensor / JAX array 42 | * - PyTorch tensor / JAX array 43 | - :py:func:`~skrl.utils.spaces.torch.unflatten_tensorized_space` 44 | - Space (PyTorch / JAX) 45 | 46 | .. raw:: html 47 | 48 |
49 | 50 | API (PyTorch) 51 | ------------- 52 | 53 | .. autofunction:: skrl.utils.spaces.torch.compute_space_size 54 | 55 | .. autofunction:: skrl.utils.spaces.torch.convert_gym_space 56 | 57 | .. autofunction:: skrl.utils.spaces.torch.flatten_tensorized_space 58 | 59 | .. autofunction:: skrl.utils.spaces.torch.sample_space 60 | 61 | .. autofunction:: skrl.utils.spaces.torch.tensorize_space 62 | 63 | .. autofunction:: skrl.utils.spaces.torch.unflatten_tensorized_space 64 | 65 | .. autofunction:: skrl.utils.spaces.torch.untensorize_space 66 | 67 | .. raw:: html 68 | 69 |
70 | 71 | API (JAX) 72 | --------- 73 | 74 | .. autofunction:: skrl.utils.spaces.jax.compute_space_size 75 | 76 | .. autofunction:: skrl.utils.spaces.jax.convert_gym_space 77 | 78 | .. autofunction:: skrl.utils.spaces.jax.flatten_tensorized_space 79 | 80 | .. autofunction:: skrl.utils.spaces.jax.sample_space 81 | 82 | .. autofunction:: skrl.utils.spaces.jax.tensorize_space 83 | 84 | .. autofunction:: skrl.utils.spaces.jax.unflatten_tensorized_space 85 | 86 | .. autofunction:: skrl.utils.spaces.jax.untensorize_space 87 | -------------------------------------------------------------------------------- /docs/source/examples/gym/jax_gym_cartpole_cem.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # import the skrl components to build the RL system 8 | from skrl import config 9 | from skrl.agents.jax.cem import CEM, CEM_DEFAULT_CONFIG 10 | from skrl.envs.wrappers.jax import wrap_env 11 | from skrl.memories.jax import RandomMemory 12 | from skrl.models.jax import CategoricalMixin, Model 13 | from skrl.trainers.jax import SequentialTrainer 14 | from skrl.utils import set_seed 15 | 16 | 17 | config.jax.backend = "numpy" # or "jax" 18 | 19 | 20 | # seed for reproducibility 21 | set_seed() # e.g. `set_seed(42)` for fixed seed 22 | 23 | 24 | # define model (categorical model) using mixin 25 | class Policy(CategoricalMixin, Model): 26 | def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs): 27 | Model.__init__(self, observation_space, action_space, device, **kwargs) 28 | CategoricalMixin.__init__(self, unnormalized_log_prob) 29 | 30 | @nn.compact 31 | def __call__(self, inputs, role): 32 | x = nn.relu(nn.Dense(64)(inputs["states"])) 33 | x = nn.relu(nn.Dense(64)(x)) 34 | x = nn.Dense(self.num_actions)(x) 35 | return x, {} 36 | 37 | 38 | # load and wrap the gym environment. 39 | # note: the environment version may change depending on the gym version 40 | try: 41 | env = gym.make("CartPole-v0") 42 | except gym.error.DeprecatedEnv as e: 43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("CartPole-v")][0] 44 | print("CartPole-v0 not found. Trying {}".format(env_id)) 45 | env = gym.make(env_id) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate a memory as experience replay 52 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False) 53 | 54 | 55 | # instantiate the agent's model (function approximator). 56 | # CEM requires 1 model, visit its documentation for more details 57 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models 58 | models = {} 59 | models["policy"] = Policy(env.observation_space, env.action_space, device) 60 | 61 | # instantiate models' state dict 62 | for role, model in models.items(): 63 | model.init_state_dict(role) 64 | 65 | # initialize models' parameters (weights and biases) 66 | for model in models.values(): 67 | model.init_parameters(method_name="normal", stddev=0.1) 68 | 69 | 70 | # configure and instantiate the agent (visit its documentation to see all the options) 71 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters 72 | cfg = CEM_DEFAULT_CONFIG.copy() 73 | cfg["rollouts"] = 1000 74 | cfg["learning_starts"] = 100 75 | # logging to TensorBoard and write checkpoints (in timesteps) 76 | cfg["experiment"]["write_interval"] = 1000 77 | cfg["experiment"]["checkpoint_interval"] = 5000 78 | cfg["experiment"]["directory"] = "runs/jax/CartPole" 79 | 80 | agent = CEM(models=models, 81 | memory=memory, 82 | cfg=cfg, 83 | observation_space=env.observation_space, 84 | action_space=env.action_space, 85 | device=device) 86 | 87 | 88 | # configure and instantiate the RL trainer 89 | cfg_trainer = {"timesteps": 100000, "headless": True} 90 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 91 | 92 | # start training 93 | trainer.train() 94 | -------------------------------------------------------------------------------- /docs/source/examples/gym/torch_gym_cartpole_cem.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # import the skrl components to build the RL system 7 | from skrl.agents.torch.cem import CEM, CEM_DEFAULT_CONFIG 8 | from skrl.envs.wrappers.torch import wrap_env 9 | from skrl.memories.torch import RandomMemory 10 | from skrl.models.torch import CategoricalMixin, Model 11 | from skrl.trainers.torch import SequentialTrainer 12 | from skrl.utils import set_seed 13 | 14 | 15 | # seed for reproducibility 16 | set_seed() # e.g. `set_seed(42)` for fixed seed 17 | 18 | 19 | # define model (categorical model) using mixin 20 | class Policy(CategoricalMixin, Model): 21 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True): 22 | Model.__init__(self, observation_space, action_space, device) 23 | CategoricalMixin.__init__(self, unnormalized_log_prob) 24 | 25 | self.linear_layer_1 = nn.Linear(self.num_observations, 64) 26 | self.linear_layer_2 = nn.Linear(64, 64) 27 | self.output_layer = nn.Linear(64, self.num_actions) 28 | 29 | def compute(self, inputs, role): 30 | x = F.relu(self.linear_layer_1(inputs["states"])) 31 | x = F.relu(self.linear_layer_2(x)) 32 | return self.output_layer(x), {} 33 | 34 | 35 | # load and wrap the gym environment. 36 | # note: the environment version may change depending on the gym version 37 | try: 38 | env = gym.make("CartPole-v0") 39 | except gym.error.DeprecatedEnv as e: 40 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("CartPole-v")][0] 41 | print("CartPole-v0 not found. Trying {}".format(env_id)) 42 | env = gym.make(env_id) 43 | env = wrap_env(env) 44 | 45 | device = env.device 46 | 47 | 48 | # instantiate a memory as experience replay 49 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False) 50 | 51 | 52 | # instantiate the agent's model (function approximator). 53 | # CEM requires 1 model, visit its documentation for more details 54 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models 55 | models = {} 56 | models["policy"] = Policy(env.observation_space, env.action_space, device) 57 | 58 | # initialize models' parameters (weights and biases) 59 | for model in models.values(): 60 | model.init_parameters(method_name="normal_", mean=0.0, std=0.1) 61 | 62 | 63 | # configure and instantiate the agent (visit its documentation to see all the options) 64 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters 65 | cfg = CEM_DEFAULT_CONFIG.copy() 66 | cfg["rollouts"] = 1000 67 | cfg["learning_starts"] = 100 68 | # logging to TensorBoard and write checkpoints (in timesteps) 69 | cfg["experiment"]["write_interval"] = 1000 70 | cfg["experiment"]["checkpoint_interval"] = 5000 71 | cfg["experiment"]["directory"] = "runs/torch/CartPole" 72 | 73 | agent = CEM(models=models, 74 | memory=memory, 75 | cfg=cfg, 76 | observation_space=env.observation_space, 77 | action_space=env.action_space, 78 | device=device) 79 | 80 | 81 | # configure and instantiate the RL trainer 82 | cfg_trainer = {"timesteps": 100000, "headless": True} 83 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 84 | 85 | # start training 86 | trainer.train() 87 | -------------------------------------------------------------------------------- /docs/source/examples/gym/torch_gym_frozen_lake_q_learning.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | import torch 4 | 5 | # import the skrl components to build the RL system 6 | from skrl.agents.torch.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG 7 | from skrl.envs.wrappers.torch import wrap_env 8 | from skrl.models.torch import Model, TabularMixin 9 | from skrl.trainers.torch import SequentialTrainer 10 | from skrl.utils import set_seed 11 | 12 | 13 | # seed for reproducibility 14 | set_seed() # e.g. `set_seed(42)` for fixed seed 15 | 16 | 17 | # define model (tabular model) using mixin 18 | class EpilonGreedyPolicy(TabularMixin, Model): 19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 20 | Model.__init__(self, observation_space, action_space, device) 21 | TabularMixin.__init__(self, num_envs) 22 | 23 | self.epsilon = epsilon 24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), 25 | dtype=torch.float32, device=self.device) 26 | 27 | def compute(self, inputs, role): 28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], 29 | dim=-1, keepdim=True).view(-1,1) 30 | 31 | # choose random actions for exploration according to epsilon 32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 33 | if indexes.numel(): 34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 35 | return actions, {} 36 | 37 | 38 | # load and wrap the gym environment. 39 | # note: the environment version may change depending on the gym version 40 | try: 41 | env = gym.make("FrozenLake-v0") 42 | except gym.error.DeprecatedEnv as e: 43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("FrozenLake-v")][0] 44 | print("FrozenLake-v0 not found. Trying {}".format(env_id)) 45 | env = gym.make(env_id) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate the agent's model (table) 52 | # Q-learning requires 1 model, visit its documentation for more details 53 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#models 54 | models = {} 55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#configuration-and-hyperparameters 60 | cfg = Q_LEARNING_DEFAULT_CONFIG.copy() 61 | cfg["discount_factor"] = 0.999 62 | cfg["alpha"] = 0.4 63 | # logging to TensorBoard and write checkpoints (in timesteps) 64 | cfg["experiment"]["write_interval"] = 1600 65 | cfg["experiment"]["checkpoint_interval"] = 8000 66 | cfg["experiment"]["directory"] = "runs/torch/FrozenLake" 67 | 68 | agent = Q_LEARNING(models=models, 69 | memory=None, 70 | cfg=cfg, 71 | observation_space=env.observation_space, 72 | action_space=env.action_space, 73 | device=device) 74 | 75 | 76 | # configure and instantiate the RL trainer 77 | cfg_trainer = {"timesteps": 80000, "headless": True} 78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 79 | 80 | # start training 81 | trainer.train() 82 | -------------------------------------------------------------------------------- /docs/source/examples/gym/torch_gym_frozen_lake_vector_q_learning.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | import torch 4 | 5 | # import the skrl components to build the RL system 6 | from skrl.agents.torch.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG 7 | from skrl.envs.wrappers.torch import wrap_env 8 | from skrl.models.torch import Model, TabularMixin 9 | from skrl.trainers.torch import SequentialTrainer 10 | from skrl.utils import set_seed 11 | 12 | 13 | # seed for reproducibility 14 | set_seed() # e.g. `set_seed(42)` for fixed seed 15 | 16 | 17 | # define model (tabular model) using mixin 18 | class EpilonGreedyPolicy(TabularMixin, Model): 19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 20 | Model.__init__(self, observation_space, action_space, device) 21 | TabularMixin.__init__(self, num_envs) 22 | 23 | self.epsilon = epsilon 24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), 25 | dtype=torch.float32, device=self.device) 26 | 27 | def compute(self, inputs, role): 28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], 29 | dim=-1, keepdim=True).view(-1,1) 30 | 31 | # choose random actions for exploration according to epsilon 32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 33 | if indexes.numel(): 34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 35 | return actions, {} 36 | 37 | 38 | # load and wrap the gym environment. 39 | # note: the environment version may change depending on the gym version 40 | try: 41 | env = gym.vector.make("FrozenLake-v0", num_envs=10, asynchronous=False) 42 | except gym.error.DeprecatedEnv as e: 43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("FrozenLake-v")][0] 44 | print("FrozenLake-v0 not found. Trying {}".format(env_id)) 45 | env = gym.vector.make(env_id, num_envs=10, asynchronous=False) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate the agent's model (table) 52 | # Q-learning requires 1 model, visit its documentation for more details 53 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#models 54 | models = {} 55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#configuration-and-hyperparameters 60 | cfg = Q_LEARNING_DEFAULT_CONFIG.copy() 61 | cfg["discount_factor"] = 0.999 62 | cfg["alpha"] = 0.4 63 | # logging to TensorBoard and write checkpoints (in timesteps) 64 | cfg["experiment"]["write_interval"] = 1600 65 | cfg["experiment"]["checkpoint_interval"] = 8000 66 | cfg["experiment"]["directory"] = "runs/torch/FrozenLake" 67 | 68 | agent = Q_LEARNING(models=models, 69 | memory=None, 70 | cfg=cfg, 71 | observation_space=env.observation_space, 72 | action_space=env.action_space, 73 | device=device) 74 | 75 | 76 | # configure and instantiate the RL trainer 77 | cfg_trainer = {"timesteps": 80000, "headless": True} 78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 79 | 80 | # start training 81 | trainer.train() 82 | -------------------------------------------------------------------------------- /docs/source/examples/gym/torch_gym_taxi_sarsa.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | import torch 4 | 5 | # import the skrl components to build the RL system 6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG 7 | from skrl.envs.wrappers.torch import wrap_env 8 | from skrl.models.torch import Model, TabularMixin 9 | from skrl.trainers.torch import SequentialTrainer 10 | from skrl.utils import set_seed 11 | 12 | 13 | # seed for reproducibility 14 | set_seed() # e.g. `set_seed(42)` for fixed seed 15 | 16 | 17 | # define model (tabular model) using mixin 18 | class EpilonGreedyPolicy(TabularMixin, Model): 19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 20 | Model.__init__(self, observation_space, action_space, device) 21 | TabularMixin.__init__(self, num_envs) 22 | 23 | self.epsilon = epsilon 24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), 25 | dtype=torch.float32, device=self.device) 26 | 27 | def compute(self, inputs, role): 28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], 29 | dim=-1, keepdim=True).view(-1,1) 30 | 31 | # choose random actions for exploration according to epsilon 32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 33 | if indexes.numel(): 34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 35 | return actions, {} 36 | 37 | 38 | # load and wrap the gym environment. 39 | # note: the environment version may change depending on the gym version 40 | try: 41 | env = gym.make("Taxi-v3") 42 | except gym.error.DeprecatedEnv as e: 43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("Taxi-v")][0] 44 | print("Taxi-v3 not found. Trying {}".format(env_id)) 45 | env = gym.make(env_id) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate the agent's model (table) 52 | # SARSA requires 1 model, visit its documentation for more details 53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models 54 | models = {} 55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters 60 | cfg = SARSA_DEFAULT_CONFIG.copy() 61 | cfg["discount_factor"] = 0.999 62 | cfg["alpha"] = 0.4 63 | # logging to TensorBoard and write checkpoints (in timesteps) 64 | cfg["experiment"]["write_interval"] = 1600 65 | cfg["experiment"]["checkpoint_interval"] = 8000 66 | cfg["experiment"]["directory"] = "runs/torch/Taxi" 67 | 68 | agent = SARSA(models=models, 69 | memory=None, 70 | cfg=cfg, 71 | observation_space=env.observation_space, 72 | action_space=env.action_space, 73 | device=device) 74 | 75 | 76 | # configure and instantiate the RL trainer 77 | cfg_trainer = {"timesteps": 80000, "headless": True} 78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 79 | 80 | # start training 81 | trainer.train() 82 | -------------------------------------------------------------------------------- /docs/source/examples/gym/torch_gym_taxi_vector_sarsa.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | import torch 4 | 5 | # import the skrl components to build the RL system 6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG 7 | from skrl.envs.wrappers.torch import wrap_env 8 | from skrl.models.torch import Model, TabularMixin 9 | from skrl.trainers.torch import SequentialTrainer 10 | from skrl.utils import set_seed 11 | 12 | 13 | # seed for reproducibility 14 | set_seed() # e.g. `set_seed(42)` for fixed seed 15 | 16 | 17 | # define model (tabular model) using mixin 18 | class EpilonGreedyPolicy(TabularMixin, Model): 19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 20 | Model.__init__(self, observation_space, action_space, device) 21 | TabularMixin.__init__(self, num_envs) 22 | 23 | self.epsilon = epsilon 24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), 25 | dtype=torch.float32, device=self.device) 26 | 27 | def compute(self, inputs, role): 28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], 29 | dim=-1, keepdim=True).view(-1,1) 30 | 31 | # choose random actions for exploration according to epsilon 32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 33 | if indexes.numel(): 34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 35 | return actions, {} 36 | 37 | 38 | # load and wrap the gym environment. 39 | # note: the environment version may change depending on the gym version 40 | try: 41 | env = gym.vector.make("Taxi-v3", num_envs=10, asynchronous=False) 42 | except gym.error.DeprecatedEnv as e: 43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("Taxi-v")][0] 44 | print("Taxi-v3 not found. Trying {}".format(env_id)) 45 | env = gym.vector.make(env_id, num_envs=10, asynchronous=False) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate the agent's model (table) 52 | # SARSA requires 1 model, visit its documentation for more details 53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models 54 | models = {} 55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters 60 | cfg = SARSA_DEFAULT_CONFIG.copy() 61 | cfg["discount_factor"] = 0.999 62 | cfg["alpha"] = 0.4 63 | # logging to TensorBoard and write checkpoints (in timesteps) 64 | cfg["experiment"]["write_interval"] = 1600 65 | cfg["experiment"]["checkpoint_interval"] = 8000 66 | cfg["experiment"]["directory"] = "runs/torch/Taxi" 67 | 68 | agent = SARSA(models=models, 69 | memory=None, 70 | cfg=cfg, 71 | observation_space=env.observation_space, 72 | action_space=env.action_space, 73 | device=device) 74 | 75 | 76 | # configure and instantiate the RL trainer 77 | cfg_trainer = {"timesteps": 80000, "headless": True} 78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 79 | 80 | # start training 81 | trainer.train() 82 | -------------------------------------------------------------------------------- /docs/source/examples/gymnasium/jax_gymnasium_cartpole_cem.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # import the skrl components to build the RL system 8 | from skrl import config 9 | from skrl.agents.jax.cem import CEM, CEM_DEFAULT_CONFIG 10 | from skrl.envs.wrappers.jax import wrap_env 11 | from skrl.memories.jax import RandomMemory 12 | from skrl.models.jax import CategoricalMixin, Model 13 | from skrl.trainers.jax import SequentialTrainer 14 | from skrl.utils import set_seed 15 | 16 | 17 | config.jax.backend = "numpy" # or "jax" 18 | 19 | 20 | # seed for reproducibility 21 | set_seed() # e.g. `set_seed(42)` for fixed seed 22 | 23 | 24 | # define model (categorical model) using mixin 25 | class Policy(CategoricalMixin, Model): 26 | def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs): 27 | Model.__init__(self, observation_space, action_space, device, **kwargs) 28 | CategoricalMixin.__init__(self, unnormalized_log_prob) 29 | 30 | @nn.compact 31 | def __call__(self, inputs, role): 32 | x = nn.relu(nn.Dense(64)(inputs["states"])) 33 | x = nn.relu(nn.Dense(64)(x)) 34 | x = nn.Dense(self.num_actions)(x) 35 | return x, {} 36 | 37 | 38 | # load and wrap the gymnasium environment. 39 | # note: the environment version may change depending on the gymnasium version 40 | try: 41 | env = gym.make("CartPole-v1") 42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e: 43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("CartPole-v")][0] 44 | print("CartPole-v0 not found. Trying {}".format(env_id)) 45 | env = gym.make(env_id) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate a memory as experience replay 52 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False) 53 | 54 | 55 | # instantiate the agent's model (function approximator). 56 | # CEM requires 1 model, visit its documentation for more details 57 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models 58 | models = {} 59 | models["policy"] = Policy(env.observation_space, env.action_space, device) 60 | 61 | # instantiate models' state dict 62 | for role, model in models.items(): 63 | model.init_state_dict(role) 64 | 65 | # initialize models' parameters (weights and biases) 66 | for model in models.values(): 67 | model.init_parameters(method_name="normal", stddev=0.1) 68 | 69 | 70 | # configure and instantiate the agent (visit its documentation to see all the options) 71 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters 72 | cfg = CEM_DEFAULT_CONFIG.copy() 73 | cfg["rollouts"] = 1000 74 | cfg["learning_starts"] = 100 75 | # logging to TensorBoard and write checkpoints (in timesteps) 76 | cfg["experiment"]["write_interval"] = 1000 77 | cfg["experiment"]["checkpoint_interval"] = 5000 78 | cfg["experiment"]["directory"] = "runs/jax/CartPole" 79 | 80 | agent = CEM(models=models, 81 | memory=memory, 82 | cfg=cfg, 83 | observation_space=env.observation_space, 84 | action_space=env.action_space, 85 | device=device) 86 | 87 | 88 | # configure and instantiate the RL trainer 89 | cfg_trainer = {"timesteps": 100000, "headless": True} 90 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 91 | 92 | # start training 93 | trainer.train() 94 | -------------------------------------------------------------------------------- /docs/source/examples/gymnasium/torch_gymnasium_cartpole_cem.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # import the skrl components to build the RL system 7 | from skrl.agents.torch.cem import CEM, CEM_DEFAULT_CONFIG 8 | from skrl.envs.wrappers.torch import wrap_env 9 | from skrl.memories.torch import RandomMemory 10 | from skrl.models.torch import CategoricalMixin, Model 11 | from skrl.trainers.torch import SequentialTrainer 12 | from skrl.utils import set_seed 13 | 14 | 15 | # seed for reproducibility 16 | set_seed() # e.g. `set_seed(42)` for fixed seed 17 | 18 | 19 | # define model (categorical model) using mixin 20 | class Policy(CategoricalMixin, Model): 21 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True): 22 | Model.__init__(self, observation_space, action_space, device) 23 | CategoricalMixin.__init__(self, unnormalized_log_prob) 24 | 25 | self.linear_layer_1 = nn.Linear(self.num_observations, 64) 26 | self.linear_layer_2 = nn.Linear(64, 64) 27 | self.output_layer = nn.Linear(64, self.num_actions) 28 | 29 | def compute(self, inputs, role): 30 | x = F.relu(self.linear_layer_1(inputs["states"])) 31 | x = F.relu(self.linear_layer_2(x)) 32 | return self.output_layer(x), {} 33 | 34 | 35 | # load and wrap the gymnasium environment. 36 | # note: the environment version may change depending on the gymnasium version 37 | try: 38 | env = gym.make("CartPole-v1") 39 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e: 40 | env_id = [spec for spec in gym.envs.registry if spec.startswith("CartPole-v")][0] 41 | print("CartPole-v0 not found. Trying {}".format(env_id)) 42 | env = gym.make(env_id) 43 | env = wrap_env(env) 44 | 45 | device = env.device 46 | 47 | 48 | # instantiate a memory as experience replay 49 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False) 50 | 51 | 52 | # instantiate the agent's model (function approximator). 53 | # CEM requires 1 model, visit its documentation for more details 54 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models 55 | models = {} 56 | models["policy"] = Policy(env.observation_space, env.action_space, device) 57 | 58 | # initialize models' parameters (weights and biases) 59 | for model in models.values(): 60 | model.init_parameters(method_name="normal_", mean=0.0, std=0.1) 61 | 62 | 63 | # configure and instantiate the agent (visit its documentation to see all the options) 64 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters 65 | cfg = CEM_DEFAULT_CONFIG.copy() 66 | cfg["rollouts"] = 1000 67 | cfg["learning_starts"] = 100 68 | # logging to TensorBoard and write checkpoints (in timesteps) 69 | cfg["experiment"]["write_interval"] = 1000 70 | cfg["experiment"]["checkpoint_interval"] = 5000 71 | cfg["experiment"]["directory"] = "runs/torch/CartPole" 72 | 73 | agent = CEM(models=models, 74 | memory=memory, 75 | cfg=cfg, 76 | observation_space=env.observation_space, 77 | action_space=env.action_space, 78 | device=device) 79 | 80 | 81 | # configure and instantiate the RL trainer 82 | cfg_trainer = {"timesteps": 100000, "headless": True} 83 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 84 | 85 | # start training 86 | trainer.train() 87 | -------------------------------------------------------------------------------- /docs/source/examples/gymnasium/torch_gymnasium_frozen_lake_q_learning.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | import torch 4 | 5 | # import the skrl components to build the RL system 6 | from skrl.agents.torch.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG 7 | from skrl.envs.wrappers.torch import wrap_env 8 | from skrl.models.torch import Model, TabularMixin 9 | from skrl.trainers.torch import SequentialTrainer 10 | from skrl.utils import set_seed 11 | 12 | 13 | # seed for reproducibility 14 | set_seed() # e.g. `set_seed(42)` for fixed seed 15 | 16 | 17 | # define model (tabular model) using mixin 18 | class EpilonGreedyPolicy(TabularMixin, Model): 19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 20 | Model.__init__(self, observation_space, action_space, device) 21 | TabularMixin.__init__(self, num_envs) 22 | 23 | self.epsilon = epsilon 24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), 25 | dtype=torch.float32, device=self.device) 26 | 27 | def compute(self, inputs, role): 28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], 29 | dim=-1, keepdim=True).view(-1,1) 30 | 31 | # choose random actions for exploration according to epsilon 32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 33 | if indexes.numel(): 34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 35 | return actions, {} 36 | 37 | 38 | # load and wrap the gymnasium environment. 39 | # note: the environment version may change depending on the gymnasium version 40 | try: 41 | env = gym.make("FrozenLake-v0") 42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e: 43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("FrozenLake-v")][0] 44 | print("FrozenLake-v0 not found. Trying {}".format(env_id)) 45 | env = gym.make(env_id) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate the agent's model (table) 52 | # Q-learning requires 1 model, visit its documentation for more details 53 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#models 54 | models = {} 55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#configuration-and-hyperparameters 60 | cfg = Q_LEARNING_DEFAULT_CONFIG.copy() 61 | cfg["discount_factor"] = 0.999 62 | cfg["alpha"] = 0.4 63 | # logging to TensorBoard and write checkpoints (in timesteps) 64 | cfg["experiment"]["write_interval"] = 1600 65 | cfg["experiment"]["checkpoint_interval"] = 8000 66 | cfg["experiment"]["directory"] = "runs/torch/FrozenLake" 67 | 68 | agent = Q_LEARNING(models=models, 69 | memory=None, 70 | cfg=cfg, 71 | observation_space=env.observation_space, 72 | action_space=env.action_space, 73 | device=device) 74 | 75 | 76 | # configure and instantiate the RL trainer 77 | cfg_trainer = {"timesteps": 80000, "headless": True} 78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 79 | 80 | # start training 81 | trainer.train() 82 | -------------------------------------------------------------------------------- /docs/source/examples/gymnasium/torch_gymnasium_taxi_sarsa.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | import torch 4 | 5 | # import the skrl components to build the RL system 6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG 7 | from skrl.envs.wrappers.torch import wrap_env 8 | from skrl.models.torch import Model, TabularMixin 9 | from skrl.trainers.torch import SequentialTrainer 10 | from skrl.utils import set_seed 11 | 12 | 13 | # seed for reproducibility 14 | set_seed() # e.g. `set_seed(42)` for fixed seed 15 | 16 | 17 | # define model (tabular model) using mixin 18 | class EpilonGreedyPolicy(TabularMixin, Model): 19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 20 | Model.__init__(self, observation_space, action_space, device) 21 | TabularMixin.__init__(self, num_envs) 22 | 23 | self.epsilon = epsilon 24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), 25 | dtype=torch.float32, device=self.device) 26 | 27 | def compute(self, inputs, role): 28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], 29 | dim=-1, keepdim=True).view(-1,1) 30 | 31 | # choose random actions for exploration according to epsilon 32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 33 | if indexes.numel(): 34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 35 | return actions, {} 36 | 37 | 38 | # load and wrap the gymnasium environment. 39 | # note: the environment version may change depending on the gymnasium version 40 | try: 41 | env = gym.make("Taxi-v3") 42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e: 43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("Taxi-v")][0] 44 | print("Taxi-v3 not found. Trying {}".format(env_id)) 45 | env = gym.make(env_id) 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate the agent's model (table) 52 | # SARSA requires 1 model, visit its documentation for more details 53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models 54 | models = {} 55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters 60 | cfg = SARSA_DEFAULT_CONFIG.copy() 61 | cfg["discount_factor"] = 0.999 62 | cfg["alpha"] = 0.4 63 | # logging to TensorBoard and write checkpoints (in timesteps) 64 | cfg["experiment"]["write_interval"] = 1600 65 | cfg["experiment"]["checkpoint_interval"] = 8000 66 | cfg["experiment"]["directory"] = "runs/torch/Taxi" 67 | 68 | agent = SARSA(models=models, 69 | memory=None, 70 | cfg=cfg, 71 | observation_space=env.observation_space, 72 | action_space=env.action_space, 73 | device=device) 74 | 75 | 76 | # configure and instantiate the RL trainer 77 | cfg_trainer = {"timesteps": 80000, "headless": True} 78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 79 | 80 | # start training 81 | trainer.train() 82 | -------------------------------------------------------------------------------- /docs/source/examples/gymnasium/torch_gymnasium_taxi_vector_sarsa.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | import torch 4 | 5 | # import the skrl components to build the RL system 6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG 7 | from skrl.envs.wrappers.torch import wrap_env 8 | from skrl.models.torch import Model, TabularMixin 9 | from skrl.trainers.torch import SequentialTrainer 10 | from skrl.utils import set_seed 11 | 12 | 13 | # seed for reproducibility 14 | set_seed() # e.g. `set_seed(42)` for fixed seed 15 | 16 | 17 | # define model (tabular model) using mixin 18 | class EpilonGreedyPolicy(TabularMixin, Model): 19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 20 | Model.__init__(self, observation_space, action_space, device) 21 | TabularMixin.__init__(self, num_envs) 22 | 23 | self.epsilon = epsilon 24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), 25 | dtype=torch.float32, device=self.device) 26 | 27 | def compute(self, inputs, role): 28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]], 29 | dim=-1, keepdim=True).view(-1,1) 30 | 31 | # choose random actions for exploration according to epsilon 32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 33 | if indexes.numel(): 34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 35 | return actions, {} 36 | 37 | 38 | # load and wrap the gymnasium environment. 39 | # note: the environment version may change depending on the gymnasium version 40 | try: 41 | env = gym.make_vec("Taxi-v3", num_envs=10, vectorization_mode="sync") 42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e: 43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("Taxi-v")][0] 44 | print("Taxi-v3 not found. Trying {}".format(env_id)) 45 | env = gym.make_vec(env_id, num_envs=10, vectorization_mode="sync") 46 | env = wrap_env(env) 47 | 48 | device = env.device 49 | 50 | 51 | # instantiate the agent's model (table) 52 | # SARSA requires 1 model, visit its documentation for more details 53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models 54 | models = {} 55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters 60 | cfg = SARSA_DEFAULT_CONFIG.copy() 61 | cfg["discount_factor"] = 0.999 62 | cfg["alpha"] = 0.4 63 | # logging to TensorBoard and write checkpoints (in timesteps) 64 | cfg["experiment"]["write_interval"] = 1600 65 | cfg["experiment"]["checkpoint_interval"] = 8000 66 | cfg["experiment"]["directory"] = "runs/torch/Taxi" 67 | 68 | agent = SARSA(models=models, 69 | memory=None, 70 | cfg=cfg, 71 | observation_space=env.observation_space, 72 | action_space=env.action_space, 73 | device=device) 74 | 75 | 76 | # configure and instantiate the RL trainer 77 | cfg_trainer = {"timesteps": 80000, "headless": True} 78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 79 | 80 | # start training 81 | trainer.train() 82 | -------------------------------------------------------------------------------- /docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_skrl_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Import the skrl components to build the RL system 5 | from skrl.models.torch import Model, GaussianMixin 6 | from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG 7 | from skrl.resources.preprocessors.torch import RunningStandardScaler 8 | from skrl.trainers.torch import SequentialTrainer 9 | from skrl.envs.torch import wrap_env 10 | 11 | 12 | # Define only the policy for evaluation 13 | class Policy(GaussianMixin, Model): 14 | def __init__(self, observation_space, action_space, device, clip_actions=False, 15 | clip_log_std=True, min_log_std=-20, max_log_std=2): 16 | Model.__init__(self, observation_space, action_space, device) 17 | GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) 18 | 19 | self.net = nn.Sequential(nn.Linear(self.num_observations, 256), 20 | nn.ELU(), 21 | nn.Linear(256, 128), 22 | nn.ELU(), 23 | nn.Linear(128, 64), 24 | nn.ELU(), 25 | nn.Linear(64, self.num_actions)) 26 | self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) 27 | 28 | def compute(self, inputs, role): 29 | return self.net(inputs["states"]), self.log_std_parameter, {} 30 | 31 | 32 | # Load the environment 33 | from reaching_iiwa_real_env import ReachingIiwa 34 | 35 | control_space = "joint" # joint or cartesian 36 | env = ReachingIiwa(control_space=control_space) 37 | 38 | # wrap the environment 39 | env = wrap_env(env) 40 | 41 | device = env.device 42 | 43 | 44 | # Instantiate the agent's policy. 45 | # PPO requires 2 models, visit its documentation for more details 46 | # https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#spaces-and-models 47 | models_ppo = {} 48 | models_ppo["policy"] = Policy(env.observation_space, env.action_space, device) 49 | 50 | 51 | # Configure and instantiate the agent. 52 | # Only modify some of the default configuration, visit its documentation to see all the options 53 | # https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#configuration-and-hyperparameters 54 | cfg_ppo = PPO_DEFAULT_CONFIG.copy() 55 | cfg_ppo["random_timesteps"] = 0 56 | cfg_ppo["learning_starts"] = 0 57 | cfg_ppo["state_preprocessor"] = RunningStandardScaler 58 | cfg_ppo["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} 59 | # logging to TensorBoard each 32 timesteps an ignore checkpoints 60 | cfg_ppo["experiment"]["write_interval"] = 32 61 | cfg_ppo["experiment"]["checkpoint_interval"] = 0 62 | 63 | agent = PPO(models=models_ppo, 64 | memory=None, 65 | cfg=cfg_ppo, 66 | observation_space=env.observation_space, 67 | action_space=env.action_space, 68 | device=device) 69 | 70 | # load checkpoints 71 | if control_space == "joint": 72 | agent.load("./agent_joint.pt") 73 | elif control_space == "cartesian": 74 | agent.load("./agent_cartesian.pt") 75 | 76 | 77 | # Configure and instantiate the RL trainer 78 | cfg_trainer = {"timesteps": 1000, "headless": True} 79 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) 80 | 81 | # start evaluation 82 | trainer.eval() 83 | -------------------------------------------------------------------------------- /docs/source/examples/shimmy/jax_shimmy_atari_pong_dqn.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # import the skrl components to build the RL system 8 | from skrl import config 9 | from skrl.agents.jax.dqn import DQN, DQN_DEFAULT_CONFIG 10 | from skrl.envs.wrappers.jax import wrap_env 11 | from skrl.memories.jax import RandomMemory 12 | from skrl.models.jax import DeterministicMixin, Model 13 | from skrl.trainers.jax import SequentialTrainer 14 | from skrl.utils import set_seed 15 | 16 | 17 | config.jax.backend = "numpy" # or "jax" 18 | 19 | 20 | # seed for reproducibility 21 | set_seed() # e.g. `set_seed(42)` for fixed seed 22 | 23 | 24 | # define model (deterministic model) using mixin 25 | class QNetwork(DeterministicMixin, Model): 26 | def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs): 27 | Model.__init__(self, observation_space, action_space, device, **kwargs) 28 | DeterministicMixin.__init__(self, clip_actions) 29 | 30 | @nn.compact 31 | def __call__(self, inputs, role): 32 | x = nn.relu(nn.Dense(64)(inputs["states"])) 33 | x = nn.relu(nn.Dense(64)(x)) 34 | x = nn.Dense(self.num_actions)(x) 35 | return x, {} 36 | 37 | 38 | # load and wrap the environment 39 | env = gym.make("ALE/Pong-v5") 40 | env = wrap_env(env) 41 | 42 | device = env.device 43 | 44 | 45 | # instantiate a memory as experience replay 46 | memory = RandomMemory(memory_size=15000, num_envs=env.num_envs, device=device, replacement=False) 47 | 48 | 49 | # instantiate the agent's models (function approximators). 50 | # DQN requires 2 models, visit its documentation for more details 51 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#models 52 | models = {} 53 | models["q_network"] = QNetwork(env.observation_space, env.action_space, device) 54 | models["target_q_network"] = QNetwork(env.observation_space, env.action_space, device) 55 | 56 | # instantiate models' state dict 57 | for role, model in models.items(): 58 | model.init_state_dict(role) 59 | 60 | # initialize models' parameters (weights and biases) 61 | for model in models.values(): 62 | model.init_parameters(method_name="normal", stddev=0.1) 63 | 64 | 65 | # configure and instantiate the agent (visit its documentation to see all the options) 66 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#configuration-and-hyperparameters 67 | cfg = DQN_DEFAULT_CONFIG.copy() 68 | cfg["learning_starts"] = 100 69 | cfg["exploration"]["final_epsilon"] = 0.04 70 | cfg["exploration"]["timesteps"] = 1500 71 | # logging to TensorBoard and write checkpoints (in timesteps) 72 | cfg["experiment"]["write_interval"] = 1000 73 | cfg["experiment"]["checkpoint_interval"] = 5000 74 | cfg["experiment"]["directory"] = "runs/torch/ALE_Pong" 75 | 76 | agent = DQN(models=models, 77 | memory=memory, 78 | cfg=cfg, 79 | observation_space=env.observation_space, 80 | action_space=env.action_space, 81 | device=device) 82 | 83 | 84 | # configure and instantiate the RL trainer 85 | cfg_trainer = {"timesteps": 50000, "headless": True} 86 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 87 | 88 | # start training 89 | trainer.train() 90 | -------------------------------------------------------------------------------- /docs/source/examples/shimmy/torch_shimmy_atari_pong_dqn.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # import the skrl components to build the RL system 7 | from skrl.agents.torch.dqn import DQN, DQN_DEFAULT_CONFIG 8 | from skrl.envs.wrappers.torch import wrap_env 9 | from skrl.memories.torch import RandomMemory 10 | from skrl.models.torch import DeterministicMixin, Model 11 | from skrl.trainers.torch import SequentialTrainer 12 | from skrl.utils import set_seed 13 | 14 | 15 | # seed for reproducibility 16 | set_seed() # e.g. `set_seed(42)` for fixed seed 17 | 18 | 19 | # define model (deterministic model) using mixin 20 | class QNetwork(DeterministicMixin, Model): 21 | def __init__(self, observation_space, action_space, device, clip_actions=False): 22 | Model.__init__(self, observation_space, action_space, device) 23 | DeterministicMixin.__init__(self, clip_actions) 24 | 25 | self.net = nn.Sequential(nn.Linear(self.num_observations, 64), 26 | nn.ReLU(), 27 | nn.Linear(64, 64), 28 | nn.ReLU(), 29 | nn.Linear(64, self.num_actions)) 30 | 31 | def compute(self, inputs, role): 32 | return self.net(inputs["states"]), {} 33 | 34 | 35 | # load and wrap the environment 36 | env = gym.make("ALE/Pong-v5") 37 | env = wrap_env(env) 38 | 39 | device = env.device 40 | 41 | 42 | # instantiate a memory as experience replay 43 | memory = RandomMemory(memory_size=15000, num_envs=env.num_envs, device=device, replacement=False) 44 | 45 | 46 | # instantiate the agent's models (function approximators). 47 | # DQN requires 2 models, visit its documentation for more details 48 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#models 49 | models = {} 50 | models["q_network"] = QNetwork(env.observation_space, env.action_space, device) 51 | models["target_q_network"] = QNetwork(env.observation_space, env.action_space, device) 52 | 53 | # initialize models' parameters (weights and biases) 54 | for model in models.values(): 55 | model.init_parameters(method_name="normal_", mean=0.0, std=0.1) 56 | 57 | 58 | # configure and instantiate the agent (visit its documentation to see all the options) 59 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#configuration-and-hyperparameters 60 | cfg = DQN_DEFAULT_CONFIG.copy() 61 | cfg["learning_starts"] = 100 62 | cfg["exploration"]["initial_epsilon"] = 1.0 63 | cfg["exploration"]["final_epsilon"] = 0.04 64 | cfg["exploration"]["timesteps"] = 1500 65 | # logging to TensorBoard and write checkpoints (in timesteps) 66 | cfg["experiment"]["write_interval"] = 1000 67 | cfg["experiment"]["checkpoint_interval"] = 5000 68 | cfg["experiment"]["directory"] = "runs/torch/ALE_Pong" 69 | 70 | agent = DQN(models=models, 71 | memory=memory, 72 | cfg=cfg, 73 | observation_space=env.observation_space, 74 | action_space=env.action_space, 75 | device=device) 76 | 77 | 78 | # configure and instantiate the RL trainer 79 | cfg_trainer = {"timesteps": 50000, "headless": True} 80 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) 81 | 82 | # start training 83 | trainer.train() 84 | -------------------------------------------------------------------------------- /docs/source/examples/utils/tensorboard_file_iterator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from skrl.utils import postprocessing 5 | 6 | 7 | labels = [] 8 | rewards = [] 9 | 10 | # load the Tensorboard files and iterate over them (tag: "Reward / Total reward (mean)") 11 | tensorboard_iterator = postprocessing.TensorboardFileIterator("runs/*/events.out.tfevents.*", 12 | tags=["Reward / Total reward (mean)"]) 13 | for dirname, data in tensorboard_iterator: 14 | rewards.append(data["Reward / Total reward (mean)"]) 15 | labels.append(dirname) 16 | 17 | # convert to numpy arrays and compute mean and std 18 | rewards = np.array(rewards) 19 | mean = np.mean(rewards[:,:,1], axis=0) 20 | std = np.std(rewards[:,:,1], axis=0) 21 | 22 | # create two subplots (one for each reward and one for the mean) 23 | fig, ax = plt.subplots(1, 2, figsize=(15, 5)) 24 | 25 | # plot the rewards for each experiment 26 | for reward, label in zip(rewards, labels): 27 | ax[0].plot(reward[:,0], reward[:,1], label=label) 28 | 29 | ax[0].set_title("Total reward (for each experiment)") 30 | ax[0].set_xlabel("Timesteps") 31 | ax[0].set_ylabel("Reward") 32 | ax[0].grid(True) 33 | ax[0].legend() 34 | 35 | # plot the mean and std (across experiments) 36 | ax[1].fill_between(rewards[0,:,0], mean - std, mean + std, alpha=0.5, label="std") 37 | ax[1].plot(rewards[0,:,0], mean, label="mean") 38 | 39 | ax[1].set_title("Total reward (mean and std of all experiments)") 40 | ax[1].set_xlabel("Timesteps") 41 | ax[1].set_ylabel("Reward") 42 | ax[1].grid(True) 43 | ax[1].legend() 44 | 45 | # show and save the figure 46 | plt.show() 47 | plt.savefig("total_reward.png") 48 | -------------------------------------------------------------------------------- /docs/source/snippets/isaacgym_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from isaacgym import gymapi 3 | 4 | from skrl.utils import isaacgym_utils 5 | 6 | 7 | # create a web viewer instance 8 | web_viewer = isaacgym_utils.WebViewer() 9 | 10 | # configure and create simulation 11 | sim_params = gymapi.SimParams() 12 | sim_params.up_axis = gymapi.UP_AXIS_Z 13 | sim_params.gravity = gymapi.Vec3(0.0, 0.0, -9.8) 14 | sim_params.physx.solver_type = 1 15 | sim_params.physx.num_position_iterations = 4 16 | sim_params.physx.num_velocity_iterations = 1 17 | sim_params.physx.use_gpu = True 18 | sim_params.use_gpu_pipeline = True 19 | 20 | gym = gymapi.acquire_gym() 21 | sim = gym.create_sim(compute_device=0, graphics_device=0, type=gymapi.SIM_PHYSX, params=sim_params) 22 | 23 | # setup num_envs and env's grid 24 | num_envs = 1 25 | spacing = 2.0 26 | env_lower = gymapi.Vec3(-spacing, -spacing, 0.0) 27 | env_upper = gymapi.Vec3(spacing, 0.0, spacing) 28 | 29 | # add ground plane 30 | plane_params = gymapi.PlaneParams() 31 | plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0) 32 | gym.add_ground(sim, plane_params) 33 | 34 | envs = [] 35 | cameras = [] 36 | 37 | for i in range(num_envs): 38 | # create env 39 | env = gym.create_env(sim, env_lower, env_upper, int(math.sqrt(num_envs))) 40 | 41 | # add sphere 42 | pose = gymapi.Transform() 43 | pose.p, pose.r = gymapi.Vec3(0.0, 0.0, 1.0), gymapi.Quat(0.0, 0.0, 0.0, 1.0) 44 | gym.create_actor(env, gym.create_sphere(sim, 0.2, None), pose, "sphere", i, 0) 45 | 46 | # add camera 47 | cam_props = gymapi.CameraProperties() 48 | cam_props.width, cam_props.height = 300, 300 49 | cam_handle = gym.create_camera_sensor(env, cam_props) 50 | gym.set_camera_location(cam_handle, env, gymapi.Vec3(1, 1, 1), gymapi.Vec3(0, 0, 0)) 51 | 52 | envs.append(env) 53 | cameras.append(cam_handle) 54 | 55 | # setup web viewer 56 | web_viewer.setup(gym, sim, envs, cameras) 57 | 58 | gym.prepare_sim(sim) 59 | 60 | 61 | for i in range(100000): 62 | gym.simulate(sim) 63 | 64 | # render the scene 65 | web_viewer.render(fetch_results=True, 66 | step_graphics=True, 67 | render_all_camera_sensors=True, 68 | wait_for_page_load=True) 69 | -------------------------------------------------------------------------------- /docs/source/snippets/noises.py: -------------------------------------------------------------------------------- 1 | # [start-base-class-torch] 2 | from typing import Union, Tuple 3 | 4 | import torch 5 | 6 | from skrl.resources.noises.torch import Noise 7 | 8 | 9 | class CustomNoise(Noise): 10 | def __init__(self, device: Union[str, torch.device] = "cuda:0") -> None: 11 | """ 12 | :param device: Device on which a torch tensor is or will be allocated (default: "cuda:0") 13 | :type device: str or torch.device, optional 14 | """ 15 | super().__init__(device) 16 | 17 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: 18 | """Sample noise 19 | 20 | :param size: Shape of the sampled tensor 21 | :type size: tuple or list of integers, or torch.Size 22 | 23 | :return: Sampled noise 24 | :rtype: torch.Tensor 25 | """ 26 | # ================================ 27 | # - sample noise 28 | # ================================ 29 | # [end-base-class-torch] 30 | 31 | 32 | # [start-base-class-jax] 33 | from typing import Optional, Union, Tuple 34 | 35 | import numpy as np 36 | 37 | import jaxlib 38 | import jax.numpy as jnp 39 | 40 | from skrl.resources.noises.torch import Noise 41 | 42 | 43 | class CustomNoise(Noise): 44 | def __init__(self, device: Optional[Union[str, jaxlib.xla_extension.Device]] = None) -> None: 45 | """Custom noise 46 | 47 | :param device: Device on which a jax array is or will be allocated (default: ``None``). 48 | If None, the device will be either ``"cuda:0"`` if available or ``"cpu"`` 49 | :type device: str or jaxlib.xla_extension.Device, optional 50 | """ 51 | super().__init__(device) 52 | 53 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jnp.ndarray]: 54 | """Sample noise 55 | 56 | :param size: Shape of the sampled tensor 57 | :type size: tuple or list of integers 58 | 59 | :return: Sampled noise 60 | :rtype: np.ndarray or jnp.ndarray 61 | """ 62 | # ================================ 63 | # - sample noise 64 | # ================================ 65 | # [end-base-class-jax] 66 | 67 | # ============================================================================= 68 | 69 | # [torch-start-gaussian] 70 | from skrl.resources.noises.torch import GaussianNoise 71 | 72 | cfg = DEFAULT_CONFIG.copy() 73 | cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.2, device="cuda:0") 74 | # [torch-end-gaussian] 75 | 76 | 77 | # [jax-start-gaussian] 78 | from skrl.resources.noises.jax import GaussianNoise 79 | 80 | cfg = DEFAULT_CONFIG.copy() 81 | cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.2) 82 | # [jax-end-gaussian] 83 | 84 | # ============================================================================= 85 | 86 | # [torch-start-ornstein-uhlenbeck] 87 | from skrl.resources.noises.torch import OrnsteinUhlenbeckNoise 88 | 89 | cfg = DEFAULT_CONFIG.copy() 90 | cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.15, sigma=0.2, base_scale=1.0, device="cuda:0") 91 | # [torch-end-ornstein-uhlenbeck] 92 | 93 | 94 | # [jax-start-ornstein-uhlenbeck] 95 | from skrl.resources.noises.jax import OrnsteinUhlenbeckNoise 96 | 97 | cfg = DEFAULT_CONFIG.copy() 98 | cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.15, sigma=0.2, base_scale=1.0) 99 | # [jax-end-ornstein-uhlenbeck] 100 | -------------------------------------------------------------------------------- /docs/source/snippets/tabular_model.py: -------------------------------------------------------------------------------- 1 | # [start-definition-torch] 2 | class TabularModel(TabularMixin, Model): 3 | def __init__(self, observation_space, action_space, device=None, num_envs=1): 4 | Model.__init__(self, observation_space, action_space, device) 5 | TabularMixin.__init__(self, num_envs) 6 | # [end-definition-torch] 7 | 8 | # ============================================================================= 9 | 10 | # [start-epsilon-greedy-torch] 11 | import torch 12 | 13 | from skrl.models.torch import Model, TabularMixin 14 | 15 | 16 | # define the model 17 | class EpilonGreedyPolicy(TabularMixin, Model): 18 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1): 19 | Model.__init__(self, observation_space, action_space, device) 20 | TabularMixin.__init__(self, num_envs) 21 | 22 | self.epsilon = epsilon 23 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), dtype=torch.float32) 24 | 25 | def compute(self, inputs, role): 26 | states = inputs["states"] 27 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), states], 28 | dim=-1, keepdim=True).view(-1,1) 29 | 30 | indexes = (torch.rand(states.shape[0], device=self.device) < self.epsilon).nonzero().view(-1) 31 | if indexes.numel(): 32 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device) 33 | return actions, {} 34 | 35 | 36 | # instantiate the model (assumes there is a wrapped environment: env) 37 | policy = EpilonGreedyPolicy(observation_space=env.observation_space, 38 | action_space=env.action_space, 39 | device=env.device, 40 | num_envs=env.num_envs, 41 | epsilon=0.15) 42 | # [end-epsilon-greedy-torch] 43 | -------------------------------------------------------------------------------- /docs/source/snippets/utils_distributed.txt: -------------------------------------------------------------------------------- 1 | [start-distributed-launcher-jax] 2 | usage: python -m skrl.utils.distributed.jax [-h] [--nnodes NNODES] 3 | [--nproc-per-node NPROC_PER_NODE] [--node-rank NODE_RANK] 4 | [--coordinator-address COORDINATOR_ADDRESS] script ... 5 | 6 | JAX Distributed Training Launcher 7 | 8 | positional arguments: 9 | script Training script path to be launched in parallel 10 | script_args Arguments for the training script 11 | 12 | options: 13 | -h, --help show this help message and exit 14 | --nnodes NNODES Number of nodes 15 | --nproc-per-node NPROC_PER_NODE, --nproc_per_node NPROC_PER_NODE 16 | Number of workers per node 17 | --node-rank NODE_RANK, --node_rank NODE_RANK 18 | Node rank for multi-node distributed training 19 | --coordinator-address COORDINATOR_ADDRESS, --coordinator_address COORDINATOR_ADDRESS 20 | IP address and port where process 0 will start a JAX service 21 | [end-distributed-launcher-jax] 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "skrl" 3 | version = "1.4.3" 4 | description = "Modular and flexible library for reinforcement learning on PyTorch and JAX" 5 | readme = "README.md" 6 | requires-python = ">=3.8" 7 | license = {text = "MIT License"} 8 | authors = [ 9 | {name = "Toni-SM"}, 10 | ] 11 | maintainers = [ 12 | {name = "Toni-SM"}, 13 | ] 14 | keywords = ["reinforcement-learning", "machine-learning", "reinforcement", "machine", "learning", "rl"] 15 | classifiers = [ 16 | "License :: OSI Approved :: MIT License", 17 | "Intended Audience :: Science/Research", 18 | "Topic :: Scientific/Engineering", 19 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 20 | "Programming Language :: Python :: 3", 21 | "Operating System :: OS Independent", 22 | ] 23 | # dependencies / optional-dependencies 24 | dependencies = [ 25 | "gymnasium", 26 | "packaging", 27 | "tensorboard", 28 | "tqdm", 29 | ] 30 | [project.optional-dependencies] 31 | torch = [ 32 | "torch>=1.10", 33 | ] 34 | jax = [ 35 | "jax>=0.4.31", 36 | "jaxlib>=0.4.31", 37 | "flax>=0.9.0", 38 | "optax", 39 | ] 40 | all = [ 41 | "torch>=1.10", 42 | "jax>=0.4.31", 43 | "jaxlib>=0.4.31", 44 | "flax>=0.9.0", 45 | "optax", 46 | ] 47 | tests = [ 48 | "pytest", 49 | "pytest-html", 50 | "pytest-cov", 51 | "hypothesis", 52 | ] 53 | # urls 54 | [project.urls] 55 | "Homepage" = "https://github.com/Toni-SM/skrl" 56 | "Documentation" = "https://skrl.readthedocs.io" 57 | "Discussions" = "https://github.com/Toni-SM/skrl/discussions" 58 | "Bug Reports" = "https://github.com/Toni-SM/skrl/issues" 59 | "Say Thanks!" = "https://github.com/Toni-SM" 60 | "Source" = "https://github.com/Toni-SM/skrl" 61 | 62 | 63 | [tool.black] 64 | line-length = 120 65 | extend-exclude = """ 66 | ( 67 | ^/docs 68 | ) 69 | """ 70 | 71 | 72 | [tool.codespell] 73 | # run: codespell 74 | skip = "./docs/source/_static,./docs/_build,pyproject.toml" 75 | quiet-level = 3 76 | count = true 77 | 78 | 79 | [tool.isort] 80 | profile = "black" 81 | line_length = 120 82 | lines_after_imports = 2 83 | known_test = [ 84 | "warnings", 85 | "hypothesis", 86 | "pytest", 87 | ] 88 | known_annotation = ["typing"] 89 | known_framework = [ 90 | "torch", 91 | "jax", 92 | "jaxlib", 93 | "flax", 94 | "optax", 95 | "numpy", 96 | ] 97 | sections = [ 98 | "FUTURE", 99 | "ANNOTATION", 100 | "TEST", 101 | "STDLIB", 102 | "THIRDPARTY", 103 | "FRAMEWORK", 104 | "FIRSTPARTY", 105 | "LOCALFOLDER", 106 | ] 107 | no_lines_before = "THIRDPARTY" 108 | skip = ["docs"] 109 | -------------------------------------------------------------------------------- /skrl/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/agents/__init__.py -------------------------------------------------------------------------------- /skrl/agents/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.base import Agent 2 | -------------------------------------------------------------------------------- /skrl/agents/jax/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.a2c.a2c import A2C, A2C_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/jax/cem/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.cem.cem import CEM, CEM_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/jax/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.ddpg.ddpg import DDPG, DDPG_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/jax/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.dqn.ddqn import DDQN, DDQN_DEFAULT_CONFIG 2 | from skrl.agents.jax.dqn.dqn import DQN, DQN_DEFAULT_CONFIG 3 | -------------------------------------------------------------------------------- /skrl/agents/jax/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.ppo.ppo import PPO, PPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/jax/rpo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.rpo.rpo import RPO, RPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/jax/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.sac.sac import SAC, SAC_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/jax/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.td3.td3 import TD3, TD3_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.base import Agent 2 | -------------------------------------------------------------------------------- /skrl/agents/torch/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.a2c.a2c import A2C, A2C_DEFAULT_CONFIG 2 | from skrl.agents.torch.a2c.a2c_rnn import A2C_RNN 3 | -------------------------------------------------------------------------------- /skrl/agents/torch/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.amp.amp import AMP, AMP_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/torch/cem/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.cem.cem import CEM, CEM_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/torch/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.ddpg.ddpg import DDPG, DDPG_DEFAULT_CONFIG 2 | from skrl.agents.torch.ddpg.ddpg_rnn import DDPG_RNN 3 | -------------------------------------------------------------------------------- /skrl/agents/torch/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.dqn.ddqn import DDQN, DDQN_DEFAULT_CONFIG 2 | from skrl.agents.torch.dqn.dqn import DQN, DQN_DEFAULT_CONFIG 3 | -------------------------------------------------------------------------------- /skrl/agents/torch/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.ppo.ppo import PPO, PPO_DEFAULT_CONFIG 2 | from skrl.agents.torch.ppo.ppo_rnn import PPO_RNN 3 | -------------------------------------------------------------------------------- /skrl/agents/torch/q_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.q_learning.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/torch/rpo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.rpo.rpo import RPO, RPO_DEFAULT_CONFIG 2 | from skrl.agents.torch.rpo.rpo_rnn import RPO_RNN 3 | -------------------------------------------------------------------------------- /skrl/agents/torch/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.sac.sac import SAC, SAC_DEFAULT_CONFIG 2 | from skrl.agents.torch.sac.sac_rnn import SAC_RNN 3 | -------------------------------------------------------------------------------- /skrl/agents/torch/sarsa/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.sarsa.sarsa import SARSA, SARSA_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/agents/torch/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.td3.td3 import TD3, TD3_DEFAULT_CONFIG 2 | from skrl.agents.torch.td3.td3_rnn import TD3_RNN 3 | -------------------------------------------------------------------------------- /skrl/agents/torch/trpo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.trpo.trpo import TRPO, TRPO_DEFAULT_CONFIG 2 | from skrl.agents.torch.trpo.trpo_rnn import TRPO_RNN 3 | -------------------------------------------------------------------------------- /skrl/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/envs/__init__.py -------------------------------------------------------------------------------- /skrl/envs/jax.py: -------------------------------------------------------------------------------- 1 | # TODO: Delete this file in future releases 2 | 3 | from skrl import logger # isort: skip 4 | 5 | logger.warning("Using `from skrl.envs.jax import ...` is deprecated and will be removed in future versions.") 6 | logger.warning(" - Import loaders using `from skrl.envs.loaders.jax import ...`") 7 | logger.warning(" - Import wrappers using `from skrl.envs.wrappers.jax import ...`") 8 | 9 | 10 | from skrl.envs.loaders.jax import ( 11 | load_bidexhands_env, 12 | load_isaacgym_env_preview2, 13 | load_isaacgym_env_preview3, 14 | load_isaacgym_env_preview4, 15 | load_isaaclab_env, 16 | load_omniverse_isaacgym_env, 17 | ) 18 | from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper, wrap_env 19 | -------------------------------------------------------------------------------- /skrl/envs/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/envs/loaders/__init__.py -------------------------------------------------------------------------------- /skrl/envs/loaders/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.envs.loaders.jax.bidexhands_envs import load_bidexhands_env 2 | from skrl.envs.loaders.jax.isaacgym_envs import ( 3 | load_isaacgym_env_preview2, 4 | load_isaacgym_env_preview3, 5 | load_isaacgym_env_preview4, 6 | ) 7 | from skrl.envs.loaders.jax.isaaclab_envs import load_isaaclab_env 8 | from skrl.envs.loaders.jax.omniverse_isaacgym_envs import load_omniverse_isaacgym_env 9 | -------------------------------------------------------------------------------- /skrl/envs/loaders/jax/bidexhands_envs.py: -------------------------------------------------------------------------------- 1 | # since Bi-DexHands environments are implemented on top of PyTorch, the loader is the same 2 | 3 | from skrl.envs.loaders.torch import load_bidexhands_env 4 | -------------------------------------------------------------------------------- /skrl/envs/loaders/jax/isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | # since Isaac Gym (preview) environments are implemented on top of PyTorch, the loaders are the same 2 | 3 | from skrl.envs.loaders.torch import ( # isort:skip 4 | load_isaacgym_env_preview2, 5 | load_isaacgym_env_preview3, 6 | load_isaacgym_env_preview4, 7 | ) 8 | -------------------------------------------------------------------------------- /skrl/envs/loaders/jax/isaaclab_envs.py: -------------------------------------------------------------------------------- 1 | # since Isaac Lab environments are implemented on top of PyTorch, the loader is the same 2 | 3 | from skrl.envs.loaders.torch import load_isaaclab_env 4 | -------------------------------------------------------------------------------- /skrl/envs/loaders/jax/omniverse_isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | # since Omniverse Isaac Gym environments are implemented on top of PyTorch, the loader is the same 2 | 3 | from skrl.envs.loaders.torch import load_omniverse_isaacgym_env 4 | -------------------------------------------------------------------------------- /skrl/envs/loaders/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.envs.loaders.torch.bidexhands_envs import load_bidexhands_env 2 | from skrl.envs.loaders.torch.isaacgym_envs import ( 3 | load_isaacgym_env_preview2, 4 | load_isaacgym_env_preview3, 5 | load_isaacgym_env_preview4, 6 | ) 7 | from skrl.envs.loaders.torch.isaaclab_envs import load_isaaclab_env 8 | from skrl.envs.loaders.torch.omniverse_isaacgym_envs import load_omniverse_isaacgym_env 9 | -------------------------------------------------------------------------------- /skrl/envs/torch.py: -------------------------------------------------------------------------------- 1 | # TODO: Delete this file in future releases 2 | 3 | from skrl import logger # isort: skip 4 | 5 | logger.warning("Using `from skrl.envs.torch import ...` is deprecated and will be removed in future versions.") 6 | logger.warning(" - Import loaders using `from skrl.envs.loaders.torch import ...`") 7 | logger.warning(" - Import wrappers using `from skrl.envs.wrappers.torch import ...`") 8 | 9 | 10 | from skrl.envs.loaders.torch import ( 11 | load_bidexhands_env, 12 | load_isaacgym_env_preview2, 13 | load_isaacgym_env_preview3, 14 | load_isaacgym_env_preview4, 15 | load_isaaclab_env, 16 | load_omniverse_isaacgym_env, 17 | ) 18 | from skrl.envs.wrappers.torch import MultiAgentEnvWrapper, Wrapper, wrap_env 19 | -------------------------------------------------------------------------------- /skrl/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/envs/wrappers/__init__.py -------------------------------------------------------------------------------- /skrl/envs/wrappers/torch/brax_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import gymnasium 4 | 5 | import torch 6 | 7 | from skrl import logger 8 | from skrl.envs.wrappers.torch.base import Wrapper 9 | from skrl.utils.spaces.torch import ( 10 | convert_gym_space, 11 | flatten_tensorized_space, 12 | tensorize_space, 13 | unflatten_tensorized_space, 14 | ) 15 | 16 | 17 | class BraxWrapper(Wrapper): 18 | def __init__(self, env: Any) -> None: 19 | """Brax environment wrapper 20 | 21 | :param env: The environment to wrap 22 | :type env: Any supported Brax environment 23 | """ 24 | super().__init__(env) 25 | 26 | import brax.envs.wrappers.gym 27 | import brax.envs.wrappers.torch 28 | 29 | env = brax.envs.wrappers.gym.VectorGymWrapper(env) 30 | env = brax.envs.wrappers.torch.TorchWrapper(env, device=self.device) 31 | self._env = env 32 | self._unwrapped = env.unwrapped 33 | 34 | @property 35 | def observation_space(self) -> gymnasium.Space: 36 | """Observation space""" 37 | return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) 38 | 39 | @property 40 | def action_space(self) -> gymnasium.Space: 41 | """Action space""" 42 | return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) 43 | 44 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 45 | """Perform a step in the environment 46 | 47 | :param actions: The actions to perform 48 | :type actions: torch.Tensor 49 | 50 | :return: Observation, reward, terminated, truncated, info 51 | :rtype: tuple of torch.Tensor and any other info 52 | """ 53 | observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) 54 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation)) 55 | truncated = torch.zeros_like(terminated) 56 | return observation, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info 57 | 58 | def reset(self) -> Tuple[torch.Tensor, Any]: 59 | """Reset the environment 60 | 61 | :return: Observation, info 62 | :rtype: torch.Tensor and any other info 63 | """ 64 | observation = self._env.reset() 65 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation)) 66 | return observation, {} 67 | 68 | def render(self, *args, **kwargs) -> None: 69 | """Render the environment""" 70 | frame = self._env.render(mode="rgb_array") 71 | 72 | # render the frame using OpenCV 73 | try: 74 | import cv2 75 | 76 | cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 77 | cv2.waitKey(1) 78 | except ImportError as e: 79 | logger.warning(f"Unable to import opencv-python: {e}. Frame will not be rendered.") 80 | return frame 81 | 82 | def close(self) -> None: 83 | """Close the environment""" 84 | # self._env.close() raises AttributeError: 'VectorGymWrapper' object has no attribute 'closed' 85 | pass 86 | -------------------------------------------------------------------------------- /skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple 2 | 3 | import torch 4 | 5 | from skrl.envs.wrappers.torch.base import Wrapper 6 | from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space 7 | 8 | 9 | class OmniverseIsaacGymWrapper(Wrapper): 10 | def __init__(self, env: Any) -> None: 11 | """Omniverse Isaac Gym environment wrapper 12 | 13 | :param env: The environment to wrap 14 | :type env: Any supported Omniverse Isaac Gym environment 15 | """ 16 | super().__init__(env) 17 | 18 | self._reset_once = True 19 | self._observations = None 20 | self._info = {} 21 | 22 | def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = None) -> None: 23 | """Run the simulation in the main thread 24 | 25 | This method is valid only for the Omniverse Isaac Gym multi-threaded environments 26 | 27 | :param trainer: Trainer which should implement a ``run`` method that initiates the RL loop on a new thread 28 | :type trainer: omni.isaac.gym.vec_env.vec_env_mt.TrainerMT, optional 29 | """ 30 | self._env.run(trainer) 31 | 32 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 33 | """Perform a step in the environment 34 | 35 | :param actions: The actions to perform 36 | :type actions: torch.Tensor 37 | 38 | :return: Observation, reward, terminated, truncated, info 39 | :rtype: tuple of torch.Tensor and any other info 40 | """ 41 | observations, reward, terminated, self._info = self._env.step( 42 | unflatten_tensorized_space(self.action_space, actions) 43 | ) 44 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 45 | truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) 46 | return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info 47 | 48 | def reset(self) -> Tuple[torch.Tensor, Any]: 49 | """Reset the environment 50 | 51 | :return: Observation, info 52 | :rtype: torch.Tensor and any other info 53 | """ 54 | if self._reset_once: 55 | observations = self._env.reset() 56 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 57 | self._reset_once = False 58 | return self._observations, self._info 59 | 60 | def render(self, *args, **kwargs) -> None: 61 | """Render the environment""" 62 | return None 63 | 64 | def close(self) -> None: 65 | """Close the environment""" 66 | self._env.close() 67 | -------------------------------------------------------------------------------- /skrl/memories/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/memories/__init__.py -------------------------------------------------------------------------------- /skrl/memories/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.memories.jax.base import Memory # isort:skip 2 | 3 | from skrl.memories.jax.random import RandomMemory 4 | -------------------------------------------------------------------------------- /skrl/memories/jax/random.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import jax 4 | import numpy as np 5 | 6 | from skrl.memories.jax import Memory 7 | 8 | 9 | class RandomMemory(Memory): 10 | def __init__( 11 | self, 12 | memory_size: int, 13 | num_envs: int = 1, 14 | device: Optional[jax.Device] = None, 15 | export: bool = False, 16 | export_format: str = "pt", 17 | export_directory: str = "", 18 | replacement=True, 19 | ) -> None: 20 | """Random sampling memory 21 | 22 | Sample a batch from memory randomly 23 | 24 | :param memory_size: Maximum number of elements in the first dimension of each internal storage 25 | :type memory_size: int 26 | :param num_envs: Number of parallel environments (default: ``1``) 27 | :type num_envs: int, optional 28 | :param device: Device on which an array is or will be allocated (default: ``None``) 29 | :type device: jax.Device, optional 30 | :param export: Export the memory to a file (default: ``False``). 31 | If True, the memory will be exported when the memory is filled 32 | :type export: bool, optional 33 | :param export_format: Export format (default: ``"pt"``). 34 | Supported formats: torch (pt), numpy (np), comma separated values (csv) 35 | :type export_format: str, optional 36 | :param export_directory: Directory where the memory will be exported (default: ``""``). 37 | If empty, the agent's experiment directory will be used 38 | :type export_directory: str, optional 39 | :param replacement: Flag to indicate whether the sample is with or without replacement (default: ``True``). 40 | Replacement implies that a value can be selected multiple times (the batch size is always guaranteed). 41 | Sampling without replacement will return a batch of maximum memory size if the memory size is less than the requested batch size 42 | :type replacement: bool, optional 43 | 44 | :raises ValueError: The export format is not supported 45 | """ 46 | super().__init__(memory_size, num_envs, device, export, export_format, export_directory) 47 | 48 | self._replacement = replacement 49 | 50 | def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> List[List[jax.Array]]: 51 | """Sample a batch from memory randomly 52 | 53 | :param names: Tensors names from which to obtain the samples 54 | :type names: tuple or list of strings 55 | :param batch_size: Number of element to sample 56 | :type batch_size: int 57 | :param mini_batches: Number of mini-batches to sample (default: ``1``) 58 | :type mini_batches: int, optional 59 | 60 | :return: Sampled data from tensors sorted according to their position in the list of names. 61 | The sampled tensors will have the following shape: (batch size, data size) 62 | :rtype: list of jax.Array list 63 | """ 64 | # generate random indexes 65 | if self._replacement: 66 | indexes = np.random.randint(0, len(self), (batch_size,)) 67 | else: 68 | indexes = np.random.permutation(len(self))[:batch_size] 69 | 70 | return self.sample_by_index(names=names, indexes=indexes, mini_batches=mini_batches) 71 | -------------------------------------------------------------------------------- /skrl/memories/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.memories.torch.base import Memory # isort:skip 2 | 3 | from skrl.memories.torch.random import RandomMemory 4 | -------------------------------------------------------------------------------- /skrl/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/models/__init__.py -------------------------------------------------------------------------------- /skrl/models/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.models.jax.base import Model # isort:skip 2 | 3 | from skrl.models.jax.categorical import CategoricalMixin 4 | from skrl.models.jax.deterministic import DeterministicMixin 5 | from skrl.models.jax.gaussian import GaussianMixin 6 | from skrl.models.jax.multicategorical import MultiCategoricalMixin 7 | -------------------------------------------------------------------------------- /skrl/models/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.models.torch.base import Model # isort:skip 2 | 3 | from skrl.models.torch.categorical import CategoricalMixin 4 | from skrl.models.torch.deterministic import DeterministicMixin 5 | from skrl.models.torch.gaussian import GaussianMixin 6 | from skrl.models.torch.multicategorical import MultiCategoricalMixin 7 | from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin 8 | from skrl.models.torch.tabular import TabularMixin 9 | -------------------------------------------------------------------------------- /skrl/multi_agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/multi_agents/__init__.py -------------------------------------------------------------------------------- /skrl/multi_agents/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.jax.base import MultiAgent 2 | -------------------------------------------------------------------------------- /skrl/multi_agents/jax/ippo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.jax.ippo.ippo import IPPO, IPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/multi_agents/jax/mappo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.jax.mappo.mappo import MAPPO, MAPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/multi_agents/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.torch.base import MultiAgent 2 | -------------------------------------------------------------------------------- /skrl/multi_agents/torch/ippo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.torch.ippo.ippo import IPPO, IPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/multi_agents/torch/mappo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.torch.mappo.mappo import MAPPO, MAPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /skrl/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/__init__.py -------------------------------------------------------------------------------- /skrl/resources/noises/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/noises/__init__.py -------------------------------------------------------------------------------- /skrl/resources/noises/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.noises.jax.base import Noise # isort:skip 2 | 3 | from skrl.resources.noises.jax.gaussian import GaussianNoise 4 | from skrl.resources.noises.jax.ornstein_uhlenbeck import OrnsteinUhlenbeckNoise 5 | -------------------------------------------------------------------------------- /skrl/resources/noises/jax/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import jax 4 | import numpy as np 5 | 6 | from skrl import config 7 | 8 | 9 | class Noise: 10 | def __init__(self, device: Optional[Union[str, jax.Device]] = None) -> None: 11 | """Base class representing a noise 12 | 13 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 14 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 15 | :type device: str or jax.Device, optional 16 | 17 | Custom noises should override the ``sample`` method:: 18 | 19 | import jax 20 | from skrl.resources.noises.jax import Noise 21 | 22 | class CustomNoise(Noise): 23 | def __init__(self, device=None): 24 | super().__init__(device) 25 | 26 | def sample(self, size): 27 | return jax.random.uniform(jax.random.PRNGKey(0), size) 28 | """ 29 | self._jax = config.jax.backend == "jax" 30 | 31 | self.device = config.jax.parse_device(device) 32 | 33 | def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]: 34 | """Sample a noise with the same size (shape) as the input tensor 35 | 36 | This method will call the sampling method as follows ``.sample(tensor.shape)`` 37 | 38 | :param tensor: Input tensor used to determine output tensor size (shape) 39 | :type tensor: np.ndarray or jax.Array 40 | 41 | :return: Sampled noise 42 | :rtype: np.ndarray or jax.Array 43 | 44 | Example:: 45 | 46 | >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) 47 | >>> noise.sample_like(x) 48 | Array([[0.57450044, 0.09968603], 49 | [0.7419659 , 0.8941783 ], 50 | [0.59656656, 0.45325184]], dtype=float32) 51 | """ 52 | return self.sample(tensor.shape) 53 | 54 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jax.Array]: 55 | """Noise sampling method to be implemented by the inheriting classes 56 | 57 | :param size: Shape of the sampled tensor 58 | :type size: tuple or list of int 59 | 60 | :raises NotImplementedError: The method is not implemented by the inheriting classes 61 | 62 | :return: Sampled noise 63 | :rtype: np.ndarray or jax.Array 64 | """ 65 | raise NotImplementedError("The sampling method (.sample()) is not implemented") 66 | -------------------------------------------------------------------------------- /skrl/resources/noises/jax/gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from skrl import config 10 | from skrl.resources.noises.jax import Noise 11 | 12 | 13 | # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function 14 | @partial(jax.jit, static_argnames=("shape")) 15 | def _sample(mean, std, key, iterator, shape): 16 | subkey = jax.random.fold_in(key, iterator) 17 | return jax.random.normal(subkey, shape) * std + mean 18 | 19 | 20 | class GaussianNoise(Noise): 21 | def __init__(self, mean: float, std: float, device: Optional[Union[str, jax.Device]] = None) -> None: 22 | """Class representing a Gaussian noise 23 | 24 | :param mean: Mean of the normal distribution 25 | :type mean: float 26 | :param std: Standard deviation of the normal distribution 27 | :type std: float 28 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 29 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 30 | :type device: str or jax.Device, optional 31 | 32 | Example:: 33 | 34 | >>> noise = GaussianNoise(mean=0, std=1) 35 | """ 36 | super().__init__(device) 37 | 38 | if self._jax: 39 | self._i = 0 40 | self._key = config.jax.key 41 | self.mean = jnp.array(mean) 42 | self.std = jnp.array(std) 43 | else: 44 | self.mean = np.array(mean) 45 | self.std = np.array(std) 46 | 47 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jax.Array]: 48 | """Sample a Gaussian noise 49 | 50 | :param size: Shape of the sampled tensor 51 | :type size: tuple or list of int 52 | 53 | :return: Sampled noise 54 | :rtype: np.ndarray or jax.Array 55 | 56 | Example:: 57 | 58 | >>> noise.sample((3, 2)) 59 | Array([[ 0.01878439, -0.12833427], 60 | [ 0.06494182, 0.12490594], 61 | [ 0.024447 , -0.01174496]], dtype=float32) 62 | 63 | >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) 64 | >>> noise.sample(x.shape) 65 | Array([[ 0.17988093, -1.2289404 ], 66 | [ 0.6218886 , 1.1961104 ], 67 | [ 0.23410667, -0.11247082]], dtype=float32) 68 | """ 69 | if self._jax: 70 | self._i += 1 71 | return _sample(self.mean, self.std, self._key, self._i, size) 72 | return np.random.normal(self.mean, self.std, size) 73 | -------------------------------------------------------------------------------- /skrl/resources/noises/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.noises.torch.base import Noise # isort:skip 2 | 3 | from skrl.resources.noises.torch.gaussian import GaussianNoise 4 | from skrl.resources.noises.torch.ornstein_uhlenbeck import OrnsteinUhlenbeckNoise 5 | -------------------------------------------------------------------------------- /skrl/resources/noises/torch/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | 5 | from skrl import config 6 | 7 | 8 | class Noise: 9 | def __init__(self, device: Optional[Union[str, torch.device]] = None) -> None: 10 | """Base class representing a noise 11 | 12 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 13 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 14 | :type device: str or torch.device, optional 15 | 16 | Custom noises should override the ``sample`` method:: 17 | 18 | import torch 19 | from skrl.resources.noises.torch import Noise 20 | 21 | class CustomNoise(Noise): 22 | def __init__(self, device=None): 23 | super().__init__(device) 24 | 25 | def sample(self, size): 26 | return torch.rand(size, device=self.device) 27 | """ 28 | self.device = config.torch.parse_device(device) 29 | 30 | def sample_like(self, tensor: torch.Tensor) -> torch.Tensor: 31 | """Sample a noise with the same size (shape) as the input tensor 32 | 33 | This method will call the sampling method as follows ``.sample(tensor.shape)`` 34 | 35 | :param tensor: Input tensor used to determine output tensor size (shape) 36 | :type tensor: torch.Tensor 37 | 38 | :return: Sampled noise 39 | :rtype: torch.Tensor 40 | 41 | Example:: 42 | 43 | >>> x = torch.rand(3, 2, device="cuda:0") 44 | >>> noise.sample_like(x) 45 | tensor([[-0.0423, -0.1325], 46 | [-0.0639, -0.0957], 47 | [-0.1367, 0.1031]], device='cuda:0') 48 | """ 49 | return self.sample(tensor.shape) 50 | 51 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: 52 | """Noise sampling method to be implemented by the inheriting classes 53 | 54 | :param size: Shape of the sampled tensor 55 | :type size: tuple or list of int, or torch.Size 56 | 57 | :raises NotImplementedError: The method is not implemented by the inheriting classes 58 | 59 | :return: Sampled noise 60 | :rtype: torch.Tensor 61 | """ 62 | raise NotImplementedError("The sampling method (.sample()) is not implemented") 63 | -------------------------------------------------------------------------------- /skrl/resources/noises/torch/gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch.distributions import Normal 5 | 6 | from skrl.resources.noises.torch import Noise 7 | 8 | 9 | # speed up distribution construction by disabling checking 10 | Normal.set_default_validate_args(False) 11 | 12 | 13 | class GaussianNoise(Noise): 14 | def __init__(self, mean: float, std: float, device: Optional[Union[str, torch.device]] = None) -> None: 15 | """Class representing a Gaussian noise 16 | 17 | :param mean: Mean of the normal distribution 18 | :type mean: float 19 | :param std: Standard deviation of the normal distribution 20 | :type std: float 21 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 22 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 23 | :type device: str or torch.device, optional 24 | 25 | Example:: 26 | 27 | >>> noise = GaussianNoise(mean=0, std=1) 28 | """ 29 | super().__init__(device) 30 | 31 | self.distribution = Normal( 32 | loc=torch.tensor(mean, device=self.device, dtype=torch.float32), 33 | scale=torch.tensor(std, device=self.device, dtype=torch.float32), 34 | ) 35 | 36 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: 37 | """Sample a Gaussian noise 38 | 39 | :param size: Shape of the sampled tensor 40 | :type size: tuple or list of int, or torch.Size 41 | 42 | :return: Sampled noise 43 | :rtype: torch.Tensor 44 | 45 | Example:: 46 | 47 | >>> noise.sample((3, 2)) 48 | tensor([[-0.4901, 1.3357], 49 | [-1.2141, 0.3323], 50 | [-0.0889, -1.1651]], device='cuda:0') 51 | 52 | >>> x = torch.rand(3, 2, device="cuda:0") 53 | >>> noise.sample(x.shape) 54 | tensor([[0.5398, 1.2009], 55 | [0.0307, 1.3065], 56 | [0.2082, 0.6116]], device='cuda:0') 57 | """ 58 | return self.distribution.sample(size) 59 | -------------------------------------------------------------------------------- /skrl/resources/noises/torch/ornstein_uhlenbeck.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch.distributions import Normal 5 | 6 | from skrl.resources.noises.torch import Noise 7 | 8 | 9 | # speed up distribution construction by disabling checking 10 | Normal.set_default_validate_args(False) 11 | 12 | 13 | class OrnsteinUhlenbeckNoise(Noise): 14 | def __init__( 15 | self, 16 | theta: float, 17 | sigma: float, 18 | base_scale: float, 19 | mean: float = 0, 20 | std: float = 1, 21 | device: Optional[Union[str, torch.device]] = None, 22 | ) -> None: 23 | """Class representing an Ornstein-Uhlenbeck noise 24 | 25 | :param theta: Factor to apply to current internal state 26 | :type theta: float 27 | :param sigma: Factor to apply to the normal distribution 28 | :type sigma: float 29 | :param base_scale: Factor to apply to returned noise 30 | :type base_scale: float 31 | :param mean: Mean of the normal distribution (default: ``0.0``) 32 | :type mean: float, optional 33 | :param std: Standard deviation of the normal distribution (default: ``1.0``) 34 | :type std: float, optional 35 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 36 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 37 | :type device: str or torch.device, optional 38 | 39 | Example:: 40 | 41 | >>> noise = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=0.5) 42 | """ 43 | super().__init__(device) 44 | 45 | self.state = 0 46 | self.theta = theta 47 | self.sigma = sigma 48 | self.base_scale = base_scale 49 | 50 | self.distribution = Normal( 51 | loc=torch.tensor(mean, device=self.device, dtype=torch.float32), 52 | scale=torch.tensor(std, device=self.device, dtype=torch.float32), 53 | ) 54 | 55 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: 56 | """Sample an Ornstein-Uhlenbeck noise 57 | 58 | :param size: Shape of the sampled tensor 59 | :type size: tuple or list of int, or torch.Size 60 | 61 | :return: Sampled noise 62 | :rtype: torch.Tensor 63 | 64 | Example:: 65 | 66 | >>> noise.sample((3, 2)) 67 | tensor([[-0.0452, 0.0162], 68 | [ 0.0649, -0.0708], 69 | [-0.0211, 0.0066]], device='cuda:0') 70 | 71 | >>> x = torch.rand(3, 2, device="cuda:0") 72 | >>> noise.sample(x.shape) 73 | tensor([[-0.0540, 0.0461], 74 | [ 0.1117, -0.1157], 75 | [-0.0074, 0.0420]], device='cuda:0') 76 | """ 77 | if hasattr(self.state, "shape") and self.state.shape != torch.Size(size): 78 | self.state = 0 79 | self.state += -self.state * self.theta + self.sigma * self.distribution.sample(size) 80 | 81 | return self.base_scale * self.state 82 | -------------------------------------------------------------------------------- /skrl/resources/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/optimizers/__init__.py -------------------------------------------------------------------------------- /skrl/resources/optimizers/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.optimizers.jax.adam import Adam 2 | -------------------------------------------------------------------------------- /skrl/resources/preprocessors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/preprocessors/__init__.py -------------------------------------------------------------------------------- /skrl/resources/preprocessors/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.preprocessors.jax.running_standard_scaler import RunningStandardScaler 2 | -------------------------------------------------------------------------------- /skrl/resources/preprocessors/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.preprocessors.torch.running_standard_scaler import RunningStandardScaler 2 | -------------------------------------------------------------------------------- /skrl/resources/schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/schedulers/__init__.py -------------------------------------------------------------------------------- /skrl/resources/schedulers/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.schedulers.jax.kl_adaptive import KLAdaptiveLR, kl_adaptive 2 | 3 | 4 | KLAdaptiveRL = KLAdaptiveLR # known typo (compatibility with versions prior to 1.0.0) 5 | -------------------------------------------------------------------------------- /skrl/resources/schedulers/jax/kl_adaptive.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | import optax 5 | 6 | 7 | def KLAdaptiveLR( 8 | kl_threshold: float = 0.008, 9 | min_lr: float = 1e-6, 10 | max_lr: float = 1e-2, 11 | kl_factor: float = 2, 12 | lr_factor: float = 1.5, 13 | ) -> optax.Schedule: 14 | """Adaptive KL scheduler 15 | 16 | Adjusts the learning rate according to the KL divergence. 17 | The implementation is adapted from the *rl_games* library. 18 | 19 | .. note:: 20 | 21 | This scheduler is only available for PPO at the moment. 22 | Applying it to other agents will not change the learning rate 23 | 24 | Example:: 25 | 26 | >>> scheduler = KLAdaptiveLR(kl_threshold=0.01) 27 | >>> for epoch in range(100): 28 | >>> # ... 29 | >>> kl_divergence = ... 30 | >>> new_lr = scheduler(timestep, lr, kl_divergence) 31 | 32 | :param kl_threshold: Threshold for KL divergence (default: ``0.008``) 33 | :type kl_threshold: float, optional 34 | :param min_lr: Lower bound for learning rate (default: ``1e-6``) 35 | :type min_lr: float, optional 36 | :param max_lr: Upper bound for learning rate (default: ``1e-2``) 37 | :type max_lr: float, optional 38 | :param kl_factor: The number used to modify the KL divergence threshold (default: ``2``) 39 | :type kl_factor: float, optional 40 | :param lr_factor: The number used to modify the learning rate (default: ``1.5``) 41 | :type lr_factor: float, optional 42 | 43 | :return: A function that maps step counts, current learning rate and KL divergence to the new learning rate value. 44 | If no learning rate is specified, 1.0 will be returned to mimic the Optax's scheduler behaviors. 45 | If the learning rate is specified but the KL divergence is not 0, the specified learning rate is returned. 46 | :rtype: optax.Schedule 47 | """ 48 | 49 | def schedule(count: int, lr: Optional[float] = None, kl: Optional[Union[np.ndarray, float]] = None) -> float: 50 | if lr is None: 51 | return 1.0 52 | if kl is not None: 53 | if kl > kl_threshold * kl_factor: 54 | lr = max(lr / lr_factor, min_lr) 55 | elif kl < kl_threshold / kl_factor: 56 | lr = min(lr * lr_factor, max_lr) 57 | return lr 58 | 59 | return schedule 60 | 61 | 62 | # Alias to maintain naming compatibility with Optax schedulers 63 | # https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html 64 | kl_adaptive = KLAdaptiveLR 65 | -------------------------------------------------------------------------------- /skrl/resources/schedulers/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.schedulers.torch.kl_adaptive import KLAdaptiveLR 2 | 3 | 4 | KLAdaptiveRL = KLAdaptiveLR # known typo (compatibility with versions prior to 1.0.0) 5 | -------------------------------------------------------------------------------- /skrl/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/trainers/__init__.py -------------------------------------------------------------------------------- /skrl/trainers/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.trainers.jax.base import Trainer, generate_equally_spaced_scopes # isort:skip 2 | 3 | from skrl.trainers.jax.sequential import SequentialTrainer 4 | from skrl.trainers.jax.step import StepTrainer 5 | -------------------------------------------------------------------------------- /skrl/trainers/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.trainers.torch.base import Trainer, generate_equally_spaced_scopes # isort:skip 2 | 3 | from skrl.trainers.torch.parallel import ParallelTrainer 4 | from skrl.trainers.torch.sequential import SequentialTrainer 5 | from skrl.trainers.torch.step import StepTrainer 6 | -------------------------------------------------------------------------------- /skrl/utils/control.py: -------------------------------------------------------------------------------- 1 | import isaacgym.torch_utils as torch_utils 2 | 3 | import torch 4 | 5 | 6 | def ik( 7 | jacobian_end_effector, current_position, current_orientation, goal_position, goal_orientation, damping_factor=0.05 8 | ): 9 | """ 10 | Damped Least Squares method: https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf 11 | """ 12 | 13 | # compute position and orientation error 14 | position_error = goal_position - current_position 15 | q_r = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation)) 16 | orientation_error = q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1) 17 | 18 | dpose = torch.cat([position_error, orientation_error], -1).unsqueeze(-1) 19 | 20 | # solve damped least squares (dO = J.T * V) 21 | transpose = torch.transpose(jacobian_end_effector, 1, 2) 22 | lmbda = torch.eye(6).to(jacobian_end_effector.device) * (damping_factor**2) 23 | return transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ dpose 24 | 25 | 26 | def osc( 27 | jacobian_end_effector, 28 | mass_matrix, 29 | current_position, 30 | current_orientation, 31 | goal_position, 32 | goal_orientation, 33 | current_dof_velocities, 34 | kp=5, 35 | kv=2, 36 | ): 37 | """ 38 | https://studywolf.wordpress.com/2013/09/17/robot-control-4-operation-space-control/ 39 | """ 40 | 41 | mass_matrix_end_effector = torch.inverse( 42 | jacobian_end_effector @ torch.inverse(mass_matrix) @ torch.transpose(jacobian_end_effector, 1, 2) 43 | ) 44 | 45 | # compute position and orientation error 46 | position_error = kp * (goal_position - current_position) 47 | q_r = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation)) 48 | orientation_error = q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1) 49 | 50 | dpose = torch.cat([position_error, orientation_error], -1) 51 | 52 | return ( 53 | torch.transpose(jacobian_end_effector, 1, 2) @ mass_matrix_end_effector @ (kp * dpose).unsqueeze(-1) 54 | - kv * mass_matrix @ current_dof_velocities 55 | ) 56 | -------------------------------------------------------------------------------- /skrl/utils/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/distributed/__init__.py -------------------------------------------------------------------------------- /skrl/utils/distributed/jax/__main__.py: -------------------------------------------------------------------------------- 1 | from . import launcher 2 | 3 | 4 | if __name__ == "__main__": 5 | launcher.launch() 6 | -------------------------------------------------------------------------------- /skrl/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | from skrl import __version__, logger 2 | 3 | 4 | def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> str: 5 | """Download a model from Hugging Face Hub 6 | 7 | :param repo_id: Hugging Face user or organization name and a repo name separated by a ``/`` 8 | :type repo_id: str 9 | :param filename: The name of the model file in the repo (default: ``"agent.pt"``) 10 | :type filename: str, optional 11 | 12 | :raises ImportError: The Hugging Face Hub package (huggingface-hub) is not installed 13 | :raises huggingface_hub.utils._errors.HfHubHTTPError: Any HTTP error raised in Hugging Face Hub 14 | 15 | :return: Local path of file or if networking is off, last version of file cached on disk 16 | :rtype: str 17 | 18 | Example:: 19 | 20 | # download trained agent from the skrl organization (https://huggingface.co/skrl) 21 | >>> from skrl.utils.huggingface import download_model_from_huggingface 22 | >>> download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO") 23 | '/home/user/.cache/huggingface/hub/models--skrl--OmniIsaacGymEnvs-Cartpole-PPO/snapshots/892e629903de6bf3ef102ae760406a5dd0f6f873/agent.pt' 24 | 25 | # download model (e.g. "policy.pth") from another user/organization (e.g. "org/ddpg-Pendulum-v1") 26 | >>> from skrl.utils.huggingface import download_model_from_huggingface 27 | >>> download_model_from_huggingface("org/ddpg-Pendulum-v1", "policy.pth") 28 | '/home/user/.cache/huggingface/hub/models--org--ddpg-Pendulum-v1/snapshots/b44ee96f93ff2e296156b002a2ca4646e197ba32/policy.pth' 29 | """ 30 | logger.info(f"Downloading model from Hugging Face Hub: {repo_id}/{filename}") 31 | try: 32 | import huggingface_hub 33 | except ImportError: 34 | logger.error("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") 35 | huggingface_hub = None 36 | 37 | if huggingface_hub is None: 38 | raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") 39 | 40 | # download and cache the model from Hugging Face Hub 41 | downloaded_model_file = huggingface_hub.hf_hub_download( 42 | repo_id=repo_id, filename=filename, library_name="skrl", library_version=__version__ 43 | ) 44 | 45 | return downloaded_model_file 46 | -------------------------------------------------------------------------------- /skrl/utils/model_instantiators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/model_instantiators/__init__.py -------------------------------------------------------------------------------- /skrl/utils/model_instantiators/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from skrl.utils.model_instantiators.jax.categorical import categorical_model 4 | from skrl.utils.model_instantiators.jax.deterministic import deterministic_model 5 | from skrl.utils.model_instantiators.jax.gaussian import gaussian_model 6 | from skrl.utils.model_instantiators.jax.multicategorical import multicategorical_model 7 | 8 | 9 | # keep for compatibility with versions prior to 1.3.0 10 | class Shape(Enum): 11 | """ 12 | Enum to select the shape of the model's inputs and outputs 13 | """ 14 | 15 | ONE = 1 16 | STATES = 0 17 | OBSERVATIONS = 0 18 | ACTIONS = -1 19 | STATES_ACTIONS = -2 20 | -------------------------------------------------------------------------------- /skrl/utils/model_instantiators/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from skrl.utils.model_instantiators.torch.categorical import categorical_model 4 | from skrl.utils.model_instantiators.torch.deterministic import deterministic_model 5 | from skrl.utils.model_instantiators.torch.gaussian import gaussian_model 6 | from skrl.utils.model_instantiators.torch.multicategorical import multicategorical_model 7 | from skrl.utils.model_instantiators.torch.multivariate_gaussian import multivariate_gaussian_model 8 | from skrl.utils.model_instantiators.torch.shared import shared_model 9 | 10 | 11 | # keep for compatibility with versions prior to 1.3.0 12 | class Shape(Enum): 13 | """ 14 | Enum to select the shape of the model's inputs and outputs 15 | """ 16 | 17 | ONE = 1 18 | STATES = 0 19 | OBSERVATIONS = 0 20 | ACTIONS = -1 21 | STATES_ACTIONS = -2 22 | -------------------------------------------------------------------------------- /skrl/utils/runner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/runner/__init__.py -------------------------------------------------------------------------------- /skrl/utils/runner/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.runner.jax.runner import Runner 2 | -------------------------------------------------------------------------------- /skrl/utils/runner/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.runner.torch.runner import Runner 2 | -------------------------------------------------------------------------------- /skrl/utils/spaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/spaces/__init__.py -------------------------------------------------------------------------------- /skrl/utils/spaces/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.spaces.jax.spaces import ( 2 | compute_space_size, 3 | convert_gym_space, 4 | flatten_tensorized_space, 5 | sample_space, 6 | tensorize_space, 7 | unflatten_tensorized_space, 8 | untensorize_space, 9 | ) 10 | -------------------------------------------------------------------------------- /skrl/utils/spaces/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.spaces.torch.spaces import ( 2 | compute_space_size, 3 | convert_gym_space, 4 | flatten_tensorized_space, 5 | sample_space, 6 | tensorize_space, 7 | unflatten_tensorized_space, 8 | untensorize_space, 9 | ) 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/__init__.py -------------------------------------------------------------------------------- /tests/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/agents/__init__.py -------------------------------------------------------------------------------- /tests/agents/jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/agents/jax/__init__.py -------------------------------------------------------------------------------- /tests/agents/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/agents/torch/__init__.py -------------------------------------------------------------------------------- /tests/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/__init__.py -------------------------------------------------------------------------------- /tests/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/wrappers/__init__.py -------------------------------------------------------------------------------- /tests/envs/wrappers/jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/wrappers/jax/__init__.py -------------------------------------------------------------------------------- /tests/envs/wrappers/jax/test_brax_envs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import warnings 3 | 4 | from collections.abc import Mapping 5 | import gymnasium 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | 11 | from skrl import config 12 | from skrl.envs.wrappers.jax import BraxWrapper, wrap_env 13 | 14 | from ....utilities import is_running_on_github_actions 15 | 16 | 17 | @pytest.mark.parametrize("backend", ["jax", "numpy"]) 18 | def test_env(capsys: pytest.CaptureFixture, backend: str): 19 | config.jax.backend = backend 20 | Array = jax.Array if backend == "jax" else np.ndarray 21 | 22 | num_envs = 10 23 | action = jnp.ones((num_envs, 1)) if backend == "jax" else np.ones((num_envs, 1)) 24 | 25 | # load wrap the environment 26 | try: 27 | import brax.envs 28 | except ImportError as e: 29 | if is_running_on_github_actions(): 30 | raise e 31 | else: 32 | pytest.skip(f"Unable to import Brax environment: {e}") 33 | 34 | original_env = brax.envs.create("inverted_pendulum", batch_size=num_envs, backend="spring") 35 | env = wrap_env(original_env, "auto") 36 | assert isinstance(env, BraxWrapper) 37 | env = wrap_env(original_env, "brax") 38 | assert isinstance(env, BraxWrapper) 39 | 40 | # check properties 41 | assert env.state_space is None 42 | assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (4,) 43 | assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) 44 | assert isinstance(env.num_envs, int) and env.num_envs == num_envs 45 | assert isinstance(env.num_agents, int) and env.num_agents == 1 46 | assert isinstance(env.device, jax.Device) 47 | # check internal properties 48 | assert env._env is not original_env # # brax's VectorGymWrapper interferes with the checking (it has _env) 49 | assert env._unwrapped is not original_env.unwrapped # brax's VectorGymWrapper interferes with the checking 50 | # check methods 51 | for _ in range(2): 52 | observation, info = env.reset() 53 | observation, info = env.reset() # edge case: parallel environments are autoreset 54 | assert isinstance(observation, Array) and observation.shape == (num_envs, 4) 55 | assert isinstance(info, Mapping) 56 | for _ in range(3): 57 | observation, reward, terminated, truncated, info = env.step(action) 58 | try: 59 | env.render() 60 | except AttributeError as e: 61 | warnings.warn(f"Brax exception when rendering: {e}") 62 | assert isinstance(observation, Array) and observation.shape == (num_envs, 4) 63 | assert isinstance(reward, Array) and reward.shape == (num_envs, 1) 64 | assert isinstance(terminated, Array) and terminated.shape == (num_envs, 1) 65 | assert isinstance(truncated, Array) and truncated.shape == (num_envs, 1) 66 | assert isinstance(info, Mapping) 67 | 68 | env.close() 69 | -------------------------------------------------------------------------------- /tests/envs/wrappers/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/wrappers/torch/__init__.py -------------------------------------------------------------------------------- /tests/envs/wrappers/torch/test_brax_envs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import warnings 3 | 4 | from collections.abc import Mapping 5 | import gymnasium 6 | 7 | import torch 8 | 9 | from skrl.envs.wrappers.torch import BraxWrapper, wrap_env 10 | 11 | from ....utilities import is_running_on_github_actions 12 | 13 | 14 | def test_env(capsys: pytest.CaptureFixture): 15 | num_envs = 10 16 | action = torch.ones((num_envs, 1)) 17 | 18 | # load wrap the environment 19 | try: 20 | import brax.envs 21 | except ImportError as e: 22 | if is_running_on_github_actions(): 23 | raise e 24 | else: 25 | pytest.skip(f"Unable to import Brax environment: {e}") 26 | 27 | original_env = brax.envs.create("inverted_pendulum", batch_size=num_envs, backend="spring") 28 | env = wrap_env(original_env, "auto") 29 | assert isinstance(env, BraxWrapper) 30 | env = wrap_env(original_env, "brax") 31 | assert isinstance(env, BraxWrapper) 32 | 33 | # check properties 34 | assert env.state_space is None 35 | assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (4,) 36 | assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,) 37 | assert isinstance(env.num_envs, int) and env.num_envs == num_envs 38 | assert isinstance(env.num_agents, int) and env.num_agents == 1 39 | assert isinstance(env.device, torch.device) 40 | # check internal properties 41 | assert env._env is not original_env # # brax's VectorGymWrapper interferes with the checking (it has _env) 42 | assert env._unwrapped is not original_env.unwrapped # brax's VectorGymWrapper interferes with the checking 43 | # check methods 44 | for _ in range(2): 45 | observation, info = env.reset() 46 | observation, info = env.reset() # edge case: parallel environments are autoreset 47 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 4]) 48 | assert isinstance(info, Mapping) 49 | for _ in range(3): 50 | observation, reward, terminated, truncated, info = env.step(action) 51 | try: 52 | env.render() 53 | except AttributeError as e: 54 | warnings.warn(f"Brax exception when rendering: {e}") 55 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 4]) 56 | assert isinstance(reward, torch.Tensor) and reward.shape == torch.Size([num_envs, 1]) 57 | assert isinstance(terminated, torch.Tensor) and terminated.shape == torch.Size([num_envs, 1]) 58 | assert isinstance(truncated, torch.Tensor) and truncated.shape == torch.Size([num_envs, 1]) 59 | assert isinstance(info, Mapping) 60 | 61 | env.close() 62 | -------------------------------------------------------------------------------- /tests/envs/wrappers/torch/test_deepmind_envs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import warnings 3 | 4 | from collections.abc import Mapping 5 | import gymnasium as gym 6 | 7 | import torch 8 | 9 | from skrl.envs.wrappers.torch import DeepMindWrapper, wrap_env 10 | 11 | from ....utilities import is_running_on_github_actions 12 | 13 | 14 | def test_env(capsys: pytest.CaptureFixture): 15 | num_envs = 1 16 | action = torch.ones((num_envs, 1)) 17 | 18 | # load wrap the environment 19 | try: 20 | from dm_control import suite 21 | except ImportError as e: 22 | if is_running_on_github_actions(): 23 | raise e 24 | else: 25 | pytest.skip(f"Unable to import DeepMind environment: {e}") 26 | 27 | original_env = suite.load(domain_name="pendulum", task_name="swingup") 28 | env = wrap_env(original_env, "auto") 29 | assert isinstance(env, DeepMindWrapper) 30 | env = wrap_env(original_env, "dm") 31 | assert isinstance(env, DeepMindWrapper) 32 | 33 | # check properties 34 | assert env.state_space is None 35 | assert isinstance(env.observation_space, gym.Space) and sorted(list(env.observation_space.keys())) == [ 36 | "orientation", 37 | "velocity", 38 | ] 39 | assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,) 40 | assert isinstance(env.num_envs, int) and env.num_envs == num_envs 41 | assert isinstance(env.num_agents, int) and env.num_agents == 1 42 | assert isinstance(env.device, torch.device) 43 | # check internal properties 44 | assert env._env is original_env 45 | assert env._unwrapped is original_env 46 | # check methods 47 | for _ in range(2): 48 | observation, info = env.reset() 49 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 3]) 50 | assert isinstance(info, Mapping) 51 | for _ in range(3): 52 | observation, reward, terminated, truncated, info = env.step(action) 53 | if not is_running_on_github_actions(): 54 | env.render() 55 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 3]) 56 | assert isinstance(reward, torch.Tensor) and reward.shape == torch.Size([num_envs, 1]) 57 | assert isinstance(terminated, torch.Tensor) and terminated.shape == torch.Size([num_envs, 1]) 58 | assert isinstance(truncated, torch.Tensor) and truncated.shape == torch.Size([num_envs, 1]) 59 | assert isinstance(info, Mapping) 60 | 61 | env.close() 62 | -------------------------------------------------------------------------------- /tests/memories/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/memories/__init__.py -------------------------------------------------------------------------------- /tests/memories/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/memories/torch/__init__.py -------------------------------------------------------------------------------- /tests/memories/torch/test_base.py: -------------------------------------------------------------------------------- 1 | import hypothesis 2 | import hypothesis.strategies as st 3 | import pytest 4 | 5 | import torch 6 | 7 | from skrl import config 8 | from skrl.memories.torch import Memory 9 | 10 | 11 | @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) 12 | def test_device(capsys, device): 13 | memory = Memory(memory_size=5, num_envs=1, device=device) 14 | memory.create_tensor("buffer", size=1) 15 | 16 | target_device = config.torch.parse_device(device) 17 | assert memory.device == target_device 18 | assert memory.get_tensor_by_name("buffer").device == target_device 19 | 20 | 21 | # __len__ 22 | 23 | 24 | def test_share_memory(capsys): 25 | memory = Memory(memory_size=5, num_envs=1, device="cuda") 26 | memory.create_tensor("buffer", size=1) 27 | 28 | memory.share_memory() 29 | 30 | 31 | @hypothesis.given( 32 | tensor_names=st.lists( 33 | st.text(st.characters(codec="ascii", categories=("Nd", "L")), min_size=1, max_size=5), # codespell:ignore 34 | min_size=0, 35 | max_size=5, 36 | unique=True, 37 | ) 38 | ) 39 | @hypothesis.settings( 40 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], 41 | deadline=None, 42 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate], 43 | ) 44 | def test_get_tensor_names(capsys, tensor_names): 45 | memory = Memory(memory_size=5, num_envs=1) 46 | for name in tensor_names: 47 | memory.create_tensor(name, size=1) 48 | 49 | assert memory.get_tensor_names() == sorted(tensor_names) 50 | 51 | 52 | @hypothesis.given( 53 | tensor_name=st.text( 54 | st.characters(codec="ascii", categories=("Nd", "L")), min_size=1, max_size=5 # codespell:ignore 55 | ) 56 | ) 57 | @hypothesis.settings( 58 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], 59 | deadline=None, 60 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate], 61 | ) 62 | @pytest.mark.parametrize("keepdim", [True, False]) 63 | def test_get_tensor_by_name(capsys, tensor_name, keepdim): 64 | memory = Memory(memory_size=5, num_envs=2) 65 | memory.create_tensor(tensor_name, size=1) 66 | 67 | target_shape = (5, 2, 1) if keepdim else (10, 1) 68 | assert memory.get_tensor_by_name(tensor_name, keepdim=keepdim).shape == target_shape 69 | 70 | 71 | @hypothesis.given( 72 | tensor_name=st.text( 73 | st.characters(codec="ascii", categories=("Nd", "L")), min_size=1, max_size=5 # codespell:ignore 74 | ) 75 | ) 76 | @hypothesis.settings( 77 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], 78 | deadline=None, 79 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate], 80 | ) 81 | def test_set_tensor_by_name(capsys, tensor_name): 82 | memory = Memory(memory_size=5, num_envs=2) 83 | memory.create_tensor(tensor_name, size=1) 84 | 85 | target_tensor = torch.arange(10, device=memory.device).reshape(5, 2, 1) 86 | memory.set_tensor_by_name(tensor_name, target_tensor) 87 | assert torch.any(memory.get_tensor_by_name(tensor_name, keepdim=True) == target_tensor) 88 | -------------------------------------------------------------------------------- /tests/strategies.py: -------------------------------------------------------------------------------- 1 | import hypothesis.strategies as st 2 | 3 | import gymnasium 4 | 5 | 6 | @st.composite 7 | def gymnasium_space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> gymnasium.spaces.Space: 8 | if not space_type: 9 | space_type = draw(st.sampled_from(["Box", "Discrete", "MultiDiscrete", "Dict", "Tuple"])) 10 | # recursion base case 11 | if remaining_iterations <= 0 and space_type in ["Dict", "Tuple"]: 12 | space_type = "Box" 13 | 14 | if space_type == "Box": 15 | shape = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) 16 | return gymnasium.spaces.Box(low=-1, high=1, shape=shape) 17 | elif space_type == "Discrete": 18 | n = draw(st.integers(min_value=1, max_value=5)) 19 | return gymnasium.spaces.Discrete(n) 20 | elif space_type == "MultiDiscrete": 21 | nvec = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) 22 | return gymnasium.spaces.MultiDiscrete(nvec) 23 | elif space_type == "Dict": 24 | remaining_iterations -= 1 25 | keys = draw(st.lists(st.text(st.characters(codec="ascii"), min_size=1, max_size=5), min_size=1, max_size=3)) 26 | spaces = {key: draw(gymnasium_space_stategy(remaining_iterations=remaining_iterations)) for key in keys} 27 | return gymnasium.spaces.Dict(spaces) 28 | elif space_type == "Tuple": 29 | remaining_iterations -= 1 30 | spaces = draw( 31 | st.lists(gymnasium_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3) 32 | ) 33 | return gymnasium.spaces.Tuple(spaces) 34 | else: 35 | raise ValueError(f"Invalid space type: {space_type}") 36 | 37 | 38 | @st.composite 39 | def gym_space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> "gym.spaces.Space": 40 | import gym 41 | 42 | if not space_type: 43 | space_type = draw(st.sampled_from(["Box", "Discrete", "MultiDiscrete", "Dict", "Tuple"])) 44 | # recursion base case 45 | if remaining_iterations <= 0 and space_type in ["Dict", "Tuple"]: 46 | space_type = "Box" 47 | 48 | if space_type == "Box": 49 | shape = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) 50 | return gym.spaces.Box(low=-1, high=1, shape=shape) 51 | elif space_type == "Discrete": 52 | n = draw(st.integers(min_value=1, max_value=5)) 53 | return gym.spaces.Discrete(n) 54 | elif space_type == "MultiDiscrete": 55 | nvec = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5)) 56 | return gym.spaces.MultiDiscrete(nvec) 57 | elif space_type == "Dict": 58 | remaining_iterations -= 1 59 | keys = draw(st.lists(st.text(st.characters(codec="ascii"), min_size=1, max_size=5), min_size=1, max_size=3)) 60 | spaces = {key: draw(gym_space_stategy(remaining_iterations=remaining_iterations)) for key in keys} 61 | return gym.spaces.Dict(spaces) 62 | elif space_type == "Tuple": 63 | remaining_iterations -= 1 64 | spaces = draw(st.lists(gym_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3)) 65 | return gym.spaces.Tuple(spaces) 66 | else: 67 | raise ValueError(f"Invalid space type: {space_type}") 68 | -------------------------------------------------------------------------------- /tests/test_torch_config.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import hypothesis 4 | import hypothesis.strategies as st 5 | import pytest 6 | 7 | import os 8 | 9 | import torch 10 | 11 | from skrl import _Config, config 12 | 13 | 14 | @pytest.mark.parametrize("device", [None, "cpu", "cuda", "cuda:0", "cuda:10", "edge-case"]) 15 | @pytest.mark.parametrize("validate", [True, False]) 16 | def test_parse_device(capsys, device: Union[str, None], validate: bool): 17 | target_device = None 18 | if device in [None, "edge-case"]: 19 | target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | elif device.startswith("cuda"): 21 | if validate and int(f"{device}:0".split(":")[1]) >= torch.cuda.device_count(): 22 | target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | if not target_device: 24 | target_device = torch.device(device) 25 | 26 | runtime_device = config.torch.parse_device(device, validate=validate) 27 | assert runtime_device == target_device 28 | 29 | 30 | @pytest.mark.parametrize("device", [None, "cpu", "cuda", "cuda:0", "cuda:10", "edge-case"]) 31 | def test_device(capsys, device: Union[str, None]): 32 | if device in [None, "edge-case"]: 33 | target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 34 | else: 35 | target_device = torch.device(device) 36 | 37 | # check setter/getter 38 | config.torch.device = device 39 | assert config.torch.device == target_device 40 | 41 | 42 | @hypothesis.given( 43 | local_rank=st.integers(), 44 | rank=st.integers(), 45 | world_size=st.integers(), 46 | ) 47 | @hypothesis.settings( 48 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], 49 | deadline=None, 50 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate], 51 | ) 52 | def test_distributed(capsys, local_rank: int, rank: int, world_size: int): 53 | os.environ["LOCAL_RANK"] = str(local_rank) 54 | os.environ["RANK"] = str(rank) 55 | os.environ["WORLD_SIZE"] = str(world_size) 56 | is_distributed = world_size > 1 57 | 58 | if is_distributed: 59 | with pytest.raises(ValueError, match="Error initializing torch.distributed"): 60 | config = _Config() 61 | return 62 | else: 63 | config = _Config() 64 | assert config.torch.local_rank == local_rank 65 | assert config.torch.rank == rank 66 | assert config.torch.world_size == world_size 67 | assert config.torch.is_distributed == is_distributed 68 | assert config.torch._device == f"cuda:{local_rank}" 69 | -------------------------------------------------------------------------------- /tests/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import gymnasium 4 | 5 | import numpy as np 6 | 7 | 8 | def is_device_available(device, *, backend) -> bool: 9 | if backend == "torch": 10 | import torch 11 | 12 | try: 13 | torch.zeros((1,), device=device) 14 | except Exception as e: 15 | return False 16 | else: 17 | raise ValueError(f"Invalid backend: {backend}") 18 | return True 19 | 20 | 21 | def is_running_on_github_actions() -> bool: 22 | return os.environ.get("GITHUB_ACTIONS") is not None 23 | 24 | 25 | def get_test_mixed_precision(default): 26 | value = os.environ.get("SKRL_TEST_MIXED_PRECISION") 27 | if value is None: 28 | return False 29 | if value.lower() in ["true", "1", "y", "yes"]: 30 | return default 31 | if value.lower() in ["false", "0", "n", "no"]: 32 | return False 33 | raise ValueError(f"Invalid value for environment variable SKRL_TEST_MIXED_PRECISION: {value}") 34 | 35 | 36 | class BaseEnv(gymnasium.Env): 37 | def __init__(self, observation_space, action_space, num_envs, device): 38 | self.device = device 39 | self.num_envs = num_envs 40 | self.action_space = action_space 41 | self.observation_space = observation_space 42 | 43 | def _sample_observation(self): 44 | raise NotImplementedError 45 | 46 | def step(self, actions): 47 | if self.num_envs == 1: 48 | rewards = random.random() 49 | terminated = random.random() > 0.95 50 | truncated = random.random() > 0.95 51 | else: 52 | rewards = np.random.random((self.num_envs,)) 53 | terminated = np.random.random((self.num_envs,)) > 0.95 54 | truncated = np.random.random((self.num_envs,)) > 0.95 55 | 56 | return self._sample_observation(), rewards, terminated, truncated, {} 57 | 58 | def reset(self): 59 | return self._sample_observation(), {} 60 | 61 | def render(self, *args, **kwargs): 62 | pass 63 | 64 | def close(self, *args, **kwargs): 65 | pass 66 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/model_instantiators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/model_instantiators/__init__.py -------------------------------------------------------------------------------- /tests/utils/model_instantiators/jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/model_instantiators/jax/__init__.py -------------------------------------------------------------------------------- /tests/utils/model_instantiators/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/model_instantiators/torch/__init__.py -------------------------------------------------------------------------------- /tests/utils/spaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/spaces/__init__.py -------------------------------------------------------------------------------- /tests/utils/spaces/jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/spaces/jax/__init__.py -------------------------------------------------------------------------------- /tests/utils/spaces/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/spaces/torch/__init__.py --------------------------------------------------------------------------------