├── .clang-format ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── pytorch-probot.yml ├── scripts │ ├── pre-build-script-win.sh │ ├── pre-build-script.sh │ ├── td_script.sh │ └── version_script.bat ├── unittest │ ├── helpers │ │ └── coverage_run_parallel.py │ ├── linux │ │ └── scripts │ │ │ ├── 10_nvidia.json │ │ │ ├── environment.yml │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ └── run_all.sh │ ├── linux_distributed │ │ └── scripts │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ ├── linux_libs │ │ ├── scripts_ataridqn │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_brax │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_all.sh │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_chess │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_d4rl │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_envpool │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_gen-dgrl │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_gym │ │ │ ├── batch_scripts.sh │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_habitat │ │ │ ├── 10_nvidia.json │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_all.sh │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_isaaclab │ │ │ └── isaac.sh │ │ ├── scripts_jumanji │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_llm │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_meltingpot │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_minari │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_open_spiel │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_openx │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_pettingzoo │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_robohive │ │ │ ├── environment.yml │ │ │ ├── install_and_run_test.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ └── setup_env.sh │ │ ├── scripts_roboset │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_sklearn │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_smacv2 │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_unity_mlagents │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ ├── scripts_vd4rl │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ │ └── scripts_vmas │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ ├── linux_olddeps │ │ └── scripts_gym_0_13 │ │ │ ├── batch_scripts.sh │ │ │ ├── environment.yml │ │ │ ├── install.sh │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_test.sh │ │ │ └── setup_env.sh │ ├── linux_optdeps │ │ └── scripts │ │ │ ├── 10_nvidia.json │ │ │ ├── environment.yml │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ └── run_all.sh │ ├── linux_sota │ │ └── scripts │ │ │ ├── 10_nvidia.json │ │ │ ├── environment.yml │ │ │ ├── post_process.sh │ │ │ ├── run-clang-format.py │ │ │ ├── run_all.sh │ │ │ ├── run_local.sh │ │ │ ├── run_test.sh │ │ │ └── test_sota.py │ └── windows_optdepts │ │ └── scripts │ │ ├── environment.yml │ │ ├── install_conda.bat │ │ ├── set_cuda_envs.sh │ │ ├── unittest.sh │ │ └── vc_env_helper.bat └── workflows │ ├── benchmarks.yml │ ├── benchmarks_pr.yml │ ├── build-wheels-aarch64-linux.yml │ ├── build-wheels-linux.yml │ ├── build-wheels-m1.yml │ ├── build-wheels-windows.yml │ ├── docs.yml │ ├── lint.yml │ ├── nightly_build.yml │ ├── test-linux-habitat.yml │ ├── test-linux-libs.yml │ ├── test-linux-llm.yml │ ├── test-linux-sota.yml │ ├── test-linux.yml │ └── test-windows-optdepts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── benchmark_batched_envs.py ├── conftest.py ├── ecosystem │ ├── gym_env_throughput.py │ └── vmas_rllib_vs_torchrl_sampling_performance.py ├── requirements.txt ├── storage │ └── benchmark_sample_latency_over_rpc.py ├── test_collectors_benchmark.py ├── test_envs_benchmark.py ├── test_objectives_benchmarks.py └── test_replaybuffer_benchmark.py ├── build_tools ├── __init__.py └── setup_helpers │ ├── __init__.py │ └── extension.py ├── docs ├── Makefile ├── make.bat ├── ppo.png ├── requirements.txt └── source │ ├── _static │ ├── css │ │ ├── custom.css │ │ ├── pytorch_theme.css │ │ └── theme.css │ ├── img │ │ ├── cartpole.gif │ │ ├── cartpole_demo.gif │ │ ├── collector-copy.png │ │ ├── dqn.png │ │ ├── dqn_td0.png │ │ ├── dqn_tdlambda.png │ │ ├── favicon.png │ │ ├── icon.png │ │ ├── invpendulum.gif │ │ ├── logo.png │ │ ├── mcts_forest.png │ │ ├── pendulum.gif │ │ ├── rename_transform.png │ │ ├── replaybuffer_traj.png │ │ ├── rollout.gif │ │ ├── rollout_recurrent.png │ │ └── transforms.png │ └── js │ │ ├── modernizr.min.js │ │ ├── theme.js │ │ └── torchrl_theme.js │ ├── _templates │ ├── class.rst │ ├── function.rst │ ├── layout.html │ ├── rl_template.rst │ ├── rl_template_fun.rst │ └── rl_template_noinherit.rst │ ├── conf.py │ ├── content_generation.py │ ├── docutils.conf │ ├── index.rst │ └── reference │ ├── collectors.rst │ ├── data.rst │ ├── envs.rst │ ├── index.rst │ ├── knowledge_base.rst │ ├── llms.rst │ ├── modules.rst │ ├── objectives.rst │ ├── trainers.rst │ └── utils.rst ├── examples ├── agents │ ├── composite_actor.py │ ├── composite_ppo.py │ ├── multi-step.py │ └── recurrent_actor.py ├── collectors │ ├── collector_device.py │ └── mp_collector_mps.py ├── distributed │ ├── collectors │ │ ├── README.md │ │ ├── multi_nodes │ │ │ ├── delayed_dist.py │ │ │ ├── delayed_rpc.py │ │ │ ├── generic.py │ │ │ ├── ray_buffer_infra.py │ │ │ ├── ray_collect.py │ │ │ ├── ray_train.py │ │ │ ├── rpc.py │ │ │ └── sync.py │ │ └── single_machine │ │ │ ├── generic.py │ │ │ ├── rpc.py │ │ │ └── sync.py │ └── replay_buffers │ │ ├── distributed_replay_buffer.py │ │ └── ray_buffer.py ├── envs │ ├── gym-async-info-reader.py │ └── gym_conversion_examples.py ├── memmap │ ├── memmap_speed_distributed.py │ └── memmap_td_distributed.py ├── replay-buffers │ ├── catframes-in-buffer.py │ ├── checkpoint.py │ └── filter-imcomplete-trajs.py ├── rlhf │ ├── .gitignore │ ├── README.md │ ├── config │ │ ├── train.yaml │ │ ├── train_reward.yaml │ │ └── train_rlhf.yaml │ ├── data │ │ └── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── actor_critic.py │ │ ├── reward.py │ │ └── transformer.py │ ├── requirements.txt │ ├── train.py │ ├── train_reward.py │ ├── train_rlhf.py │ └── utils.py └── video │ └── video-from-dataset.py ├── gallery └── README.rst ├── knowledge_base ├── DEBUGGING_RL.md ├── GYM.md ├── HABITAT.md ├── MUJOCO_INSTALLATION.md ├── PRO-TIPS.md ├── README.md ├── RESOURCES.md ├── VERSIONING_ISSUES.md └── VIDEO_CUSTOMISATION.md ├── mypy.ini ├── packaging ├── build_wheels.sh ├── pkg_helpers.bash ├── wheel │ └── relocate.py └── windows │ └── internal │ ├── cuda_install.bat │ ├── driver_update.bat │ ├── vc_env_helper.bat │ ├── vc_install_helper.sh │ └── vs2017_install.ps1 ├── pyproject.toml ├── pytest.ini ├── setup.cfg ├── setup.py ├── sota-check ├── README.md ├── run_a2c_atari.sh ├── run_a2c_mujoco.sh ├── run_cql_offline.sh ├── run_cql_online.sh ├── run_crossq.sh ├── run_ddpg.sh ├── run_discrete_sac.sh ├── run_dqn_atari.sh ├── run_dqn_cartpole.sh ├── run_dt.sh ├── run_dt_online.sh ├── run_impala_single_node.sh ├── run_iql_discrete.sh ├── run_iql_offline.sh ├── run_iql_online.sh ├── run_multiagent_iddpg.sh ├── run_multiagent_ippo.sh ├── run_multiagent_iql.sh ├── run_multiagent_qmix.sh ├── run_multiagent_sac.sh ├── run_ppo_atari.sh ├── run_ppo_mujoco.sh ├── run_sac.sh ├── run_td3.sh ├── run_td3bc.sh └── submitit-release-check.sh ├── sota-implementations ├── README.md ├── a2c │ ├── README.md │ ├── a2c_atari.py │ ├── a2c_mujoco.py │ ├── config_atari.yaml │ ├── config_mujoco.yaml │ ├── utils_atari.py │ └── utils_mujoco.py ├── bandits │ ├── README.md │ └── dqn.py ├── cql │ ├── cql_offline.py │ ├── cql_online.py │ ├── discrete_cql_config.yaml │ ├── discrete_cql_online.py │ ├── offline_config.yaml │ ├── online_config.yaml │ └── utils.py ├── crossq │ ├── config.yaml │ ├── crossq.py │ └── utils.py ├── ddpg │ ├── config.yaml │ ├── ddpg.py │ └── utils.py ├── decision_transformer │ ├── dt.py │ ├── dt_config.yaml │ ├── lamb.py │ ├── odt_config.yaml │ ├── online_dt.py │ └── utils.py ├── discrete_sac │ ├── config.yaml │ ├── discrete_sac.py │ └── utils.py ├── dqn │ ├── README.md │ ├── config_atari.yaml │ ├── config_cartpole.yaml │ ├── dqn_atari.py │ ├── dqn_cartpole.py │ ├── utils_atari.py │ └── utils_cartpole.py ├── dreamer │ ├── README.md │ ├── config.yaml │ ├── dreamer.py │ └── dreamer_utils.py ├── gail │ ├── config.yaml │ ├── gail.py │ ├── gail_utils.py │ └── ppo_utils.py ├── impala │ ├── README.md │ ├── config_multi_node_ray.yaml │ ├── config_multi_node_submitit.yaml │ ├── config_single_node.yaml │ ├── impala_multi_node_ray.py │ ├── impala_multi_node_submitit.py │ ├── impala_single_node.py │ └── utils.py ├── iql │ ├── discrete_iql.py │ ├── discrete_iql.yaml │ ├── iql_offline.py │ ├── iql_online.py │ ├── offline_config.yaml │ ├── online_config.yaml │ └── utils.py ├── media │ ├── ant_chart.png │ ├── cheetah_chart.png │ ├── halfcheetah_chart.png │ └── walker2d_chart.png ├── multiagent │ ├── README.md │ ├── iql.py │ ├── iql.yaml │ ├── maddpg_iddpg.py │ ├── maddpg_iddpg.yaml │ ├── mappo_ippo.py │ ├── mappo_ippo.yaml │ ├── qmix_vdn.py │ ├── qmix_vdn.yaml │ ├── sac.py │ ├── sac.yaml │ └── utils │ │ ├── __init__.py │ │ ├── logging.py │ │ └── utils.py ├── ppo │ ├── README.md │ ├── config_atari.yaml │ ├── config_mujoco.yaml │ ├── ppo_atari.py │ ├── ppo_mujoco.py │ ├── utils_atari.py │ └── utils_mujoco.py ├── redq │ ├── README.md │ ├── config.yaml │ ├── redq.py │ └── utils.py ├── sac │ ├── config-async.yaml │ ├── config.yaml │ ├── sac-async.py │ ├── sac.py │ └── utils.py ├── td3 │ ├── config.yaml │ ├── td3.py │ └── utils.py └── td3_bc │ ├── config.yaml │ ├── td3_bc.py │ └── utils.py ├── test ├── _utils_internal.py ├── assets │ ├── generate.py │ ├── openai_summarize_comparisons.zip │ ├── openai_summarize_tldr.zip │ └── tldr_batch.zip ├── conftest.py ├── llm │ ├── libs │ │ └── test_mlgym.py │ ├── mocking_classes.py │ ├── smoke_test.py │ ├── smoke_test_deps.py │ ├── test_collectors.py │ ├── test_data.py │ ├── test_envs.py │ ├── test_modules.py │ └── test_objectives.py ├── mocking_classes.py ├── opengl_rendering.py ├── smoke_test.py ├── smoke_test_deps.py ├── test_actors.py ├── test_collector.py ├── test_cost.py ├── test_distributed.py ├── test_distributions.py ├── test_env.py ├── test_exploration.py ├── test_helpers.py ├── test_libs.py ├── test_loggers.py ├── test_modules.py ├── test_postprocs.py ├── test_rb.py ├── test_rb_distributed.py ├── test_rlhf.py ├── test_shared.py ├── test_specs.py ├── test_storage_map.py ├── test_tensordictmodules.py ├── test_trainer.py ├── test_transforms.py └── test_utils.py ├── torchrl ├── __init__.py ├── _extension.py ├── _utils.py ├── collectors │ ├── __init__.py │ ├── collectors.py │ ├── distributed │ │ ├── __init__.py │ │ ├── default_configs.py │ │ ├── generic.py │ │ ├── ray.py │ │ ├── rpc.py │ │ ├── sync.py │ │ └── utils.py │ ├── llm │ │ ├── __init__.py │ │ ├── base.py │ │ ├── utils.py │ │ └── weight_update │ │ │ ├── __init__.py │ │ │ └── vllm.py │ ├── utils.py │ └── weight_update.py ├── csrc │ ├── numpy_utils.h │ ├── pybind.cpp │ ├── segment_tree.h │ ├── torch_utils.h │ ├── utils.cpp │ └── utils.h ├── data │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── atari_dqn.py │ │ ├── common.py │ │ ├── d4rl.py │ │ ├── d4rl_infos.py │ │ ├── gen_dgrl.py │ │ ├── minari_data.py │ │ ├── openml.py │ │ ├── openx.py │ │ ├── roboset.py │ │ ├── utils.py │ │ ├── vd4rl.py │ │ └── vd4rl_datasets.json │ ├── llm │ │ ├── __init__.py │ │ ├── chat.py │ │ ├── common.py │ │ ├── dataset.py │ │ ├── prompt.py │ │ ├── reward.py │ │ └── utils.py │ ├── map │ │ ├── __init__.py │ │ ├── hash.py │ │ ├── query.py │ │ ├── tdstorage.py │ │ ├── tree.py │ │ └── utils.py │ ├── postprocs │ │ ├── __init__.py │ │ └── postprocs.py │ ├── replay_buffers │ │ ├── __init__.py │ │ ├── checkpointers.py │ │ ├── ray_buffer.py │ │ ├── replay_buffers.py │ │ ├── samplers.py │ │ ├── scheduler.py │ │ ├── storages.py │ │ ├── utils.py │ │ └── writers.py │ ├── rlhf.py │ ├── tensor_specs.py │ └── utils.py ├── envs │ ├── __init__.py │ ├── async_envs.py │ ├── batched_envs.py │ ├── common.py │ ├── custom │ │ ├── __init__.py │ │ ├── chess.py │ │ ├── llm.py │ │ ├── pendulum.py │ │ ├── san_moves.txt │ │ └── tictactoeenv.py │ ├── env_creator.py │ ├── gym_like.py │ ├── libs │ │ ├── __init__.py │ │ ├── _gym_utils.py │ │ ├── brax.py │ │ ├── dm_control.py │ │ ├── envpool.py │ │ ├── gym.py │ │ ├── habitat.py │ │ ├── isaac_lab.py │ │ ├── isaacgym.py │ │ ├── jax_utils.py │ │ ├── jumanji.py │ │ ├── meltingpot.py │ │ ├── openml.py │ │ ├── openspiel.py │ │ ├── pettingzoo.py │ │ ├── robohive.py │ │ ├── smacv2.py │ │ ├── unity_mlagents.py │ │ ├── utils.py │ │ └── vmas.py │ ├── llm │ │ ├── __init__.py │ │ ├── chat.py │ │ ├── datasets │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── gsm8k.py │ │ │ └── ifeval.py │ │ ├── envs.py │ │ ├── libs │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ └── mlgym.py │ │ ├── reward │ │ │ ├── __init__.py │ │ │ ├── gsm8k.py │ │ │ └── ifeval │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── _instructions.py │ │ │ │ ├── _instructions_main.py │ │ │ │ ├── _instructions_registry.py │ │ │ │ ├── _instructions_util.py │ │ │ │ └── _scorer.py │ │ └── transforms │ │ │ ├── __init__.py │ │ │ ├── dataloading.py │ │ │ ├── format.py │ │ │ ├── kl.py │ │ │ └── tokenizer.py │ ├── model_based │ │ ├── __init__.py │ │ ├── common.py │ │ └── dreamer.py │ ├── transforms │ │ ├── __init__.py │ │ ├── functional.py │ │ ├── gym_transforms.py │ │ ├── llm.py │ │ ├── r3m.py │ │ ├── rb_transforms.py │ │ ├── rlhf.py │ │ ├── transforms.py │ │ ├── utils.py │ │ ├── vc1.py │ │ ├── vecnorm.py │ │ └── vip.py │ ├── utils.py │ └── vec_envs.py ├── modules │ ├── __init__.py │ ├── distributions │ │ ├── __init__.py │ │ ├── continuous.py │ │ ├── discrete.py │ │ ├── truncated_normal.py │ │ └── utils.py │ ├── llm │ │ ├── __init__.py │ │ ├── backends │ │ │ ├── __init__.py │ │ │ └── vllm.py │ │ ├── policies │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── transformers_wrapper.py │ │ │ └── vllm_wrapper.py │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── batchrenorm.py │ │ ├── decision_transformer.py │ │ ├── exploration.py │ │ ├── llm.py │ │ ├── model_based.py │ │ ├── models.py │ │ ├── multiagent.py │ │ ├── recipes │ │ │ └── impala.py │ │ ├── rlhf.py │ │ └── utils.py │ ├── planners │ │ ├── __init__.py │ │ ├── cem.py │ │ ├── common.py │ │ └── mppi.py │ ├── tensordict_module │ │ ├── __init__.py │ │ ├── actors.py │ │ ├── common.py │ │ ├── exploration.py │ │ ├── probabilistic.py │ │ ├── rnn.py │ │ ├── sequence.py │ │ └── world_models.py │ └── utils │ │ ├── __init__.py │ │ ├── mappings.py │ │ └── utils.py ├── objectives │ ├── __init__.py │ ├── a2c.py │ ├── common.py │ ├── cql.py │ ├── crossq.py │ ├── ddpg.py │ ├── decision_transformer.py │ ├── deprecated.py │ ├── dqn.py │ ├── dreamer.py │ ├── functional.py │ ├── gail.py │ ├── iql.py │ ├── llm │ │ ├── __init__.py │ │ └── grpo.py │ ├── multiagent │ │ ├── __init__.py │ │ └── qmixer.py │ ├── ppo.py │ ├── redq.py │ ├── reinforce.py │ ├── sac.py │ ├── td3.py │ ├── td3_bc.py │ ├── utils.py │ └── value │ │ ├── __init__.py │ │ ├── advantages.py │ │ ├── functional.py │ │ └── utils.py ├── record │ ├── __init__.py │ ├── loggers │ │ ├── __init__.py │ │ ├── common.py │ │ ├── csv.py │ │ ├── mlflow.py │ │ ├── tensorboard.py │ │ ├── utils.py │ │ └── wandb.py │ └── recorder.py └── trainers │ ├── __init__.py │ ├── helpers │ ├── __init__.py │ ├── collectors.py │ ├── envs.py │ ├── logger.py │ ├── losses.py │ ├── models.py │ ├── replay_buffer.py │ └── trainers.py │ └── trainers.py ├── tutorials ├── README.md ├── media │ └── transformer.png └── sphinx-tutorials │ ├── README.rst │ ├── coding_ddpg.py │ ├── coding_dqn.py │ ├── coding_ppo.py │ ├── dqn_with_rnn.py │ ├── export.py │ ├── getting-started-0.py │ ├── getting-started-1.py │ ├── getting-started-2.py │ ├── getting-started-3.py │ ├── getting-started-4.py │ ├── getting-started-5.py │ ├── multi_task.py │ ├── multiagent_competitive_ddpg.py │ ├── multiagent_ppo.py │ ├── pendulum.py │ ├── pretrained_models.py │ ├── rb_tutorial.py │ ├── run_local.sh │ ├── torchrl_demo.py │ └── torchrl_envs.py └── version.txt /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature Request]" 5 | labels: ["enhancement"] 6 | assignees: vmoens 7 | 8 | --- 9 | 10 | ## Motivation 11 | 12 | Please outline the motivation for the proposal. 13 | Is your feature request related to a problem? e.g., "I'm always frustrated when [...]". 14 | If this is related to another issue, please link here too. 15 | 16 | ## Solution 17 | 18 | A clear and concise description of what you want to happen. 19 | 20 | ## Alternatives 21 | 22 | A clear and concise description of any alternative solutions or features you've considered. 23 | 24 | ## Additional context 25 | 26 | Add any other context or screenshots about the feature request here. 27 | 28 | ## Checklist 29 | 30 | - [ ] I have checked that there is no similar issue in the repo (**required**) 31 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | Describe your changes in detail. 4 | 5 | ## Motivation and Context 6 | 7 | Why is this change required? What problem does it solve? 8 | If it fixes an open issue, please link to the issue here. 9 | You can use the syntax `close #15213` if this solves the issue #15213 10 | 11 | - [ ] I have raised an issue to propose this change ([required](https://github.com/pytorch/rl/issues) for new features and bug fixes) 12 | 13 | ## Types of changes 14 | 15 | What types of changes does your code introduce? Remove all that do not apply: 16 | 17 | - [ ] Bug fix (non-breaking change which fixes an issue) 18 | - [ ] New feature (non-breaking change which adds core functionality) 19 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 20 | - [ ] Documentation (update in the documentation) 21 | - [ ] Example (update in the folder of examples) 22 | 23 | ## Checklist 24 | 25 | Go over all the following points, and put an `x` in all the boxes that apply. 26 | If you are unsure about any of these, don't hesitate to ask. We are here to help! 27 | 28 | - [ ] I have read the [CONTRIBUTION](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) guide (**required**) 29 | - [ ] My change requires a change to the documentation. 30 | - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). 31 | - [ ] I have updated the documentation accordingly. 32 | -------------------------------------------------------------------------------- /.github/pytorch-probot.yml: -------------------------------------------------------------------------------- 1 | # List of workflows that will be re-run in case of failures 2 | # https://github.com/pytorch/test-infra/blob/main/torchci/lib/bot/retryBot.ts 3 | retryable_workflows: 4 | - Build M1 5 | - Wheels 6 | -------------------------------------------------------------------------------- /.github/scripts/pre-build-script-win.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install --upgrade setuptools 4 | 5 | export TORCHRL_BUILD_VERSION=0.9.0 6 | -------------------------------------------------------------------------------- /.github/scripts/pre-build-script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install --upgrade setuptools 4 | 5 | ${CONDA_RUN} pip install "pybind11[global]" 6 | ${CONDA_RUN} conda install anaconda::cmake -y 7 | ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U 8 | -------------------------------------------------------------------------------- /.github/scripts/td_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TORCHRL_BUILD_VERSION=0.9.0 4 | pip install --upgrade setuptools 5 | 6 | # Check if ARCH is set to aarch64 7 | ARCH=${ARCH:-} # This sets ARCH to an empty string if it's not defined 8 | 9 | if pip list | grep -q torch; then 10 | echo "Torch is installed." 11 | 12 | # ${CONDA_RUN} conda install 'anaconda::cmake>=3.22' -y 13 | 14 | ${CONDA_RUN} pip install "pybind11[global]" 15 | 16 | ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U --no-deps 17 | elif [[ -n "${SMOKE_TEST_SCRIPT:-}" ]]; then 18 | ${CONDA_RUN} ${PIP_INSTALL_TORCH} 19 | # TODO: revert when nightlies of tensordict are fixed 20 | # if [[ "$ARCH" == "aarch64" ]]; then 21 | 22 | 23 | # ${CONDA_RUN} conda install 'anaconda::cmake>=3.22' -y 24 | 25 | ${CONDA_RUN} pip install "pybind11[global]" 26 | 27 | ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U --no-deps 28 | else 29 | echo "Torch is not installed - tensordict will be installed later." 30 | fi 31 | -------------------------------------------------------------------------------- /.github/scripts/version_script.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set TORCHRL_BUILD_VERSION=0.9.0 3 | echo TORCHRL_BUILD_VERSION is set to %TORCHRL_BUILD_VERSION% 4 | 5 | @echo on 6 | 7 | set VC_VERSION_LOWER=17 8 | set VC_VERSION_UPPER=18 9 | if "%VC_YEAR%" == "2019" ( 10 | set VC_VERSION_LOWER=16 11 | set VC_VERSION_UPPER=17 12 | ) 13 | if "%VC_YEAR%" == "2017" ( 14 | set VC_VERSION_LOWER=15 15 | set VC_VERSION_UPPER=16 16 | ) 17 | 18 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( 19 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 20 | set "VS15INSTALLDIR=%%i" 21 | set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" 22 | goto vswhere 23 | ) 24 | ) 25 | 26 | :vswhere 27 | if "%VSDEVCMD_ARGS%" == "" ( 28 | call "%VS15VCVARSALL%" x64 || exit /b 1 29 | ) else ( 30 | call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 31 | ) 32 | 33 | @echo on 34 | 35 | if "%CU_VERSION%" == "xpu" call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" 36 | 37 | set DISTUTILS_USE_SDK=1 38 | 39 | :: Upgrade setuptools before installing PyTorch 40 | pip install --upgrade setuptools==72.1.0 || exit /b 1 41 | 42 | set args=%1 43 | shift 44 | :start 45 | if [%1] == [] goto done 46 | set args=%args% %1 47 | shift 48 | goto start 49 | 50 | :done 51 | if "%args%" == "" ( 52 | echo Usage: vc_env_helper.bat [command] [args] 53 | echo e.g. vc_env_helper.bat cl /c test.cpp 54 | ) 55 | 56 | %args% || exit /b 1 57 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | - hypothesis 9 | - future 10 | - cloudpickle 11 | - pygame 12 | - moviepy<2.0.0 13 | - tqdm 14 | - pytest 15 | - pytest-cov 16 | - pytest-mock 17 | - pytest-instafail 18 | - pytest-rerunfailures 19 | - pytest-timeout 20 | - pytest-asyncio 21 | - expecttest 22 | - pybind11[global] 23 | - pyyaml 24 | - scipy 25 | - hydra-core 26 | - tensorboard 27 | - imageio==2.26.0 28 | - wandb 29 | - dm_control 30 | - mujoco 31 | - mlflow 32 | - av 33 | - coverage 34 | - ray 35 | - transformers 36 | - ninja 37 | - timm 38 | -------------------------------------------------------------------------------- /.github/unittest/linux/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_distributed/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | - hypothesis 9 | - future 10 | - cloudpickle 11 | - pygame 12 | - moviepy<2.0.0 13 | - tqdm 14 | - pytest 15 | - pytest-cov 16 | - pytest-mock 17 | - pytest-instafail 18 | - pytest-rerunfailures 19 | - pytest-asyncio 20 | - expecttest 21 | - pybind11[global] 22 | - pyyaml 23 | - scipy 24 | - hydra-core 25 | - tensorboard 26 | - imageio==2.26.0 27 | - wandb 28 | - dm_control 29 | - mujoco 30 | - mlflow 31 | - av 32 | - coverage 33 | - ray 34 | - virtualenv 35 | -------------------------------------------------------------------------------- /.github/unittest/linux_distributed/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_distributed/scripts/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | export PYTORCH_TEST_WITH_SLOW='1' 9 | python -m torch.utils.collect_env 10 | # Avoid error: "fatal: unsafe repository" 11 | git config --global --add safe.directory '*' 12 | 13 | root_dir="$(git rev-parse --show-toplevel)" 14 | env_dir="${root_dir}/env" 15 | lib_dir="${env_dir}/lib" 16 | 17 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 18 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 19 | export MKL_THREADING_LAYER=GNU 20 | export CKPT_BACKEND=torch 21 | export LAZY_LEGACY_OP=False 22 | export BATCHED_PIPE_TIMEOUT=60 23 | 24 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 25 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' 26 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200 --mp_fork_if_no_cuda 27 | coverage combine 28 | coverage xml -i 29 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_ataridqn/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - pip 7 | - gsutil 8 | - pip: 9 | - hypothesis 10 | - future 11 | - cloudpickle 12 | - pytest 13 | - pytest-cov 14 | - pytest-mock 15 | - pytest-instafail 16 | - pytest-rerunfailures 17 | - pytest-error-for-skips 18 | - pytest-asyncio 19 | - expecttest 20 | - pybind11[global] 21 | - pyyaml 22 | - scipy 23 | - hydra-core 24 | - tqdm 25 | - h5py 26 | - datasets 27 | - pillow 28 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_ataridqn/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_ataridqn/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export LAZY_LEGACY_OP=False 12 | export PYTORCH_TEST_WITH_SLOW='1' 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | env_dir="${root_dir}/env" 19 | lib_dir="${env_dir}/lib" 20 | 21 | conda deactivate && conda activate ./env 22 | 23 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestAtariDQN --error-for-skips --runslow 24 | coverage combine 25 | coverage xml -i 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_brax/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - jax[cuda12] 23 | - brax 24 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_brax/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_brax/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euxo pipefail 4 | 5 | apt update 6 | apt install -y libglfw3 libglfw3-dev libglew-dev libgl1-mesa-glx libgl1-mesa-dev mesa-common-dev libegl1-mesa-dev freeglut3 freeglut3-dev 7 | 8 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 9 | bash ${this_dir}/setup_env.sh 10 | bash ${this_dir}/install.sh 11 | bash ${this_dir}/run_test.sh 12 | bash ${this_dir}/post_process.sh 13 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_brax/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euxo pipefail 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | 9 | export PYTORCH_TEST_WITH_SLOW='1' 10 | export LAZY_LEGACY_OP=False 11 | python -m torch.utils.collect_env 12 | # Avoid error: "fatal: unsafe repository" 13 | git config --global --add safe.directory '*' 14 | 15 | root_dir="$(git rev-parse --show-toplevel)" 16 | env_dir="${root_dir}/env" 17 | lib_dir="${env_dir}/lib" 18 | 19 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 20 | # export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 21 | export MKL_THREADING_LAYER=GNU 22 | # more logging 23 | export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON 24 | 25 | #wget https://github.com/openai/mujoco-py/blob/master/vendor/10_nvidia.json 26 | #mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 27 | 28 | # this workflow only tests the libs 29 | python -c "import brax" 30 | python -c "import brax.envs" 31 | python -c "import jax" 32 | python3 -c 'import torch;t = torch.ones([2,2], device="cuda:0");print(t);print("tensor device:" + str(t.device))' 33 | 34 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestBrax --error-for-skips 35 | coverage combine 36 | coverage xml -i 37 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_chess/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - chess 23 | - transformers 24 | - cairosvg 25 | - pycairo 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_chess/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_chess/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get install -y git wget cmake 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | 20 | conda deactivate && conda activate ./env 21 | 22 | # this workflow only tests the libs 23 | python -c "import chess" 24 | 25 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_env.py --instafail -v --durations 200 --capture no -k TestChessEnv --error-for-skips --runslow 26 | 27 | coverage combine 28 | coverage xml -i 29 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_d4rl/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - cython<3 23 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_d4rl/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_d4rl/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | # we install d4rl here bc env variables have been updated 12 | git clone https://github.com/Farama-Foundation/d4rl.git 13 | cd d4rl 14 | #pip3 install -U 'mujoco-py<2.1,>=2.0' 15 | pip3 install -U "gym[classic_control,atari,accept-rom-license]"==0.23 16 | pip3 install -U six 17 | pip install -e . 18 | cd .. 19 | 20 | #flow is a dependency disaster of biblical scale 21 | #git clone https://github.com/flow-project/flow.git 22 | #cd flow 23 | #python setup.py develop 24 | #cd .. 25 | 26 | export PYTORCH_TEST_WITH_SLOW='1' 27 | export LAZY_LEGACY_OP=False 28 | python -m torch.utils.collect_env 29 | # Avoid error: "fatal: unsafe repository" 30 | git config --global --add safe.directory '*' 31 | 32 | root_dir="$(git rev-parse --show-toplevel)" 33 | env_dir="${root_dir}/env" 34 | lib_dir="${env_dir}/lib" 35 | 36 | conda deactivate && conda activate ./env 37 | 38 | # this workflow only tests the libs 39 | printf "* Smoke test\n" 40 | 41 | python -c """import gym 42 | import d4rl 43 | """ 44 | 45 | printf "* Tests" 46 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestD4RL --error-for-skips --runslow 47 | coverage combine 48 | coverage xml -i 49 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_envpool/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | - hypothesis 9 | - future 10 | - cloudpickle 11 | - pygame 12 | - moviepy<2.0.0 13 | - pytest-cov 14 | - pytest-mock 15 | - pytest-instafail 16 | - pytest-rerunfailures 17 | - pytest-error-for-skips 18 | - pytest-asyncio 19 | - expecttest 20 | - pybind11[global] 21 | - pyyaml 22 | - scipy 23 | - dm_control 24 | - mujoco 25 | - coverage 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_envpool/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | unset PYTORCH_VERSION 4 | # For unittest, nightly PyTorch is used as the following section, 5 | # so no need to set PYTORCH_VERSION. 6 | # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. 7 | 8 | set -e 9 | 10 | eval "$(./conda/bin/conda shell.bash hook)" 11 | conda activate ./env 12 | 13 | if [ "${CU_VERSION:-}" == cpu ] ; then 14 | version="cpu" 15 | echo "Using cpu build" 16 | else 17 | if [[ ${#CU_VERSION} -eq 4 ]]; then 18 | CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" 19 | elif [[ ${#CU_VERSION} -eq 5 ]]; then 20 | CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" 21 | fi 22 | echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" 23 | version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" 24 | fi 25 | 26 | # submodules 27 | git submodule sync && git submodule update --init --recursive 28 | 29 | printf "Installing PyTorch with cu128" 30 | if [ "${CU_VERSION:-}" == cpu ] ; then 31 | pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U 32 | else 33 | pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U 34 | fi 35 | 36 | # smoke test 37 | python -c "import functorch" 38 | 39 | # install tensordict 40 | pip install git+https://github.com/pytorch/tensordict 41 | 42 | printf "* Installing torchrl\n" 43 | python setup.py develop 44 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_envpool/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_envpool/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # this code is supposed to run on CPU 4 | # rendering with the combination of packages we have here in headless mode 5 | # is hard to nail. 6 | # IMPORTANT: As a consequence, we can't guarantee TorchRL compatibility with 7 | # rendering with this version of gym / mujoco-py. 8 | 9 | set -e 10 | 11 | eval "$(./conda/bin/conda shell.bash hook)" 12 | conda activate ./env 13 | 14 | export PYTORCH_TEST_WITH_SLOW='1' 15 | export LAZY_LEGACY_OP=False 16 | python -m torch.utils.collect_env 17 | # Avoid error: "fatal: unsafe repository" 18 | git config --global --add safe.directory '*' 19 | 20 | root_dir="$(git rev-parse --show-toplevel)" 21 | env_dir="${root_dir}/env" 22 | lib_dir="${env_dir}/lib" 23 | 24 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 25 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 26 | export MKL_THREADING_LAYER=GNU 27 | 28 | # this workflow only tests the libs 29 | python -c "import envpool" 30 | 31 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestEnvPool --error-for-skips 32 | coverage combine 33 | coverage xml -i 34 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_gen-dgrl/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - huggingface_hub 23 | - tqdm 24 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_gen-dgrl/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_gen-dgrl/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export PYTORCH_TEST_WITH_SLOW='1' 12 | export LAZY_LEGACY_OP=False 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | env_dir="${root_dir}/env" 19 | lib_dir="${env_dir}/lib" 20 | 21 | conda deactivate && conda activate ./env 22 | 23 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestGenDGRL --error-for-skips --runslow 24 | coverage combine 25 | coverage xml -i 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_gym/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | # Initial version is required to install Atari ROMS in setup_env.sh 9 | - gym[atari]==0.13 10 | - hypothesis 11 | - future 12 | - cloudpickle 13 | - pygame 14 | - moviepy<2.0.0 15 | - tqdm 16 | - pytest 17 | - pytest-cov 18 | - pytest-mock 19 | - pytest-instafail 20 | - pytest-rerunfailures 21 | - pytest-error-for-skips 22 | - pytest-asyncio 23 | - expecttest 24 | - pybind11[global] 25 | - pyyaml 26 | - scipy 27 | - hydra-core 28 | - patchelf 29 | - pyopengl==3.1.0 30 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_gym/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_gym/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | export PYTORCH_TEST_WITH_SLOW='1' 9 | export LAZY_LEGACY_OP=False 10 | python -m torch.utils.collect_env 11 | # Avoid error: "fatal: unsafe repository" 12 | git config --global --add safe.directory '*' 13 | 14 | root_dir="$(git rev-parse --show-toplevel)" 15 | env_dir="${root_dir}/env" 16 | lib_dir="${env_dir}/lib" 17 | 18 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 19 | export MKL_THREADING_LAYER=GNU 20 | 21 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 22 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym' 23 | 24 | unset LD_PRELOAD 25 | export DISPLAY=:99 26 | Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & 27 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips --mp_fork 28 | coverage combine 29 | coverage xml -i 30 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_habitat/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_habitat/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-error-for-skips 15 | - pytest-rerunfailures 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy==1.9.1 21 | - hydra-core 22 | - ninja 23 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_habitat/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | unset PYTORCH_VERSION 4 | 5 | set -e 6 | set -v 7 | 8 | eval "$(./conda/bin/conda shell.bash hook)" 9 | conda activate ./env 10 | 11 | if [[ ${#CU_VERSION} -eq 4 ]]; then 12 | CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" 13 | elif [[ ${#CU_VERSION} -eq 5 ]]; then 14 | CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" 15 | fi 16 | echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" 17 | version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" 18 | 19 | # submodules 20 | git submodule sync && git submodule update --init --recursive 21 | 22 | printf "Installing PyTorch with %s\n" "${CU_VERSION}" 23 | if [[ "$TORCH_VERSION" == "nightly" ]]; then 24 | pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U 25 | elif [[ "$TORCH_VERSION" == "stable" ]]; then 26 | pip3 install torch --index-url https://download.pytorch.org/whl/cu128 27 | fi 28 | 29 | # install tensordict 30 | # install tensordict 31 | if [[ "$RELEASE" == 0 ]]; then 32 | pip3 install git+https://github.com/pytorch/tensordict.git 33 | else 34 | pip3 install tensordict 35 | fi 36 | 37 | # smoke test 38 | python3 -c "import functorch;import tensordict" 39 | 40 | printf "* Installing torchrl\n" 41 | python setup.py develop 42 | 43 | # smoke test 44 | python3 -c "import torchrl" 45 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_habitat/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_habitat/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euxo pipefail 4 | set -v 5 | 6 | 7 | apt-get update && apt-get upgrade -y 8 | apt-get install -y vim git wget cmake 9 | 10 | apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev 11 | apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 12 | 13 | apt-get install -y g++ gcc 14 | #apt-get upgrade -y libstdc++6 15 | #apt-get install -y libgcc 16 | apt-get dist-upgrade -y 17 | 18 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 19 | # from cudagl docker image 20 | cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 21 | 22 | bash ${this_dir}/setup_env.sh 23 | bash ${this_dir}/install.sh 24 | 25 | #apt-get install -y freeglut3 freeglut3-dev 26 | bash ${this_dir}/run_test.sh 27 | bash ${this_dir}/post_process.sh 28 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_jumanji/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - jumanji 23 | - gymnasium<1.0.0 24 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_jumanji/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_jumanji/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | apt-get update && apt-get install -y git wget cmake 8 | 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | 20 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 21 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 22 | export MKL_THREADING_LAYER=GNU 23 | # more logging 24 | export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON 25 | 26 | #wget https://github.com/openai/mujoco-py/blob/master/vendor/10_nvidia.json 27 | #mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 28 | 29 | # this workflow only tests the libs 30 | python -c "import jumanji" 31 | 32 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestJumanji --error-for-skips --runslow 33 | coverage combine 34 | coverage xml -i 35 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_llm/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - transformers 23 | - datasets 24 | - vllm 25 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_llm/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_llm/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get install -y git gcc cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export PYTORCH_TEST_WITH_SLOW='1' 12 | export LAZY_LEGACY_OP=False 13 | 14 | # to solve RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method 15 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 16 | python -m torch.utils.collect_env 17 | # Avoid error: "fatal: unsafe repository" 18 | git config --global --add safe.directory '*' 19 | 20 | root_dir="$(git rev-parse --show-toplevel)" 21 | env_dir="${root_dir}/env" 22 | lib_dir="${env_dir}/lib" 23 | 24 | conda deactivate && conda activate ./env 25 | 26 | python -c "import transformers, datasets" 27 | 28 | pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips 29 | 30 | python examples/rlhf/train_rlhf.py \ 31 | sys.device=cuda:0 sys.ref_device=cuda:0 \ 32 | model.name_or_path=gpt2 train.max_epochs=2 \ 33 | data.batch_size=2 train.ppo.ppo_batch_size=2 \ 34 | train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \ 35 | train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \ 36 | data.block_size=110 io.logger=csv 37 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_meltingpot/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - cloudpickle 8 | - torch 9 | - pytest 10 | - pytest-cov 11 | - pytest-mock 12 | - pytest-instafail 13 | - pytest-rerunfailures 14 | - pytest-error-for-skips 15 | - pytest-asyncio 16 | - expecttest 17 | - pybind11[global] 18 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_meltingpot/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_meltingpot/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | apt-get update && apt-get install -y git wget cmake 8 | 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | 20 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 21 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 22 | export MKL_THREADING_LAYER=GNU 23 | # more logging 24 | export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON 25 | 26 | # this workflow only tests the libs 27 | python -c "import meltingpot" 28 | 29 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMeltingpot --error-for-skips 30 | coverage combine 31 | coverage xml -i 32 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_minari/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - minari[gcs,hdf5,hf] 23 | - gymnasium<1.0.0 24 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_minari/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_minari/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export PYTORCH_TEST_WITH_SLOW='1' 12 | export LAZY_LEGACY_OP=False 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | env_dir="${root_dir}/env" 19 | lib_dir="${env_dir}/lib" 20 | 21 | conda deactivate && conda activate ./env 22 | 23 | # this workflow only tests the libs 24 | python -c "import minari" 25 | 26 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMinari --error-for-skips --runslow 27 | coverage combine 28 | coverage xml -i 29 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_open_spiel/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - expecttest 17 | - pybind11[global] 18 | - pyyaml 19 | - scipy 20 | - hydra-core 21 | - open_spiel 22 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_open_spiel/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_open_spiel/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get install -y git wget cmake 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | 20 | conda deactivate && conda activate ./env 21 | 22 | # this workflow only tests the libs 23 | python -c "import pyspiel" 24 | 25 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenSpiel --error-for-skips --runslow 26 | 27 | coverage combine 28 | coverage xml -i 29 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_openx/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - tqdm 23 | - h5py 24 | - datasets 25 | - pillow 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_openx/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_openx/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export PYTORCH_TEST_WITH_SLOW='1' 12 | export LAZY_LEGACY_OP=False 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | env_dir="${root_dir}/env" 19 | lib_dir="${env_dir}/lib" 20 | 21 | conda deactivate && conda activate ./env 22 | 23 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenX --error-for-skips --runslow 24 | coverage combine 25 | coverage xml -i 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_pettingzoo/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - swig 7 | - pip: 8 | - cloudpickle 9 | - gym 10 | - gym-notices 11 | - six 12 | - zipp 13 | - pytest 14 | - pytest-cov 15 | - pytest-mock 16 | - pytest-instafail 17 | - pytest-rerunfailures 18 | - pytest-error-for-skips 19 | - expecttest 20 | - pybind11[global] 21 | - pyyaml 22 | - autorom[accept-rom-license] 23 | - pettingzoo[all]==1.24.3 24 | - gymnasium<1.0.0 25 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_pettingzoo/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_pettingzoo/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | apt-get update && apt-get install -y git wget cmake 8 | 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | 20 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 21 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 22 | export MKL_THREADING_LAYER=GNU 23 | # more logging 24 | export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON 25 | 26 | # this workflow only tests the libs 27 | python -c "import pettingzoo" 28 | 29 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestPettingZoo --error-for-skips 30 | coverage combine 31 | coverage xml -i 32 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_robohive/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | # Initial version is required to install Atari ROMS in setup_env.sh 9 | - gymnasium<1.0 10 | - hypothesis 11 | - future 12 | - cloudpickle 13 | - pygame 14 | - moviepy<2.0.0 15 | - tqdm 16 | - pytest 17 | - pytest-cov 18 | - pytest-mock 19 | - pytest-instafail 20 | - pytest-rerunfailures 21 | - pytest-error-for-skips 22 | - pytest-asyncio 23 | - expecttest 24 | - pybind11[global] 25 | - pyyaml 26 | - scipy 27 | - hydra-core 28 | - patchelf 29 | - mujoco==2.3.3 30 | - dm_control==1.0.11 31 | - numpy<2.0.0 32 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_robohive/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_roboset/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - huggingface_hub 23 | - tqdm 24 | - h5py 25 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_roboset/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_roboset/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export PYTORCH_TEST_WITH_SLOW='1' 12 | export LAZY_LEGACY_OP=False 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | env_dir="${root_dir}/env" 19 | lib_dir="${env_dir}/lib" 20 | 21 | conda deactivate && conda activate ./env 22 | 23 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestRoboset --error-for-skips --runslow 24 | coverage combine 25 | coverage xml -i 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_sklearn/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - scikit-learn 23 | - pandas 24 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_sklearn/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_sklearn/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get install -y git gcc cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export PYTORCH_TEST_WITH_SLOW='1' 12 | export LAZY_LEGACY_OP=False 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | env_dir="${root_dir}/env" 19 | lib_dir="${env_dir}/lib" 20 | 21 | conda deactivate && conda activate ./env 22 | 23 | # this workflow only tests the libs 24 | python -c "import sklearn, pandas" 25 | 26 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenML --error-for-skips --runslow 27 | coverage combine 28 | coverage xml -i 29 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_smacv2/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - cloudpickle 8 | - gym 9 | - gym-notices 10 | - zipp 11 | - pytest 12 | - pytest-cov 13 | - pytest-mock 14 | - pytest-instafail 15 | - pytest-rerunfailures 16 | - pytest-error-for-skips 17 | - pytest-asyncio 18 | - expecttest 19 | - pybind11[global] 20 | - pyyaml 21 | - numpy==1.23.0 22 | - git+https://github.com/oxwhirl/smacv2.git 23 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_smacv2/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_smacv2/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | apt-get update && apt-get install -y git wget cmake 8 | 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | export SC2PATH="${root_dir}/StarCraftII" 20 | echo 'SC2PATH is set to ' "$SC2PATH" 21 | 22 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 23 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 24 | export MKL_THREADING_LAYER=GNU 25 | # more logging 26 | export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON 27 | 28 | # this workflow only tests the libs 29 | python -c "import smacv2" 30 | 31 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestSmacv2 --error-for-skips 32 | coverage combine 33 | coverage xml -i 34 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - python==3.10.12 6 | - pip 7 | - pip: 8 | - mlagents_envs==1.0.0 9 | - hypothesis 10 | - future 11 | - cloudpickle 12 | - pytest 13 | - pytest-cov 14 | - pytest-mock 15 | - pytest-instafail 16 | - pytest-rerunfailures 17 | - pytest-error-for-skips 18 | - expecttest 19 | - pybind11[global] 20 | - pyyaml 21 | - scipy 22 | - hydra-core 23 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get install -y git wget cmake 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | 20 | conda deactivate && conda activate ./env 21 | 22 | # this workflow only tests the libs 23 | python -c "import mlagents_envs" 24 | 25 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow 26 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py --instafail -v --durations 200 --capture no -k test_transform_env[unity] 27 | 28 | coverage combine 29 | coverage xml -i 30 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_vd4rl/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-error-for-skips 16 | - pytest-asyncio 17 | - expecttest 18 | - pybind11[global] 19 | - pyyaml 20 | - scipy 21 | - hydra-core 22 | - huggingface_hub 23 | - tqdm 24 | - h5py 25 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_vd4rl/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_vd4rl/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake 9 | ln -s /usr/bin/swig3.0 /usr/bin/swig 10 | 11 | export PYTORCH_TEST_WITH_SLOW='1' 12 | export LAZY_LEGACY_OP=False 13 | python -m torch.utils.collect_env 14 | # Avoid error: "fatal: unsafe repository" 15 | git config --global --add safe.directory '*' 16 | 17 | root_dir="$(git rev-parse --show-toplevel)" 18 | env_dir="${root_dir}/env" 19 | lib_dir="${env_dir}/lib" 20 | 21 | conda deactivate && conda activate ./env 22 | 23 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestVD4RL --error-for-skips --runslow 24 | coverage combine 25 | coverage xml -i 26 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_vmas/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - cloudpickle 8 | - gym 9 | - gym-notices 10 | - numpy 11 | - pyglet==1.5.27 12 | - six 13 | - torch 14 | - vmas 15 | - zipp 16 | - pytest 17 | - pytest-cov 18 | - pytest-mock 19 | - pytest-instafail 20 | - pytest-rerunfailures 21 | - pytest-error-for-skips 22 | - pytest-asyncio 23 | - expecttest 24 | - pybind11[global] 25 | - pyyaml 26 | - scipy 27 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_vmas/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_libs/scripts_vmas/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | apt-get update && apt-get install -y git wget cmake 8 | 9 | 10 | export PYTORCH_TEST_WITH_SLOW='1' 11 | export LAZY_LEGACY_OP=False 12 | python -m torch.utils.collect_env 13 | # Avoid error: "fatal: unsafe repository" 14 | git config --global --add safe.directory '*' 15 | 16 | root_dir="$(git rev-parse --show-toplevel)" 17 | env_dir="${root_dir}/env" 18 | lib_dir="${env_dir}/lib" 19 | 20 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 21 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 22 | export MKL_THREADING_LAYER=GNU 23 | # more logging 24 | export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON 25 | 26 | # this workflow only tests the libs 27 | python -c "import vmas" 28 | 29 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestVmas --error-for-skips 30 | coverage combine 31 | coverage xml -i 32 | -------------------------------------------------------------------------------- /.github/unittest/linux_olddeps/scripts_gym_0_13/batch_scripts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Runs a batch of scripts in a row to allow docker run to keep installed libraries 4 | # and env variables across runs. 5 | 6 | DIR="$(cd "$(dirname "$0")" && pwd)" 7 | 8 | $DIR/install.sh 9 | $DIR/run_test.sh 10 | -------------------------------------------------------------------------------- /.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | - hypothesis 9 | - future 10 | - cloudpickle 11 | - gym[atari]==0.13 12 | - pygame 13 | - moviepy<2.0.0 14 | - tqdm 15 | - pytest 16 | - pytest-cov 17 | - pytest-mock 18 | - pytest-instafail 19 | - pytest-rerunfailures 20 | - expecttest 21 | - pyyaml 22 | - scipy 23 | - hydra-core 24 | - mujoco 25 | - patchelf 26 | - pyopengl==3.1.4 27 | - ray 28 | - av 29 | - h5py 30 | -------------------------------------------------------------------------------- /.github/unittest/linux_olddeps/scripts_gym_0_13/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -v 5 | 6 | eval "$(./conda/bin/conda shell.bash hook)" 7 | conda activate ./env 8 | -------------------------------------------------------------------------------- /.github/unittest/linux_olddeps/scripts_gym_0_13/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -v 5 | 6 | eval "$(./conda/bin/conda shell.bash hook)" 7 | conda activate ./env 8 | 9 | export PYTORCH_TEST_WITH_SLOW='1' 10 | export LAZY_LEGACY_OP=False 11 | python -m torch.utils.collect_env 12 | # Avoid error: "fatal: unsafe repository" 13 | git config --global --add safe.directory '*' 14 | 15 | root_dir="$(git rev-parse --show-toplevel)" 16 | env_dir="${root_dir}/env" 17 | lib_dir="${env_dir}/lib" 18 | 19 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 20 | #export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/work/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin 21 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/pytorch/rl/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin 22 | export MKL_THREADING_LAYER=GNU 23 | export BATCHED_PIPE_TIMEOUT=60 24 | 25 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 26 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym' 27 | 28 | export DISPLAY=:99 29 | Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & 30 | 31 | CKPT_BACKEND=torch MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest \ 32 | --instafail -v \ 33 | --durations 200 \ 34 | --ignore test/test_distributed.py \ 35 | --ignore test/test_rlhf.py \ 36 | --ignore test/llm \ 37 | --mp_fork_if_no_cuda 38 | 39 | #pytest --instafail -v --durations 200 40 | #python test/test_libs.py 41 | coverage combine 42 | coverage xml -i 43 | -------------------------------------------------------------------------------- /.github/unittest/linux_optdeps/scripts/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_optdeps/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - pytest-timeout 16 | - expecttest 17 | - pybind11[global] 18 | - pyyaml 19 | - scipy 20 | - coverage 21 | - ray 22 | -------------------------------------------------------------------------------- /.github/unittest/linux_optdeps/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_sota/scripts/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_sota/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - protobuf 7 | - pip: 8 | - hypothesis 9 | - future 10 | - cloudpickle 11 | - pygame 12 | - moviepy<2.0.0 13 | - tqdm 14 | - pytest 15 | - pytest-cov 16 | - pytest-mock 17 | - pytest-instafail 18 | - pytest-rerunfailures 19 | - expecttest 20 | - pybind11[global] 21 | - pyyaml 22 | - scipy 23 | - hydra-core 24 | - imageio==2.26.0 25 | - dm_control 26 | - mujoco 27 | - mlflow 28 | - av 29 | - coverage 30 | - vmas 31 | - transformers 32 | -------------------------------------------------------------------------------- /.github/unittest/linux_sota/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | -------------------------------------------------------------------------------- /.github/unittest/linux_sota/scripts/run_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -v 5 | 6 | # Read script from line 29 7 | filename=".github/unittest/linux_examples/scripts/run_test.sh" 8 | start_line=29 9 | script=$(tail -n +$start_line "$filename") 10 | script="set -e"$'\n'"$script" 11 | 12 | # Replace "cuda:0" with "cpu" 13 | script="${script//cuda:0/cpu}" 14 | 15 | # Remove any instances of ".github/unittest/helpers/coverage_run_parallel.py" 16 | script="${script//.github\/unittest\/helpers\/coverage_run_parallel.py}" 17 | script="${script//coverage combine}" 18 | script="${script//coverage xml -i}" 19 | 20 | # Execute the modified script 21 | echo "$script" | bash 22 | -------------------------------------------------------------------------------- /.github/unittest/linux_sota/scripts/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -v 5 | 6 | # Initialize an error flag 7 | error_occurred=0 8 | # Function to handle errors 9 | error_handler() { 10 | echo "Error on line $1" 11 | error_occurred=1 12 | } 13 | # Trap ERR to call the error_handler function with the failing line number 14 | trap 'error_handler $LINENO' ERR 15 | 16 | export PYTORCH_TEST_WITH_SLOW='1' 17 | python -m torch.utils.collect_env 18 | # Avoid error: "fatal: unsafe repository" 19 | git config --global --add safe.directory '*' 20 | 21 | root_dir="$(git rev-parse --show-toplevel)" 22 | env_dir="${root_dir}/env" 23 | lib_dir="${env_dir}/lib" 24 | 25 | # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found 26 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir 27 | export MKL_THREADING_LAYER=GNU 28 | export CUDA_LAUNCH_BLOCKING=1 29 | 30 | python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 31 | 32 | coverage run -m pytest .github/unittest/linux_sota/scripts/test_sota.py --instafail --durations 200 -vvv --capture no 33 | 34 | coverage combine 35 | coverage xml -i 36 | -------------------------------------------------------------------------------- /.github/unittest/windows_optdepts/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - hypothesis 8 | - future 9 | - cloudpickle 10 | - pytest 11 | - pytest-cov 12 | - pytest-mock 13 | - pytest-instafail 14 | - pytest-rerunfailures 15 | - expecttest 16 | - pyyaml 17 | - scipy 18 | - coverage 19 | -------------------------------------------------------------------------------- /.github/unittest/windows_optdepts/scripts/install_conda.bat: -------------------------------------------------------------------------------- 1 | start /wait "" "%miniconda_exe%" /S /InstallationType=JustMe /RegisterPython=0 /AddToPath=0 /D=%tmp_conda% 2 | -------------------------------------------------------------------------------- /.github/unittest/windows_optdepts/scripts/vc_env_helper.bat: -------------------------------------------------------------------------------- 1 | @echo on 2 | 3 | set VC_VERSION_LOWER=16 4 | set VC_VERSION_UPPER=17 5 | 6 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( 7 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 8 | set "VS15INSTALLDIR=%%i" 9 | set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" 10 | goto vswhere 11 | ) 12 | ) 13 | 14 | :vswhere 15 | if "%VSDEVCMD_ARGS%" == "" ( 16 | call "%VS15VCVARSALL%" x64 || exit /b 1 17 | ) else ( 18 | call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 19 | ) 20 | 21 | @echo on 22 | 23 | set DISTUTILS_USE_SDK=1 24 | 25 | set args=%1 26 | shift 27 | :start 28 | if [%1] == [] goto done 29 | set args=%args% %1 30 | shift 31 | goto start 32 | 33 | :done 34 | if "%args%" == "" ( 35 | echo Usage: vc_env_helper.bat [command] [args] 36 | echo e.g. vc_env_helper.bat cl /c test.cpp 37 | ) 38 | 39 | %args% || exit /b 1 40 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-docstring-first 6 | - id: check-toml 7 | - id: check-yaml 8 | exclude: packaging/.* 9 | - id: mixed-line-ending 10 | args: [--fix=lf] 11 | - id: end-of-file-fixer 12 | 13 | - repo: https://github.com/omnilib/ufmt 14 | rev: v2.0.0b2 15 | hooks: 16 | - id: ufmt 17 | additional_dependencies: 18 | - black == 22.3.0 19 | - usort == 1.0.3 20 | - libcst == 0.4.7 21 | 22 | - repo: https://github.com/pycqa/flake8 23 | rev: 4.0.1 24 | hooks: 25 | - id: flake8 26 | args: [--config=setup.cfg] 27 | additional_dependencies: 28 | - flake8-bugbear==22.10.27 29 | - flake8-comprehensions==3.10.1 30 | - torchfix==0.0.2 31 | - flake8-print==5.0.0 32 | 33 | - repo: https://github.com/PyCQA/pydocstyle 34 | rev: 6.1.1 35 | hooks: 36 | - id: pydocstyle 37 | files: ^torchrl/ 38 | 39 | - repo: https://github.com/asottile/pyupgrade 40 | rev: v3.9.0 41 | hooks: 42 | - id: pyupgrade 43 | args: [--py38-plus] 44 | 45 | - repo: local 46 | hooks: 47 | - id: autoflake 48 | name: autoflake 49 | entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports 50 | language: system 51 | types: [python] 52 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: If you use this software, please cite it as below. 3 | title: TorchRL 4 | authors: 5 | - family-names: PyTorch Team 6 | url: https://pytorch.org/rl 7 | preferred-citation: 8 | type: conference-paper 9 | title: "TorchRL: A data-driven decision-making library for PyTorch" 10 | authors: 11 | - family-names: Bou 12 | given-names: Albert 13 | - family-names: Bettini 14 | given-names: Matteo 15 | - family-names: Dittert 16 | given-names: Sebastian 17 | - family-names: Kumar 18 | given-names: Vikash 19 | - family-names: Sodhani 20 | given-names: Shagun 21 | - family-names: Yang 22 | given-names: Xiaomeng 23 | - family-names: De Fabritiis 24 | given-names: Gianni 25 | - family-names: Moens 26 | given-names: Vincent 27 | collection-title: arXiv 28 | year: 2023 29 | url: https://arxiv.org/abs/2306.00577 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmarks/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest-benchmark 2 | tenacity 3 | -------------------------------------------------------------------------------- /build_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /build_tools/setup_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .extension import CMakeBuild, get_ext_modules 7 | 8 | __all__ = ["CMakeBuild", "get_ext_modules"] 9 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS = -v 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/ppo.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | sphinx-copybutton 4 | sphinx-gallery 5 | sphinx===5.0.0 6 | Jinja2==3.1.4 7 | sphinx-autodoc-typehints 8 | sphinx-serve==1.0.1 9 | git+https://github.com/vmoens/aafig@4319769eae88fff8e3464858f3cf8c277f35335d 10 | sphinxcontrib-htmlhelp 11 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 12 | myst-parser 13 | docutils 14 | sphinx_design 15 | 16 | torchvision 17 | dm_control 18 | mujoco 19 | gym[classic_control,accept-rom-license,ale-py,atari] 20 | pygame 21 | tqdm 22 | ipython 23 | imageio[ffmpeg,pyav] 24 | memory_profiler 25 | pyrender 26 | pytest 27 | vmas 28 | onnxscript 29 | onnxruntime 30 | onnx 31 | psutil 32 | -------------------------------------------------------------------------------- /docs/source/_static/img/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/cartpole.gif -------------------------------------------------------------------------------- /docs/source/_static/img/cartpole_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/cartpole_demo.gif -------------------------------------------------------------------------------- /docs/source/_static/img/collector-copy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/collector-copy.png -------------------------------------------------------------------------------- /docs/source/_static/img/dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/dqn.png -------------------------------------------------------------------------------- /docs/source/_static/img/dqn_td0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/dqn_td0.png -------------------------------------------------------------------------------- /docs/source/_static/img/dqn_tdlambda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/dqn_tdlambda.png -------------------------------------------------------------------------------- /docs/source/_static/img/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/favicon.png -------------------------------------------------------------------------------- /docs/source/_static/img/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/icon.png -------------------------------------------------------------------------------- /docs/source/_static/img/invpendulum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/invpendulum.gif -------------------------------------------------------------------------------- /docs/source/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/logo.png -------------------------------------------------------------------------------- /docs/source/_static/img/mcts_forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/mcts_forest.png -------------------------------------------------------------------------------- /docs/source/_static/img/pendulum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/pendulum.gif -------------------------------------------------------------------------------- /docs/source/_static/img/rename_transform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/rename_transform.png -------------------------------------------------------------------------------- /docs/source/_static/img/replaybuffer_traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/replaybuffer_traj.png -------------------------------------------------------------------------------- /docs/source/_static/img/rollout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/rollout.gif -------------------------------------------------------------------------------- /docs/source/_static/img/rollout_recurrent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/rollout_recurrent.png -------------------------------------------------------------------------------- /docs/source/_static/img/transforms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/docs/source/_static/img/transforms.png -------------------------------------------------------------------------------- /docs/source/_templates/class.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autoclass:: {{ name }} 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/_templates/function.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autofunction:: {{ name }} 7 | -------------------------------------------------------------------------------- /docs/source/_templates/rl_template.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autoclass:: {{ name }} 7 | :members: 8 | :inherited-members: 9 | -------------------------------------------------------------------------------- /docs/source/_templates/rl_template_fun.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autofunction:: {{ name }} 7 | -------------------------------------------------------------------------------- /docs/source/_templates/rl_template_noinherit.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: {{ module }} 2 | 3 | 4 | {{ name | underline}} 5 | 6 | .. autoclass:: {{ name }} 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/docutils.conf: -------------------------------------------------------------------------------- 1 | [html writers] 2 | table_style: colwidths-auto # Necessary for the table generated by autosummary to look decent 3 | -------------------------------------------------------------------------------- /docs/source/reference/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | 7 | collectors 8 | data 9 | envs 10 | llms 11 | modules 12 | objectives 13 | trainers 14 | utils 15 | -------------------------------------------------------------------------------- /docs/source/reference/knowledge_base.rst: -------------------------------------------------------------------------------- 1 | Knowledge Base 2 | ============== 3 | 4 | .. _ref_knowledge_base: 5 | 6 | .. include:: ../../../knowledge_base/README.md 7 | :start-line: 1 8 | :parser: myst_parser.sphinx_ 9 | 10 | .. toctree:: 11 | :glob: 12 | :maxdepth: 1 13 | :caption: Contents: 14 | 15 | generated/knowledge_base/* 16 | -------------------------------------------------------------------------------- /docs/source/reference/utils.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: torchrl 2 | 3 | torchrl._utils package 4 | ==================== 5 | 6 | Set of utility methods that are used internally by the library. 7 | 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :template: rl_template.rst 12 | 13 | implement_for 14 | set_auto_unwrap_transformed_env 15 | auto_unwrap_transformed_env 16 | -------------------------------------------------------------------------------- /examples/distributed/collectors/README.md: -------------------------------------------------------------------------------- 1 | # Distributed data collection examples 2 | 3 | If your algorithm is bound by the data collection speed, you may consider using 4 | distributed data collector to make your training faster. 5 | TorchRL offers a bunch of distributed data collectors that you can use 6 | to increase the collection speed tenfold or more. 7 | 8 | These examples are divided in a single machine and a multi-node series. 9 | 10 | Refer to the [documentation](https://pytorch.org/rl/reference/collectors.html) 11 | for more insight on what you can expect do 12 | and how these tools should be used. 13 | -------------------------------------------------------------------------------- /examples/distributed/collectors/multi_nodes/ray_collect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example use of a distributed collector 3 | ====================================== 4 | 5 | This example illustrates how a TorchRL collector can be converted into a distributed collector. 6 | 7 | This example should create 3 collector instances, 1 local and 2 remote, but 4 instances seem to 8 | be created. Why? 9 | """ 10 | from tensordict.nn import TensorDictModule 11 | from torch import nn 12 | from torchrl._utils import logger as torchrl_logger 13 | from torchrl.collectors.distributed.ray import RayCollector 14 | from torchrl.envs.libs.gym import GymEnv 15 | 16 | 17 | if __name__ == "__main__": 18 | 19 | # 1. Create environment factory 20 | def env_maker(): 21 | return GymEnv("Pendulum-v1", device="cpu") 22 | 23 | policy = TensorDictModule( 24 | nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"] 25 | ) 26 | 27 | # 2. Define distributed collector 28 | remote_config = { 29 | "num_cpus": 1, 30 | "num_gpus": 0, 31 | "memory": 5 * 1024**3, 32 | "object_store_memory": 2 * 1024**3, 33 | } 34 | distributed_collector = RayCollector( 35 | [env_maker], 36 | policy, 37 | total_frames=10000, 38 | frames_per_batch=200, 39 | remote_configs=remote_config, 40 | ) 41 | 42 | # Sample batches until reaching total_frames 43 | counter = 0 44 | num_frames = 0 45 | for batch in distributed_collector: 46 | counter += 1 47 | num_frames += batch.shape.numel() 48 | torchrl_logger.info(f"batch {counter}, total frames {num_frames}") 49 | -------------------------------------------------------------------------------- /examples/rlhf/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.bin 3 | *.pt 4 | *.json 5 | -------------------------------------------------------------------------------- /examples/rlhf/config/train.yaml: -------------------------------------------------------------------------------- 1 | io: 2 | eval_interval: 200 3 | log_interval: 50 4 | eval_iters: 100 5 | data: 6 | batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size 7 | block_size: 550 8 | model: 9 | name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint 10 | out_dir: ./out 11 | dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ 12 | train: 13 | grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 14 | max_iters: 5000 # total number of training iterations 15 | gradient_accumulation_steps: 2 # used to simulate larger batch sizes 16 | always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir 17 | decay_lr: True # whether to decay the learning rate 18 | optimizer: 19 | # keyword arguments for torch.optim.AdamW 20 | lr: 1.0e-5 21 | weight_decay: 1.0e-1 22 | betas: [0.9, 0.95] 23 | scheduler: 24 | # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR 25 | T_max: 5000 # maximum number of iterations 26 | eta_min: 1.0e-6 # minimum learning rate 27 | sys: 28 | device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks 29 | dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler 30 | compile: True # use PyTorch 2.0 to compile the model to be faster 31 | -------------------------------------------------------------------------------- /examples/rlhf/config/train_reward.yaml: -------------------------------------------------------------------------------- 1 | io: 2 | eval_interval: 200 3 | log_interval: 50 4 | eval_iters: 100 5 | data: 6 | batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size 7 | block_size: 550 8 | model: 9 | name_or_path: ./out 10 | dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ 11 | reward_model: 12 | out_dir: ./out_reward 13 | init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward 14 | train: 15 | grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 16 | max_iters: 20000 # total number of training iterations 17 | gradient_accumulation_steps: 2 # used to simulate larger batch sizes 18 | always_save_checkpoint: False # if True, always save a checkpoint after each eval 19 | decay_lr: False # whether to decay the learning rate 20 | optimizer: 21 | # keyword arguments for torch.optim.AdamW 22 | lr: 1.0e-5 23 | weight_decay: 1.0e-1 24 | betas: [0.9, 0.95] 25 | scheduler: 26 | # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR 27 | T_max: 20000 28 | eta_min: 1.0e-6 29 | sys: 30 | device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 31 | dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 32 | compile: True # use PyTorch 2.0 to compile the model to be faster 33 | -------------------------------------------------------------------------------- /examples/rlhf/config/train_rlhf.yaml: -------------------------------------------------------------------------------- 1 | io: 2 | eval_interval: 6 3 | log_interval: 1 4 | eval_iters: 10 5 | logger: wandb 6 | project_name: torchrl_example_rlhf 7 | group_name: null 8 | data: 9 | batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size 10 | block_size: 550 11 | num_workers: 1 12 | model: 13 | name_or_path: ./out 14 | out_dir: ./out_rlhf 15 | dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ 16 | reward_model: 17 | name_or_path: ./out_reward 18 | train: 19 | grad_clip: 1.0 20 | max_epochs: 1000 # total number of training iterations 21 | always_save_checkpoint: True # if True, always save a checkpoint after each eval 22 | decay_lr: True 23 | optimizer: 24 | # keyword arguments for torch.optim.AdamW 25 | lr: 5.0e-5 26 | weight_decay: 0.0 # 01 27 | betas: [0.9, 0.999] 28 | scheduler: 29 | # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR 30 | T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size 31 | eta_min: 5.0e-6 32 | ppo: 33 | episode_length: 50 34 | ppo_batch_size: 16 35 | ppo_num_epochs: 3 36 | num_rollouts_per_epoch: 32 37 | sys: 38 | device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 39 | ref_device: cuda:1 # device of reference model 40 | dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 41 | compile: False # use PyTorch 2.0 to compile the model to be faster 42 | -------------------------------------------------------------------------------- /examples/rlhf/data/__init__.py: -------------------------------------------------------------------------------- 1 | from torchrl.data.llm.prompt import get_prompt_dataloader_tldr 2 | 3 | __all__ = ["get_prompt_dataloader_tldr"] 4 | -------------------------------------------------------------------------------- /examples/rlhf/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /examples/rlhf/models/actor_critic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator 8 | from torchrl.modules.tensordict_module.common import VmapModule 9 | 10 | from .transformer import init_transformer 11 | 12 | __all__ = ["init_actor_critic"] 13 | 14 | 15 | def init_actor_critic(model_cfg, sys_cfg): 16 | 17 | transformer_name_or_path = model_cfg.name_or_path 18 | dropout = model_cfg.dropout 19 | 20 | device = sys_cfg.device 21 | compile_model = sys_cfg.compile 22 | base_model = init_transformer( 23 | transformer_name_or_path, 24 | dropout, 25 | device, 26 | as_tensordictmodule=False, 27 | compile_model=compile_model, 28 | inference=True, 29 | ) 30 | model = LMHeadActorValueOperator(base_model) 31 | model.to(device) 32 | model.eval() 33 | actor = model.get_policy_operator() 34 | critic = model.get_value_operator() 35 | critic_head = model.get_value_head() 36 | 37 | return actor, VmapModule(critic, mock=True), critic_head, base_model 38 | -------------------------------------------------------------------------------- /examples/rlhf/models/reward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import warnings 6 | 7 | import torch 8 | from tensordict.nn import TensorDictModule 9 | from torchrl._utils import logger as torchrl_logger 10 | 11 | from torchrl.modules.models.llm import GPT2RewardModel 12 | 13 | 14 | def init_reward_model( 15 | transformer_path=None, reward_model_path=None, device=None, compile_model=False 16 | ): 17 | if transformer_path is None and reward_model_path is None: 18 | warnings.warn( 19 | "You did not provide a path to the reward model, a naive reward model will be used instead." 20 | ) 21 | model = GPT2RewardModel() 22 | else: 23 | if not ((transformer_path is None) ^ (reward_model_path is None)): 24 | raise ValueError( 25 | "Exactly one of transformer_path or reward_model_path should be specified." 26 | ) 27 | if transformer_path is not None: 28 | model = GPT2RewardModel(transformer_path) 29 | else: 30 | model = GPT2RewardModel.from_pretrained(reward_model_path) 31 | 32 | model.to(device) 33 | if compile_model: 34 | torchrl_logger.info("Compiling the reward model...") 35 | model = torch.compile(model) 36 | 37 | model = TensorDictModule( 38 | model, 39 | in_keys=["input_ids", "attention_mask"], 40 | out_keys=["rewards", "end_scores"], 41 | ) 42 | return model 43 | -------------------------------------------------------------------------------- /examples/rlhf/models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import torch 6 | from tensordict.nn import TensorDictModule 7 | from torchrl._utils import logger as torchrl_logger 8 | from transformers import GPT2LMHeadModel 9 | 10 | 11 | def init_transformer( 12 | name_or_path, 13 | dropout, 14 | device, 15 | compile_model, 16 | as_tensordictmodule=True, 17 | inference=False, 18 | ): 19 | model_kwargs = { 20 | "resid_pdrop": dropout, 21 | "embd_pdrop": dropout, 22 | "attn_pdrop": dropout, 23 | "summary_first_dropout": dropout, 24 | } 25 | model = GPT2LMHeadModel.from_pretrained( 26 | name_or_path, return_dict=False, **model_kwargs 27 | ) 28 | model.to(device) 29 | 30 | if compile_model: 31 | torchrl_logger.info("Compiling transformer model...") 32 | model = torch.compile(model) 33 | 34 | if as_tensordictmodule: 35 | model = TensorDictModule( 36 | model, 37 | in_keys={ 38 | "input_ids": "input_ids", 39 | "attention_mask": "attention_mask", 40 | "labels": "labels", 41 | }, 42 | out_keys=["logits"] if inference else ["loss", "logits"], 43 | ) 44 | return model 45 | -------------------------------------------------------------------------------- /examples/rlhf/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | hydra-core 3 | matplotlib 4 | numpy 5 | PyYAML 6 | requests 7 | tiktoken 8 | tqdm 9 | transformers 10 | git+https://github.com/pytorch/rl 11 | git+https://github.com/pytorch-labs/tensordict 12 | -------------------------------------------------------------------------------- /examples/video/video-from-dataset.py: -------------------------------------------------------------------------------- 1 | """Video from dataset example. 2 | 3 | This example shows how to save a video from a dataset. 4 | 5 | To run it, you will need to install the openx requirements as well as torchvision. 6 | """ 7 | 8 | from torchrl.data.datasets import OpenXExperienceReplay 9 | from torchrl.record import CSVLogger, VideoRecorder 10 | 11 | # Create a logger that saves videos as mp4 12 | logger = CSVLogger("./dump", video_format="mp4") 13 | 14 | 15 | # We use the VideoRecorder transform to save register the images coming from the batch. 16 | t = VideoRecorder( 17 | logger=logger, tag="pixels", in_keys=[("next", "observation", "image")] 18 | ) 19 | # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False) 20 | dataset = OpenXExperienceReplay( 21 | "cmu_stretch", 22 | batch_size=2000, 23 | slice_len=200, 24 | download=True, 25 | strict_length=False, 26 | transform=t, 27 | ) 28 | 29 | # Get a batch of data and visualize it 30 | for _ in dataset: 31 | # The transform has seen the data since it's in the replay buffer 32 | t.dump() 33 | break 34 | 35 | # Alternatively, we can build the dataset without the VideoRecorder and call it manually: 36 | dataset = OpenXExperienceReplay( 37 | "cmu_stretch", 38 | batch_size=2000, 39 | slice_len=200, 40 | download=True, 41 | strict_length=False, 42 | ) 43 | 44 | # Get a batch of data and visualize it 45 | for data in dataset: 46 | t(data) 47 | t.dump() 48 | break 49 | -------------------------------------------------------------------------------- /gallery/README.rst: -------------------------------------------------------------------------------- 1 | Example gallery 2 | =============== 3 | 4 | Below is a gallery of examples 5 | -------------------------------------------------------------------------------- /knowledge_base/GYM.md: -------------------------------------------------------------------------------- 1 | # Working with gym 2 | 3 | ## What is OpenAI Gym? 4 | 5 | OpenAI Gym is a python library that provides the tooling for coding and using 6 | environments in RL contexts. The environments can be either simulators or real 7 | world systems (such as robots or games). 8 | Due to its easiness of use, Gym has been widely adopted as one the main APIs for 9 | environment interaction in RL and control. 10 | 11 | Historically, Gym was started by OpenAI on [https://github.com/openai/gym](https://github.com/openai/gym). 12 | Since then, OpenAI has ceased to maintain it and the library has been forked out 13 | in [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) by the Farama Foundation. 14 | 15 | Check the [Gym documentation](https://www.gymlibrary.dev/) for further details 16 | about the installation and usage. 17 | 18 | ## Versioning 19 | The OpenAI Gym library is known to have gone through multiple BC breaking changes 20 | and significant user-facing API modifications. 21 | In practice, TorchRL is tested against gym 0.13 and further and should work with 22 | any version in between. 23 | 24 | However, libraries built around Gym may have a custom env construction process 25 | that breaks the automatic wrapping from the `GymEnv` class. In those cases, it 26 | is best to first create the gym environment and wrap it using 27 | `torchrl.envs.libs.gym.GymWrapper`. 28 | 29 | If you run into an issue when running TorchRL with a specific version of gym, 30 | feel free to open an issue and we will gladly look into this. 31 | -------------------------------------------------------------------------------- /knowledge_base/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge base 2 | 3 | This knowledge base is aimed at helping you solving common issues that you might encounter on your RL journey. 4 | It is a more practical guide than what one might usually gather from a textbook or a repo. 5 | It details pro-tips to run models on cluster. It gives indications on how to set up a 6 | conda environment, how to install packages without sudo access. It 7 | highlights common issues when installing and running RL libraries and 8 | provide ready-to-use solution for these issues. 9 | 10 | ## Contributing 11 | 12 | Of course, this is a ever 1% complete journey, and we can't expect to cover all the RL 13 | field at any point in the future. 14 | 15 | **If you feel something is missing**, or if you bump in an issue and think that the 16 | community might benefit from your experience, please submit a PR with your contribution! 17 | -------------------------------------------------------------------------------- /knowledge_base/VERSIONING_ISSUES.md: -------------------------------------------------------------------------------- 1 | # Versioning Issues 2 | 3 | ## Pytorch version 4 | This issue is related to https://github.com/pytorch/rl/issues/689. Using PyTorch versions <2.0 and installing stable package leads to undefined symbol errors. For example: 5 | ``` 6 | ImportError: /usr/local/lib/python3.7/dist-packages/torchrl/_torchrl.so: undefined symbol: _ZN8pybind116detail11type_casterIN2at6TensorEvE4loadENS_6handleEb 7 | ``` 8 | 9 | ### How to reproduce 10 | 1. Create an Colab Notebook (at 24/11/2022 Colab environment has Python 3.7 and Pytorch 1.12 installed by default). 11 | 2. ``` !pip install torchrl ``` 12 | 3. ``` import torchrl ``` 13 | 14 | In Colab you can solve the issue by running: 15 | ``` 16 | !pip3 install torch --extra-index-url https://download.pytorch.org/whl/cpu -U 17 | ``` 18 | before the ```!pip install torchrl``` command. This will install the latest pytorch. Instructions can be found [here](https://pytorch.org/get-started/locally/). 19 | 20 | ### Workarounds 21 | There are two workarounds to this issue 22 | 1. Install/upgrade to the latest pytorch release before installing torchrl. 23 | 2. If you need to use a previous pytorch release: Install functorch version related to your torch distribution: e.g. ``` pip install functorch==0.2.0 ``` and install library from source ``` pip install git+https://github.com/pytorch/rl@ ```. 24 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | files = torchrl 4 | show_error_codes = True 5 | pretty = True 6 | allow_redefinition = True 7 | warn_redundant_casts = True 8 | 9 | [mypy-torchvision.*] 10 | 11 | ignore_errors = True 12 | ignore_missing_imports = True 13 | 14 | [mypy-numpy.*] 15 | 16 | ignore_missing_imports = True 17 | 18 | [mypy-scipy.*] 19 | 20 | ignore_missing_imports = True 21 | 22 | [mypy-pycocotools.*] 23 | 24 | ignore_missing_imports = True 25 | 26 | [mypy-lmdb.*] 27 | 28 | ignore_missing_imports = True 29 | 30 | [mypy-tqdm.*] 31 | 32 | ignore_missing_imports = True 33 | 34 | [mypy-moviepy.*] 35 | 36 | ignore_missing_imports = True 37 | 38 | [mypy-dm_control.*] 39 | 40 | ignore_missing_imports = True 41 | 42 | [mypy-dm_env.*] 43 | 44 | ignore_missing_imports = True 45 | 46 | [mypy-retro.*] 47 | 48 | ignore_missing_imports = True 49 | 50 | [mypy-gym.*] 51 | 52 | ignore_missing_imports = True 53 | 54 | [mypy-torchrl._torchrl.*] 55 | 56 | ignore_missing_imports = True 57 | -------------------------------------------------------------------------------- /packaging/build_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | . "$script_dir/pkg_helpers.bash" 6 | 7 | export BUILD_TYPE=wheel 8 | setup_env 9 | setup_wheel_python 10 | pip_install numpy pyyaml future ninja 11 | pip_install --upgrade setuptools 12 | setup_pip_pytorch_version 13 | python setup.py clean 14 | 15 | # Copy binaries to be included in the wheel distribution 16 | if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then 17 | python_exec="$(which python)" 18 | bin_path=$(dirname $python_exec) 19 | if [[ "$(uname)" == Darwin ]]; then 20 | # Install delocate to relocate the required binaries 21 | pip_install "delocate>=0.9" 22 | fi 23 | else 24 | # Install auditwheel to get some inspection utilities 25 | pip_install auditwheel 26 | 27 | # Point to custom libraries 28 | export LD_LIBRARY_PATH=$(pwd)/ext_libraries/lib:$LD_LIBRARY_PATH 29 | fi 30 | 31 | if [[ "$OSTYPE" == "msys" ]]; then 32 | IS_WHEEL=1 "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel 33 | else 34 | python setup.py bdist_wheel 35 | if [[ "$(uname)" != Darwin ]]; then 36 | rename "linux_x86_64" "manylinux1_x86_64" dist/*.whl 37 | fi 38 | fi 39 | -------------------------------------------------------------------------------- /packaging/windows/internal/driver_update.bat: -------------------------------------------------------------------------------- 1 | set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" 2 | curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe 3 | if errorlevel 1 exit /b 1 4 | 5 | start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot 6 | if errorlevel 1 exit /b 1 7 | 8 | del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL 9 | 10 | setlocal EnableDelayedExpansion 11 | set NVIDIA_GPU_EXISTS=0 12 | for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( 13 | set GPUS=%%i 14 | if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( 15 | SET NVIDIA_GPU_EXISTS=1 16 | goto gpu_check_end 17 | ) 18 | ) 19 | :gpu_check_end 20 | endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% 21 | 22 | if "%NVIDIA_GPU_EXISTS%" == "0" ( 23 | echo "CUDA Driver installation Failed" 24 | exit /b 1 25 | ) 26 | -------------------------------------------------------------------------------- /packaging/windows/internal/vc_env_helper.bat: -------------------------------------------------------------------------------- 1 | @echo on 2 | 3 | set VC_VERSION_LOWER=16 4 | set VC_VERSION_UPPER=17 5 | if "%VC_YEAR%" == "2017" ( 6 | set VC_VERSION_LOWER=15 7 | set VC_VERSION_UPPER=16 8 | ) 9 | 10 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( 11 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 12 | set "VS15INSTALLDIR=%%i" 13 | set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" 14 | goto vswhere 15 | ) 16 | ) 17 | 18 | :vswhere 19 | if "%VSDEVCMD_ARGS%" == "" ( 20 | call "%VS15VCVARSALL%" x64 || exit /b 1 21 | ) else ( 22 | call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 23 | ) 24 | 25 | @echo on 26 | 27 | set DISTUTILS_USE_SDK=1 28 | 29 | set args=%1 30 | shift 31 | :start 32 | if [%1] == [] goto done 33 | set args=%args% %1 34 | shift 35 | goto start 36 | 37 | :done 38 | if "%args%" == "" ( 39 | echo Usage: vc_env_helper.bat [command] [args] 40 | echo e.g. vc_env_helper.bat cl /c test.cpp 41 | ) 42 | 43 | %args% || exit /b 1 44 | -------------------------------------------------------------------------------- /packaging/windows/internal/vc_install_helper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | if [[ "$CU_VERSION" == "cu92" ]]; then 6 | export VC_YEAR=2017 7 | export VSDEVCMD_ARGS="-vcvars_ver=14.13" 8 | powershell packaging/windows/internal/vs2017_install.ps1 9 | elif [[ "$CU_VERSION" == "cu100" ]]; then 10 | export VC_YEAR=2017 11 | export VSDEVCMD_ARGS="" 12 | powershell packaging/windows/internal/vs2017_install.ps1 13 | else 14 | export VC_YEAR=2019 15 | export VSDEVCMD_ARGS="" 16 | fi 17 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | first_party_detection = false 3 | 4 | [build-system] 5 | requires = ["setuptools", "wheel", "torch", "ninja"] 6 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = 3 | # show summary of all tests that did not pass 4 | -ra 5 | # Make tracebacks shorter 6 | --tb=native 7 | markers = 8 | unity_editor 9 | testpaths = 10 | test 11 | xfail_strict = True 12 | -------------------------------------------------------------------------------- /sota-check/README.md: -------------------------------------------------------------------------------- 1 | # SOTA Performance checks 2 | 3 | This folder contains a `submitit-release-check.sh` file that executes all 4 | the training scripts using `sbatch` with the default configuration and long them 5 | into a common WandB project. 6 | 7 | This script is to be executed before every release to assess the performance of 8 | the various algorithms available in torchrl. The name of the project will include 9 | the specific commit of torchrl used to run the scripts (e.g. `torchrl-examples-check-`). 10 | 11 | ## Usage 12 | 13 | To display the script usage, you can use the `--help` option: 14 | 15 | ```bash 16 | ./submitit-release-check.sh --help 17 | ``` 18 | 19 | ## Setup 20 | 21 | The following setup should allow you to run the scripts: 22 | 23 | ```bash 24 | export MUJOCO_GL=egl 25 | 26 | conda create -n rl-sota-bench python=3.10 -y 27 | conda install anaconda::libglu -y 28 | pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 29 | pip3 install "gymnasium[atari,mujoco]" vmas tqdm wandb pygame "moviepy<2.0.0" imageio submitit hydra-core transformers 30 | 31 | cd /path/to/tensordict 32 | python setup.py develop 33 | cd /path/to/torchrl 34 | python setup.py develop 35 | ``` 36 | -------------------------------------------------------------------------------- /sota-check/run_a2c_atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=a2c_atari 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/a2c_atari_%j.txt 8 | #SBATCH --error=slurm_errors/a2c_atari_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="a2c_atari" 13 | 14 | export PYTHONPATH=$(dirname $(dirname $PWD)) 15 | python $PYTHONPATH/sota-implementations/a2c/a2c_atari.py \ 16 | logger.backend=wandb \ 17 | logger.project_name="$project_name" \ 18 | logger.group_name="$group_name" 19 | 20 | # Capture the exit status of the Python command 21 | exit_status=$? 22 | # Write the exit status to a file 23 | if [ $exit_status -eq 0 ]; then 24 | echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log 25 | else 26 | echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log 27 | fi 28 | -------------------------------------------------------------------------------- /sota-check/run_a2c_mujoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=a2c_mujoco 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/a2c_mujoco_%j.txt 8 | #SBATCH --error=slurm_errors/a2c_mujoco_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="a2c_mujoco" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/a2c/a2c_mujoco.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_cql_offline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=cql_offline 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/cql_offline_%j.txt 8 | #SBATCH --error=slurm_errors/cql_offline_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="cql_offline" 13 | 14 | export PYTHONPATH=$(dirname $(dirname $PWD)) 15 | python $PYTHONPATH/sota-implementations/cql/cql_offline.py \ 16 | logger.backend=wandb \ 17 | logger.project_name="$project_name" \ 18 | logger.group_name="$group_name" 19 | 20 | # Capture the exit status of the Python command 21 | exit_status=$? 22 | # Write the exit status to a file 23 | if [ $exit_status -eq 0 ]; then 24 | echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log 25 | else 26 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 27 | fi 28 | -------------------------------------------------------------------------------- /sota-check/run_cql_online.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=cql_online 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/cql_online_%j.txt 8 | #SBATCH --error=slurm_errors/cql_online_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="cql_online" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/cql/cql_online.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_crossq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=crossq 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/crossq_%j.txt 8 | #SBATCH --error=slurm_errors/crossq_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="crossq" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/crossq/crossq.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_ddpg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=ddpg 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/ddpg_%j.txt 8 | #SBATCH --error=slurm_errors/ddpg_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="ddpg" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/ddpg/ddpg.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_discrete_sac.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=discrete_sac 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/discrete_sac_%j.txt 8 | #SBATCH --error=slurm_errors/discrete_sac_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="discrete_sac" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/discrete_sac/discrete_sac.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_dqn_atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=dqn_atari 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/dqn_atari_%j.txt 8 | #SBATCH --error=slurm_errors/dqn_atari_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="dqn_atari" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/dqn/dqn_atari.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_dqn_cartpole.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=dqn_cartpole 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/dqn_cartpole_%j.txt 8 | #SBATCH --error=slurm_errors/dqn_cartpole_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="dqn_cartpole" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/dqn/dqn_cartpole.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_dt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=dt 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/dt_offline_%j.txt 8 | #SBATCH --error=slurm_errors/dt_offline_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="dt_offline" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/decision_transformer/dt.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_dt_online.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=dt_online 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/dt_online_%j.txt 8 | #SBATCH --error=slurm_errors/dt_online_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="dt_online" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/decision_transformer/online_dt.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_impala_single_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=impala_1node 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/impala_1node_%j.txt 8 | #SBATCH --error=slurm_errors/impala_1node_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="impala_1node" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/impala/impala_single_node.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_iql_discrete.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=iql_discrete 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/iql_discrete_%j.txt 8 | #SBATCH --error=slurm_errors/iql_discrete_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="iql_discrete" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/iql/discrete_iql.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_iql_offline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=iql_offline 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/iql_offline_%j.txt 8 | #SBATCH --error=slurm_errors/iql_offline_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="iql_offline" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/iql/iql_offline.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_iql_online.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=iql_online 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/iql_online_%j.txt 8 | #SBATCH --error=slurm_errors/iql_online_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="iql_online" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/iql/iql_online.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_multiagent_iddpg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=marl_iddpg 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/marl_iddpg_%j.txt 8 | #SBATCH --error=slurm_errors/marl_iddpg_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="marl_iddpg" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/multiagent/maddpg_iddpg.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_multiagent_ippo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=marl_ippo 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/marl_ippo_%j.txt 8 | #SBATCH --error=slurm_errors/marl_ippo_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="mappo_ippo" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/multiagent/mappo_ippo.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_multiagent_iql.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=marl_iql 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/marl_iql_%j.txt 8 | #SBATCH --error=slurm_errors/marl_iql_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="marl_iql" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/multiagent/iql.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_multiagent_qmix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=marl_qmix_vdn 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/marl_qmix_vdn_%j.txt 8 | #SBATCH --error=slurm_errors/marl_qmix_vdn_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="marl_qmix_vdn" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/multiagent/qmix_vdn.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_multiagent_sac.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=marl_sac 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/marl_sac_%j.txt 8 | #SBATCH --error=slurm_errors/marl_sac_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="marl_sac" 13 | 14 | export PYTHONPATH=$(dirname $(dirname $PWD)) 15 | python $PYTHONPATH/sota-implementations/multiagent/sac.py \ 16 | logger.backend=wandb \ 17 | logger.project_name="$project_name" \ 18 | logger.group_name="$group_name" 19 | 20 | # Capture the exit status of the Python command 21 | exit_status=$? 22 | # Write the exit status to a file 23 | if [ $exit_status -eq 0 ]; then 24 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 25 | else 26 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 27 | fi 28 | -------------------------------------------------------------------------------- /sota-check/run_ppo_atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=ppo_atari 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/ppo_atari_%j.txt 8 | #SBATCH --error=slurm_errors/ppo_atari_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="ppo_atari" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/ppo/ppo_atari.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_ppo_mujoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=ppo_mujoco 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/ppo_mujoco_%j.txt 8 | #SBATCH --error=slurm_errors/ppo_mujoco_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="ppo_mujoco" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/ppo/ppo_mujoco.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_sac.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=sac 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/sac_%j.txt 8 | #SBATCH --error=slurm_errors/sac_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="sac" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/sac/sac.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_td3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=td3 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/td3_%j.txt 8 | #SBATCH --error=slurm_errors/td3_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="td3" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/td3/td3.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-check/run_td3bc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=td3bc_offline 4 | #SBATCH --ntasks=32 5 | #SBATCH --cpus-per-task=1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --output=slurm_logs/td3bc_offline_%j.txt 8 | #SBATCH --error=slurm_errors/td3bc_offline_%j.txt 9 | 10 | current_commit=$(git rev-parse --short HEAD) 11 | project_name="torchrl-example-check-$current_commit" 12 | group_name="td3bc_offline" 13 | export PYTHONPATH=$(dirname $(dirname $PWD)) 14 | python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \ 15 | logger.backend=wandb \ 16 | logger.project_name="$project_name" \ 17 | logger.group_name="$group_name" 18 | 19 | # Capture the exit status of the Python command 20 | exit_status=$? 21 | # Write the exit status to a file 22 | if [ $exit_status -eq 0 ]; then 23 | echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log 24 | else 25 | echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log 26 | fi 27 | -------------------------------------------------------------------------------- /sota-implementations/a2c/config_atari.yaml: -------------------------------------------------------------------------------- 1 | # Environment 2 | env: 3 | env_name: PongNoFrameskip-v4 4 | backend: gymnasium 5 | num_envs: 16 6 | 7 | # collector 8 | collector: 9 | frames_per_batch: 800 10 | total_frames: 40_000_000 11 | 12 | # logger 13 | logger: 14 | backend: wandb 15 | project_name: torchrl_example_a2c 16 | group_name: null 17 | exp_name: Atari_Schulman17 18 | test_interval: 40_000_000 19 | num_test_episodes: 3 20 | video: False 21 | 22 | # Optim 23 | optim: 24 | lr: 0.0001 25 | eps: 1.0e-8 26 | weight_decay: 0.0 27 | max_grad_norm: 40.0 28 | anneal_lr: True 29 | 30 | # loss 31 | loss: 32 | gamma: 0.99 33 | mini_batch_size: 80 34 | gae_lambda: 0.95 35 | critic_coef: 0.25 36 | entropy_coef: 0.01 37 | loss_critic_type: l2 38 | device: 39 | 40 | compile: 41 | compile: False 42 | compile_mode: 43 | cudagraphs: False 44 | -------------------------------------------------------------------------------- /sota-implementations/a2c/config_mujoco.yaml: -------------------------------------------------------------------------------- 1 | # task and env 2 | env: 3 | env_name: HalfCheetah-v4 4 | 5 | # collector 6 | collector: 7 | frames_per_batch: 640 8 | total_frames: 1_000_000 9 | 10 | # logger 11 | logger: 12 | backend: wandb 13 | project_name: torchrl_example_a2c 14 | group_name: null 15 | exp_name: Mujoco_Schulman17 16 | test_interval: 1_000_000 17 | num_test_episodes: 5 18 | video: False 19 | 20 | # Optim 21 | optim: 22 | lr: 3e-4 23 | weight_decay: 0.0 24 | anneal_lr: False 25 | 26 | # loss 27 | loss: 28 | gamma: 0.99 29 | mini_batch_size: 64 30 | gae_lambda: 0.95 31 | critic_coef: 0.25 32 | entropy_coef: 0.0 33 | loss_critic_type: l2 34 | device: 35 | 36 | compile: 37 | compile: False 38 | compile_mode: default 39 | cudagraphs: False 40 | -------------------------------------------------------------------------------- /sota-implementations/bandits/README.md: -------------------------------------------------------------------------------- 1 | # Bandits example 2 | 3 | ## Note: 4 | This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the 5 | benchmarking of future releases, to ensure that it can be successfully run with the release code and that the 6 | results are consistent. For now, be aware that this additional check has not been performed in the case of this 7 | specific example. 8 | -------------------------------------------------------------------------------- /sota-implementations/cql/discrete_cql_config.yaml: -------------------------------------------------------------------------------- 1 | # Task and env 2 | env: 3 | name: CartPole-v1 4 | task: "" 5 | backend: gymnasium 6 | n_samples_stats: 1000 7 | max_episode_steps: 200 8 | seed: 0 9 | 10 | # Collector 11 | collector: 12 | frames_per_batch: 200 13 | total_frames: 1_000_000 14 | multi_step: 0 15 | init_random_frames: 1000 16 | env_per_collector: 1 17 | device: 18 | max_frames_per_traj: 200 19 | annealing_frames: 10000 20 | eps_start: 1.0 21 | eps_end: 0.01 22 | 23 | # Logger 24 | logger: 25 | backend: wandb 26 | project_name: torchrl_example_cql 27 | group_name: null 28 | exp_name: cql_cartpole_gym 29 | log_interval: 5000 # record interval in frames 30 | eval_steps: 200 31 | mode: online 32 | eval_iter: 1000 33 | video: False 34 | 35 | # Buffer 36 | replay_buffer: 37 | prb: 0 38 | buffer_prefetch: 64 39 | size: 1_000_000 40 | scratch_dir: null 41 | 42 | # Optimization 43 | optim: 44 | utd_ratio: 1 45 | device: null 46 | lr: 1e-3 47 | weight_decay: 0.0 48 | batch_size: 256 49 | 50 | # Policy and model 51 | model: 52 | hidden_sizes: [256, 256] 53 | activation: relu 54 | 55 | # loss 56 | loss: 57 | loss_function: l2 58 | gamma: 0.99 59 | tau: 0.005 60 | 61 | compile: 62 | compile: False 63 | compile_mode: 64 | cudagraphs: False 65 | -------------------------------------------------------------------------------- /sota-implementations/cql/offline_config.yaml: -------------------------------------------------------------------------------- 1 | # env and task 2 | env: 3 | name: Hopper-v4 4 | task: "" 5 | library: gym 6 | n_samples_stats: 1000 7 | seed: 0 8 | backend: gymnasium 9 | 10 | # logger 11 | logger: 12 | backend: wandb 13 | project_name: torchrl_example_cql 14 | group_name: null 15 | exp_name: cql_${replay_buffer.dataset} 16 | # eval iter in gradient steps 17 | eval_iter: 5000 18 | eval_steps: 1000 19 | mode: online 20 | eval_envs: 5 21 | video: False 22 | 23 | # replay buffer 24 | replay_buffer: 25 | dataset: hopper-medium-v2 26 | batch_size: 256 27 | 28 | # optimization 29 | optim: 30 | device: null 31 | actor_lr: 3e-4 32 | critic_lr: 3e-4 33 | weight_decay: 0.0 34 | gradient_steps: 1_000_000 35 | policy_eval_start: 40_000 36 | 37 | # policy and model 38 | model: 39 | hidden_sizes: [256, 256] 40 | activation: relu 41 | default_policy_scale: 1.0 42 | scale_lb: 0.1 43 | 44 | # loss 45 | loss: 46 | loss_function: l2 47 | gamma: 0.99 48 | tau: 0.005 49 | # CQL specific hyperparameter 50 | temperature: 1.0 51 | min_q_weight: 1.0 52 | max_q_backup: False 53 | deterministic_backup: False 54 | num_random: 10 55 | with_lagrange: True 56 | lagrange_thresh: 5.0 # tau 57 | 58 | compile: 59 | compile: False 60 | compile_mode: 61 | cudagraphs: False 62 | -------------------------------------------------------------------------------- /sota-implementations/cql/online_config.yaml: -------------------------------------------------------------------------------- 1 | # Task and env 2 | env: 3 | name: Pendulum-v1 4 | task: "" 5 | n_samples_stats: 1000 6 | seed: 0 7 | train_num_envs: 1 8 | eval_num_envs: 1 9 | backend: gymnasium 10 | 11 | # Collector 12 | collector: 13 | frames_per_batch: 1000 14 | total_frames: 1_000_000 15 | multi_step: 0 16 | init_random_frames: 5_000 17 | env_per_collector: 1 18 | device: 19 | max_frames_per_traj: 1000 20 | 21 | 22 | # logger 23 | logger: 24 | backend: wandb 25 | project_name: torchrl_example_cql 26 | group_name: null 27 | exp_name: cql_${env.name} 28 | log_interval: 5000 # record interval in frames 29 | mode: online 30 | eval_steps: 1000 31 | video: False 32 | 33 | # Buffer 34 | replay_buffer: 35 | prb: 0 36 | buffer_prefetch: 64 37 | size: 1_000_000 38 | 39 | # Optimization 40 | optim: 41 | utd_ratio: 1 42 | device: null 43 | actor_lr: 3e-4 44 | critic_lr: 3e-4 45 | weight_decay: 0.0 46 | batch_size: 256 47 | optim_steps_per_batch: 200 48 | 49 | # Policy and model 50 | model: 51 | hidden_sizes: [256, 256] 52 | activation: relu 53 | default_policy_scale: 1.0 54 | scale_lb: 0.1 55 | 56 | # loss 57 | loss: 58 | loss_function: l2 59 | gamma: 0.99 60 | tau: 0.005 61 | # CQL hyperparameter 62 | temperature: 1.0 63 | min_q_weight: 1.0 64 | max_q_backup: False 65 | deterministic_backup: False 66 | num_random: 10 67 | with_lagrange: True 68 | lagrange_thresh: 10.0 69 | 70 | compile: 71 | compile: False 72 | compile_mode: 73 | cudagraphs: False 74 | -------------------------------------------------------------------------------- /sota-implementations/crossq/config.yaml: -------------------------------------------------------------------------------- 1 | # environment and task 2 | env: 3 | name: HalfCheetah-v4 4 | task: "" 5 | library: gym 6 | max_episode_steps: 1000 7 | seed: 42 8 | 9 | # collector 10 | collector: 11 | total_frames: 1_000_000 12 | init_random_frames: 25000 13 | frames_per_batch: 1000 14 | init_env_steps: 1000 15 | device: 16 | env_per_collector: 1 17 | reset_at_each_iter: False 18 | 19 | # replay buffer 20 | replay_buffer: 21 | size: 1000000 22 | prb: 0 # use prioritized experience replay 23 | scratch_dir: null 24 | 25 | # optim 26 | optim: 27 | utd_ratio: 1.0 28 | policy_update_delay: 3 29 | gamma: 0.99 30 | loss_function: l2 31 | lr: 1.0e-3 32 | weight_decay: 0.0 33 | batch_size: 256 34 | alpha_init: 1.0 35 | adam_eps: 1.0e-8 36 | beta1: 0.5 37 | beta2: 0.999 38 | 39 | # network 40 | network: 41 | batch_norm_momentum: 0.01 42 | warmup_steps: 100000 43 | critic_hidden_sizes: [2048, 2048] 44 | actor_hidden_sizes: [256, 256] 45 | critic_activation: relu 46 | actor_activation: relu 47 | default_policy_scale: 1.0 48 | scale_lb: 0.1 49 | device: 50 | 51 | compile: 52 | compile: False 53 | compile_mode: 54 | cudagraphs: False 55 | 56 | # logging 57 | logger: 58 | backend: wandb 59 | project_name: torchrl_example_crossQ 60 | group_name: null 61 | exp_name: ${env.name}_CrossQ 62 | mode: online 63 | eval_iter: 25000 64 | -------------------------------------------------------------------------------- /sota-implementations/ddpg/config.yaml: -------------------------------------------------------------------------------- 1 | # environment and task 2 | env: 3 | name: HalfCheetah-v4 4 | task: "" 5 | library: gymnasium 6 | max_episode_steps: 1000 7 | seed: 42 8 | 9 | # collector 10 | collector: 11 | total_frames: 1_000_000 12 | init_random_frames: 25_000 13 | frames_per_batch: 1000 14 | init_env_steps: 1000 15 | reset_at_each_iter: False 16 | device: 17 | env_per_collector: 1 18 | 19 | 20 | # replay buffer 21 | replay_buffer: 22 | size: 1000000 23 | prb: 0 # use prioritized experience replay 24 | scratch_dir: null 25 | 26 | # optimization 27 | optim: 28 | utd_ratio: 1.0 29 | gamma: 0.99 30 | loss_function: l2 31 | lr: 3.0e-4 32 | weight_decay: 1e-4 33 | batch_size: 256 34 | target_update_polyak: 0.995 35 | device: null 36 | 37 | # network 38 | network: 39 | hidden_sizes: [256, 256] 40 | activation: relu 41 | noise_type: "ou" # ou or gaussian 42 | 43 | compile: 44 | compile: False 45 | compile_mode: 46 | cudagraphs: False 47 | 48 | # logging 49 | logger: 50 | backend: wandb 51 | project_name: torchrl_example_ddpg 52 | group_name: null 53 | exp_name: ${env.name}_DDPG 54 | mode: online 55 | eval_iter: 25000 56 | video: False 57 | num_eval_envs: 1 58 | -------------------------------------------------------------------------------- /sota-implementations/discrete_sac/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # task and env 3 | env: 4 | name: CartPole-v1 5 | task: "" 6 | library: gym 7 | seed: 42 8 | max_episode_steps: 500 9 | 10 | # collector 11 | collector: 12 | total_frames: 25000 13 | init_random_frames: 1000 14 | init_env_steps: 1000 15 | frames_per_batch: 500 16 | reset_at_each_iter: False 17 | device: null 18 | env_per_collector: 1 19 | num_workers: 1 20 | 21 | # replay buffer 22 | replay_buffer: 23 | prb: 0 # use prioritized experience replay 24 | size: 1000000 25 | scratch_dir: null 26 | 27 | # optim 28 | optim: 29 | utd_ratio: 1.0 30 | gamma: 0.99 31 | batch_size: 256 32 | lr: 3.0e-4 33 | weight_decay: 0.0 34 | target_update_polyak: 0.995 35 | target_entropy_weight: 0.2 36 | target_entropy: "auto" 37 | loss_function: l2 38 | # default is 0.98 but needs to be decreased for env 39 | # with small action space 40 | 41 | # network 42 | network: 43 | hidden_sizes: [256, 256] 44 | activation: relu 45 | device: null 46 | 47 | compile: 48 | compile: False 49 | compile_mode: 50 | cudagraphs: False 51 | 52 | # logging 53 | logger: 54 | backend: wandb 55 | project_name: torchrl_example_discrete_sac 56 | group_name: null 57 | exp_name: ${env.name}_DiscreteSAC 58 | mode: online 59 | eval_iter: 5000 60 | video: False 61 | -------------------------------------------------------------------------------- /sota-implementations/dqn/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing Deep Q-Learning (DQN) Algorithm Results 2 | 3 | This repository contains scripts that enable training agents using the Deep Q-Learning (DQN) Algorithm on CartPole and Atari environments. For Atari, We follow the original paper [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) by Mnih et al. (2013). 4 | 5 | 6 | ## Examples Structure 7 | 8 | Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: 9 | 10 | 1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. dqn_atari.py). 11 | 12 | 2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py). 13 | 14 | 3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml). 15 | 16 | 17 | ## Running the Examples 18 | 19 | You can execute the DQN algorithm on the CartPole environment by running the following command: 20 | 21 | ```bash 22 | python dqn_cartpole.py 23 | 24 | You can execute the DQN algorithm on Atari environments by running the following command: 25 | 26 | ```bash 27 | python dqn_atari.py 28 | ``` 29 | 30 | ``` 31 | -------------------------------------------------------------------------------- /sota-implementations/dqn/config_atari.yaml: -------------------------------------------------------------------------------- 1 | device: null 2 | 3 | # Environment 4 | env: 5 | env_name: PongNoFrameskip-v4 6 | backend: gymnasium 7 | 8 | # collector 9 | collector: 10 | total_frames: 40_000_100 11 | frames_per_batch: 1600 12 | eps_start: 1.0 13 | eps_end: 0.01 14 | annealing_frames: 4_000_000 15 | init_random_frames: 200_000 16 | 17 | # buffer 18 | buffer: 19 | buffer_size: 1_000_000 20 | batch_size: 32 21 | scratch_dir: null 22 | 23 | # logger 24 | logger: 25 | backend: wandb 26 | project_name: torchrl_example_dqn 27 | group_name: null 28 | exp_name: DQN 29 | test_interval: 1_000_000 30 | num_test_episodes: 3 31 | video: False 32 | 33 | # Optim 34 | optim: 35 | lr: 0.00025 36 | max_grad_norm: 10 37 | 38 | # loss 39 | loss: 40 | gamma: 0.99 41 | hard_update_freq: 10_000 42 | num_updates: 100 43 | 44 | compile: 45 | compile: False 46 | compile_mode: default 47 | cudagraphs: False 48 | -------------------------------------------------------------------------------- /sota-implementations/dqn/config_cartpole.yaml: -------------------------------------------------------------------------------- 1 | device: null 2 | 3 | # Environment 4 | env: 5 | env_name: CartPole-v1 6 | 7 | # collector 8 | collector: 9 | total_frames: 500_100 10 | frames_per_batch: 1000 11 | eps_start: 1.0 12 | eps_end: 0.05 13 | annealing_frames: 250_000 14 | init_random_frames: 10_000 15 | 16 | # buffer 17 | buffer: 18 | buffer_size: 10_000 19 | batch_size: 128 20 | 21 | # logger 22 | logger: 23 | backend: wandb 24 | project_name: torchrl_example_dqn 25 | group_name: null 26 | exp_name: DQN 27 | test_interval: 50_000 28 | num_test_episodes: 5 29 | video: False 30 | 31 | # Optim 32 | optim: 33 | lr: 2.5e-4 34 | max_grad_norm: 10 35 | 36 | # loss 37 | loss: 38 | gamma: 0.99 39 | hard_update_freq: 50 40 | num_updates: 100 41 | 42 | compile: 43 | compile: False 44 | compile_mode: 45 | cudagraphs: False 46 | -------------------------------------------------------------------------------- /sota-implementations/dreamer/README.md: -------------------------------------------------------------------------------- 1 | # Dreamer example 2 | 3 | ## Note: 4 | This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the 5 | benchmarking of future releases, to ensure that it can be successfully run with the release code and that the 6 | results are consistent. For now, be aware that this additional check has not been performed in the case of this 7 | specific example. 8 | -------------------------------------------------------------------------------- /sota-implementations/dreamer/config.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | name: cheetah 3 | task: run 4 | seed: 0 5 | backend: dm_control 6 | frame_skip: 2 7 | from_pixels: True 8 | grayscale: False 9 | image_size : 64 10 | horizon: 500 11 | n_parallel_envs: 8 12 | device: cpu 13 | 14 | collector: 15 | total_frames: 5_000_000 16 | init_random_frames: 3000 17 | frames_per_batch: 1000 18 | device: 19 | 20 | optimization: 21 | train_every: 1000 22 | grad_clip: 100 23 | 24 | world_model_lr: 6e-4 25 | actor_lr: 8e-5 26 | value_lr: 8e-5 27 | kl_scale: 1.0 28 | free_nats: 3.0 29 | optim_steps_per_batch: 80 30 | gamma: 0.99 31 | lmbda: 0.95 32 | imagination_horizon: 15 33 | compile: False 34 | compile_backend: inductor 35 | use_autocast: True 36 | 37 | networks: 38 | exploration_noise: 0.3 39 | device: 40 | state_dim: 30 41 | rssm_hidden_dim: 200 42 | hidden_dim: 400 43 | activation: "elu" 44 | 45 | 46 | replay_buffer: 47 | batch_size: 2500 48 | buffer_size: 1000000 49 | batch_length: 50 50 | scratch_dir: null 51 | 52 | logger: 53 | backend: wandb 54 | project: dreamer-v1 55 | exp_name: ${env.name}-${env.task}-${env.seed} 56 | mode: online 57 | # eval interval, in collection counts 58 | eval_iter: 10 59 | eval_rollout_steps: 500 60 | video: False 61 | -------------------------------------------------------------------------------- /sota-implementations/gail/config.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | env_name: HalfCheetah-v4 3 | seed: 42 4 | backend: gymnasium 5 | 6 | logger: 7 | backend: wandb 8 | project_name: gail 9 | group_name: null 10 | exp_name: gail_ppo 11 | test_interval: 5000 12 | num_test_episodes: 5 13 | video: False 14 | mode: online 15 | 16 | ppo: 17 | collector: 18 | frames_per_batch: 2048 19 | total_frames: 1_000_000 20 | 21 | optim: 22 | lr: 3e-4 23 | weight_decay: 0.0 24 | anneal_lr: True 25 | 26 | loss: 27 | gamma: 0.99 28 | mini_batch_size: 64 29 | ppo_epochs: 10 30 | gae_lambda: 0.95 31 | clip_epsilon: 0.2 32 | anneal_clip_epsilon: False 33 | critic_coef: 0.25 34 | entropy_coef: 0.0 35 | loss_critic_type: l2 36 | 37 | gail: 38 | hidden_dim: 128 39 | lr: 3e-4 40 | use_grad_penalty: False 41 | gp_lambda: 10.0 42 | device: null 43 | 44 | compile: 45 | compile: False 46 | compile_mode: default 47 | cudagraphs: False 48 | 49 | replay_buffer: 50 | dataset: halfcheetah-expert-v2 51 | batch_size: 256 52 | -------------------------------------------------------------------------------- /sota-implementations/impala/config_multi_node_submitit.yaml: -------------------------------------------------------------------------------- 1 | # Environment 2 | env: 3 | env_name: PongNoFrameskip-v4 4 | backend: gymnasium 5 | 6 | # Device for the forward and backward passes 7 | local_device: 8 | 9 | # SLURM config 10 | slurm_config: 11 | timeout_min: 10 12 | slurm_partition: train 13 | slurm_cpus_per_task: 1 14 | slurm_gpus_per_node: 1 15 | 16 | # collector 17 | collector: 18 | backend: gloo 19 | frames_per_batch: 80 20 | total_frames: 200_000_000 21 | num_workers: 1 22 | 23 | # logger 24 | logger: 25 | backend: wandb 26 | project_name: torchrl_example_impala_submitit 27 | group_name: null 28 | exp_name: Atari_IMPALA 29 | test_interval: 200_000_000 30 | num_test_episodes: 3 31 | 32 | # Optim 33 | optim: 34 | lr: 0.0006 35 | eps: 1e-8 36 | weight_decay: 0.0 37 | momentum: 0.0 38 | alpha: 0.99 39 | max_grad_norm: 40.0 40 | anneal_lr: True 41 | 42 | # loss 43 | loss: 44 | gamma: 0.99 45 | batch_size: 32 46 | sgd_updates: 1 47 | critic_coef: 0.5 48 | entropy_coef: 0.01 49 | loss_critic_type: l2 50 | -------------------------------------------------------------------------------- /sota-implementations/impala/config_single_node.yaml: -------------------------------------------------------------------------------- 1 | # Environment 2 | env: 3 | env_name: PongNoFrameskip-v4 4 | backend: gymnasium 5 | 6 | # Device for the forward and backward passes 7 | device: 8 | 9 | # collector 10 | collector: 11 | frames_per_batch: 80 12 | total_frames: 200_000_000 13 | num_workers: 12 14 | 15 | # logger 16 | logger: 17 | backend: wandb 18 | project_name: torchrl_example_impala 19 | group_name: null 20 | exp_name: Atari_IMPALA 21 | test_interval: 200_000_000 22 | num_test_episodes: 3 23 | 24 | # Optim 25 | optim: 26 | lr: 0.0006 27 | eps: 1e-8 28 | weight_decay: 0.0 29 | momentum: 0.0 30 | alpha: 0.99 31 | max_grad_norm: 40.0 32 | anneal_lr: True 33 | 34 | # loss 35 | loss: 36 | gamma: 0.99 37 | batch_size: 32 38 | sgd_updates: 1 39 | critic_coef: 0.5 40 | entropy_coef: 0.01 41 | loss_critic_type: l2 42 | -------------------------------------------------------------------------------- /sota-implementations/iql/discrete_iql.yaml: -------------------------------------------------------------------------------- 1 | # task and env 2 | env: 3 | name: CartPole-v1 4 | task: "" 5 | n_samples_stats: 1000 6 | seed: 0 7 | train_num_envs: 1 8 | eval_num_envs: 1 9 | backend: gymnasium 10 | 11 | 12 | # collector 13 | collector: 14 | frames_per_batch: 200 15 | total_frames: 20000 16 | init_random_frames: 1000 17 | env_per_collector: 1 18 | device: 19 | max_frames_per_traj: 200 20 | 21 | # logger 22 | logger: 23 | backend: wandb 24 | project_name: torchrl_example_discrete_iql 25 | exp_name: iql_${env.name} 26 | group_name: null 27 | log_interval: 5000 # record interval in frames 28 | eval_steps: 200 29 | mode: online 30 | eval_iter: 1000 31 | video: False 32 | 33 | # replay buffer 34 | replay_buffer: 35 | prb: 0 36 | buffer_prefetch: 64 37 | size: 1_000_000 38 | 39 | # optimization 40 | optim: 41 | utd_ratio: 1 42 | device: null 43 | lr: 3e-4 44 | weight_decay: 0.0 45 | batch_size: 256 46 | 47 | # network 48 | model: 49 | hidden_sizes: [256, 256] 50 | activation: relu 51 | 52 | 53 | # loss 54 | loss: 55 | loss_function: l2 56 | gamma: 0.99 57 | hard_update_interval: 10 58 | 59 | # IQL specific hyperparameter 60 | temperature: 100 61 | expectile: 0.8 62 | 63 | compile: 64 | compile: False 65 | compile_mode: default 66 | cudagraphs: False 67 | -------------------------------------------------------------------------------- /sota-implementations/iql/offline_config.yaml: -------------------------------------------------------------------------------- 1 | # env and task 2 | env: 3 | name: HalfCheetah-v4 4 | task: "" 5 | exp_name: iql_${replay_buffer.dataset} 6 | n_samples_stats: 1000 7 | seed: 0 8 | backend: gymnasium 9 | 10 | # logger 11 | logger: 12 | backend: wandb 13 | project_name: torchrl_example_iql 14 | exp_name: iql_${replay_buffer.dataset} 15 | group_name: null 16 | eval_iter: 500 17 | eval_steps: 1000 18 | mode: online 19 | eval_envs: 5 20 | video: False 21 | 22 | # replay buffer 23 | replay_buffer: 24 | dataset: halfcheetah-medium-v2 25 | batch_size: 256 26 | 27 | # optimization 28 | optim: 29 | device: null 30 | lr: 3e-4 31 | weight_decay: 0.0 32 | gradient_steps: 50000 33 | 34 | # network 35 | model: 36 | hidden_sizes: [256, 256] 37 | activation: relu 38 | default_policy_scale: 1.0 39 | scale_lb: 0.1 40 | 41 | # loss 42 | loss: 43 | loss_function: l2 44 | gamma: 0.99 45 | tau: 0.005 46 | 47 | # IQL specific hyperparameter 48 | temperature: 3.0 49 | expectile: 0.7 50 | 51 | compile: 52 | compile: False 53 | compile_mode: 54 | cudagraphs: False 55 | -------------------------------------------------------------------------------- /sota-implementations/iql/online_config.yaml: -------------------------------------------------------------------------------- 1 | # task and env 2 | env: 3 | name: Pendulum-v1 4 | task: "" 5 | n_samples_stats: 1000 6 | seed: 0 7 | train_num_envs: 1 8 | eval_num_envs: 1 9 | backend: gymnasium 10 | 11 | # collector 12 | collector: 13 | frames_per_batch: 200 14 | total_frames: 20000 15 | multi_step: 0 16 | init_random_frames: 5000 17 | env_per_collector: 1 18 | device: 19 | max_frames_per_traj: 200 20 | 21 | # logger 22 | logger: 23 | backend: wandb 24 | project_name: torchrl_example_iql 25 | exp_name: iql_${env.name} 26 | group_name: null 27 | log_interval: 5000 # record interval in frames 28 | eval_steps: 200 29 | mode: online 30 | eval_iter: 1000 31 | video: False 32 | 33 | # replay buffer 34 | replay_buffer: 35 | prb: 0 36 | buffer_prefetch: 64 37 | size: 1_000_000 38 | 39 | # optimization 40 | optim: 41 | utd_ratio: 1 42 | device: null 43 | lr: 3e-4 44 | weight_decay: 0.0 45 | batch_size: 256 46 | optim_steps_per_batch: 200 47 | 48 | # network 49 | model: 50 | hidden_sizes: [256, 256] 51 | activation: relu 52 | default_policy_scale: 1.0 53 | scale_lb: 0.1 54 | 55 | # loss 56 | loss: 57 | loss_function: l2 58 | gamma: 0.99 59 | tau: 0.005 60 | 61 | # IQL specific hyperparameter 62 | temperature: 3.0 63 | expectile: 0.7 64 | 65 | compile: 66 | compile: False 67 | compile_mode: 68 | cudagraphs: False 69 | -------------------------------------------------------------------------------- /sota-implementations/media/ant_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/sota-implementations/media/ant_chart.png -------------------------------------------------------------------------------- /sota-implementations/media/cheetah_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/sota-implementations/media/cheetah_chart.png -------------------------------------------------------------------------------- /sota-implementations/media/halfcheetah_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/sota-implementations/media/halfcheetah_chart.png -------------------------------------------------------------------------------- /sota-implementations/media/walker2d_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/sota-implementations/media/walker2d_chart.png -------------------------------------------------------------------------------- /sota-implementations/multiagent/iql.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | 3 | env: 4 | max_steps: 100 5 | scenario_name: "balance" 6 | scenario: 7 | n_agents: 3 8 | device: ??? # These values will be populated dynamically 9 | vmas_envs: ??? 10 | 11 | model: 12 | shared_parameters: True 13 | 14 | collector: 15 | frames_per_batch: 60_000 # Frames sampled each sampling iteration 16 | n_iters: 500 # Number of sampling/training iterations 17 | total_frames: ??? 18 | 19 | buffer: 20 | memory_size: ??? 21 | 22 | loss: 23 | gamma: 0.9 24 | tau: 0.005 # For target net 25 | 26 | train: 27 | num_epochs: 45 # optimization steps per batch of data collected 28 | minibatch_size: 4096 # size of minibatches used in each epoch 29 | lr: 5e-5 30 | max_grad_norm: 40.0 31 | device: ??? 32 | 33 | eval: 34 | evaluation_interval: 20 35 | evaluation_episodes: 200 36 | 37 | logger: 38 | backend: wandb # Delete to remove logging 39 | project_name: null 40 | group_name: null 41 | -------------------------------------------------------------------------------- /sota-implementations/multiagent/maddpg_iddpg.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | 3 | env: 4 | max_steps: 100 5 | scenario_name: "balance" 6 | scenario: 7 | n_agents: 3 8 | device: ??? # These values will be populated dynamically 9 | vmas_envs: ??? 10 | 11 | model: 12 | shared_parameters: False # MADDPG paper does not use shared params because reward function can be different 13 | centralised_critic: True # MADDPG if True, IDDPG if False 14 | 15 | collector: 16 | frames_per_batch: 60_000 # Frames sampled each sampling iteration 17 | n_iters: 500 # Number of sampling/training iterations 18 | total_frames: ??? 19 | 20 | buffer: 21 | memory_size: ??? 22 | 23 | loss: 24 | gamma: 0.9 25 | tau: 0.005 # For target net 26 | 27 | train: 28 | num_epochs: 45 # optimization steps per batch of data collected 29 | minibatch_size: 4096 # size of minibatches used in each epoch 30 | lr: 5e-5 31 | max_grad_norm: 40.0 32 | device: ??? 33 | 34 | eval: 35 | evaluation_interval: 20 36 | evaluation_episodes: 200 37 | 38 | logger: 39 | backend: wandb # Delete to remove logging 40 | project_name: null 41 | group_name: null 42 | -------------------------------------------------------------------------------- /sota-implementations/multiagent/mappo_ippo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | 3 | env: 4 | max_steps: 100 5 | scenario_name: "balance" 6 | scenario: 7 | n_agents: 3 8 | device: ??? # These values will be populated dynamically 9 | vmas_envs: ??? 10 | 11 | model: 12 | shared_parameters: True 13 | centralised_critic: True # MAPPO if True, IPPO if False 14 | 15 | collector: 16 | frames_per_batch: 60_000 # Frames sampled each sampling iteration 17 | n_iters: 500 # Number of sampling/training iterations 18 | total_frames: ??? 19 | 20 | buffer: 21 | memory_size: ??? 22 | 23 | loss: 24 | gamma: 0.9 25 | lmbda: 0.9 26 | entropy_eps: 0 27 | clip_epsilon: 0.2 28 | 29 | train: 30 | num_epochs: 45 # optimization steps per batch of data collected 31 | minibatch_size: 4096 # size of minibatches used in each epoch 32 | lr: 5e-5 33 | max_grad_norm: 40.0 34 | device: ??? 35 | 36 | eval: 37 | evaluation_interval: 20 38 | evaluation_episodes: 200 39 | 40 | logger: 41 | backend: wandb # Delete to remove logging 42 | project_name: null 43 | group_name: null 44 | -------------------------------------------------------------------------------- /sota-implementations/multiagent/qmix_vdn.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | 3 | env: 4 | max_steps: 100 5 | scenario_name: "balance" 6 | scenario: 7 | n_agents: 3 8 | device: ??? # These values will be populated dynamically 9 | vmas_envs: ??? 10 | 11 | model: 12 | shared_parameters: True 13 | 14 | collector: 15 | frames_per_batch: 60_000 # Frames sampled each sampling iteration 16 | n_iters: 500 # Number of sampling/training iterations 17 | total_frames: ??? 18 | 19 | buffer: 20 | memory_size: ??? 21 | 22 | loss: 23 | mixer_type: "qmix" # or "vdn" 24 | gamma: 0.9 25 | tau: 0.005 # For target net 26 | 27 | train: 28 | num_epochs: 45 # optimization steps per batch of data collected 29 | minibatch_size: 4096 # size of minibatches used in each epoch 30 | lr: 5e-5 31 | max_grad_norm: 40.0 32 | device: ??? 33 | 34 | eval: 35 | evaluation_interval: 20 36 | evaluation_episodes: 200 37 | 38 | logger: 39 | backend: wandb # Delete to remove logging 40 | project_name: null 41 | group_name: null 42 | -------------------------------------------------------------------------------- /sota-implementations/multiagent/sac.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | 3 | env: 4 | continuous_actions: True # False for discrete sac 5 | categorical_actions: False 6 | max_steps: 100 7 | scenario_name: "balance" 8 | scenario: 9 | n_agents: 3 10 | device: ??? # These values will be populated dynamically 11 | vmas_envs: ??? 12 | 13 | model: 14 | shared_parameters: True 15 | centralised_critic: True 16 | 17 | collector: 18 | frames_per_batch: 60_000 # Frames sampled each sampling iteration 19 | n_iters: 500 # Number of sampling/training iterations 20 | total_frames: ??? 21 | 22 | buffer: 23 | memory_size: ??? 24 | 25 | loss: 26 | gamma: 0.9 27 | tau: 0.005 # For target net 28 | 29 | train: 30 | num_epochs: 45 # optimization steps per batch of data collected 31 | minibatch_size: 4096 # size of minibatches used in each epoch 32 | lr: 5e-5 33 | max_grad_norm: 2.0 34 | device: ??? 35 | 36 | eval: 37 | evaluation_interval: 20 38 | evaluation_episodes: 200 39 | 40 | logger: 41 | backend: wandb # Delete to remove logging 42 | project_name: null 43 | group_name: null 44 | -------------------------------------------------------------------------------- /sota-implementations/multiagent/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /sota-implementations/multiagent/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from tensordict import unravel_key 8 | from torchrl.envs import Transform 9 | 10 | 11 | def swap_last(source, dest): 12 | source = unravel_key(source) 13 | dest = unravel_key(dest) 14 | if isinstance(source, str): 15 | if isinstance(dest, str): 16 | return dest 17 | return dest[-1] 18 | if isinstance(dest, str): 19 | return source[:-1] + (dest,) 20 | return source[:-1] + (dest[-1],) 21 | 22 | 23 | class DoneTransform(Transform): 24 | """Expands the 'done' entries (incl. terminated) to match the reward shape. 25 | 26 | Can be appended to a replay buffer or a collector. 27 | """ 28 | 29 | def __init__(self, reward_key, done_keys): 30 | super().__init__() 31 | self.reward_key = reward_key 32 | self.done_keys = done_keys 33 | 34 | def forward(self, tensordict): 35 | for done_key in self.done_keys: 36 | new_name = swap_last(self.reward_key, done_key) 37 | tensordict.set( 38 | ("next", new_name), 39 | tensordict.get(("next", done_key)) 40 | .unsqueeze(-1) 41 | .expand(tensordict.get(("next", self.reward_key)).shape), 42 | ) 43 | return tensordict 44 | -------------------------------------------------------------------------------- /sota-implementations/ppo/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing Proximal Policy Optimization (PPO) Algorithm Results 2 | 3 | This repository contains scripts that enable training agents using the Proximal Policy Optimization (PPO) Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Schulman et al. (2017) to implement the PPO algorithm but introduce the improvement of computing the Generalised Advantage Estimator (GAE) at every epoch. 4 | 5 | 6 | ## Examples Structure 7 | 8 | Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files: 9 | 10 | 1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. ppo_atari.py). 11 | 12 | 2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py). 13 | 14 | 3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml). 15 | 16 | 17 | ## Running the Examples 18 | 19 | You can execute the PPO algorithm on Atari environments by running the following command: 20 | 21 | ```bash 22 | python ppo_atari.py 23 | ``` 24 | 25 | You can execute the PPO algorithm on MuJoCo environments by running the following command: 26 | 27 | ```bash 28 | python ppo_mujoco.py 29 | ``` 30 | -------------------------------------------------------------------------------- /sota-implementations/ppo/config_atari.yaml: -------------------------------------------------------------------------------- 1 | # Environment 2 | env: 3 | env_name: PongNoFrameskip-v4 4 | num_envs: 8 5 | backend: gymnasium 6 | 7 | # collector 8 | collector: 9 | frames_per_batch: 4096 10 | total_frames: 40_000_000 11 | 12 | # logger 13 | logger: 14 | backend: wandb 15 | project_name: torchrl_example_ppo 16 | group_name: null 17 | exp_name: Atari_Schulman17 18 | test_interval: 40_000_000 19 | num_test_episodes: 3 20 | video: False 21 | 22 | # Optim 23 | optim: 24 | lr: 2.5e-4 25 | eps: 1.0e-6 26 | weight_decay: 0.0 27 | max_grad_norm: 0.5 28 | anneal_lr: True 29 | device: 30 | 31 | # loss 32 | loss: 33 | gamma: 0.99 34 | mini_batch_size: 1024 35 | ppo_epochs: 3 36 | gae_lambda: 0.95 37 | clip_epsilon: 0.1 38 | anneal_clip_epsilon: True 39 | critic_coef: 1.0 40 | entropy_coef: 0.01 41 | loss_critic_type: l2 42 | 43 | compile: 44 | compile: False 45 | compile_mode: 46 | cudagraphs: False 47 | -------------------------------------------------------------------------------- /sota-implementations/ppo/config_mujoco.yaml: -------------------------------------------------------------------------------- 1 | # task and env 2 | env: 3 | env_name: HalfCheetah-v4 4 | 5 | # collector 6 | collector: 7 | frames_per_batch: 2048 8 | total_frames: 1_000_000 9 | 10 | # logger 11 | logger: 12 | backend: wandb 13 | project_name: torchrl_example_ppo 14 | group_name: null 15 | exp_name: Mujoco_Schulman17 16 | test_interval: 1_000_000 17 | num_test_episodes: 5 18 | video: False 19 | 20 | # Optim 21 | optim: 22 | lr: 3e-4 23 | weight_decay: 0.0 24 | anneal_lr: True 25 | device: 26 | 27 | # loss 28 | loss: 29 | gamma: 0.99 30 | mini_batch_size: 64 31 | ppo_epochs: 10 32 | gae_lambda: 0.95 33 | clip_epsilon: 0.2 34 | anneal_clip_epsilon: False 35 | critic_coef: 0.25 36 | entropy_coef: 0.0 37 | loss_critic_type: l2 38 | 39 | compile: 40 | compile: False 41 | compile_mode: 42 | cudagraphs: False 43 | -------------------------------------------------------------------------------- /sota-implementations/redq/README.md: -------------------------------------------------------------------------------- 1 | # REDQ example 2 | 3 | ## Note: 4 | This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the 5 | benchmarking of future releases, to ensure that it can be successfully run with the release code and that the 6 | results are consistent. For now, be aware that this additional check has not been performed in the case of this 7 | specific example. 8 | -------------------------------------------------------------------------------- /sota-implementations/sac/config-async.yaml: -------------------------------------------------------------------------------- 1 | # environment and task 2 | env: 3 | name: HalfCheetah-v4 4 | task: "" 5 | library: gymnasium 6 | max_episode_steps: 1000 7 | seed: 42 8 | 9 | # collector 10 | collector: 11 | total_frames: -1 12 | init_random_frames: 25000 13 | frames_per_batch: 8000 14 | init_env_steps: 1000 15 | device: cuda:1 16 | env_per_collector: 16 17 | reset_at_each_iter: False 18 | update_freq: 10_000 19 | 20 | # replay buffer 21 | replay_buffer: 22 | size: 100_000 # Small buffer size to keep only recent elements 23 | prb: 0 # use prioritized experience replay 24 | scratch_dir: 25 | 26 | # optim 27 | optim: 28 | utd_ratio: 1.0 29 | gamma: 0.99 30 | loss_function: l2 31 | lr: 3.0e-4 32 | weight_decay: 0.0 33 | batch_size: 256 34 | target_update_polyak: 0.995 35 | alpha_init: 1.0 36 | adam_eps: 1.0e-8 37 | 38 | # network 39 | network: 40 | hidden_sizes: [256, 256] 41 | activation: relu 42 | default_policy_scale: 1.0 43 | scale_lb: 0.1 44 | device: 45 | 46 | # logging 47 | logger: 48 | backend: wandb 49 | project_name: torchrl_example_sac 50 | group_name: null 51 | exp_name: ${env.name}_SAC 52 | mode: online 53 | log_freq: 25000 # logging freq in updates 54 | video: False 55 | 56 | compile: 57 | compile: False 58 | compile_mode: 59 | cudagraphs: False 60 | -------------------------------------------------------------------------------- /sota-implementations/sac/config.yaml: -------------------------------------------------------------------------------- 1 | # environment and task 2 | env: 3 | name: HalfCheetah-v4 4 | task: "" 5 | library: gymnasium 6 | max_episode_steps: 1000 7 | seed: 42 8 | 9 | # collector 10 | collector: 11 | total_frames: 1_000_000 12 | init_random_frames: 25000 13 | frames_per_batch: 1000 14 | init_env_steps: 1000 15 | device: 16 | env_per_collector: 8 17 | reset_at_each_iter: False 18 | 19 | # replay buffer 20 | replay_buffer: 21 | size: 1000000 22 | prb: 0 # use prioritized experience replay 23 | scratch_dir: 24 | 25 | # optim 26 | optim: 27 | utd_ratio: 1.0 28 | gamma: 0.99 29 | loss_function: l2 30 | lr: 3.0e-4 31 | weight_decay: 0.0 32 | batch_size: 256 33 | target_update_polyak: 0.995 34 | alpha_init: 1.0 35 | adam_eps: 1.0e-8 36 | 37 | # network 38 | network: 39 | hidden_sizes: [256, 256] 40 | activation: relu 41 | default_policy_scale: 1.0 42 | scale_lb: 0.1 43 | device: 44 | 45 | # logging 46 | logger: 47 | backend: wandb 48 | project_name: torchrl_example_sac 49 | group_name: null 50 | exp_name: ${env.name}_SAC 51 | mode: online 52 | eval_iter: 25000 53 | video: False 54 | 55 | compile: 56 | compile: False 57 | compile_mode: 58 | cudagraphs: False 59 | -------------------------------------------------------------------------------- /sota-implementations/td3/config.yaml: -------------------------------------------------------------------------------- 1 | # task and env 2 | env: 3 | name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency 4 | task: "" 5 | library: gymnasium 6 | seed: 42 7 | max_episode_steps: 1000 8 | 9 | # collector 10 | collector: 11 | total_frames: 1000000 12 | init_random_frames: 25_000 13 | init_env_steps: 1000 14 | frames_per_batch: 1000 15 | reset_at_each_iter: False 16 | device: 17 | env_per_collector: 8 18 | num_workers: 1 19 | 20 | # replay buffer 21 | replay_buffer: 22 | prb: 0 # use prioritized experience replay 23 | size: 1000000 24 | scratch_dir: null 25 | 26 | # optim 27 | optim: 28 | utd_ratio: 1.0 29 | gamma: 0.99 30 | loss_function: l2 31 | lr: 3.0e-4 32 | weight_decay: 0.0 33 | adam_eps: 1e-4 34 | batch_size: 256 35 | target_update_polyak: 0.995 36 | policy_update_delay: 2 37 | policy_noise: 0.2 38 | noise_clip: 0.5 39 | 40 | # network 41 | network: 42 | hidden_sizes: [256, 256] 43 | activation: relu 44 | device: null 45 | 46 | # logging 47 | logger: 48 | backend: wandb 49 | project_name: torchrl_example_td3 50 | group_name: null 51 | exp_name: ${env.name}_TD3 52 | mode: online 53 | eval_iter: 25000 54 | video: False 55 | 56 | compile: 57 | compile: False 58 | compile_mode: 59 | cudagraphs: False 60 | -------------------------------------------------------------------------------- /sota-implementations/td3_bc/config.yaml: -------------------------------------------------------------------------------- 1 | # task and env 2 | env: 3 | name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency 4 | task: "" 5 | library: gymnasium 6 | seed: 42 7 | max_episode_steps: 1000 8 | 9 | # replay buffer 10 | replay_buffer: 11 | dataset: halfcheetah-medium-v2 12 | batch_size: 256 13 | 14 | # optim 15 | optim: 16 | gradient_steps: 100000 17 | gamma: 0.99 18 | loss_function: l2 19 | lr: 3.0e-4 20 | weight_decay: 0.0 21 | adam_eps: 1e-4 22 | batch_size: 256 23 | target_update_polyak: 0.995 24 | policy_update_delay: 2 25 | policy_noise: 0.2 26 | noise_clip: 0.5 27 | alpha: 2.5 28 | 29 | # network 30 | network: 31 | hidden_sizes: [256, 256] 32 | activation: relu 33 | device: null 34 | 35 | # logging 36 | logger: 37 | backend: wandb 38 | project_name: td3+bc_${replay_buffer.dataset} 39 | group_name: null 40 | exp_name: TD3+BC_${replay_buffer.dataset} 41 | mode: online 42 | eval_iter: 5000 43 | eval_steps: 1000 44 | eval_envs: 1 45 | video: False 46 | 47 | compile: 48 | compile: False 49 | compile_mode: 50 | cudagraphs: False 51 | -------------------------------------------------------------------------------- /test/assets/openai_summarize_comparisons.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/test/assets/openai_summarize_comparisons.zip -------------------------------------------------------------------------------- /test/assets/openai_summarize_tldr.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/test/assets/openai_summarize_tldr.zip -------------------------------------------------------------------------------- /test/assets/tldr_batch.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/test/assets/tldr_batch.zip -------------------------------------------------------------------------------- /test/llm/smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import argparse 8 | 9 | import pytest 10 | 11 | 12 | def test_import(): 13 | pass 14 | 15 | 16 | if __name__ == "__main__": 17 | args, unknown = argparse.ArgumentParser().parse_known_args() 18 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 19 | -------------------------------------------------------------------------------- /test/llm/smoke_test_deps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import argparse 8 | 9 | import pytest 10 | 11 | 12 | if __name__ == "__main__": 13 | args, unknown = argparse.ArgumentParser().parse_known_args() 14 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) 15 | -------------------------------------------------------------------------------- /test/smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | 8 | def test_imports(): 9 | from torchrl.data import ( 10 | PrioritizedReplayBuffer, 11 | ReplayBuffer, 12 | TensorSpec, 13 | ) # noqa: F401 14 | from torchrl.envs import Transform, TransformedEnv # noqa: F401 15 | from torchrl.envs.gym_like import GymLikeEnv # noqa: F401 16 | from torchrl.modules import SafeModule # noqa: F401 17 | from torchrl.objectives.common import LossModule # noqa: F401 18 | 19 | PrioritizedReplayBuffer(alpha=1.1, beta=1.1) 20 | -------------------------------------------------------------------------------- /torchrl/collectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchrl.envs.utils import RandomPolicy 7 | 8 | from .collectors import ( 9 | aSyncDataCollector, 10 | DataCollectorBase, 11 | MultiaSyncDataCollector, 12 | MultiSyncDataCollector, 13 | SyncDataCollector, 14 | ) 15 | from .weight_update import ( 16 | MultiProcessedWeightUpdate, 17 | RayWeightUpdater, 18 | VanillaWeightUpdater, 19 | WeightUpdaterBase, 20 | ) 21 | 22 | __all__ = [ 23 | "RandomPolicy", 24 | "WeightUpdaterBase", 25 | "VanillaWeightUpdater", 26 | "RayWeightUpdater", 27 | "MultiProcessedWeightUpdate", 28 | "aSyncDataCollector", 29 | "DataCollectorBase", 30 | "MultiaSyncDataCollector", 31 | "MultiSyncDataCollector", 32 | "SyncDataCollector", 33 | ] 34 | -------------------------------------------------------------------------------- /torchrl/collectors/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .generic import ( 7 | DEFAULT_SLURM_CONF, 8 | DistributedDataCollector, 9 | DistributedWeightUpdater, 10 | ) 11 | from .ray import RayCollector 12 | from .rpc import RPCDataCollector 13 | from .sync import DistributedSyncDataCollector 14 | from .utils import submitit_delayed_launcher 15 | 16 | __all__ = [ 17 | "DEFAULT_SLURM_CONF", 18 | "DistributedDataCollector", 19 | "DistributedWeightUpdater", 20 | "DistributedSyncDataCollector", 21 | "RPCDataCollector", 22 | "RPCDataCollector", 23 | "RayCollector", 24 | "submitit_delayed_launcher", 25 | ] 26 | -------------------------------------------------------------------------------- /torchrl/collectors/distributed/default_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import os 8 | 9 | TCP_PORT = os.environ.get("TCP_PORT", "10003") 10 | IDLE_TIMEOUT = os.environ.get("RCP_IDLE_TIMEOUT", 10) 11 | 12 | MAX_TIME_TO_CONNECT = 1000 13 | 14 | SLEEP_INTERVAL = 1e-6 15 | 16 | DEFAULT_SLURM_CONF = { 17 | "timeout_min": 10, 18 | "slurm_partition": "train", 19 | "slurm_cpus_per_task": 32, 20 | "slurm_gpus_per_node": 0, 21 | } #: Default value of the SLURM jobs 22 | 23 | DEFAULT_SLURM_CONF_MAIN = { 24 | "timeout_min": 10, 25 | "slurm_partition": "train", 26 | "slurm_cpus_per_task": 32, 27 | "slurm_gpus_per_node": 1, 28 | } #: Default value of the SLURM main job 29 | 30 | DEFAULT_TENSORPIPE_OPTIONS = { 31 | "num_worker_threads": 16, 32 | "rpc_timeout": 10_000, 33 | "_transports": ["uv"], 34 | } 35 | -------------------------------------------------------------------------------- /torchrl/collectors/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .base import LLMCollector 7 | from .weight_update import vLLMUpdater 8 | 9 | __all__ = ["vLLMUpdater", "LLMCollector"] 10 | -------------------------------------------------------------------------------- /torchrl/collectors/llm/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import importlib.util 8 | 9 | from queue import Full as QueueFull, Queue 10 | 11 | from tensordict import TensorDictBase 12 | 13 | from torchrl._utils import logger as torchrl_logger 14 | 15 | _has_ray = importlib.util.find_spec("ray") is not None 16 | 17 | 18 | class _QueueAsRB: 19 | def __init__(self, queue: Queue | ray.util.queue.Queue): # noqa 20 | if not _has_ray: 21 | raise ImportError("Ray not installed.") 22 | self.queue = queue 23 | 24 | def extend(self, data: TensorDictBase): 25 | from ray.util.queue import Full as RayQueueFull 26 | 27 | # unbind the data and put in the queue 28 | for item in data.unbind(0): 29 | while True: 30 | try: 31 | self.queue.put_nowait(item) 32 | break 33 | except (QueueFull, RayQueueFull): 34 | self.queue.get() # Remove the oldest item to make space 35 | torchrl_logger.warn("rollout queue full. Discarding data.") 36 | return 37 | -------------------------------------------------------------------------------- /torchrl/collectors/llm/weight_update/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from .vllm import vLLMUpdater 8 | 9 | __all__ = ["vLLMUpdater"] 10 | -------------------------------------------------------------------------------- /torchrl/csrc/numpy_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace py = pybind11; 15 | 16 | namespace torchrl { 17 | namespace utils { 18 | 19 | template 20 | std::vector NumpyArrayShape(const py::array_t& arr) { 21 | const int64_t ndim = arr.ndim(); 22 | std::vector shape(ndim); 23 | for (int64_t i = 0; i < ndim; ++i) { 24 | shape[i] = static_cast(arr.shape(i)); 25 | } 26 | return shape; 27 | } 28 | 29 | template 30 | py::array_t NumpyEmptyLike(const py::array_t& src) { 31 | py::array_t dst(src.size()); 32 | const std::vector shape = NumpyArrayShape(src); 33 | dst.resize(shape); 34 | return dst; 35 | } 36 | 37 | } // namespace utils 38 | } // namespace torchrl 39 | -------------------------------------------------------------------------------- /torchrl/csrc/pybind.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #include "segment_tree.h" 14 | #include "utils.h" 15 | 16 | namespace py = pybind11; 17 | 18 | PYBIND11_MODULE(_torchrl, m) { 19 | torchrl::DefineSumSegmentTree("Fp32", m); 20 | torchrl::DefineSumSegmentTree("Fp64", m); 21 | 22 | torchrl::DefineMinSegmentTree("Fp32", m); 23 | torchrl::DefineMinSegmentTree("Fp64", m); 24 | 25 | m.def("safetanh", &safetanh, "Safe Tanh"); 26 | m.def("safeatanh", &safeatanh, "Safe Inverse Tanh"); 27 | } 28 | -------------------------------------------------------------------------------- /torchrl/csrc/torch_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | 10 | #include 11 | 12 | namespace torchrl { 13 | namespace utils { 14 | 15 | template 16 | struct TorchDataType; 17 | 18 | template <> 19 | struct TorchDataType { 20 | static constexpr torch::ScalarType value = torch::kInt64; 21 | }; 22 | 23 | template <> 24 | struct TorchDataType { 25 | static constexpr torch::ScalarType value = torch::kFloat; 26 | }; 27 | 28 | template <> 29 | struct TorchDataType { 30 | static constexpr torch::ScalarType value = torch::kDouble; 31 | }; 32 | 33 | } // namespace utils 34 | } // namespace torchrl 35 | -------------------------------------------------------------------------------- /torchrl/csrc/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | // utils.h 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6); 13 | torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6); 14 | 15 | class SafeTanh : public torch::autograd::Function { 16 | public: 17 | static torch::Tensor forward(torch::autograd::AutogradContext* ctx, 18 | torch::Tensor input, float eps); 19 | static torch::autograd::tensor_list backward( 20 | torch::autograd::AutogradContext* ctx, 21 | torch::autograd::tensor_list grad_outputs); 22 | }; 23 | 24 | class SafeInvTanh : public torch::autograd::Function { 25 | public: 26 | static torch::Tensor forward(torch::autograd::AutogradContext* ctx, 27 | torch::Tensor input, float eps); 28 | static torch::autograd::tensor_list backward( 29 | torch::autograd::AutogradContext* ctx, 30 | torch::autograd::tensor_list grad_outputs); 31 | }; 32 | -------------------------------------------------------------------------------- /torchrl/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /torchrl/data/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import os 8 | 9 | 10 | def _get_root_dir(dataset: str): 11 | return os.path.join(os.path.expanduser("~"), ".cache", "torchrl", dataset) 12 | -------------------------------------------------------------------------------- /torchrl/data/datasets/vd4rl_datasets.json: -------------------------------------------------------------------------------- 1 | ["distracting/walker_walk_random/64px/easy", "distracting/walker_walk_random/64px/hard", "distracting/walker_walk_random/64px/medium", "main/cheetah_run/expert/64px", "main/cheetah_run/medium/64px", "main/cheetah_run/medium_expert/64px", "main/cheetah_run/medium_replay/64px", "main/cheetah_run/random/64px", "main/humanoid_walk/expert/64px", "main/humanoid_walk/medium/64px", "main/humanoid_walk/medium_expert/64px", "main/humanoid_walk/medium_replay/64px", "main/humanoid_walk/random/64px", "main/walker_walk/expert/64px", "main/walker_walk/medium/64px", "main/walker_walk/medium_expert/64px", "main/walker_walk/medium_replay/64px", "main/walker_walk/random/64px", "multitask/cheetah_run_random/64px", "multitask/walker_walk_random/64px"] 2 | -------------------------------------------------------------------------------- /torchrl/data/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .chat import History 7 | from .common import LLMData 8 | from .dataset import ( 9 | create_infinite_iterator, 10 | get_dataloader, 11 | TensorDictTokenizer, 12 | TokenizedDatasetLoader, 13 | ) 14 | from .prompt import PromptData, PromptTensorDictTokenizer 15 | from .reward import PairwiseDataset, RewardData 16 | from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel 17 | 18 | __all__ = [ 19 | "AdaptiveKLController", 20 | "History", 21 | "ConstantKLController", 22 | "LLMData", 23 | "PairwiseDataset", 24 | "PromptData", 25 | "PromptTensorDictTokenizer", 26 | "RewardData", 27 | "RolloutFromModel", 28 | "TensorDictTokenizer", 29 | "TokenizedDatasetLoader", 30 | "create_infinite_iterator", 31 | "get_dataloader", 32 | ] 33 | -------------------------------------------------------------------------------- /torchrl/data/map/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hash import BinaryToDecimal, RandomProjectionHash, SipHash 7 | from .query import HashToInt, QueryModule 8 | from .tdstorage import TensorDictMap, TensorMap 9 | from .tree import MCTSForest, Tree 10 | 11 | __all__ = [ 12 | "BinaryToDecimal", 13 | "RandomProjectionHash", 14 | "SipHash", 15 | "HashToInt", 16 | "QueryModule", 17 | "TensorDictMap", 18 | "TensorMap", 19 | "MCTSForest", 20 | "Tree", 21 | ] 22 | -------------------------------------------------------------------------------- /torchrl/data/postprocs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .postprocs import DensifyReward, MultiStep 7 | 8 | __all__ = ["MultiStep", "DensifyReward"] 9 | -------------------------------------------------------------------------------- /torchrl/data/rlhf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import warnings 6 | 7 | from torchrl.data.llm import ( 8 | AdaptiveKLController, 9 | ConstantKLController, 10 | create_infinite_iterator, 11 | get_dataloader, 12 | PairwiseDataset, 13 | PromptData, 14 | PromptTensorDictTokenizer, 15 | RewardData, 16 | RolloutFromModel, 17 | TensorDictTokenizer, 18 | TokenizedDatasetLoader, 19 | ) 20 | 21 | __all__ = [ 22 | "create_infinite_iterator", 23 | "get_dataloader", 24 | "TensorDictTokenizer", 25 | "TokenizedDatasetLoader", 26 | "PromptData", 27 | "PromptTensorDictTokenizer", 28 | "PairwiseDataset", 29 | "RewardData", 30 | "AdaptiveKLController", 31 | "ConstantKLController", 32 | "RolloutFromModel", 33 | ] 34 | 35 | warnings.warn( 36 | "Imports from torchrl.data.rlhf have moved to torchrl.data.llm. " 37 | "torchrl.data.rlhf will be deprecated in v0.10.", 38 | category=DeprecationWarning, 39 | ) 40 | -------------------------------------------------------------------------------- /torchrl/envs/custom/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .chess import ChessEnv 7 | from .llm import LLMHashingEnv 8 | from .pendulum import PendulumEnv 9 | from .tictactoeenv import TicTacToeEnv 10 | 11 | __all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv"] 12 | -------------------------------------------------------------------------------- /torchrl/envs/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .chat import ChatEnv, DatasetChatEnv 7 | from .datasets import ( 8 | GSM8KEnv, 9 | GSM8KPrepareQuestion, 10 | IFEvalData, 11 | IFEvalEnv, 12 | make_gsm8k_env, 13 | ) 14 | from .envs import LLMEnv, LLMHashingEnv 15 | from .libs import make_mlgym, MLGymWrapper 16 | from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer 17 | from .transforms import ( 18 | as_nested_tensor, 19 | as_padded_tensor, 20 | DataLoadingPrimer, 21 | KLRewardTransform, 22 | TemplateTransform, 23 | Tokenizer, 24 | ) 25 | 26 | __all__ = [ 27 | "ChatEnv", 28 | "DatasetChatEnv", 29 | "GSM8KEnv", 30 | "make_gsm8k_env", 31 | "GSM8KPrepareQuestion", 32 | "IFEvalEnv", 33 | "IFEvalData", 34 | "LLMEnv", 35 | "LLMHashingEnv", 36 | "as_nested_tensor", 37 | "as_padded_tensor", 38 | "DataLoadingPrimer", 39 | "GSM8KRewardParser", 40 | "make_mlgym", 41 | "IFEvalScoreData", 42 | "MLGymWrapper", 43 | "KLRewardTransform", 44 | "TemplateTransform", 45 | "Tokenizer", 46 | "IfEvalScorer", 47 | ] 48 | -------------------------------------------------------------------------------- /torchrl/envs/llm/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | This folder contains utils for specific datasets, such as reward parsers or pre-build 4 | environments. 5 | -------------------------------------------------------------------------------- /torchrl/envs/llm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env 8 | from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer 9 | 10 | __all__ = [ 11 | "make_gsm8k_env", 12 | "GSM8KPrepareQuestion", 13 | "GSM8KEnv", 14 | "IFEvalEnv", 15 | "IFEvalData", 16 | "IfEvalScorer", 17 | ] 18 | -------------------------------------------------------------------------------- /torchrl/envs/llm/libs/README.md: -------------------------------------------------------------------------------- 1 | ## Library wrappers 2 | 3 | This folder offers a list of wrappers for popular tooling libraries. 4 | -------------------------------------------------------------------------------- /torchrl/envs/llm/libs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .mlgym import make_mlgym, MLGymWrapper 7 | 8 | __all__ = ["make_mlgym", "MLGymWrapper"] 9 | -------------------------------------------------------------------------------- /torchrl/envs/llm/reward/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from .gsm8k import GSM8KRewardParser 8 | from .ifeval import IFEvalScoreData, IfEvalScorer 9 | 10 | __all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData"] 11 | -------------------------------------------------------------------------------- /torchrl/envs/llm/reward/ifeval/README.md: -------------------------------------------------------------------------------- 1 | # Adapted Code from SkyThought 2 | 3 | This project includes code adapted from [SkyThought](https://github.com/NovaSky-AI/SkyThought), specifically the file 4 | [`ifeval_scorer.py`](https://github.com/NovaSky-AI/SkyThought/blob/2e5db2b26be63c5545d93be4ad08f5ca46449776/skythought/evals/scoring/ifeval/ifeval_scorer.py). 5 | 6 | Parts of these files are themselves copied from other sources with a similar license. 7 | 8 | The original code is distributed under the Apache 2.0 license, which can be found in the SkyThought repository: [Apache 2.0 License](https://github.com/NovaSky-AI/SkyThought/blob/main/LICENSE). 9 | 10 | ### Modifications 11 | 12 | Modifications were made to the original code according to the terms of the Apache 2.0 license. The changes include 13 | TorchRL formatting of the data using TensorDict and TorchRL's transforms. 14 | -------------------------------------------------------------------------------- /torchrl/envs/llm/reward/ifeval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | from ._scorer import IFEvalScoreData, IfEvalScorer 9 | 10 | __all__ = ["IfEvalScorer", "IFEvalScoreData"] 11 | -------------------------------------------------------------------------------- /torchrl/envs/llm/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer 7 | from .format import TemplateTransform 8 | from .kl import KLRewardTransform 9 | from .tokenizer import Tokenizer 10 | 11 | __all__ = [ 12 | "DataLoadingPrimer", 13 | "Tokenizer", 14 | "TemplateTransform", 15 | "KLRewardTransform", 16 | "as_nested_tensor", 17 | "as_padded_tensor", 18 | ] 19 | -------------------------------------------------------------------------------- /torchrl/envs/model_based/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .common import ModelBasedEnvBase 7 | from .dreamer import DreamerDecoder, DreamerEnv 8 | 9 | __all__ = ["ModelBasedEnvBase", "DreamerDecoder", "DreamerEnv"] 10 | -------------------------------------------------------------------------------- /torchrl/envs/transforms/rlhf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import warnings 6 | 7 | from .llm import ( 8 | as_nested_tensor, 9 | as_padded_tensor, 10 | DataLoadingPrimer, 11 | KLRewardTransform, 12 | ) 13 | 14 | __all__ = [ 15 | "as_padded_tensor", 16 | "as_nested_tensor", 17 | "DataLoadingPrimer", 18 | "KLRewardTransform", 19 | ] 20 | 21 | warnings.warn( 22 | "Imports from torchrl.envs.transforms.rlhf have moved to torchrl.envs.transforms.llm. " 23 | "torchrl.envs.transforms.rlhf will be deprecated in v0.10.", 24 | category=DeprecationWarning, 25 | ) 26 | -------------------------------------------------------------------------------- /torchrl/envs/vec_envs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import warnings 8 | 9 | warnings.warn("vec_env.py has moved to batch_envs.py.", category=DeprecationWarning) 10 | 11 | from .batched_envs import * # noqa: F403, F401 12 | -------------------------------------------------------------------------------- /torchrl/modules/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from .backends import ( 8 | LLMOnDevice, 9 | make_vllm_worker, 10 | stateless_init_process_group, 11 | vLLMWorker, 12 | ) 13 | 14 | from .policies import CategoricalSequential, TransformersWrapper, vLLMWrapper 15 | 16 | __all__ = [ 17 | "CategoricalSequential", 18 | "LLMOnDevice", 19 | "TransformersWrapper", 20 | "make_vllm_worker", 21 | "stateless_init_process_group", 22 | "vLLMWorker", 23 | "vLLMWrapper", 24 | ] 25 | -------------------------------------------------------------------------------- /torchrl/modules/llm/backends/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from .vllm import ( 8 | LLMOnDevice, 9 | make_vllm_worker, 10 | stateless_init_process_group, 11 | vLLMWorker, 12 | ) 13 | 14 | __all__ = [ 15 | "vLLMWorker", 16 | "stateless_init_process_group", 17 | "make_vllm_worker", 18 | "LLMOnDevice", 19 | ] 20 | -------------------------------------------------------------------------------- /torchrl/modules/llm/policies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | from .common import CategoricalSequential 9 | from .transformers_wrapper import TransformersWrapper 10 | 11 | from .vllm_wrapper import vLLMWrapper 12 | 13 | __all__ = ["TransformersWrapper", "vLLMWrapper", "CategoricalSequential"] 14 | -------------------------------------------------------------------------------- /torchrl/modules/llm/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | import contextlib 9 | import os 10 | 11 | import torch 12 | 13 | 14 | @contextlib.contextmanager 15 | def _cuda_visible_devices(devices: list[torch.device | int]): 16 | devices = [torch.device(d).index if not isinstance(d, int) else d for d in devices] 17 | CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES") 18 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices)) 19 | yield 20 | if CUDA_VISIBLE_DEVICES: 21 | os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES 22 | else: 23 | os.unsetenv("CUDA_VISIBLE_DEVICES") 24 | -------------------------------------------------------------------------------- /torchrl/modules/models/rlhf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import warnings 6 | 7 | from .llm import GPT2RewardModel 8 | 9 | __all__ = ["GPT2RewardModel"] 10 | 11 | warnings.warn( 12 | "Imports from torchrl.modules.models.rlhf have moved to torchrl.modules.models.llm. " 13 | "torchrl.modules.models.rlhf will be deprecated in v0.10.", 14 | category=DeprecationWarning, 15 | ) 16 | -------------------------------------------------------------------------------- /torchrl/modules/planners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .cem import CEMPlanner 7 | from .common import MPCPlannerBase 8 | from .mppi import MPPIPlanner 9 | 10 | __all__ = ["CEMPlanner", "MPCPlannerBase", "MPPIPlanner"] 11 | -------------------------------------------------------------------------------- /torchrl/modules/tensordict_module/world_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from tensordict.nn import TensorDictModule, TensorDictSequential 8 | 9 | 10 | class WorldModelWrapper(TensorDictSequential): 11 | """World model wrapper. 12 | 13 | This module wraps together a transition model and a reward model. 14 | The transition model is used to predict an imaginary world state. 15 | The reward model is used to predict the reward of the imagined transition. 16 | 17 | Args: 18 | transition_model (TensorDictModule): a transition model that generates a new world states. 19 | reward_model (TensorDictModule): a reward model, that reads the world state and returns a reward. 20 | 21 | """ 22 | 23 | def __init__( 24 | self, transition_model: TensorDictModule, reward_model: TensorDictModule 25 | ): 26 | super().__init__(transition_model, reward_model) 27 | 28 | def get_transition_model_operator(self) -> TensorDictModule: 29 | """Returns a transition operator that maps either an observation to a world state or a world state to the next world state.""" 30 | return self.module[0] 31 | 32 | def get_reward_operator(self) -> TensorDictModule: 33 | """Returns a reward operator that maps a world state to a reward.""" 34 | return self.module[1] 35 | -------------------------------------------------------------------------------- /torchrl/modules/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import OrderedDict 7 | 8 | import torch 9 | from packaging import version 10 | 11 | 12 | if version.parse(torch.__version__) >= version.parse("1.12.0"): 13 | from torch.nn.parameter import _ParameterMeta 14 | else: 15 | pass 16 | 17 | # Metaclass to combine _TensorMeta and the instance check override for Parameter. 18 | class _ParameterMeta(torch._C._TensorMeta): 19 | # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. 20 | def __instancecheck__(self, instance): 21 | return super().__instancecheck__(instance) or ( 22 | isinstance(instance, torch.Tensor) 23 | and getattr(instance, "_is_param", False) 24 | ) 25 | 26 | 27 | from .mappings import biased_softplus, inv_softplus, mappings 28 | from .utils import get_primers_from_module 29 | 30 | __all__ = [ 31 | "OrderedDict", 32 | "torch", 33 | "version", 34 | "biased_softplus", 35 | "inv_softplus", 36 | "mappings", 37 | "get_primers_from_module", 38 | ] 39 | -------------------------------------------------------------------------------- /torchrl/modules/utils/mappings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from tensordict.nn.utils import biased_softplus, expln, inv_softplus, mappings 8 | 9 | __all__ = ["biased_softplus", "expln", "inv_softplus", "mappings"] 10 | -------------------------------------------------------------------------------- /torchrl/objectives/llm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage 8 | 9 | __all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage"] 10 | -------------------------------------------------------------------------------- /torchrl/objectives/multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .qmixer import QMixerLoss 7 | 8 | __all__ = ["QMixerLoss"] 9 | -------------------------------------------------------------------------------- /torchrl/objectives/value/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .advantages import ( 7 | GAE, 8 | TD0Estimate, 9 | TD0Estimator, 10 | TD1Estimate, 11 | TD1Estimator, 12 | TDLambdaEstimate, 13 | TDLambdaEstimator, 14 | ValueEstimatorBase, 15 | VTrace, 16 | ) 17 | 18 | __all__ = [ 19 | "GAE", 20 | "TD0Estimate", 21 | "TD0Estimator", 22 | "TD1Estimate", 23 | "TD1Estimator", 24 | "TDLambdaEstimate", 25 | "TDLambdaEstimator", 26 | "ValueEstimatorBase", 27 | "VTrace", 28 | ] 29 | -------------------------------------------------------------------------------- /torchrl/record/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger 7 | from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder 8 | 9 | __all__ = [ 10 | "CSVLogger", 11 | "MLFlowLogger", 12 | "TensorboardLogger", 13 | "WandbLogger", 14 | "PixelRenderTransform", 15 | "TensorDictRecorder", 16 | "VideoRecorder", 17 | ] 18 | -------------------------------------------------------------------------------- /torchrl/record/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .common import Logger 7 | 8 | from .csv import CSVLogger 9 | from .mlflow import MLFlowLogger 10 | from .tensorboard import TensorboardLogger 11 | from .utils import generate_exp_name, get_logger 12 | 13 | from .wandb import WandbLogger 14 | 15 | __all__ = [ 16 | "Logger", 17 | "CSVLogger", 18 | "MLFlowLogger", 19 | "TensorboardLogger", 20 | "generate_exp_name", 21 | "get_logger", 22 | "WandbLogger", 23 | ] 24 | -------------------------------------------------------------------------------- /torchrl/record/loggers/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | import abc 8 | from typing import Sequence 9 | 10 | from torch import Tensor 11 | 12 | 13 | __all__ = ["Logger"] 14 | 15 | 16 | class Logger: 17 | """A template for loggers.""" 18 | 19 | def __init__(self, exp_name: str, log_dir: str) -> None: 20 | self.exp_name = exp_name 21 | self.log_dir = log_dir 22 | self.experiment = self._create_experiment() 23 | 24 | @abc.abstractmethod 25 | def _create_experiment(self) -> Experiment: # noqa: F821 26 | ... 27 | 28 | @abc.abstractmethod 29 | def log_scalar(self, name: str, value: float, step: int = None) -> None: 30 | ... 31 | 32 | @abc.abstractmethod 33 | def log_video(self, name: str, video: Tensor, step: int = None, **kwargs) -> None: 34 | ... 35 | 36 | @abc.abstractmethod 37 | def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 38 | ... 39 | 40 | @abc.abstractmethod 41 | def __repr__(self) -> str: 42 | ... 43 | 44 | @abc.abstractmethod 45 | def log_histogram(self, name: str, data: Sequence, **kwargs): 46 | ... 47 | -------------------------------------------------------------------------------- /torchrl/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .trainers import ( 7 | BatchSubSampler, 8 | ClearCudaCache, 9 | CountFramesLog, 10 | LogReward, 11 | LogScalar, 12 | LogValidationReward, 13 | mask_batch, 14 | OptimizerHook, 15 | Recorder, 16 | ReplayBufferTrainer, 17 | RewardNormalizer, 18 | SelectKeys, 19 | Trainer, 20 | TrainerHookBase, 21 | UpdateWeights, 22 | ) 23 | 24 | __all__ = [ 25 | "BatchSubSampler", 26 | "ClearCudaCache", 27 | "CountFramesLog", 28 | "LogReward", 29 | "LogScalar", 30 | "LogValidationReward", 31 | "mask_batch", 32 | "OptimizerHook", 33 | "Recorder", 34 | "ReplayBufferTrainer", 35 | "RewardNormalizer", 36 | "SelectKeys", 37 | "Trainer", 38 | "TrainerHookBase", 39 | "UpdateWeights", 40 | ] 41 | -------------------------------------------------------------------------------- /torchrl/trainers/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .collectors import ( 7 | make_collector_offpolicy, 8 | make_collector_onpolicy, 9 | sync_async_collector, 10 | sync_sync_collector, 11 | ) 12 | from .envs import ( 13 | correct_for_frame_skip, 14 | get_stats_random_rollout, 15 | parallel_env_constructor, 16 | transformed_env_constructor, 17 | ) 18 | from .logger import LoggerConfig 19 | from .losses import make_dqn_loss, make_target_updater 20 | from .models import make_dqn_actor, make_dreamer 21 | from .replay_buffer import make_replay_buffer 22 | from .trainers import make_trainer 23 | 24 | __all__ = [ 25 | "make_collector_offpolicy", 26 | "make_collector_onpolicy", 27 | "sync_async_collector", 28 | "sync_sync_collector", 29 | "correct_for_frame_skip", 30 | "get_stats_random_rollout", 31 | "parallel_env_constructor", 32 | "transformed_env_constructor", 33 | "LoggerConfig", 34 | "make_dqn_loss", 35 | "make_target_updater", 36 | "make_dqn_actor", 37 | "make_dreamer", 38 | "make_replay_buffer", 39 | "make_trainer", 40 | ] 41 | -------------------------------------------------------------------------------- /torchrl/trainers/helpers/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from __future__ import annotations 6 | 7 | from dataclasses import dataclass, field 8 | from typing import Any 9 | 10 | 11 | @dataclass 12 | class LoggerConfig: 13 | """Logger config data-class.""" 14 | 15 | logger: str = "csv" 16 | # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv' 17 | record_video: bool = False 18 | # whether a video of the task should be rendered during logging. 19 | no_video: bool = True 20 | # whether a video of the task should be rendered during logging. 21 | exp_name: str = "" 22 | # experiment name. Used for logging directory. 23 | # A date and uuid will be joined to account for multiple experiments with the same name. 24 | record_interval: int = 1000 25 | # number of batch collections in between two collections of validation rollouts. Default=1000. 26 | record_frames: int = 1000 27 | # number of steps in validation rollouts. " "Default=1000. 28 | recorder_log_keys: Any = field(default_factory=lambda: None) 29 | # Keys to log in the recorder 30 | offline_logging: bool = True 31 | # If True, Wandb will do the logging offline 32 | project_name: str = "" 33 | # The name of the project for WandB 34 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | Get a sense of TorchRL functionalities through our [tutorials](https://pytorch.org/rl/stable/tutorials). 4 | 5 | The ["Getting Started"](https://pytorch.org/rl/stable/index.html#getting-started) section will help you model your first training loop with the library! 6 | 7 | The rest of the tutorials is split in [Basic](https://pytorch.org/rl/stable/index.html#basics), [Intermediate](https://pytorch.org/rl/stable/index.html#intermediate) and [Advanced](https://pytorch.org/rl/stable/index.html#advanced) sections. 8 | -------------------------------------------------------------------------------- /tutorials/media/transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rl/ba0faef45e7598e138b286387d560a9b2559ef5e/tutorials/media/transformer.png -------------------------------------------------------------------------------- /tutorials/sphinx-tutorials/README.rst: -------------------------------------------------------------------------------- 1 | README Tutos 2 | ============ 3 | 4 | Check the tutorials on torchrl documentation: https://pytorch.org/rl 5 | -------------------------------------------------------------------------------- /tutorials/sphinx-tutorials/run_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -v 5 | 6 | # Allows you to run all the tutorials without building the docset. 7 | 8 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 9 | 10 | # loop through all the .py files in the directory 11 | for file in $(ls -r "$DIR"/*.py) 12 | do 13 | # execute each Python script using the 'exec' function 14 | echo $file 15 | python -c """ 16 | with open('$file') as f: 17 | source = f.read() 18 | code = compile(source, '$file', 'exec') 19 | 20 | exec(code) 21 | """ 22 | done 23 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.9.0 2 | --------------------------------------------------------------------------------