├── .github
├── GITHUB_ACTIONS.md
├── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ └── config.yml
└── workflows
│ ├── pre-commit.yml
│ ├── python-publish-manual.yml
│ ├── tests-jax.yml
│ └── tests-torch.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── README.md
├── make.bat
├── requirements.txt
└── source
│ ├── 404.rst
│ ├── _static
│ ├── css
│ │ ├── s5defs-roles.css
│ │ └── skrl.css
│ ├── data
│ │ ├── 404-dark.svg
│ │ ├── 404-light.svg
│ │ ├── favicon.ico
│ │ ├── logo-dark-mode.png
│ │ ├── logo-jax.svg
│ │ ├── logo-light-mode.png
│ │ ├── logo-torch.svg
│ │ ├── skrl-up-transparent.png
│ │ └── skrl-up.png
│ └── imgs
│ │ ├── data_tensorboard.jpg
│ │ ├── example_bidexhands.png
│ │ ├── example_deepmind.png
│ │ ├── example_gym.png
│ │ ├── example_isaacgym.png
│ │ ├── example_isaaclab.png
│ │ ├── example_omniverse_isaacgym.png
│ │ ├── example_parallel.jpg
│ │ ├── example_robosuite.png
│ │ ├── example_shimmy.png
│ │ ├── manual_trainer-dark.svg
│ │ ├── manual_trainer-light.svg
│ │ ├── model_categorical-dark.svg
│ │ ├── model_categorical-light.svg
│ │ ├── model_categorical_cnn-dark.svg
│ │ ├── model_categorical_cnn-light.svg
│ │ ├── model_categorical_mlp-dark.svg
│ │ ├── model_categorical_mlp-light.svg
│ │ ├── model_categorical_rnn-dark.svg
│ │ ├── model_categorical_rnn-light.svg
│ │ ├── model_deterministic-dark.svg
│ │ ├── model_deterministic-light.svg
│ │ ├── model_deterministic_cnn-dark.svg
│ │ ├── model_deterministic_cnn-light.svg
│ │ ├── model_deterministic_mlp-dark.svg
│ │ ├── model_deterministic_mlp-light.svg
│ │ ├── model_deterministic_rnn-dark.svg
│ │ ├── model_deterministic_rnn-light.svg
│ │ ├── model_gaussian-dark.svg
│ │ ├── model_gaussian-light.svg
│ │ ├── model_gaussian_cnn-dark.svg
│ │ ├── model_gaussian_cnn-light.svg
│ │ ├── model_gaussian_mlp-dark.svg
│ │ ├── model_gaussian_mlp-light.svg
│ │ ├── model_gaussian_rnn-dark.svg
│ │ ├── model_gaussian_rnn-light.svg
│ │ ├── model_multicategorical-dark.svg
│ │ ├── model_multicategorical-light.svg
│ │ ├── model_multivariate_gaussian-dark.svg
│ │ ├── model_multivariate_gaussian-light.svg
│ │ ├── multi_agent_wrapping-dark.svg
│ │ ├── multi_agent_wrapping-light.svg
│ │ ├── noise_gaussian.png
│ │ ├── noise_ornstein_uhlenbeck.png
│ │ ├── parallel_trainer-dark.svg
│ │ ├── parallel_trainer-light.svg
│ │ ├── rl_schema-dark.svg
│ │ ├── rl_schema-light.svg
│ │ ├── sequential_trainer-dark.svg
│ │ ├── sequential_trainer-light.svg
│ │ ├── utils_tensorboard_file_iterator.svg
│ │ ├── wrapping-dark.svg
│ │ └── wrapping-light.svg
│ ├── api
│ ├── agents.rst
│ ├── agents
│ │ ├── a2c.rst
│ │ ├── amp.rst
│ │ ├── cem.rst
│ │ ├── ddpg.rst
│ │ ├── ddqn.rst
│ │ ├── dqn.rst
│ │ ├── ppo.rst
│ │ ├── q_learning.rst
│ │ ├── rpo.rst
│ │ ├── sac.rst
│ │ ├── sarsa.rst
│ │ ├── td3.rst
│ │ └── trpo.rst
│ ├── config
│ │ └── frameworks.rst
│ ├── envs.rst
│ ├── envs
│ │ ├── isaac_gym.rst
│ │ ├── isaaclab.rst
│ │ ├── multi_agents_wrapping.rst
│ │ ├── omniverse_isaac_gym.rst
│ │ └── wrapping.rst
│ ├── memories.rst
│ ├── memories
│ │ └── random.rst
│ ├── models.rst
│ ├── models
│ │ ├── categorical.rst
│ │ ├── deterministic.rst
│ │ ├── gaussian.rst
│ │ ├── multicategorical.rst
│ │ ├── multivariate_gaussian.rst
│ │ ├── shared_model.rst
│ │ └── tabular.rst
│ ├── multi_agents.rst
│ ├── multi_agents
│ │ ├── ippo.rst
│ │ └── mappo.rst
│ ├── resources.rst
│ ├── resources
│ │ ├── noises.rst
│ │ ├── noises
│ │ │ ├── gaussian.rst
│ │ │ └── ornstein_uhlenbeck.rst
│ │ ├── optimizers.rst
│ │ ├── optimizers
│ │ │ └── adam.rst
│ │ ├── preprocessors.rst
│ │ ├── preprocessors
│ │ │ └── running_standard_scaler.rst
│ │ ├── schedulers.rst
│ │ └── schedulers
│ │ │ └── kl_adaptive.rst
│ ├── trainers.rst
│ ├── trainers
│ │ ├── manual.rst
│ │ ├── parallel.rst
│ │ ├── sequential.rst
│ │ └── step.rst
│ ├── utils.rst
│ └── utils
│ │ ├── distributed.rst
│ │ ├── huggingface.rst
│ │ ├── isaacgym_utils.rst
│ │ ├── model_instantiators.rst
│ │ ├── omniverse_isaacgym_utils.rst
│ │ ├── postprocessing.rst
│ │ ├── runner.rst
│ │ ├── seed.rst
│ │ └── spaces.rst
│ ├── conf.py
│ ├── examples
│ ├── bidexhands
│ │ ├── jax_bidexhands_shadow_hand_over_ippo.py
│ │ ├── jax_bidexhands_shadow_hand_over_mappo.py
│ │ ├── torch_bidexhands_shadow_hand_over_ippo.py
│ │ └── torch_bidexhands_shadow_hand_over_mappo.py
│ ├── deepmind
│ │ ├── dm_manipulation_stack_sac.py
│ │ └── dm_suite_cartpole_swingup_ddpg.py
│ ├── gym
│ │ ├── jax_gym_cartpole_cem.py
│ │ ├── jax_gym_cartpole_dqn.py
│ │ ├── jax_gym_cartpole_vector_dqn.py
│ │ ├── jax_gym_pendulum_ddpg.py
│ │ ├── jax_gym_pendulum_ppo.py
│ │ ├── jax_gym_pendulum_sac.py
│ │ ├── jax_gym_pendulum_td3.py
│ │ ├── jax_gym_pendulum_vector_ddpg.py
│ │ ├── torch_gym_cartpole_cem.py
│ │ ├── torch_gym_cartpole_dqn.py
│ │ ├── torch_gym_cartpole_vector_dqn.py
│ │ ├── torch_gym_frozen_lake_q_learning.py
│ │ ├── torch_gym_frozen_lake_vector_q_learning.py
│ │ ├── torch_gym_pendulum_ddpg.py
│ │ ├── torch_gym_pendulum_ppo.py
│ │ ├── torch_gym_pendulum_sac.py
│ │ ├── torch_gym_pendulum_td3.py
│ │ ├── torch_gym_pendulum_trpo.py
│ │ ├── torch_gym_pendulum_vector_ddpg.py
│ │ ├── torch_gym_pendulumnovel_ddpg.py
│ │ ├── torch_gym_pendulumnovel_ddpg_gru.py
│ │ ├── torch_gym_pendulumnovel_ddpg_lstm.py
│ │ ├── torch_gym_pendulumnovel_ddpg_rnn.py
│ │ ├── torch_gym_pendulumnovel_ppo.py
│ │ ├── torch_gym_pendulumnovel_ppo_gru.py
│ │ ├── torch_gym_pendulumnovel_ppo_lstm.py
│ │ ├── torch_gym_pendulumnovel_ppo_rnn.py
│ │ ├── torch_gym_pendulumnovel_sac.py
│ │ ├── torch_gym_pendulumnovel_sac_gru.py
│ │ ├── torch_gym_pendulumnovel_sac_lstm.py
│ │ ├── torch_gym_pendulumnovel_sac_rnn.py
│ │ ├── torch_gym_pendulumnovel_td3.py
│ │ ├── torch_gym_pendulumnovel_td3_gru.py
│ │ ├── torch_gym_pendulumnovel_td3_lstm.py
│ │ ├── torch_gym_pendulumnovel_td3_rnn.py
│ │ ├── torch_gym_pendulumnovel_trpo.py
│ │ ├── torch_gym_pendulumnovel_trpo_gru.py
│ │ ├── torch_gym_pendulumnovel_trpo_lstm.py
│ │ ├── torch_gym_pendulumnovel_trpo_rnn.py
│ │ ├── torch_gym_taxi_sarsa.py
│ │ └── torch_gym_taxi_vector_sarsa.py
│ ├── gymnasium
│ │ ├── jax_gymnasium_cartpole_cem.py
│ │ ├── jax_gymnasium_cartpole_dqn.py
│ │ ├── jax_gymnasium_cartpole_vector_dqn.py
│ │ ├── jax_gymnasium_pendulum_ddpg.py
│ │ ├── jax_gymnasium_pendulum_ppo.py
│ │ ├── jax_gymnasium_pendulum_sac.py
│ │ ├── jax_gymnasium_pendulum_td3.py
│ │ ├── jax_gymnasium_pendulum_vector_ddpg.py
│ │ ├── torch_gymnasium_cartpole_cem.py
│ │ ├── torch_gymnasium_cartpole_dqn.py
│ │ ├── torch_gymnasium_cartpole_vector_dqn.py
│ │ ├── torch_gymnasium_frozen_lake_q_learning.py
│ │ ├── torch_gymnasium_frozen_lake_vector_q_learning.py
│ │ ├── torch_gymnasium_pendulum_ddpg.py
│ │ ├── torch_gymnasium_pendulum_ppo.py
│ │ ├── torch_gymnasium_pendulum_sac.py
│ │ ├── torch_gymnasium_pendulum_td3.py
│ │ ├── torch_gymnasium_pendulum_trpo.py
│ │ ├── torch_gymnasium_pendulum_vector_ddpg.py
│ │ ├── torch_gymnasium_pendulumnovel_ddpg.py
│ │ ├── torch_gymnasium_pendulumnovel_ddpg_gru.py
│ │ ├── torch_gymnasium_pendulumnovel_ddpg_lstm.py
│ │ ├── torch_gymnasium_pendulumnovel_ddpg_rnn.py
│ │ ├── torch_gymnasium_pendulumnovel_ppo.py
│ │ ├── torch_gymnasium_pendulumnovel_ppo_gru.py
│ │ ├── torch_gymnasium_pendulumnovel_ppo_lstm.py
│ │ ├── torch_gymnasium_pendulumnovel_ppo_rnn.py
│ │ ├── torch_gymnasium_pendulumnovel_sac.py
│ │ ├── torch_gymnasium_pendulumnovel_sac_gru.py
│ │ ├── torch_gymnasium_pendulumnovel_sac_lstm.py
│ │ ├── torch_gymnasium_pendulumnovel_sac_rnn.py
│ │ ├── torch_gymnasium_pendulumnovel_td3.py
│ │ ├── torch_gymnasium_pendulumnovel_td3_gru.py
│ │ ├── torch_gymnasium_pendulumnovel_td3_lstm.py
│ │ ├── torch_gymnasium_pendulumnovel_td3_rnn.py
│ │ ├── torch_gymnasium_pendulumnovel_trpo.py
│ │ ├── torch_gymnasium_pendulumnovel_trpo_gru.py
│ │ ├── torch_gymnasium_pendulumnovel_trpo_lstm.py
│ │ ├── torch_gymnasium_pendulumnovel_trpo_rnn.py
│ │ ├── torch_gymnasium_taxi_sarsa.py
│ │ └── torch_gymnasium_taxi_vector_sarsa.py
│ ├── isaacgym
│ │ ├── jax_allegro_hand_ppo.py
│ │ ├── jax_ant_ddpg.py
│ │ ├── jax_ant_ppo.py
│ │ ├── jax_ant_sac.py
│ │ ├── jax_ant_td3.py
│ │ ├── jax_anymal_ppo.py
│ │ ├── jax_anymal_terrain_ppo.py
│ │ ├── jax_ball_balance_ppo.py
│ │ ├── jax_cartpole_ppo.py
│ │ ├── jax_factory_task_nut_bolt_pick_ppo.py
│ │ ├── jax_factory_task_nut_bolt_place_ppo.py
│ │ ├── jax_factory_task_nut_bolt_screw_ppo.py
│ │ ├── jax_franka_cabinet_ppo.py
│ │ ├── jax_franka_cube_stack_ppo.py
│ │ ├── jax_humanoid_ppo.py
│ │ ├── jax_ingenuity_ppo.py
│ │ ├── jax_quadcopter_ppo.py
│ │ ├── jax_shadow_hand_ppo.py
│ │ ├── jax_trifinger_ppo.py
│ │ ├── torch_allegro_hand_ppo.py
│ │ ├── torch_allegro_kuka_ppo.py
│ │ ├── torch_ant_ddpg.py
│ │ ├── torch_ant_ddpg_td3_sac_parallel_unshared_memory.py
│ │ ├── torch_ant_ddpg_td3_sac_sequential_shared_memory.py
│ │ ├── torch_ant_ddpg_td3_sac_sequential_unshared_memory.py
│ │ ├── torch_ant_ppo.py
│ │ ├── torch_ant_sac.py
│ │ ├── torch_ant_td3.py
│ │ ├── torch_anymal_ppo.py
│ │ ├── torch_anymal_terrain_ppo.py
│ │ ├── torch_ball_balance_ppo.py
│ │ ├── torch_cartpole_ppo.py
│ │ ├── torch_factory_task_nut_bolt_pick_ppo.py
│ │ ├── torch_factory_task_nut_bolt_place_ppo.py
│ │ ├── torch_factory_task_nut_bolt_screw_ppo.py
│ │ ├── torch_franka_cabinet_ppo.py
│ │ ├── torch_franka_cube_stack_ppo.py
│ │ ├── torch_humanoid_amp.py
│ │ ├── torch_humanoid_ppo.py
│ │ ├── torch_ingenuity_ppo.py
│ │ ├── torch_quadcopter_ppo.py
│ │ ├── torch_shadow_hand_ppo.py
│ │ ├── torch_trifinger_ppo.py
│ │ └── trpo_cartpole.py
│ ├── isaaclab
│ │ ├── jax_ant_ddpg.py
│ │ ├── jax_ant_ppo.py
│ │ ├── jax_ant_sac.py
│ │ ├── jax_ant_td3.py
│ │ ├── jax_cartpole_ppo.py
│ │ ├── jax_humanoid_ppo.py
│ │ ├── jax_lift_franka_ppo.py
│ │ ├── jax_reach_franka_ppo.py
│ │ ├── jax_velocity_anymal_c_ppo.py
│ │ ├── torch_ant_ddpg.py
│ │ ├── torch_ant_ppo.py
│ │ ├── torch_ant_sac.py
│ │ ├── torch_ant_td3.py
│ │ ├── torch_cartpole_ppo.py
│ │ ├── torch_humanoid_ppo.py
│ │ ├── torch_lift_franka_ppo.py
│ │ ├── torch_reach_franka_ppo.py
│ │ └── torch_velocity_anymal_c_ppo.py
│ ├── isaacsim
│ │ ├── torch_isaacsim_cartpole_ppo.py
│ │ └── torch_isaacsim_jetbot_ppo.py
│ ├── omniisaacgym
│ │ ├── jax_allegro_hand_ppo.py
│ │ ├── jax_ant_ddpg.py
│ │ ├── jax_ant_mt_ppo.py
│ │ ├── jax_ant_ppo.py
│ │ ├── jax_ant_sac.py
│ │ ├── jax_ant_td3.py
│ │ ├── jax_anymal_ppo.py
│ │ ├── jax_anymal_terrain_ppo.py
│ │ ├── jax_ball_balance_ppo.py
│ │ ├── jax_cartpole_mt_ppo.py
│ │ ├── jax_cartpole_ppo.py
│ │ ├── jax_crazyflie_ppo.py
│ │ ├── jax_factory_task_nut_bolt_pick_ppo.py
│ │ ├── jax_franka_cabinet_ppo.py
│ │ ├── jax_humanoid_ppo.py
│ │ ├── jax_ingenuity_ppo.py
│ │ ├── jax_quadcopter_ppo.py
│ │ ├── jax_shadow_hand_ppo.py
│ │ ├── torch_allegro_hand_ppo.py
│ │ ├── torch_ant_ddpg.py
│ │ ├── torch_ant_ddpg_td3_sac_parallel_unshared_memory.py
│ │ ├── torch_ant_ddpg_td3_sac_sequential_shared_memory.py
│ │ ├── torch_ant_ddpg_td3_sac_sequential_unshared_memory.py
│ │ ├── torch_ant_mt_ppo.py
│ │ ├── torch_ant_ppo.py
│ │ ├── torch_ant_sac.py
│ │ ├── torch_ant_td3.py
│ │ ├── torch_anymal_ppo.py
│ │ ├── torch_anymal_terrain_ppo.py
│ │ ├── torch_ball_balance_ppo.py
│ │ ├── torch_cartpole_mt_ppo.py
│ │ ├── torch_cartpole_ppo.py
│ │ ├── torch_crazyflie_ppo.py
│ │ ├── torch_factory_task_nut_bolt_pick_ppo.py
│ │ ├── torch_franka_cabinet_ppo.py
│ │ ├── torch_humanoid_ppo.py
│ │ ├── torch_ingenuity_ppo.py
│ │ ├── torch_quadcopter_ppo.py
│ │ └── torch_shadow_hand_ppo.py
│ ├── real_world
│ │ ├── franka_emika_panda
│ │ │ ├── reaching_franka_isaacgym_env.py
│ │ │ ├── reaching_franka_isaacgym_skrl_eval.py
│ │ │ ├── reaching_franka_isaacgym_skrl_train.py
│ │ │ ├── reaching_franka_omniverse_isaacgym_env.py
│ │ │ ├── reaching_franka_omniverse_isaacgym_skrl_eval.py
│ │ │ ├── reaching_franka_omniverse_isaacgym_skrl_train.py
│ │ │ ├── reaching_franka_real_env.py
│ │ │ └── reaching_franka_real_skrl_eval.py
│ │ └── kuka_lbr_iiwa
│ │ │ ├── reaching_iiwa_omniverse_isaacgym_env.py
│ │ │ ├── reaching_iiwa_omniverse_isaacgym_skrl_eval.py
│ │ │ ├── reaching_iiwa_omniverse_isaacgym_skrl_train.py
│ │ │ ├── reaching_iiwa_real_env.py
│ │ │ ├── reaching_iiwa_real_ros2_env.py
│ │ │ ├── reaching_iiwa_real_ros_env.py
│ │ │ ├── reaching_iiwa_real_ros_ros2_skrl_eval.py
│ │ │ └── reaching_iiwa_real_skrl_eval.py
│ ├── robosuite
│ │ └── td3_robosuite_two_arm_lift.py
│ ├── shimmy
│ │ ├── jax_shimmy_atari_pong_dqn.py
│ │ ├── jax_shimmy_dm_control_acrobot_swingup_sparse_sac.py
│ │ ├── jax_shimmy_openai_gym_compatibility_pendulum_ddpg.py
│ │ ├── torch_shimmy_atari_pong_dqn.py
│ │ ├── torch_shimmy_dm_control_acrobot_swingup_sparse_sac.py
│ │ └── torch_shimmy_openai_gym_compatibility_pendulum_ddpg.py
│ └── utils
│ │ └── tensorboard_file_iterator.py
│ ├── index.rst
│ ├── intro
│ ├── data.rst
│ ├── examples.rst
│ ├── getting_started.rst
│ └── installation.rst
│ └── snippets
│ ├── agent.py
│ ├── agents_basic_usage.py
│ ├── categorical_model.py
│ ├── data.py
│ ├── deterministic_model.py
│ ├── gaussian_model.py
│ ├── isaacgym_utils.py
│ ├── loaders.py
│ ├── memories.py
│ ├── model_instantiators.txt
│ ├── model_mixin.py
│ ├── multi_agent.py
│ ├── multi_agents_basic_usage.py
│ ├── multicategorical_model.py
│ ├── multivariate_gaussian_model.py
│ ├── noises.py
│ ├── runner.txt
│ ├── shared_model.py
│ ├── tabular_model.py
│ ├── trainer.py
│ ├── utils_distributed.txt
│ ├── utils_postprocessing.py
│ └── wrapping.py
├── pyproject.toml
├── skrl
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── jax
│ │ ├── __init__.py
│ │ ├── a2c
│ │ │ ├── __init__.py
│ │ │ └── a2c.py
│ │ ├── base.py
│ │ ├── cem
│ │ │ ├── __init__.py
│ │ │ └── cem.py
│ │ ├── ddpg
│ │ │ ├── __init__.py
│ │ │ └── ddpg.py
│ │ ├── dqn
│ │ │ ├── __init__.py
│ │ │ ├── ddqn.py
│ │ │ └── dqn.py
│ │ ├── ppo
│ │ │ ├── __init__.py
│ │ │ └── ppo.py
│ │ ├── rpo
│ │ │ ├── __init__.py
│ │ │ └── rpo.py
│ │ ├── sac
│ │ │ ├── __init__.py
│ │ │ └── sac.py
│ │ └── td3
│ │ │ ├── __init__.py
│ │ │ └── td3.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── a2c
│ │ ├── __init__.py
│ │ ├── a2c.py
│ │ └── a2c_rnn.py
│ │ ├── amp
│ │ ├── __init__.py
│ │ └── amp.py
│ │ ├── base.py
│ │ ├── cem
│ │ ├── __init__.py
│ │ └── cem.py
│ │ ├── ddpg
│ │ ├── __init__.py
│ │ ├── ddpg.py
│ │ └── ddpg_rnn.py
│ │ ├── dqn
│ │ ├── __init__.py
│ │ ├── ddqn.py
│ │ └── dqn.py
│ │ ├── ppo
│ │ ├── __init__.py
│ │ ├── ppo.py
│ │ └── ppo_rnn.py
│ │ ├── q_learning
│ │ ├── __init__.py
│ │ └── q_learning.py
│ │ ├── rpo
│ │ ├── __init__.py
│ │ ├── rpo.py
│ │ └── rpo_rnn.py
│ │ ├── sac
│ │ ├── __init__.py
│ │ ├── sac.py
│ │ └── sac_rnn.py
│ │ ├── sarsa
│ │ ├── __init__.py
│ │ └── sarsa.py
│ │ ├── td3
│ │ ├── __init__.py
│ │ ├── td3.py
│ │ └── td3_rnn.py
│ │ └── trpo
│ │ ├── __init__.py
│ │ ├── trpo.py
│ │ └── trpo_rnn.py
├── envs
│ ├── __init__.py
│ ├── jax.py
│ ├── loaders
│ │ ├── __init__.py
│ │ ├── jax
│ │ │ ├── __init__.py
│ │ │ ├── bidexhands_envs.py
│ │ │ ├── isaacgym_envs.py
│ │ │ ├── isaaclab_envs.py
│ │ │ └── omniverse_isaacgym_envs.py
│ │ └── torch
│ │ │ ├── __init__.py
│ │ │ ├── bidexhands_envs.py
│ │ │ ├── isaacgym_envs.py
│ │ │ ├── isaaclab_envs.py
│ │ │ └── omniverse_isaacgym_envs.py
│ ├── torch.py
│ └── wrappers
│ │ ├── __init__.py
│ │ ├── jax
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── bidexhands_envs.py
│ │ ├── brax_envs.py
│ │ ├── gym_envs.py
│ │ ├── gymnasium_envs.py
│ │ ├── isaacgym_envs.py
│ │ ├── isaaclab_envs.py
│ │ ├── omniverse_isaacgym_envs.py
│ │ └── pettingzoo_envs.py
│ │ └── torch
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── bidexhands_envs.py
│ │ ├── brax_envs.py
│ │ ├── deepmind_envs.py
│ │ ├── gym_envs.py
│ │ ├── gymnasium_envs.py
│ │ ├── isaacgym_envs.py
│ │ ├── isaaclab_envs.py
│ │ ├── omniverse_isaacgym_envs.py
│ │ ├── pettingzoo_envs.py
│ │ └── robosuite_envs.py
├── memories
│ ├── __init__.py
│ ├── jax
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── random.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── random.py
├── models
│ ├── __init__.py
│ ├── jax
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── categorical.py
│ │ ├── deterministic.py
│ │ ├── gaussian.py
│ │ └── multicategorical.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── categorical.py
│ │ ├── deterministic.py
│ │ ├── gaussian.py
│ │ ├── multicategorical.py
│ │ ├── multivariate_gaussian.py
│ │ └── tabular.py
├── multi_agents
│ ├── __init__.py
│ ├── jax
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── ippo
│ │ │ ├── __init__.py
│ │ │ └── ippo.py
│ │ └── mappo
│ │ │ ├── __init__.py
│ │ │ └── mappo.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── ippo
│ │ ├── __init__.py
│ │ └── ippo.py
│ │ └── mappo
│ │ ├── __init__.py
│ │ └── mappo.py
├── resources
│ ├── __init__.py
│ ├── noises
│ │ ├── __init__.py
│ │ ├── jax
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── gaussian.py
│ │ │ └── ornstein_uhlenbeck.py
│ │ └── torch
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── gaussian.py
│ │ │ └── ornstein_uhlenbeck.py
│ ├── optimizers
│ │ ├── __init__.py
│ │ └── jax
│ │ │ ├── __init__.py
│ │ │ └── adam.py
│ ├── preprocessors
│ │ ├── __init__.py
│ │ ├── jax
│ │ │ ├── __init__.py
│ │ │ └── running_standard_scaler.py
│ │ └── torch
│ │ │ ├── __init__.py
│ │ │ └── running_standard_scaler.py
│ └── schedulers
│ │ ├── __init__.py
│ │ ├── jax
│ │ ├── __init__.py
│ │ └── kl_adaptive.py
│ │ └── torch
│ │ ├── __init__.py
│ │ └── kl_adaptive.py
├── trainers
│ ├── __init__.py
│ ├── jax
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── sequential.py
│ │ └── step.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── parallel.py
│ │ ├── sequential.py
│ │ └── step.py
└── utils
│ ├── __init__.py
│ ├── control.py
│ ├── distributed
│ ├── __init__.py
│ └── jax
│ │ ├── __main__.py
│ │ └── launcher.py
│ ├── huggingface.py
│ ├── isaacgym_utils.py
│ ├── model_instantiators
│ ├── __init__.py
│ ├── jax
│ │ ├── __init__.py
│ │ ├── categorical.py
│ │ ├── common.py
│ │ ├── deterministic.py
│ │ ├── gaussian.py
│ │ └── multicategorical.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── categorical.py
│ │ ├── common.py
│ │ ├── deterministic.py
│ │ ├── gaussian.py
│ │ ├── multicategorical.py
│ │ ├── multivariate_gaussian.py
│ │ └── shared.py
│ ├── omniverse_isaacgym_utils.py
│ ├── postprocessing.py
│ ├── runner
│ ├── __init__.py
│ ├── jax
│ │ ├── __init__.py
│ │ └── runner.py
│ └── torch
│ │ ├── __init__.py
│ │ └── runner.py
│ └── spaces
│ ├── __init__.py
│ ├── jax
│ ├── __init__.py
│ └── spaces.py
│ └── torch
│ ├── __init__.py
│ └── spaces.py
└── tests
├── __init__.py
├── agents
├── __init__.py
├── jax
│ ├── __init__.py
│ ├── test_a2c.py
│ ├── test_cem.py
│ ├── test_ddpg.py
│ ├── test_ddqn.py
│ ├── test_dqn.py
│ ├── test_ppo.py
│ ├── test_rpo.py
│ ├── test_sac.py
│ └── test_td3.py
└── torch
│ ├── __init__.py
│ ├── test_a2c.py
│ ├── test_amp.py
│ ├── test_cem.py
│ ├── test_ddpg.py
│ ├── test_ddqn.py
│ ├── test_dqn.py
│ ├── test_ppo.py
│ ├── test_rpo.py
│ ├── test_sac.py
│ ├── test_td3.py
│ └── test_trpo.py
├── envs
├── __init__.py
└── wrappers
│ ├── __init__.py
│ ├── jax
│ ├── __init__.py
│ ├── test_brax_envs.py
│ ├── test_gym_envs.py
│ ├── test_gymnasium_envs.py
│ ├── test_isaacgym_envs.py
│ ├── test_isaaclab_envs.py
│ ├── test_omniverse_isaacgym_envs.py
│ └── test_pettingzoo_envs.py
│ └── torch
│ ├── __init__.py
│ ├── test_brax_envs.py
│ ├── test_deepmind_envs.py
│ ├── test_gym_envs.py
│ ├── test_gymnasium_envs.py
│ ├── test_isaacgym_envs.py
│ ├── test_isaaclab_envs.py
│ ├── test_omniverse_isaacgym_envs.py
│ └── test_pettingzoo_envs.py
├── memories
├── __init__.py
└── torch
│ ├── __init__.py
│ └── test_base.py
├── strategies.py
├── test_jax_config.py
├── test_torch_config.py
├── utilities.py
└── utils
├── __init__.py
├── model_instantiators
├── __init__.py
├── jax
│ ├── __init__.py
│ ├── test_definition.py
│ └── test_models.py
└── torch
│ ├── __init__.py
│ ├── test_definition.py
│ └── test_models.py
└── spaces
├── __init__.py
├── jax
├── __init__.py
└── test_spaces.py
└── torch
├── __init__.py
└── test_spaces.py
/.github/GITHUB_ACTIONS.md:
--------------------------------------------------------------------------------
1 | ## GitHub Actions
2 |
3 | ### Relevant links
4 |
5 | - `runs-on`:
6 | - [Standard GitHub-hosted runners for public repositories](https://docs.github.com/en/actions/using-github-hosted-runners/using-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories)
7 | - [GitHub Actions Runner Images](https://github.com/actions/runner-images)
8 | - `actions/setup-python`:
9 | - [Building and testing Python](https://docs.github.com/en/actions/use-cases-and-examples/building-and-testing/building-and-testing-python)
10 | - [Available Python versions](https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json)
11 |
12 | ### Run GitHub Actions locally with nektos/act
13 |
14 | [nektos/act](https://nektosact.com/) is a tool to run GitHub Actions locally. Install it as a [GitHub CLI](https://cli.github.com/) extension via [this steps](https://nektosact.com/installation/gh.html).
15 |
16 | #### Useful commands
17 |
18 | * List workflows/jobs:
19 |
20 | ```bash
21 | gh act -l
22 | ```
23 |
24 | * Run a specific job:
25 |
26 | Use `--env DELETE_HOSTED_TOOL_PYTHON_CACHE=1` to delete the Python cache.
27 |
28 | ```bash
29 | gh act -j Job-ID
30 | ```
31 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: Submit a bug report
3 | labels: bug
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | **Your help in making skrl better is greatly appreciated!**
9 |
10 | * Please ensure that:
11 | * The issue hasn't already been reported by using the [issue search](https://github.com/Toni-SM/skrl/search?q=is%3Aissue&type=issues).
12 | * The issue (and its solution) is not listed in the skrl documentation [troubleshooting](https://skrl.readthedocs.io/en/latest/intro/installation.html#known-issues-and-troubleshooting) section.
13 | * For questions, please consider [open a discussion](https://github.com/Toni-SM/skrl/discussions).
14 |
15 | - type: textarea
16 | attributes:
17 | label: Description
18 | description: A clear and concise description of the bug/issue. Try to provide a minimal example to reproduce it (error/log messages are also helpful).
19 | placeholder: |
20 | Markdown formatting might be applied to the text.
21 |
22 | ```python
23 | # use triple backticks for code blocks or error/log messages
24 | ```
25 | validations:
26 | required: true
27 | - type: dropdown
28 | attributes:
29 | label: What skrl version are you using?
30 | description: The skrl version can be obtained with the command `pip show skrl`.
31 | options:
32 | - ---
33 | - 1.4.3
34 | - 1.4.2
35 | - 1.4.1
36 | - 1.4.0
37 | - 1.3.0
38 | - 1.2.0
39 | - 1.1.0
40 | - 1.0.0
41 | - 1.0.0-rc2
42 | - 1.0.0-rc1
43 | - 0.10.2 or 0.10.1
44 | - 0.10.0 or earlier
45 | - develop branch
46 | validations:
47 | required: true
48 | - type: input
49 | attributes:
50 | label: What ML framework/library version are you using?
51 | description: The version can be obtained with the command `pip show torch` or `pip show jax jaxlib flax optax`.
52 | placeholder: PyTorch version, JAX/jaxlib/Flax/Optax version, etc.
53 | - type: input
54 | attributes:
55 | label: Additional system information
56 | placeholder: Python version, OS (Linux/Windows/macOS/WSL), etc.
57 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Question
4 | url: https://github.com/Toni-SM/skrl/discussions
5 | about: Please ask questions on the Discussions tab
6 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on: [ push, pull_request ]
4 |
5 | jobs:
6 |
7 | pre-commit:
8 | name: Pre-commit hooks
9 | runs-on: ubuntu-latest
10 | steps:
11 | # setup
12 | - uses: actions/checkout@v4
13 | - name: Set up Python
14 | uses: actions/setup-python@v5
15 | with:
16 | python-version: '3.10'
17 | # install dependencies
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --quiet --upgrade pip
21 | python -m pip install --quiet pre-commit
22 | # run pre-commit
23 | - name: Run pre-commit
24 | run: |
25 | pre-commit run --all-files
26 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish-manual.yml:
--------------------------------------------------------------------------------
1 | name: pypi (manually triggered workflow)
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | job:
7 | description: 'Upload Python Package to PyPI/TestPyPI'
8 | required: true
9 | default: 'test-pypi'
10 |
11 | permissions:
12 | contents: read
13 |
14 | jobs:
15 |
16 | pypi:
17 | name: Publish package to PyPI
18 | runs-on: ubuntu-22.04
19 | if: ${{ github.event.inputs.job == 'pypi'}}
20 |
21 | steps:
22 | - uses: actions/checkout@v3
23 |
24 | - name: Set up Python
25 | uses: actions/setup-python@v3
26 | with:
27 | python-version: '3.10.16'
28 |
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 |
34 | - name: Build package
35 | run: python -m build
36 |
37 | - name: Publish package to PyPI
38 | uses: pypa/gh-action-pypi-publish@release/v1
39 | with:
40 | user: __token__
41 | password: ${{ secrets.PYPI_API_TOKEN }}
42 | verbose: true
43 |
44 | test-pypi:
45 | name: Publish package to TestPyPI
46 | runs-on: ubuntu-22.04
47 | if: ${{ github.event.inputs.job == 'test-pypi'}}
48 |
49 | steps:
50 | - uses: actions/checkout@v3
51 |
52 | - name: Set up Python
53 | uses: actions/setup-python@v3
54 | with:
55 | python-version: '3.10.16'
56 |
57 | - name: Install dependencies
58 | run: |
59 | python -m pip install --upgrade pip
60 | pip install build
61 |
62 | - name: Build package
63 | run: python -m build
64 |
65 | - name: Publish package to TestPyPI
66 | uses: pypa/gh-action-pypi-publish@release/v1
67 | with:
68 | user: __token__
69 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
70 | repository_url: https://test.pypi.org/legacy/
71 | verbose: true
72 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.6.0
4 | hooks:
5 | - id: check-ast
6 | - id: check-case-conflict
7 | - id: check-docstring-first
8 | - id: check-json
9 | - id: check-merge-conflict
10 | - id: check-toml
11 | - id: check-yaml
12 | - id: debug-statements
13 | - id: detect-private-key
14 | - id: end-of-file-fixer
15 | - id: name-tests-test
16 | args: ["--pytest-test-first"]
17 | exclude: ^(tests/strategies.py|tests/utilities.py)
18 | - id: trailing-whitespace
19 | - repo: https://github.com/codespell-project/codespell
20 | rev: v2.3.0
21 | hooks:
22 | - id: codespell
23 | exclude: ^(docs/source/_static|docs/_build|pyproject.toml)
24 | additional_dependencies:
25 | - tomli
26 | - repo: https://github.com/python/black
27 | rev: 24.8.0
28 | hooks:
29 | - id: black
30 | args: ["--line-length=120"]
31 | exclude: ^(docs/)
32 | - repo: https://github.com/pycqa/isort
33 | rev: 5.13.2
34 | hooks:
35 | - id: isort
36 | - repo: https://github.com/pre-commit/pygrep-hooks
37 | rev: v1.10.0
38 | hooks:
39 | - id: rst-backticks
40 | - id: rst-directive-colons
41 | - id: rst-inline-touching-normal
42 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the version of Python and other tools you might need
8 | build:
9 | os: ubuntu-22.04
10 | tools:
11 | python: "3.10"
12 |
13 | # Build documentation in the docs/ directory with Sphinx
14 | sphinx:
15 | configuration: docs/source/conf.py
16 | # builder: html
17 | # fail_on_warning: false
18 |
19 | # If using Sphinx, optionally build your docs in additional formats such as PDF
20 | # formats:
21 | # - pdf
22 |
23 | # Python requirements required to build your docs
24 | python:
25 | install:
26 | - requirements: docs/requirements.txt
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Toni-SM
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://pypi.org/project/skrl)
2 | [ ](https://huggingface.co/skrl)
3 | 
4 |
5 | [](https://github.com/Toni-SM/skrl)
6 |
7 | [](https://skrl.readthedocs.io/en/latest/?badge=latest)
8 | [](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml)
9 | [](https://github.com/Toni-SM/skrl/actions/workflows/tests-torch.yml)
10 | [](https://github.com/Toni-SM/skrl/actions/workflows/tests-jax.yml)
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | SKRL - Reinforcement Learning library
19 |
20 |
21 | **skrl** is an open-source modular library for Reinforcement Learning written in Python (on top of [PyTorch](https://pytorch.org/) and [JAX](https://jax.readthedocs.io)) and designed with a focus on modularity, readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI [Gym](https://www.gymlibrary.dev), Farama [Gymnasium](https://gymnasium.farama.org) and [PettingZoo](https://pettingzoo.farama.org), Google [DeepMind](https://github.com/deepmind/dm_env) and [Brax](https://github.com/google/brax), among other environment interfaces, it allows loading and configuring NVIDIA [Isaac Lab](https://isaac-sim.github.io/IsaacLab/index.html) (as well as [Isaac Gym](https://developer.nvidia.com/isaac-gym/) and [Omniverse Isaac Gym](https://github.com/isaac-sim/OmniIsaacGymEnvs)) environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run.
22 |
23 |
24 |
25 | ### Please, visit the documentation for usage details and examples
26 |
27 | https://skrl.readthedocs.io
28 |
29 |
30 |
31 | > **Note:** This project is under **active continuous development**. Please make sure you always have the latest version. Visit the [develop](https://github.com/Toni-SM/skrl/tree/develop) branch or its [documentation](https://skrl.readthedocs.io/en/develop) to access the latest updates to be released.
32 |
33 |
34 |
35 | ### Citing this library
36 |
37 | To cite this library in publications, please use the following reference:
38 |
39 | ```bibtex
40 | @article{serrano2023skrl,
41 | author = {Antonio Serrano-Muñoz and Dimitrios Chrysostomou and Simon Bøgh and Nestor Arana-Arexolaleiba},
42 | title = {skrl: Modular and Flexible Library for Reinforcement Learning},
43 | journal = {Journal of Machine Learning Research},
44 | year = {2023},
45 | volume = {24},
46 | number = {254},
47 | pages = {1--9},
48 | url = {http://jmlr.org/papers/v24/23-0112.html}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | ## Install Sphinx and Read the Docs Sphinx Theme
4 |
5 | ```bash
6 | cd docs
7 | pip install -r requirements.txt
8 | ```
9 |
10 | ## Building the documentation
11 |
12 | ```bash
13 | cd docs
14 | make html
15 | ```
16 |
17 | Building each time a file is changed:
18 |
19 | ```bash
20 | cd docs
21 | sphinx-autobuild ./source/ _build/html
22 | ```
23 |
24 | ## Useful links
25 |
26 | - [Sphinx directives](https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html)
27 | - [Math support in Sphinx](https://www.sphinx-doc.org/en/1.0/ext/math.html)
28 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | furo==2024.8.6
2 | sphinx
3 | sphinx-tabs
4 | sphinx-autobuild
5 | sphinx-copybutton
6 | sphinx-notfound-page
7 | decorator
8 | numpy
9 |
--------------------------------------------------------------------------------
/docs/source/404.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | Page not found
4 | ==============
5 |
6 | .. image:: _static/data/404-light.svg
7 | :width: 50%
8 | :align: center
9 | :class: only-light
10 | :alt: 404
11 |
12 | .. image:: _static/data/404-dark.svg
13 | :width: 50%
14 | :align: center
15 | :class: only-dark
16 | :alt: 404
17 |
18 | .. raw:: html
19 |
20 |
21 |
22 |
404: Puzzle piece not found.
23 |
Did you look under the sofa cushions?
24 |
25 |
26 |
27 |
28 | Since version 1.0.0, the documentation structure has changed to improve content organization and to provide a better browsing experience.
29 | Navigate using the left sidebar or type in the search box to find what you are looking for.
30 |
--------------------------------------------------------------------------------
/docs/source/_static/css/s5defs-roles.css:
--------------------------------------------------------------------------------
1 |
2 | .black {
3 | color: black;
4 | }
5 |
6 | .gray {
7 | color: gray;
8 | }
9 |
10 | .grey {
11 | color: gray;
12 | }
13 |
14 | .silver {
15 | color: silver;
16 | }
17 |
18 | .white {
19 | color: white;
20 | }
21 |
22 | .maroon {
23 | color: maroon;
24 | }
25 |
26 | .red {
27 | color: red;
28 | }
29 |
30 | .magenta {
31 | color: magenta;
32 | }
33 |
34 | .fuchsia {
35 | color: fuchsia;
36 | }
37 |
38 | .pink {
39 | color: pink;
40 | }
41 |
42 | .orange {
43 | color: orange;
44 | }
45 |
46 | .yellow {
47 | color: yellow;
48 | }
49 |
50 | .lime {
51 | color: lime;
52 | }
53 |
54 | .green {
55 | color: #02a802;
56 | }
57 |
58 | .olive {
59 | color: olive;
60 | }
61 |
62 | .teal {
63 | color: teal;
64 | }
65 |
66 | .cyan {
67 | color: cyan;
68 | }
69 |
70 | .aqua {
71 | color: aqua;
72 | }
73 |
74 | .blue {
75 | color: #007cea;
76 | }
77 |
78 | .navy {
79 | color: navy;
80 | }
81 |
82 | .purple {
83 | color: purple;
84 | }
85 |
--------------------------------------------------------------------------------
/docs/source/_static/css/skrl.css:
--------------------------------------------------------------------------------
1 | .nowrap {
2 | white-space: nowrap;
3 | }
4 |
5 | .sidebar-brand-text {
6 | font-size: 1.25rem !important;
7 | }
8 |
9 | tbody > tr > th.stub {
10 | font-weight: normal;
11 | text-align: left;
12 | }
13 |
--------------------------------------------------------------------------------
/docs/source/_static/data/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/favicon.ico
--------------------------------------------------------------------------------
/docs/source/_static/data/logo-dark-mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/logo-dark-mode.png
--------------------------------------------------------------------------------
/docs/source/_static/data/logo-light-mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/logo-light-mode.png
--------------------------------------------------------------------------------
/docs/source/_static/data/logo-torch.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | image/svg+xml
46 |
49 |
54 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/docs/source/_static/data/skrl-up-transparent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/skrl-up-transparent.png
--------------------------------------------------------------------------------
/docs/source/_static/data/skrl-up.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/data/skrl-up.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/data_tensorboard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/data_tensorboard.jpg
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_bidexhands.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_bidexhands.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_deepmind.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_deepmind.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_gym.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_gym.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_isaacgym.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_isaacgym.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_isaaclab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_isaaclab.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_omniverse_isaacgym.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_omniverse_isaacgym.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_parallel.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_parallel.jpg
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_robosuite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_robosuite.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/example_shimmy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/example_shimmy.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/noise_gaussian.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/noise_gaussian.png
--------------------------------------------------------------------------------
/docs/source/_static/imgs/noise_ornstein_uhlenbeck.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/docs/source/_static/imgs/noise_ornstein_uhlenbeck.png
--------------------------------------------------------------------------------
/docs/source/api/envs.rst:
--------------------------------------------------------------------------------
1 | Environments
2 | ============
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | Wrapping (single-agent)
8 | Wrapping (multi-agents)
9 | Isaac Lab environments
10 | Isaac Gym environments
11 | Omniverse Isaac Gym environments
12 |
13 | The environment plays a fundamental and crucial role in defining the RL setup. It is the place where the agent interacts, and it is responsible for providing the agent with information about its current state, as well as the rewards/penalties associated with each action.
14 |
15 | .. raw:: html
16 |
17 |
18 |
19 | In this section you will find how to load environments from NVIDIA Isaac Lab (as well as Isaac Gym and Omniverse Isaac Gym) with a simple function.
20 |
21 | .. list-table::
22 | :header-rows: 1
23 |
24 | * - Loaders
25 | - .. centered:: |_4| |pytorch| |_4|
26 | - .. centered:: |_4| |jax| |_4|
27 | * - :doc:`Isaac Lab environments `
28 | - .. centered:: :math:`\blacksquare`
29 | - .. centered:: :math:`\blacksquare`
30 | * - :doc:`Isaac Gym environments `
31 | - .. centered:: :math:`\blacksquare`
32 | - .. centered:: :math:`\blacksquare`
33 | * - :doc:`Omniverse Isaac Gym environments `
34 | - .. centered:: :math:`\blacksquare`
35 | - .. centered:: :math:`\blacksquare`
36 |
37 | In addition, you will be able to :doc:`wrap single-agent ` and :doc:`multi-agent ` RL environment interfaces.
38 |
39 | .. list-table::
40 | :header-rows: 1
41 |
42 | * - Wrappers
43 | - .. centered:: |_4| |pytorch| |_4|
44 | - .. centered:: |_4| |jax| |_4|
45 | * - Bi-DexHands
46 | - .. centered:: :math:`\blacksquare`
47 | - .. centered:: :math:`\blacksquare`
48 | * - Brax
49 | - .. centered:: :math:`\blacksquare`
50 | - .. centered:: :math:`\blacksquare`
51 | * - DeepMind
52 | - .. centered:: :math:`\blacksquare`
53 | - .. centered:: :math:`\square`
54 | * - Gym
55 | - .. centered:: :math:`\blacksquare`
56 | - .. centered:: :math:`\blacksquare`
57 | * - Gymnasium
58 | - .. centered:: :math:`\blacksquare`
59 | - .. centered:: :math:`\blacksquare`
60 | * - Isaac Lab
61 | - .. centered:: :math:`\blacksquare`
62 | - .. centered:: :math:`\blacksquare`
63 | * - Isaac Gym (previews)
64 | - .. centered:: :math:`\blacksquare`
65 | - .. centered:: :math:`\blacksquare`
66 | * - Omniverse Isaac Gym |_5| |_5| |_5| |_5| |_2|
67 | - .. centered:: :math:`\blacksquare`
68 | - .. centered:: :math:`\blacksquare`
69 | * - PettingZoo
70 | - .. centered:: :math:`\blacksquare`
71 | - .. centered:: :math:`\blacksquare`
72 | * - robosuite
73 | - .. centered:: :math:`\blacksquare`
74 | - .. centered:: :math:`\square`
75 | * - Shimmy
76 | - .. centered:: :math:`\blacksquare`
77 | - .. centered:: :math:`\blacksquare`
78 |
--------------------------------------------------------------------------------
/docs/source/api/envs/isaaclab.rst:
--------------------------------------------------------------------------------
1 | Isaac Lab environments
2 | ======================
3 |
4 | .. image:: ../../_static/imgs/example_isaaclab.png
5 | :width: 100%
6 | :align: center
7 | :alt: Isaac Lab environments
8 |
9 | .. raw:: html
10 |
11 |
12 |
13 | Environments
14 | ------------
15 |
16 | The repository https://github.com/isaac-sim/IsaacLab provides the example reinforcement learning environments for Isaac Lab (Orbit and Omniverse Isaac Gym unification).
17 |
18 | These environments can be easily loaded and configured by calling a single function provided with this library. This function also makes it possible to configure the environment from the command line arguments (see Isaac Lab's `Training with an RL Agent `_) or from its parameters (:literal:`task_name`, :literal:`num_envs`, :literal:`headless`, and :literal:`cli_args`).
19 |
20 | .. note::
21 |
22 | The command line arguments has priority over the function parameters.
23 |
24 | .. note::
25 |
26 | Isaac Lab environments implement a functionality to get their configuration from the command line. Setting the :literal:`headless` option from the trainer configuration will not work. In this case, it is necessary to set the load function's :literal:`headless` argument to True or to invoke the scripts as follows: :literal:`isaaclab -p script.py --headless`.
27 |
28 | .. raw:: html
29 |
30 |
31 |
32 | Usage
33 | ^^^^^
34 |
35 | .. tabs::
36 |
37 | .. tab:: Function parameters
38 |
39 | .. tabs::
40 |
41 | .. group-tab:: |_4| |pytorch| |_4|
42 |
43 | .. literalinclude:: ../../snippets/loaders.py
44 | :language: python
45 | :emphasize-lines: 2, 5
46 | :start-after: [start-isaaclab-envs-parameters-torch]
47 | :end-before: [end-isaaclab-envs-parameters-torch]
48 |
49 | .. group-tab:: |_4| |jax| |_4|
50 |
51 | .. literalinclude:: ../../snippets/loaders.py
52 | :language: python
53 | :emphasize-lines: 2, 5
54 | :start-after: [start-isaaclab-envs-parameters-jax]
55 | :end-before: [end-isaaclab-envs-parameters-jax]
56 |
57 | .. tab:: Command line arguments (priority)
58 |
59 | .. tabs::
60 |
61 | .. group-tab:: |_4| |pytorch| |_4|
62 |
63 | .. literalinclude:: ../../snippets/loaders.py
64 | :language: python
65 | :emphasize-lines: 2, 5
66 | :start-after: [start-isaaclab-envs-cli-torch]
67 | :end-before: [end-isaaclab-envs-cli-torch]
68 |
69 | .. group-tab:: |_4| |jax| |_4|
70 |
71 | .. literalinclude:: ../../snippets/loaders.py
72 | :language: python
73 | :emphasize-lines: 2, 5
74 | :start-after: [start-isaaclab-envs-cli-jax]
75 | :end-before: [end-isaaclab-envs-cli-jax]
76 |
77 | Run the main script passing the configuration as command line arguments. For example:
78 |
79 | .. code-block::
80 |
81 | isaaclab -p main.py --task Isaac-Cartpole-v0
82 |
83 | .. raw:: html
84 |
85 |
86 |
87 | API
88 | ^^^
89 |
90 | .. autofunction:: skrl.envs.loaders.torch.load_isaaclab_env
91 |
--------------------------------------------------------------------------------
/docs/source/api/memories.rst:
--------------------------------------------------------------------------------
1 | Memories
2 | ========
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | Random
8 |
9 | Memories are storage components that allow agents to collect and use/reuse current or past experiences of their interaction with the environment or other types of information.
10 |
11 | .. raw:: html
12 |
13 |
14 |
15 | .. list-table::
16 | :header-rows: 1
17 |
18 | * - Memories
19 | - .. centered:: |_4| |pytorch| |_4|
20 | - .. centered:: |_4| |jax| |_4|
21 | * - :doc:`Random memory `
22 | - .. centered:: :math:`\blacksquare`
23 | - .. centered:: :math:`\blacksquare`
24 |
25 | Base class
26 | ----------
27 |
28 | .. note::
29 |
30 | This is the base class for all the other classes in this module.
31 | It provides the basic functionality for the other classes.
32 | **It is not intended to be used directly**.
33 |
34 | .. raw:: html
35 |
36 |
37 |
38 | Basic inheritance usage
39 | ^^^^^^^^^^^^^^^^^^^^^^^
40 |
41 | .. tabs::
42 |
43 | .. group-tab:: |_4| |pytorch| |_4|
44 |
45 | .. literalinclude:: ../snippets/memories.py
46 | :language: python
47 | :start-after: [start-base-class-torch]
48 | :end-before: [end-base-class-torch]
49 |
50 | .. group-tab:: |_4| |jax| |_4|
51 |
52 | .. literalinclude:: ../snippets/memories.py
53 | :language: python
54 | :start-after: [start-base-class-jax]
55 | :end-before: [end-base-class-jax]
56 |
57 | .. raw:: html
58 |
59 |
60 |
61 | API (PyTorch)
62 | ^^^^^^^^^^^^^
63 |
64 | .. autoclass:: skrl.memories.torch.base.Memory
65 | :undoc-members:
66 | :show-inheritance:
67 | :members:
68 |
69 | .. automethod:: __len__
70 |
71 | .. raw:: html
72 |
73 |
74 |
75 | API (JAX)
76 | ^^^^^^^^^
77 |
78 | .. autoclass:: skrl.memories.jax.base.Memory
79 | :undoc-members:
80 | :show-inheritance:
81 | :members:
82 |
83 | .. automethod:: __len__
84 |
--------------------------------------------------------------------------------
/docs/source/api/memories/random.rst:
--------------------------------------------------------------------------------
1 | Random memory
2 | =============
3 |
4 | Random sampling memory
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Usage
11 | -----
12 |
13 | .. tabs::
14 |
15 | .. group-tab:: |_4| |pytorch| |_4|
16 |
17 | .. literalinclude:: ../../snippets/memories.py
18 | :language: python
19 | :emphasize-lines: 2, 5
20 | :start-after: [start-random-torch]
21 | :end-before: [end-random-torch]
22 |
23 | .. group-tab:: |_4| |jax| |_4|
24 |
25 | .. literalinclude:: ../../snippets/memories.py
26 | :language: python
27 | :emphasize-lines: 2, 5
28 | :start-after: [start-random-jax]
29 | :end-before: [end-random-jax]
30 |
31 | .. raw:: html
32 |
33 |
34 |
35 | API (PyTorch)
36 | -------------
37 |
38 | .. autoclass:: skrl.memories.torch.random.RandomMemory
39 | :undoc-members:
40 | :show-inheritance:
41 | :inherited-members:
42 | :members:
43 |
44 | .. automethod:: __len__
45 |
46 | .. raw:: html
47 |
48 |
49 |
50 | API (JAX)
51 | ---------
52 |
53 | .. autoclass:: skrl.memories.jax.random.RandomMemory
54 | :undoc-members:
55 | :show-inheritance:
56 | :inherited-members:
57 | :members:
58 |
59 | .. automethod:: __len__
60 |
--------------------------------------------------------------------------------
/docs/source/api/models/shared_model.rst:
--------------------------------------------------------------------------------
1 | Shared model
2 | ============
3 |
4 | Sometimes it is desirable to define models that use shared layers or network to represent multiple function approximators. This practice, known as *shared parameters* (or *parameter sharing*), *shared layers*, *shared model*, *shared networks* or *joint architecture* among others, is typically justified by the following criteria:
5 |
6 | * Learning the same characteristics, especially when processing large inputs (such as images, e.g.).
7 |
8 | * Reduce the number of parameters in the whole system.
9 |
10 | * Make the computation more efficient (single forward-pass).
11 |
12 | .. raw:: html
13 |
14 |
15 |
16 | Implementation
17 | --------------
18 |
19 | By combining the implemented mixins, it is possible to define shared models with skrl. In these cases, the use of the :literal:`role` argument (a Python string) is relevant. The agents will call the models by setting the :literal:`role` argument according to their requirements. Visit each agent's documentation (*Key* column of the table under *Spaces and models* section) to know the possible values that this parameter can take.
20 |
21 | The code snippet below shows how to define a shared model. The following practices for building shared models can be identified:
22 |
23 | * The definition of multiple inheritance must always include the :ref:`Model ` base class at the end.
24 |
25 | * The :ref:`Model ` base class constructor must be invoked before the mixins constructor.
26 |
27 | * All mixin constructors must be invoked.
28 |
29 | * Specify :literal:`role` argument is optional if all constructors belong to different mixins.
30 |
31 | * If multiple models of the same mixin type are required, the same constructor must be invoked as many times as needed. To do so, it is mandatory to specify the :literal:`role` argument.
32 |
33 | * The :literal:`.act(...)` method needs to be overridden to disambiguate its call.
34 |
35 | * The same instance of the shared model must be passed to all keys involved.
36 |
37 | .. raw:: html
38 |
39 |
40 |
41 | .. tabs::
42 |
43 | .. group-tab:: |_4| |pytorch| |_4|
44 |
45 | .. tabs::
46 |
47 | .. group-tab:: Single forward-pass
48 |
49 | .. warning::
50 |
51 | The implementation described for single forward-pass requires that the value-pass always follows the policy-pass (e.g.: ``PPO``) which may not be generalized to other algorithms.
52 |
53 | If this requirement is not met, other forms of "chaching" the shared layers/network output could be implemented.
54 |
55 | .. literalinclude:: ../../snippets/shared_model.py
56 | :language: python
57 | :start-after: [start-mlp-single-forward-pass-torch]
58 | :end-before: [end-mlp-single-forward-pass-torch]
59 |
60 | .. group-tab:: Multiple forward-pass
61 |
62 | .. literalinclude:: ../../snippets/shared_model.py
63 | :language: python
64 | :start-after: [start-mlp-multi-forward-pass-torch]
65 | :end-before: [end-mlp-multi-forward-pass-torch]
66 |
--------------------------------------------------------------------------------
/docs/source/api/models/tabular.rst:
--------------------------------------------------------------------------------
1 | .. _models_tabular:
2 |
3 | Tabular model
4 | =============
5 |
6 | Tabular models run **discrete-domain deterministic/stochastic** policies.
7 |
8 | .. raw:: html
9 |
10 |
11 |
12 | skrl provides a Python mixin (:literal:`TabularMixin`) to assist in the creation of these types of models, allowing users to have full control over the table definitions. Note that the use of this mixin must comply with the following rules:
13 |
14 | * The definition of multiple inheritance must always include the :ref:`Model ` base class at the end.
15 |
16 | * The :ref:`Model ` base class constructor must be invoked before the mixins constructor.
17 |
18 | .. tabs::
19 |
20 | .. group-tab:: |_4| |pytorch| |_4|
21 |
22 | .. literalinclude:: ../../snippets/tabular_model.py
23 | :language: python
24 | :emphasize-lines: 1, 3-4
25 | :start-after: [start-definition-torch]
26 | :end-before: [end-definition-torch]
27 |
28 | .. raw:: html
29 |
30 |
31 |
32 | Usage
33 | -----
34 |
35 | .. tabs::
36 |
37 | .. tab:: :math:`\epsilon`-greedy policy
38 |
39 | .. tabs::
40 |
41 | .. group-tab:: |_4| |pytorch| |_4|
42 |
43 | .. literalinclude:: ../../snippets/tabular_model.py
44 | :language: python
45 | :start-after: [start-epsilon-greedy-torch]
46 | :end-before: [end-epsilon-greedy-torch]
47 |
48 | .. raw:: html
49 |
50 |
51 |
52 | API (PyTorch)
53 | -------------
54 |
55 | .. autoclass:: skrl.models.torch.tabular.TabularMixin
56 | :show-inheritance:
57 | :exclude-members: to, state_dict, load_state_dict, load, save
58 | :members:
59 |
--------------------------------------------------------------------------------
/docs/source/api/multi_agents.rst:
--------------------------------------------------------------------------------
1 | Multi-agents
2 | ============
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | IPPO
8 | MAPPO
9 |
10 | Multi-agents are autonomous entities that interact with the environment to learn and improve their behavior. Multi-agents' goal is to learn optimal policies, which are correspondence between states and actions that maximize the cumulative reward received from the environment over time.
11 |
12 | .. raw:: html
13 |
14 |
15 |
16 | .. list-table::
17 | :header-rows: 1
18 |
19 | * - Multi-agents
20 | - .. centered:: |_4| |pytorch| |_4|
21 | - .. centered:: |_4| |jax| |_4|
22 | * - :doc:`Independent Proximal Policy Optimization ` (**IPPO**)
23 | - .. centered:: :math:`\blacksquare`
24 | - .. centered:: :math:`\blacksquare`
25 | * - :doc:`Multi-Agent Proximal Policy Optimization ` (**MAPPO**)
26 | - .. centered:: :math:`\blacksquare`
27 | - .. centered:: :math:`\blacksquare`
28 |
29 | Base class
30 | ----------
31 |
32 | .. note::
33 |
34 | This is the base class for all multi-agents and provides only basic functionality that is not tied to any implementation of the optimization algorithms.
35 | **It is not intended to be used directly**.
36 |
37 | .. raw:: html
38 |
39 |
40 |
41 | Basic inheritance usage
42 | ^^^^^^^^^^^^^^^^^^^^^^^
43 |
44 | .. tabs::
45 |
46 | .. tab:: Inheritance
47 |
48 | .. tabs::
49 |
50 | .. group-tab:: |_4| |pytorch| |_4|
51 |
52 | .. literalinclude:: ../snippets/multi_agent.py
53 | :language: python
54 | :start-after: [start-multi-agent-base-class-torch]
55 | :end-before: [end-multi-agent-base-class-torch]
56 |
57 | .. group-tab:: |_4| |jax| |_4|
58 |
59 | .. literalinclude:: ../snippets/multi_agent.py
60 | :language: python
61 | :start-after: [start-multi-agent-base-class-jax]
62 | :end-before: [end-multi-agent-base-class-jax]
63 |
64 | .. raw:: html
65 |
66 |
67 |
68 | API (PyTorch)
69 | ^^^^^^^^^^^^^
70 |
71 | .. autoclass:: skrl.multi_agents.torch.base.MultiAgent
72 | :undoc-members:
73 | :show-inheritance:
74 | :inherited-members:
75 | :private-members: _update, _empty_preprocessor, _get_internal_value, _as_dict
76 | :members:
77 |
78 | .. automethod:: __str__
79 |
80 | .. raw:: html
81 |
82 |
83 |
84 | API (JAX)
85 | ^^^^^^^^^
86 |
87 | .. autoclass:: skrl.multi_agents.jax.base.MultiAgent
88 | :undoc-members:
89 | :show-inheritance:
90 | :inherited-members:
91 | :private-members: _update, _empty_preprocessor, _get_internal_value, _as_dict
92 | :members:
93 |
94 | .. automethod:: __str__
95 |
--------------------------------------------------------------------------------
/docs/source/api/resources.rst:
--------------------------------------------------------------------------------
1 | Resources
2 | =========
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | Noises
8 | Preprocessors
9 | Learning rate schedulers
10 | Optimizers
11 |
12 | Resources groups a variety of components that may be used to improve the agents' performance.
13 |
14 | .. raw:: html
15 |
16 |
17 |
18 | Available resources are :doc:`noises `, input :doc:`preprocessors `, learning rate :doc:`schedulers ` and :doc:`optimizers ` (this last one only for JAX).
19 |
20 | .. list-table::
21 | :header-rows: 1
22 |
23 | * - Noises
24 | - .. centered:: |_4| |pytorch| |_4|
25 | - .. centered:: |_4| |jax| |_4|
26 | * - :doc:`Gaussian ` noise
27 | - .. centered:: :math:`\blacksquare`
28 | - .. centered:: :math:`\blacksquare`
29 | * - :doc:`Ornstein-Uhlenbeck ` noise |_2|
30 | - .. centered:: :math:`\blacksquare`
31 | - .. centered:: :math:`\blacksquare`
32 |
33 | .. list-table::
34 | :header-rows: 1
35 |
36 | * - Preprocessors
37 | - .. centered:: |_4| |pytorch| |_4|
38 | - .. centered:: |_4| |jax| |_4|
39 | * - :doc:`Running standard scaler ` |_4|
40 | - .. centered:: :math:`\blacksquare`
41 | - .. centered:: :math:`\blacksquare`
42 |
43 | .. list-table::
44 | :header-rows: 1
45 |
46 | * - Learning rate schedulers
47 | - .. centered:: |_4| |pytorch| |_4|
48 | - .. centered:: |_4| |jax| |_4|
49 | * - :doc:`KL Adaptive `
50 | - .. centered:: :math:`\blacksquare`
51 | - .. centered:: :math:`\blacksquare`
52 |
53 | .. list-table::
54 | :header-rows: 1
55 |
56 | * - Optimizers
57 | - .. centered:: |_4| |pytorch| |_4|
58 | - .. centered:: |_4| |jax| |_4|
59 | * - :doc:`Adam `\ |_5| |_5| |_5| |_5| |_5| |_5| |_3|
60 | - .. centered:: :math:`\scriptscriptstyle \texttt{PyTorch}`
61 | - .. centered:: :math:`\blacksquare`
62 |
--------------------------------------------------------------------------------
/docs/source/api/resources/noises.rst:
--------------------------------------------------------------------------------
1 | Noises
2 | ======
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | Gaussian noise
8 | Ornstein-Uhlenbeck
9 |
10 | Definition of the noises used by the agents during the exploration stage. All noises inherit from a base class that defines a uniform interface.
11 |
12 | .. raw:: html
13 |
14 |
15 |
16 | .. list-table::
17 | :header-rows: 1
18 |
19 | * - Noises
20 | - .. centered:: |_4| |pytorch| |_4|
21 | - .. centered:: |_4| |jax| |_4|
22 | * - :doc:`Gaussian ` noise
23 | - .. centered:: :math:`\blacksquare`
24 | - .. centered:: :math:`\blacksquare`
25 | * - :doc:`Ornstein-Uhlenbeck ` noise |_2|
26 | - .. centered:: :math:`\blacksquare`
27 | - .. centered:: :math:`\blacksquare`
28 |
29 | Base class
30 | ----------
31 |
32 | .. note::
33 |
34 | This is the base class for all the other classes in this module.
35 | It provides the basic functionality for the other classes.
36 | **It is not intended to be used directly**.
37 |
38 | .. raw:: html
39 |
40 |
41 |
42 | Basic inheritance usage
43 | ^^^^^^^^^^^^^^^^^^^^^^^
44 |
45 | .. tabs::
46 |
47 | .. group-tab:: |_4| |pytorch| |_4|
48 |
49 | .. literalinclude:: ../../snippets/noises.py
50 | :language: python
51 | :start-after: [start-base-class-torch]
52 | :end-before: [end-base-class-torch]
53 |
54 | .. group-tab:: |_4| |jax| |_4|
55 |
56 | .. literalinclude:: ../../snippets/noises.py
57 | :language: python
58 | :start-after: [start-base-class-jax]
59 | :end-before: [end-base-class-jax]
60 |
61 | .. raw:: html
62 |
63 |
64 |
65 | API (PyTorch)
66 | ^^^^^^^^^^^^^
67 |
68 | .. autoclass:: skrl.resources.noises.torch.base.Noise
69 | :undoc-members:
70 | :show-inheritance:
71 | :inherited-members:
72 | :members:
73 |
74 | .. raw:: html
75 |
76 |
77 |
78 | API (JAX)
79 | ^^^^^^^^^
80 |
81 | .. autoclass:: skrl.resources.noises.jax.base.Noise
82 | :undoc-members:
83 | :show-inheritance:
84 | :inherited-members:
85 | :members:
86 |
--------------------------------------------------------------------------------
/docs/source/api/resources/noises/gaussian.rst:
--------------------------------------------------------------------------------
1 | .. _gaussian-noise:
2 |
3 | Gaussian noise
4 | ==============
5 |
6 | Noise generated by normal distribution.
7 |
8 | .. raw:: html
9 |
10 |
11 |
12 | Usage
13 | -----
14 |
15 | The noise usage is defined in each agent's configuration dictionary. A noise instance is set under the :literal:`"noise"` sub-key. The following examples show how to set the noise for an agent:
16 |
17 | |
18 |
19 | .. image:: ../../../_static/imgs/noise_gaussian.png
20 | :width: 75%
21 | :align: center
22 | :alt: Gaussian noise
23 |
24 | .. raw:: html
25 |
26 |
27 |
28 | .. tabs::
29 |
30 | .. group-tab:: |_4| |pytorch| |_4|
31 |
32 | .. literalinclude:: ../../../snippets/noises.py
33 | :language: python
34 | :emphasize-lines: 1, 4
35 | :start-after: [torch-start-gaussian]
36 | :end-before: [torch-end-gaussian]
37 |
38 | .. group-tab:: |_4| |jax| |_4|
39 |
40 | .. literalinclude:: ../../../snippets/noises.py
41 | :language: python
42 | :emphasize-lines: 1, 4
43 | :start-after: [jax-start-gaussian]
44 | :end-before: [jax-end-gaussian]
45 |
46 | .. raw:: html
47 |
48 |
49 |
50 | API (PyTorch)
51 | -------------
52 |
53 | .. autoclass:: skrl.resources.noises.torch.gaussian.GaussianNoise
54 | :undoc-members:
55 | :show-inheritance:
56 | :inherited-members:
57 | :private-members: _update
58 | :members:
59 |
60 | .. raw:: html
61 |
62 |
63 |
64 | API (JAX)
65 | ---------
66 |
67 | .. autoclass:: skrl.resources.noises.jax.gaussian.GaussianNoise
68 | :undoc-members:
69 | :show-inheritance:
70 | :inherited-members:
71 | :private-members: _update
72 | :members:
73 |
--------------------------------------------------------------------------------
/docs/source/api/resources/noises/ornstein_uhlenbeck.rst:
--------------------------------------------------------------------------------
1 | .. _ornstein-uhlenbeck-noise:
2 |
3 | Ornstein-Uhlenbeck noise
4 | ========================
5 |
6 | Noise generated by a stochastic process that is characterized by its mean-reverting behavior.
7 |
8 | .. raw:: html
9 |
10 |
11 |
12 | Usage
13 | -----
14 |
15 | The noise usage is defined in each agent's configuration dictionary. A noise instance is set under the :literal:`"noise"` sub-key. The following examples show how to set the noise for an agent:
16 |
17 | |
18 |
19 | .. image:: ../../../_static/imgs/noise_ornstein_uhlenbeck.png
20 | :width: 75%
21 | :align: center
22 | :alt: Ornstein-Uhlenbeck noise
23 |
24 | .. raw:: html
25 |
26 |
27 |
28 | .. tabs::
29 |
30 | .. group-tab:: |_4| |pytorch| |_4|
31 |
32 | .. literalinclude:: ../../../snippets/noises.py
33 | :language: python
34 | :emphasize-lines: 1, 4
35 | :start-after: [torch-start-ornstein-uhlenbeck]
36 | :end-before: [torch-end-ornstein-uhlenbeck]
37 |
38 | .. group-tab:: |_4| |jax| |_4|
39 |
40 | .. literalinclude:: ../../../snippets/noises.py
41 | :language: python
42 | :emphasize-lines: 1, 4
43 | :start-after: [jax-start-ornstein-uhlenbeck]
44 | :end-before: [jax-end-ornstein-uhlenbeck]
45 |
46 | .. raw:: html
47 |
48 |
49 |
50 | API (PyTorch)
51 | -------------
52 |
53 | .. autoclass:: skrl.resources.noises.torch.ornstein_uhlenbeck.OrnsteinUhlenbeckNoise
54 | :undoc-members:
55 | :show-inheritance:
56 | :inherited-members:
57 | :private-members: _update
58 | :members:
59 |
60 | .. raw:: html
61 |
62 |
63 |
64 | API (JAX)
65 | ---------
66 |
67 | .. autoclass:: skrl.resources.noises.jax.ornstein_uhlenbeck.OrnsteinUhlenbeckNoise
68 | :undoc-members:
69 | :show-inheritance:
70 | :inherited-members:
71 | :private-members: _update
72 | :members:
73 |
--------------------------------------------------------------------------------
/docs/source/api/resources/optimizers.rst:
--------------------------------------------------------------------------------
1 | Optimizers
2 | ==========
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | Adam
8 |
9 | Optimizers are algorithms that adjust the parameters of artificial neural networks to minimize the error or loss function during the training process.
10 |
11 | .. raw:: html
12 |
13 |
14 |
15 | .. list-table::
16 | :header-rows: 1
17 |
18 | * - Optimizers
19 | - .. centered:: |_4| |pytorch| |_4|
20 | - .. centered:: |_4| |jax| |_4|
21 | * - :doc:`Adam `\ |_5| |_5| |_5| |_5| |_5| |_5| |_3|
22 | - .. centered:: :math:`\scriptscriptstyle \texttt{PyTorch}`
23 | - .. centered:: :math:`\blacksquare`
24 |
--------------------------------------------------------------------------------
/docs/source/api/resources/optimizers/adam.rst:
--------------------------------------------------------------------------------
1 | Adam
2 | ====
3 |
4 | An extension of the stochastic gradient descent algorithm that adaptively changes the learning rate for each neural network parameter.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Usage
11 | -----
12 |
13 | .. note::
14 |
15 | This class is the result of isolating the Optax optimizer that is mixed with the model parameters, as defined in the `Flax's TrainState `_ class. It is not intended to be used directly by the user, but by agent implementations.
16 |
17 | .. tabs::
18 |
19 | .. group-tab:: |_4| |jax| |_4|
20 |
21 | .. code-block:: python
22 | :emphasize-lines: 2, 5, 8
23 |
24 | # import the optimizer class
25 | from skrl.resources.optimizers.jax import Adam
26 |
27 | # instantiate the optimizer
28 | optimizer = Adam(model=model, lr=1e-3)
29 |
30 | # step the optimizer
31 | optimizer = optimizer.step(grad, model)
32 |
33 | .. raw:: html
34 |
35 |
36 |
37 | API (JAX)
38 | ---------
39 |
40 | .. autoclass:: skrl.resources.optimizers.jax.adam.Adam
41 | :show-inheritance:
42 | :inherited-members:
43 | :members:
44 |
45 | .. automethod:: __new__
46 |
--------------------------------------------------------------------------------
/docs/source/api/resources/preprocessors.rst:
--------------------------------------------------------------------------------
1 | Preprocessors
2 | =============
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | Running standard scaler
8 |
9 | Preprocessors are functions used to transform or encode raw input data into a form more suitable for the learning algorithm.
10 |
11 | .. raw:: html
12 |
13 |
14 |
15 | .. list-table::
16 | :header-rows: 1
17 |
18 | * - Preprocessors
19 | - .. centered:: |_4| |pytorch| |_4|
20 | - .. centered:: |_4| |jax| |_4|
21 | * - :doc:`Running standard scaler ` |_4|
22 | - .. centered:: :math:`\blacksquare`
23 | - .. centered:: :math:`\blacksquare`
24 |
--------------------------------------------------------------------------------
/docs/source/api/resources/schedulers/kl_adaptive.rst:
--------------------------------------------------------------------------------
1 | KL Adaptive
2 | ===========
3 |
4 | Adjust the learning rate according to the value of the Kullback-Leibler (KL) divergence.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Algorithm
11 | ---------
12 |
13 | .. raw:: html
14 |
15 |
16 |
17 | Algorithm implementation
18 | ^^^^^^^^^^^^^^^^^^^^^^^^
19 |
20 | The learning rate (:math:`\eta`) at each step is modified as follows:
21 |
22 | | **IF** :math:`\; KL >` :guilabel:`kl_factor` :guilabel:`kl_threshold` **THEN**
23 | | :math:`\eta_{t + 1} = \max(\eta_t \,/` :guilabel:`lr_factor` :math:`,` :guilabel:`min_lr` :math:`)`
24 | | **IF** :math:`\; KL <` :guilabel:`kl_threshold` :math:`/` :guilabel:`kl_factor` **THEN**
25 | | :math:`\eta_{t + 1} = \min(` :guilabel:`lr_factor` :math:`\eta_t,` :guilabel:`max_lr` :math:`)`
26 |
27 | .. raw:: html
28 |
29 |
30 |
31 | Usage
32 | -----
33 |
34 | The learning rate scheduler usage is defined in each agent's configuration dictionary. The scheduler class is set under the :literal:`"learning_rate_scheduler"` key and its arguments are set under the :literal:`"learning_rate_scheduler_kwargs"` key as a keyword argument dictionary, without specifying the optimizer (first argument). The following examples show how to set the scheduler for an agent:
35 |
36 | .. tabs::
37 |
38 | .. group-tab:: |_4| |pytorch| |_4|
39 |
40 | .. code-block:: python
41 | :emphasize-lines: 2, 5-6
42 |
43 | # import the scheduler class
44 | from skrl.resources.schedulers.torch import KLAdaptiveLR
45 |
46 | cfg = DEFAULT_CONFIG.copy()
47 | cfg["learning_rate_scheduler"] = KLAdaptiveLR
48 | cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
49 |
50 | .. group-tab:: |_4| |jax| |_4|
51 |
52 | .. code-block:: python
53 | :emphasize-lines: 2, 5-6
54 |
55 | # import the scheduler class
56 | from skrl.resources.schedulers.jax import KLAdaptiveLR # or kl_adaptive (Optax style)
57 |
58 | cfg = DEFAULT_CONFIG.copy()
59 | cfg["learning_rate_scheduler"] = KLAdaptiveLR
60 | cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
61 |
62 | .. raw:: html
63 |
64 |
65 |
66 | API (PyTorch)
67 | -------------
68 |
69 | .. autoclass:: skrl.resources.schedulers.torch.kl_adaptive.KLAdaptiveLR
70 | :show-inheritance:
71 | :inherited-members:
72 | :members:
73 |
74 | .. raw:: html
75 |
76 |
77 |
78 | API (JAX)
79 | ---------
80 |
81 | .. autofunction:: skrl.resources.schedulers.jax.kl_adaptive.KLAdaptiveLR
82 |
--------------------------------------------------------------------------------
/docs/source/api/trainers.rst:
--------------------------------------------------------------------------------
1 | Trainers
2 | ========
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | Sequential
8 | Parallel
9 | Step
10 | Manual training
11 |
12 | Trainers are responsible for orchestrating and managing the training/evaluation of agents and their interactions with the environment.
13 |
14 | .. raw:: html
15 |
16 |
17 |
18 | .. list-table::
19 | :header-rows: 1
20 |
21 | * - Trainers
22 | - .. centered:: |_4| |pytorch| |_4|
23 | - .. centered:: |_4| |jax| |_4|
24 | * - :doc:`Sequential trainer `
25 | - .. centered:: :math:`\blacksquare`
26 | - .. centered:: :math:`\blacksquare`
27 | * - :doc:`Parallel trainer `
28 | - .. centered:: :math:`\blacksquare`
29 | - .. centered:: :math:`\square`
30 | * - :doc:`Step trainer `
31 | - .. centered:: :math:`\blacksquare`
32 | - .. centered:: :math:`\blacksquare`
33 | * - :doc:`Manual training `
34 | - .. centered:: :math:`\blacksquare`
35 | - .. centered:: :math:`\blacksquare`
36 |
37 | Base class
38 | ----------
39 |
40 | .. note::
41 |
42 | This is the base class for all the other classes in this module.
43 | It provides the basic functionality for the other classes.
44 | **It is not intended to be used directly**.
45 |
46 | .. raw:: html
47 |
48 |
49 |
50 | Basic inheritance usage
51 | ^^^^^^^^^^^^^^^^^^^^^^^
52 |
53 | .. tabs::
54 |
55 | .. group-tab:: |_4| |pytorch| |_4|
56 |
57 | .. literalinclude:: ../snippets/trainer.py
58 | :language: python
59 | :start-after: [pytorch-start-base]
60 | :end-before: [pytorch-end-base]
61 |
62 | .. group-tab:: |_4| |jax| |_4|
63 |
64 | .. literalinclude:: ../snippets/trainer.py
65 | :language: python
66 | :start-after: [jax-start-base]
67 | :end-before: [jax-end-base]
68 |
69 | .. raw:: html
70 |
71 |
72 |
73 | API (PyTorch)
74 | ^^^^^^^^^^^^^
75 |
76 | .. autoclass:: skrl.trainers.torch.base.Trainer
77 | :undoc-members:
78 | :show-inheritance:
79 | :inherited-members:
80 | :private-members: _setup_agents
81 | :members:
82 |
83 | .. automethod:: __str__
84 |
85 | .. raw:: html
86 |
87 |
88 |
89 | API (JAX)
90 | ^^^^^^^^^
91 |
92 | .. autoclass:: skrl.trainers.jax.base.Trainer
93 | :undoc-members:
94 | :show-inheritance:
95 | :inherited-members:
96 | :private-members: _setup_agents
97 | :members:
98 |
99 | .. automethod:: __str__
100 |
--------------------------------------------------------------------------------
/docs/source/api/trainers/manual.rst:
--------------------------------------------------------------------------------
1 | Manual training
2 | ===============
3 |
4 | Train agents by manually controlling the training/evaluation loop.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Concept
11 | -------
12 |
13 | .. image:: ../../_static/imgs/manual_trainer-light.svg
14 | :width: 100%
15 | :align: center
16 | :class: only-light
17 | :alt: Manual trainer
18 |
19 | .. image:: ../../_static/imgs/manual_trainer-dark.svg
20 | :width: 100%
21 | :align: center
22 | :class: only-dark
23 | :alt: Manual trainer
24 |
25 | .. raw:: html
26 |
27 |
28 |
29 | Usage
30 | -----
31 |
32 | .. tabs::
33 |
34 | .. group-tab:: |_4| |pytorch| |_4|
35 |
36 | .. tabs::
37 |
38 | .. group-tab:: Training
39 |
40 | .. literalinclude:: ../../snippets/trainer.py
41 | :language: python
42 | :start-after: [pytorch-start-manual-training]
43 | :end-before: [pytorch-end-manual-training]
44 |
45 | .. group-tab:: Evaluation
46 |
47 | .. literalinclude:: ../../snippets/trainer.py
48 | :language: python
49 | :start-after: [pytorch-start-manual-evaluation]
50 | :end-before: [pytorch-end-manual-evaluation]
51 |
52 | .. group-tab:: |_4| |jax| |_4|
53 |
54 | .. tabs::
55 |
56 | .. group-tab:: Training
57 |
58 | .. literalinclude:: ../../snippets/trainer.py
59 | :language: python
60 | :start-after: [jax-start-manual-training]
61 | :end-before: [jax-end-manual-training]
62 |
63 | .. group-tab:: Evaluation
64 |
65 | .. literalinclude:: ../../snippets/trainer.py
66 | :language: python
67 | :start-after: [jax-start-manual-evaluation]
68 | :end-before: [jax-end-manual-evaluation]
69 |
70 | .. raw:: html
71 |
72 |
73 |
--------------------------------------------------------------------------------
/docs/source/api/trainers/parallel.rst:
--------------------------------------------------------------------------------
1 | Parallel trainer
2 | ================
3 |
4 | Train agents in parallel using multiple processes.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Concept
11 | -------
12 |
13 | .. image:: ../../_static/imgs/parallel_trainer-light.svg
14 | :width: 100%
15 | :align: center
16 | :class: only-light
17 | :alt: Parallel trainer
18 |
19 | .. image:: ../../_static/imgs/parallel_trainer-dark.svg
20 | :width: 100%
21 | :align: center
22 | :class: only-dark
23 | :alt: Parallel trainer
24 |
25 | .. raw:: html
26 |
27 |
28 |
29 | Usage
30 | -----
31 |
32 | .. note::
33 |
34 | Each process adds a GPU memory overhead (~1GB, although it can be much higher) due to PyTorch's CUDA kernels. See PyTorch `Issue #12873 `_ for more details
35 |
36 | .. note::
37 |
38 | At the moment, only simultaneous training and evaluation of agents with local memory (no memory sharing) is implemented
39 |
40 | .. tabs::
41 |
42 | .. group-tab:: |_4| |pytorch| |_4|
43 |
44 | .. literalinclude:: ../../snippets/trainer.py
45 | :language: python
46 | :start-after: [pytorch-start-parallel]
47 | :end-before: [pytorch-end-parallel]
48 |
49 | .. raw:: html
50 |
51 |
52 |
53 | Configuration
54 | -------------
55 |
56 | .. literalinclude:: ../../../../skrl/trainers/torch/parallel.py
57 | :language: python
58 | :start-after: [start-config-dict-torch]
59 | :end-before: [end-config-dict-torch]
60 |
61 | .. raw:: html
62 |
63 |
64 |
65 | API (PyTorch)
66 | -------------
67 |
68 | .. autoclass:: skrl.trainers.torch.parallel.PARALLEL_TRAINER_DEFAULT_CONFIG
69 |
70 | .. autoclass:: skrl.trainers.torch.parallel.ParallelTrainer
71 | :undoc-members:
72 | :show-inheritance:
73 | :inherited-members:
74 | :members:
75 |
--------------------------------------------------------------------------------
/docs/source/api/trainers/sequential.rst:
--------------------------------------------------------------------------------
1 | Sequential trainer
2 | ==================
3 |
4 | Train agents sequentially (i.e., one after the other in each interaction with the environment).
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Concept
11 | -------
12 |
13 | .. image:: ../../_static/imgs/sequential_trainer-light.svg
14 | :width: 100%
15 | :align: center
16 | :class: only-light
17 | :alt: Sequential trainer
18 |
19 | .. image:: ../../_static/imgs/sequential_trainer-dark.svg
20 | :width: 100%
21 | :align: center
22 | :class: only-dark
23 | :alt: Sequential trainer
24 |
25 | .. raw:: html
26 |
27 |
28 |
29 | Usage
30 | -----
31 |
32 | .. tabs::
33 |
34 | .. group-tab:: |_4| |pytorch| |_4|
35 |
36 | .. literalinclude:: ../../snippets/trainer.py
37 | :language: python
38 | :start-after: [pytorch-start-sequential]
39 | :end-before: [pytorch-end-sequential]
40 |
41 | .. group-tab:: |_4| |jax| |_4|
42 |
43 | .. literalinclude:: ../../snippets/trainer.py
44 | :language: python
45 | :start-after: [jax-start-sequential]
46 | :end-before: [jax-end-sequential]
47 |
48 | .. raw:: html
49 |
50 |
51 |
52 | Configuration
53 | -------------
54 |
55 | .. literalinclude:: ../../../../skrl/trainers/torch/sequential.py
56 | :language: python
57 | :start-after: [start-config-dict-torch]
58 | :end-before: [end-config-dict-torch]
59 |
60 | .. raw:: html
61 |
62 |
63 |
64 | API (PyTorch)
65 | -------------
66 |
67 | .. autoclass:: skrl.trainers.torch.sequential.SEQUENTIAL_TRAINER_DEFAULT_CONFIG
68 |
69 | .. autoclass:: skrl.trainers.torch.sequential.SequentialTrainer
70 | :undoc-members:
71 | :show-inheritance:
72 | :inherited-members:
73 | :members:
74 |
75 | .. raw:: html
76 |
77 |
78 |
79 | API (JAX)
80 | ---------
81 |
82 | .. autoclass:: skrl.trainers.jax.sequential.SEQUENTIAL_TRAINER_DEFAULT_CONFIG
83 |
84 | .. autoclass:: skrl.trainers.jax.sequential.SequentialTrainer
85 | :undoc-members:
86 | :show-inheritance:
87 | :inherited-members:
88 | :members:
89 |
--------------------------------------------------------------------------------
/docs/source/api/trainers/step.rst:
--------------------------------------------------------------------------------
1 | Step trainer
2 | ============
3 |
4 | Train agents controlling the training/evaluation loop step-by-step.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Concept
11 | -------
12 |
13 | .. image:: ../../_static/imgs/manual_trainer-light.svg
14 | :width: 100%
15 | :align: center
16 | :class: only-light
17 | :alt: Step-by-step trainer
18 |
19 | .. image:: ../../_static/imgs/manual_trainer-dark.svg
20 | :width: 100%
21 | :align: center
22 | :class: only-dark
23 | :alt: Step-by-step trainer
24 |
25 | .. raw:: html
26 |
27 |
28 |
29 | Usage
30 | -----
31 |
32 | .. tabs::
33 |
34 | .. group-tab:: |_4| |pytorch| |_4|
35 |
36 | .. literalinclude:: ../../snippets/trainer.py
37 | :language: python
38 | :start-after: [pytorch-start-step]
39 | :end-before: [pytorch-end-step]
40 |
41 | .. group-tab:: |_4| |jax| |_4|
42 |
43 | .. literalinclude:: ../../snippets/trainer.py
44 | :language: python
45 | :start-after: [jax-start-step]
46 | :end-before: [jax-end-step]
47 |
48 | .. raw:: html
49 |
50 |
51 |
52 | Configuration
53 | -------------
54 |
55 | .. literalinclude:: ../../../../skrl/trainers/torch/step.py
56 | :language: python
57 | :start-after: [start-config-dict-torch]
58 | :end-before: [end-config-dict-torch]
59 |
60 | .. raw:: html
61 |
62 |
63 |
64 | API (PyTorch)
65 | -------------
66 |
67 | .. autoclass:: skrl.trainers.torch.step.STEP_TRAINER_DEFAULT_CONFIG
68 |
69 | .. autoclass:: skrl.trainers.torch.step.StepTrainer
70 | :undoc-members:
71 | :show-inheritance:
72 | :inherited-members:
73 | :members:
74 |
75 | .. raw:: html
76 |
77 |
78 |
79 | API (JAX)
80 | ---------
81 |
82 | .. autoclass:: skrl.trainers.jax.step.STEP_TRAINER_DEFAULT_CONFIG
83 |
84 | .. autoclass:: skrl.trainers.jax.step.StepTrainer
85 | :undoc-members:
86 | :show-inheritance:
87 | :inherited-members:
88 | :members:
89 |
--------------------------------------------------------------------------------
/docs/source/api/utils.rst:
--------------------------------------------------------------------------------
1 | Utils and configurations
2 | ========================
3 |
4 | .. toctree::
5 | :hidden:
6 |
7 | ML frameworks configuration
8 | Random seed
9 | Spaces
10 | Model instantiators
11 | Runner
12 | Distributed runs
13 | Memory and Tensorboard file post-processing
14 | Hugging Face integration
15 | Isaac Gym utils
16 | Omniverse Isaac Gym utils
17 |
18 | A set of utilities and configurations for managing an RL setup is provided as part of the library.
19 |
20 | .. raw:: html
21 |
22 |
23 |
24 | .. list-table::
25 | :header-rows: 1
26 |
27 | * - Configurations
28 | - .. centered:: |_4| |pytorch| |_4|
29 | - .. centered:: |_4| |jax| |_4|
30 | * - :doc:`ML frameworks ` configuration |_5| |_5| |_5| |_5| |_5| |_2|
31 | - .. centered:: :math:`\blacksquare`
32 | - .. centered:: :math:`\blacksquare`
33 |
34 | .. list-table::
35 | :header-rows: 1
36 |
37 | * - Utils
38 | - .. centered:: |_4| |pytorch| |_4|
39 | - .. centered:: |_4| |jax| |_4|
40 | * - :doc:`Random seed `
41 | - .. centered:: :math:`\blacksquare`
42 | - .. centered:: :math:`\blacksquare`
43 | * - :doc:`Spaces `
44 | - .. centered:: :math:`\blacksquare`
45 | - .. centered:: :math:`\blacksquare`
46 | * - :doc:`Model instantiators `
47 | - .. centered:: :math:`\blacksquare`
48 | - .. centered:: :math:`\blacksquare`
49 | * - :doc:`Runner `
50 | - .. centered:: :math:`\blacksquare`
51 | - .. centered:: :math:`\blacksquare`
52 | * - :doc:`Distributed runs `
53 | - .. centered:: :math:`\blacksquare`
54 | - .. centered:: :math:`\blacksquare`
55 | * - Memory and Tensorboard :doc:`file post-processing `
56 | - .. centered:: :math:`\blacksquare`
57 | - .. centered:: :math:`\blacksquare`
58 | * - :doc:`Hugging Face integration `
59 | - .. centered:: :math:`\blacksquare`
60 | - .. centered:: :math:`\blacksquare`
61 | * - :doc:`Isaac Gym utils `
62 | - .. centered:: :math:`\blacksquare`
63 | - .. centered:: :math:`\blacksquare`
64 | * - :doc:`Omniverse Isaac Gym utils `
65 | - .. centered:: :math:`\blacksquare`
66 | - .. centered:: :math:`\blacksquare`
67 |
--------------------------------------------------------------------------------
/docs/source/api/utils/distributed.rst:
--------------------------------------------------------------------------------
1 | Distributed
2 | ===========
3 |
4 | Utilities to start multiple processes from a single program invocation in distributed learning
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | PyTorch
11 | -------
12 |
13 | PyTorch provides a Python module/console script to launch distributed runs. Visit PyTorch's `torchrun `_ documentation for more details.
14 |
15 | The following environment variables available in all processes can be accessed through the library:
16 |
17 | * ``LOCAL_RANK`` (accessible via :data:`skrl.config.torch.local_rank`): The local rank.
18 | * ``RANK`` (accessible via :data:`skrl.config.torch.rank`): The global rank.
19 | * ``WORLD_SIZE`` (accessible via :data:`skrl.config.torch.world_size`): The world size (total number of workers in the job).
20 |
21 | JAX
22 | ---
23 |
24 | According to the JAX documentation for `multi-host and multi-process environments `_, JAX doesn't automatically start multiple processes from a single program invocation.
25 |
26 | Therefore, in order to make distributed learning simpler, this library provides a module (based on the PyTorch ``torch.distributed.run`` module) for launching multi-host and multi-process learning directly from the command line.
27 |
28 | This module launches, in multiple processes, the same JAX Python program (Single Program, Multiple Data (SPMD) parallel computation technique) that defines the following environment variables for each process:
29 |
30 | * ``JAX_LOCAL_RANK`` (accessible via :data:`skrl.config.jax.local_rank`): The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node).
31 | * ``JAX_RANK`` (accessible via :data:`skrl.config.jax.rank`): The rank (ID number) of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes).
32 | * ``JAX_WORLD_SIZE`` (accessible via :data:`skrl.config.jax.world_size`): The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes).
33 | * ``JAX_COORDINATOR_ADDR`` (accessible via :data:`skrl.config.jax.coordinator_address`): IP address where process 0 will start a JAX coordinator service.
34 | * ``JAX_COORDINATOR_PORT`` (accessible via :data:`skrl.config.jax.coordinator_address`): Port where process 0 will start a JAX coordinator service.
35 |
36 | .. raw:: html
37 |
38 |
39 |
40 | Usage
41 | ^^^^^
42 |
43 | .. code-block:: bash
44 |
45 | $ python -m skrl.utils.distributed.jax --help
46 |
47 | .. literalinclude:: ../../snippets/utils_distributed.txt
48 | :language: text
49 | :start-after: [start-distributed-launcher-jax]
50 | :end-before: [end-distributed-launcher-jax]
51 |
52 | .. raw:: html
53 |
54 |
55 |
56 | API
57 | ^^^
58 |
59 | .. autofunction:: skrl.utils.distributed.jax.launcher.launch
60 |
--------------------------------------------------------------------------------
/docs/source/api/utils/huggingface.rst:
--------------------------------------------------------------------------------
1 | Hugging Face integration
2 | ========================
3 |
4 | The Hugging Face (HF) Hub is a platform for building, training, and deploying ML models, as well as accessing a variety of datasets and metrics for further analysis and validation.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Integration
11 | -----------
12 |
13 | .. raw:: html
14 |
15 |
16 |
17 | Download model from HF Hub
18 | ^^^^^^^^^^^^^^^^^^^^^^^^^^
19 |
20 | Several skrl-trained models (agent checkpoints) for different environments/tasks of Gym/Gymnasium, Isaac Gym, Omniverse Isaac Gym, etc. are available in the Hugging Face Hub
21 |
22 | These models can be used as comparison benchmarks, for collecting environment transitions in memory (for offline reinforcement learning, e.g.) or for pre-initialization of agents for performing similar tasks, among others
23 |
24 | Visit the `skrl organization on the Hugging Face Hub `_ to access publicly available models!
25 |
26 | .. raw:: html
27 |
28 |
29 |
30 | API
31 | ---
32 |
33 | .. autofunction:: skrl.utils.huggingface.download_model_from_huggingface
34 |
--------------------------------------------------------------------------------
/docs/source/api/utils/omniverse_isaacgym_utils.rst:
--------------------------------------------------------------------------------
1 | Omniverse Isaac Gym utils
2 | =========================
3 |
4 | Utilities for ease of programming of Omniverse Isaac Gym environments.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Control of robotic manipulators
11 | -------------------------------
12 |
13 | .. raw:: html
14 |
15 |
16 |
17 | Differential inverse kinematics
18 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
19 |
20 | This implementation attempts to unify under a single and reusable function the whole set of procedures used to compute the inverse kinematics of a robotic manipulator, originally shown in the Isaac Lab (Orbit then) framework's task space controllers section, but this time for Omniverse Isaac Gym.
21 |
22 | :math:`\Delta\theta =` :guilabel:`scale` :math:`J^\dagger \, \vec{e}`
23 |
24 | where
25 |
26 | | :math:`\qquad \Delta\theta \;` is the change in joint angles
27 | | :math:`\qquad \vec{e} \;` is the Cartesian pose error (position and orientation)
28 | | :math:`\qquad J^\dagger \;` is the pseudoinverse of the Jacobian estimated as follows:
29 |
30 | The pseudoinverse of the Jacobian (:math:`J^\dagger`) is estimated as follows:
31 |
32 | * Tanspose: :math:`\; J^\dagger = J^T`
33 | * Pseduoinverse: :math:`\; J^\dagger = J^T(JJ^T)^{-1}`
34 | * Damped least-squares: :math:`\; J^\dagger = J^T(JJ^T \, +` :guilabel:`damping`:math:`{}^2 I)^{-1}`
35 | * Singular-vale decomposition: See `buss2004introduction `_ (section 6)
36 |
37 | .. raw:: html
38 |
39 |
40 |
41 | API
42 | ^^^
43 |
44 | .. autofunction:: skrl.utils.omniverse_isaacgym_utils.ik
45 |
46 | .. raw:: html
47 |
48 |
49 |
50 | OmniIsaacGymEnvs-like environment instance
51 | ------------------------------------------
52 |
53 | Instantiate a VecEnvBase-based object compatible with OmniIsaacGymEnvs for use outside of the OmniIsaacGymEnvs implementation.
54 |
55 | .. raw:: html
56 |
57 |
58 |
59 | API
60 | ^^^
61 |
62 | .. autofunction:: skrl.utils.omniverse_isaacgym_utils.get_env_instance
63 |
--------------------------------------------------------------------------------
/docs/source/api/utils/postprocessing.rst:
--------------------------------------------------------------------------------
1 | File post-processing
2 | ====================
3 |
4 | Utilities for processing files generated during training/evaluation.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Exported memories
11 | -----------------
12 |
13 | This library provides an implementation for quickly loading exported memory files to inspect their contents in future post-processing steps. See the section :ref:`Library utilities (skrl.utils module) ` for a real use case
14 |
15 | .. raw:: html
16 |
17 |
18 |
19 | Usage
20 | ^^^^^
21 |
22 | .. tabs::
23 |
24 | .. tab:: PyTorch (.pt)
25 |
26 | .. literalinclude:: ../../snippets/utils_postprocessing.py
27 | :language: python
28 | :emphasize-lines: 1, 5-6
29 | :start-after: [start-memory_file_iterator-torch]
30 | :end-before: [end-memory_file_iterator-torch]
31 |
32 | .. tab:: NumPy (.npz)
33 |
34 | .. literalinclude:: ../../snippets/utils_postprocessing.py
35 | :language: python
36 | :emphasize-lines: 1, 5-6
37 | :start-after: [start-memory_file_iterator-numpy]
38 | :end-before: [end-memory_file_iterator-numpy]
39 |
40 | .. tab:: Comma-separated values (.csv)
41 |
42 | .. literalinclude:: ../../snippets/utils_postprocessing.py
43 | :language: python
44 | :emphasize-lines: 1, 5-6
45 | :start-after: [start-memory_file_iterator-csv]
46 | :end-before: [end-memory_file_iterator-csv]
47 |
48 | .. raw:: html
49 |
50 |
51 |
52 | API
53 | ^^^
54 |
55 | .. autoclass:: skrl.utils.postprocessing.MemoryFileIterator
56 | :undoc-members:
57 | :show-inheritance:
58 | :inherited-members:
59 | :private-members: _format_numpy, _format_torch, _format_csv
60 | :members:
61 |
62 | .. automethod:: __iter__
63 | .. automethod:: __next__
64 |
65 | .. raw:: html
66 |
67 |
68 |
69 | Tensorboard files
70 | -----------------
71 |
72 | This library provides an implementation for quickly loading Tensorboard files to inspect their contents in future post-processing steps. See the section :ref:`Library utilities (skrl.utils module) ` for a real use case
73 |
74 | .. raw:: html
75 |
76 |
77 |
78 | Requirements
79 | ^^^^^^^^^^^^
80 |
81 | This utility requires the `TensorFlow `_ package to be installed to load and parse Tensorboard files:
82 |
83 | .. code-block:: bash
84 |
85 | pip install tensorflow
86 |
87 | .. raw:: html
88 |
89 |
90 |
91 | Usage
92 | ^^^^^
93 |
94 | .. tabs::
95 |
96 | .. tab:: Tensorboard (events.out.tfevents.*)
97 |
98 | .. literalinclude:: ../../snippets/utils_postprocessing.py
99 | :language: python
100 | :emphasize-lines: 1, 5-7
101 | :start-after: [start-tensorboard_file_iterator-list]
102 | :end-before: [end-tensorboard_file_iterator-list]
103 |
104 | .. raw:: html
105 |
106 |
107 |
108 | API
109 | ^^^
110 |
111 | .. autoclass:: skrl.utils.postprocessing.TensorboardFileIterator
112 | :undoc-members:
113 | :show-inheritance:
114 | :inherited-members:
115 | :members:
116 |
117 | .. automethod:: __iter__
118 | .. automethod:: __next__
119 |
--------------------------------------------------------------------------------
/docs/source/api/utils/runner.rst:
--------------------------------------------------------------------------------
1 | Runner
2 | ======
3 |
4 | Utility that configures and instantiates skrl's components to run training/evaluation workflows in a few lines of code.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Usage
11 | -----
12 |
13 | .. hint::
14 |
15 | The ``Runner`` classes encapsulates, and greatly simplifies, the definitions and instantiations needed to execute RL tasks.
16 | However, such simplification hides and makes difficult the modification and readability of the code (models, agents, etc.).
17 |
18 | For more control and readability over the RL system setup refer to the :doc:`Examples <../../intro/examples>` section's training scripts (**recommended!**)
19 |
20 | .. raw:: html
21 |
22 |
23 |
24 | .. tabs::
25 |
26 | .. group-tab:: |_4| |pytorch| |_4|
27 |
28 | .. tabs::
29 |
30 | .. group-tab:: Python code
31 |
32 | .. literalinclude:: ../../snippets/runner.txt
33 | :language: python
34 | :emphasize-lines: 1, 10, 13, 16
35 | :start-after: [start-runner-train-torch]
36 | :end-before: [end-runner-train-torch]
37 |
38 | .. group-tab:: Example .yaml file (PPO)
39 |
40 | .. literalinclude:: ../../snippets/runner.txt
41 | :language: yaml
42 | :start-after: [start-cfg-yaml]
43 | :end-before: [end-cfg-yaml]
44 |
45 | .. group-tab:: |_4| |jax| |_4|
46 |
47 | .. tabs::
48 |
49 | .. group-tab:: Python code
50 |
51 | .. literalinclude:: ../../snippets/runner.txt
52 | :language: python
53 | :emphasize-lines: 1, 10, 13, 16
54 | :start-after: [start-runner-train-jax]
55 | :end-before: [end-runner-train-jax]
56 |
57 | .. group-tab:: Example .yaml file (PPO)
58 |
59 | .. literalinclude:: ../../snippets/runner.txt
60 | :language: yaml
61 | :start-after: [start-cfg-yaml]
62 | :end-before: [end-cfg-yaml]
63 |
64 | API (PyTorch)
65 | -------------
66 |
67 | .. autoclass:: skrl.utils.runner.torch.Runner
68 | :show-inheritance:
69 | :members:
70 |
71 | .. raw:: html
72 |
73 |
74 |
75 | API (JAX)
76 | ---------
77 |
78 | .. autoclass:: skrl.utils.runner.jax.Runner
79 | :show-inheritance:
80 | :members:
81 |
--------------------------------------------------------------------------------
/docs/source/api/utils/seed.rst:
--------------------------------------------------------------------------------
1 | Random seed
2 | ===========
3 |
4 | Utilities for seeding the random number generators.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Seed the random number generators
11 | ---------------------------------
12 |
13 | .. raw:: html
14 |
15 |
16 |
17 | API
18 | ^^^
19 |
20 | .. autofunction:: skrl.utils.set_seed
21 |
--------------------------------------------------------------------------------
/docs/source/api/utils/spaces.rst:
--------------------------------------------------------------------------------
1 | Spaces
2 | ======
3 |
4 | Utilities to operate on Gymnasium `spaces `_.
5 |
6 | .. raw:: html
7 |
8 |
9 |
10 | Overview
11 | --------
12 |
13 | The utilities described in this section supports the following Gymnasium spaces:
14 |
15 | .. list-table::
16 | :header-rows: 1
17 |
18 | * - Type
19 | - Supported spaces
20 | * - Fundamental
21 | - :py:class:`~gymnasium.spaces.Box`, :py:class:`~gymnasium.spaces.Discrete`, and :py:class:`~gymnasium.spaces.MultiDiscrete`
22 | * - Composite
23 | - :py:class:`~gymnasium.spaces.Dict` and :py:class:`~gymnasium.spaces.Tuple`
24 |
25 | The following table provides a snapshot of the space sample conversion functions:
26 |
27 | .. list-table::
28 | :header-rows: 1
29 |
30 | * - Input
31 | - Function
32 | - Output
33 | * - Space (NumPy / int)
34 | - :py:func:`~skrl.utils.spaces.torch.tensorize_space`
35 | - Space (PyTorch / JAX)
36 | * - Space (PyTorch / JAX)
37 | - :py:func:`~skrl.utils.spaces.torch.untensorize_space`
38 | - Space (NumPy / int)
39 | * - Space (PyTorch / JAX)
40 | - :py:func:`~skrl.utils.spaces.torch.flatten_tensorized_space`
41 | - PyTorch tensor / JAX array
42 | * - PyTorch tensor / JAX array
43 | - :py:func:`~skrl.utils.spaces.torch.unflatten_tensorized_space`
44 | - Space (PyTorch / JAX)
45 |
46 | .. raw:: html
47 |
48 |
49 |
50 | API (PyTorch)
51 | -------------
52 |
53 | .. autofunction:: skrl.utils.spaces.torch.compute_space_size
54 |
55 | .. autofunction:: skrl.utils.spaces.torch.convert_gym_space
56 |
57 | .. autofunction:: skrl.utils.spaces.torch.flatten_tensorized_space
58 |
59 | .. autofunction:: skrl.utils.spaces.torch.sample_space
60 |
61 | .. autofunction:: skrl.utils.spaces.torch.tensorize_space
62 |
63 | .. autofunction:: skrl.utils.spaces.torch.unflatten_tensorized_space
64 |
65 | .. autofunction:: skrl.utils.spaces.torch.untensorize_space
66 |
67 | .. raw:: html
68 |
69 |
70 |
71 | API (JAX)
72 | ---------
73 |
74 | .. autofunction:: skrl.utils.spaces.jax.compute_space_size
75 |
76 | .. autofunction:: skrl.utils.spaces.jax.convert_gym_space
77 |
78 | .. autofunction:: skrl.utils.spaces.jax.flatten_tensorized_space
79 |
80 | .. autofunction:: skrl.utils.spaces.jax.sample_space
81 |
82 | .. autofunction:: skrl.utils.spaces.jax.tensorize_space
83 |
84 | .. autofunction:: skrl.utils.spaces.jax.unflatten_tensorized_space
85 |
86 | .. autofunction:: skrl.utils.spaces.jax.untensorize_space
87 |
--------------------------------------------------------------------------------
/docs/source/examples/gym/jax_gym_cartpole_cem.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | import flax.linen as nn
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | # import the skrl components to build the RL system
8 | from skrl import config
9 | from skrl.agents.jax.cem import CEM, CEM_DEFAULT_CONFIG
10 | from skrl.envs.wrappers.jax import wrap_env
11 | from skrl.memories.jax import RandomMemory
12 | from skrl.models.jax import CategoricalMixin, Model
13 | from skrl.trainers.jax import SequentialTrainer
14 | from skrl.utils import set_seed
15 |
16 |
17 | config.jax.backend = "numpy" # or "jax"
18 |
19 |
20 | # seed for reproducibility
21 | set_seed() # e.g. `set_seed(42)` for fixed seed
22 |
23 |
24 | # define model (categorical model) using mixin
25 | class Policy(CategoricalMixin, Model):
26 | def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs):
27 | Model.__init__(self, observation_space, action_space, device, **kwargs)
28 | CategoricalMixin.__init__(self, unnormalized_log_prob)
29 |
30 | @nn.compact
31 | def __call__(self, inputs, role):
32 | x = nn.relu(nn.Dense(64)(inputs["states"]))
33 | x = nn.relu(nn.Dense(64)(x))
34 | x = nn.Dense(self.num_actions)(x)
35 | return x, {}
36 |
37 |
38 | # load and wrap the gym environment.
39 | # note: the environment version may change depending on the gym version
40 | try:
41 | env = gym.make("CartPole-v0")
42 | except gym.error.DeprecatedEnv as e:
43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("CartPole-v")][0]
44 | print("CartPole-v0 not found. Trying {}".format(env_id))
45 | env = gym.make(env_id)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate a memory as experience replay
52 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False)
53 |
54 |
55 | # instantiate the agent's model (function approximator).
56 | # CEM requires 1 model, visit its documentation for more details
57 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models
58 | models = {}
59 | models["policy"] = Policy(env.observation_space, env.action_space, device)
60 |
61 | # instantiate models' state dict
62 | for role, model in models.items():
63 | model.init_state_dict(role)
64 |
65 | # initialize models' parameters (weights and biases)
66 | for model in models.values():
67 | model.init_parameters(method_name="normal", stddev=0.1)
68 |
69 |
70 | # configure and instantiate the agent (visit its documentation to see all the options)
71 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters
72 | cfg = CEM_DEFAULT_CONFIG.copy()
73 | cfg["rollouts"] = 1000
74 | cfg["learning_starts"] = 100
75 | # logging to TensorBoard and write checkpoints (in timesteps)
76 | cfg["experiment"]["write_interval"] = 1000
77 | cfg["experiment"]["checkpoint_interval"] = 5000
78 | cfg["experiment"]["directory"] = "runs/jax/CartPole"
79 |
80 | agent = CEM(models=models,
81 | memory=memory,
82 | cfg=cfg,
83 | observation_space=env.observation_space,
84 | action_space=env.action_space,
85 | device=device)
86 |
87 |
88 | # configure and instantiate the RL trainer
89 | cfg_trainer = {"timesteps": 100000, "headless": True}
90 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
91 |
92 | # start training
93 | trainer.train()
94 |
--------------------------------------------------------------------------------
/docs/source/examples/gym/torch_gym_cartpole_cem.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # import the skrl components to build the RL system
7 | from skrl.agents.torch.cem import CEM, CEM_DEFAULT_CONFIG
8 | from skrl.envs.wrappers.torch import wrap_env
9 | from skrl.memories.torch import RandomMemory
10 | from skrl.models.torch import CategoricalMixin, Model
11 | from skrl.trainers.torch import SequentialTrainer
12 | from skrl.utils import set_seed
13 |
14 |
15 | # seed for reproducibility
16 | set_seed() # e.g. `set_seed(42)` for fixed seed
17 |
18 |
19 | # define model (categorical model) using mixin
20 | class Policy(CategoricalMixin, Model):
21 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True):
22 | Model.__init__(self, observation_space, action_space, device)
23 | CategoricalMixin.__init__(self, unnormalized_log_prob)
24 |
25 | self.linear_layer_1 = nn.Linear(self.num_observations, 64)
26 | self.linear_layer_2 = nn.Linear(64, 64)
27 | self.output_layer = nn.Linear(64, self.num_actions)
28 |
29 | def compute(self, inputs, role):
30 | x = F.relu(self.linear_layer_1(inputs["states"]))
31 | x = F.relu(self.linear_layer_2(x))
32 | return self.output_layer(x), {}
33 |
34 |
35 | # load and wrap the gym environment.
36 | # note: the environment version may change depending on the gym version
37 | try:
38 | env = gym.make("CartPole-v0")
39 | except gym.error.DeprecatedEnv as e:
40 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("CartPole-v")][0]
41 | print("CartPole-v0 not found. Trying {}".format(env_id))
42 | env = gym.make(env_id)
43 | env = wrap_env(env)
44 |
45 | device = env.device
46 |
47 |
48 | # instantiate a memory as experience replay
49 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False)
50 |
51 |
52 | # instantiate the agent's model (function approximator).
53 | # CEM requires 1 model, visit its documentation for more details
54 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models
55 | models = {}
56 | models["policy"] = Policy(env.observation_space, env.action_space, device)
57 |
58 | # initialize models' parameters (weights and biases)
59 | for model in models.values():
60 | model.init_parameters(method_name="normal_", mean=0.0, std=0.1)
61 |
62 |
63 | # configure and instantiate the agent (visit its documentation to see all the options)
64 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters
65 | cfg = CEM_DEFAULT_CONFIG.copy()
66 | cfg["rollouts"] = 1000
67 | cfg["learning_starts"] = 100
68 | # logging to TensorBoard and write checkpoints (in timesteps)
69 | cfg["experiment"]["write_interval"] = 1000
70 | cfg["experiment"]["checkpoint_interval"] = 5000
71 | cfg["experiment"]["directory"] = "runs/torch/CartPole"
72 |
73 | agent = CEM(models=models,
74 | memory=memory,
75 | cfg=cfg,
76 | observation_space=env.observation_space,
77 | action_space=env.action_space,
78 | device=device)
79 |
80 |
81 | # configure and instantiate the RL trainer
82 | cfg_trainer = {"timesteps": 100000, "headless": True}
83 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
84 |
85 | # start training
86 | trainer.train()
87 |
--------------------------------------------------------------------------------
/docs/source/examples/gym/torch_gym_frozen_lake_q_learning.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | import torch
4 |
5 | # import the skrl components to build the RL system
6 | from skrl.agents.torch.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG
7 | from skrl.envs.wrappers.torch import wrap_env
8 | from skrl.models.torch import Model, TabularMixin
9 | from skrl.trainers.torch import SequentialTrainer
10 | from skrl.utils import set_seed
11 |
12 |
13 | # seed for reproducibility
14 | set_seed() # e.g. `set_seed(42)` for fixed seed
15 |
16 |
17 | # define model (tabular model) using mixin
18 | class EpilonGreedyPolicy(TabularMixin, Model):
19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
20 | Model.__init__(self, observation_space, action_space, device)
21 | TabularMixin.__init__(self, num_envs)
22 |
23 | self.epsilon = epsilon
24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions),
25 | dtype=torch.float32, device=self.device)
26 |
27 | def compute(self, inputs, role):
28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
29 | dim=-1, keepdim=True).view(-1,1)
30 |
31 | # choose random actions for exploration according to epsilon
32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
33 | if indexes.numel():
34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
35 | return actions, {}
36 |
37 |
38 | # load and wrap the gym environment.
39 | # note: the environment version may change depending on the gym version
40 | try:
41 | env = gym.make("FrozenLake-v0")
42 | except gym.error.DeprecatedEnv as e:
43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("FrozenLake-v")][0]
44 | print("FrozenLake-v0 not found. Trying {}".format(env_id))
45 | env = gym.make(env_id)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate the agent's model (table)
52 | # Q-learning requires 1 model, visit its documentation for more details
53 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#models
54 | models = {}
55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#configuration-and-hyperparameters
60 | cfg = Q_LEARNING_DEFAULT_CONFIG.copy()
61 | cfg["discount_factor"] = 0.999
62 | cfg["alpha"] = 0.4
63 | # logging to TensorBoard and write checkpoints (in timesteps)
64 | cfg["experiment"]["write_interval"] = 1600
65 | cfg["experiment"]["checkpoint_interval"] = 8000
66 | cfg["experiment"]["directory"] = "runs/torch/FrozenLake"
67 |
68 | agent = Q_LEARNING(models=models,
69 | memory=None,
70 | cfg=cfg,
71 | observation_space=env.observation_space,
72 | action_space=env.action_space,
73 | device=device)
74 |
75 |
76 | # configure and instantiate the RL trainer
77 | cfg_trainer = {"timesteps": 80000, "headless": True}
78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
79 |
80 | # start training
81 | trainer.train()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/gym/torch_gym_frozen_lake_vector_q_learning.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | import torch
4 |
5 | # import the skrl components to build the RL system
6 | from skrl.agents.torch.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG
7 | from skrl.envs.wrappers.torch import wrap_env
8 | from skrl.models.torch import Model, TabularMixin
9 | from skrl.trainers.torch import SequentialTrainer
10 | from skrl.utils import set_seed
11 |
12 |
13 | # seed for reproducibility
14 | set_seed() # e.g. `set_seed(42)` for fixed seed
15 |
16 |
17 | # define model (tabular model) using mixin
18 | class EpilonGreedyPolicy(TabularMixin, Model):
19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
20 | Model.__init__(self, observation_space, action_space, device)
21 | TabularMixin.__init__(self, num_envs)
22 |
23 | self.epsilon = epsilon
24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions),
25 | dtype=torch.float32, device=self.device)
26 |
27 | def compute(self, inputs, role):
28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
29 | dim=-1, keepdim=True).view(-1,1)
30 |
31 | # choose random actions for exploration according to epsilon
32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
33 | if indexes.numel():
34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
35 | return actions, {}
36 |
37 |
38 | # load and wrap the gym environment.
39 | # note: the environment version may change depending on the gym version
40 | try:
41 | env = gym.vector.make("FrozenLake-v0", num_envs=10, asynchronous=False)
42 | except gym.error.DeprecatedEnv as e:
43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("FrozenLake-v")][0]
44 | print("FrozenLake-v0 not found. Trying {}".format(env_id))
45 | env = gym.vector.make(env_id, num_envs=10, asynchronous=False)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate the agent's model (table)
52 | # Q-learning requires 1 model, visit its documentation for more details
53 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#models
54 | models = {}
55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#configuration-and-hyperparameters
60 | cfg = Q_LEARNING_DEFAULT_CONFIG.copy()
61 | cfg["discount_factor"] = 0.999
62 | cfg["alpha"] = 0.4
63 | # logging to TensorBoard and write checkpoints (in timesteps)
64 | cfg["experiment"]["write_interval"] = 1600
65 | cfg["experiment"]["checkpoint_interval"] = 8000
66 | cfg["experiment"]["directory"] = "runs/torch/FrozenLake"
67 |
68 | agent = Q_LEARNING(models=models,
69 | memory=None,
70 | cfg=cfg,
71 | observation_space=env.observation_space,
72 | action_space=env.action_space,
73 | device=device)
74 |
75 |
76 | # configure and instantiate the RL trainer
77 | cfg_trainer = {"timesteps": 80000, "headless": True}
78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
79 |
80 | # start training
81 | trainer.train()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/gym/torch_gym_taxi_sarsa.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | import torch
4 |
5 | # import the skrl components to build the RL system
6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG
7 | from skrl.envs.wrappers.torch import wrap_env
8 | from skrl.models.torch import Model, TabularMixin
9 | from skrl.trainers.torch import SequentialTrainer
10 | from skrl.utils import set_seed
11 |
12 |
13 | # seed for reproducibility
14 | set_seed() # e.g. `set_seed(42)` for fixed seed
15 |
16 |
17 | # define model (tabular model) using mixin
18 | class EpilonGreedyPolicy(TabularMixin, Model):
19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
20 | Model.__init__(self, observation_space, action_space, device)
21 | TabularMixin.__init__(self, num_envs)
22 |
23 | self.epsilon = epsilon
24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions),
25 | dtype=torch.float32, device=self.device)
26 |
27 | def compute(self, inputs, role):
28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
29 | dim=-1, keepdim=True).view(-1,1)
30 |
31 | # choose random actions for exploration according to epsilon
32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
33 | if indexes.numel():
34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
35 | return actions, {}
36 |
37 |
38 | # load and wrap the gym environment.
39 | # note: the environment version may change depending on the gym version
40 | try:
41 | env = gym.make("Taxi-v3")
42 | except gym.error.DeprecatedEnv as e:
43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("Taxi-v")][0]
44 | print("Taxi-v3 not found. Trying {}".format(env_id))
45 | env = gym.make(env_id)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate the agent's model (table)
52 | # SARSA requires 1 model, visit its documentation for more details
53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models
54 | models = {}
55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters
60 | cfg = SARSA_DEFAULT_CONFIG.copy()
61 | cfg["discount_factor"] = 0.999
62 | cfg["alpha"] = 0.4
63 | # logging to TensorBoard and write checkpoints (in timesteps)
64 | cfg["experiment"]["write_interval"] = 1600
65 | cfg["experiment"]["checkpoint_interval"] = 8000
66 | cfg["experiment"]["directory"] = "runs/torch/Taxi"
67 |
68 | agent = SARSA(models=models,
69 | memory=None,
70 | cfg=cfg,
71 | observation_space=env.observation_space,
72 | action_space=env.action_space,
73 | device=device)
74 |
75 |
76 | # configure and instantiate the RL trainer
77 | cfg_trainer = {"timesteps": 80000, "headless": True}
78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
79 |
80 | # start training
81 | trainer.train()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/gym/torch_gym_taxi_vector_sarsa.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | import torch
4 |
5 | # import the skrl components to build the RL system
6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG
7 | from skrl.envs.wrappers.torch import wrap_env
8 | from skrl.models.torch import Model, TabularMixin
9 | from skrl.trainers.torch import SequentialTrainer
10 | from skrl.utils import set_seed
11 |
12 |
13 | # seed for reproducibility
14 | set_seed() # e.g. `set_seed(42)` for fixed seed
15 |
16 |
17 | # define model (tabular model) using mixin
18 | class EpilonGreedyPolicy(TabularMixin, Model):
19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
20 | Model.__init__(self, observation_space, action_space, device)
21 | TabularMixin.__init__(self, num_envs)
22 |
23 | self.epsilon = epsilon
24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions),
25 | dtype=torch.float32, device=self.device)
26 |
27 | def compute(self, inputs, role):
28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
29 | dim=-1, keepdim=True).view(-1,1)
30 |
31 | # choose random actions for exploration according to epsilon
32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
33 | if indexes.numel():
34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
35 | return actions, {}
36 |
37 |
38 | # load and wrap the gym environment.
39 | # note: the environment version may change depending on the gym version
40 | try:
41 | env = gym.vector.make("Taxi-v3", num_envs=10, asynchronous=False)
42 | except gym.error.DeprecatedEnv as e:
43 | env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("Taxi-v")][0]
44 | print("Taxi-v3 not found. Trying {}".format(env_id))
45 | env = gym.vector.make(env_id, num_envs=10, asynchronous=False)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate the agent's model (table)
52 | # SARSA requires 1 model, visit its documentation for more details
53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models
54 | models = {}
55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters
60 | cfg = SARSA_DEFAULT_CONFIG.copy()
61 | cfg["discount_factor"] = 0.999
62 | cfg["alpha"] = 0.4
63 | # logging to TensorBoard and write checkpoints (in timesteps)
64 | cfg["experiment"]["write_interval"] = 1600
65 | cfg["experiment"]["checkpoint_interval"] = 8000
66 | cfg["experiment"]["directory"] = "runs/torch/Taxi"
67 |
68 | agent = SARSA(models=models,
69 | memory=None,
70 | cfg=cfg,
71 | observation_space=env.observation_space,
72 | action_space=env.action_space,
73 | device=device)
74 |
75 |
76 | # configure and instantiate the RL trainer
77 | cfg_trainer = {"timesteps": 80000, "headless": True}
78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
79 |
80 | # start training
81 | trainer.train()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/gymnasium/jax_gymnasium_cartpole_cem.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | import flax.linen as nn
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | # import the skrl components to build the RL system
8 | from skrl import config
9 | from skrl.agents.jax.cem import CEM, CEM_DEFAULT_CONFIG
10 | from skrl.envs.wrappers.jax import wrap_env
11 | from skrl.memories.jax import RandomMemory
12 | from skrl.models.jax import CategoricalMixin, Model
13 | from skrl.trainers.jax import SequentialTrainer
14 | from skrl.utils import set_seed
15 |
16 |
17 | config.jax.backend = "numpy" # or "jax"
18 |
19 |
20 | # seed for reproducibility
21 | set_seed() # e.g. `set_seed(42)` for fixed seed
22 |
23 |
24 | # define model (categorical model) using mixin
25 | class Policy(CategoricalMixin, Model):
26 | def __init__(self, observation_space, action_space, device=None, unnormalized_log_prob=True, **kwargs):
27 | Model.__init__(self, observation_space, action_space, device, **kwargs)
28 | CategoricalMixin.__init__(self, unnormalized_log_prob)
29 |
30 | @nn.compact
31 | def __call__(self, inputs, role):
32 | x = nn.relu(nn.Dense(64)(inputs["states"]))
33 | x = nn.relu(nn.Dense(64)(x))
34 | x = nn.Dense(self.num_actions)(x)
35 | return x, {}
36 |
37 |
38 | # load and wrap the gymnasium environment.
39 | # note: the environment version may change depending on the gymnasium version
40 | try:
41 | env = gym.make("CartPole-v1")
42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e:
43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("CartPole-v")][0]
44 | print("CartPole-v0 not found. Trying {}".format(env_id))
45 | env = gym.make(env_id)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate a memory as experience replay
52 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False)
53 |
54 |
55 | # instantiate the agent's model (function approximator).
56 | # CEM requires 1 model, visit its documentation for more details
57 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models
58 | models = {}
59 | models["policy"] = Policy(env.observation_space, env.action_space, device)
60 |
61 | # instantiate models' state dict
62 | for role, model in models.items():
63 | model.init_state_dict(role)
64 |
65 | # initialize models' parameters (weights and biases)
66 | for model in models.values():
67 | model.init_parameters(method_name="normal", stddev=0.1)
68 |
69 |
70 | # configure and instantiate the agent (visit its documentation to see all the options)
71 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters
72 | cfg = CEM_DEFAULT_CONFIG.copy()
73 | cfg["rollouts"] = 1000
74 | cfg["learning_starts"] = 100
75 | # logging to TensorBoard and write checkpoints (in timesteps)
76 | cfg["experiment"]["write_interval"] = 1000
77 | cfg["experiment"]["checkpoint_interval"] = 5000
78 | cfg["experiment"]["directory"] = "runs/jax/CartPole"
79 |
80 | agent = CEM(models=models,
81 | memory=memory,
82 | cfg=cfg,
83 | observation_space=env.observation_space,
84 | action_space=env.action_space,
85 | device=device)
86 |
87 |
88 | # configure and instantiate the RL trainer
89 | cfg_trainer = {"timesteps": 100000, "headless": True}
90 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
91 |
92 | # start training
93 | trainer.train()
94 |
--------------------------------------------------------------------------------
/docs/source/examples/gymnasium/torch_gymnasium_cartpole_cem.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # import the skrl components to build the RL system
7 | from skrl.agents.torch.cem import CEM, CEM_DEFAULT_CONFIG
8 | from skrl.envs.wrappers.torch import wrap_env
9 | from skrl.memories.torch import RandomMemory
10 | from skrl.models.torch import CategoricalMixin, Model
11 | from skrl.trainers.torch import SequentialTrainer
12 | from skrl.utils import set_seed
13 |
14 |
15 | # seed for reproducibility
16 | set_seed() # e.g. `set_seed(42)` for fixed seed
17 |
18 |
19 | # define model (categorical model) using mixin
20 | class Policy(CategoricalMixin, Model):
21 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True):
22 | Model.__init__(self, observation_space, action_space, device)
23 | CategoricalMixin.__init__(self, unnormalized_log_prob)
24 |
25 | self.linear_layer_1 = nn.Linear(self.num_observations, 64)
26 | self.linear_layer_2 = nn.Linear(64, 64)
27 | self.output_layer = nn.Linear(64, self.num_actions)
28 |
29 | def compute(self, inputs, role):
30 | x = F.relu(self.linear_layer_1(inputs["states"]))
31 | x = F.relu(self.linear_layer_2(x))
32 | return self.output_layer(x), {}
33 |
34 |
35 | # load and wrap the gymnasium environment.
36 | # note: the environment version may change depending on the gymnasium version
37 | try:
38 | env = gym.make("CartPole-v1")
39 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e:
40 | env_id = [spec for spec in gym.envs.registry if spec.startswith("CartPole-v")][0]
41 | print("CartPole-v0 not found. Trying {}".format(env_id))
42 | env = gym.make(env_id)
43 | env = wrap_env(env)
44 |
45 | device = env.device
46 |
47 |
48 | # instantiate a memory as experience replay
49 | memory = RandomMemory(memory_size=1000, num_envs=env.num_envs, device=device, replacement=False)
50 |
51 |
52 | # instantiate the agent's model (function approximator).
53 | # CEM requires 1 model, visit its documentation for more details
54 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#models
55 | models = {}
56 | models["policy"] = Policy(env.observation_space, env.action_space, device)
57 |
58 | # initialize models' parameters (weights and biases)
59 | for model in models.values():
60 | model.init_parameters(method_name="normal_", mean=0.0, std=0.1)
61 |
62 |
63 | # configure and instantiate the agent (visit its documentation to see all the options)
64 | # https://skrl.readthedocs.io/en/latest/api/agents/cem.html#configuration-and-hyperparameters
65 | cfg = CEM_DEFAULT_CONFIG.copy()
66 | cfg["rollouts"] = 1000
67 | cfg["learning_starts"] = 100
68 | # logging to TensorBoard and write checkpoints (in timesteps)
69 | cfg["experiment"]["write_interval"] = 1000
70 | cfg["experiment"]["checkpoint_interval"] = 5000
71 | cfg["experiment"]["directory"] = "runs/torch/CartPole"
72 |
73 | agent = CEM(models=models,
74 | memory=memory,
75 | cfg=cfg,
76 | observation_space=env.observation_space,
77 | action_space=env.action_space,
78 | device=device)
79 |
80 |
81 | # configure and instantiate the RL trainer
82 | cfg_trainer = {"timesteps": 100000, "headless": True}
83 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
84 |
85 | # start training
86 | trainer.train()
87 |
--------------------------------------------------------------------------------
/docs/source/examples/gymnasium/torch_gymnasium_frozen_lake_q_learning.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | import torch
4 |
5 | # import the skrl components to build the RL system
6 | from skrl.agents.torch.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG
7 | from skrl.envs.wrappers.torch import wrap_env
8 | from skrl.models.torch import Model, TabularMixin
9 | from skrl.trainers.torch import SequentialTrainer
10 | from skrl.utils import set_seed
11 |
12 |
13 | # seed for reproducibility
14 | set_seed() # e.g. `set_seed(42)` for fixed seed
15 |
16 |
17 | # define model (tabular model) using mixin
18 | class EpilonGreedyPolicy(TabularMixin, Model):
19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
20 | Model.__init__(self, observation_space, action_space, device)
21 | TabularMixin.__init__(self, num_envs)
22 |
23 | self.epsilon = epsilon
24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions),
25 | dtype=torch.float32, device=self.device)
26 |
27 | def compute(self, inputs, role):
28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
29 | dim=-1, keepdim=True).view(-1,1)
30 |
31 | # choose random actions for exploration according to epsilon
32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
33 | if indexes.numel():
34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
35 | return actions, {}
36 |
37 |
38 | # load and wrap the gymnasium environment.
39 | # note: the environment version may change depending on the gymnasium version
40 | try:
41 | env = gym.make("FrozenLake-v0")
42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e:
43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("FrozenLake-v")][0]
44 | print("FrozenLake-v0 not found. Trying {}".format(env_id))
45 | env = gym.make(env_id)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate the agent's model (table)
52 | # Q-learning requires 1 model, visit its documentation for more details
53 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#models
54 | models = {}
55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/q_learning.html#configuration-and-hyperparameters
60 | cfg = Q_LEARNING_DEFAULT_CONFIG.copy()
61 | cfg["discount_factor"] = 0.999
62 | cfg["alpha"] = 0.4
63 | # logging to TensorBoard and write checkpoints (in timesteps)
64 | cfg["experiment"]["write_interval"] = 1600
65 | cfg["experiment"]["checkpoint_interval"] = 8000
66 | cfg["experiment"]["directory"] = "runs/torch/FrozenLake"
67 |
68 | agent = Q_LEARNING(models=models,
69 | memory=None,
70 | cfg=cfg,
71 | observation_space=env.observation_space,
72 | action_space=env.action_space,
73 | device=device)
74 |
75 |
76 | # configure and instantiate the RL trainer
77 | cfg_trainer = {"timesteps": 80000, "headless": True}
78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
79 |
80 | # start training
81 | trainer.train()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/gymnasium/torch_gymnasium_taxi_sarsa.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | import torch
4 |
5 | # import the skrl components to build the RL system
6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG
7 | from skrl.envs.wrappers.torch import wrap_env
8 | from skrl.models.torch import Model, TabularMixin
9 | from skrl.trainers.torch import SequentialTrainer
10 | from skrl.utils import set_seed
11 |
12 |
13 | # seed for reproducibility
14 | set_seed() # e.g. `set_seed(42)` for fixed seed
15 |
16 |
17 | # define model (tabular model) using mixin
18 | class EpilonGreedyPolicy(TabularMixin, Model):
19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
20 | Model.__init__(self, observation_space, action_space, device)
21 | TabularMixin.__init__(self, num_envs)
22 |
23 | self.epsilon = epsilon
24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions),
25 | dtype=torch.float32, device=self.device)
26 |
27 | def compute(self, inputs, role):
28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
29 | dim=-1, keepdim=True).view(-1,1)
30 |
31 | # choose random actions for exploration according to epsilon
32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
33 | if indexes.numel():
34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
35 | return actions, {}
36 |
37 |
38 | # load and wrap the gymnasium environment.
39 | # note: the environment version may change depending on the gymnasium version
40 | try:
41 | env = gym.make("Taxi-v3")
42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e:
43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("Taxi-v")][0]
44 | print("Taxi-v3 not found. Trying {}".format(env_id))
45 | env = gym.make(env_id)
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate the agent's model (table)
52 | # SARSA requires 1 model, visit its documentation for more details
53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models
54 | models = {}
55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters
60 | cfg = SARSA_DEFAULT_CONFIG.copy()
61 | cfg["discount_factor"] = 0.999
62 | cfg["alpha"] = 0.4
63 | # logging to TensorBoard and write checkpoints (in timesteps)
64 | cfg["experiment"]["write_interval"] = 1600
65 | cfg["experiment"]["checkpoint_interval"] = 8000
66 | cfg["experiment"]["directory"] = "runs/torch/Taxi"
67 |
68 | agent = SARSA(models=models,
69 | memory=None,
70 | cfg=cfg,
71 | observation_space=env.observation_space,
72 | action_space=env.action_space,
73 | device=device)
74 |
75 |
76 | # configure and instantiate the RL trainer
77 | cfg_trainer = {"timesteps": 80000, "headless": True}
78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
79 |
80 | # start training
81 | trainer.train()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/gymnasium/torch_gymnasium_taxi_vector_sarsa.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | import torch
4 |
5 | # import the skrl components to build the RL system
6 | from skrl.agents.torch.sarsa import SARSA, SARSA_DEFAULT_CONFIG
7 | from skrl.envs.wrappers.torch import wrap_env
8 | from skrl.models.torch import Model, TabularMixin
9 | from skrl.trainers.torch import SequentialTrainer
10 | from skrl.utils import set_seed
11 |
12 |
13 | # seed for reproducibility
14 | set_seed() # e.g. `set_seed(42)` for fixed seed
15 |
16 |
17 | # define model (tabular model) using mixin
18 | class EpilonGreedyPolicy(TabularMixin, Model):
19 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
20 | Model.__init__(self, observation_space, action_space, device)
21 | TabularMixin.__init__(self, num_envs)
22 |
23 | self.epsilon = epsilon
24 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions),
25 | dtype=torch.float32, device=self.device)
26 |
27 | def compute(self, inputs, role):
28 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), inputs["states"]],
29 | dim=-1, keepdim=True).view(-1,1)
30 |
31 | # choose random actions for exploration according to epsilon
32 | indexes = (torch.rand(inputs["states"].shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
33 | if indexes.numel():
34 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
35 | return actions, {}
36 |
37 |
38 | # load and wrap the gymnasium environment.
39 | # note: the environment version may change depending on the gymnasium version
40 | try:
41 | env = gym.make_vec("Taxi-v3", num_envs=10, vectorization_mode="sync")
42 | except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e:
43 | env_id = [spec for spec in gym.envs.registry if spec.startswith("Taxi-v")][0]
44 | print("Taxi-v3 not found. Trying {}".format(env_id))
45 | env = gym.make_vec(env_id, num_envs=10, vectorization_mode="sync")
46 | env = wrap_env(env)
47 |
48 | device = env.device
49 |
50 |
51 | # instantiate the agent's model (table)
52 | # SARSA requires 1 model, visit its documentation for more details
53 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#models
54 | models = {}
55 | models["policy"] = EpilonGreedyPolicy(env.observation_space, env.action_space, device, num_envs=env.num_envs, epsilon=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/sarsa.html#configuration-and-hyperparameters
60 | cfg = SARSA_DEFAULT_CONFIG.copy()
61 | cfg["discount_factor"] = 0.999
62 | cfg["alpha"] = 0.4
63 | # logging to TensorBoard and write checkpoints (in timesteps)
64 | cfg["experiment"]["write_interval"] = 1600
65 | cfg["experiment"]["checkpoint_interval"] = 8000
66 | cfg["experiment"]["directory"] = "runs/torch/Taxi"
67 |
68 | agent = SARSA(models=models,
69 | memory=None,
70 | cfg=cfg,
71 | observation_space=env.observation_space,
72 | action_space=env.action_space,
73 | device=device)
74 |
75 |
76 | # configure and instantiate the RL trainer
77 | cfg_trainer = {"timesteps": 80000, "headless": True}
78 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
79 |
80 | # start training
81 | trainer.train()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/real_world/kuka_lbr_iiwa/reaching_iiwa_real_skrl_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # Import the skrl components to build the RL system
5 | from skrl.models.torch import Model, GaussianMixin
6 | from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
7 | from skrl.resources.preprocessors.torch import RunningStandardScaler
8 | from skrl.trainers.torch import SequentialTrainer
9 | from skrl.envs.torch import wrap_env
10 |
11 |
12 | # Define only the policy for evaluation
13 | class Policy(GaussianMixin, Model):
14 | def __init__(self, observation_space, action_space, device, clip_actions=False,
15 | clip_log_std=True, min_log_std=-20, max_log_std=2):
16 | Model.__init__(self, observation_space, action_space, device)
17 | GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
18 |
19 | self.net = nn.Sequential(nn.Linear(self.num_observations, 256),
20 | nn.ELU(),
21 | nn.Linear(256, 128),
22 | nn.ELU(),
23 | nn.Linear(128, 64),
24 | nn.ELU(),
25 | nn.Linear(64, self.num_actions))
26 | self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
27 |
28 | def compute(self, inputs, role):
29 | return self.net(inputs["states"]), self.log_std_parameter, {}
30 |
31 |
32 | # Load the environment
33 | from reaching_iiwa_real_env import ReachingIiwa
34 |
35 | control_space = "joint" # joint or cartesian
36 | env = ReachingIiwa(control_space=control_space)
37 |
38 | # wrap the environment
39 | env = wrap_env(env)
40 |
41 | device = env.device
42 |
43 |
44 | # Instantiate the agent's policy.
45 | # PPO requires 2 models, visit its documentation for more details
46 | # https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#spaces-and-models
47 | models_ppo = {}
48 | models_ppo["policy"] = Policy(env.observation_space, env.action_space, device)
49 |
50 |
51 | # Configure and instantiate the agent.
52 | # Only modify some of the default configuration, visit its documentation to see all the options
53 | # https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#configuration-and-hyperparameters
54 | cfg_ppo = PPO_DEFAULT_CONFIG.copy()
55 | cfg_ppo["random_timesteps"] = 0
56 | cfg_ppo["learning_starts"] = 0
57 | cfg_ppo["state_preprocessor"] = RunningStandardScaler
58 | cfg_ppo["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
59 | # logging to TensorBoard each 32 timesteps an ignore checkpoints
60 | cfg_ppo["experiment"]["write_interval"] = 32
61 | cfg_ppo["experiment"]["checkpoint_interval"] = 0
62 |
63 | agent = PPO(models=models_ppo,
64 | memory=None,
65 | cfg=cfg_ppo,
66 | observation_space=env.observation_space,
67 | action_space=env.action_space,
68 | device=device)
69 |
70 | # load checkpoints
71 | if control_space == "joint":
72 | agent.load("./agent_joint.pt")
73 | elif control_space == "cartesian":
74 | agent.load("./agent_cartesian.pt")
75 |
76 |
77 | # Configure and instantiate the RL trainer
78 | cfg_trainer = {"timesteps": 1000, "headless": True}
79 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent)
80 |
81 | # start evaluation
82 | trainer.eval()
83 |
--------------------------------------------------------------------------------
/docs/source/examples/shimmy/jax_shimmy_atari_pong_dqn.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | import flax.linen as nn
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | # import the skrl components to build the RL system
8 | from skrl import config
9 | from skrl.agents.jax.dqn import DQN, DQN_DEFAULT_CONFIG
10 | from skrl.envs.wrappers.jax import wrap_env
11 | from skrl.memories.jax import RandomMemory
12 | from skrl.models.jax import DeterministicMixin, Model
13 | from skrl.trainers.jax import SequentialTrainer
14 | from skrl.utils import set_seed
15 |
16 |
17 | config.jax.backend = "numpy" # or "jax"
18 |
19 |
20 | # seed for reproducibility
21 | set_seed() # e.g. `set_seed(42)` for fixed seed
22 |
23 |
24 | # define model (deterministic model) using mixin
25 | class QNetwork(DeterministicMixin, Model):
26 | def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
27 | Model.__init__(self, observation_space, action_space, device, **kwargs)
28 | DeterministicMixin.__init__(self, clip_actions)
29 |
30 | @nn.compact
31 | def __call__(self, inputs, role):
32 | x = nn.relu(nn.Dense(64)(inputs["states"]))
33 | x = nn.relu(nn.Dense(64)(x))
34 | x = nn.Dense(self.num_actions)(x)
35 | return x, {}
36 |
37 |
38 | # load and wrap the environment
39 | env = gym.make("ALE/Pong-v5")
40 | env = wrap_env(env)
41 |
42 | device = env.device
43 |
44 |
45 | # instantiate a memory as experience replay
46 | memory = RandomMemory(memory_size=15000, num_envs=env.num_envs, device=device, replacement=False)
47 |
48 |
49 | # instantiate the agent's models (function approximators).
50 | # DQN requires 2 models, visit its documentation for more details
51 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#models
52 | models = {}
53 | models["q_network"] = QNetwork(env.observation_space, env.action_space, device)
54 | models["target_q_network"] = QNetwork(env.observation_space, env.action_space, device)
55 |
56 | # instantiate models' state dict
57 | for role, model in models.items():
58 | model.init_state_dict(role)
59 |
60 | # initialize models' parameters (weights and biases)
61 | for model in models.values():
62 | model.init_parameters(method_name="normal", stddev=0.1)
63 |
64 |
65 | # configure and instantiate the agent (visit its documentation to see all the options)
66 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#configuration-and-hyperparameters
67 | cfg = DQN_DEFAULT_CONFIG.copy()
68 | cfg["learning_starts"] = 100
69 | cfg["exploration"]["final_epsilon"] = 0.04
70 | cfg["exploration"]["timesteps"] = 1500
71 | # logging to TensorBoard and write checkpoints (in timesteps)
72 | cfg["experiment"]["write_interval"] = 1000
73 | cfg["experiment"]["checkpoint_interval"] = 5000
74 | cfg["experiment"]["directory"] = "runs/torch/ALE_Pong"
75 |
76 | agent = DQN(models=models,
77 | memory=memory,
78 | cfg=cfg,
79 | observation_space=env.observation_space,
80 | action_space=env.action_space,
81 | device=device)
82 |
83 |
84 | # configure and instantiate the RL trainer
85 | cfg_trainer = {"timesteps": 50000, "headless": True}
86 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
87 |
88 | # start training
89 | trainer.train()
90 |
--------------------------------------------------------------------------------
/docs/source/examples/shimmy/torch_shimmy_atari_pong_dqn.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | # import the skrl components to build the RL system
7 | from skrl.agents.torch.dqn import DQN, DQN_DEFAULT_CONFIG
8 | from skrl.envs.wrappers.torch import wrap_env
9 | from skrl.memories.torch import RandomMemory
10 | from skrl.models.torch import DeterministicMixin, Model
11 | from skrl.trainers.torch import SequentialTrainer
12 | from skrl.utils import set_seed
13 |
14 |
15 | # seed for reproducibility
16 | set_seed() # e.g. `set_seed(42)` for fixed seed
17 |
18 |
19 | # define model (deterministic model) using mixin
20 | class QNetwork(DeterministicMixin, Model):
21 | def __init__(self, observation_space, action_space, device, clip_actions=False):
22 | Model.__init__(self, observation_space, action_space, device)
23 | DeterministicMixin.__init__(self, clip_actions)
24 |
25 | self.net = nn.Sequential(nn.Linear(self.num_observations, 64),
26 | nn.ReLU(),
27 | nn.Linear(64, 64),
28 | nn.ReLU(),
29 | nn.Linear(64, self.num_actions))
30 |
31 | def compute(self, inputs, role):
32 | return self.net(inputs["states"]), {}
33 |
34 |
35 | # load and wrap the environment
36 | env = gym.make("ALE/Pong-v5")
37 | env = wrap_env(env)
38 |
39 | device = env.device
40 |
41 |
42 | # instantiate a memory as experience replay
43 | memory = RandomMemory(memory_size=15000, num_envs=env.num_envs, device=device, replacement=False)
44 |
45 |
46 | # instantiate the agent's models (function approximators).
47 | # DQN requires 2 models, visit its documentation for more details
48 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#models
49 | models = {}
50 | models["q_network"] = QNetwork(env.observation_space, env.action_space, device)
51 | models["target_q_network"] = QNetwork(env.observation_space, env.action_space, device)
52 |
53 | # initialize models' parameters (weights and biases)
54 | for model in models.values():
55 | model.init_parameters(method_name="normal_", mean=0.0, std=0.1)
56 |
57 |
58 | # configure and instantiate the agent (visit its documentation to see all the options)
59 | # https://skrl.readthedocs.io/en/latest/api/agents/dqn.html#configuration-and-hyperparameters
60 | cfg = DQN_DEFAULT_CONFIG.copy()
61 | cfg["learning_starts"] = 100
62 | cfg["exploration"]["initial_epsilon"] = 1.0
63 | cfg["exploration"]["final_epsilon"] = 0.04
64 | cfg["exploration"]["timesteps"] = 1500
65 | # logging to TensorBoard and write checkpoints (in timesteps)
66 | cfg["experiment"]["write_interval"] = 1000
67 | cfg["experiment"]["checkpoint_interval"] = 5000
68 | cfg["experiment"]["directory"] = "runs/torch/ALE_Pong"
69 |
70 | agent = DQN(models=models,
71 | memory=memory,
72 | cfg=cfg,
73 | observation_space=env.observation_space,
74 | action_space=env.action_space,
75 | device=device)
76 |
77 |
78 | # configure and instantiate the RL trainer
79 | cfg_trainer = {"timesteps": 50000, "headless": True}
80 | trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
81 |
82 | # start training
83 | trainer.train()
84 |
--------------------------------------------------------------------------------
/docs/source/examples/utils/tensorboard_file_iterator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from skrl.utils import postprocessing
5 |
6 |
7 | labels = []
8 | rewards = []
9 |
10 | # load the Tensorboard files and iterate over them (tag: "Reward / Total reward (mean)")
11 | tensorboard_iterator = postprocessing.TensorboardFileIterator("runs/*/events.out.tfevents.*",
12 | tags=["Reward / Total reward (mean)"])
13 | for dirname, data in tensorboard_iterator:
14 | rewards.append(data["Reward / Total reward (mean)"])
15 | labels.append(dirname)
16 |
17 | # convert to numpy arrays and compute mean and std
18 | rewards = np.array(rewards)
19 | mean = np.mean(rewards[:,:,1], axis=0)
20 | std = np.std(rewards[:,:,1], axis=0)
21 |
22 | # create two subplots (one for each reward and one for the mean)
23 | fig, ax = plt.subplots(1, 2, figsize=(15, 5))
24 |
25 | # plot the rewards for each experiment
26 | for reward, label in zip(rewards, labels):
27 | ax[0].plot(reward[:,0], reward[:,1], label=label)
28 |
29 | ax[0].set_title("Total reward (for each experiment)")
30 | ax[0].set_xlabel("Timesteps")
31 | ax[0].set_ylabel("Reward")
32 | ax[0].grid(True)
33 | ax[0].legend()
34 |
35 | # plot the mean and std (across experiments)
36 | ax[1].fill_between(rewards[0,:,0], mean - std, mean + std, alpha=0.5, label="std")
37 | ax[1].plot(rewards[0,:,0], mean, label="mean")
38 |
39 | ax[1].set_title("Total reward (mean and std of all experiments)")
40 | ax[1].set_xlabel("Timesteps")
41 | ax[1].set_ylabel("Reward")
42 | ax[1].grid(True)
43 | ax[1].legend()
44 |
45 | # show and save the figure
46 | plt.show()
47 | plt.savefig("total_reward.png")
48 |
--------------------------------------------------------------------------------
/docs/source/snippets/isaacgym_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from isaacgym import gymapi
3 |
4 | from skrl.utils import isaacgym_utils
5 |
6 |
7 | # create a web viewer instance
8 | web_viewer = isaacgym_utils.WebViewer()
9 |
10 | # configure and create simulation
11 | sim_params = gymapi.SimParams()
12 | sim_params.up_axis = gymapi.UP_AXIS_Z
13 | sim_params.gravity = gymapi.Vec3(0.0, 0.0, -9.8)
14 | sim_params.physx.solver_type = 1
15 | sim_params.physx.num_position_iterations = 4
16 | sim_params.physx.num_velocity_iterations = 1
17 | sim_params.physx.use_gpu = True
18 | sim_params.use_gpu_pipeline = True
19 |
20 | gym = gymapi.acquire_gym()
21 | sim = gym.create_sim(compute_device=0, graphics_device=0, type=gymapi.SIM_PHYSX, params=sim_params)
22 |
23 | # setup num_envs and env's grid
24 | num_envs = 1
25 | spacing = 2.0
26 | env_lower = gymapi.Vec3(-spacing, -spacing, 0.0)
27 | env_upper = gymapi.Vec3(spacing, 0.0, spacing)
28 |
29 | # add ground plane
30 | plane_params = gymapi.PlaneParams()
31 | plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)
32 | gym.add_ground(sim, plane_params)
33 |
34 | envs = []
35 | cameras = []
36 |
37 | for i in range(num_envs):
38 | # create env
39 | env = gym.create_env(sim, env_lower, env_upper, int(math.sqrt(num_envs)))
40 |
41 | # add sphere
42 | pose = gymapi.Transform()
43 | pose.p, pose.r = gymapi.Vec3(0.0, 0.0, 1.0), gymapi.Quat(0.0, 0.0, 0.0, 1.0)
44 | gym.create_actor(env, gym.create_sphere(sim, 0.2, None), pose, "sphere", i, 0)
45 |
46 | # add camera
47 | cam_props = gymapi.CameraProperties()
48 | cam_props.width, cam_props.height = 300, 300
49 | cam_handle = gym.create_camera_sensor(env, cam_props)
50 | gym.set_camera_location(cam_handle, env, gymapi.Vec3(1, 1, 1), gymapi.Vec3(0, 0, 0))
51 |
52 | envs.append(env)
53 | cameras.append(cam_handle)
54 |
55 | # setup web viewer
56 | web_viewer.setup(gym, sim, envs, cameras)
57 |
58 | gym.prepare_sim(sim)
59 |
60 |
61 | for i in range(100000):
62 | gym.simulate(sim)
63 |
64 | # render the scene
65 | web_viewer.render(fetch_results=True,
66 | step_graphics=True,
67 | render_all_camera_sensors=True,
68 | wait_for_page_load=True)
69 |
--------------------------------------------------------------------------------
/docs/source/snippets/noises.py:
--------------------------------------------------------------------------------
1 | # [start-base-class-torch]
2 | from typing import Union, Tuple
3 |
4 | import torch
5 |
6 | from skrl.resources.noises.torch import Noise
7 |
8 |
9 | class CustomNoise(Noise):
10 | def __init__(self, device: Union[str, torch.device] = "cuda:0") -> None:
11 | """
12 | :param device: Device on which a torch tensor is or will be allocated (default: "cuda:0")
13 | :type device: str or torch.device, optional
14 | """
15 | super().__init__(device)
16 |
17 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
18 | """Sample noise
19 |
20 | :param size: Shape of the sampled tensor
21 | :type size: tuple or list of integers, or torch.Size
22 |
23 | :return: Sampled noise
24 | :rtype: torch.Tensor
25 | """
26 | # ================================
27 | # - sample noise
28 | # ================================
29 | # [end-base-class-torch]
30 |
31 |
32 | # [start-base-class-jax]
33 | from typing import Optional, Union, Tuple
34 |
35 | import numpy as np
36 |
37 | import jaxlib
38 | import jax.numpy as jnp
39 |
40 | from skrl.resources.noises.torch import Noise
41 |
42 |
43 | class CustomNoise(Noise):
44 | def __init__(self, device: Optional[Union[str, jaxlib.xla_extension.Device]] = None) -> None:
45 | """Custom noise
46 |
47 | :param device: Device on which a jax array is or will be allocated (default: ``None``).
48 | If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
49 | :type device: str or jaxlib.xla_extension.Device, optional
50 | """
51 | super().__init__(device)
52 |
53 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jnp.ndarray]:
54 | """Sample noise
55 |
56 | :param size: Shape of the sampled tensor
57 | :type size: tuple or list of integers
58 |
59 | :return: Sampled noise
60 | :rtype: np.ndarray or jnp.ndarray
61 | """
62 | # ================================
63 | # - sample noise
64 | # ================================
65 | # [end-base-class-jax]
66 |
67 | # =============================================================================
68 |
69 | # [torch-start-gaussian]
70 | from skrl.resources.noises.torch import GaussianNoise
71 |
72 | cfg = DEFAULT_CONFIG.copy()
73 | cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.2, device="cuda:0")
74 | # [torch-end-gaussian]
75 |
76 |
77 | # [jax-start-gaussian]
78 | from skrl.resources.noises.jax import GaussianNoise
79 |
80 | cfg = DEFAULT_CONFIG.copy()
81 | cfg["exploration"]["noise"] = GaussianNoise(mean=0, std=0.2)
82 | # [jax-end-gaussian]
83 |
84 | # =============================================================================
85 |
86 | # [torch-start-ornstein-uhlenbeck]
87 | from skrl.resources.noises.torch import OrnsteinUhlenbeckNoise
88 |
89 | cfg = DEFAULT_CONFIG.copy()
90 | cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.15, sigma=0.2, base_scale=1.0, device="cuda:0")
91 | # [torch-end-ornstein-uhlenbeck]
92 |
93 |
94 | # [jax-start-ornstein-uhlenbeck]
95 | from skrl.resources.noises.jax import OrnsteinUhlenbeckNoise
96 |
97 | cfg = DEFAULT_CONFIG.copy()
98 | cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.15, sigma=0.2, base_scale=1.0)
99 | # [jax-end-ornstein-uhlenbeck]
100 |
--------------------------------------------------------------------------------
/docs/source/snippets/tabular_model.py:
--------------------------------------------------------------------------------
1 | # [start-definition-torch]
2 | class TabularModel(TabularMixin, Model):
3 | def __init__(self, observation_space, action_space, device=None, num_envs=1):
4 | Model.__init__(self, observation_space, action_space, device)
5 | TabularMixin.__init__(self, num_envs)
6 | # [end-definition-torch]
7 |
8 | # =============================================================================
9 |
10 | # [start-epsilon-greedy-torch]
11 | import torch
12 |
13 | from skrl.models.torch import Model, TabularMixin
14 |
15 |
16 | # define the model
17 | class EpilonGreedyPolicy(TabularMixin, Model):
18 | def __init__(self, observation_space, action_space, device, num_envs=1, epsilon=0.1):
19 | Model.__init__(self, observation_space, action_space, device)
20 | TabularMixin.__init__(self, num_envs)
21 |
22 | self.epsilon = epsilon
23 | self.q_table = torch.ones((num_envs, self.num_observations, self.num_actions), dtype=torch.float32)
24 |
25 | def compute(self, inputs, role):
26 | states = inputs["states"]
27 | actions = torch.argmax(self.q_table[torch.arange(self.num_envs).view(-1, 1), states],
28 | dim=-1, keepdim=True).view(-1,1)
29 |
30 | indexes = (torch.rand(states.shape[0], device=self.device) < self.epsilon).nonzero().view(-1)
31 | if indexes.numel():
32 | actions[indexes] = torch.randint(self.num_actions, (indexes.numel(), 1), device=self.device)
33 | return actions, {}
34 |
35 |
36 | # instantiate the model (assumes there is a wrapped environment: env)
37 | policy = EpilonGreedyPolicy(observation_space=env.observation_space,
38 | action_space=env.action_space,
39 | device=env.device,
40 | num_envs=env.num_envs,
41 | epsilon=0.15)
42 | # [end-epsilon-greedy-torch]
43 |
--------------------------------------------------------------------------------
/docs/source/snippets/utils_distributed.txt:
--------------------------------------------------------------------------------
1 | [start-distributed-launcher-jax]
2 | usage: python -m skrl.utils.distributed.jax [-h] [--nnodes NNODES]
3 | [--nproc-per-node NPROC_PER_NODE] [--node-rank NODE_RANK]
4 | [--coordinator-address COORDINATOR_ADDRESS] script ...
5 |
6 | JAX Distributed Training Launcher
7 |
8 | positional arguments:
9 | script Training script path to be launched in parallel
10 | script_args Arguments for the training script
11 |
12 | options:
13 | -h, --help show this help message and exit
14 | --nnodes NNODES Number of nodes
15 | --nproc-per-node NPROC_PER_NODE, --nproc_per_node NPROC_PER_NODE
16 | Number of workers per node
17 | --node-rank NODE_RANK, --node_rank NODE_RANK
18 | Node rank for multi-node distributed training
19 | --coordinator-address COORDINATOR_ADDRESS, --coordinator_address COORDINATOR_ADDRESS
20 | IP address and port where process 0 will start a JAX service
21 | [end-distributed-launcher-jax]
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "skrl"
3 | version = "1.4.3"
4 | description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
5 | readme = "README.md"
6 | requires-python = ">=3.8"
7 | license = {text = "MIT License"}
8 | authors = [
9 | {name = "Toni-SM"},
10 | ]
11 | maintainers = [
12 | {name = "Toni-SM"},
13 | ]
14 | keywords = ["reinforcement-learning", "machine-learning", "reinforcement", "machine", "learning", "rl"]
15 | classifiers = [
16 | "License :: OSI Approved :: MIT License",
17 | "Intended Audience :: Science/Research",
18 | "Topic :: Scientific/Engineering",
19 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
20 | "Programming Language :: Python :: 3",
21 | "Operating System :: OS Independent",
22 | ]
23 | # dependencies / optional-dependencies
24 | dependencies = [
25 | "gymnasium",
26 | "packaging",
27 | "tensorboard",
28 | "tqdm",
29 | ]
30 | [project.optional-dependencies]
31 | torch = [
32 | "torch>=1.10",
33 | ]
34 | jax = [
35 | "jax>=0.4.31",
36 | "jaxlib>=0.4.31",
37 | "flax>=0.9.0",
38 | "optax",
39 | ]
40 | all = [
41 | "torch>=1.10",
42 | "jax>=0.4.31",
43 | "jaxlib>=0.4.31",
44 | "flax>=0.9.0",
45 | "optax",
46 | ]
47 | tests = [
48 | "pytest",
49 | "pytest-html",
50 | "pytest-cov",
51 | "hypothesis",
52 | ]
53 | # urls
54 | [project.urls]
55 | "Homepage" = "https://github.com/Toni-SM/skrl"
56 | "Documentation" = "https://skrl.readthedocs.io"
57 | "Discussions" = "https://github.com/Toni-SM/skrl/discussions"
58 | "Bug Reports" = "https://github.com/Toni-SM/skrl/issues"
59 | "Say Thanks!" = "https://github.com/Toni-SM"
60 | "Source" = "https://github.com/Toni-SM/skrl"
61 |
62 |
63 | [tool.black]
64 | line-length = 120
65 | extend-exclude = """
66 | (
67 | ^/docs
68 | )
69 | """
70 |
71 |
72 | [tool.codespell]
73 | # run: codespell
74 | skip = "./docs/source/_static,./docs/_build,pyproject.toml"
75 | quiet-level = 3
76 | count = true
77 |
78 |
79 | [tool.isort]
80 | profile = "black"
81 | line_length = 120
82 | lines_after_imports = 2
83 | known_test = [
84 | "warnings",
85 | "hypothesis",
86 | "pytest",
87 | ]
88 | known_annotation = ["typing"]
89 | known_framework = [
90 | "torch",
91 | "jax",
92 | "jaxlib",
93 | "flax",
94 | "optax",
95 | "numpy",
96 | ]
97 | sections = [
98 | "FUTURE",
99 | "ANNOTATION",
100 | "TEST",
101 | "STDLIB",
102 | "THIRDPARTY",
103 | "FRAMEWORK",
104 | "FIRSTPARTY",
105 | "LOCALFOLDER",
106 | ]
107 | no_lines_before = "THIRDPARTY"
108 | skip = ["docs"]
109 |
--------------------------------------------------------------------------------
/skrl/agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/agents/__init__.py
--------------------------------------------------------------------------------
/skrl/agents/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.base import Agent
2 |
--------------------------------------------------------------------------------
/skrl/agents/jax/a2c/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.a2c.a2c import A2C, A2C_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/jax/cem/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.cem.cem import CEM, CEM_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/jax/ddpg/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.ddpg.ddpg import DDPG, DDPG_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/jax/dqn/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.dqn.ddqn import DDQN, DDQN_DEFAULT_CONFIG
2 | from skrl.agents.jax.dqn.dqn import DQN, DQN_DEFAULT_CONFIG
3 |
--------------------------------------------------------------------------------
/skrl/agents/jax/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.ppo.ppo import PPO, PPO_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/jax/rpo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.rpo.rpo import RPO, RPO_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/jax/sac/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.sac.sac import SAC, SAC_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/jax/td3/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.jax.td3.td3 import TD3, TD3_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.base import Agent
2 |
--------------------------------------------------------------------------------
/skrl/agents/torch/a2c/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.a2c.a2c import A2C, A2C_DEFAULT_CONFIG
2 | from skrl.agents.torch.a2c.a2c_rnn import A2C_RNN
3 |
--------------------------------------------------------------------------------
/skrl/agents/torch/amp/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.amp.amp import AMP, AMP_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/torch/cem/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.cem.cem import CEM, CEM_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/torch/ddpg/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.ddpg.ddpg import DDPG, DDPG_DEFAULT_CONFIG
2 | from skrl.agents.torch.ddpg.ddpg_rnn import DDPG_RNN
3 |
--------------------------------------------------------------------------------
/skrl/agents/torch/dqn/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.dqn.ddqn import DDQN, DDQN_DEFAULT_CONFIG
2 | from skrl.agents.torch.dqn.dqn import DQN, DQN_DEFAULT_CONFIG
3 |
--------------------------------------------------------------------------------
/skrl/agents/torch/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.ppo.ppo import PPO, PPO_DEFAULT_CONFIG
2 | from skrl.agents.torch.ppo.ppo_rnn import PPO_RNN
3 |
--------------------------------------------------------------------------------
/skrl/agents/torch/q_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.q_learning.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/torch/rpo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.rpo.rpo import RPO, RPO_DEFAULT_CONFIG
2 | from skrl.agents.torch.rpo.rpo_rnn import RPO_RNN
3 |
--------------------------------------------------------------------------------
/skrl/agents/torch/sac/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.sac.sac import SAC, SAC_DEFAULT_CONFIG
2 | from skrl.agents.torch.sac.sac_rnn import SAC_RNN
3 |
--------------------------------------------------------------------------------
/skrl/agents/torch/sarsa/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.sarsa.sarsa import SARSA, SARSA_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/agents/torch/td3/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.td3.td3 import TD3, TD3_DEFAULT_CONFIG
2 | from skrl.agents.torch.td3.td3_rnn import TD3_RNN
3 |
--------------------------------------------------------------------------------
/skrl/agents/torch/trpo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.agents.torch.trpo.trpo import TRPO, TRPO_DEFAULT_CONFIG
2 | from skrl.agents.torch.trpo.trpo_rnn import TRPO_RNN
3 |
--------------------------------------------------------------------------------
/skrl/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/envs/__init__.py
--------------------------------------------------------------------------------
/skrl/envs/jax.py:
--------------------------------------------------------------------------------
1 | # TODO: Delete this file in future releases
2 |
3 | from skrl import logger # isort: skip
4 |
5 | logger.warning("Using `from skrl.envs.jax import ...` is deprecated and will be removed in future versions.")
6 | logger.warning(" - Import loaders using `from skrl.envs.loaders.jax import ...`")
7 | logger.warning(" - Import wrappers using `from skrl.envs.wrappers.jax import ...`")
8 |
9 |
10 | from skrl.envs.loaders.jax import (
11 | load_bidexhands_env,
12 | load_isaacgym_env_preview2,
13 | load_isaacgym_env_preview3,
14 | load_isaacgym_env_preview4,
15 | load_isaaclab_env,
16 | load_omniverse_isaacgym_env,
17 | )
18 | from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper, wrap_env
19 |
--------------------------------------------------------------------------------
/skrl/envs/loaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/envs/loaders/__init__.py
--------------------------------------------------------------------------------
/skrl/envs/loaders/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.envs.loaders.jax.bidexhands_envs import load_bidexhands_env
2 | from skrl.envs.loaders.jax.isaacgym_envs import (
3 | load_isaacgym_env_preview2,
4 | load_isaacgym_env_preview3,
5 | load_isaacgym_env_preview4,
6 | )
7 | from skrl.envs.loaders.jax.isaaclab_envs import load_isaaclab_env
8 | from skrl.envs.loaders.jax.omniverse_isaacgym_envs import load_omniverse_isaacgym_env
9 |
--------------------------------------------------------------------------------
/skrl/envs/loaders/jax/bidexhands_envs.py:
--------------------------------------------------------------------------------
1 | # since Bi-DexHands environments are implemented on top of PyTorch, the loader is the same
2 |
3 | from skrl.envs.loaders.torch import load_bidexhands_env
4 |
--------------------------------------------------------------------------------
/skrl/envs/loaders/jax/isaacgym_envs.py:
--------------------------------------------------------------------------------
1 | # since Isaac Gym (preview) environments are implemented on top of PyTorch, the loaders are the same
2 |
3 | from skrl.envs.loaders.torch import ( # isort:skip
4 | load_isaacgym_env_preview2,
5 | load_isaacgym_env_preview3,
6 | load_isaacgym_env_preview4,
7 | )
8 |
--------------------------------------------------------------------------------
/skrl/envs/loaders/jax/isaaclab_envs.py:
--------------------------------------------------------------------------------
1 | # since Isaac Lab environments are implemented on top of PyTorch, the loader is the same
2 |
3 | from skrl.envs.loaders.torch import load_isaaclab_env
4 |
--------------------------------------------------------------------------------
/skrl/envs/loaders/jax/omniverse_isaacgym_envs.py:
--------------------------------------------------------------------------------
1 | # since Omniverse Isaac Gym environments are implemented on top of PyTorch, the loader is the same
2 |
3 | from skrl.envs.loaders.torch import load_omniverse_isaacgym_env
4 |
--------------------------------------------------------------------------------
/skrl/envs/loaders/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.envs.loaders.torch.bidexhands_envs import load_bidexhands_env
2 | from skrl.envs.loaders.torch.isaacgym_envs import (
3 | load_isaacgym_env_preview2,
4 | load_isaacgym_env_preview3,
5 | load_isaacgym_env_preview4,
6 | )
7 | from skrl.envs.loaders.torch.isaaclab_envs import load_isaaclab_env
8 | from skrl.envs.loaders.torch.omniverse_isaacgym_envs import load_omniverse_isaacgym_env
9 |
--------------------------------------------------------------------------------
/skrl/envs/torch.py:
--------------------------------------------------------------------------------
1 | # TODO: Delete this file in future releases
2 |
3 | from skrl import logger # isort: skip
4 |
5 | logger.warning("Using `from skrl.envs.torch import ...` is deprecated and will be removed in future versions.")
6 | logger.warning(" - Import loaders using `from skrl.envs.loaders.torch import ...`")
7 | logger.warning(" - Import wrappers using `from skrl.envs.wrappers.torch import ...`")
8 |
9 |
10 | from skrl.envs.loaders.torch import (
11 | load_bidexhands_env,
12 | load_isaacgym_env_preview2,
13 | load_isaacgym_env_preview3,
14 | load_isaacgym_env_preview4,
15 | load_isaaclab_env,
16 | load_omniverse_isaacgym_env,
17 | )
18 | from skrl.envs.wrappers.torch import MultiAgentEnvWrapper, Wrapper, wrap_env
19 |
--------------------------------------------------------------------------------
/skrl/envs/wrappers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/envs/wrappers/__init__.py
--------------------------------------------------------------------------------
/skrl/envs/wrappers/torch/brax_envs.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple
2 |
3 | import gymnasium
4 |
5 | import torch
6 |
7 | from skrl import logger
8 | from skrl.envs.wrappers.torch.base import Wrapper
9 | from skrl.utils.spaces.torch import (
10 | convert_gym_space,
11 | flatten_tensorized_space,
12 | tensorize_space,
13 | unflatten_tensorized_space,
14 | )
15 |
16 |
17 | class BraxWrapper(Wrapper):
18 | def __init__(self, env: Any) -> None:
19 | """Brax environment wrapper
20 |
21 | :param env: The environment to wrap
22 | :type env: Any supported Brax environment
23 | """
24 | super().__init__(env)
25 |
26 | import brax.envs.wrappers.gym
27 | import brax.envs.wrappers.torch
28 |
29 | env = brax.envs.wrappers.gym.VectorGymWrapper(env)
30 | env = brax.envs.wrappers.torch.TorchWrapper(env, device=self.device)
31 | self._env = env
32 | self._unwrapped = env.unwrapped
33 |
34 | @property
35 | def observation_space(self) -> gymnasium.Space:
36 | """Observation space"""
37 | return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True)
38 |
39 | @property
40 | def action_space(self) -> gymnasium.Space:
41 | """Action space"""
42 | return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True)
43 |
44 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
45 | """Perform a step in the environment
46 |
47 | :param actions: The actions to perform
48 | :type actions: torch.Tensor
49 |
50 | :return: Observation, reward, terminated, truncated, info
51 | :rtype: tuple of torch.Tensor and any other info
52 | """
53 | observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions))
54 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation))
55 | truncated = torch.zeros_like(terminated)
56 | return observation, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info
57 |
58 | def reset(self) -> Tuple[torch.Tensor, Any]:
59 | """Reset the environment
60 |
61 | :return: Observation, info
62 | :rtype: torch.Tensor and any other info
63 | """
64 | observation = self._env.reset()
65 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation))
66 | return observation, {}
67 |
68 | def render(self, *args, **kwargs) -> None:
69 | """Render the environment"""
70 | frame = self._env.render(mode="rgb_array")
71 |
72 | # render the frame using OpenCV
73 | try:
74 | import cv2
75 |
76 | cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
77 | cv2.waitKey(1)
78 | except ImportError as e:
79 | logger.warning(f"Unable to import opencv-python: {e}. Frame will not be rendered.")
80 | return frame
81 |
82 | def close(self) -> None:
83 | """Close the environment"""
84 | # self._env.close() raises AttributeError: 'VectorGymWrapper' object has no attribute 'closed'
85 | pass
86 |
--------------------------------------------------------------------------------
/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional, Tuple
2 |
3 | import torch
4 |
5 | from skrl.envs.wrappers.torch.base import Wrapper
6 | from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space
7 |
8 |
9 | class OmniverseIsaacGymWrapper(Wrapper):
10 | def __init__(self, env: Any) -> None:
11 | """Omniverse Isaac Gym environment wrapper
12 |
13 | :param env: The environment to wrap
14 | :type env: Any supported Omniverse Isaac Gym environment
15 | """
16 | super().__init__(env)
17 |
18 | self._reset_once = True
19 | self._observations = None
20 | self._info = {}
21 |
22 | def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = None) -> None:
23 | """Run the simulation in the main thread
24 |
25 | This method is valid only for the Omniverse Isaac Gym multi-threaded environments
26 |
27 | :param trainer: Trainer which should implement a ``run`` method that initiates the RL loop on a new thread
28 | :type trainer: omni.isaac.gym.vec_env.vec_env_mt.TrainerMT, optional
29 | """
30 | self._env.run(trainer)
31 |
32 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
33 | """Perform a step in the environment
34 |
35 | :param actions: The actions to perform
36 | :type actions: torch.Tensor
37 |
38 | :return: Observation, reward, terminated, truncated, info
39 | :rtype: tuple of torch.Tensor and any other info
40 | """
41 | observations, reward, terminated, self._info = self._env.step(
42 | unflatten_tensorized_space(self.action_space, actions)
43 | )
44 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"]))
45 | truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated)
46 | return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info
47 |
48 | def reset(self) -> Tuple[torch.Tensor, Any]:
49 | """Reset the environment
50 |
51 | :return: Observation, info
52 | :rtype: torch.Tensor and any other info
53 | """
54 | if self._reset_once:
55 | observations = self._env.reset()
56 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"]))
57 | self._reset_once = False
58 | return self._observations, self._info
59 |
60 | def render(self, *args, **kwargs) -> None:
61 | """Render the environment"""
62 | return None
63 |
64 | def close(self) -> None:
65 | """Close the environment"""
66 | self._env.close()
67 |
--------------------------------------------------------------------------------
/skrl/memories/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/memories/__init__.py
--------------------------------------------------------------------------------
/skrl/memories/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.memories.jax.base import Memory # isort:skip
2 |
3 | from skrl.memories.jax.random import RandomMemory
4 |
--------------------------------------------------------------------------------
/skrl/memories/jax/random.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import jax
4 | import numpy as np
5 |
6 | from skrl.memories.jax import Memory
7 |
8 |
9 | class RandomMemory(Memory):
10 | def __init__(
11 | self,
12 | memory_size: int,
13 | num_envs: int = 1,
14 | device: Optional[jax.Device] = None,
15 | export: bool = False,
16 | export_format: str = "pt",
17 | export_directory: str = "",
18 | replacement=True,
19 | ) -> None:
20 | """Random sampling memory
21 |
22 | Sample a batch from memory randomly
23 |
24 | :param memory_size: Maximum number of elements in the first dimension of each internal storage
25 | :type memory_size: int
26 | :param num_envs: Number of parallel environments (default: ``1``)
27 | :type num_envs: int, optional
28 | :param device: Device on which an array is or will be allocated (default: ``None``)
29 | :type device: jax.Device, optional
30 | :param export: Export the memory to a file (default: ``False``).
31 | If True, the memory will be exported when the memory is filled
32 | :type export: bool, optional
33 | :param export_format: Export format (default: ``"pt"``).
34 | Supported formats: torch (pt), numpy (np), comma separated values (csv)
35 | :type export_format: str, optional
36 | :param export_directory: Directory where the memory will be exported (default: ``""``).
37 | If empty, the agent's experiment directory will be used
38 | :type export_directory: str, optional
39 | :param replacement: Flag to indicate whether the sample is with or without replacement (default: ``True``).
40 | Replacement implies that a value can be selected multiple times (the batch size is always guaranteed).
41 | Sampling without replacement will return a batch of maximum memory size if the memory size is less than the requested batch size
42 | :type replacement: bool, optional
43 |
44 | :raises ValueError: The export format is not supported
45 | """
46 | super().__init__(memory_size, num_envs, device, export, export_format, export_directory)
47 |
48 | self._replacement = replacement
49 |
50 | def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> List[List[jax.Array]]:
51 | """Sample a batch from memory randomly
52 |
53 | :param names: Tensors names from which to obtain the samples
54 | :type names: tuple or list of strings
55 | :param batch_size: Number of element to sample
56 | :type batch_size: int
57 | :param mini_batches: Number of mini-batches to sample (default: ``1``)
58 | :type mini_batches: int, optional
59 |
60 | :return: Sampled data from tensors sorted according to their position in the list of names.
61 | The sampled tensors will have the following shape: (batch size, data size)
62 | :rtype: list of jax.Array list
63 | """
64 | # generate random indexes
65 | if self._replacement:
66 | indexes = np.random.randint(0, len(self), (batch_size,))
67 | else:
68 | indexes = np.random.permutation(len(self))[:batch_size]
69 |
70 | return self.sample_by_index(names=names, indexes=indexes, mini_batches=mini_batches)
71 |
--------------------------------------------------------------------------------
/skrl/memories/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.memories.torch.base import Memory # isort:skip
2 |
3 | from skrl.memories.torch.random import RandomMemory
4 |
--------------------------------------------------------------------------------
/skrl/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/models/__init__.py
--------------------------------------------------------------------------------
/skrl/models/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.models.jax.base import Model # isort:skip
2 |
3 | from skrl.models.jax.categorical import CategoricalMixin
4 | from skrl.models.jax.deterministic import DeterministicMixin
5 | from skrl.models.jax.gaussian import GaussianMixin
6 | from skrl.models.jax.multicategorical import MultiCategoricalMixin
7 |
--------------------------------------------------------------------------------
/skrl/models/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.models.torch.base import Model # isort:skip
2 |
3 | from skrl.models.torch.categorical import CategoricalMixin
4 | from skrl.models.torch.deterministic import DeterministicMixin
5 | from skrl.models.torch.gaussian import GaussianMixin
6 | from skrl.models.torch.multicategorical import MultiCategoricalMixin
7 | from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin
8 | from skrl.models.torch.tabular import TabularMixin
9 |
--------------------------------------------------------------------------------
/skrl/multi_agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/multi_agents/__init__.py
--------------------------------------------------------------------------------
/skrl/multi_agents/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.multi_agents.jax.base import MultiAgent
2 |
--------------------------------------------------------------------------------
/skrl/multi_agents/jax/ippo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.multi_agents.jax.ippo.ippo import IPPO, IPPO_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/multi_agents/jax/mappo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.multi_agents.jax.mappo.mappo import MAPPO, MAPPO_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/multi_agents/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.multi_agents.torch.base import MultiAgent
2 |
--------------------------------------------------------------------------------
/skrl/multi_agents/torch/ippo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.multi_agents.torch.ippo.ippo import IPPO, IPPO_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/multi_agents/torch/mappo/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.multi_agents.torch.mappo.mappo import MAPPO, MAPPO_DEFAULT_CONFIG
2 |
--------------------------------------------------------------------------------
/skrl/resources/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/__init__.py
--------------------------------------------------------------------------------
/skrl/resources/noises/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/noises/__init__.py
--------------------------------------------------------------------------------
/skrl/resources/noises/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.resources.noises.jax.base import Noise # isort:skip
2 |
3 | from skrl.resources.noises.jax.gaussian import GaussianNoise
4 | from skrl.resources.noises.jax.ornstein_uhlenbeck import OrnsteinUhlenbeckNoise
5 |
--------------------------------------------------------------------------------
/skrl/resources/noises/jax/base.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import jax
4 | import numpy as np
5 |
6 | from skrl import config
7 |
8 |
9 | class Noise:
10 | def __init__(self, device: Optional[Union[str, jax.Device]] = None) -> None:
11 | """Base class representing a noise
12 |
13 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``).
14 | If None, the device will be either ``"cuda"`` if available or ``"cpu"``
15 | :type device: str or jax.Device, optional
16 |
17 | Custom noises should override the ``sample`` method::
18 |
19 | import jax
20 | from skrl.resources.noises.jax import Noise
21 |
22 | class CustomNoise(Noise):
23 | def __init__(self, device=None):
24 | super().__init__(device)
25 |
26 | def sample(self, size):
27 | return jax.random.uniform(jax.random.PRNGKey(0), size)
28 | """
29 | self._jax = config.jax.backend == "jax"
30 |
31 | self.device = config.jax.parse_device(device)
32 |
33 | def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]:
34 | """Sample a noise with the same size (shape) as the input tensor
35 |
36 | This method will call the sampling method as follows ``.sample(tensor.shape)``
37 |
38 | :param tensor: Input tensor used to determine output tensor size (shape)
39 | :type tensor: np.ndarray or jax.Array
40 |
41 | :return: Sampled noise
42 | :rtype: np.ndarray or jax.Array
43 |
44 | Example::
45 |
46 | >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2))
47 | >>> noise.sample_like(x)
48 | Array([[0.57450044, 0.09968603],
49 | [0.7419659 , 0.8941783 ],
50 | [0.59656656, 0.45325184]], dtype=float32)
51 | """
52 | return self.sample(tensor.shape)
53 |
54 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jax.Array]:
55 | """Noise sampling method to be implemented by the inheriting classes
56 |
57 | :param size: Shape of the sampled tensor
58 | :type size: tuple or list of int
59 |
60 | :raises NotImplementedError: The method is not implemented by the inheriting classes
61 |
62 | :return: Sampled noise
63 | :rtype: np.ndarray or jax.Array
64 | """
65 | raise NotImplementedError("The sampling method (.sample()) is not implemented")
66 |
--------------------------------------------------------------------------------
/skrl/resources/noises/jax/gaussian.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | from functools import partial
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 |
9 | from skrl import config
10 | from skrl.resources.noises.jax import Noise
11 |
12 |
13 | # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function
14 | @partial(jax.jit, static_argnames=("shape"))
15 | def _sample(mean, std, key, iterator, shape):
16 | subkey = jax.random.fold_in(key, iterator)
17 | return jax.random.normal(subkey, shape) * std + mean
18 |
19 |
20 | class GaussianNoise(Noise):
21 | def __init__(self, mean: float, std: float, device: Optional[Union[str, jax.Device]] = None) -> None:
22 | """Class representing a Gaussian noise
23 |
24 | :param mean: Mean of the normal distribution
25 | :type mean: float
26 | :param std: Standard deviation of the normal distribution
27 | :type std: float
28 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``).
29 | If None, the device will be either ``"cuda"`` if available or ``"cpu"``
30 | :type device: str or jax.Device, optional
31 |
32 | Example::
33 |
34 | >>> noise = GaussianNoise(mean=0, std=1)
35 | """
36 | super().__init__(device)
37 |
38 | if self._jax:
39 | self._i = 0
40 | self._key = config.jax.key
41 | self.mean = jnp.array(mean)
42 | self.std = jnp.array(std)
43 | else:
44 | self.mean = np.array(mean)
45 | self.std = np.array(std)
46 |
47 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jax.Array]:
48 | """Sample a Gaussian noise
49 |
50 | :param size: Shape of the sampled tensor
51 | :type size: tuple or list of int
52 |
53 | :return: Sampled noise
54 | :rtype: np.ndarray or jax.Array
55 |
56 | Example::
57 |
58 | >>> noise.sample((3, 2))
59 | Array([[ 0.01878439, -0.12833427],
60 | [ 0.06494182, 0.12490594],
61 | [ 0.024447 , -0.01174496]], dtype=float32)
62 |
63 | >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2))
64 | >>> noise.sample(x.shape)
65 | Array([[ 0.17988093, -1.2289404 ],
66 | [ 0.6218886 , 1.1961104 ],
67 | [ 0.23410667, -0.11247082]], dtype=float32)
68 | """
69 | if self._jax:
70 | self._i += 1
71 | return _sample(self.mean, self.std, self._key, self._i, size)
72 | return np.random.normal(self.mean, self.std, size)
73 |
--------------------------------------------------------------------------------
/skrl/resources/noises/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.resources.noises.torch.base import Noise # isort:skip
2 |
3 | from skrl.resources.noises.torch.gaussian import GaussianNoise
4 | from skrl.resources.noises.torch.ornstein_uhlenbeck import OrnsteinUhlenbeckNoise
5 |
--------------------------------------------------------------------------------
/skrl/resources/noises/torch/base.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 |
5 | from skrl import config
6 |
7 |
8 | class Noise:
9 | def __init__(self, device: Optional[Union[str, torch.device]] = None) -> None:
10 | """Base class representing a noise
11 |
12 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``).
13 | If None, the device will be either ``"cuda"`` if available or ``"cpu"``
14 | :type device: str or torch.device, optional
15 |
16 | Custom noises should override the ``sample`` method::
17 |
18 | import torch
19 | from skrl.resources.noises.torch import Noise
20 |
21 | class CustomNoise(Noise):
22 | def __init__(self, device=None):
23 | super().__init__(device)
24 |
25 | def sample(self, size):
26 | return torch.rand(size, device=self.device)
27 | """
28 | self.device = config.torch.parse_device(device)
29 |
30 | def sample_like(self, tensor: torch.Tensor) -> torch.Tensor:
31 | """Sample a noise with the same size (shape) as the input tensor
32 |
33 | This method will call the sampling method as follows ``.sample(tensor.shape)``
34 |
35 | :param tensor: Input tensor used to determine output tensor size (shape)
36 | :type tensor: torch.Tensor
37 |
38 | :return: Sampled noise
39 | :rtype: torch.Tensor
40 |
41 | Example::
42 |
43 | >>> x = torch.rand(3, 2, device="cuda:0")
44 | >>> noise.sample_like(x)
45 | tensor([[-0.0423, -0.1325],
46 | [-0.0639, -0.0957],
47 | [-0.1367, 0.1031]], device='cuda:0')
48 | """
49 | return self.sample(tensor.shape)
50 |
51 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
52 | """Noise sampling method to be implemented by the inheriting classes
53 |
54 | :param size: Shape of the sampled tensor
55 | :type size: tuple or list of int, or torch.Size
56 |
57 | :raises NotImplementedError: The method is not implemented by the inheriting classes
58 |
59 | :return: Sampled noise
60 | :rtype: torch.Tensor
61 | """
62 | raise NotImplementedError("The sampling method (.sample()) is not implemented")
63 |
--------------------------------------------------------------------------------
/skrl/resources/noises/torch/gaussian.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | from torch.distributions import Normal
5 |
6 | from skrl.resources.noises.torch import Noise
7 |
8 |
9 | # speed up distribution construction by disabling checking
10 | Normal.set_default_validate_args(False)
11 |
12 |
13 | class GaussianNoise(Noise):
14 | def __init__(self, mean: float, std: float, device: Optional[Union[str, torch.device]] = None) -> None:
15 | """Class representing a Gaussian noise
16 |
17 | :param mean: Mean of the normal distribution
18 | :type mean: float
19 | :param std: Standard deviation of the normal distribution
20 | :type std: float
21 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``).
22 | If None, the device will be either ``"cuda"`` if available or ``"cpu"``
23 | :type device: str or torch.device, optional
24 |
25 | Example::
26 |
27 | >>> noise = GaussianNoise(mean=0, std=1)
28 | """
29 | super().__init__(device)
30 |
31 | self.distribution = Normal(
32 | loc=torch.tensor(mean, device=self.device, dtype=torch.float32),
33 | scale=torch.tensor(std, device=self.device, dtype=torch.float32),
34 | )
35 |
36 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
37 | """Sample a Gaussian noise
38 |
39 | :param size: Shape of the sampled tensor
40 | :type size: tuple or list of int, or torch.Size
41 |
42 | :return: Sampled noise
43 | :rtype: torch.Tensor
44 |
45 | Example::
46 |
47 | >>> noise.sample((3, 2))
48 | tensor([[-0.4901, 1.3357],
49 | [-1.2141, 0.3323],
50 | [-0.0889, -1.1651]], device='cuda:0')
51 |
52 | >>> x = torch.rand(3, 2, device="cuda:0")
53 | >>> noise.sample(x.shape)
54 | tensor([[0.5398, 1.2009],
55 | [0.0307, 1.3065],
56 | [0.2082, 0.6116]], device='cuda:0')
57 | """
58 | return self.distribution.sample(size)
59 |
--------------------------------------------------------------------------------
/skrl/resources/noises/torch/ornstein_uhlenbeck.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | from torch.distributions import Normal
5 |
6 | from skrl.resources.noises.torch import Noise
7 |
8 |
9 | # speed up distribution construction by disabling checking
10 | Normal.set_default_validate_args(False)
11 |
12 |
13 | class OrnsteinUhlenbeckNoise(Noise):
14 | def __init__(
15 | self,
16 | theta: float,
17 | sigma: float,
18 | base_scale: float,
19 | mean: float = 0,
20 | std: float = 1,
21 | device: Optional[Union[str, torch.device]] = None,
22 | ) -> None:
23 | """Class representing an Ornstein-Uhlenbeck noise
24 |
25 | :param theta: Factor to apply to current internal state
26 | :type theta: float
27 | :param sigma: Factor to apply to the normal distribution
28 | :type sigma: float
29 | :param base_scale: Factor to apply to returned noise
30 | :type base_scale: float
31 | :param mean: Mean of the normal distribution (default: ``0.0``)
32 | :type mean: float, optional
33 | :param std: Standard deviation of the normal distribution (default: ``1.0``)
34 | :type std: float, optional
35 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``).
36 | If None, the device will be either ``"cuda"`` if available or ``"cpu"``
37 | :type device: str or torch.device, optional
38 |
39 | Example::
40 |
41 | >>> noise = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=0.5)
42 | """
43 | super().__init__(device)
44 |
45 | self.state = 0
46 | self.theta = theta
47 | self.sigma = sigma
48 | self.base_scale = base_scale
49 |
50 | self.distribution = Normal(
51 | loc=torch.tensor(mean, device=self.device, dtype=torch.float32),
52 | scale=torch.tensor(std, device=self.device, dtype=torch.float32),
53 | )
54 |
55 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
56 | """Sample an Ornstein-Uhlenbeck noise
57 |
58 | :param size: Shape of the sampled tensor
59 | :type size: tuple or list of int, or torch.Size
60 |
61 | :return: Sampled noise
62 | :rtype: torch.Tensor
63 |
64 | Example::
65 |
66 | >>> noise.sample((3, 2))
67 | tensor([[-0.0452, 0.0162],
68 | [ 0.0649, -0.0708],
69 | [-0.0211, 0.0066]], device='cuda:0')
70 |
71 | >>> x = torch.rand(3, 2, device="cuda:0")
72 | >>> noise.sample(x.shape)
73 | tensor([[-0.0540, 0.0461],
74 | [ 0.1117, -0.1157],
75 | [-0.0074, 0.0420]], device='cuda:0')
76 | """
77 | if hasattr(self.state, "shape") and self.state.shape != torch.Size(size):
78 | self.state = 0
79 | self.state += -self.state * self.theta + self.sigma * self.distribution.sample(size)
80 |
81 | return self.base_scale * self.state
82 |
--------------------------------------------------------------------------------
/skrl/resources/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/optimizers/__init__.py
--------------------------------------------------------------------------------
/skrl/resources/optimizers/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.resources.optimizers.jax.adam import Adam
2 |
--------------------------------------------------------------------------------
/skrl/resources/preprocessors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/preprocessors/__init__.py
--------------------------------------------------------------------------------
/skrl/resources/preprocessors/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.resources.preprocessors.jax.running_standard_scaler import RunningStandardScaler
2 |
--------------------------------------------------------------------------------
/skrl/resources/preprocessors/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.resources.preprocessors.torch.running_standard_scaler import RunningStandardScaler
2 |
--------------------------------------------------------------------------------
/skrl/resources/schedulers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/resources/schedulers/__init__.py
--------------------------------------------------------------------------------
/skrl/resources/schedulers/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.resources.schedulers.jax.kl_adaptive import KLAdaptiveLR, kl_adaptive
2 |
3 |
4 | KLAdaptiveRL = KLAdaptiveLR # known typo (compatibility with versions prior to 1.0.0)
5 |
--------------------------------------------------------------------------------
/skrl/resources/schedulers/jax/kl_adaptive.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import numpy as np
4 | import optax
5 |
6 |
7 | def KLAdaptiveLR(
8 | kl_threshold: float = 0.008,
9 | min_lr: float = 1e-6,
10 | max_lr: float = 1e-2,
11 | kl_factor: float = 2,
12 | lr_factor: float = 1.5,
13 | ) -> optax.Schedule:
14 | """Adaptive KL scheduler
15 |
16 | Adjusts the learning rate according to the KL divergence.
17 | The implementation is adapted from the *rl_games* library.
18 |
19 | .. note::
20 |
21 | This scheduler is only available for PPO at the moment.
22 | Applying it to other agents will not change the learning rate
23 |
24 | Example::
25 |
26 | >>> scheduler = KLAdaptiveLR(kl_threshold=0.01)
27 | >>> for epoch in range(100):
28 | >>> # ...
29 | >>> kl_divergence = ...
30 | >>> new_lr = scheduler(timestep, lr, kl_divergence)
31 |
32 | :param kl_threshold: Threshold for KL divergence (default: ``0.008``)
33 | :type kl_threshold: float, optional
34 | :param min_lr: Lower bound for learning rate (default: ``1e-6``)
35 | :type min_lr: float, optional
36 | :param max_lr: Upper bound for learning rate (default: ``1e-2``)
37 | :type max_lr: float, optional
38 | :param kl_factor: The number used to modify the KL divergence threshold (default: ``2``)
39 | :type kl_factor: float, optional
40 | :param lr_factor: The number used to modify the learning rate (default: ``1.5``)
41 | :type lr_factor: float, optional
42 |
43 | :return: A function that maps step counts, current learning rate and KL divergence to the new learning rate value.
44 | If no learning rate is specified, 1.0 will be returned to mimic the Optax's scheduler behaviors.
45 | If the learning rate is specified but the KL divergence is not 0, the specified learning rate is returned.
46 | :rtype: optax.Schedule
47 | """
48 |
49 | def schedule(count: int, lr: Optional[float] = None, kl: Optional[Union[np.ndarray, float]] = None) -> float:
50 | if lr is None:
51 | return 1.0
52 | if kl is not None:
53 | if kl > kl_threshold * kl_factor:
54 | lr = max(lr / lr_factor, min_lr)
55 | elif kl < kl_threshold / kl_factor:
56 | lr = min(lr * lr_factor, max_lr)
57 | return lr
58 |
59 | return schedule
60 |
61 |
62 | # Alias to maintain naming compatibility with Optax schedulers
63 | # https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html
64 | kl_adaptive = KLAdaptiveLR
65 |
--------------------------------------------------------------------------------
/skrl/resources/schedulers/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.resources.schedulers.torch.kl_adaptive import KLAdaptiveLR
2 |
3 |
4 | KLAdaptiveRL = KLAdaptiveLR # known typo (compatibility with versions prior to 1.0.0)
5 |
--------------------------------------------------------------------------------
/skrl/trainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/trainers/__init__.py
--------------------------------------------------------------------------------
/skrl/trainers/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.trainers.jax.base import Trainer, generate_equally_spaced_scopes # isort:skip
2 |
3 | from skrl.trainers.jax.sequential import SequentialTrainer
4 | from skrl.trainers.jax.step import StepTrainer
5 |
--------------------------------------------------------------------------------
/skrl/trainers/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.trainers.torch.base import Trainer, generate_equally_spaced_scopes # isort:skip
2 |
3 | from skrl.trainers.torch.parallel import ParallelTrainer
4 | from skrl.trainers.torch.sequential import SequentialTrainer
5 | from skrl.trainers.torch.step import StepTrainer
6 |
--------------------------------------------------------------------------------
/skrl/utils/control.py:
--------------------------------------------------------------------------------
1 | import isaacgym.torch_utils as torch_utils
2 |
3 | import torch
4 |
5 |
6 | def ik(
7 | jacobian_end_effector, current_position, current_orientation, goal_position, goal_orientation, damping_factor=0.05
8 | ):
9 | """
10 | Damped Least Squares method: https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
11 | """
12 |
13 | # compute position and orientation error
14 | position_error = goal_position - current_position
15 | q_r = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation))
16 | orientation_error = q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1)
17 |
18 | dpose = torch.cat([position_error, orientation_error], -1).unsqueeze(-1)
19 |
20 | # solve damped least squares (dO = J.T * V)
21 | transpose = torch.transpose(jacobian_end_effector, 1, 2)
22 | lmbda = torch.eye(6).to(jacobian_end_effector.device) * (damping_factor**2)
23 | return transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ dpose
24 |
25 |
26 | def osc(
27 | jacobian_end_effector,
28 | mass_matrix,
29 | current_position,
30 | current_orientation,
31 | goal_position,
32 | goal_orientation,
33 | current_dof_velocities,
34 | kp=5,
35 | kv=2,
36 | ):
37 | """
38 | https://studywolf.wordpress.com/2013/09/17/robot-control-4-operation-space-control/
39 | """
40 |
41 | mass_matrix_end_effector = torch.inverse(
42 | jacobian_end_effector @ torch.inverse(mass_matrix) @ torch.transpose(jacobian_end_effector, 1, 2)
43 | )
44 |
45 | # compute position and orientation error
46 | position_error = kp * (goal_position - current_position)
47 | q_r = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation))
48 | orientation_error = q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1)
49 |
50 | dpose = torch.cat([position_error, orientation_error], -1)
51 |
52 | return (
53 | torch.transpose(jacobian_end_effector, 1, 2) @ mass_matrix_end_effector @ (kp * dpose).unsqueeze(-1)
54 | - kv * mass_matrix @ current_dof_velocities
55 | )
56 |
--------------------------------------------------------------------------------
/skrl/utils/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/distributed/__init__.py
--------------------------------------------------------------------------------
/skrl/utils/distributed/jax/__main__.py:
--------------------------------------------------------------------------------
1 | from . import launcher
2 |
3 |
4 | if __name__ == "__main__":
5 | launcher.launch()
6 |
--------------------------------------------------------------------------------
/skrl/utils/huggingface.py:
--------------------------------------------------------------------------------
1 | from skrl import __version__, logger
2 |
3 |
4 | def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> str:
5 | """Download a model from Hugging Face Hub
6 |
7 | :param repo_id: Hugging Face user or organization name and a repo name separated by a ``/``
8 | :type repo_id: str
9 | :param filename: The name of the model file in the repo (default: ``"agent.pt"``)
10 | :type filename: str, optional
11 |
12 | :raises ImportError: The Hugging Face Hub package (huggingface-hub) is not installed
13 | :raises huggingface_hub.utils._errors.HfHubHTTPError: Any HTTP error raised in Hugging Face Hub
14 |
15 | :return: Local path of file or if networking is off, last version of file cached on disk
16 | :rtype: str
17 |
18 | Example::
19 |
20 | # download trained agent from the skrl organization (https://huggingface.co/skrl)
21 | >>> from skrl.utils.huggingface import download_model_from_huggingface
22 | >>> download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO")
23 | '/home/user/.cache/huggingface/hub/models--skrl--OmniIsaacGymEnvs-Cartpole-PPO/snapshots/892e629903de6bf3ef102ae760406a5dd0f6f873/agent.pt'
24 |
25 | # download model (e.g. "policy.pth") from another user/organization (e.g. "org/ddpg-Pendulum-v1")
26 | >>> from skrl.utils.huggingface import download_model_from_huggingface
27 | >>> download_model_from_huggingface("org/ddpg-Pendulum-v1", "policy.pth")
28 | '/home/user/.cache/huggingface/hub/models--org--ddpg-Pendulum-v1/snapshots/b44ee96f93ff2e296156b002a2ca4646e197ba32/policy.pth'
29 | """
30 | logger.info(f"Downloading model from Hugging Face Hub: {repo_id}/{filename}")
31 | try:
32 | import huggingface_hub
33 | except ImportError:
34 | logger.error("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it")
35 | huggingface_hub = None
36 |
37 | if huggingface_hub is None:
38 | raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it")
39 |
40 | # download and cache the model from Hugging Face Hub
41 | downloaded_model_file = huggingface_hub.hf_hub_download(
42 | repo_id=repo_id, filename=filename, library_name="skrl", library_version=__version__
43 | )
44 |
45 | return downloaded_model_file
46 |
--------------------------------------------------------------------------------
/skrl/utils/model_instantiators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/model_instantiators/__init__.py
--------------------------------------------------------------------------------
/skrl/utils/model_instantiators/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from skrl.utils.model_instantiators.jax.categorical import categorical_model
4 | from skrl.utils.model_instantiators.jax.deterministic import deterministic_model
5 | from skrl.utils.model_instantiators.jax.gaussian import gaussian_model
6 | from skrl.utils.model_instantiators.jax.multicategorical import multicategorical_model
7 |
8 |
9 | # keep for compatibility with versions prior to 1.3.0
10 | class Shape(Enum):
11 | """
12 | Enum to select the shape of the model's inputs and outputs
13 | """
14 |
15 | ONE = 1
16 | STATES = 0
17 | OBSERVATIONS = 0
18 | ACTIONS = -1
19 | STATES_ACTIONS = -2
20 |
--------------------------------------------------------------------------------
/skrl/utils/model_instantiators/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from skrl.utils.model_instantiators.torch.categorical import categorical_model
4 | from skrl.utils.model_instantiators.torch.deterministic import deterministic_model
5 | from skrl.utils.model_instantiators.torch.gaussian import gaussian_model
6 | from skrl.utils.model_instantiators.torch.multicategorical import multicategorical_model
7 | from skrl.utils.model_instantiators.torch.multivariate_gaussian import multivariate_gaussian_model
8 | from skrl.utils.model_instantiators.torch.shared import shared_model
9 |
10 |
11 | # keep for compatibility with versions prior to 1.3.0
12 | class Shape(Enum):
13 | """
14 | Enum to select the shape of the model's inputs and outputs
15 | """
16 |
17 | ONE = 1
18 | STATES = 0
19 | OBSERVATIONS = 0
20 | ACTIONS = -1
21 | STATES_ACTIONS = -2
22 |
--------------------------------------------------------------------------------
/skrl/utils/runner/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/runner/__init__.py
--------------------------------------------------------------------------------
/skrl/utils/runner/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.utils.runner.jax.runner import Runner
2 |
--------------------------------------------------------------------------------
/skrl/utils/runner/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.utils.runner.torch.runner import Runner
2 |
--------------------------------------------------------------------------------
/skrl/utils/spaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/skrl/utils/spaces/__init__.py
--------------------------------------------------------------------------------
/skrl/utils/spaces/jax/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.utils.spaces.jax.spaces import (
2 | compute_space_size,
3 | convert_gym_space,
4 | flatten_tensorized_space,
5 | sample_space,
6 | tensorize_space,
7 | unflatten_tensorized_space,
8 | untensorize_space,
9 | )
10 |
--------------------------------------------------------------------------------
/skrl/utils/spaces/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from skrl.utils.spaces.torch.spaces import (
2 | compute_space_size,
3 | convert_gym_space,
4 | flatten_tensorized_space,
5 | sample_space,
6 | tensorize_space,
7 | unflatten_tensorized_space,
8 | untensorize_space,
9 | )
10 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/__init__.py
--------------------------------------------------------------------------------
/tests/agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/agents/__init__.py
--------------------------------------------------------------------------------
/tests/agents/jax/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/agents/jax/__init__.py
--------------------------------------------------------------------------------
/tests/agents/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/agents/torch/__init__.py
--------------------------------------------------------------------------------
/tests/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/__init__.py
--------------------------------------------------------------------------------
/tests/envs/wrappers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/wrappers/__init__.py
--------------------------------------------------------------------------------
/tests/envs/wrappers/jax/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/wrappers/jax/__init__.py
--------------------------------------------------------------------------------
/tests/envs/wrappers/jax/test_brax_envs.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import warnings
3 |
4 | from collections.abc import Mapping
5 | import gymnasium
6 |
7 | import jax
8 | import jax.numpy as jnp
9 | import numpy as np
10 |
11 | from skrl import config
12 | from skrl.envs.wrappers.jax import BraxWrapper, wrap_env
13 |
14 | from ....utilities import is_running_on_github_actions
15 |
16 |
17 | @pytest.mark.parametrize("backend", ["jax", "numpy"])
18 | def test_env(capsys: pytest.CaptureFixture, backend: str):
19 | config.jax.backend = backend
20 | Array = jax.Array if backend == "jax" else np.ndarray
21 |
22 | num_envs = 10
23 | action = jnp.ones((num_envs, 1)) if backend == "jax" else np.ones((num_envs, 1))
24 |
25 | # load wrap the environment
26 | try:
27 | import brax.envs
28 | except ImportError as e:
29 | if is_running_on_github_actions():
30 | raise e
31 | else:
32 | pytest.skip(f"Unable to import Brax environment: {e}")
33 |
34 | original_env = brax.envs.create("inverted_pendulum", batch_size=num_envs, backend="spring")
35 | env = wrap_env(original_env, "auto")
36 | assert isinstance(env, BraxWrapper)
37 | env = wrap_env(original_env, "brax")
38 | assert isinstance(env, BraxWrapper)
39 |
40 | # check properties
41 | assert env.state_space is None
42 | assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (4,)
43 | assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,)
44 | assert isinstance(env.num_envs, int) and env.num_envs == num_envs
45 | assert isinstance(env.num_agents, int) and env.num_agents == 1
46 | assert isinstance(env.device, jax.Device)
47 | # check internal properties
48 | assert env._env is not original_env # # brax's VectorGymWrapper interferes with the checking (it has _env)
49 | assert env._unwrapped is not original_env.unwrapped # brax's VectorGymWrapper interferes with the checking
50 | # check methods
51 | for _ in range(2):
52 | observation, info = env.reset()
53 | observation, info = env.reset() # edge case: parallel environments are autoreset
54 | assert isinstance(observation, Array) and observation.shape == (num_envs, 4)
55 | assert isinstance(info, Mapping)
56 | for _ in range(3):
57 | observation, reward, terminated, truncated, info = env.step(action)
58 | try:
59 | env.render()
60 | except AttributeError as e:
61 | warnings.warn(f"Brax exception when rendering: {e}")
62 | assert isinstance(observation, Array) and observation.shape == (num_envs, 4)
63 | assert isinstance(reward, Array) and reward.shape == (num_envs, 1)
64 | assert isinstance(terminated, Array) and terminated.shape == (num_envs, 1)
65 | assert isinstance(truncated, Array) and truncated.shape == (num_envs, 1)
66 | assert isinstance(info, Mapping)
67 |
68 | env.close()
69 |
--------------------------------------------------------------------------------
/tests/envs/wrappers/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/envs/wrappers/torch/__init__.py
--------------------------------------------------------------------------------
/tests/envs/wrappers/torch/test_brax_envs.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import warnings
3 |
4 | from collections.abc import Mapping
5 | import gymnasium
6 |
7 | import torch
8 |
9 | from skrl.envs.wrappers.torch import BraxWrapper, wrap_env
10 |
11 | from ....utilities import is_running_on_github_actions
12 |
13 |
14 | def test_env(capsys: pytest.CaptureFixture):
15 | num_envs = 10
16 | action = torch.ones((num_envs, 1))
17 |
18 | # load wrap the environment
19 | try:
20 | import brax.envs
21 | except ImportError as e:
22 | if is_running_on_github_actions():
23 | raise e
24 | else:
25 | pytest.skip(f"Unable to import Brax environment: {e}")
26 |
27 | original_env = brax.envs.create("inverted_pendulum", batch_size=num_envs, backend="spring")
28 | env = wrap_env(original_env, "auto")
29 | assert isinstance(env, BraxWrapper)
30 | env = wrap_env(original_env, "brax")
31 | assert isinstance(env, BraxWrapper)
32 |
33 | # check properties
34 | assert env.state_space is None
35 | assert isinstance(env.observation_space, gymnasium.Space) and env.observation_space.shape == (4,)
36 | assert isinstance(env.action_space, gymnasium.Space) and env.action_space.shape == (1,)
37 | assert isinstance(env.num_envs, int) and env.num_envs == num_envs
38 | assert isinstance(env.num_agents, int) and env.num_agents == 1
39 | assert isinstance(env.device, torch.device)
40 | # check internal properties
41 | assert env._env is not original_env # # brax's VectorGymWrapper interferes with the checking (it has _env)
42 | assert env._unwrapped is not original_env.unwrapped # brax's VectorGymWrapper interferes with the checking
43 | # check methods
44 | for _ in range(2):
45 | observation, info = env.reset()
46 | observation, info = env.reset() # edge case: parallel environments are autoreset
47 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 4])
48 | assert isinstance(info, Mapping)
49 | for _ in range(3):
50 | observation, reward, terminated, truncated, info = env.step(action)
51 | try:
52 | env.render()
53 | except AttributeError as e:
54 | warnings.warn(f"Brax exception when rendering: {e}")
55 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 4])
56 | assert isinstance(reward, torch.Tensor) and reward.shape == torch.Size([num_envs, 1])
57 | assert isinstance(terminated, torch.Tensor) and terminated.shape == torch.Size([num_envs, 1])
58 | assert isinstance(truncated, torch.Tensor) and truncated.shape == torch.Size([num_envs, 1])
59 | assert isinstance(info, Mapping)
60 |
61 | env.close()
62 |
--------------------------------------------------------------------------------
/tests/envs/wrappers/torch/test_deepmind_envs.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import warnings
3 |
4 | from collections.abc import Mapping
5 | import gymnasium as gym
6 |
7 | import torch
8 |
9 | from skrl.envs.wrappers.torch import DeepMindWrapper, wrap_env
10 |
11 | from ....utilities import is_running_on_github_actions
12 |
13 |
14 | def test_env(capsys: pytest.CaptureFixture):
15 | num_envs = 1
16 | action = torch.ones((num_envs, 1))
17 |
18 | # load wrap the environment
19 | try:
20 | from dm_control import suite
21 | except ImportError as e:
22 | if is_running_on_github_actions():
23 | raise e
24 | else:
25 | pytest.skip(f"Unable to import DeepMind environment: {e}")
26 |
27 | original_env = suite.load(domain_name="pendulum", task_name="swingup")
28 | env = wrap_env(original_env, "auto")
29 | assert isinstance(env, DeepMindWrapper)
30 | env = wrap_env(original_env, "dm")
31 | assert isinstance(env, DeepMindWrapper)
32 |
33 | # check properties
34 | assert env.state_space is None
35 | assert isinstance(env.observation_space, gym.Space) and sorted(list(env.observation_space.keys())) == [
36 | "orientation",
37 | "velocity",
38 | ]
39 | assert isinstance(env.action_space, gym.Space) and env.action_space.shape == (1,)
40 | assert isinstance(env.num_envs, int) and env.num_envs == num_envs
41 | assert isinstance(env.num_agents, int) and env.num_agents == 1
42 | assert isinstance(env.device, torch.device)
43 | # check internal properties
44 | assert env._env is original_env
45 | assert env._unwrapped is original_env
46 | # check methods
47 | for _ in range(2):
48 | observation, info = env.reset()
49 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 3])
50 | assert isinstance(info, Mapping)
51 | for _ in range(3):
52 | observation, reward, terminated, truncated, info = env.step(action)
53 | if not is_running_on_github_actions():
54 | env.render()
55 | assert isinstance(observation, torch.Tensor) and observation.shape == torch.Size([num_envs, 3])
56 | assert isinstance(reward, torch.Tensor) and reward.shape == torch.Size([num_envs, 1])
57 | assert isinstance(terminated, torch.Tensor) and terminated.shape == torch.Size([num_envs, 1])
58 | assert isinstance(truncated, torch.Tensor) and truncated.shape == torch.Size([num_envs, 1])
59 | assert isinstance(info, Mapping)
60 |
61 | env.close()
62 |
--------------------------------------------------------------------------------
/tests/memories/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/memories/__init__.py
--------------------------------------------------------------------------------
/tests/memories/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/memories/torch/__init__.py
--------------------------------------------------------------------------------
/tests/memories/torch/test_base.py:
--------------------------------------------------------------------------------
1 | import hypothesis
2 | import hypothesis.strategies as st
3 | import pytest
4 |
5 | import torch
6 |
7 | from skrl import config
8 | from skrl.memories.torch import Memory
9 |
10 |
11 | @pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
12 | def test_device(capsys, device):
13 | memory = Memory(memory_size=5, num_envs=1, device=device)
14 | memory.create_tensor("buffer", size=1)
15 |
16 | target_device = config.torch.parse_device(device)
17 | assert memory.device == target_device
18 | assert memory.get_tensor_by_name("buffer").device == target_device
19 |
20 |
21 | # __len__
22 |
23 |
24 | def test_share_memory(capsys):
25 | memory = Memory(memory_size=5, num_envs=1, device="cuda")
26 | memory.create_tensor("buffer", size=1)
27 |
28 | memory.share_memory()
29 |
30 |
31 | @hypothesis.given(
32 | tensor_names=st.lists(
33 | st.text(st.characters(codec="ascii", categories=("Nd", "L")), min_size=1, max_size=5), # codespell:ignore
34 | min_size=0,
35 | max_size=5,
36 | unique=True,
37 | )
38 | )
39 | @hypothesis.settings(
40 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
41 | deadline=None,
42 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate],
43 | )
44 | def test_get_tensor_names(capsys, tensor_names):
45 | memory = Memory(memory_size=5, num_envs=1)
46 | for name in tensor_names:
47 | memory.create_tensor(name, size=1)
48 |
49 | assert memory.get_tensor_names() == sorted(tensor_names)
50 |
51 |
52 | @hypothesis.given(
53 | tensor_name=st.text(
54 | st.characters(codec="ascii", categories=("Nd", "L")), min_size=1, max_size=5 # codespell:ignore
55 | )
56 | )
57 | @hypothesis.settings(
58 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
59 | deadline=None,
60 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate],
61 | )
62 | @pytest.mark.parametrize("keepdim", [True, False])
63 | def test_get_tensor_by_name(capsys, tensor_name, keepdim):
64 | memory = Memory(memory_size=5, num_envs=2)
65 | memory.create_tensor(tensor_name, size=1)
66 |
67 | target_shape = (5, 2, 1) if keepdim else (10, 1)
68 | assert memory.get_tensor_by_name(tensor_name, keepdim=keepdim).shape == target_shape
69 |
70 |
71 | @hypothesis.given(
72 | tensor_name=st.text(
73 | st.characters(codec="ascii", categories=("Nd", "L")), min_size=1, max_size=5 # codespell:ignore
74 | )
75 | )
76 | @hypothesis.settings(
77 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
78 | deadline=None,
79 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate],
80 | )
81 | def test_set_tensor_by_name(capsys, tensor_name):
82 | memory = Memory(memory_size=5, num_envs=2)
83 | memory.create_tensor(tensor_name, size=1)
84 |
85 | target_tensor = torch.arange(10, device=memory.device).reshape(5, 2, 1)
86 | memory.set_tensor_by_name(tensor_name, target_tensor)
87 | assert torch.any(memory.get_tensor_by_name(tensor_name, keepdim=True) == target_tensor)
88 |
--------------------------------------------------------------------------------
/tests/strategies.py:
--------------------------------------------------------------------------------
1 | import hypothesis.strategies as st
2 |
3 | import gymnasium
4 |
5 |
6 | @st.composite
7 | def gymnasium_space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> gymnasium.spaces.Space:
8 | if not space_type:
9 | space_type = draw(st.sampled_from(["Box", "Discrete", "MultiDiscrete", "Dict", "Tuple"]))
10 | # recursion base case
11 | if remaining_iterations <= 0 and space_type in ["Dict", "Tuple"]:
12 | space_type = "Box"
13 |
14 | if space_type == "Box":
15 | shape = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5))
16 | return gymnasium.spaces.Box(low=-1, high=1, shape=shape)
17 | elif space_type == "Discrete":
18 | n = draw(st.integers(min_value=1, max_value=5))
19 | return gymnasium.spaces.Discrete(n)
20 | elif space_type == "MultiDiscrete":
21 | nvec = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5))
22 | return gymnasium.spaces.MultiDiscrete(nvec)
23 | elif space_type == "Dict":
24 | remaining_iterations -= 1
25 | keys = draw(st.lists(st.text(st.characters(codec="ascii"), min_size=1, max_size=5), min_size=1, max_size=3))
26 | spaces = {key: draw(gymnasium_space_stategy(remaining_iterations=remaining_iterations)) for key in keys}
27 | return gymnasium.spaces.Dict(spaces)
28 | elif space_type == "Tuple":
29 | remaining_iterations -= 1
30 | spaces = draw(
31 | st.lists(gymnasium_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3)
32 | )
33 | return gymnasium.spaces.Tuple(spaces)
34 | else:
35 | raise ValueError(f"Invalid space type: {space_type}")
36 |
37 |
38 | @st.composite
39 | def gym_space_stategy(draw, space_type: str = "", remaining_iterations: int = 5) -> "gym.spaces.Space":
40 | import gym
41 |
42 | if not space_type:
43 | space_type = draw(st.sampled_from(["Box", "Discrete", "MultiDiscrete", "Dict", "Tuple"]))
44 | # recursion base case
45 | if remaining_iterations <= 0 and space_type in ["Dict", "Tuple"]:
46 | space_type = "Box"
47 |
48 | if space_type == "Box":
49 | shape = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5))
50 | return gym.spaces.Box(low=-1, high=1, shape=shape)
51 | elif space_type == "Discrete":
52 | n = draw(st.integers(min_value=1, max_value=5))
53 | return gym.spaces.Discrete(n)
54 | elif space_type == "MultiDiscrete":
55 | nvec = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=1, max_size=5))
56 | return gym.spaces.MultiDiscrete(nvec)
57 | elif space_type == "Dict":
58 | remaining_iterations -= 1
59 | keys = draw(st.lists(st.text(st.characters(codec="ascii"), min_size=1, max_size=5), min_size=1, max_size=3))
60 | spaces = {key: draw(gym_space_stategy(remaining_iterations=remaining_iterations)) for key in keys}
61 | return gym.spaces.Dict(spaces)
62 | elif space_type == "Tuple":
63 | remaining_iterations -= 1
64 | spaces = draw(st.lists(gym_space_stategy(remaining_iterations=remaining_iterations), min_size=1, max_size=3))
65 | return gym.spaces.Tuple(spaces)
66 | else:
67 | raise ValueError(f"Invalid space type: {space_type}")
68 |
--------------------------------------------------------------------------------
/tests/test_torch_config.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import hypothesis
4 | import hypothesis.strategies as st
5 | import pytest
6 |
7 | import os
8 |
9 | import torch
10 |
11 | from skrl import _Config, config
12 |
13 |
14 | @pytest.mark.parametrize("device", [None, "cpu", "cuda", "cuda:0", "cuda:10", "edge-case"])
15 | @pytest.mark.parametrize("validate", [True, False])
16 | def test_parse_device(capsys, device: Union[str, None], validate: bool):
17 | target_device = None
18 | if device in [None, "edge-case"]:
19 | target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20 | elif device.startswith("cuda"):
21 | if validate and int(f"{device}:0".split(":")[1]) >= torch.cuda.device_count():
22 | target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23 | if not target_device:
24 | target_device = torch.device(device)
25 |
26 | runtime_device = config.torch.parse_device(device, validate=validate)
27 | assert runtime_device == target_device
28 |
29 |
30 | @pytest.mark.parametrize("device", [None, "cpu", "cuda", "cuda:0", "cuda:10", "edge-case"])
31 | def test_device(capsys, device: Union[str, None]):
32 | if device in [None, "edge-case"]:
33 | target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34 | else:
35 | target_device = torch.device(device)
36 |
37 | # check setter/getter
38 | config.torch.device = device
39 | assert config.torch.device == target_device
40 |
41 |
42 | @hypothesis.given(
43 | local_rank=st.integers(),
44 | rank=st.integers(),
45 | world_size=st.integers(),
46 | )
47 | @hypothesis.settings(
48 | suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
49 | deadline=None,
50 | phases=[hypothesis.Phase.explicit, hypothesis.Phase.reuse, hypothesis.Phase.generate],
51 | )
52 | def test_distributed(capsys, local_rank: int, rank: int, world_size: int):
53 | os.environ["LOCAL_RANK"] = str(local_rank)
54 | os.environ["RANK"] = str(rank)
55 | os.environ["WORLD_SIZE"] = str(world_size)
56 | is_distributed = world_size > 1
57 |
58 | if is_distributed:
59 | with pytest.raises(ValueError, match="Error initializing torch.distributed"):
60 | config = _Config()
61 | return
62 | else:
63 | config = _Config()
64 | assert config.torch.local_rank == local_rank
65 | assert config.torch.rank == rank
66 | assert config.torch.world_size == world_size
67 | assert config.torch.is_distributed == is_distributed
68 | assert config.torch._device == f"cuda:{local_rank}"
69 |
--------------------------------------------------------------------------------
/tests/utilities.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import gymnasium
4 |
5 | import numpy as np
6 |
7 |
8 | def is_device_available(device, *, backend) -> bool:
9 | if backend == "torch":
10 | import torch
11 |
12 | try:
13 | torch.zeros((1,), device=device)
14 | except Exception as e:
15 | return False
16 | else:
17 | raise ValueError(f"Invalid backend: {backend}")
18 | return True
19 |
20 |
21 | def is_running_on_github_actions() -> bool:
22 | return os.environ.get("GITHUB_ACTIONS") is not None
23 |
24 |
25 | def get_test_mixed_precision(default):
26 | value = os.environ.get("SKRL_TEST_MIXED_PRECISION")
27 | if value is None:
28 | return False
29 | if value.lower() in ["true", "1", "y", "yes"]:
30 | return default
31 | if value.lower() in ["false", "0", "n", "no"]:
32 | return False
33 | raise ValueError(f"Invalid value for environment variable SKRL_TEST_MIXED_PRECISION: {value}")
34 |
35 |
36 | class BaseEnv(gymnasium.Env):
37 | def __init__(self, observation_space, action_space, num_envs, device):
38 | self.device = device
39 | self.num_envs = num_envs
40 | self.action_space = action_space
41 | self.observation_space = observation_space
42 |
43 | def _sample_observation(self):
44 | raise NotImplementedError
45 |
46 | def step(self, actions):
47 | if self.num_envs == 1:
48 | rewards = random.random()
49 | terminated = random.random() > 0.95
50 | truncated = random.random() > 0.95
51 | else:
52 | rewards = np.random.random((self.num_envs,))
53 | terminated = np.random.random((self.num_envs,)) > 0.95
54 | truncated = np.random.random((self.num_envs,)) > 0.95
55 |
56 | return self._sample_observation(), rewards, terminated, truncated, {}
57 |
58 | def reset(self):
59 | return self._sample_observation(), {}
60 |
61 | def render(self, *args, **kwargs):
62 | pass
63 |
64 | def close(self, *args, **kwargs):
65 | pass
66 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/model_instantiators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/model_instantiators/__init__.py
--------------------------------------------------------------------------------
/tests/utils/model_instantiators/jax/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/model_instantiators/jax/__init__.py
--------------------------------------------------------------------------------
/tests/utils/model_instantiators/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/model_instantiators/torch/__init__.py
--------------------------------------------------------------------------------
/tests/utils/spaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/spaces/__init__.py
--------------------------------------------------------------------------------
/tests/utils/spaces/jax/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/spaces/jax/__init__.py
--------------------------------------------------------------------------------
/tests/utils/spaces/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Toni-SM/skrl/90adbbc1593ebb7ea5f98f39bd22f7e88d1198b2/tests/utils/spaces/torch/__init__.py
--------------------------------------------------------------------------------