├── src ├── third_parties │ └── skrl │ │ ├── agents │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ ├── a2c │ │ │ │ └── __init__.py │ │ │ ├── cem │ │ │ │ └── __init__.py │ │ │ ├── ppo │ │ │ │ └── __init__.py │ │ │ ├── rpo │ │ │ │ └── __init__.py │ │ │ ├── sac │ │ │ │ └── __init__.py │ │ │ ├── td3 │ │ │ │ └── __init__.py │ │ │ ├── ddpg │ │ │ │ └── __init__.py │ │ │ └── dqn │ │ │ │ └── __init__.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── amp │ │ │ └── __init__.py │ │ │ ├── cem │ │ │ └── __init__.py │ │ │ ├── sarsa │ │ │ └── __init__.py │ │ │ ├── q_learning │ │ │ └── __init__.py │ │ │ ├── a2c │ │ │ └── __init__.py │ │ │ ├── ppo │ │ │ └── __init__.py │ │ │ ├── rpo │ │ │ └── __init__.py │ │ │ ├── sac │ │ │ └── __init__.py │ │ │ ├── td3 │ │ │ └── __init__.py │ │ │ ├── ddpg │ │ │ └── __init__.py │ │ │ ├── trpo │ │ │ └── __init__.py │ │ │ └── dqn │ │ │ └── __init__.py │ │ ├── envs │ │ ├── __init__.py │ │ ├── loaders │ │ │ ├── __init__.py │ │ │ ├── jax │ │ │ │ ├── isaaclab_envs.py │ │ │ │ ├── bidexhands_envs.py │ │ │ │ ├── omniverse_isaacgym_envs.py │ │ │ │ ├── isaacgym_envs.py │ │ │ │ └── __init__.py │ │ │ └── torch │ │ │ │ └── __init__.py │ │ ├── wrappers │ │ │ ├── __init__.py │ │ │ ├── torch │ │ │ │ ├── omniverse_isaacgym_envs.py │ │ │ │ ├── brax_envs.py │ │ │ │ ├── pettingzoo_envs.py │ │ │ │ ├── gymnasium_envs.py │ │ │ │ ├── bidexhands_envs.py │ │ │ │ ├── deepmind_envs.py │ │ │ │ ├── isaacgym_envs.py │ │ │ │ ├── gym_envs.py │ │ │ │ ├── robosuite_envs.py │ │ │ │ └── isaaclab_envs.py │ │ │ └── jax │ │ │ │ ├── brax_envs.py │ │ │ │ ├── omniverse_isaacgym_envs.py │ │ │ │ ├── pettingzoo_envs.py │ │ │ │ ├── gymnasium_envs.py │ │ │ │ └── bidexhands_envs.py │ │ ├── jax.py │ │ └── torch.py │ │ ├── memories │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ └── random.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ └── random.py │ │ ├── models │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ └── deterministic.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ └── deterministic.py │ │ ├── trainers │ │ ├── __init__.py │ │ ├── jax │ │ │ └── __init__.py │ │ └── torch │ │ │ └── __init__.py │ │ ├── multi_agents │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ ├── ippo │ │ │ │ └── __init__.py │ │ │ └── mappo │ │ │ │ └── __init__.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── ippo │ │ │ └── __init__.py │ │ │ └── mappo │ │ │ └── __init__.py │ │ ├── resources │ │ ├── __init__.py │ │ ├── noises │ │ │ ├── __init__.py │ │ │ ├── jax │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── gaussian.py │ │ │ │ └── ornstein_uhlenbeck.py │ │ │ └── torch │ │ │ │ ├── __init__.py │ │ │ │ ├── gaussian.py │ │ │ │ ├── base.py │ │ │ │ └── ornstein_uhlenbeck.py │ │ ├── optimizers │ │ │ ├── __init__.py │ │ │ └── jax │ │ │ │ ├── __init__.py │ │ │ │ └── adam.py │ │ ├── preprocessors │ │ │ ├── __init__.py │ │ │ ├── jax │ │ │ │ └── __init__.py │ │ │ └── torch │ │ │ │ └── __init__.py │ │ └── schedulers │ │ │ ├── __init__.py │ │ │ ├── torch │ │ │ ├── __init__.py │ │ │ └── kl_adaptive.py │ │ │ └── jax │ │ │ ├── __init__.py │ │ │ └── kl_adaptive.py │ │ └── utils │ │ ├── runner │ │ ├── __init__.py │ │ ├── jax │ │ │ └── __init__.py │ │ └── torch │ │ │ └── __init__.py │ │ ├── spaces │ │ ├── __init__.py │ │ ├── jax │ │ │ └── __init__.py │ │ └── torch │ │ │ └── __init__.py │ │ ├── distributed │ │ ├── __init__.py │ │ └── jax │ │ │ ├── __main__.py │ │ │ └── launcher.py │ │ ├── model_instantiators │ │ ├── __init__.py │ │ ├── jax │ │ │ ├── __init__.py │ │ │ ├── deterministic.py │ │ │ ├── categorical.py │ │ │ └── multicategorical.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── deterministic.py │ │ │ ├── categorical.py │ │ │ ├── multicategorical.py │ │ │ └── multivariate_gaussian.py │ │ ├── control.py │ │ ├── huggingface.py │ │ └── __init__.py └── isaac_quad_sim2real │ ├── tasks │ ├── race │ │ ├── __init__.py │ │ └── config │ │ │ ├── __init__.py │ │ │ └── crazyflie │ │ │ ├── agents │ │ │ ├── __init__.py │ │ │ ├── rsl_rl_ppo_cfg.py │ │ │ ├── skrl_mappo_cfg.yaml │ │ │ └── rl_cfg.py │ │ │ └── __init__.py │ └── __init__.py │ └── __init__.py ├── .envrc ├── AgileFlight_CoverImage.png ├── AgileFlightEmergesMultiAgentCompetition.pdf ├── setup.py ├── .gitmodules ├── .gitignore ├── config └── extension.toml ├── pyproject.toml ├── README.md ├── scripts └── skrl │ └── cli_args.py └── test.py /src/third_parties/skrl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/memories/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.envrc: -------------------------------------------------------------------------------- 1 | conda activate isaac_quad_sim2real 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/multi_agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/runner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/spaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/race/__init__.py: -------------------------------------------------------------------------------- 1 | """Drone racing environments. 2 | """ 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.base import Agent 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.base import Agent 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/runner/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.runner.jax.runner import Runner 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/multi_agents/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.jax.base import MultiAgent 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/runner/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.runner.torch.runner import Runner 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.a2c.a2c import A2C, A2C_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/cem/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.cem.cem import CEM, CEM_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.ppo.ppo import PPO, PPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/rpo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.rpo.rpo import RPO, RPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.sac.sac import SAC, SAC_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.td3.td3 import TD3, TD3_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/multi_agents/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.torch.base import MultiAgent 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.ddpg.ddpg import DDPG, DDPG_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.amp.amp import AMP, AMP_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/cem/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.cem.cem import CEM, CEM_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/optimizers/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.optimizers.jax.adam import Adam 2 | -------------------------------------------------------------------------------- /AgileFlight_CoverImage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jirl-upenn/AgileFlight_MultiAgent/HEAD/AgileFlight_CoverImage.png -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/sarsa/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.sarsa.sarsa import SARSA, SARSA_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/multi_agents/jax/ippo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.jax.ippo.ippo import IPPO, IPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/multi_agents/jax/mappo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.jax.mappo.mappo import MAPPO, MAPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/multi_agents/torch/ippo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.torch.ippo.ippo import IPPO, IPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/multi_agents/torch/mappo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.multi_agents.torch.mappo.mappo import MAPPO, MAPPO_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/q_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.q_learning.q_learning import Q_LEARNING, Q_LEARNING_DEFAULT_CONFIG 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/distributed/jax/__main__.py: -------------------------------------------------------------------------------- 1 | from . import launcher 2 | 3 | 4 | if __name__ == "__main__": 5 | launcher.launch() 6 | -------------------------------------------------------------------------------- /AgileFlightEmergesMultiAgentCompetition.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jirl-upenn/AgileFlight_MultiAgent/HEAD/AgileFlightEmergesMultiAgentCompetition.pdf -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/preprocessors/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.preprocessors.jax.running_standard_scaler import RunningStandardScaler 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/preprocessors/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.preprocessors.torch.running_standard_scaler import RunningStandardScaler 2 | -------------------------------------------------------------------------------- /src/third_parties/skrl/memories/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.memories.jax.base import Memory # isort:skip 2 | 3 | from skrl.memories.jax.random import RandomMemory 4 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python module serving as a project/extension template. 3 | """ 4 | 5 | # Register Gym environments. 6 | from .tasks import * 7 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.a2c.a2c import A2C, A2C_DEFAULT_CONFIG 2 | from skrl.agents.torch.a2c.a2c_rnn import A2C_RNN 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.ppo.ppo import PPO, PPO_DEFAULT_CONFIG 2 | from skrl.agents.torch.ppo.ppo_rnn import PPO_RNN 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/rpo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.rpo.rpo import RPO, RPO_DEFAULT_CONFIG 2 | from skrl.agents.torch.rpo.rpo_rnn import RPO_RNN 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/sac/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.sac.sac import SAC, SAC_DEFAULT_CONFIG 2 | from skrl.agents.torch.sac.sac_rnn import SAC_RNN 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/td3/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.td3.td3 import TD3, TD3_DEFAULT_CONFIG 2 | from skrl.agents.torch.td3.td3_rnn import TD3_RNN 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/memories/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.memories.torch.base import Memory # isort:skip 2 | 3 | from skrl.memories.torch.random import RandomMemory 4 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/jax/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.jax.dqn.ddqn import DDQN, DDQN_DEFAULT_CONFIG 2 | from skrl.agents.jax.dqn.dqn import DQN, DQN_DEFAULT_CONFIG 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.ddpg.ddpg import DDPG, DDPG_DEFAULT_CONFIG 2 | from skrl.agents.torch.ddpg.ddpg_rnn import DDPG_RNN 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/trpo/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.trpo.trpo import TRPO, TRPO_DEFAULT_CONFIG 2 | from skrl.agents.torch.trpo.trpo_rnn import TRPO_RNN 3 | -------------------------------------------------------------------------------- /src/third_parties/skrl/agents/torch/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.agents.torch.dqn.ddqn import DDQN, DDQN_DEFAULT_CONFIG 2 | from skrl.agents.torch.dqn.dqn import DQN, DQN_DEFAULT_CONFIG 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from setuptools import setup 7 | 8 | setup() 9 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/loaders/jax/isaaclab_envs.py: -------------------------------------------------------------------------------- 1 | # since Isaac Lab environments are implemented on top of PyTorch, the loader is the same 2 | 3 | from skrl.envs.loaders.torch import load_isaaclab_env 4 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/loaders/jax/bidexhands_envs.py: -------------------------------------------------------------------------------- 1 | # since Bi-DexHands environments are implemented on top of PyTorch, the loader is the same 2 | 3 | from skrl.envs.loaders.torch import load_bidexhands_env 4 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/schedulers/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.schedulers.torch.kl_adaptive import KLAdaptiveLR 2 | 3 | 4 | KLAdaptiveRL = KLAdaptiveLR # known typo (compatibility with versions prior to 1.0.0) 5 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/loaders/jax/omniverse_isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | # since Omniverse Isaac Gym environments are implemented on top of PyTorch, the loader is the same 2 | 3 | from skrl.envs.loaders.torch import load_omniverse_isaacgym_env 4 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/schedulers/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.schedulers.jax.kl_adaptive import KLAdaptiveLR, kl_adaptive 2 | 3 | 4 | KLAdaptiveRL = KLAdaptiveLR # known typo (compatibility with versions prior to 1.0.0) 5 | -------------------------------------------------------------------------------- /src/third_parties/skrl/trainers/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.trainers.jax.base import Trainer, generate_equally_spaced_scopes # isort:skip 2 | 3 | from skrl.trainers.jax.sequential import SequentialTrainer 4 | from skrl.trainers.jax.step import StepTrainer 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/third_parties/rsl_rl"] 2 | path = src/third_parties/rsl_rl 3 | url = git@github.com:Jirl-upenn/rsl_rl.git 4 | [submodule "src/third_parties/rotorpy"] 5 | path = src/third_parties/rotorpy 6 | url = git@github.com:Jirl-upenn/rotorpy.git 7 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.noises.jax.base import Noise # isort:skip 2 | 3 | from skrl.resources.noises.jax.gaussian import GaussianNoise 4 | from skrl.resources.noises.jax.ornstein_uhlenbeck import OrnsteinUhlenbeckNoise 5 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.resources.noises.torch.base import Noise # isort:skip 2 | 3 | from skrl.resources.noises.torch.gaussian import GaussianNoise 4 | from skrl.resources.noises.torch.ornstein_uhlenbeck import OrnsteinUhlenbeckNoise 5 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/race/config/__init__.py: -------------------------------------------------------------------------------- 1 | """Configurations for drone racing environments.""" 2 | 3 | # We leave this file empty since we don't want to expose any configs in this package directly. 4 | # We still need this file to import the "config" module in the parent package. 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/logs/* 2 | **/outputs/* 3 | **/wandb/* 4 | 5 | # Python 6 | .DS_Store 7 | **/*.egg-info/ 8 | **/__pycache__/ 9 | **/.pytest_cache/ 10 | **/*.pyc 11 | **/*.pb 12 | 13 | # IDE 14 | **/.idea/ 15 | **/.vscode/ 16 | 17 | # RL-Games 18 | **/runs/* 19 | **/logs/* 20 | **/recordings/* -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/spaces/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.spaces.jax.spaces import ( 2 | compute_space_size, 3 | convert_gym_space, 4 | flatten_tensorized_space, 5 | sample_space, 6 | tensorize_space, 7 | unflatten_tensorized_space, 8 | untensorize_space, 9 | ) 10 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/spaces/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.utils.spaces.torch.spaces import ( 2 | compute_space_size, 3 | convert_gym_space, 4 | flatten_tensorized_space, 5 | sample_space, 6 | tensorize_space, 7 | unflatten_tensorized_space, 8 | untensorize_space, 9 | ) 10 | -------------------------------------------------------------------------------- /src/third_parties/skrl/trainers/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.trainers.torch.base import Trainer, generate_equally_spaced_scopes # isort:skip 2 | 3 | from skrl.trainers.torch.parallel import ParallelTrainer 4 | from skrl.trainers.torch.sequential import SequentialTrainer 5 | from skrl.trainers.torch.step import StepTrainer 6 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/loaders/jax/isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | # since Isaac Gym (preview) environments are implemented on top of PyTorch, the loaders are the same 2 | 3 | from skrl.envs.loaders.torch import ( # isort:skip 4 | load_isaacgym_env_preview2, 5 | load_isaacgym_env_preview3, 6 | load_isaacgym_env_preview4, 7 | ) 8 | -------------------------------------------------------------------------------- /src/third_parties/skrl/models/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.models.jax.base import Model # isort:skip 2 | 3 | from skrl.models.jax.categorical import CategoricalMixin 4 | from skrl.models.jax.deterministic import DeterministicMixin 5 | from skrl.models.jax.gaussian import GaussianMixin 6 | from skrl.models.jax.multicategorical import MultiCategoricalMixin 7 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/race/config/crazyflie/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | # Copyright (c) 2022-2025, The Isaac Lab Project Developers. 7 | # All rights reserved. 8 | # 9 | # SPDX-License-Identifier: BSD-3-Clause 10 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/loaders/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.envs.loaders.jax.bidexhands_envs import load_bidexhands_env 2 | from skrl.envs.loaders.jax.isaacgym_envs import ( 3 | load_isaacgym_env_preview2, 4 | load_isaacgym_env_preview3, 5 | load_isaacgym_env_preview4, 6 | ) 7 | from skrl.envs.loaders.jax.isaaclab_envs import load_isaaclab_env 8 | from skrl.envs.loaders.jax.omniverse_isaacgym_envs import load_omniverse_isaacgym_env 9 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/loaders/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.envs.loaders.torch.bidexhands_envs import load_bidexhands_env 2 | from skrl.envs.loaders.torch.isaacgym_envs import ( 3 | load_isaacgym_env_preview2, 4 | load_isaacgym_env_preview3, 5 | load_isaacgym_env_preview4, 6 | ) 7 | from skrl.envs.loaders.torch.isaaclab_envs import load_isaaclab_env 8 | from skrl.envs.loaders.torch.omniverse_isaacgym_envs import load_omniverse_isaacgym_env 9 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """Package containing task implementations for various robotic environments.""" 2 | 3 | import os 4 | import toml 5 | 6 | from isaaclab_tasks.utils import import_packages 7 | 8 | ## 9 | # Register Gym environments. 10 | ## 11 | 12 | 13 | # The blacklist is used to prevent importing configs from sub-packages 14 | _BLACKLIST_PKGS = ["utils"] 15 | # Import all configs in this package 16 | import_packages(__name__, _BLACKLIST_PKGS) 17 | -------------------------------------------------------------------------------- /src/third_parties/skrl/models/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from skrl.models.torch.base import Model # isort:skip 2 | 3 | from skrl.models.torch.categorical import CategoricalMixin 4 | from skrl.models.torch.deterministic import DeterministicMixin 5 | from skrl.models.torch.gaussian import GaussianMixin 6 | from skrl.models.torch.multicategorical import MultiCategoricalMixin 7 | from skrl.models.torch.multivariate_gaussian import MultivariateGaussianMixin 8 | from skrl.models.torch.tabular import TabularMixin 9 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/race/config/crazyflie/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | from . import agents, ma_quadcopter_env 4 | 5 | ## 6 | # Register Gym environments. 7 | ## 8 | 9 | # code for multi-agent environment registry 10 | gym.register( 11 | id="Isaac-MA-Quadcopter-Race-v0", 12 | entry_point=ma_quadcopter_env.QuadcopterEnv, 13 | disable_env_checker=True, 14 | kwargs={ 15 | "env_cfg_entry_point": ma_quadcopter_env.QuadcopterEnvCfg, 16 | "skrl_mappo_cfg_entry_point": f"{agents.__name__}:skrl_mappo_cfg.yaml", 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from skrl.utils.model_instantiators.jax.categorical import categorical_model 4 | from skrl.utils.model_instantiators.jax.deterministic import deterministic_model 5 | from skrl.utils.model_instantiators.jax.gaussian import gaussian_model 6 | from skrl.utils.model_instantiators.jax.multicategorical import multicategorical_model 7 | 8 | 9 | # keep for compatibility with versions prior to 1.3.0 10 | class Shape(Enum): 11 | """ 12 | Enum to select the shape of the model's inputs and outputs 13 | """ 14 | 15 | ONE = 1 16 | STATES = 0 17 | OBSERVATIONS = 0 18 | ACTIONS = -1 19 | STATES_ACTIONS = -2 20 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/jax.py: -------------------------------------------------------------------------------- 1 | # TODO: Delete this file in future releases 2 | 3 | from skrl import logger # isort: skip 4 | 5 | logger.warning("Using `from skrl.envs.jax import ...` is deprecated and will be removed in future versions.") 6 | logger.warning(" - Import loaders using `from skrl.envs.loaders.jax import ...`") 7 | logger.warning(" - Import wrappers using `from skrl.envs.wrappers.jax import ...`") 8 | 9 | 10 | from skrl.envs.loaders.jax import ( 11 | load_bidexhands_env, 12 | load_isaacgym_env_preview2, 13 | load_isaacgym_env_preview3, 14 | load_isaacgym_env_preview4, 15 | load_isaaclab_env, 16 | load_omniverse_isaacgym_env, 17 | ) 18 | from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper, wrap_env 19 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/torch.py: -------------------------------------------------------------------------------- 1 | # TODO: Delete this file in future releases 2 | 3 | from skrl import logger # isort: skip 4 | 5 | logger.warning("Using `from skrl.envs.torch import ...` is deprecated and will be removed in future versions.") 6 | logger.warning(" - Import loaders using `from skrl.envs.loaders.torch import ...`") 7 | logger.warning(" - Import wrappers using `from skrl.envs.wrappers.torch import ...`") 8 | 9 | 10 | from skrl.envs.loaders.torch import ( 11 | load_bidexhands_env, 12 | load_isaacgym_env_preview2, 13 | load_isaacgym_env_preview3, 14 | load_isaacgym_env_preview4, 15 | load_isaaclab_env, 16 | load_omniverse_isaacgym_env, 17 | ) 18 | from skrl.envs.wrappers.torch import MultiAgentEnvWrapper, Wrapper, wrap_env 19 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from skrl.utils.model_instantiators.torch.categorical import categorical_model 4 | from skrl.utils.model_instantiators.torch.deterministic import deterministic_model 5 | from skrl.utils.model_instantiators.torch.gaussian import gaussian_model 6 | from skrl.utils.model_instantiators.torch.multicategorical import multicategorical_model 7 | from skrl.utils.model_instantiators.torch.multivariate_gaussian import multivariate_gaussian_model 8 | from skrl.utils.model_instantiators.torch.shared import shared_model 9 | 10 | 11 | # keep for compatibility with versions prior to 1.3.0 12 | class Shape(Enum): 13 | """ 14 | Enum to select the shape of the model's inputs and outputs 15 | """ 16 | 17 | ONE = 1 18 | STATES = 0 19 | OBSERVATIONS = 0 20 | ACTIONS = -1 21 | STATES_ACTIONS = -2 22 | -------------------------------------------------------------------------------- /config/extension.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | 3 | # Semantic Versioning is used: https://semver.org/ 4 | version = "0.1.0" 5 | 6 | # Description 7 | category = "isaac_quad_sim2real" 8 | readme = "README.md" 9 | 10 | title = "ISAAC_QUAD_SIM2REAL" 11 | author = "Lorenzo Bianchi" 12 | maintainer = "Lorenzo Bianchi" 13 | description = "IsaacLab extension to develop RL environments for drones" 14 | repository = "https://github.com/Jirl-upenn/isaac_quad_sim2real" 15 | keywords = ["extension", "template", "isaaclab"] 16 | 17 | [dependencies] 18 | "omni.isaac.lab" = {} 19 | "omni.isaac.lab_assets" = {} 20 | "omni.isaac.lab_mimic" = {} 21 | "omni.isaac.lab_rl" = {} 22 | "omni.isaac.lab_tasks" = {} 23 | # NOTE: Add additional dependencies here 24 | 25 | [[python.module]] 26 | name = "isaac_quad_sim2real" 27 | 28 | [isaaclab_settings] 29 | # TODO: Uncomment and list any apt dependencies here. 30 | # If none, leave it commented out. 31 | # apt_deps = ["example_package"] 32 | # TODO: Uncomment and provide path to a ros_ws 33 | # with rosdeps to be installed. If none, 34 | # leave it commented out. 35 | # ros_ws = "path/from/extension_root/to/ros_ws" 36 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/control.py: -------------------------------------------------------------------------------- 1 | import isaacgym.torch_utils as torch_utils 2 | 3 | import torch 4 | 5 | 6 | def ik( 7 | jacobian_end_effector, current_position, current_orientation, goal_position, goal_orientation, damping_factor=0.05 8 | ): 9 | """ 10 | Damped Least Squares method: https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf 11 | """ 12 | 13 | # compute position and orientation error 14 | position_error = goal_position - current_position 15 | q_r = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation)) 16 | orientation_error = q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1) 17 | 18 | dpose = torch.cat([position_error, orientation_error], -1).unsqueeze(-1) 19 | 20 | # solve damped least squares (dO = J.T * V) 21 | transpose = torch.transpose(jacobian_end_effector, 1, 2) 22 | lmbda = torch.eye(6).to(jacobian_end_effector.device) * (damping_factor**2) 23 | return transpose @ torch.inverse(jacobian_end_effector @ transpose + lmbda) @ dpose 24 | 25 | 26 | def osc( 27 | jacobian_end_effector, 28 | mass_matrix, 29 | current_position, 30 | current_orientation, 31 | goal_position, 32 | goal_orientation, 33 | current_dof_velocities, 34 | kp=5, 35 | kv=2, 36 | ): 37 | """ 38 | https://studywolf.wordpress.com/2013/09/17/robot-control-4-operation-space-control/ 39 | """ 40 | 41 | mass_matrix_end_effector = torch.inverse( 42 | jacobian_end_effector @ torch.inverse(mass_matrix) @ torch.transpose(jacobian_end_effector, 1, 2) 43 | ) 44 | 45 | # compute position and orientation error 46 | position_error = kp * (goal_position - current_position) 47 | q_r = torch_utils.quat_mul(goal_orientation, torch_utils.quat_conjugate(current_orientation)) 48 | orientation_error = q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1) 49 | 50 | dpose = torch.cat([position_error, orientation_error], -1) 51 | 52 | return ( 53 | torch.transpose(jacobian_end_effector, 1, 2) @ mass_matrix_end_effector @ (kp * dpose).unsqueeze(-1) 54 | - kv * mass_matrix @ current_dof_velocities 55 | ) 56 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/torch/gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch.distributions import Normal 5 | 6 | from skrl.resources.noises.torch import Noise 7 | 8 | 9 | # speed up distribution construction by disabling checking 10 | Normal.set_default_validate_args(False) 11 | 12 | 13 | class GaussianNoise(Noise): 14 | def __init__(self, mean: float, std: float, device: Optional[Union[str, torch.device]] = None) -> None: 15 | """Class representing a Gaussian noise 16 | 17 | :param mean: Mean of the normal distribution 18 | :type mean: float 19 | :param std: Standard deviation of the normal distribution 20 | :type std: float 21 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 22 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 23 | :type device: str or torch.device, optional 24 | 25 | Example:: 26 | 27 | >>> noise = GaussianNoise(mean=0, std=1) 28 | """ 29 | super().__init__(device) 30 | 31 | self.distribution = Normal( 32 | loc=torch.tensor(mean, device=self.device, dtype=torch.float32), 33 | scale=torch.tensor(std, device=self.device, dtype=torch.float32), 34 | ) 35 | 36 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: 37 | """Sample a Gaussian noise 38 | 39 | :param size: Shape of the sampled tensor 40 | :type size: tuple or list of int, or torch.Size 41 | 42 | :return: Sampled noise 43 | :rtype: torch.Tensor 44 | 45 | Example:: 46 | 47 | >>> noise.sample((3, 2)) 48 | tensor([[-0.4901, 1.3357], 49 | [-1.2141, 0.3323], 50 | [-0.0889, -1.1651]], device='cuda:0') 51 | 52 | >>> x = torch.rand(3, 2, device="cuda:0") 53 | >>> noise.sample(x.shape) 54 | tensor([[0.5398, 1.2009], 55 | [0.0307, 1.3065], 56 | [0.2082, 0.6116]], device='cuda:0') 57 | """ 58 | return self.distribution.sample(size) 59 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/torch/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | 5 | from skrl import config 6 | 7 | 8 | class Noise: 9 | def __init__(self, device: Optional[Union[str, torch.device]] = None) -> None: 10 | """Base class representing a noise 11 | 12 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 13 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 14 | :type device: str or torch.device, optional 15 | 16 | Custom noises should override the ``sample`` method:: 17 | 18 | import torch 19 | from skrl.resources.noises.torch import Noise 20 | 21 | class CustomNoise(Noise): 22 | def __init__(self, device=None): 23 | super().__init__(device) 24 | 25 | def sample(self, size): 26 | return torch.rand(size, device=self.device) 27 | """ 28 | self.device = config.torch.parse_device(device) 29 | 30 | def sample_like(self, tensor: torch.Tensor) -> torch.Tensor: 31 | """Sample a noise with the same size (shape) as the input tensor 32 | 33 | This method will call the sampling method as follows ``.sample(tensor.shape)`` 34 | 35 | :param tensor: Input tensor used to determine output tensor size (shape) 36 | :type tensor: torch.Tensor 37 | 38 | :return: Sampled noise 39 | :rtype: torch.Tensor 40 | 41 | Example:: 42 | 43 | >>> x = torch.rand(3, 2, device="cuda:0") 44 | >>> noise.sample_like(x) 45 | tensor([[-0.0423, -0.1325], 46 | [-0.0639, -0.0957], 47 | [-0.1367, 0.1031]], device='cuda:0') 48 | """ 49 | return self.sample(tensor.shape) 50 | 51 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: 52 | """Noise sampling method to be implemented by the inheriting classes 53 | 54 | :param size: Shape of the sampled tensor 55 | :type size: tuple or list of int, or torch.Size 56 | 57 | :raises NotImplementedError: The method is not implemented by the inheriting classes 58 | 59 | :return: Sampled noise 60 | :rtype: torch.Tensor 61 | """ 62 | raise NotImplementedError("The sampling method (.sample()) is not implemented") 63 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | from skrl import __version__, logger 2 | 3 | 4 | def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> str: 5 | """Download a model from Hugging Face Hub 6 | 7 | :param repo_id: Hugging Face user or organization name and a repo name separated by a ``/`` 8 | :type repo_id: str 9 | :param filename: The name of the model file in the repo (default: ``"agent.pt"``) 10 | :type filename: str, optional 11 | 12 | :raises ImportError: The Hugging Face Hub package (huggingface-hub) is not installed 13 | :raises huggingface_hub.utils._errors.HfHubHTTPError: Any HTTP error raised in Hugging Face Hub 14 | 15 | :return: Local path of file or if networking is off, last version of file cached on disk 16 | :rtype: str 17 | 18 | Example:: 19 | 20 | # download trained agent from the skrl organization (https://huggingface.co/skrl) 21 | >>> from skrl.utils.huggingface import download_model_from_huggingface 22 | >>> download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO") 23 | '/home/user/.cache/huggingface/hub/models--skrl--OmniIsaacGymEnvs-Cartpole-PPO/snapshots/892e629903de6bf3ef102ae760406a5dd0f6f873/agent.pt' 24 | 25 | # download model (e.g. "policy.pth") from another user/organization (e.g. "org/ddpg-Pendulum-v1") 26 | >>> from skrl.utils.huggingface import download_model_from_huggingface 27 | >>> download_model_from_huggingface("org/ddpg-Pendulum-v1", "policy.pth") 28 | '/home/user/.cache/huggingface/hub/models--org--ddpg-Pendulum-v1/snapshots/b44ee96f93ff2e296156b002a2ca4646e197ba32/policy.pth' 29 | """ 30 | logger.info(f"Downloading model from Hugging Face Hub: {repo_id}/{filename}") 31 | try: 32 | import huggingface_hub 33 | except ImportError: 34 | logger.error("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") 35 | huggingface_hub = None 36 | 37 | if huggingface_hub is None: 38 | raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") 39 | 40 | # download and cache the model from Hugging Face Hub 41 | downloaded_model_file = huggingface_hub.hf_hub_download( 42 | repo_id=repo_id, filename=filename, library_name="skrl", library_version=__version__ 43 | ) 44 | 45 | return downloaded_model_file 46 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/race/config/crazyflie/agents/rsl_rl_ppo_cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, The Isaac Lab Project Developers. 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from isaaclab.utils import configclass 7 | 8 | from .rl_cfg import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, RslRlPpoAlgorithmCfg, RslRlPpoActorCriticRecurrentCfg 9 | 10 | 11 | @configclass 12 | class QuadcopterPPORunnerCfg(RslRlOnPolicyRunnerCfg): 13 | num_steps_per_env = 24 14 | max_iterations = 200 15 | save_interval = 50 16 | experiment_name = "quadcopter_direct" 17 | empirical_normalization = False 18 | wandb_project = "single-agent-sim2real" 19 | wandb_entity = "vineetp-university-of-pennsylvania" 20 | 21 | # ego drone: recurrent FiLM architecture 22 | # policy = RslRlPpoActorCriticRecurrentCfg( 23 | # class_name="ActorCriticRecurrentFiLM", 24 | # init_noise_std=1.0, 25 | # actor_hidden_dims=[128, 128], 26 | # film_hidden_dims=[3, 3], 27 | # cond_dim=2, 28 | # critic_hidden_dims=[512, 512], 29 | # activation="elu", 30 | # rnn_type="lstm", 31 | # rnn_hidden_size=256, 32 | # rnn_num_layers=2, 33 | # min_std=0.2, 34 | # ) 35 | 36 | # ego drone: non-recurrent FiLM architecture (for compatibility with saved checkpoint) 37 | policy = RslRlPpoActorCriticCfg( 38 | class_name="ActorCritic", 39 | init_noise_std=1.0, 40 | actor_hidden_dims=[128, 128], 41 | film_hidden_dims=[3, 3], 42 | cond_dim=2, 43 | critic_hidden_dims=[512, 512], 44 | activation="elu", 45 | min_std=0.2, 46 | ) 47 | 48 | # adversary drone: non-recurrent FiLM 49 | adversary_policy = RslRlPpoActorCriticCfg( 50 | class_name="ActorCritic", 51 | init_noise_std=1.0, 52 | actor_hidden_dims=[128, 128], 53 | film_hidden_dims=[3, 3], 54 | cond_dim=2, 55 | critic_hidden_dims=[512, 512], 56 | activation="elu", 57 | min_std=0.2, 58 | ) 59 | 60 | algorithm = RslRlPpoAlgorithmCfg( 61 | value_loss_coef=1.0, 62 | use_clipped_value_loss=True, 63 | clip_param=0.2, 64 | entropy_coef=0.0, 65 | num_learning_epochs=5, 66 | num_mini_batches=4, 67 | learning_rate=5.0e-4, 68 | schedule="adaptive", 69 | gamma=0.99, 70 | lam=0.95, 71 | desired_kl=0.01, 72 | max_grad_norm=1.0, 73 | ) 74 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "isaac_quad_sim2real" 7 | version = "0.1.0" 8 | keywords = ["reinforcement-learning", "isaac", "drone", "rl-pytorch"] 9 | maintainers = [ 10 | { name="Lorenzo Bianchi", email="lorebia@seas.upenn.edu" }, 11 | ] 12 | authors = [ 13 | { name="Lorenzo Bianchi", email="lorebia@seas.upenn.edu" }, 14 | ] 15 | description = "" 16 | readme = { file = "README.md", content-type = "text/markdown"} 17 | license = { text = "BSD-3-Clause" } 18 | 19 | requires-python = ">=3.7" 20 | classifiers = [ 21 | "Programming Language :: Python :: 3", 22 | "Operating System :: OS Independent", 23 | ] 24 | dependencies = [] 25 | 26 | [project.urls] 27 | Homepage = "https://github.com/Jirl-upenn/isaac_quad_sim2real" 28 | Issues = "https://github.com/Jirl-upenn/isaac_quad_sim2real/issues" 29 | 30 | [tool.setuptools.packages.find] 31 | where = ["."] 32 | include = ["isaac_quad_sim2real*", "third_parties.rsl_rl*", "third_parties.skrl*"] 33 | 34 | [tool.setuptools.package-data] 35 | "isaac_quad_sim2real" = ["config/*"] 36 | 37 | [tool.isort] 38 | 39 | py_version = 37 40 | line_length = 120 41 | group_by_package = true 42 | 43 | # Files to skip 44 | skip_glob = [".vscode/*"] 45 | 46 | # Order of imports 47 | sections = [ 48 | "FUTURE", 49 | "STDLIB", 50 | "THIRDPARTY", 51 | "FIRSTPARTY", 52 | "LOCALFOLDER", 53 | ] 54 | 55 | # Extra standard libraries considered as part of python (permissive licenses) 56 | extra_standard_library = [ 57 | "numpy", 58 | "torch", 59 | "tensordict", 60 | "warp", 61 | "typing_extensions", 62 | "git", 63 | ] 64 | # Imports from this repository 65 | known_first_party = "isaac_quad_sim2real" 66 | 67 | [tool.pyright] 68 | 69 | include = ["isaac_quad_sim2real"] 70 | 71 | typeCheckingMode = "basic" 72 | pythonVersion = "3.10" 73 | pythonPlatform = "Linux" 74 | enableTypeIgnoreComments = true 75 | 76 | # This is required as the CI pre-commit does not download the module (i.e. numpy, torch, prettytable) 77 | # Therefore, we have to ignore missing imports 78 | reportMissingImports = "none" 79 | # This is required to ignore for type checks of modules with stubs missing. 80 | reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers 81 | 82 | reportGeneralTypeIssues = "none" # -> raises 218 errors (usage of literal MISSING in dataclasses) 83 | reportOptionalMemberAccess = "warning" # -> raises 8 errors 84 | reportPrivateUsage = "warning" 85 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/jax/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import jax 4 | import numpy as np 5 | 6 | from skrl import config 7 | 8 | 9 | class Noise: 10 | def __init__(self, device: Optional[Union[str, jax.Device]] = None) -> None: 11 | """Base class representing a noise 12 | 13 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 14 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 15 | :type device: str or jax.Device, optional 16 | 17 | Custom noises should override the ``sample`` method:: 18 | 19 | import jax 20 | from skrl.resources.noises.jax import Noise 21 | 22 | class CustomNoise(Noise): 23 | def __init__(self, device=None): 24 | super().__init__(device) 25 | 26 | def sample(self, size): 27 | return jax.random.uniform(jax.random.PRNGKey(0), size) 28 | """ 29 | self._jax = config.jax.backend == "jax" 30 | 31 | self.device = config.jax.parse_device(device) 32 | 33 | def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]: 34 | """Sample a noise with the same size (shape) as the input tensor 35 | 36 | This method will call the sampling method as follows ``.sample(tensor.shape)`` 37 | 38 | :param tensor: Input tensor used to determine output tensor size (shape) 39 | :type tensor: np.ndarray or jax.Array 40 | 41 | :return: Sampled noise 42 | :rtype: np.ndarray or jax.Array 43 | 44 | Example:: 45 | 46 | >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) 47 | >>> noise.sample_like(x) 48 | Array([[0.57450044, 0.09968603], 49 | [0.7419659 , 0.8941783 ], 50 | [0.59656656, 0.45325184]], dtype=float32) 51 | """ 52 | return self.sample(tensor.shape) 53 | 54 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jax.Array]: 55 | """Noise sampling method to be implemented by the inheriting classes 56 | 57 | :param size: Shape of the sampled tensor 58 | :type size: tuple or list of int 59 | 60 | :raises NotImplementedError: The method is not implemented by the inheriting classes 61 | 62 | :return: Sampled noise 63 | :rtype: np.ndarray or jax.Array 64 | """ 65 | raise NotImplementedError("The sampling method (.sample()) is not implemented") 66 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/schedulers/jax/kl_adaptive.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | import optax 5 | 6 | 7 | def KLAdaptiveLR( 8 | kl_threshold: float = 0.008, 9 | min_lr: float = 1e-6, 10 | max_lr: float = 1e-2, 11 | kl_factor: float = 2, 12 | lr_factor: float = 1.5, 13 | ) -> optax.Schedule: 14 | """Adaptive KL scheduler 15 | 16 | Adjusts the learning rate according to the KL divergence. 17 | The implementation is adapted from the *rl_games* library. 18 | 19 | .. note:: 20 | 21 | This scheduler is only available for PPO at the moment. 22 | Applying it to other agents will not change the learning rate 23 | 24 | Example:: 25 | 26 | >>> scheduler = KLAdaptiveLR(kl_threshold=0.01) 27 | >>> for epoch in range(100): 28 | >>> # ... 29 | >>> kl_divergence = ... 30 | >>> new_lr = scheduler(timestep, lr, kl_divergence) 31 | 32 | :param kl_threshold: Threshold for KL divergence (default: ``0.008``) 33 | :type kl_threshold: float, optional 34 | :param min_lr: Lower bound for learning rate (default: ``1e-6``) 35 | :type min_lr: float, optional 36 | :param max_lr: Upper bound for learning rate (default: ``1e-2``) 37 | :type max_lr: float, optional 38 | :param kl_factor: The number used to modify the KL divergence threshold (default: ``2``) 39 | :type kl_factor: float, optional 40 | :param lr_factor: The number used to modify the learning rate (default: ``1.5``) 41 | :type lr_factor: float, optional 42 | 43 | :return: A function that maps step counts, current learning rate and KL divergence to the new learning rate value. 44 | If no learning rate is specified, 1.0 will be returned to mimic the Optax's scheduler behaviors. 45 | If the learning rate is specified but the KL divergence is not 0, the specified learning rate is returned. 46 | :rtype: optax.Schedule 47 | """ 48 | 49 | def schedule(count: int, lr: Optional[float] = None, kl: Optional[Union[np.ndarray, float]] = None) -> float: 50 | if lr is None: 51 | return 1.0 52 | if kl is not None: 53 | if kl > kl_threshold * kl_factor: 54 | lr = max(lr / lr_factor, min_lr) 55 | elif kl < kl_threshold / kl_factor: 56 | lr = min(lr * lr_factor, max_lr) 57 | return lr 58 | 59 | return schedule 60 | 61 | 62 | # Alias to maintain naming compatibility with Optax schedulers 63 | # https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html 64 | kl_adaptive = KLAdaptiveLR 65 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/jax/gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from skrl import config 10 | from skrl.resources.noises.jax import Noise 11 | 12 | 13 | # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function 14 | @partial(jax.jit, static_argnames=("shape")) 15 | def _sample(mean, std, key, iterator, shape): 16 | subkey = jax.random.fold_in(key, iterator) 17 | return jax.random.normal(subkey, shape) * std + mean 18 | 19 | 20 | class GaussianNoise(Noise): 21 | def __init__(self, mean: float, std: float, device: Optional[Union[str, jax.Device]] = None) -> None: 22 | """Class representing a Gaussian noise 23 | 24 | :param mean: Mean of the normal distribution 25 | :type mean: float 26 | :param std: Standard deviation of the normal distribution 27 | :type std: float 28 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 29 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 30 | :type device: str or jax.Device, optional 31 | 32 | Example:: 33 | 34 | >>> noise = GaussianNoise(mean=0, std=1) 35 | """ 36 | super().__init__(device) 37 | 38 | if self._jax: 39 | self._i = 0 40 | self._key = config.jax.key 41 | self.mean = jnp.array(mean) 42 | self.std = jnp.array(std) 43 | else: 44 | self.mean = np.array(mean) 45 | self.std = np.array(std) 46 | 47 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jax.Array]: 48 | """Sample a Gaussian noise 49 | 50 | :param size: Shape of the sampled tensor 51 | :type size: tuple or list of int 52 | 53 | :return: Sampled noise 54 | :rtype: np.ndarray or jax.Array 55 | 56 | Example:: 57 | 58 | >>> noise.sample((3, 2)) 59 | Array([[ 0.01878439, -0.12833427], 60 | [ 0.06494182, 0.12490594], 61 | [ 0.024447 , -0.01174496]], dtype=float32) 62 | 63 | >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) 64 | >>> noise.sample(x.shape) 65 | Array([[ 0.17988093, -1.2289404 ], 66 | [ 0.6218886 , 1.1961104 ], 67 | [ 0.23410667, -0.11247082]], dtype=float32) 68 | """ 69 | if self._jax: 70 | self._i += 1 71 | return _sample(self.mean, self.std, self._key, self._i, size) 72 | return np.random.normal(self.mean, self.std, size) 73 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/race/config/crazyflie/agents/skrl_mappo_cfg.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | 4 | # Models are instantiated using skrl's model instantiator utility 5 | # https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html 6 | models: 7 | separate: True 8 | policy: # see gaussian_model parameters 9 | class: GaussianMixin 10 | clip_actions: False 11 | clip_log_std: True 12 | min_log_std: -20.0 13 | max_log_std: 2.0 14 | initial_log_std: 0.0 15 | network: 16 | - name: net 17 | input: STATES 18 | layers: [512, 512, 256, 128] 19 | activations: elu 20 | output: ACTIONS 21 | value: # see deterministic_model parameters 22 | class: DeterministicMixin 23 | clip_actions: False 24 | network: 25 | - name: net 26 | input: STATES 27 | layers: [512, 512, 256, 256, 128, 128] 28 | activations: elu 29 | output: ONE 30 | 31 | 32 | # Rollout memory 33 | # https://skrl.readthedocs.io/en/latest/api/memories/random.html 34 | memory: 35 | class: RandomMemory 36 | memory_size: -1 # automatically determined (same as agent:rollouts) 37 | 38 | 39 | # MAPPO agent configuration (field names are from MAPPO_DEFAULT_CONFIG) 40 | # https://skrl.readthedocs.io/en/latest/api/multi_agents/mappo.html 41 | agent: 42 | class: MAPPO 43 | rollouts: 16 44 | learning_epochs: 5 45 | mini_batches: 4 46 | discount_factor: 0.99 47 | lambda: 0.95 48 | learning_rate: 1.0e-04 #prev 5.0e-04 49 | learning_rate_scheduler: null #prev KLAdaptiveLR 50 | learning_rate_scheduler_kwargs: 51 | kl_threshold: 0.016 52 | state_preprocessor: null #prev RunningStandardScaler 53 | state_preprocessor_kwargs: null 54 | shared_state_preprocessor: RunningStandardScaler 55 | shared_state_preprocessor_kwargs: null 56 | value_preprocessor: null #prev RunningStandardScaler 57 | value_preprocessor_kwargs: null 58 | random_timesteps: 0 59 | learning_starts: 0 60 | grad_norm_clip: 0.5 61 | ratio_clip: 0.2 62 | value_clip: 0.2 63 | clip_predicted_values: True 64 | entropy_loss_scale: 0.0 65 | value_loss_scale: 1.0 66 | kl_threshold: 0.0 67 | rewards_shaper_scale: 1.0 68 | time_limit_bootstrap: False 69 | # logging and checkpoint 70 | experiment: 71 | directory: "mappo_race" 72 | experiment_name: "" 73 | write_interval: auto 74 | checkpoint_interval: auto 75 | wandb: true 76 | wandb_kwargs: 77 | project: "isaac_quad_sim2real" 78 | entity: null # set to your wandb username/team 79 | tags: ["mappo", "race", "multi-agent", "quadcopter"] 80 | group: "mappo_training" 81 | notes: "Multi-agent quadcopter racing with MAPPO" 82 | 83 | 84 | # Sequential trainer 85 | # https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html 86 | trainer: 87 | class: SequentialTrainer 88 | timesteps: 36000 89 | environment_info: log 90 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/omniverse_isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple 2 | 3 | import torch 4 | 5 | from skrl.envs.wrappers.torch.base import Wrapper 6 | from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space 7 | 8 | 9 | class OmniverseIsaacGymWrapper(Wrapper): 10 | def __init__(self, env: Any) -> None: 11 | """Omniverse Isaac Gym environment wrapper 12 | 13 | :param env: The environment to wrap 14 | :type env: Any supported Omniverse Isaac Gym environment 15 | """ 16 | super().__init__(env) 17 | 18 | self._reset_once = True 19 | self._observations = None 20 | self._info = {} 21 | 22 | def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = None) -> None: 23 | """Run the simulation in the main thread 24 | 25 | This method is valid only for the Omniverse Isaac Gym multi-threaded environments 26 | 27 | :param trainer: Trainer which should implement a ``run`` method that initiates the RL loop on a new thread 28 | :type trainer: omni.isaac.gym.vec_env.vec_env_mt.TrainerMT, optional 29 | """ 30 | self._env.run(trainer) 31 | 32 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 33 | """Perform a step in the environment 34 | 35 | :param actions: The actions to perform 36 | :type actions: torch.Tensor 37 | 38 | :return: Observation, reward, terminated, truncated, info 39 | :rtype: tuple of torch.Tensor and any other info 40 | """ 41 | observations, reward, terminated, self._info = self._env.step( 42 | unflatten_tensorized_space(self.action_space, actions) 43 | ) 44 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 45 | truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) 46 | return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info 47 | 48 | def reset(self) -> Tuple[torch.Tensor, Any]: 49 | """Reset the environment 50 | 51 | :return: Observation, info 52 | :rtype: torch.Tensor and any other info 53 | """ 54 | if self._reset_once: 55 | observations = self._env.reset() 56 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 57 | self._reset_once = False 58 | return self._observations, self._info 59 | 60 | def render(self, *args, **kwargs) -> None: 61 | """Render the environment""" 62 | return None 63 | 64 | def close(self) -> None: 65 | """Close the environment""" 66 | self._env.close() 67 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/torch/ornstein_uhlenbeck.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch.distributions import Normal 5 | 6 | from skrl.resources.noises.torch import Noise 7 | 8 | 9 | # speed up distribution construction by disabling checking 10 | Normal.set_default_validate_args(False) 11 | 12 | 13 | class OrnsteinUhlenbeckNoise(Noise): 14 | def __init__( 15 | self, 16 | theta: float, 17 | sigma: float, 18 | base_scale: float, 19 | mean: float = 0, 20 | std: float = 1, 21 | device: Optional[Union[str, torch.device]] = None, 22 | ) -> None: 23 | """Class representing an Ornstein-Uhlenbeck noise 24 | 25 | :param theta: Factor to apply to current internal state 26 | :type theta: float 27 | :param sigma: Factor to apply to the normal distribution 28 | :type sigma: float 29 | :param base_scale: Factor to apply to returned noise 30 | :type base_scale: float 31 | :param mean: Mean of the normal distribution (default: ``0.0``) 32 | :type mean: float, optional 33 | :param std: Standard deviation of the normal distribution (default: ``1.0``) 34 | :type std: float, optional 35 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 36 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 37 | :type device: str or torch.device, optional 38 | 39 | Example:: 40 | 41 | >>> noise = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=0.5) 42 | """ 43 | super().__init__(device) 44 | 45 | self.state = 0 46 | self.theta = theta 47 | self.sigma = sigma 48 | self.base_scale = base_scale 49 | 50 | self.distribution = Normal( 51 | loc=torch.tensor(mean, device=self.device, dtype=torch.float32), 52 | scale=torch.tensor(std, device=self.device, dtype=torch.float32), 53 | ) 54 | 55 | def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor: 56 | """Sample an Ornstein-Uhlenbeck noise 57 | 58 | :param size: Shape of the sampled tensor 59 | :type size: tuple or list of int, or torch.Size 60 | 61 | :return: Sampled noise 62 | :rtype: torch.Tensor 63 | 64 | Example:: 65 | 66 | >>> noise.sample((3, 2)) 67 | tensor([[-0.0452, 0.0162], 68 | [ 0.0649, -0.0708], 69 | [-0.0211, 0.0066]], device='cuda:0') 70 | 71 | >>> x = torch.rand(3, 2, device="cuda:0") 72 | >>> noise.sample(x.shape) 73 | tensor([[-0.0540, 0.0461], 74 | [ 0.1117, -0.1157], 75 | [-0.0074, 0.0420]], device='cuda:0') 76 | """ 77 | if hasattr(self.state, "shape") and self.state.shape != torch.Size(size): 78 | self.state = 0 79 | self.state += -self.state * self.theta + self.sigma * self.distribution.sample(size) 80 | 81 | return self.base_scale * self.state 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Agile Flight Emerges from Multi-Agent Competitive Racing 2 | 3 | [![Agile Flight Emerges from Multi-Agent Competitive Racing](AgileFlight_CoverImage.png)](https://youtu.be/AIUfCbEJX6E) 4 | 5 | This repository contains the code for training and evaluating Our multi-agent quadcopter racing policy in the paper, [Agile Flight Emerges from Multi-Agent Competitive Racing](https://arxiv.org/abs/2512.11781). 6 | In order to train the Dense Single (DS) and Sparse Single (SS) policies, please navigate to the [AgileFlight_SingleAgent branch](https://github.com/Jirl-upenn/AgileFlight_MultiAgent/tree/AgileFlight_SingleAgent). 7 | ## Paper and Video 8 | 9 | Paper: [arXiv](https://arxiv.org/abs/2512.11781) 10 | 11 | Video: [Youtube](https://youtu.be/AIUfCbEJX6E) 12 | 13 | ```bibtex 14 | @misc{pasumarti2025agileflightemergesmultiagent, 15 | title={Agile Flight Emerges from Multi-Agent Competitive Racing}, 16 | author={Vineet Pasumarti and Lorenzo Bianchi and Antonio Loquercio}, 17 | year={2025}, 18 | eprint={2512.11781}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.RO}, 21 | url={https://arxiv.org/abs/2512.11781}, 22 | } 23 | ``` 24 | 25 | ## Setup 26 | 27 | ### Prerequisites 28 | 29 | - GPU with CUDA support 30 | - NVIDIA Isaac Sim v4.5.0 31 | - NVIDIA Isaac Lab v2.1.0 32 | - Ubuntu 20.04 / 22.04 (recommended) 33 | 34 | ### Installation 35 | 36 | 1. Clone the repository: 37 | 38 | ```bash 39 | # It is critical that the project repo and the Isaac Lab directory are at the same level 40 | git clone -b AgileFlight_MultiAgent https://github.com/Jirl-upenn/isaac_quad_sim2real.git 41 | cd isaac_quad_sim2real 42 | ``` 43 | 44 | 2. Create and activate your Isaac Lab v2.1.0 conda environment ([Isaac Lab installation guide](https://isaac-sim.github.io/IsaacLab/main/source/setup/installation/index.html)) 45 | 46 | 3. Install the package and dependencies: 47 | 48 | ```bash 49 | # Install the main package 50 | pip install -e . 51 | ``` 52 | 53 | ## Training Examples 54 | 55 | The main training script uses a modified mappo.py from the [skrl](https://skrl.readthedocs.io/) library. 56 | 57 | ```bash 58 | # Train Our policy on the Complex Track with walls 59 | python scripts/skrl/ma_train_race.py \ 60 | --task Isaac-MA-Quadcopter-Race-v0 \ 61 | --num_envs 10240 \ 62 | --algorithm MAPPO \ 63 | --max_iterations 10000 \ 64 | --headless \ 65 | --use_wall \ 66 | --track complex 67 | ``` 68 | 69 | ```bash 70 | # Train Our policy on the Lemniscate Track 71 | python scripts/skrl/ma_train_race.py \ 72 | --task Isaac-MA-Quadcopter-Race-v0 \ 73 | --num_envs 10240 \ 74 | --algorithm MAPPO \ 75 | --max_iterations 10000 \ 76 | --headless \ 77 | --track lemniscate 78 | ``` 79 | 80 | ## Evaluation 81 | 82 | To evaluate a trained policy: 83 | 84 | ```bash 85 | # Evaluate Our policy on the Complex Track with walls 86 | python scripts/skrl/ma_play_race.py \ 87 | --task Isaac-MA-Quadcopter-Race-v0 \ 88 | --num_envs 1 \ 89 | --algorithm MAPPO \ 90 | --track complex \ 91 | --use_wall \ 92 | --checkpoint path/to/checkpoint.pt \ 93 | --video \ 94 | --video_length 1000 \ 95 | --headless 96 | ``` -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/brax_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import gymnasium 4 | 5 | import torch 6 | 7 | from skrl import logger 8 | from skrl.envs.wrappers.torch.base import Wrapper 9 | from skrl.utils.spaces.torch import ( 10 | convert_gym_space, 11 | flatten_tensorized_space, 12 | tensorize_space, 13 | unflatten_tensorized_space, 14 | ) 15 | 16 | 17 | class BraxWrapper(Wrapper): 18 | def __init__(self, env: Any) -> None: 19 | """Brax environment wrapper 20 | 21 | :param env: The environment to wrap 22 | :type env: Any supported Brax environment 23 | """ 24 | super().__init__(env) 25 | 26 | import brax.envs.wrappers.gym 27 | import brax.envs.wrappers.torch 28 | 29 | env = brax.envs.wrappers.gym.VectorGymWrapper(env) 30 | env = brax.envs.wrappers.torch.TorchWrapper(env, device=self.device) 31 | self._env = env 32 | self._unwrapped = env.unwrapped 33 | 34 | @property 35 | def observation_space(self) -> gymnasium.Space: 36 | """Observation space""" 37 | return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) 38 | 39 | @property 40 | def action_space(self) -> gymnasium.Space: 41 | """Action space""" 42 | return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) 43 | 44 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 45 | """Perform a step in the environment 46 | 47 | :param actions: The actions to perform 48 | :type actions: torch.Tensor 49 | 50 | :return: Observation, reward, terminated, truncated, info 51 | :rtype: tuple of torch.Tensor and any other info 52 | """ 53 | observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) 54 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation)) 55 | truncated = torch.zeros_like(terminated) 56 | return observation, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info 57 | 58 | def reset(self) -> Tuple[torch.Tensor, Any]: 59 | """Reset the environment 60 | 61 | :return: Observation, info 62 | :rtype: torch.Tensor and any other info 63 | """ 64 | observation = self._env.reset() 65 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation)) 66 | return observation, {} 67 | 68 | def render(self, *args, **kwargs) -> None: 69 | """Render the environment""" 70 | frame = self._env.render(mode="rgb_array") 71 | 72 | # render the frame using OpenCV 73 | try: 74 | import cv2 75 | 76 | cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 77 | cv2.waitKey(1) 78 | except ImportError as e: 79 | logger.warning(f"Unable to import opencv-python: {e}. Frame will not be rendered.") 80 | return frame 81 | 82 | def close(self) -> None: 83 | """Close the environment""" 84 | # self._env.close() raises AttributeError: 'VectorGymWrapper' object has no attribute 'closed' 85 | pass 86 | -------------------------------------------------------------------------------- /src/third_parties/skrl/memories/jax/random.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import jax 4 | import numpy as np 5 | 6 | from skrl.memories.jax import Memory 7 | 8 | 9 | class RandomMemory(Memory): 10 | def __init__( 11 | self, 12 | memory_size: int, 13 | num_envs: int = 1, 14 | device: Optional[jax.Device] = None, 15 | export: bool = False, 16 | export_format: str = "pt", 17 | export_directory: str = "", 18 | replacement=True, 19 | ) -> None: 20 | """Random sampling memory 21 | 22 | Sample a batch from memory randomly 23 | 24 | :param memory_size: Maximum number of elements in the first dimension of each internal storage 25 | :type memory_size: int 26 | :param num_envs: Number of parallel environments (default: ``1``) 27 | :type num_envs: int, optional 28 | :param device: Device on which an array is or will be allocated (default: ``None``) 29 | :type device: jax.Device, optional 30 | :param export: Export the memory to a file (default: ``False``). 31 | If True, the memory will be exported when the memory is filled 32 | :type export: bool, optional 33 | :param export_format: Export format (default: ``"pt"``). 34 | Supported formats: torch (pt), numpy (np), comma separated values (csv) 35 | :type export_format: str, optional 36 | :param export_directory: Directory where the memory will be exported (default: ``""``). 37 | If empty, the agent's experiment directory will be used 38 | :type export_directory: str, optional 39 | :param replacement: Flag to indicate whether the sample is with or without replacement (default: ``True``). 40 | Replacement implies that a value can be selected multiple times (the batch size is always guaranteed). 41 | Sampling without replacement will return a batch of maximum memory size if the memory size is less than the requested batch size 42 | :type replacement: bool, optional 43 | 44 | :raises ValueError: The export format is not supported 45 | """ 46 | super().__init__(memory_size, num_envs, device, export, export_format, export_directory) 47 | 48 | self._replacement = replacement 49 | 50 | def sample(self, names: Tuple[str], batch_size: int, mini_batches: int = 1) -> List[List[jax.Array]]: 51 | """Sample a batch from memory randomly 52 | 53 | :param names: Tensors names from which to obtain the samples 54 | :type names: tuple or list of strings 55 | :param batch_size: Number of element to sample 56 | :type batch_size: int 57 | :param mini_batches: Number of mini-batches to sample (default: ``1``) 58 | :type mini_batches: int, optional 59 | 60 | :return: Sampled data from tensors sorted according to their position in the list of names. 61 | The sampled tensors will have the following shape: (batch size, data size) 62 | :rtype: list of jax.Array list 63 | """ 64 | # generate random indexes 65 | if self._replacement: 66 | indexes = np.random.randint(0, len(self), (batch_size,)) 67 | else: 68 | indexes = np.random.permutation(len(self))[:batch_size] 69 | 70 | return self.sample_by_index(names=names, indexes=indexes, mini_batches=mini_batches) 71 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/noises/jax/ornstein_uhlenbeck.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from skrl import config 10 | from skrl.resources.noises.jax import Noise 11 | 12 | 13 | # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function 14 | @partial(jax.jit, static_argnames=("shape")) 15 | def _sample(theta, sigma, state, mean, std, key, iterator, shape): 16 | subkey = jax.random.fold_in(key, iterator) 17 | return state * theta + sigma * (jax.random.normal(subkey, shape) * std + mean) 18 | 19 | 20 | class OrnsteinUhlenbeckNoise(Noise): 21 | def __init__( 22 | self, 23 | theta: float, 24 | sigma: float, 25 | base_scale: float, 26 | mean: float = 0, 27 | std: float = 1, 28 | device: Optional[Union[str, jax.Device]] = None, 29 | ) -> None: 30 | """Class representing an Ornstein-Uhlenbeck noise 31 | 32 | :param theta: Factor to apply to current internal state 33 | :type theta: float 34 | :param sigma: Factor to apply to the normal distribution 35 | :type sigma: float 36 | :param base_scale: Factor to apply to returned noise 37 | :type base_scale: float 38 | :param mean: Mean of the normal distribution (default: ``0.0``) 39 | :type mean: float, optional 40 | :param std: Standard deviation of the normal distribution (default: ``1.0``) 41 | :type std: float, optional 42 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 43 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 44 | :type device: str or jax.Device, optional 45 | 46 | Example:: 47 | 48 | >>> noise = OrnsteinUhlenbeckNoise(theta=0.1, sigma=0.2, base_scale=0.5) 49 | """ 50 | super().__init__(device) 51 | 52 | self.state = 0 53 | self.theta = theta 54 | self.sigma = sigma 55 | self.base_scale = base_scale 56 | 57 | if self._jax: 58 | self.mean = jnp.array(mean) 59 | self.std = jnp.array(std) 60 | 61 | self._i = 0 62 | self._key = config.jax.key 63 | else: 64 | self.mean = np.array(mean) 65 | self.std = np.array(std) 66 | 67 | def sample(self, size: Tuple[int]) -> Union[np.ndarray, jax.Array]: 68 | """Sample an Ornstein-Uhlenbeck noise 69 | 70 | :param size: Shape of the sampled tensor 71 | :type size: tuple or list of int 72 | 73 | :return: Sampled noise 74 | :rtype: np.ndarray or jax.Array 75 | 76 | Example:: 77 | 78 | >>> noise.sample((3, 2)) 79 | Array([[ 0.01878439, -0.12833427], 80 | [ 0.06494182, 0.12490594], 81 | [ 0.024447 , -0.01174496]], dtype=float32) 82 | 83 | >>> x = jax.random.uniform(jax.random.PRNGKey(0), (3, 2)) 84 | >>> noise.sample(x.shape) 85 | Array([[ 0.17988093, -1.2289404 ], 86 | [ 0.6218886 , 1.1961104 ], 87 | [ 0.23410667, -0.11247082]], dtype=float32) 88 | """ 89 | if hasattr(self.state, "shape") and self.state.shape != size: 90 | self.state = 0 91 | if self._jax: 92 | self._i += 1 93 | self.state = _sample(self.theta, self.sigma, self.state, self.mean, self.std, self._key, self._i, size) 94 | else: 95 | self.state += -self.state * self.theta + self.sigma * np.random.normal(self.mean, self.std, size) 96 | return self.base_scale * self.state 97 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/jax/brax_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Union 2 | 3 | import gymnasium 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from skrl import logger 10 | from skrl.envs.wrappers.jax.base import Wrapper 11 | from skrl.utils.spaces.jax import ( 12 | convert_gym_space, 13 | flatten_tensorized_space, 14 | tensorize_space, 15 | unflatten_tensorized_space, 16 | ) 17 | 18 | 19 | class BraxWrapper(Wrapper): 20 | def __init__(self, env: Any) -> None: 21 | """Brax environment wrapper 22 | 23 | :param env: The environment to wrap 24 | :type env: Any supported Brax environment 25 | """ 26 | super().__init__(env) 27 | 28 | import brax.envs.wrappers.gym 29 | 30 | env = brax.envs.wrappers.gym.VectorGymWrapper(env) 31 | self._env = env 32 | self._unwrapped = env.unwrapped 33 | 34 | @property 35 | def observation_space(self) -> gymnasium.Space: 36 | """Observation space""" 37 | return convert_gym_space(self._unwrapped.observation_space, squeeze_batch_dimension=True) 38 | 39 | @property 40 | def action_space(self) -> gymnasium.Space: 41 | """Action space""" 42 | return convert_gym_space(self._unwrapped.action_space, squeeze_batch_dimension=True) 43 | 44 | def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ 45 | Union[np.ndarray, jax.Array], 46 | Union[np.ndarray, jax.Array], 47 | Union[np.ndarray, jax.Array], 48 | Union[np.ndarray, jax.Array], 49 | Any, 50 | ]: 51 | """Perform a step in the environment 52 | 53 | :param actions: The actions to perform 54 | :type actions: np.ndarray or jax.Array 55 | 56 | :return: Observation, reward, terminated, truncated, info 57 | :rtype: tuple of np.ndarray or jax.Array and any other info 58 | """ 59 | observation, reward, terminated, info = self._env.step(unflatten_tensorized_space(self.action_space, actions)) 60 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, device=self.device)) 61 | truncated = jnp.zeros_like(terminated) 62 | if not self._jax: 63 | observation = np.asarray(jax.device_get(observation)) 64 | reward = np.asarray(jax.device_get(reward)) 65 | terminated = np.asarray(jax.device_get(terminated)) 66 | truncated = np.asarray(jax.device_get(truncated)) 67 | return observation, reward.reshape(-1, 1), terminated.reshape(-1, 1), truncated.reshape(-1, 1), info 68 | 69 | def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: 70 | """Reset the environment 71 | 72 | :return: Observation, info 73 | :rtype: np.ndarray or jax.Array and any other info 74 | """ 75 | observation = self._env.reset() 76 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, device=self.device)) 77 | if not self._jax: 78 | observation = np.asarray(jax.device_get(observation)) 79 | return observation, {} 80 | 81 | def render(self, *args, **kwargs) -> None: 82 | """Render the environment""" 83 | frame = self._env.render(mode="rgb_array") 84 | 85 | # render the frame using OpenCV 86 | try: 87 | import cv2 88 | 89 | cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 90 | cv2.waitKey(1) 91 | except ImportError as e: 92 | logger.warning(f"Unable to import opencv-python: {e}. Frame will not be rendered.") 93 | return frame 94 | 95 | def close(self) -> None: 96 | """Close the environment""" 97 | # self._env.close() raises AttributeError: 'VectorGymWrapper' object has no attribute 'closed' 98 | pass 99 | -------------------------------------------------------------------------------- /scripts/skrl/cli_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, The Isaac Lab Project Developers. 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from __future__ import annotations 7 | 8 | import argparse 9 | import random 10 | from typing import TYPE_CHECKING 11 | 12 | if TYPE_CHECKING: 13 | from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg 14 | 15 | 16 | def add_rsl_rl_args(parser: argparse.ArgumentParser): 17 | """Add RSL-RL arguments to the parser. 18 | 19 | Args: 20 | parser: The parser to add the arguments to. 21 | """ 22 | # create a new argument group 23 | arg_group = parser.add_argument_group("rsl_rl", description="Arguments for RSL-RL agent.") 24 | # -- experiment arguments 25 | arg_group.add_argument( 26 | "--experiment_name", type=str, default=None, help="Name of the experiment folder where logs will be stored." 27 | ) 28 | arg_group.add_argument("--run_name", type=str, default=None, help="Run name suffix to the log directory.") 29 | # -- load arguments 30 | arg_group.add_argument("--resume", type=bool, default=None, help="Whether to resume from a checkpoint.") 31 | arg_group.add_argument("--load_run", type=str, default=None, help="Name of the run folder to resume from.") 32 | arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file to resume from.") 33 | # -- logger arguments 34 | arg_group.add_argument( 35 | "--logger", type=str, default=None, choices={"wandb", "tensorboard", "neptune"}, help="Logger module to use." 36 | ) 37 | arg_group.add_argument( 38 | "--log_project_name", type=str, default=None, help="Name of the logging project when using wandb or neptune." 39 | ) 40 | 41 | 42 | def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg: 43 | """Parse configuration for RSL-RL agent based on inputs. 44 | 45 | Args: 46 | task_name: The name of the environment. 47 | args_cli: The command line arguments. 48 | 49 | Returns: 50 | The parsed configuration for RSL-RL agent based on inputs. 51 | """ 52 | from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry 53 | 54 | # load the default configuration 55 | rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") 56 | rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli) 57 | return rslrl_cfg 58 | 59 | 60 | def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace): 61 | """Update configuration for RSL-RL agent based on inputs. 62 | 63 | Args: 64 | agent_cfg: The configuration for RSL-RL agent. 65 | args_cli: The command line arguments. 66 | 67 | Returns: 68 | The updated configuration for RSL-RL agent based on inputs. 69 | """ 70 | # override the default configuration with CLI arguments 71 | if hasattr(args_cli, "seed") and args_cli.seed is not None: 72 | # randomly sample a seed if seed = -1 73 | if args_cli.seed == -1: 74 | args_cli.seed = random.randint(0, 10000) 75 | agent_cfg.seed = args_cli.seed 76 | if args_cli.resume is not None: 77 | agent_cfg.resume = args_cli.resume 78 | if args_cli.load_run is not None: 79 | agent_cfg.load_run = args_cli.load_run 80 | if args_cli.checkpoint is not None: 81 | agent_cfg.load_checkpoint = args_cli.checkpoint 82 | if args_cli.run_name is not None: 83 | agent_cfg.run_name = args_cli.run_name 84 | if args_cli.logger is not None: 85 | agent_cfg.logger = args_cli.logger 86 | # set the project name for wandb and neptune 87 | if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name: 88 | agent_cfg.wandb_project = args_cli.log_project_name 89 | agent_cfg.neptune_project = args_cli.log_project_name 90 | 91 | return agent_cfg 92 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/pettingzoo_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Tuple 2 | 3 | import collections 4 | 5 | import torch 6 | 7 | from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper 8 | from skrl.utils.spaces.torch import ( 9 | flatten_tensorized_space, 10 | tensorize_space, 11 | unflatten_tensorized_space, 12 | untensorize_space, 13 | ) 14 | 15 | 16 | class PettingZooWrapper(MultiAgentEnvWrapper): 17 | def __init__(self, env: Any) -> None: 18 | """PettingZoo (parallel) environment wrapper 19 | 20 | :param env: The environment to wrap 21 | :type env: Any supported PettingZoo (parallel) environment 22 | """ 23 | super().__init__(env) 24 | 25 | def step(self, actions: Mapping[str, torch.Tensor]) -> Tuple[ 26 | Mapping[str, torch.Tensor], 27 | Mapping[str, torch.Tensor], 28 | Mapping[str, torch.Tensor], 29 | Mapping[str, torch.Tensor], 30 | Mapping[str, Any], 31 | ]: 32 | """Perform a step in the environment 33 | 34 | :param actions: The actions to perform 35 | :type actions: dictionary of torch.Tensor 36 | 37 | :return: Observation, reward, terminated, truncated, info 38 | :rtype: tuple of dictionaries torch.Tensor and any other info 39 | """ 40 | actions = { 41 | uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) 42 | for uid, action in actions.items() 43 | } 44 | observations, rewards, terminated, truncated, infos = self._env.step(actions) 45 | 46 | # convert response to torch 47 | observations = { 48 | uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) 49 | for uid, value in observations.items() 50 | } 51 | rewards = { 52 | uid: torch.tensor(value, device=self.device, dtype=torch.float32).view(self.num_envs, -1) 53 | for uid, value in rewards.items() 54 | } 55 | terminated = { 56 | uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) 57 | for uid, value in terminated.items() 58 | } 59 | truncated = { 60 | uid: torch.tensor(value, device=self.device, dtype=torch.bool).view(self.num_envs, -1) 61 | for uid, value in truncated.items() 62 | } 63 | return observations, rewards, terminated, truncated, infos 64 | 65 | def state(self) -> torch.Tensor: 66 | """Get the environment state 67 | 68 | :return: State 69 | :rtype: torch.Tensor 70 | """ 71 | return flatten_tensorized_space( 72 | tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), device=self.device) 73 | ) 74 | 75 | def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: 76 | """Reset the environment 77 | 78 | :return: Observation, info 79 | :rtype: tuple of dictionaries of torch.Tensor and any other info 80 | """ 81 | outputs = self._env.reset() 82 | if isinstance(outputs, collections.abc.Mapping): 83 | observations = outputs 84 | infos = {uid: {} for uid in self.possible_agents} 85 | else: 86 | observations, infos = outputs 87 | 88 | # convert response to torch 89 | observations = { 90 | uid: flatten_tensorized_space(tensorize_space(self.observation_spaces[uid], value, device=self.device)) 91 | for uid, value in observations.items() 92 | } 93 | return observations, infos 94 | 95 | def render(self, *args, **kwargs) -> Any: 96 | """Render the environment""" 97 | return self._env.render(*args, **kwargs) 98 | 99 | def close(self) -> None: 100 | """Close the environment""" 101 | self._env.close() 102 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import os 4 | import random 5 | import sys 6 | import time 7 | 8 | import numpy as np 9 | 10 | from skrl import config, logger 11 | 12 | 13 | def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int: 14 | """ 15 | Set the seed for the random number generators 16 | 17 | .. note:: 18 | 19 | In distributed runs, the worker/process seed will be incremented (counting from the defined value) according to its rank 20 | 21 | .. warning:: 22 | 23 | Due to NumPy's legacy seeding constraint the seed must be between 0 and 2**32 - 1. 24 | Otherwise a NumPy exception (``ValueError: Seed must be between 0 and 2**32 - 1``) will be raised 25 | 26 | Modified packages: 27 | 28 | - random 29 | - numpy 30 | - torch (if available) 31 | - jax (skrl's PRNG key: ``config.jax.key``) 32 | 33 | Example:: 34 | 35 | # fixed seed 36 | >>> from skrl.utils import set_seed 37 | >>> set_seed(42) 38 | [skrl:INFO] Seed: 42 39 | 42 40 | 41 | # random seed 42 | >>> from skrl.utils import set_seed 43 | >>> set_seed() 44 | [skrl:INFO] Seed: 1776118066 45 | 1776118066 46 | 47 | # enable deterministic. The following environment variables should be established: 48 | # - CUDA 10.1: CUDA_LAUNCH_BLOCKING=1 49 | # - CUDA 10.2 or later: CUBLAS_WORKSPACE_CONFIG=:16:8 or CUBLAS_WORKSPACE_CONFIG=:4096:8 50 | >>> from skrl.utils import set_seed 51 | >>> set_seed(42, deterministic=True) 52 | [skrl:INFO] Seed: 42 53 | [skrl:WARNING] PyTorch/cuDNN deterministic algorithms are enabled. This may affect performance 54 | 42 55 | 56 | :param seed: The seed to set. Is None, a random seed will be generated (default: ``None``) 57 | :type seed: int, optional 58 | :param deterministic: Whether PyTorch is configured to use deterministic algorithms (default: ``False``). 59 | The following environment variables should be established for CUDA 10.1 (``CUDA_LAUNCH_BLOCKING=1``) 60 | and for CUDA 10.2 or later (``CUBLAS_WORKSPACE_CONFIG=:16:8`` or ``CUBLAS_WORKSPACE_CONFIG=:4096:8``). 61 | See PyTorch `Reproducibility `_ for details 62 | :type deterministic: bool, optional 63 | 64 | :return: Seed 65 | :rtype: int 66 | """ 67 | # generate a random seed 68 | if seed is None: 69 | try: 70 | seed = int.from_bytes(os.urandom(4), byteorder=sys.byteorder) 71 | except NotImplementedError: 72 | seed = int(time.time() * 1000) 73 | seed %= 2**31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1 74 | seed = int(seed) 75 | 76 | # set different seeds in distributed runs 77 | if config.torch.is_distributed: 78 | seed += config.torch.rank 79 | if config.jax.is_distributed: 80 | seed += config.jax.rank 81 | 82 | logger.info(f"Seed: {seed}") 83 | 84 | # numpy 85 | random.seed(seed) 86 | np.random.seed(seed) 87 | 88 | # torch 89 | try: 90 | import torch 91 | 92 | torch.manual_seed(seed) 93 | torch.cuda.manual_seed(seed) 94 | torch.cuda.manual_seed_all(seed) 95 | 96 | if deterministic: 97 | torch.backends.cudnn.benchmark = False 98 | torch.backends.cudnn.deterministic = True 99 | 100 | # On CUDA 10.1, set environment variable CUDA_LAUNCH_BLOCKING=1 101 | # On CUDA 10.2 or later, set environment variable CUBLAS_WORKSPACE_CONFIG=:16:8 or CUBLAS_WORKSPACE_CONFIG=:4096:8 102 | 103 | logger.warning("PyTorch/cuDNN deterministic algorithms are enabled. This may affect performance") 104 | except ImportError: 105 | pass 106 | except Exception as e: 107 | logger.warning(f"PyTorch seeding error: {e}") 108 | 109 | # jax 110 | config.jax.key = seed 111 | 112 | return seed 113 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/schedulers/torch/kl_adaptive.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from packaging import version 4 | 5 | import torch 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class KLAdaptiveLR(_LRScheduler): 10 | def __init__( 11 | self, 12 | optimizer: torch.optim.Optimizer, 13 | kl_threshold: float = 0.008, 14 | min_lr: float = 1e-6, 15 | max_lr: float = 1e-2, 16 | kl_factor: float = 2, 17 | lr_factor: float = 1.5, 18 | last_epoch: int = -1, 19 | verbose: bool = False, 20 | ) -> None: 21 | """Adaptive KL scheduler 22 | 23 | Adjusts the learning rate according to the KL divergence. 24 | The implementation is adapted from the *rl_games* library. 25 | 26 | .. note:: 27 | 28 | This scheduler is only available for PPO at the moment. 29 | Applying it to other agents will not change the learning rate 30 | 31 | Example:: 32 | 33 | >>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01) 34 | >>> for epoch in range(100): 35 | >>> # ... 36 | >>> kl_divergence = ... 37 | >>> scheduler.step(kl_divergence) 38 | 39 | :param optimizer: Wrapped optimizer 40 | :type optimizer: torch.optim.Optimizer 41 | :param kl_threshold: Threshold for KL divergence (default: ``0.008``) 42 | :type kl_threshold: float, optional 43 | :param min_lr: Lower bound for learning rate (default: ``1e-6``) 44 | :type min_lr: float, optional 45 | :param max_lr: Upper bound for learning rate (default: ``1e-2``) 46 | :type max_lr: float, optional 47 | :param kl_factor: The number used to modify the KL divergence threshold (default: ``2``) 48 | :type kl_factor: float, optional 49 | :param lr_factor: The number used to modify the learning rate (default: ``1.5``) 50 | :type lr_factor: float, optional 51 | :param last_epoch: The index of last epoch (default: ``-1``) 52 | :type last_epoch: int, optional 53 | :param verbose: Verbose mode (default: ``False``) 54 | :type verbose: bool, optional 55 | """ 56 | if version.parse(torch.__version__) >= version.parse("2.7"): 57 | super().__init__(optimizer, last_epoch) 58 | else: 59 | if version.parse(torch.__version__) >= version.parse("2.2"): 60 | verbose = "deprecated" 61 | super().__init__(optimizer, last_epoch, verbose) 62 | 63 | self.kl_threshold = kl_threshold 64 | self.min_lr = min_lr 65 | self.max_lr = max_lr 66 | self._kl_factor = kl_factor 67 | self._lr_factor = lr_factor 68 | 69 | self._last_lr = [group["lr"] for group in self.optimizer.param_groups] 70 | 71 | def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[int] = None) -> None: 72 | """ 73 | Step scheduler 74 | 75 | Example:: 76 | 77 | >>> kl = torch.distributions.kl_divergence(p, q) 78 | >>> kl 79 | tensor([0.0332, 0.0500, 0.0383, ..., 0.0076, 0.0240, 0.0164]) 80 | >>> scheduler.step(kl.mean()) 81 | 82 | >>> kl = 0.0046 83 | >>> scheduler.step(kl) 84 | 85 | :param kl: KL divergence (default: ``None``) 86 | If None, no adjustment is made. 87 | If tensor, the number of elements must be 1 88 | :type kl: torch.Tensor, float or None, optional 89 | :param epoch: Epoch (default: ``None``) 90 | :type epoch: int, optional 91 | """ 92 | if kl is not None: 93 | for group in self.optimizer.param_groups: 94 | if kl > self.kl_threshold * self._kl_factor: 95 | group["lr"] = max(group["lr"] / self._lr_factor, self.min_lr) 96 | elif kl < self.kl_threshold / self._kl_factor: 97 | group["lr"] = min(group["lr"] * self._lr_factor, self.max_lr) 98 | 99 | self._last_lr = [group["lr"] for group in self.optimizer.param_groups] 100 | -------------------------------------------------------------------------------- /src/third_parties/skrl/models/torch/deterministic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Tuple, Union 2 | 3 | import gymnasium 4 | 5 | import torch 6 | 7 | 8 | class DeterministicMixin: 9 | def __init__(self, clip_actions: bool = False, role: str = "") -> None: 10 | """Deterministic mixin model (deterministic model) 11 | 12 | :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) 13 | :type clip_actions: bool, optional 14 | :param role: Role play by the model (default: ``""``) 15 | :type role: str, optional 16 | 17 | Example:: 18 | 19 | # define the model 20 | >>> import torch 21 | >>> import torch.nn as nn 22 | >>> from skrl.models.torch import Model, DeterministicMixin 23 | >>> 24 | >>> class Value(DeterministicMixin, Model): 25 | ... def __init__(self, observation_space, action_space, device="cuda:0", clip_actions=False): 26 | ... Model.__init__(self, observation_space, action_space, device) 27 | ... DeterministicMixin.__init__(self, clip_actions) 28 | ... 29 | ... self.net = nn.Sequential(nn.Linear(self.num_observations, 32), 30 | ... nn.ELU(), 31 | ... nn.Linear(32, 32), 32 | ... nn.ELU(), 33 | ... nn.Linear(32, 1)) 34 | ... 35 | ... def compute(self, inputs, role): 36 | ... return self.net(inputs["states"]), {} 37 | ... 38 | >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) 39 | >>> # and an action_space: gymnasium.spaces.Box with shape (8,) 40 | >>> model = Value(observation_space, action_space) 41 | >>> 42 | >>> print(model) 43 | Value( 44 | (net): Sequential( 45 | (0): Linear(in_features=60, out_features=32, bias=True) 46 | (1): ELU(alpha=1.0) 47 | (2): Linear(in_features=32, out_features=32, bias=True) 48 | (3): ELU(alpha=1.0) 49 | (4): Linear(in_features=32, out_features=1, bias=True) 50 | ) 51 | ) 52 | """ 53 | self._d_clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) 54 | 55 | if self._d_clip_actions: 56 | self._d_clip_actions_min = torch.tensor(self.action_space.low, device=self.device, dtype=torch.float32) 57 | self._d_clip_actions_max = torch.tensor(self.action_space.high, device=self.device, dtype=torch.float32) 58 | 59 | def act( 60 | self, inputs: Mapping[str, Union[torch.Tensor, Any]], role: str = "" 61 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]: 62 | """Act deterministically in response to the state of the environment 63 | 64 | :param inputs: Model inputs. The most common keys are: 65 | 66 | - ``"states"``: state of the environment used to make the decision 67 | - ``"taken_actions"``: actions taken by the policy for the given states 68 | :type inputs: dict where the values are typically torch.Tensor 69 | :param role: Role play by the model (default: ``""``) 70 | :type role: str, optional 71 | 72 | :return: Model output. The first component is the action to be taken by the agent. 73 | The second component is ``None``. The third component is a dictionary containing extra output values 74 | :rtype: tuple of torch.Tensor, torch.Tensor or None, and dict 75 | 76 | Example:: 77 | 78 | >>> # given a batch of sample states with shape (4096, 60) 79 | >>> actions, _, outputs = model.act({"states": states}) 80 | >>> print(actions.shape, outputs) 81 | torch.Size([4096, 1]) {} 82 | """ 83 | # map from observations/states to actions 84 | actions, outputs = self.compute(inputs, role) 85 | 86 | # clip actions 87 | if self._d_clip_actions: 88 | actions = torch.clamp(actions, min=self._d_clip_actions_min, max=self._d_clip_actions_max) 89 | 90 | return actions, None, outputs 91 | -------------------------------------------------------------------------------- /src/third_parties/skrl/memories/torch/random.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | 5 | from skrl.memories.torch import Memory 6 | 7 | 8 | class RandomMemory(Memory): 9 | def __init__( 10 | self, 11 | memory_size: int, 12 | num_envs: int = 1, 13 | device: Optional[Union[str, torch.device]] = None, 14 | export: bool = False, 15 | export_format: str = "pt", 16 | export_directory: str = "", 17 | replacement=True, 18 | ) -> None: 19 | """Random sampling memory 20 | 21 | Sample a batch from memory randomly 22 | 23 | :param memory_size: Maximum number of elements in the first dimension of each internal storage 24 | :type memory_size: int 25 | :param num_envs: Number of parallel environments (default: ``1``) 26 | :type num_envs: int, optional 27 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 28 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 29 | :type device: str or torch.device, optional 30 | :param export: Export the memory to a file (default: ``False``). 31 | If True, the memory will be exported when the memory is filled 32 | :type export: bool, optional 33 | :param export_format: Export format (default: ``"pt"``). 34 | Supported formats: torch (pt), numpy (np), comma separated values (csv) 35 | :type export_format: str, optional 36 | :param export_directory: Directory where the memory will be exported (default: ``""``). 37 | If empty, the agent's experiment directory will be used 38 | :type export_directory: str, optional 39 | :param replacement: Flag to indicate whether the sample is with or without replacement (default: ``True``). 40 | Replacement implies that a value can be selected multiple times (the batch size is always guaranteed). 41 | Sampling without replacement will return a batch of maximum memory size if the memory size is less than the requested batch size 42 | :type replacement: bool, optional 43 | 44 | :raises ValueError: The export format is not supported 45 | """ 46 | super().__init__(memory_size, num_envs, device, export, export_format, export_directory) 47 | 48 | self._replacement = replacement 49 | 50 | def sample( 51 | self, names: Tuple[str], batch_size: int, mini_batches: int = 1, sequence_length: int = 1 52 | ) -> List[List[torch.Tensor]]: 53 | """Sample a batch from memory randomly 54 | 55 | :param names: Tensors names from which to obtain the samples 56 | :type names: tuple or list of strings 57 | :param batch_size: Number of element to sample 58 | :type batch_size: int 59 | :param mini_batches: Number of mini-batches to sample (default: ``1``) 60 | :type mini_batches: int, optional 61 | :param sequence_length: Length of each sequence (default: ``1``) 62 | :type sequence_length: int, optional 63 | 64 | :return: Sampled data from tensors sorted according to their position in the list of names. 65 | The sampled tensors will have the following shape: (batch size, data size) 66 | :rtype: list of torch.Tensor list 67 | """ 68 | # compute valid memory sizes 69 | size = len(self) 70 | if sequence_length > 1: 71 | sequence_indexes = torch.arange(0, self.num_envs * sequence_length, self.num_envs) 72 | size -= sequence_indexes[-1].item() 73 | 74 | # generate random indexes 75 | if self._replacement: 76 | indexes = torch.randint(0, size, (batch_size,)) 77 | else: 78 | # details about the random sampling performance can be found here: 79 | # https://discuss.pytorch.org/t/torch-equivalent-of-numpy-random-choice/16146/19 80 | indexes = torch.randperm(size, dtype=torch.long)[:batch_size] 81 | 82 | # generate sequence indexes 83 | if sequence_length > 1: 84 | indexes = (sequence_indexes.repeat(indexes.shape[0], 1) + indexes.view(-1, 1)).view(-1) 85 | 86 | self.sampling_indexes = indexes 87 | return self.sample_by_index(names=names, indexes=indexes, mini_batches=mini_batches) 88 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/distributed/jax/launcher.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Sequence 2 | 3 | import argparse 4 | import multiprocessing as mp 5 | import os 6 | import subprocess 7 | import sys 8 | 9 | 10 | def _get_args_parser() -> argparse.ArgumentParser: 11 | """Instantiate and configure the argument parser object 12 | 13 | :return: Argument parser object 14 | :rtype: argparse.ArgumentParser 15 | """ 16 | parser = argparse.ArgumentParser(description="JAX Distributed Training Launcher") 17 | 18 | # worker/node size related arguments 19 | parser.add_argument("--nnodes", type=int, default=1, help="Number of nodes") 20 | parser.add_argument("--nproc-per-node", "--nproc_per_node", type=int, default=1, help="Number of workers per node") 21 | parser.add_argument( 22 | "--node-rank", "--node_rank", type=int, default=0, help="Node rank for multi-node distributed training" 23 | ) 24 | 25 | # coordinator related arguments 26 | parser.add_argument( 27 | "--coordinator-address", 28 | "--coordinator_address", 29 | type=str, 30 | default="127.0.0.1:5000", 31 | help="IP address and port where process 0 will start a JAX service", 32 | ) 33 | 34 | # positional arguments 35 | parser.add_argument("script", type=str, help="Training script path to be launched in parallel") 36 | parser.add_argument("script_args", nargs="...", help="Arguments for the training script") 37 | 38 | return parser 39 | 40 | 41 | def _start_processes( 42 | cmd: Sequence[str], 43 | envs: Sequence[Mapping[str, str]], 44 | nprocs: int, 45 | daemon: bool = False, 46 | start_method: str = "spawn", 47 | ) -> None: 48 | """Start child processes according the specified configuration and wait for them to join 49 | 50 | :param cmd: Command to run on each child process 51 | :type cmd: list of str 52 | :param envs: List of environment variables for each child process 53 | :type envs: list of dictionaries 54 | :param nprocs: Number of child processes to start 55 | :type nprocs: int 56 | :param daemon: Whether the child processes are daemonic (default: ``False``). 57 | See Python multiprocessing module for more details 58 | :type daemon: bool 59 | :param start_method: Method which should be used to start child processes (default: ``"spawn"``). 60 | See Python multiprocessing module for more details 61 | :type start_method: str 62 | """ 63 | mp.set_start_method(method=start_method, force=True) 64 | 65 | processes = [] 66 | for i in range(nprocs): 67 | process = mp.Process(target=_process, args=(cmd, envs[i]), daemon=daemon) 68 | processes.append(process) 69 | process.start() 70 | 71 | for process in processes: 72 | process.join() 73 | 74 | 75 | def _process(cmd: Sequence[str], env: Mapping[str, str]) -> None: 76 | """Run a command in the current process 77 | 78 | :param cmd: Command to run 79 | :type cmd: list of str 80 | :param envs: Environment variables for the current process 81 | :type envs: dict 82 | """ 83 | subprocess.run(cmd, env=env) 84 | 85 | 86 | def launch(): 87 | """Main entry point for launching distributed runs""" 88 | args = _get_args_parser().parse_args() 89 | 90 | # validate distributed config 91 | if args.nnodes < 1: 92 | print(f"[ERROR] Number of nodes ({args.nnodes}) must be greater than 0") 93 | exit() 94 | if args.node_rank >= args.nnodes: 95 | print(f"[ERROR] Node rank ({args.node_rank}) is out of range for the available number of nodes ({args.nnodes})") 96 | exit() 97 | 98 | # define custom environment variables (see skrl.config.jax for more details) 99 | envs = [] 100 | for i in range(args.nnodes): 101 | if i == args.node_rank: 102 | for j in range(args.nproc_per_node): 103 | env = os.environ.copy() 104 | env["JAX_LOCAL_RANK"] = str(j) 105 | env["JAX_RANK"] = str(i * args.nproc_per_node + j) 106 | env["JAX_WORLD_SIZE"] = str(args.nnodes * args.nproc_per_node) 107 | env["JAX_COORDINATOR_ADDR"] = args.coordinator_address.split(":")[0] 108 | env["JAX_COORDINATOR_PORT"] = args.coordinator_address.split(":")[1] 109 | envs.append(env) 110 | 111 | # spawn processes 112 | cmd = [sys.executable, args.script, *args.script_args] 113 | _start_processes(cmd, envs, nprocs=args.nproc_per_node) 114 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/gymnasium_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import gymnasium 4 | 5 | import torch 6 | 7 | from skrl import logger 8 | from skrl.envs.wrappers.torch.base import Wrapper 9 | from skrl.utils.spaces.torch import ( 10 | flatten_tensorized_space, 11 | tensorize_space, 12 | unflatten_tensorized_space, 13 | untensorize_space, 14 | ) 15 | 16 | 17 | class GymnasiumWrapper(Wrapper): 18 | def __init__(self, env: Any) -> None: 19 | """Gymnasium environment wrapper 20 | 21 | :param env: The environment to wrap 22 | :type env: Any supported Gymnasium environment 23 | """ 24 | super().__init__(env) 25 | 26 | self._vectorized = False 27 | try: 28 | self._vectorized = self._vectorized or isinstance(env, gymnasium.vector.VectorEnv) 29 | except Exception as e: 30 | pass 31 | try: 32 | self._vectorized = self._vectorized or isinstance(env, gymnasium.experimental.vector.VectorEnv) 33 | except Exception as e: 34 | logger.warning(f"Failed to check for a vectorized environment: {e}") 35 | if self._vectorized: 36 | self._reset_once = True 37 | self._observation = None 38 | self._info = None 39 | 40 | @property 41 | def observation_space(self) -> gymnasium.Space: 42 | """Observation space""" 43 | if self._vectorized: 44 | return self._env.single_observation_space 45 | return self._env.observation_space 46 | 47 | @property 48 | def action_space(self) -> gymnasium.Space: 49 | """Action space""" 50 | if self._vectorized: 51 | return self._env.single_action_space 52 | return self._env.action_space 53 | 54 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 55 | """Perform a step in the environment 56 | 57 | :param actions: The actions to perform 58 | :type actions: torch.Tensor 59 | 60 | :return: Observation, reward, terminated, truncated, info 61 | :rtype: tuple of torch.Tensor and any other info 62 | """ 63 | actions = untensorize_space( 64 | self.action_space, 65 | unflatten_tensorized_space(self.action_space, actions), 66 | squeeze_batch_dimension=not self._vectorized, 67 | ) 68 | 69 | observation, reward, terminated, truncated, info = self._env.step(actions) 70 | 71 | # convert response to torch 72 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, device=self.device)) 73 | reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1) 74 | terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) 75 | truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) 76 | 77 | # save observation and info for vectorized envs 78 | if self._vectorized: 79 | self._observation = observation 80 | self._info = info 81 | 82 | return observation, reward, terminated, truncated, info 83 | 84 | def reset(self) -> Tuple[torch.Tensor, Any]: 85 | """Reset the environment 86 | 87 | :return: Observation, info 88 | :rtype: torch.Tensor and any other info 89 | """ 90 | # handle vectorized environments (vector environments are autoreset) 91 | if self._vectorized: 92 | if self._reset_once: 93 | observation, self._info = self._env.reset() 94 | self._observation = flatten_tensorized_space( 95 | tensorize_space(self.observation_space, observation, device=self.device) 96 | ) 97 | self._reset_once = False 98 | return self._observation, self._info 99 | 100 | observation, info = self._env.reset() 101 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, device=self.device)) 102 | return observation, info 103 | 104 | def render(self, *args, **kwargs) -> Any: 105 | """Render the environment""" 106 | if self._vectorized: 107 | return self._env.call("render", *args, **kwargs) 108 | return self._env.render(*args, **kwargs) 109 | 110 | def close(self) -> None: 111 | """Close the environment""" 112 | self._env.close() 113 | -------------------------------------------------------------------------------- /src/third_parties/skrl/models/jax/deterministic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Tuple, Union 2 | 3 | import gymnasium 4 | 5 | import flax 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | 11 | class DeterministicMixin: 12 | def __init__(self, clip_actions: bool = False, role: str = "") -> None: 13 | """Deterministic mixin model (deterministic model) 14 | 15 | :param clip_actions: Flag to indicate whether the actions should be clipped to the action space (default: ``False``) 16 | :type clip_actions: bool, optional 17 | :param role: Role play by the model (default: ``""``) 18 | :type role: str, optional 19 | 20 | Example:: 21 | 22 | # define the model 23 | >>> import flax.linen as nn 24 | >>> from skrl.models.jax import Model, DeterministicMixin 25 | >>> 26 | >>> class Value(DeterministicMixin, Model): 27 | ... def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs): 28 | ... Model.__init__(self, observation_space, action_space, device, **kwargs) 29 | ... DeterministicMixin.__init__(self, clip_actions) 30 | ... 31 | ... @nn.compact # marks the given module method allowing inlined submodules 32 | ... def __call__(self, inputs, role): 33 | ... x = nn.elu(nn.Dense(32)([inputs["states"])) 34 | ... x = nn.elu(nn.Dense(32)(x)) 35 | ... x = nn.Dense(1)(x) 36 | ... return x, {} 37 | ... 38 | >>> # given an observation_space: gymnasium.spaces.Box with shape (60,) 39 | >>> # and an action_space: gymnasium.spaces.Box with shape (8,) 40 | >>> model = Value(observation_space, action_space) 41 | >>> 42 | >>> print(model) 43 | Value( 44 | # attributes 45 | observation_space = Box(-1.0, 1.0, (60,), float32) 46 | action_space = Box(-1.0, 1.0, (8,), float32) 47 | device = StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0) 48 | ) 49 | """ 50 | self._d_clip_actions = clip_actions and isinstance(self.action_space, gymnasium.Space) 51 | 52 | if self._d_clip_actions: 53 | self._d_clip_actions_min = jnp.array(self.action_space.low, dtype=jnp.float32) 54 | self._d_clip_actions_max = jnp.array(self.action_space.high, dtype=jnp.float32) 55 | 56 | # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.IncorrectPostInitOverrideError 57 | flax.linen.Module.__post_init__(self) 58 | 59 | def act( 60 | self, 61 | inputs: Mapping[str, Union[Union[np.ndarray, jax.Array], Any]], 62 | role: str = "", 63 | params: Optional[jax.Array] = None, 64 | ) -> Tuple[jax.Array, Union[jax.Array, None], Mapping[str, Union[jax.Array, Any]]]: 65 | """Act deterministically in response to the state of the environment 66 | 67 | :param inputs: Model inputs. The most common keys are: 68 | 69 | - ``"states"``: state of the environment used to make the decision 70 | - ``"taken_actions"``: actions taken by the policy for the given states 71 | :type inputs: dict where the values are typically np.ndarray or jax.Array 72 | :param role: Role play by the model (default: ``""``) 73 | :type role: str, optional 74 | :param params: Parameters used to compute the output (default: ``None``). 75 | If ``None``, internal parameters will be used 76 | :type params: jnp.array 77 | 78 | :return: Model output. The first component is the action to be taken by the agent. 79 | The second component is ``None``. The third component is a dictionary containing extra output values 80 | :rtype: tuple of jax.Array, jax.Array or None, and dict 81 | 82 | Example:: 83 | 84 | >>> # given a batch of sample states with shape (4096, 60) 85 | >>> actions, _, outputs = model.act({"states": states}) 86 | >>> print(actions.shape, outputs) 87 | (4096, 1) {} 88 | """ 89 | # map from observations/states to actions 90 | actions, outputs = self.apply(self.state_dict.params if params is None else params, inputs, role) 91 | 92 | # clip actions 93 | if self._d_clip_actions: 94 | actions = jnp.clip(actions, a_min=self._d_clip_actions_min, a_max=self._d_clip_actions_max) 95 | 96 | return actions, None, outputs 97 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/bidexhands_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Sequence, Tuple 2 | 3 | import gymnasium 4 | 5 | import torch 6 | 7 | from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper 8 | from skrl.utils.spaces.torch import convert_gym_space 9 | 10 | 11 | class BiDexHandsWrapper(MultiAgentEnvWrapper): 12 | def __init__(self, env: Any) -> None: 13 | """Bi-DexHands wrapper 14 | 15 | :param env: The environment to wrap 16 | :type env: Any supported Bi-DexHands environment 17 | """ 18 | super().__init__(env) 19 | 20 | self._reset_once = True 21 | self._states = None 22 | self._observations = None 23 | self._info = {} 24 | 25 | @property 26 | def agents(self) -> Sequence[str]: 27 | """Names of all current agents 28 | 29 | These may be changed as an environment progresses (i.e. agents can be added or removed) 30 | """ 31 | return self.possible_agents 32 | 33 | @property 34 | def possible_agents(self) -> Sequence[str]: 35 | """Names of all possible agents the environment could generate 36 | 37 | These can not be changed as an environment progresses 38 | """ 39 | return [f"agent_{i}" for i in range(self.num_agents)] 40 | 41 | @property 42 | def state_spaces(self) -> Mapping[str, gymnasium.Space]: 43 | """State spaces 44 | 45 | Since the state space is a global view of the environment (and therefore the same for all the agents), 46 | this property returns a dictionary (for consistency with the other space-related properties) with the same 47 | space for all the agents 48 | """ 49 | return { 50 | uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space) 51 | } 52 | 53 | @property 54 | def observation_spaces(self) -> Mapping[str, gymnasium.Space]: 55 | """Observation spaces""" 56 | return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.observation_space)} 57 | 58 | @property 59 | def action_spaces(self) -> Mapping[str, gymnasium.Space]: 60 | """Action spaces""" 61 | return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.action_space)} 62 | 63 | def step(self, actions: Mapping[str, torch.Tensor]) -> Tuple[ 64 | Mapping[str, torch.Tensor], 65 | Mapping[str, torch.Tensor], 66 | Mapping[str, torch.Tensor], 67 | Mapping[str, torch.Tensor], 68 | Mapping[str, Any], 69 | ]: 70 | """Perform a step in the environment 71 | 72 | :param actions: The actions to perform 73 | :type actions: dictionary of torch.Tensor 74 | 75 | :return: Observation, reward, terminated, truncated, info 76 | :rtype: tuple of dictionaries torch.Tensor and any other info 77 | """ 78 | actions = [actions[uid] for uid in self.possible_agents] 79 | observations, states, rewards, terminated, _, _ = self._env.step(actions) 80 | 81 | self._states = states[:, 0] 82 | self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} 83 | rewards = {uid: rewards[:, i].view(-1, 1) for i, uid in enumerate(self.possible_agents)} 84 | terminated = {uid: terminated[:, i].view(-1, 1) for i, uid in enumerate(self.possible_agents)} 85 | truncated = {uid: torch.zeros_like(value) for uid, value in terminated.items()} 86 | 87 | return self._observations, rewards, terminated, truncated, self._info 88 | 89 | def state(self) -> torch.Tensor: 90 | """Get the environment state 91 | 92 | :return: State 93 | :rtype: torch.Tensor 94 | """ 95 | return self._states 96 | 97 | def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: 98 | """Reset the environment 99 | 100 | :return: Observation, info 101 | :rtype: tuple of dictionaries of torch.Tensor and any other info 102 | """ 103 | if self._reset_once: 104 | observations, states, _ = self._env.reset() 105 | self._states = states[:, 0] 106 | self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} 107 | self._reset_once = False 108 | return self._observations, self._info 109 | 110 | def render(self, *args, **kwargs) -> None: 111 | """Render the environment""" 112 | return None 113 | 114 | def close(self) -> None: 115 | """Close the environment""" 116 | pass 117 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/jax/omniverse_isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple, Union 2 | 3 | import jax 4 | import jax.dlpack as jax_dlpack 5 | import numpy as np 6 | 7 | 8 | try: 9 | import torch 10 | import torch.utils.dlpack as torch_dlpack 11 | except: 12 | pass # TODO: show warning message 13 | else: 14 | from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space 15 | 16 | from skrl import logger 17 | from skrl.envs.wrappers.jax.base import Wrapper 18 | 19 | 20 | # ML frameworks conversion utilities 21 | # jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DLPack tensor is on GPU, but no GPU backend was provided. 22 | _CPU = jax.devices()[0].device_kind.lower() == "cpu" 23 | if _CPU: 24 | logger.warning("OmniIsaacGymEnvs runs on GPU, but there is no GPU backend for JAX. JAX operations will run on CPU.") 25 | 26 | 27 | def _jax2torch(array, device, from_jax=True): 28 | if from_jax: 29 | return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(array)).to(device=device) 30 | return torch.tensor(array, device=device) 31 | 32 | 33 | def _torch2jax(tensor, to_jax=True): 34 | if to_jax: 35 | return jax_dlpack.from_dlpack( 36 | torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous()) 37 | ) 38 | return tensor.cpu().numpy() 39 | 40 | 41 | class OmniverseIsaacGymWrapper(Wrapper): 42 | def __init__(self, env: Any) -> None: 43 | """Omniverse Isaac Gym environment wrapper 44 | 45 | :param env: The environment to wrap 46 | :type env: Any supported Omniverse Isaac Gym environment 47 | """ 48 | super().__init__(env) 49 | 50 | self._env_device = torch.device(self._unwrapped.device) 51 | self._reset_once = True 52 | self._observations = None 53 | self._info = {} 54 | 55 | def run(self, trainer: Optional["omni.isaac.gym.vec_env.vec_env_mt.TrainerMT"] = None) -> None: 56 | """Run the simulation in the main thread 57 | 58 | This method is valid only for the Omniverse Isaac Gym multi-threaded environments 59 | 60 | :param trainer: Trainer which should implement a ``run`` method that initiates the RL loop on a new thread 61 | :type trainer: omni.isaac.gym.vec_env.vec_env_mt.TrainerMT, optional 62 | """ 63 | self._env.run(trainer) 64 | 65 | def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ 66 | Union[np.ndarray, jax.Array], 67 | Union[np.ndarray, jax.Array], 68 | Union[np.ndarray, jax.Array], 69 | Union[np.ndarray, jax.Array], 70 | Any, 71 | ]: 72 | """Perform a step in the environment 73 | 74 | :param actions: The actions to perform 75 | :type actions: np.ndarray or jax.Array 76 | 77 | :return: Observation, reward, terminated, truncated, info 78 | :rtype: tuple of np.ndarray or jax.Array and any other info 79 | """ 80 | actions = _jax2torch(actions, self._env_device, self._jax) 81 | 82 | with torch.no_grad(): 83 | observations, reward, terminated, self._info = self._env.step( 84 | unflatten_tensorized_space(self.action_space, actions) 85 | ) 86 | 87 | observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 88 | terminated = terminated.to(dtype=torch.int8) 89 | truncated = ( 90 | self._info["time_outs"].to(dtype=torch.int8) if "time_outs" in self._info else torch.zeros_like(terminated) 91 | ) 92 | 93 | self._observations = _torch2jax(observations, self._jax) 94 | return ( 95 | self._observations, 96 | _torch2jax(reward.view(-1, 1), self._jax), 97 | _torch2jax(terminated.view(-1, 1), self._jax), 98 | _torch2jax(truncated.view(-1, 1), self._jax), 99 | self._info, 100 | ) 101 | 102 | def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: 103 | """Reset the environment 104 | 105 | :return: Observation, info 106 | :rtype: np.ndarray or jax.Array and any other info 107 | """ 108 | if self._reset_once: 109 | observations = self._env.reset() 110 | observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 111 | self._observations = _torch2jax(observations, self._jax) 112 | self._reset_once = False 113 | return self._observations, self._info 114 | 115 | def render(self, *args, **kwargs) -> None: 116 | """Render the environment""" 117 | return None 118 | 119 | def close(self) -> None: 120 | """Close the environment""" 121 | self._env.close() 122 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/torch/deterministic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Sequence, Tuple, Union 2 | 3 | import textwrap 4 | import gymnasium 5 | 6 | import torch 7 | import torch.nn as nn # noqa 8 | 9 | from skrl.models.torch import DeterministicMixin # noqa 10 | from skrl.models.torch import Model 11 | from skrl.utils.model_instantiators.torch.common import one_hot_encoding # noqa 12 | from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers 13 | from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa 14 | 15 | 16 | def deterministic_model( 17 | observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 18 | action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 19 | device: Optional[Union[str, torch.device]] = None, 20 | clip_actions: bool = False, 21 | network: Sequence[Mapping[str, Any]] = [], 22 | output: Union[str, Sequence[str]] = "", 23 | return_source: bool = False, 24 | *args, 25 | **kwargs, 26 | ) -> Union[Model, str]: 27 | """Instantiate a deterministic model 28 | 29 | :param observation_space: Observation/state space or shape (default: None). 30 | If it is not None, the num_observations property will contain the size of that space 31 | :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional 32 | :param action_space: Action space or shape (default: None). 33 | If it is not None, the num_actions property will contain the size of that space 34 | :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional 35 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 36 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 37 | :type device: str or torch.device, optional 38 | :param clip_actions: Flag to indicate whether the actions should be clipped (default: False) 39 | :type clip_actions: bool, optional 40 | :param network: Network definition (default: []) 41 | :type network: list of dict, optional 42 | :param output: Output expression (default: "") 43 | :type output: list or str, optional 44 | :param return_source: Whether to return the source string containing the model class used to 45 | instantiate the model rather than the model instance (default: False). 46 | :type return_source: bool, optional 47 | 48 | :return: Deterministic model instance or definition source 49 | :rtype: Model 50 | """ 51 | # compatibility with versions prior to 1.3.0 52 | if not network and kwargs: 53 | network, output = convert_deprecated_parameters(kwargs) 54 | 55 | # parse model definition 56 | containers, output = generate_containers(network, output, embed_output=True, indent=1) 57 | 58 | # network definitions 59 | networks = [] 60 | forward: list[str] = [] 61 | for container in containers: 62 | networks.append(f'self.{container["name"]}_container = {container["sequential"]}') 63 | forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') 64 | # process output 65 | if output["modules"]: 66 | networks.append(f'self.output_layer = {output["modules"][0]}') 67 | forward.append(f'output = self.output_layer({container["name"]})') 68 | if output["output"]: 69 | forward.append(f'output = {output["output"]}') 70 | else: 71 | forward[-1] = forward[-1].replace(f'{container["name"]} =', "output =", 1) 72 | 73 | # build substitutions and indent content 74 | networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] 75 | forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] 76 | 77 | template = f"""class DeterministicModel(DeterministicMixin, Model): 78 | def __init__(self, observation_space, action_space, device, clip_actions): 79 | Model.__init__(self, observation_space, action_space, device) 80 | DeterministicMixin.__init__(self, clip_actions) 81 | 82 | {networks} 83 | 84 | def compute(self, inputs, role=""): 85 | states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) 86 | taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) 87 | {forward} 88 | return output, {{}} 89 | """ 90 | # return source 91 | if return_source: 92 | return template 93 | 94 | # instantiate model 95 | _locals = {} 96 | exec(template, globals(), _locals) 97 | return _locals["DeterministicModel"]( 98 | observation_space=observation_space, action_space=action_space, device=device, clip_actions=clip_actions 99 | ) 100 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/jax/deterministic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Sequence, Tuple, Union 2 | 3 | import textwrap 4 | import gymnasium 5 | 6 | import flax.linen as nn # noqa 7 | import jax 8 | import jax.numpy as jnp # noqa 9 | 10 | from skrl.models.jax import DeterministicMixin # noqa 11 | from skrl.models.jax import Model # noqa 12 | from skrl.utils.model_instantiators.jax.common import one_hot_encoding # noqa 13 | from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers 14 | from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa 15 | 16 | 17 | def deterministic_model( 18 | observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 19 | action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 20 | device: Optional[Union[str, jax.Device]] = None, 21 | clip_actions: bool = False, 22 | network: Sequence[Mapping[str, Any]] = [], 23 | output: Union[str, Sequence[str]] = "", 24 | return_source: bool = False, 25 | *args, 26 | **kwargs, 27 | ) -> Union[Model, str]: 28 | """Instantiate a deterministic model 29 | 30 | :param observation_space: Observation/state space or shape (default: None). 31 | If it is not None, the num_observations property will contain the size of that space 32 | :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional 33 | :param action_space: Action space or shape (default: None). 34 | If it is not None, the num_actions property will contain the size of that space 35 | :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional 36 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 37 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 38 | :type device: str or jax.Device, optional 39 | :param clip_actions: Flag to indicate whether the actions should be clipped (default: False) 40 | :type clip_actions: bool, optional 41 | :param network: Network definition (default: []) 42 | :type network: list of dict, optional 43 | :param output: Output expression (default: "") 44 | :type output: list or str, optional 45 | :param return_source: Whether to return the source string containing the model class used to 46 | instantiate the model rather than the model instance (default: False). 47 | :type return_source: bool, optional 48 | 49 | :return: Deterministic model instance or definition source 50 | :rtype: Model 51 | """ 52 | # compatibility with versions prior to 1.3.0 53 | if not network and kwargs: 54 | network, output = convert_deprecated_parameters(kwargs) 55 | 56 | # parse model definition 57 | containers, output = generate_containers(network, output, embed_output=True, indent=1) 58 | 59 | # network definitions 60 | networks = [] 61 | forward: list[str] = [] 62 | for container in containers: 63 | networks.append(f'self.{container["name"]}_container = {container["sequential"]}') 64 | forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') 65 | # process output 66 | if output["modules"]: 67 | networks.append(f'self.output_layer = {output["modules"][0]}') 68 | forward.append(f'output = self.output_layer({container["name"]})') 69 | if output["output"]: 70 | forward.append(f'output = {output["output"]}') 71 | else: 72 | forward[-1] = forward[-1].replace(f'{container["name"]} =', "output =", 1) 73 | 74 | # build substitutions and indent content 75 | networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] 76 | forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] 77 | 78 | template = f"""class DeterministicModel(DeterministicMixin, Model): 79 | def __init__(self, observation_space, action_space, device, clip_actions=False, **kwargs): 80 | Model.__init__(self, observation_space, action_space, device, **kwargs) 81 | DeterministicMixin.__init__(self, clip_actions) 82 | 83 | def setup(self): 84 | {networks} 85 | 86 | def __call__(self, inputs, role): 87 | states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) 88 | taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) 89 | {forward} 90 | return output, {{}} 91 | """ 92 | # return source 93 | if return_source: 94 | return template 95 | 96 | # instantiate model 97 | _locals = {} 98 | exec(template, globals(), _locals) 99 | return _locals["DeterministicModel"]( 100 | observation_space=observation_space, action_space=action_space, device=device, clip_actions=clip_actions 101 | ) 102 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/jax/pettingzoo_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Tuple, Union 2 | 3 | import collections 4 | 5 | import jax 6 | import numpy as np 7 | 8 | from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper 9 | from skrl.utils.spaces.jax import ( 10 | flatten_tensorized_space, 11 | tensorize_space, 12 | unflatten_tensorized_space, 13 | untensorize_space, 14 | ) 15 | 16 | 17 | class PettingZooWrapper(MultiAgentEnvWrapper): 18 | def __init__(self, env: Any) -> None: 19 | """PettingZoo (parallel) environment wrapper 20 | 21 | :param env: The environment to wrap 22 | :type env: Any supported PettingZoo (parallel) environment 23 | """ 24 | super().__init__(env) 25 | 26 | def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> Tuple[ 27 | Mapping[str, Union[np.ndarray, jax.Array]], 28 | Mapping[str, Union[np.ndarray, jax.Array]], 29 | Mapping[str, Union[np.ndarray, jax.Array]], 30 | Mapping[str, Union[np.ndarray, jax.Array]], 31 | Mapping[str, Any], 32 | ]: 33 | """Perform a step in the environment 34 | 35 | :param actions: The actions to perform 36 | :type actions: dict of np.ndarray or jax.Array 37 | 38 | :return: Observation, reward, terminated, truncated, info 39 | :rtype: tuple of dict of np.ndarray or jax.Array and any other info 40 | """ 41 | if self._jax: 42 | actions = jax.device_get(actions) 43 | actions = { 44 | uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) 45 | for uid, action in actions.items() 46 | } 47 | observations, rewards, terminated, truncated, infos = self._env.step(actions) 48 | 49 | # convert response to numpy or jax 50 | observations = { 51 | uid: flatten_tensorized_space( 52 | tensorize_space(self.observation_spaces[uid], value, device=self.device, _jax=False), _jax=False 53 | ) 54 | for uid, value in observations.items() 55 | } 56 | rewards = {uid: np.array(value, dtype=np.float32).reshape(self.num_envs, -1) for uid, value in rewards.items()} 57 | terminated = { 58 | uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in terminated.items() 59 | } 60 | truncated = {uid: np.array(value, dtype=np.int8).reshape(self.num_envs, -1) for uid, value in truncated.items()} 61 | if self._jax: 62 | observations = {uid: jax.device_put(value, device=self.device) for uid, value in observations.items()} 63 | rewards = {uid: jax.device_put(value, device=self.device) for uid, value in rewards.items()} 64 | terminated = {uid: jax.device_put(value, device=self.device) for uid, value in terminated.items()} 65 | truncated = {uid: jax.device_put(value, device=self.device) for uid, value in truncated.items()} 66 | return observations, rewards, terminated, truncated, infos 67 | 68 | def state(self) -> Union[np.ndarray, jax.Array]: 69 | """Get the environment state 70 | 71 | :return: State 72 | :rtype: np.ndarray or jax.Array 73 | """ 74 | state = flatten_tensorized_space( 75 | tensorize_space(next(iter(self.state_spaces.values())), self._env.state(), device=self.device, _jax=False), 76 | _jax=False, 77 | ) 78 | if self._jax: 79 | state = jax.device_put(state, device=self.device) 80 | return state 81 | 82 | def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Any]]: 83 | """Reset the environment 84 | 85 | :return: Observation, info 86 | :rtype: tuple of dict of np.ndarray or jax.Array and any other info 87 | """ 88 | outputs = self._env.reset() 89 | if isinstance(outputs, collections.abc.Mapping): 90 | observations = outputs 91 | infos = {uid: {} for uid in self.possible_agents} 92 | else: 93 | observations, infos = outputs 94 | 95 | # convert response to numpy or jax 96 | observations = { 97 | uid: flatten_tensorized_space( 98 | tensorize_space(self.observation_spaces[uid], value, device=self.device, _jax=False), _jax=False 99 | ) 100 | for uid, value in observations.items() 101 | } 102 | if self._jax: 103 | observations = {uid: jax.device_put(value, device=self.device) for uid, value in observations.items()} 104 | return observations, infos 105 | 106 | def render(self, *args, **kwargs) -> Any: 107 | """Render the environment""" 108 | return self._env.render(*args, **kwargs) 109 | 110 | def close(self) -> None: 111 | """Close the environment""" 112 | self._env.close() 113 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/torch/categorical.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Sequence, Tuple, Union 2 | 3 | import textwrap 4 | import gymnasium 5 | 6 | import torch 7 | import torch.nn as nn # noqa 8 | 9 | from skrl.models.torch import CategoricalMixin # noqa 10 | from skrl.models.torch import Model 11 | from skrl.utils.model_instantiators.torch.common import one_hot_encoding # noqa 12 | from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers 13 | from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa 14 | 15 | 16 | def categorical_model( 17 | observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 18 | action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 19 | device: Optional[Union[str, torch.device]] = None, 20 | unnormalized_log_prob: bool = True, 21 | network: Sequence[Mapping[str, Any]] = [], 22 | output: Union[str, Sequence[str]] = "", 23 | return_source: bool = False, 24 | *args, 25 | **kwargs, 26 | ) -> Union[Model, str]: 27 | """Instantiate a categorical model 28 | 29 | :param observation_space: Observation/state space or shape (default: None). 30 | If it is not None, the num_observations property will contain the size of that space 31 | :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional 32 | :param action_space: Action space or shape (default: None). 33 | If it is not None, the num_actions property will contain the size of that space 34 | :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional 35 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 36 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 37 | :type device: str or torch.device, optional 38 | :param unnormalized_log_prob: Flag to indicate how to be interpreted the model's output (default: True). 39 | If True, the model's output is interpreted as unnormalized log probabilities 40 | (it can be any real number), otherwise as normalized probabilities 41 | (the output must be non-negative, finite and have a non-zero sum) 42 | :type unnormalized_log_prob: bool, optional 43 | :param network: Network definition (default: []) 44 | :type network: list of dict, optional 45 | :param output: Output expression (default: "") 46 | :type output: list or str, optional 47 | :param return_source: Whether to return the source string containing the model class used to 48 | instantiate the model rather than the model instance (default: False). 49 | :type return_source: bool, optional 50 | 51 | :return: Categorical model instance or definition source 52 | :rtype: Model 53 | """ 54 | # compatibility with versions prior to 1.3.0 55 | if not network and kwargs: 56 | network, output = convert_deprecated_parameters(kwargs) 57 | 58 | # parse model definition 59 | containers, output = generate_containers(network, output, embed_output=True, indent=1) 60 | 61 | # network definitions 62 | networks = [] 63 | forward: list[str] = [] 64 | for container in containers: 65 | networks.append(f'self.{container["name"]}_container = {container["sequential"]}') 66 | forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') 67 | # process output 68 | if output["modules"]: 69 | networks.append(f'self.output_layer = {output["modules"][0]}') 70 | forward.append(f'output = self.output_layer({container["name"]})') 71 | if output["output"]: 72 | forward.append(f'output = {output["output"]}') 73 | else: 74 | forward[-1] = forward[-1].replace(f'{container["name"]} =', "output =", 1) 75 | 76 | # build substitutions and indent content 77 | networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] 78 | forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] 79 | 80 | template = f"""class CategoricalModel(CategoricalMixin, Model): 81 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob): 82 | Model.__init__(self, observation_space, action_space, device) 83 | CategoricalMixin.__init__(self, unnormalized_log_prob) 84 | 85 | {networks} 86 | 87 | def compute(self, inputs, role=""): 88 | states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) 89 | taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) 90 | {forward} 91 | return output, {{}} 92 | """ 93 | # return source 94 | if return_source: 95 | return template 96 | 97 | # instantiate model 98 | _locals = {} 99 | exec(template, globals(), _locals) 100 | return _locals["CategoricalModel"]( 101 | observation_space=observation_space, 102 | action_space=action_space, 103 | device=device, 104 | unnormalized_log_prob=unnormalized_log_prob, 105 | ) 106 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/deepmind_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import collections 4 | import gymnasium 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from skrl import logger 10 | from skrl.envs.wrappers.torch.base import Wrapper 11 | from skrl.utils.spaces.torch import ( 12 | flatten_tensorized_space, 13 | tensorize_space, 14 | unflatten_tensorized_space, 15 | untensorize_space, 16 | ) 17 | 18 | 19 | class DeepMindWrapper(Wrapper): 20 | def __init__(self, env: Any) -> None: 21 | """DeepMind environment wrapper 22 | 23 | :param env: The environment to wrap 24 | :type env: Any supported DeepMind environment 25 | """ 26 | super().__init__(env) 27 | 28 | from dm_env import specs 29 | 30 | self._specs = specs 31 | 32 | @property 33 | def observation_space(self) -> gymnasium.Space: 34 | """Observation space""" 35 | return self._spec_to_space(self._env.observation_spec()) 36 | 37 | @property 38 | def action_space(self) -> gymnasium.Space: 39 | """Action space""" 40 | return self._spec_to_space(self._env.action_spec()) 41 | 42 | def _spec_to_space(self, spec: Any) -> gymnasium.Space: 43 | """Convert the DeepMind spec to a gymnasium space 44 | 45 | :param spec: The DeepMind spec to convert 46 | :type spec: Any supported DeepMind spec 47 | 48 | :raises: ValueError if the spec type is not supported 49 | 50 | :return: The gymnasium space 51 | :rtype: gymnasium.Space 52 | """ 53 | if isinstance(spec, self._specs.DiscreteArray): 54 | return gymnasium.spaces.Discrete(spec.num_values) 55 | elif isinstance(spec, self._specs.BoundedArray): 56 | return gymnasium.spaces.Box( 57 | shape=spec.shape, 58 | dtype=spec.dtype, 59 | low=spec.minimum if spec.minimum.ndim else np.full(spec.shape, spec.minimum), 60 | high=spec.maximum if spec.maximum.ndim else np.full(spec.shape, spec.maximum), 61 | ) 62 | elif isinstance(spec, self._specs.Array): 63 | return gymnasium.spaces.Box( 64 | shape=spec.shape, 65 | dtype=spec.dtype, 66 | low=np.full(spec.shape, float("-inf")), 67 | high=np.full(spec.shape, float("inf")), 68 | ) 69 | elif isinstance(spec, collections.OrderedDict): 70 | return gymnasium.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) 71 | else: 72 | raise ValueError(f"Spec type {type(spec)} not supported. Please report this issue") 73 | 74 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 75 | """Perform a step in the environment 76 | 77 | :param actions: The actions to perform 78 | :type actions: torch.Tensor 79 | 80 | :return: Observation, reward, terminated, truncated, info 81 | :rtype: tuple of torch.Tensor and any other info 82 | """ 83 | actions = untensorize_space(self.action_space, unflatten_tensorized_space(self.action_space, actions)) 84 | timestep = self._env.step(actions) 85 | 86 | observation = flatten_tensorized_space( 87 | tensorize_space(self.observation_space, timestep.observation, device=self.device) 88 | ) 89 | reward = timestep.reward if timestep.reward is not None else 0 90 | terminated = timestep.last() 91 | truncated = False 92 | info = {} 93 | 94 | # convert response to torch 95 | return ( 96 | observation, 97 | torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1), 98 | torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), 99 | torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), 100 | info, 101 | ) 102 | 103 | def reset(self) -> Tuple[torch.Tensor, Any]: 104 | """Reset the environment 105 | 106 | :return: The state of the environment 107 | :rtype: torch.Tensor 108 | """ 109 | timestep = self._env.reset() 110 | observation = flatten_tensorized_space( 111 | tensorize_space(self.observation_space, timestep.observation, device=self.device) 112 | ) 113 | return observation, {} 114 | 115 | def render(self, *args, **kwargs) -> np.ndarray: 116 | """Render the environment 117 | 118 | OpenCV is used to render the environment. 119 | Install OpenCV with ``pip install opencv-python`` 120 | """ 121 | frame = self._env.physics.render(480, 640, camera_id=0) 122 | 123 | # render the frame using OpenCV 124 | try: 125 | import cv2 126 | 127 | cv2.imshow("env", cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 128 | cv2.waitKey(1) 129 | except ImportError as e: 130 | logger.warning(f"Unable to import opencv-python: {e}. Frame will not be rendered.") 131 | return frame 132 | 133 | def close(self) -> None: 134 | """Close the environment""" 135 | self._env.close() 136 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/jax/categorical.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Sequence, Tuple, Union 2 | 3 | import textwrap 4 | import gymnasium 5 | 6 | import flax.linen as nn # noqa 7 | import jax 8 | import jax.numpy as jnp # noqa 9 | 10 | from skrl.models.jax import CategoricalMixin # noqa 11 | from skrl.models.jax import Model # noqa 12 | from skrl.utils.model_instantiators.jax.common import one_hot_encoding # noqa 13 | from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers 14 | from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa 15 | 16 | 17 | def categorical_model( 18 | observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 19 | action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 20 | device: Optional[Union[str, jax.Device]] = None, 21 | unnormalized_log_prob: bool = True, 22 | network: Sequence[Mapping[str, Any]] = [], 23 | output: Union[str, Sequence[str]] = "", 24 | return_source: bool = False, 25 | *args, 26 | **kwargs, 27 | ) -> Union[Model, str]: 28 | """Instantiate a categorical model 29 | 30 | :param observation_space: Observation/state space or shape (default: None). 31 | If it is not None, the num_observations property will contain the size of that space 32 | :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional 33 | :param action_space: Action space or shape (default: None). 34 | If it is not None, the num_actions property will contain the size of that space 35 | :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional 36 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 37 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 38 | :type device: str or jax.Device, optional 39 | :param unnormalized_log_prob: Flag to indicate how to be interpreted the model's output (default: True). 40 | If True, the model's output is interpreted as unnormalized log probabilities 41 | (it can be any real number), otherwise as normalized probabilities 42 | (the output must be non-negative, finite and have a non-zero sum) 43 | :type unnormalized_log_prob: bool, optional 44 | :param network: Network definition (default: []) 45 | :type network: list of dict, optional 46 | :param output: Output expression (default: "") 47 | :type output: list or str, optional 48 | :param return_source: Whether to return the source string containing the model class used to 49 | instantiate the model rather than the model instance (default: False). 50 | :type return_source: bool, optional 51 | 52 | :return: Categorical model instance or definition source 53 | :rtype: Model 54 | """ 55 | # compatibility with versions prior to 1.3.0 56 | if not network and kwargs: 57 | network, output = convert_deprecated_parameters(kwargs) 58 | 59 | # parse model definition 60 | containers, output = generate_containers(network, output, embed_output=True, indent=1) 61 | 62 | # network definitions 63 | networks = [] 64 | forward: list[str] = [] 65 | for container in containers: 66 | networks.append(f'self.{container["name"]}_container = {container["sequential"]}') 67 | forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') 68 | # process output 69 | if output["modules"]: 70 | networks.append(f'self.output_layer = {output["modules"][0]}') 71 | forward.append(f'output = self.output_layer({container["name"]})') 72 | if output["output"]: 73 | forward.append(f'output = {output["output"]}') 74 | else: 75 | forward[-1] = forward[-1].replace(f'{container["name"]} =', "output =", 1) 76 | 77 | # build substitutions and indent content 78 | networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] 79 | forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] 80 | 81 | template = f"""class CategoricalModel(CategoricalMixin, Model): 82 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, **kwargs): 83 | Model.__init__(self, observation_space, action_space, device, **kwargs) 84 | CategoricalMixin.__init__(self, unnormalized_log_prob) 85 | 86 | def setup(self): 87 | {networks} 88 | 89 | def __call__(self, inputs, role): 90 | states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) 91 | taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) 92 | {forward} 93 | return output, {{}} 94 | """ 95 | # return source 96 | if return_source: 97 | return template 98 | 99 | # instantiate model 100 | _locals = {} 101 | exec(template, globals(), _locals) 102 | return _locals["CategoricalModel"]( 103 | observation_space=observation_space, 104 | action_space=action_space, 105 | device=device, 106 | unnormalized_log_prob=unnormalized_log_prob, 107 | ) 108 | -------------------------------------------------------------------------------- /src/third_parties/skrl/resources/optimizers/jax/adam.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import functools 4 | 5 | import flax 6 | import jax 7 | import optax 8 | 9 | from skrl.models.jax import Model 10 | 11 | 12 | # https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function 13 | @functools.partial(jax.jit, static_argnames=("transformation")) 14 | def _step(transformation, grad, state, state_dict): 15 | # optax transform 16 | params, optimizer_state = transformation.update(grad, state, state_dict.params) 17 | # apply transformation 18 | params = optax.apply_updates(state_dict.params, params) 19 | return optimizer_state, state_dict.replace(params=params) 20 | 21 | 22 | @functools.partial(jax.jit, static_argnames=("transformation")) 23 | def _step_with_scale(transformation, grad, state, state_dict, scale): 24 | # optax transform 25 | params, optimizer_state = transformation.update(grad, state, state_dict.params) 26 | # custom scale 27 | # https://optax.readthedocs.io/en/latest/api/transformations.html#optax.scale 28 | params = jax.tree_util.tree_map(lambda params: scale * params, params) 29 | # apply transformation 30 | params = optax.apply_updates(state_dict.params, params) 31 | return optimizer_state, state_dict.replace(params=params) 32 | 33 | 34 | class Adam: 35 | def __new__(cls, model: Model, lr: float = 1e-3, grad_norm_clip: float = 0, scale: bool = True) -> "Optimizer": 36 | """Adam optimizer 37 | 38 | Adapted from `Optax's Adam `_ 39 | to support custom scale (learning rate) 40 | 41 | :param model: Model 42 | :type model: skrl.models.jax.Model 43 | :param lr: Learning rate (default: ``1e-3``) 44 | :type lr: float, optional 45 | :param grad_norm_clip: Clipping coefficient for the norm of the gradients (default: ``0``). 46 | Disabled if less than or equal to zero 47 | :type grad_norm_clip: float, optional 48 | :param scale: Whether to instantiate the optimizer as-is or remove the scaling step (default: ``True``). 49 | Remove the scaling step if a custom learning rate is to be applied during optimization steps 50 | :type scale: bool, optional 51 | 52 | :return: Adam optimizer 53 | :rtype: flax.struct.PyTreeNode 54 | 55 | Example:: 56 | 57 | >>> optimizer = Adam(model=policy, lr=5e-4) 58 | >>> # step the optimizer given a computed gradiend (grad) 59 | >>> optimizer = optimizer.step(grad, policy) 60 | 61 | # apply custom learning rate during optimization steps 62 | >>> optimizer = Adam(model=policy, lr=5e-4, scale=False) 63 | >>> # step the optimizer given a computed gradiend and an updated learning rate (lr) 64 | >>> optimizer = optimizer.step(grad, policy, lr) 65 | """ 66 | 67 | class Optimizer(flax.struct.PyTreeNode): 68 | """Optimizer 69 | 70 | This class is the result of isolating the Optax optimizer, 71 | which is mixed with the model parameters, from Flax's TrainState class 72 | 73 | https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#train-state 74 | """ 75 | 76 | transformation: optax.GradientTransformation = flax.struct.field(pytree_node=False) 77 | state: optax.OptState = flax.struct.field(pytree_node=True) 78 | 79 | @classmethod 80 | def _create(cls, *, transformation, state, **kwargs): 81 | return cls(transformation=transformation, state=state, **kwargs) 82 | 83 | def step(self, grad: jax.Array, model: Model, lr: Optional[float] = None) -> "Optimizer": 84 | """Performs a single optimization step 85 | 86 | :param grad: Gradients 87 | :type grad: jax.Array 88 | :param model: Model 89 | :type model: skrl.models.jax.Model 90 | :param lr: Learning rate. 91 | If given, a scale optimization step will be performed 92 | :type lr: float, optional 93 | 94 | :return: Optimizer 95 | :rtype: flax.struct.PyTreeNode 96 | """ 97 | if lr is None: 98 | optimizer_state, model.state_dict = _step(self.transformation, grad, self.state, model.state_dict) 99 | else: 100 | optimizer_state, model.state_dict = _step_with_scale( 101 | self.transformation, grad, self.state, model.state_dict, -lr 102 | ) 103 | return self.replace(state=optimizer_state) 104 | 105 | # default optax transformation 106 | if scale: 107 | transformation = optax.adam(learning_rate=lr) 108 | # optax transformation without scaling step 109 | else: 110 | transformation = optax.scale_by_adam() 111 | 112 | # clip updates using their global norm 113 | if grad_norm_clip > 0: 114 | transformation = optax.chain(optax.clip_by_global_norm(grad_norm_clip), transformation) 115 | 116 | return Optimizer._create(transformation=transformation, state=transformation.init(model.state_dict.params)) 117 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/jax/gymnasium_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Union 2 | 3 | import gymnasium 4 | 5 | import jax 6 | import numpy as np 7 | 8 | from skrl import logger 9 | from skrl.envs.wrappers.jax.base import Wrapper 10 | from skrl.utils.spaces.jax import ( 11 | flatten_tensorized_space, 12 | tensorize_space, 13 | unflatten_tensorized_space, 14 | untensorize_space, 15 | ) 16 | 17 | 18 | class GymnasiumWrapper(Wrapper): 19 | def __init__(self, env: Any) -> None: 20 | """Gymnasium environment wrapper 21 | 22 | :param env: The environment to wrap 23 | :type env: Any supported Gymnasium environment 24 | """ 25 | super().__init__(env) 26 | 27 | self._vectorized = False 28 | try: 29 | self._vectorized = self._vectorized or isinstance(env, gymnasium.vector.VectorEnv) 30 | except Exception as e: 31 | pass 32 | try: 33 | self._vectorized = self._vectorized or isinstance(env, gymnasium.experimental.vector.VectorEnv) 34 | except Exception as e: 35 | logger.warning(f"Failed to check for a vectorized environment: {e}") 36 | if self._vectorized: 37 | self._reset_once = True 38 | self._observation = None 39 | self._info = None 40 | 41 | @property 42 | def observation_space(self) -> gymnasium.Space: 43 | """Observation space""" 44 | if self._vectorized: 45 | return self._env.single_observation_space 46 | return self._env.observation_space 47 | 48 | @property 49 | def action_space(self) -> gymnasium.Space: 50 | """Action space""" 51 | if self._vectorized: 52 | return self._env.single_action_space 53 | return self._env.action_space 54 | 55 | def step(self, actions: Union[np.ndarray, jax.Array]) -> Tuple[ 56 | Union[np.ndarray, jax.Array], 57 | Union[np.ndarray, jax.Array], 58 | Union[np.ndarray, jax.Array], 59 | Union[np.ndarray, jax.Array], 60 | Any, 61 | ]: 62 | """Perform a step in the environment 63 | 64 | :param actions: The actions to perform 65 | :type actions: np.ndarray or jax.Array 66 | 67 | :return: Observation, reward, terminated, truncated, info 68 | :rtype: tuple of np.ndarray or jax.Array and any other info 69 | """ 70 | if self._jax or isinstance(actions, jax.Array): 71 | actions = np.asarray(jax.device_get(actions)) 72 | actions = untensorize_space( 73 | self.action_space, 74 | unflatten_tensorized_space(self.action_space, actions), 75 | squeeze_batch_dimension=not self._vectorized, 76 | ) 77 | 78 | observation, reward, terminated, truncated, info = self._env.step(actions) 79 | 80 | # convert response to numpy or jax 81 | observation = flatten_tensorized_space( 82 | tensorize_space(self.observation_space, observation, device=self.device, _jax=False), _jax=False 83 | ) 84 | reward = np.array(reward, dtype=np.float32).reshape(self.num_envs, -1) 85 | terminated = np.array(terminated, dtype=np.int8).reshape(self.num_envs, -1) 86 | truncated = np.array(truncated, dtype=np.int8).reshape(self.num_envs, -1) 87 | if self._jax: 88 | observation = jax.device_put(observation, device=self.device) 89 | reward = jax.device_put(reward, device=self.device) 90 | terminated = jax.device_put(terminated, device=self.device) 91 | truncated = jax.device_put(truncated, device=self.device) 92 | 93 | # save observation and info for vectorized envs 94 | if self._vectorized: 95 | self._observation = observation 96 | self._info = info 97 | 98 | return observation, reward, terminated, truncated, info 99 | 100 | def reset(self) -> Tuple[Union[np.ndarray, jax.Array], Any]: 101 | """Reset the environment 102 | 103 | :return: Observation, info 104 | :rtype: np.ndarray or jax.Array and any other info 105 | """ 106 | # handle vectorized environments (vector environments are autoreset) 107 | if self._vectorized: 108 | if self._reset_once: 109 | observation, self._info = self._env.reset() 110 | self._observation = flatten_tensorized_space( 111 | tensorize_space(self.observation_space, observation, device=self.device, _jax=False), _jax=False 112 | ) 113 | if self._jax: 114 | self._observation = jax.device_put(self._observation, device=self.device) 115 | self._reset_once = False 116 | return self._observation, self._info 117 | 118 | observation, info = self._env.reset() 119 | 120 | # convert response to numpy or jax 121 | observation = flatten_tensorized_space( 122 | tensorize_space(self.observation_space, observation, device=self.device, _jax=False), _jax=False 123 | ) 124 | if self._jax: 125 | observation = jax.device_put(observation, device=self.device) 126 | return observation, info 127 | 128 | def render(self, *args, **kwargs) -> Any: 129 | """Render the environment""" 130 | if self._vectorized: 131 | return self._env.call("render", *args, **kwargs) 132 | return self._env.render(*args, **kwargs) 133 | 134 | def close(self) -> None: 135 | """Close the environment""" 136 | self._env.close() 137 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/isaacgym_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Union 2 | 3 | import gymnasium 4 | 5 | import torch 6 | 7 | from skrl.envs.wrappers.torch.base import Wrapper 8 | from skrl.utils.spaces.torch import ( 9 | convert_gym_space, 10 | flatten_tensorized_space, 11 | tensorize_space, 12 | unflatten_tensorized_space, 13 | ) 14 | 15 | 16 | class IsaacGymPreview2Wrapper(Wrapper): 17 | def __init__(self, env: Any) -> None: 18 | """Isaac Gym environment (preview 2) wrapper 19 | 20 | :param env: The environment to wrap 21 | :type env: Any supported Isaac Gym environment (preview 2) environment 22 | """ 23 | super().__init__(env) 24 | 25 | self._reset_once = True 26 | self._observations = None 27 | self._info = {} 28 | 29 | @property 30 | def observation_space(self) -> gymnasium.Space: 31 | """Observation space""" 32 | return convert_gym_space(self._unwrapped.observation_space) 33 | 34 | @property 35 | def action_space(self) -> gymnasium.Space: 36 | """Action space""" 37 | return convert_gym_space(self._unwrapped.action_space) 38 | 39 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 40 | """Perform a step in the environment 41 | 42 | :param actions: The actions to perform 43 | :type actions: torch.Tensor 44 | 45 | :return: Observation, reward, terminated, truncated, info 46 | :rtype: tuple of torch.Tensor and any other info 47 | """ 48 | observations, reward, terminated, self._info = self._env.step( 49 | unflatten_tensorized_space(self.action_space, actions) 50 | ) 51 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) 52 | truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) 53 | return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info 54 | 55 | def reset(self) -> Tuple[torch.Tensor, Any]: 56 | """Reset the environment 57 | 58 | :return: Observation, info 59 | :rtype: torch.Tensor and any other info 60 | """ 61 | if self._reset_once: 62 | observations = self._env.reset() 63 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations)) 64 | self._reset_once = False 65 | return self._observations, self._info 66 | 67 | def render(self, *args, **kwargs) -> None: 68 | """Render the environment""" 69 | return None 70 | 71 | def close(self) -> None: 72 | """Close the environment""" 73 | pass 74 | 75 | 76 | class IsaacGymPreview3Wrapper(Wrapper): 77 | def __init__(self, env: Any) -> None: 78 | """Isaac Gym environment (preview 3) wrapper 79 | 80 | :param env: The environment to wrap 81 | :type env: Any supported Isaac Gym environment (preview 3) environment 82 | """ 83 | super().__init__(env) 84 | 85 | self._reset_once = True 86 | self._observations = None 87 | self._info = {} 88 | 89 | @property 90 | def observation_space(self) -> gymnasium.Space: 91 | """Observation space""" 92 | return convert_gym_space(self._unwrapped.observation_space) 93 | 94 | @property 95 | def action_space(self) -> gymnasium.Space: 96 | """Action space""" 97 | return convert_gym_space(self._unwrapped.action_space) 98 | 99 | @property 100 | def state_space(self) -> Union[gymnasium.Space, None]: 101 | """State space""" 102 | try: 103 | if self.num_states: 104 | return convert_gym_space(self._unwrapped.state_space) 105 | except: 106 | pass 107 | return None 108 | 109 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 110 | """Perform a step in the environment 111 | 112 | :param actions: The actions to perform 113 | :type actions: torch.Tensor 114 | 115 | :return: Observation, reward, terminated, truncated, info 116 | :rtype: tuple of torch.Tensor and any other info 117 | """ 118 | observations, reward, terminated, self._info = self._env.step( 119 | unflatten_tensorized_space(self.action_space, actions) 120 | ) 121 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 122 | truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated) 123 | return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info 124 | 125 | def reset(self) -> Tuple[torch.Tensor, Any]: 126 | """Reset the environment 127 | 128 | :return: Observation, info 129 | :rtype: torch.Tensor and any other info 130 | """ 131 | if self._reset_once: 132 | observations = self._env.reset() 133 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["obs"])) 134 | self._reset_once = False 135 | return self._observations, self._info 136 | 137 | def render(self, *args, **kwargs) -> None: 138 | """Render the environment""" 139 | return None 140 | 141 | def close(self) -> None: 142 | """Close the environment""" 143 | pass 144 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Launch Isaac Sim Simulator first.""" 2 | 3 | import argparse 4 | 5 | from omni.isaac.lab.app import AppLauncher 6 | 7 | # add argparse arguments 8 | parser = argparse.ArgumentParser(description="Example on using the contact sensor.") 9 | parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.") 10 | # append AppLauncher cli args 11 | AppLauncher.add_app_launcher_args(parser) 12 | # parse the arguments 13 | args_cli = parser.parse_args() 14 | 15 | # launch omniverse app 16 | app_launcher = AppLauncher(args_cli) 17 | simulation_app = app_launcher.app 18 | 19 | """Rest everything follows.""" 20 | 21 | import torch 22 | 23 | import omni.isaac.lab.sim as sim_utils 24 | from omni.isaac.lab.assets import AssetBaseCfg, RigidObjectCfg 25 | from omni.isaac.lab.scene import InteractiveScene, InteractiveSceneCfg 26 | from omni.isaac.lab.sensors import ContactSensorCfg 27 | from omni.isaac.lab.utils import configclass 28 | 29 | ## 30 | # Pre-defined configs 31 | ## 32 | from omni.isaac.lab_assets import ANYMAL_C_CFG # isort: skip 33 | 34 | 35 | @configclass 36 | class ContactSensorSceneCfg(InteractiveSceneCfg): 37 | """Design the scene with sensors on the robot.""" 38 | 39 | # ground plane 40 | ground = AssetBaseCfg(prim_path="/World/defaultGroundPlane", spawn=sim_utils.GroundPlaneCfg()) 41 | 42 | # lights 43 | dome_light = AssetBaseCfg( 44 | prim_path="/World/Light", spawn=sim_utils.DomeLightCfg(intensity=3000.0, color=(0.75, 0.75, 0.75)) 45 | ) 46 | 47 | # robot 48 | robot = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") 49 | 50 | # Rigid Object 51 | cube = RigidObjectCfg( 52 | prim_path="{ENV_REGEX_NS}/Cube", 53 | spawn=sim_utils.CuboidCfg( 54 | size=(0.5, 0.5, 0.1), 55 | rigid_props=sim_utils.RigidBodyPropertiesCfg(), 56 | mass_props=sim_utils.MassPropertiesCfg(mass=100.0), 57 | collision_props=sim_utils.CollisionPropertiesCfg(), 58 | physics_material=sim_utils.RigidBodyMaterialCfg(static_friction=1.0), 59 | visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 0.0), metallic=0.2), 60 | ), 61 | init_state=RigidObjectCfg.InitialStateCfg(pos=(0.5, 0.5, 0.05)), 62 | ) 63 | 64 | contact_forces_H = ContactSensorCfg( 65 | prim_path="{ENV_REGEX_NS}/Robot/.*H_FOOT", 66 | update_period=0.0, 67 | history_length=6, 68 | debug_vis=True, 69 | ) 70 | 71 | 72 | def run_simulator(sim: sim_utils.SimulationContext, scene: InteractiveScene): 73 | """Run the simulator.""" 74 | # Define simulation stepping 75 | sim_dt = sim.get_physics_dt() 76 | sim_time = 0.0 77 | count = 0 78 | 79 | # Simulate physics 80 | while simulation_app.is_running(): 81 | 82 | if count % 500 == 0: 83 | # reset counter 84 | count = 0 85 | # reset the scene entities 86 | # root state 87 | # we offset the root state by the origin since the states are written in simulation world frame 88 | # if this is not done, then the robots will be spawned at the (0, 0, 0) of the simulation world 89 | root_state = scene["robot"].data.default_root_state.clone() 90 | root_state[:, :3] += scene.env_origins 91 | scene["robot"].write_root_pose_to_sim(root_state[:, :7]) 92 | scene["robot"].write_root_velocity_to_sim(root_state[:, 7:]) 93 | # set joint positions with some noise 94 | joint_pos, joint_vel = ( 95 | scene["robot"].data.default_joint_pos.clone(), 96 | scene["robot"].data.default_joint_vel.clone(), 97 | ) 98 | joint_pos += torch.rand_like(joint_pos) * 0.1 99 | scene["robot"].write_joint_state_to_sim(joint_pos, joint_vel) 100 | # clear internal buffers 101 | scene.reset() 102 | print("[INFO]: Resetting robot state...") 103 | # Apply default actions to the robot 104 | # -- generate actions/commands 105 | targets = scene["robot"].data.default_joint_pos 106 | # -- apply action to the robot 107 | scene["robot"].set_joint_position_target(targets) 108 | # -- write data to sim 109 | scene.write_data_to_sim() 110 | # perform step 111 | sim.step() 112 | # update sim-time 113 | sim_time += sim_dt 114 | count += 1 115 | # update buffers 116 | scene.update(sim_dt) 117 | 118 | # print information from the sensors 119 | print(scene["contact_forces_H"]) 120 | print("Received force matrix of: ", scene["contact_forces_H"].data.force_matrix_w) 121 | print("Received contact force of: ", scene["contact_forces_H"].data.net_forces_w) 122 | 123 | def main(): 124 | """Main function.""" 125 | 126 | # Initialize the simulation context 127 | sim_cfg = sim_utils.SimulationCfg(dt=0.005, device=args_cli.device) 128 | sim = sim_utils.SimulationContext(sim_cfg) 129 | # Set main camera 130 | sim.set_camera_view(eye=[3.5, 3.5, 3.5], target=[0.0, 0.0, 0.0]) 131 | # design scene 132 | scene_cfg = ContactSensorSceneCfg(num_envs=args_cli.num_envs, env_spacing=2.0) 133 | scene = InteractiveScene(scene_cfg) 134 | # Play the simulator 135 | sim.reset() 136 | # Now we are ready! 137 | print("[INFO]: Setup complete...") 138 | # Run the simulator 139 | run_simulator(sim, scene) 140 | 141 | 142 | if __name__ == "__main__": 143 | # run the main function 144 | main() 145 | # close sim app 146 | simulation_app.close() -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/torch/multicategorical.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Sequence, Tuple, Union 2 | 3 | import textwrap 4 | import gymnasium 5 | 6 | import torch 7 | import torch.nn as nn # noqa 8 | 9 | from skrl.models.torch import MultiCategoricalMixin # noqa 10 | from skrl.models.torch import Model 11 | from skrl.utils.model_instantiators.torch.common import one_hot_encoding # noqa 12 | from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers 13 | from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa 14 | 15 | 16 | def multicategorical_model( 17 | observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 18 | action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 19 | device: Optional[Union[str, torch.device]] = None, 20 | unnormalized_log_prob: bool = True, 21 | reduction: str = "sum", 22 | network: Sequence[Mapping[str, Any]] = [], 23 | output: Union[str, Sequence[str]] = "", 24 | return_source: bool = False, 25 | *args, 26 | **kwargs, 27 | ) -> Union[Model, str]: 28 | """Instantiate a multi-categorical model 29 | 30 | :param observation_space: Observation/state space or shape (default: None). 31 | If it is not None, the num_observations property will contain the size of that space 32 | :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional 33 | :param action_space: Action space or shape (default: None). 34 | If it is not None, the num_actions property will contain the size of that space 35 | :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional 36 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 37 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 38 | :type device: str or torch.device, optional 39 | :param unnormalized_log_prob: Flag to indicate how to be interpreted the model's output (default: True). 40 | If True, the model's output is interpreted as unnormalized log probabilities 41 | (it can be any real number), otherwise as normalized probabilities 42 | (the output must be non-negative, finite and have a non-zero sum) 43 | :type unnormalized_log_prob: bool, optional 44 | :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). 45 | Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density 46 | function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` 47 | :type reduction: str, optional 48 | :param network: Network definition (default: []) 49 | :type network: list of dict, optional 50 | :param output: Output expression (default: "") 51 | :type output: list or str, optional 52 | :param return_source: Whether to return the source string containing the model class used to 53 | instantiate the model rather than the model instance (default: False). 54 | :type return_source: bool, optional 55 | 56 | :return: Multi-Categorical model instance or definition source 57 | :rtype: Model 58 | """ 59 | # compatibility with versions prior to 1.3.0 60 | if not network and kwargs: 61 | network, output = convert_deprecated_parameters(kwargs) 62 | 63 | # parse model definition 64 | containers, output = generate_containers(network, output, embed_output=True, indent=1) 65 | 66 | # network definitions 67 | networks = [] 68 | forward: list[str] = [] 69 | for container in containers: 70 | networks.append(f'self.{container["name"]}_container = {container["sequential"]}') 71 | forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') 72 | # process output 73 | if output["modules"]: 74 | networks.append(f'self.output_layer = {output["modules"][0]}') 75 | forward.append(f'output = self.output_layer({container["name"]})') 76 | if output["output"]: 77 | forward.append(f'output = {output["output"]}') 78 | else: 79 | forward[-1] = forward[-1].replace(f'{container["name"]} =', "output =", 1) 80 | 81 | # build substitutions and indent content 82 | networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] 83 | forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] 84 | 85 | template = f"""class MultiCategoricalModel(MultiCategoricalMixin, Model): 86 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob, reduction="sum"): 87 | Model.__init__(self, observation_space, action_space, device) 88 | MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction) 89 | 90 | {networks} 91 | 92 | def compute(self, inputs, role=""): 93 | states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) 94 | taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) 95 | {forward} 96 | return output, {{}} 97 | """ 98 | # return source 99 | if return_source: 100 | return template 101 | 102 | # instantiate model 103 | _locals = {} 104 | exec(template, globals(), _locals) 105 | return _locals["MultiCategoricalModel"]( 106 | observation_space=observation_space, 107 | action_space=action_space, 108 | device=device, 109 | unnormalized_log_prob=unnormalized_log_prob, 110 | reduction=reduction, 111 | ) 112 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/gym_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import gymnasium 4 | from packaging import version 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from skrl import logger 10 | from skrl.envs.wrappers.torch.base import Wrapper 11 | from skrl.utils.spaces.torch import ( 12 | convert_gym_space, 13 | flatten_tensorized_space, 14 | tensorize_space, 15 | unflatten_tensorized_space, 16 | untensorize_space, 17 | ) 18 | 19 | 20 | class GymWrapper(Wrapper): 21 | def __init__(self, env: Any) -> None: 22 | """OpenAI Gym environment wrapper 23 | 24 | :param env: The environment to wrap 25 | :type env: Any supported OpenAI Gym environment 26 | """ 27 | super().__init__(env) 28 | 29 | # hack to fix: module 'numpy' has no attribute 'bool8' 30 | try: 31 | np.bool8 32 | except AttributeError: 33 | np.bool8 = np.bool 34 | 35 | import gym 36 | 37 | self._vectorized = False 38 | try: 39 | if isinstance(env, gym.vector.VectorEnv): 40 | self._vectorized = True 41 | self._reset_once = True 42 | self._observation = None 43 | self._info = None 44 | except Exception as e: 45 | logger.warning(f"Failed to check for a vectorized environment: {e}") 46 | 47 | self._deprecated_api = version.parse(gym.__version__) < version.parse("0.25.0") 48 | if self._deprecated_api: 49 | logger.warning(f"Using a deprecated version of OpenAI Gym's API: {gym.__version__}") 50 | 51 | @property 52 | def observation_space(self) -> gymnasium.Space: 53 | """Observation space""" 54 | if self._vectorized: 55 | return convert_gym_space(self._env.single_observation_space) 56 | return convert_gym_space(self._env.observation_space) 57 | 58 | @property 59 | def action_space(self) -> gymnasium.Space: 60 | """Action space""" 61 | if self._vectorized: 62 | return convert_gym_space(self._env.single_action_space) 63 | return convert_gym_space(self._env.action_space) 64 | 65 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 66 | """Perform a step in the environment 67 | 68 | :param actions: The actions to perform 69 | :type actions: torch.Tensor 70 | 71 | :return: Observation, reward, terminated, truncated, info 72 | :rtype: tuple of torch.Tensor and any other info 73 | """ 74 | actions = untensorize_space( 75 | self.action_space, 76 | unflatten_tensorized_space(self.action_space, actions), 77 | squeeze_batch_dimension=not self._vectorized, 78 | ) 79 | 80 | if self._deprecated_api: 81 | observation, reward, terminated, info = self._env.step(actions) 82 | # truncated: https://gymnasium.farama.org/tutorials/handling_time_limits 83 | if isinstance(info, (tuple, list)): 84 | truncated = np.array([d.get("TimeLimit.truncated", False) for d in info], dtype=terminated.dtype) 85 | terminated *= np.logical_not(truncated) 86 | info = {} 87 | else: 88 | truncated = info.get("TimeLimit.truncated", False) 89 | if truncated: 90 | terminated = False 91 | else: 92 | observation, reward, terminated, truncated, info = self._env.step(actions) 93 | 94 | # convert response to torch 95 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, device=self.device)) 96 | reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1) 97 | terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) 98 | truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) 99 | 100 | # save observation and info for vectorized envs 101 | if self._vectorized: 102 | self._observation = observation 103 | self._info = info 104 | 105 | return observation, reward, terminated, truncated, info 106 | 107 | def reset(self) -> Tuple[torch.Tensor, Any]: 108 | """Reset the environment 109 | 110 | :return: Observation, info 111 | :rtype: torch.Tensor and any other info 112 | """ 113 | # handle vectorized environments (vector environments are autoreset) 114 | if self._vectorized: 115 | if self._reset_once: 116 | if self._deprecated_api: 117 | observation = self._env.reset() 118 | self._info = {} 119 | else: 120 | observation, self._info = self._env.reset() 121 | self._observation = flatten_tensorized_space( 122 | tensorize_space(self.observation_space, observation, device=self.device) 123 | ) 124 | self._reset_once = False 125 | return self._observation, self._info 126 | 127 | if self._deprecated_api: 128 | observation = self._env.reset() 129 | info = {} 130 | else: 131 | observation, info = self._env.reset() 132 | observation = flatten_tensorized_space(tensorize_space(self.observation_space, observation, device=self.device)) 133 | return observation, info 134 | 135 | def render(self, *args, **kwargs) -> Any: 136 | """Render the environment""" 137 | if self._vectorized: 138 | return None 139 | return self._env.render(*args, **kwargs) 140 | 141 | def close(self) -> None: 142 | """Close the environment""" 143 | self._env.close() 144 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/jax/multicategorical.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Sequence, Tuple, Union 2 | 3 | import textwrap 4 | import gymnasium 5 | 6 | import flax.linen as nn # noqa 7 | import jax 8 | import jax.numpy as jnp # noqa 9 | 10 | from skrl.models.jax import MultiCategoricalMixin # noqa 11 | from skrl.models.jax import Model 12 | from skrl.utils.model_instantiators.jax.common import one_hot_encoding # noqa 13 | from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers 14 | from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa 15 | 16 | 17 | def multicategorical_model( 18 | observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 19 | action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 20 | device: Optional[Union[str, jax.Device]] = None, 21 | unnormalized_log_prob: bool = True, 22 | reduction: str = "sum", 23 | network: Sequence[Mapping[str, Any]] = [], 24 | output: Union[str, Sequence[str]] = "", 25 | return_source: bool = False, 26 | *args, 27 | **kwargs, 28 | ) -> Union[Model, str]: 29 | """Instantiate a multi-categorical model 30 | 31 | :param observation_space: Observation/state space or shape (default: None). 32 | If it is not None, the num_observations property will contain the size of that space 33 | :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional 34 | :param action_space: Action space or shape (default: None). 35 | If it is not None, the num_actions property will contain the size of that space 36 | :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional 37 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 38 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 39 | :type device: str or jax.Device, optional 40 | :param unnormalized_log_prob: Flag to indicate how to be interpreted the model's output (default: True). 41 | If True, the model's output is interpreted as unnormalized log probabilities 42 | (it can be any real number), otherwise as normalized probabilities 43 | (the output must be non-negative, finite and have a non-zero sum) 44 | :type unnormalized_log_prob: bool, optional 45 | :param reduction: Reduction method for returning the log probability density function: (default: ``"sum"``). 46 | Supported values are ``"mean"``, ``"sum"``, ``"prod"`` and ``"none"``. If "``none"``, the log probability density 47 | function is returned as a tensor of shape ``(num_samples, num_actions)`` instead of ``(num_samples, 1)`` 48 | :type reduction: str, optional 49 | :param network: Network definition (default: []) 50 | :type network: list of dict, optional 51 | :param output: Output expression (default: "") 52 | :type output: list or str, optional 53 | :param return_source: Whether to return the source string containing the model class used to 54 | instantiate the model rather than the model instance (default: False). 55 | :type return_source: bool, optional 56 | 57 | :return: Multi-Categorical model instance or definition source 58 | :rtype: Model 59 | """ 60 | # compatibility with versions prior to 1.3.0 61 | if not network and kwargs: 62 | network, output = convert_deprecated_parameters(kwargs) 63 | 64 | # parse model definition 65 | containers, output = generate_containers(network, output, embed_output=True, indent=1) 66 | 67 | # network definitions 68 | networks = [] 69 | forward: list[str] = [] 70 | for container in containers: 71 | networks.append(f'self.{container["name"]}_container = {container["sequential"]}') 72 | forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') 73 | # process output 74 | if output["modules"]: 75 | networks.append(f'self.output_layer = {output["modules"][0]}') 76 | forward.append(f'output = self.output_layer({container["name"]})') 77 | if output["output"]: 78 | forward.append(f'output = {output["output"]}') 79 | else: 80 | forward[-1] = forward[-1].replace(f'{container["name"]} =', "output =", 1) 81 | 82 | # build substitutions and indent content 83 | networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] 84 | forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] 85 | 86 | template = f"""class MultiCategoricalModel(MultiCategoricalMixin, Model): 87 | def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True, reduction="sum", **kwargs): 88 | Model.__init__(self, observation_space, action_space, device, **kwargs) 89 | MultiCategoricalMixin.__init__(self, unnormalized_log_prob, reduction) 90 | 91 | def setup(self): 92 | {networks} 93 | 94 | def __call__(self, inputs, role): 95 | states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) 96 | taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) 97 | {forward} 98 | return output, {{}} 99 | """ 100 | # return source 101 | if return_source: 102 | return template 103 | 104 | # instantiate model 105 | _locals = {} 106 | exec(template, globals(), _locals) 107 | return _locals["MultiCategoricalModel"]( 108 | observation_space=observation_space, 109 | action_space=action_space, 110 | device=device, 111 | unnormalized_log_prob=unnormalized_log_prob, 112 | reduction=reduction, 113 | ) 114 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/robosuite_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple 2 | 3 | import collections 4 | import gymnasium 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from skrl.envs.wrappers.torch.base import Wrapper 10 | from skrl.utils.spaces.torch import convert_gym_space 11 | 12 | 13 | class RobosuiteWrapper(Wrapper): 14 | def __init__(self, env: Any) -> None: 15 | """Robosuite environment wrapper 16 | 17 | :param env: The environment to wrap 18 | :type env: Any supported robosuite environment 19 | """ 20 | super().__init__(env) 21 | 22 | # observation and action spaces 23 | self._observation_space = self._spec_to_space(self._env.observation_spec()) 24 | self._action_space = self._spec_to_space(self._env.action_spec) 25 | 26 | @property 27 | def state_space(self) -> gymnasium.Space: 28 | """State space 29 | 30 | An alias for the ``observation_space`` property 31 | """ 32 | return convert_gym_space(self._observation_space) 33 | 34 | @property 35 | def observation_space(self) -> gymnasium.Space: 36 | """Observation space""" 37 | return convert_gym_space(self._observation_space) 38 | 39 | @property 40 | def action_space(self) -> gymnasium.Space: 41 | """Action space""" 42 | return convert_gym_space(self._action_space) 43 | 44 | def _spec_to_space(self, spec: Any) -> gymnasium.Space: 45 | """Convert the robosuite spec to a Gym space 46 | 47 | :param spec: The robosuite spec to convert 48 | :type spec: Any supported robosuite spec 49 | 50 | :raises: ValueError if the spec type is not supported 51 | 52 | :return: The Gym space 53 | :rtype: gymnasium.Space 54 | """ 55 | if type(spec) is tuple: 56 | return gymnasium.spaces.Box(shape=spec[0].shape, dtype=np.float32, low=spec[0], high=spec[1]) 57 | elif isinstance(spec, np.ndarray): 58 | return gymnasium.spaces.Box( 59 | shape=spec.shape, 60 | dtype=np.float32, 61 | low=np.full(spec.shape, float("-inf")), 62 | high=np.full(spec.shape, float("inf")), 63 | ) 64 | elif isinstance(spec, collections.OrderedDict): 65 | return gymnasium.spaces.Dict({k: self._spec_to_space(v) for k, v in spec.items()}) 66 | else: 67 | raise ValueError(f"Spec type {type(spec)} not supported. Please report this issue") 68 | 69 | def _observation_to_tensor(self, observation: Any, spec: Optional[Any] = None) -> torch.Tensor: 70 | """Convert the observation to a flat tensor 71 | 72 | :param observation: The observation to convert to a tensor 73 | :type observation: Any supported observation 74 | 75 | :raises: ValueError if the observation spec type is not supported 76 | 77 | :return: The observation as a flat tensor 78 | :rtype: torch.Tensor 79 | """ 80 | spec = spec if spec is not None else self._env.observation_spec() 81 | 82 | if isinstance(spec, np.ndarray): 83 | return torch.tensor(observation, device=self.device, dtype=torch.float32).reshape(self.num_envs, -1) 84 | elif isinstance(spec, collections.OrderedDict): 85 | return torch.cat( 86 | [self._observation_to_tensor(observation[k], spec[k]) for k in sorted(spec.keys())], dim=-1 87 | ).reshape(self.num_envs, -1) 88 | else: 89 | raise ValueError(f"Observation spec type {type(spec)} not supported. Please report this issue") 90 | 91 | def _tensor_to_action(self, actions: torch.Tensor) -> Any: 92 | """Convert the action to the robosuite expected format 93 | 94 | :param actions: The actions to perform 95 | :type actions: torch.Tensor 96 | 97 | :raise ValueError: If the action space type is not supported 98 | 99 | :return: The action in the robosuite expected format 100 | :rtype: Any supported robosuite action 101 | """ 102 | spec = self._env.action_spec 103 | 104 | if type(spec) is tuple: 105 | return np.array(actions.cpu().numpy(), dtype=np.float32).reshape(spec[0].shape) 106 | else: 107 | raise ValueError(f"Action spec type {type(spec)} not supported. Please report this issue") 108 | 109 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 110 | """Perform a step in the environment 111 | 112 | :param actions: The actions to perform 113 | :type actions: torch.Tensor 114 | 115 | :return: Observation, reward, terminated, truncated, info 116 | :rtype: tuple of torch.Tensor and any other info 117 | """ 118 | observation, reward, terminated, info = self._env.step(self._tensor_to_action(actions)) 119 | truncated = False 120 | info = {} 121 | 122 | # convert response to torch 123 | return ( 124 | self._observation_to_tensor(observation), 125 | torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1), 126 | torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), 127 | torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1), 128 | info, 129 | ) 130 | 131 | def reset(self) -> Tuple[torch.Tensor, Any]: 132 | """Reset the environment 133 | 134 | :return: The state of the environment 135 | :rtype: torch.Tensor 136 | """ 137 | observation = self._env.reset() 138 | return self._observation_to_tensor(observation), {} 139 | 140 | def render(self, *args, **kwargs) -> None: 141 | """Render the environment""" 142 | self._env.render(*args, **kwargs) 143 | 144 | def close(self) -> None: 145 | """Close the environment""" 146 | self._env.close() 147 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/jax/bidexhands_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Sequence, Tuple, Union 2 | 3 | import gymnasium 4 | 5 | import jax 6 | import jax.dlpack as jax_dlpack 7 | import numpy as np 8 | 9 | 10 | try: 11 | import torch 12 | import torch.utils.dlpack as torch_dlpack 13 | except: 14 | pass # TODO: show warning message 15 | 16 | from skrl.envs.wrappers.jax.base import MultiAgentEnvWrapper 17 | from skrl.utils.spaces.jax import convert_gym_space 18 | 19 | 20 | # ML frameworks conversion utilities 21 | # jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DLPack tensor is on GPU, but no GPU backend was provided. 22 | _CPU = jax.devices()[0].device_kind.lower() == "cpu" 23 | 24 | 25 | def _jax2torch(array, device, from_jax=True): 26 | if from_jax: 27 | return torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(array)).to(device=device) 28 | return torch.tensor(array, device=device) 29 | 30 | 31 | def _torch2jax(tensor, to_jax=True): 32 | if to_jax: 33 | return jax_dlpack.from_dlpack( 34 | torch_dlpack.to_dlpack(tensor.contiguous().cpu() if _CPU else tensor.contiguous()) 35 | ) 36 | return tensor.cpu().numpy() 37 | 38 | 39 | class BiDexHandsWrapper(MultiAgentEnvWrapper): 40 | def __init__(self, env: Any) -> None: 41 | """Bi-DexHands wrapper 42 | 43 | :param env: The environment to wrap 44 | :type env: Any supported Bi-DexHands environment 45 | """ 46 | super().__init__(env) 47 | 48 | self._reset_once = True 49 | self._states = None 50 | self._observations = None 51 | self._info = {} 52 | 53 | @property 54 | def agents(self) -> Sequence[str]: 55 | """Names of all current agents 56 | 57 | These may be changed as an environment progresses (i.e. agents can be added or removed) 58 | """ 59 | return self.possible_agents 60 | 61 | @property 62 | def possible_agents(self) -> Sequence[str]: 63 | """Names of all possible agents the environment could generate 64 | 65 | These can not be changed as an environment progresses 66 | """ 67 | return [f"agent_{i}" for i in range(self.num_agents)] 68 | 69 | @property 70 | def state_spaces(self) -> Mapping[str, gymnasium.Space]: 71 | """State spaces 72 | 73 | Since the state space is a global view of the environment (and therefore the same for all the agents), 74 | this property returns a dictionary (for consistency with the other space-related properties) with the same 75 | space for all the agents 76 | """ 77 | return { 78 | uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.share_observation_space) 79 | } 80 | 81 | @property 82 | def observation_spaces(self) -> Mapping[str, gymnasium.Space]: 83 | """Observation spaces""" 84 | return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.observation_space)} 85 | 86 | @property 87 | def action_spaces(self) -> Mapping[str, gymnasium.Space]: 88 | """Action spaces""" 89 | return {uid: convert_gym_space(space) for uid, space in zip(self.possible_agents, self._env.action_space)} 90 | 91 | def step(self, actions: Mapping[str, Union[np.ndarray, jax.Array]]) -> Tuple[ 92 | Mapping[str, Union[np.ndarray, jax.Array]], 93 | Mapping[str, Union[np.ndarray, jax.Array]], 94 | Mapping[str, Union[np.ndarray, jax.Array]], 95 | Mapping[str, Union[np.ndarray, jax.Array]], 96 | Mapping[str, Any], 97 | ]: 98 | """Perform a step in the environment 99 | 100 | :param actions: The actions to perform 101 | :type actions: dict of np.ndarray or jax.Array 102 | 103 | :return: Observation, reward, terminated, truncated, info 104 | :rtype: tuple of dict of np.ndarray or jax.Array and any other info 105 | """ 106 | actions = [_jax2torch(actions[uid], self._env.rl_device, self._jax) for uid in self.possible_agents] 107 | 108 | with torch.no_grad(): 109 | observations, states, rewards, terminated, _, _ = self._env.step(actions) 110 | 111 | observations = _torch2jax(observations, self._jax) 112 | states = _torch2jax(states, self._jax) 113 | rewards = _torch2jax(rewards, self._jax) 114 | terminated = _torch2jax(terminated.to(dtype=torch.int8), self._jax) 115 | 116 | self._states = states[:, 0] 117 | self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} 118 | rewards = {uid: rewards[:, i].reshape(-1, 1) for i, uid in enumerate(self.possible_agents)} 119 | terminated = {uid: terminated[:, i].reshape(-1, 1) for i, uid in enumerate(self.possible_agents)} 120 | truncated = terminated 121 | 122 | return self._observations, rewards, terminated, truncated, self._info 123 | 124 | def state(self) -> Union[np.ndarray, jax.Array]: 125 | """Get the environment state 126 | 127 | :return: State 128 | :rtype: np.ndarray of jax.Array 129 | """ 130 | return self._states 131 | 132 | def reset(self) -> Tuple[Mapping[str, Union[np.ndarray, jax.Array]], Mapping[str, Any]]: 133 | """Reset the environment 134 | 135 | :return: Observation, info 136 | :rtype: tuple of dict of np.ndarray of jax.Array and any other info 137 | """ 138 | if self._reset_once: 139 | observations, states, _ = self._env.reset() 140 | 141 | observations = _torch2jax(observations, self._jax) 142 | states = _torch2jax(states, self._jax) 143 | 144 | self._states = states[:, 0] 145 | self._observations = {uid: observations[:, i] for i, uid in enumerate(self.possible_agents)} 146 | self._reset_once = False 147 | return self._observations, self._info 148 | 149 | def render(self, *args, **kwargs) -> None: 150 | """Render the environment""" 151 | return None 152 | 153 | def close(self) -> None: 154 | """Close the environment""" 155 | pass 156 | -------------------------------------------------------------------------------- /src/third_parties/skrl/utils/model_instantiators/torch/multivariate_gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Optional, Sequence, Tuple, Union 2 | 3 | import textwrap 4 | import gymnasium 5 | 6 | import torch 7 | import torch.nn as nn # noqa 8 | 9 | from skrl.models.torch import MultivariateGaussianMixin # noqa 10 | from skrl.models.torch import Model 11 | from skrl.utils.model_instantiators.torch.common import one_hot_encoding # noqa 12 | from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers 13 | from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa 14 | 15 | 16 | def multivariate_gaussian_model( 17 | observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 18 | action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None, 19 | device: Optional[Union[str, torch.device]] = None, 20 | clip_actions: bool = False, 21 | clip_log_std: bool = True, 22 | min_log_std: float = -20, 23 | max_log_std: float = 2, 24 | initial_log_std: float = 0, 25 | fixed_log_std: bool = False, 26 | network: Sequence[Mapping[str, Any]] = [], 27 | output: Union[str, Sequence[str]] = "", 28 | return_source: bool = False, 29 | *args, 30 | **kwargs, 31 | ) -> Union[Model, str]: 32 | """Instantiate a multivariate Gaussian model 33 | 34 | :param observation_space: Observation/state space or shape (default: None). 35 | If it is not None, the num_observations property will contain the size of that space 36 | :type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional 37 | :param action_space: Action space or shape (default: None). 38 | If it is not None, the num_actions property will contain the size of that space 39 | :type action_space: int, tuple or list of integers, gymnasium.Space or None, optional 40 | :param device: Device on which a tensor/array is or will be allocated (default: ``None``). 41 | If None, the device will be either ``"cuda"`` if available or ``"cpu"`` 42 | :type device: str or torch.device, optional 43 | :param clip_actions: Flag to indicate whether the actions should be clipped (default: False) 44 | :type clip_actions: bool, optional 45 | :param clip_log_std: Flag to indicate whether the log standard deviations should be clipped (default: True) 46 | :type clip_log_std: bool, optional 47 | :param min_log_std: Minimum value of the log standard deviation (default: -20) 48 | :type min_log_std: float, optional 49 | :param max_log_std: Maximum value of the log standard deviation (default: 2) 50 | :type max_log_std: float, optional 51 | :param initial_log_std: Initial value for the log standard deviation (default: 0) 52 | :type initial_log_std: float, optional 53 | :param fixed_log_std: Whether the log standard deviation parameter should be fixed (default: False). 54 | Fixed parameters have the gradient computation deactivated 55 | :type fixed_log_std: bool, optional 56 | :param network: Network definition (default: []) 57 | :type network: list of dict, optional 58 | :param output: Output expression (default: "") 59 | :type output: list or str, optional 60 | :param return_source: Whether to return the source string containing the model class used to 61 | instantiate the model rather than the model instance (default: False). 62 | :type return_source: bool, optional 63 | 64 | :return: Multivariate Gaussian model instance or definition source 65 | :rtype: Model 66 | """ 67 | # compatibility with versions prior to 1.3.0 68 | if not network and kwargs: 69 | network, output = convert_deprecated_parameters(kwargs) 70 | 71 | # parse model definition 72 | containers, output = generate_containers(network, output, embed_output=True, indent=1) 73 | 74 | # network definitions 75 | networks = [] 76 | forward: list[str] = [] 77 | for container in containers: 78 | networks.append(f'self.{container["name"]}_container = {container["sequential"]}') 79 | forward.append(f'{container["name"]} = self.{container["name"]}_container({container["input"]})') 80 | # process output 81 | if output["modules"]: 82 | networks.append(f'self.output_layer = {output["modules"][0]}') 83 | forward.append(f'output = self.output_layer({container["name"]})') 84 | if output["output"]: 85 | forward.append(f'output = {output["output"]}') 86 | else: 87 | forward[-1] = forward[-1].replace(f'{container["name"]} =', "output =", 1) 88 | 89 | # build substitutions and indent content 90 | networks = textwrap.indent("\n".join(networks), prefix=" " * 8)[8:] 91 | forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:] 92 | 93 | template = f"""class MultivariateGaussianModel(MultivariateGaussianMixin, Model): 94 | def __init__(self, observation_space, action_space, device, clip_actions, 95 | clip_log_std, min_log_std, max_log_std): 96 | Model.__init__(self, observation_space, action_space, device) 97 | MultivariateGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) 98 | 99 | {networks} 100 | self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={float(initial_log_std)}), requires_grad={not fixed_log_std}) 101 | 102 | def compute(self, inputs, role=""): 103 | states = unflatten_tensorized_space(self.observation_space, inputs.get("states")) 104 | taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions")) 105 | {forward} 106 | return output, self.log_std_parameter, {{}} 107 | """ 108 | # return source 109 | if return_source: 110 | return template 111 | 112 | # instantiate model 113 | _locals = {} 114 | exec(template, globals(), _locals) 115 | return _locals["MultivariateGaussianModel"]( 116 | observation_space=observation_space, 117 | action_space=action_space, 118 | device=device, 119 | clip_actions=clip_actions, 120 | clip_log_std=clip_log_std, 121 | min_log_std=min_log_std, 122 | max_log_std=max_log_std, 123 | ) 124 | -------------------------------------------------------------------------------- /src/isaac_quad_sim2real/tasks/race/config/crazyflie/agents/rl_cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, The Isaac Lab Project Developers. 2 | # All rights reserved. 3 | # 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | 6 | from dataclasses import MISSING 7 | from typing import Literal 8 | 9 | from isaaclab.utils import configclass 10 | 11 | 12 | @configclass 13 | class RslRlPpoActorCriticCfg: 14 | """Configuration for the PPO actor-critic networks.""" 15 | 16 | class_name: str = "ActorCritic" 17 | """The policy class name. Default is ActorCritic.""" 18 | 19 | init_noise_std: float = MISSING 20 | """The initial noise standard deviation for the policy.""" 21 | 22 | actor_hidden_dims: list[int] = MISSING 23 | """The hidden dimensions of the actor network.""" 24 | 25 | film_hidden_dims: list[int] = MISSING 26 | """The hidden dimensions of the FiLM network.""" 27 | 28 | cond_dim: int = MISSING 29 | """The number of conditioning""" 30 | 31 | critic_hidden_dims: list[int] = MISSING 32 | """The hidden dimensions of the critic network.""" 33 | 34 | activation: str = MISSING 35 | """The activation function for the actor and critic networks.""" 36 | 37 | min_std: float = MISSING 38 | """The minimum standard deviation for the policy.""" 39 | 40 | 41 | @configclass 42 | class RslRlPpoActorCriticRecurrentCfg: 43 | """Configuration for the recurrent PPO actor-critic networks with FiLM.""" 44 | 45 | class_name: str = "ActorCriticRecurrentFiLM" 46 | """The policy class name.""" 47 | 48 | init_noise_std: float = MISSING 49 | """The initial noise standard deviation for the policy.""" 50 | 51 | actor_hidden_dims: list[int] = MISSING 52 | """The hidden dimensions of the actor network.""" 53 | 54 | film_hidden_dims: list[int] = MISSING 55 | """The hidden dimensions of the FiLM network.""" 56 | 57 | cond_dim: int = MISSING 58 | """The number of conditioning variables.""" 59 | 60 | critic_hidden_dims: list[int] = MISSING 61 | """The hidden dimensions of the critic network.""" 62 | 63 | activation: str = MISSING 64 | """The activation function for the actor and critic networks.""" 65 | 66 | min_std: float = MISSING 67 | """The minimum standard deviation for the policy.""" 68 | 69 | rnn_type: str = "lstm" 70 | """The type of RNN to use (lstm or gru).""" 71 | 72 | rnn_hidden_size: int = MISSING 73 | """The hidden size of the RNN.""" 74 | 75 | rnn_num_layers: int = MISSING 76 | """The number of RNN layers.""" 77 | 78 | 79 | @configclass 80 | class RslRlPpoAlgorithmCfg: 81 | """Configuration for the PPO algorithm.""" 82 | 83 | class_name: str = "PPO" 84 | """The algorithm class name. Default is PPO.""" 85 | 86 | value_loss_coef: float = MISSING 87 | """The coefficient for the value loss.""" 88 | 89 | use_clipped_value_loss: bool = MISSING 90 | """Whether to use clipped value loss.""" 91 | 92 | clip_param: float = MISSING 93 | """The clipping parameter for the policy.""" 94 | 95 | entropy_coef: float = MISSING 96 | """The coefficient for the entropy loss.""" 97 | 98 | num_learning_epochs: int = MISSING 99 | """The number of learning epochs per update.""" 100 | 101 | num_mini_batches: int = MISSING 102 | """The number of mini-batches per update.""" 103 | 104 | learning_rate: float = MISSING 105 | """The learning rate for the policy.""" 106 | 107 | schedule: str = MISSING 108 | """The learning rate schedule.""" 109 | 110 | gamma: float = MISSING 111 | """The discount factor.""" 112 | 113 | lam: float = MISSING 114 | """The lambda parameter for Generalized Advantage Estimation (GAE).""" 115 | 116 | desired_kl: float = MISSING 117 | """The desired KL divergence.""" 118 | 119 | max_grad_norm: float = MISSING 120 | """The maximum gradient norm.""" 121 | 122 | 123 | @configclass 124 | class RslRlOnPolicyRunnerCfg: 125 | """Configuration of the runner for on-policy algorithms.""" 126 | 127 | seed: int = 42 128 | """The seed for the experiment. Default is 42.""" 129 | 130 | device: str = "cuda:0" 131 | """The device for the rl-agent. Default is cuda:0.""" 132 | 133 | num_steps_per_env: int = MISSING 134 | """The number of steps per environment per update.""" 135 | 136 | max_iterations: int = MISSING 137 | """The maximum number of iterations.""" 138 | 139 | empirical_normalization: bool = MISSING 140 | """Whether to use empirical normalization.""" 141 | 142 | policy: RslRlPpoActorCriticCfg = MISSING 143 | """The policy configuration.""" 144 | 145 | algorithm: RslRlPpoAlgorithmCfg = MISSING 146 | """The algorithm configuration.""" 147 | 148 | ## 149 | # Checkpointing parameters 150 | ## 151 | 152 | save_interval: int = MISSING 153 | """The number of iterations between saves.""" 154 | 155 | experiment_name: str = MISSING 156 | """The experiment name.""" 157 | 158 | run_name: str = "" 159 | """The run name. Default is empty string. 160 | 161 | The name of the run directory is typically the time-stamp at execution. If the run name is not empty, 162 | then it is appended to the run directory's name, i.e. the logging directory's name will become 163 | ``{time-stamp}_{run_name}``. 164 | """ 165 | 166 | ## 167 | # Logging parameters 168 | ## 169 | 170 | logger: Literal["tensorboard", "neptune", "wandb"] = "tensorboard" 171 | """The logger to use. Default is tensorboard.""" 172 | 173 | neptune_project: str = "isaaclab" 174 | """The neptune project name. Default is "isaaclab".""" 175 | 176 | wandb_project: str = "isaaclab" 177 | """The wandb project name. Default is "isaaclab".""" 178 | 179 | ## 180 | # Loading parameters 181 | ## 182 | 183 | resume: bool = False 184 | """Whether to resume. Default is False.""" 185 | 186 | load_run: str = ".*" 187 | """The run directory to load. Default is ".*" (all). 188 | 189 | If regex expression, the latest (alphabetical order) matching run will be loaded. 190 | """ 191 | 192 | load_checkpoint: str = "model_.*.pt" 193 | """The checkpoint file to load. Default is ``"model_.*.pt"`` (all). 194 | 195 | If regex expression, the latest (alphabetical order) matching file will be loaded. 196 | """ 197 | -------------------------------------------------------------------------------- /src/third_parties/skrl/envs/wrappers/torch/isaaclab_envs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Tuple, Union 2 | 3 | import gymnasium 4 | 5 | import torch 6 | 7 | from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper, Wrapper 8 | from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space, unflatten_tensorized_space 9 | 10 | 11 | class IsaacLabWrapper(Wrapper): 12 | def __init__(self, env: Any) -> None: 13 | """Isaac Lab environment wrapper 14 | 15 | :param env: The environment to wrap 16 | :type env: Any supported Isaac Lab environment 17 | """ 18 | super().__init__(env) 19 | 20 | self._reset_once = True 21 | self._observations = None 22 | self._info = {} 23 | 24 | @property 25 | def state_space(self) -> Union[gymnasium.Space, None]: 26 | """State space""" 27 | try: 28 | return self._unwrapped.single_observation_space["critic"] 29 | except KeyError: 30 | pass 31 | try: 32 | return self._unwrapped.state_space 33 | except AttributeError: 34 | return None 35 | 36 | @property 37 | def observation_space(self) -> gymnasium.Space: 38 | """Observation space""" 39 | try: 40 | return self._unwrapped.single_observation_space["policy"] 41 | except: 42 | return self._unwrapped.observation_space["policy"] 43 | 44 | @property 45 | def action_space(self) -> gymnasium.Space: 46 | """Action space""" 47 | try: 48 | return self._unwrapped.single_action_space 49 | except: 50 | return self._unwrapped.action_space 51 | 52 | def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]: 53 | """Perform a step in the environment 54 | 55 | :param actions: The actions to perform 56 | :type actions: torch.Tensor 57 | 58 | :return: Observation, reward, terminated, truncated, info 59 | :rtype: tuple of torch.Tensor and any other info 60 | """ 61 | actions = unflatten_tensorized_space(self.action_space, actions) 62 | observations, reward, terminated, truncated, self._info = self._env.step(actions) 63 | self._observations = flatten_tensorized_space(tensorize_space(self.observation_space, observations["policy"])) 64 | return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info 65 | 66 | def reset(self) -> Tuple[torch.Tensor, Any]: 67 | """Reset the environment 68 | 69 | :return: Observation, info 70 | :rtype: torch.Tensor and any other info 71 | """ 72 | if self._reset_once: 73 | observations, self._info = self._env.reset() 74 | self._observations = flatten_tensorized_space( 75 | tensorize_space(self.observation_space, observations["policy"]) 76 | ) 77 | self._reset_once = False 78 | return self._observations, self._info 79 | 80 | def render(self, *args, **kwargs) -> None: 81 | """Render the environment""" 82 | return None 83 | 84 | def close(self) -> None: 85 | """Close the environment""" 86 | self._env.close() 87 | 88 | 89 | class IsaacLabMultiAgentWrapper(MultiAgentEnvWrapper): 90 | def __init__(self, env: Any) -> None: 91 | """Isaac Lab environment wrapper for multi-agent implementation 92 | 93 | :param env: The environment to wrap 94 | :type env: Any supported Isaac Lab environment 95 | """ 96 | super().__init__(env) 97 | 98 | self._reset_once = True 99 | self._observations = None 100 | self._info = {} 101 | 102 | def step(self, actions: Mapping[str, torch.Tensor]) -> Tuple[ 103 | Mapping[str, torch.Tensor], 104 | Mapping[str, torch.Tensor], 105 | Mapping[str, torch.Tensor], 106 | Mapping[str, torch.Tensor], 107 | Mapping[str, Any], 108 | ]: 109 | """Perform a step in the environment 110 | 111 | :param actions: The actions to perform 112 | :type actions: dictionary of torch.Tensor 113 | 114 | :return: Observation, reward, terminated, truncated, info 115 | :rtype: tuple of dictionaries torch.Tensor and any other info 116 | """ 117 | actions = {k: unflatten_tensorized_space(self.action_spaces[k], v) for k, v in actions.items()} 118 | observations, rewards, terminated, truncated, self._info = self._env.step(actions) 119 | self._observations = { 120 | k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) for k, v in observations.items() 121 | } 122 | return ( 123 | self._observations, 124 | {k: v.view(-1, 1) for k, v in rewards.items()}, 125 | {k: v.view(-1, 1) for k, v in terminated.items()}, 126 | {k: v.view(-1, 1) for k, v in truncated.items()}, 127 | self._info, 128 | ) 129 | 130 | def reset(self) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, Any]]: 131 | """Reset the environment 132 | 133 | :return: Observation, info 134 | :rtype: torch.Tensor and any other info 135 | """ 136 | if self._reset_once: 137 | observations, self._info = self._env.reset() 138 | self._observations = { 139 | k: flatten_tensorized_space(tensorize_space(self.observation_spaces[k], v)) 140 | for k, v in observations.items() 141 | } 142 | self._reset_once = False 143 | return self._observations, self._info 144 | 145 | def state(self) -> torch.Tensor: 146 | """Get the environment state 147 | 148 | :return: State 149 | :rtype: torch.Tensor 150 | """ 151 | try: 152 | state = self._env.state() 153 | except AttributeError: # 'OrderEnforcing' object has no attribute 'state' 154 | state = self._unwrapped.state() 155 | if state is not None: 156 | return flatten_tensorized_space(tensorize_space(next(iter(self.state_spaces.values())), state)) 157 | return state 158 | 159 | def render(self, *args, **kwargs) -> None: 160 | """Render the environment""" 161 | return None 162 | 163 | def close(self) -> None: 164 | """Close the environment""" 165 | self._env.close() 166 | --------------------------------------------------------------------------------