├── .dockerignore ├── .github ├── FUNDING.yml ├── issue_template.md ├── pull_request_template.md └── workflows │ ├── constraints.txt │ ├── poetry-lock-export-ubuntu.yaml │ ├── poetry-lock-export.yaml │ ├── pre-commit.yml │ ├── tests.yaml │ └── utils_test.yaml ├── .gitignore ├── .gitpod.Dockerfile ├── .gitpod.yml ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── benchmark ├── c51.sh ├── cleanrl_1gpu.slurm_template ├── ddpg.sh ├── ddpg_plot.sh ├── dqn.sh ├── ppg.sh ├── ppo.sh ├── ppo_plot.sh ├── ppo_trxl.sh ├── pqn.sh ├── pqn_plot.sh ├── qdagger.sh ├── rainbow.sh ├── rnd.sh ├── rpo.sh ├── sac.sh ├── sac_atari.sh ├── sac_plot.sh ├── td3.sh ├── td3_plot.sh └── zoo.sh ├── cleanrl ├── c51.py ├── c51_atari.py ├── c51_atari_jax.py ├── c51_jax.py ├── ddpg_continuous_action.py ├── ddpg_continuous_action_jax.py ├── dqn.py ├── dqn_atari.py ├── dqn_atari_jax.py ├── dqn_jax.py ├── ppg_procgen.py ├── ppo.py ├── ppo_atari.py ├── ppo_atari_envpool.py ├── ppo_atari_envpool_xla_jax.py ├── ppo_atari_envpool_xla_jax_scan.py ├── ppo_atari_lstm.py ├── ppo_atari_multigpu.py ├── ppo_continuous_action.py ├── ppo_continuous_action_isaacgym │ ├── isaacgym │ │ ├── poetry.lock │ │ └── pyproject.toml │ └── ppo_continuous_action_isaacgym.py ├── ppo_pettingzoo_ma_atari.py ├── ppo_procgen.py ├── ppo_rnd_envpool.py ├── ppo_trxl │ ├── enjoy.py │ ├── poetry.lock │ ├── pom_env.py │ ├── ppo_trxl.py │ └── pyproject.toml ├── pqn.py ├── pqn_atari_envpool.py ├── pqn_atari_envpool_lstm.py ├── qdagger_dqn_atari_impalacnn.py ├── qdagger_dqn_atari_jax_impalacnn.py ├── rainbow_atari.py ├── rpo_continuous_action.py ├── sac_atari.py ├── sac_continuous_action.py ├── td3_continuous_action.py └── td3_continuous_action_jax.py ├── cleanrl_utils ├── __init__.py ├── add_header.py ├── benchmark.py ├── buffers.py ├── docker_build.py ├── docker_queue.py ├── enjoy.py ├── evals │ ├── __init__.py │ ├── c51_eval.py │ ├── c51_jax_eval.py │ ├── ddpg_eval.py │ ├── ddpg_jax_eval.py │ ├── dqn_eval.py │ ├── dqn_jax_eval.py │ ├── ppo_envpool_jax_eval.py │ ├── ppo_eval.py │ ├── td3_eval.py │ └── td3_jax_eval.py ├── huggingface.py ├── paper_plot.py ├── plot.py ├── plot_individual.py ├── reproduce.py ├── resume.py ├── submit_exp.py └── tuner.py ├── cloud ├── .gitignore ├── examples │ ├── submit_exp.sh │ └── terminate_all.sh ├── main.tf └── modules │ └── cleanrl │ ├── main.tf │ ├── setups.tf │ └── variables.tf ├── docs ├── CNAME ├── advanced │ ├── hyperparameter-tuning.md │ ├── optuna-dashboard-1.png │ ├── optuna-dashboard-2.png │ ├── optuna-results.png │ └── resume-training.md ├── benchmark │ ├── ddpg.md │ ├── ppo.md │ ├── ppo_atari.md │ ├── ppo_atari_envpool.md │ ├── ppo_atari_envpool_runtimes.md │ ├── ppo_atari_envpool_xla_jax.md │ ├── ppo_atari_envpool_xla_jax_runtimes.md │ ├── ppo_atari_envpool_xla_jax_scan.md │ ├── ppo_atari_envpool_xla_jax_scan_runtimes.md │ ├── ppo_atari_lstm.md │ ├── ppo_atari_lstm_runtimes.md │ ├── ppo_atari_multigpu.md │ ├── ppo_atari_multigpu_runtimes.md │ ├── ppo_atari_runtimes.md │ ├── ppo_continuous_action.md │ ├── ppo_continuous_action_runtimes.md │ ├── ppo_envpool.md │ ├── ppo_envpool_runtimes.md │ ├── ppo_procgen.md │ ├── ppo_procgen_runtimes.md │ ├── ppo_runtimes.md │ ├── sac.md │ ├── sac_runtimes.md │ ├── td3.md │ └── td3_runtimes.md ├── blog │ ├── .authors.yml │ ├── .meta.yml │ ├── index.md │ └── posts │ │ └── cleanrl-v1.md ├── cleanrl-supported-papers-projects.md ├── cloud │ ├── aws_batch1.png │ ├── aws_batch2.png │ ├── installation.md │ ├── submit-experiments.md │ └── wandb.png ├── contribution.md ├── css │ ├── custom.css │ └── termynal.css ├── get-started │ ├── CleanRL_Huggingface_Integration_Demo.ipynb │ ├── basic-usage.md │ ├── benchmark-utility.md │ ├── colab-badge.svg │ ├── examples.md │ ├── experiment-tracking.md │ ├── installation.md │ ├── tensorboard.png │ ├── videos.png │ ├── videos2.png │ └── zoo.md ├── index.md ├── js │ ├── chat.js │ ├── custom.js │ └── termynal.js ├── rl-algorithms │ ├── c51.md │ ├── c51 │ │ ├── Acrobot-v1.png │ │ ├── BeamRiderNoFrameskip-v4.png │ │ ├── BreakoutNoFrameskip-v4.png │ │ ├── CartPole-v1.png │ │ ├── MountainCar-v0.png │ │ ├── PongNoFrameskip-v4.png │ │ └── jax │ │ │ ├── Acrobot-v1-time.png │ │ │ ├── Acrobot-v1.png │ │ │ ├── BeamRiderNoFrameskip-v4-time.png │ │ │ ├── BeamRiderNoFrameskip-v4.png │ │ │ ├── BreakoutNoFrameskip-v4-time.png │ │ │ ├── BreakoutNoFrameskip-v4.png │ │ │ ├── CartPole-v1-time.png │ │ │ ├── CartPole-v1.png │ │ │ ├── MountainCar-v0-time.png │ │ │ ├── MountainCar-v0.png │ │ │ ├── PongNoFrameskip-v4-time.png │ │ │ └── PongNoFrameskip-v4.png │ ├── ddpg-jax │ │ ├── HalfCheetah-v2-time.png │ │ ├── HalfCheetah-v2.png │ │ ├── Hopper-v2-time.png │ │ ├── Hopper-v2.png │ │ ├── Walker2d-v2-time.png │ │ └── Walker2d-v2.png │ ├── ddpg.md │ ├── ddpg │ │ └── ddpg.png │ ├── dqn.md │ ├── dqn │ │ ├── Acrobot-v1.png │ │ ├── BeamRiderNoFrameskip-v4.png │ │ ├── BreakoutNoFrameskip-v4.png │ │ ├── CartPole-v1.png │ │ ├── MountainCar-v0.png │ │ ├── PongNoFrameskip-v4.png │ │ └── jax │ │ │ ├── Acrobot-v1-time.png │ │ │ ├── Acrobot-v1.png │ │ │ ├── BeamRiderNoFrameskip-v4-time.png │ │ │ ├── BeamRiderNoFrameskip-v4.png │ │ │ ├── BreakoutNoFrameskip-v4-time.png │ │ │ ├── BreakoutNoFrameskip-v4.png │ │ │ ├── CartPole-v1-time.png │ │ │ ├── CartPole-v1.png │ │ │ ├── MountainCar-v0-time.png │ │ │ ├── MountainCar-v0.png │ │ │ ├── PongNoFrameskip-v4-time.png │ │ │ └── PongNoFrameskip-v4.png │ ├── overview.md │ ├── ppg.md │ ├── ppg │ │ ├── BigFish.png │ │ ├── BossFight.png │ │ ├── StarPilot.png │ │ ├── comparison │ │ │ ├── BigFish.png │ │ │ ├── BossFight.png │ │ │ └── StarPilot.png │ │ └── ppg-ppo.png │ ├── ppo-isaacgymenvs.md │ ├── ppo-rnd.md │ ├── ppo-rnd │ │ ├── MontezumaRevenge-v5-time.png │ │ └── MontezumaRevenge-v5.png │ ├── ppo-trxl.md │ ├── ppo-trxl │ │ └── compare.png │ ├── ppo.md │ ├── ppo │ │ ├── Acrobot-v1.png │ │ ├── BeamRider-time.png │ │ ├── BeamRider.png │ │ ├── BeamRiderNoFrameskip-v4.png │ │ ├── BeamRiderNoFrameskip-v4multigpu-time.png │ │ ├── BeamRiderNoFrameskip-v4multigpu.png │ │ ├── BigFish.png │ │ ├── BossFight.png │ │ ├── Breakout-a.png │ │ ├── Breakout-time-a.png │ │ ├── Breakout-time.png │ │ ├── Breakout.png │ │ ├── BreakoutNoFrameskip-v4.png │ │ ├── BreakoutNoFrameskip-v4multigpu-time.png │ │ ├── BreakoutNoFrameskip-v4multigpu.png │ │ ├── CartPole-v1.png │ │ ├── HalfCheetah-v2.png │ │ ├── Hopper-v2.png │ │ ├── MountainCar-v0.png │ │ ├── Pong-time.png │ │ ├── Pong.png │ │ ├── PongNoFrameskip-v4.png │ │ ├── PongNoFrameskip-v4multigpu-time.png │ │ ├── PongNoFrameskip-v4multigpu.png │ │ ├── StarPilot.png │ │ ├── Walker2d-v2.png │ │ ├── isaacgymenvs │ │ │ ├── AllegroHand-c-time.png │ │ │ ├── AllegroHand-c.png │ │ │ ├── AllegroHand-time.png │ │ │ ├── AllegroHand.png │ │ │ ├── Ant-time.png │ │ │ ├── Ant.png │ │ │ ├── Anymal-time.png │ │ │ ├── Anymal.png │ │ │ ├── BallBalance-time.png │ │ │ ├── BallBalance.png │ │ │ ├── Cartpole-time.png │ │ │ ├── Cartpole.png │ │ │ ├── Humanoid-time.png │ │ │ ├── Humanoid.png │ │ │ ├── ShadowHand-c-time.png │ │ │ ├── ShadowHand-c.png │ │ │ ├── ShadowHand-time.png │ │ │ ├── ShadowHand.png │ │ │ └── old │ │ │ │ ├── AllegroHand.png │ │ │ │ ├── Ant.png │ │ │ │ ├── Anymal.png │ │ │ │ ├── BallBalance.png │ │ │ │ ├── Cartpole.png │ │ │ │ ├── Humanoid.png │ │ │ │ └── ShadowHand.png │ │ ├── lstm │ │ │ ├── BeamRiderNoFrameskip-v4.png │ │ │ ├── BreakoutNoFrameskip-v4.png │ │ │ └── PongNoFrameskip-v4.png │ │ ├── pong_v3.png │ │ ├── ppo-1-title.png │ │ ├── ppo-2-title.png │ │ ├── ppo-3-title.png │ │ ├── ppo_atari_envpool_xla_jax │ │ │ ├── atari_hns.md │ │ │ ├── atari_returns.md │ │ │ ├── hms_each_game.png │ │ │ ├── hms_each_game.svg │ │ │ ├── hns_ppo_vs_baselines.png │ │ │ ├── hns_ppo_vs_baselines.svg │ │ │ ├── hns_ppo_vs_baselines2.svg │ │ │ ├── hns_ppo_vs_r2d2.png │ │ │ ├── hns_ppo_vs_r2d2.svg │ │ │ ├── runset_0_hms_bar.png │ │ │ ├── runset_0_hms_bar.svg │ │ │ ├── runset_1_hms_bar.png │ │ │ └── runset_1_hms_bar.svg │ │ ├── ppo_atari_envpool_xla_jax_scan │ │ │ ├── compare-time.png │ │ │ └── compare.png │ │ ├── ppo_continuous_action_gymnasium_dm_control.png │ │ ├── ppo_continuous_action_gymnasium_mujoco_v2.png │ │ ├── ppo_continuous_action_gymnasium_mujoco_v4.png │ │ ├── ppo_continuous_action_v2_vs_v4.png │ │ ├── surround_v2.png │ │ └── tennis_v3.png │ ├── pqn.md │ ├── pqn │ │ ├── pqn.png │ │ ├── pqn_lstm.png │ │ └── pqn_state.png │ ├── qdagger.md │ ├── qdagger │ │ ├── BeamRiderNoFrameskip-v4.png │ │ ├── BreakoutNoFrameskip-v4.png │ │ ├── PongNoFrameskip-v4.png │ │ ├── compare.png │ │ └── jax │ │ │ ├── BeamRiderNoFrameskip-v4.png │ │ │ ├── BreakoutNoFrameskip-v4.png │ │ │ ├── PongNoFrameskip-v4.png │ │ │ └── compare.png │ ├── rainbow.md │ ├── rainbow │ │ ├── rainbow_c51_dqn_bars.png │ │ ├── rainbow_env_curves.png │ │ └── rainbow_sample_eff.png │ ├── rpo.md │ ├── rpo │ │ ├── dm_control_all_ppo_rpo_8M.png │ │ ├── gym.png │ │ ├── mujoco_v2_failure_0_5.png │ │ ├── mujoco_v2_part1.png │ │ ├── mujoco_v2_part2.png │ │ ├── mujoco_v2_part2_0_5.png │ │ ├── mujoco_v4_failure_0_5.png │ │ ├── mujoco_v4_part1.png │ │ ├── mujoco_v4_part2.png │ │ └── mujoco_v4_part2_0_5.png │ ├── sac.md │ ├── sac │ │ ├── BeamRiderNoFrameskip-v4.png │ │ ├── BreakoutNoFrameskip-v4.png │ │ ├── HalfCheetah-v2.png │ │ ├── Hopper-v2.png │ │ ├── PongNoFrameskip-v4.png │ │ └── Walker2d-v2.png │ ├── td3-jax │ │ ├── HalfCheetah-v2-time.png │ │ ├── HalfCheetah-v2.png │ │ ├── Hopper-v2-time.png │ │ ├── Hopper-v2.png │ │ ├── Walker2d-v2-time.png │ │ └── Walker2d-v2.png │ ├── td3.md │ └── td3 │ │ ├── HalfCheetah-v2.png │ │ ├── Hopper-v2.png │ │ ├── Humanoid-v2.png │ │ ├── InvertedPendulum-v2.png │ │ ├── Pusher-v2.png │ │ └── Walker2d-v2.png ├── rlops │ ├── docs-update.png │ ├── rlops.png │ └── tags.png ├── static │ ├── blog │ │ └── cleanrl-v1 │ │ │ ├── github-action.png │ │ │ ├── hf.png │ │ │ └── rlops.png │ ├── o1.png │ ├── o2.png │ ├── o3.png │ └── pre-commit.png └── stylesheets │ └── extra.css ├── entrypoint.sh ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── requirements ├── requirements-atari.txt ├── requirements-cloud.txt ├── requirements-dm_control.txt ├── requirements-docs.txt ├── requirements-envpool.txt ├── requirements-jax.txt ├── requirements-memory_gym.txt ├── requirements-mujoco.txt ├── requirements-optuna.txt ├── requirements-pettingzoo.txt ├── requirements-procgen.txt └── requirements.txt ├── tests ├── test_atari.py ├── test_atari_gymnasium.py ├── test_atari_jax_gymnasium.py ├── test_atari_multigpu.py ├── test_classic_control.py ├── test_classic_control_gymnasium.py ├── test_classic_control_jax_gymnasium.py ├── test_enjoy.py ├── test_envpool.py ├── test_jax_compute_gae.py ├── test_mujoco.py ├── test_pettingzoo_ma_atari.py ├── test_procgen.py ├── test_tuner.py └── test_utils.py └── tuner_example.py /.dockerignore: -------------------------------------------------------------------------------- 1 | *.mp4 2 | *.pyc 3 | *.pyo 4 | *.log 5 | *.json 6 | **/wandb 7 | **/runs 8 | **/videos 9 | .git 10 | *.tfevents.* -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [vwxyzjn]# Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: ['https://www.buymeacoffee.com/dosssman']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/issue_template.md: -------------------------------------------------------------------------------- 1 | ## Problem Description 2 | 3 | 4 | ## Checklist 5 | - [ ] I have installed dependencies via `poetry install` (see [CleanRL's installation guideline](https://docs.cleanrl.dev/get-started/installation/). 6 | - [ ] I have checked that there is no similar [issue](https://github.com/vwxyzjn/cleanrl/issues) in the repo. 7 | - [ ] I have checked the [documentation site](https://docs.cleanrl.dev/) and found not relevant information in [GitHub issues](https://github.com/vwxyzjn/cleanrl/issues). 8 | 9 | ## Current Behavior 10 | 11 | 12 | ## Expected Behavior 13 | 14 | 15 | ## Possible Solution 16 | 17 | 18 | ## Steps to Reproduce 19 | 20 | 21 | 1. 22 | 2. 23 | 3. 24 | 4. 25 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Types of changes 5 | 6 | - [ ] Bug fix 7 | - [ ] New feature 8 | - [ ] New algorithm 9 | - [ ] Documentation 10 | 11 | ## Checklist: 12 | 13 | 14 | - [ ] I've read the [CONTRIBUTION](https://docs.cleanrl.dev/contribution/) guide (**required**). 15 | - [ ] I have ensured `pre-commit run --all-files` passes (**required**). 16 | - [ ] I have updated the tests accordingly (if applicable). 17 | - [ ] I have updated the documentation and previewed the changes via `mkdocs serve`. 18 | - [ ] I have explained note-worthy implementation details. 19 | - [ ] I have explained the logged metrics. 20 | - [ ] I have added links to the original paper and related papers. 21 | 22 | If you need to run benchmark experiments for a performance-impacting changes: 23 | 24 | - [ ] I have contacted @vwxyzjn to obtain access to the [openrlbenchmark W&B team](https://wandb.ai/openrlbenchmark). 25 | - [ ] I have used the [benchmark utility](/get-started/benchmark-utility/) to submit the tracked experiments to the [openrlbenchmark/cleanrl](https://wandb.ai/openrlbenchmark/cleanrl) W&B project, optionally with `--capture_video`. 26 | - [ ] I have performed RLops with `python -m openrlbenchmark.rlops`. 27 | - For new feature or bug fix: 28 | - [ ] I have used the RLops utility to understand the performance impact of the changes and confirmed there is no regression. 29 | - For new algorithm: 30 | - [ ] I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation). 31 | - [ ] I have added the learning curves generated by the `python -m openrlbenchmark.rlops` utility to the documentation. 32 | - [ ] I have added links to the tracked experiments in W&B, generated by `python -m openrlbenchmark.rlops ....your_args... --report`, to the documentation. 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/constraints.txt: -------------------------------------------------------------------------------- 1 | pip==22.0.3 2 | poetry==1.1.13 3 | virtualenv==20.13.1 -------------------------------------------------------------------------------- /.github/workflows/poetry-lock-export-ubuntu.yaml: -------------------------------------------------------------------------------- 1 | name: Poetry lock and export (Ubuntu host) 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | group: ${{ github.head_ref || github.run_id }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | create-ubuntu-lock-file: 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 30 16 | strategy: 17 | matrix: 18 | python-version: [3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Create virtualenv 29 | run: | 30 | which python 31 | python -m venv venv 32 | source venv/bin/activate 33 | python -m pip install --constraint=.github/workflows/constraints.txt --upgrade pip 34 | which python 35 | 36 | - name: Install Poetry 37 | run: | 38 | curl -sSL https://install.python-poetry.org | python3 - 39 | export PATH="/Users/runner/.local/bin:$PATH" 40 | poetry --version 41 | poetry config virtualenvs.in-project true 42 | poetry config virtualenvs.create false 43 | poetry config virtualenvs.path venv 44 | source venv/bin/activate 45 | which python 46 | 47 | - name: Run Poetry update 48 | run: | 49 | source venv/bin/activate 50 | export PATH="/Users/runner/.local/bin:$PATH" 51 | rm poetry.lock 52 | poetry update isaacgym 53 | poetry lock --no-update 54 | poetry export -f requirements.txt --output requirements.txt 55 | 56 | - name: Create Pull Request 57 | uses: peter-evans/create-pull-request@v3 58 | with: 59 | commit-message: Update poetry.lock and requirements.txt 60 | title: Update poetry.lock and requirements.txt 61 | body: Update poetry.lock and requirements.txt 62 | branch: poetry-lock-and-export -------------------------------------------------------------------------------- /.github/workflows/poetry-lock-export.yaml: -------------------------------------------------------------------------------- 1 | name: Poetry lock and export (MacOs host) 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | group: ${{ github.head_ref || github.run_id }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | create-macos-lock-file: 14 | runs-on: macos-latest 15 | timeout-minutes: 30 16 | strategy: 17 | matrix: 18 | python-version: [3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Create virtualenv 29 | run: | 30 | which python 31 | python -m venv venv 32 | source venv/bin/activate 33 | python -m pip install --constraint=.github/workflows/constraints.txt --upgrade pip 34 | which python 35 | 36 | - name: Install Poetry 37 | run: | 38 | curl -sSL https://install.python-poetry.org | python3 - 39 | export PATH="/Users/runner/.local/bin:$PATH" 40 | poetry --version 41 | poetry config virtualenvs.in-project true 42 | poetry config virtualenvs.create false 43 | poetry config virtualenvs.path venv 44 | source venv/bin/activate 45 | which python 46 | 47 | - name: Run Poetry update 48 | run: | 49 | source venv/bin/activate 50 | export PATH="/Users/runner/.local/bin:$PATH" 51 | rm poetry.lock 52 | poetry update isaacgym 53 | poetry lock --no-update 54 | poetry export -f requirements.txt --output requirements.txt 55 | 56 | - name: Create Pull Request 57 | uses: peter-evans/create-pull-request@v3 58 | with: 59 | commit-message: Update poetry.lock and requirements.txt 60 | title: Update poetry.lock and requirements.txt 61 | body: Update poetry.lock and requirements.txt 62 | branch: poetry-lock-and-export -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | paths-ignore: 6 | - 'docs/blog/*' # dummy ignore 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: [3.9] 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | with: 17 | fetch-depth: 0 18 | submodules: recursive 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - uses: pre-commit/action@v2.0.3 24 | with: 25 | extra_args: --hook-stage manual --all-files 26 | -------------------------------------------------------------------------------- /.github/workflows/utils_test.yaml: -------------------------------------------------------------------------------- 1 | name: utils_test 2 | on: 3 | pull_request: 4 | paths-ignore: 5 | - '**/README.md' 6 | - 'docs/**/*' 7 | - 'cloud/**/*' 8 | jobs: 9 | ci: 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | python-version: ["3.8", "3.9", "3.10"] 14 | poetry-version: ["1.7"] 15 | os: [ubuntu-22.04] 16 | runs-on: ${{ matrix.os }} 17 | steps: 18 | - uses: actions/checkout@v2 19 | - uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Run image 23 | uses: abatilo/actions-poetry@v2.0.0 24 | with: 25 | poetry-version: ${{ matrix.poetry-version }} 26 | 27 | - name: Install test dependencies 28 | run: poetry install -E pytest 29 | - name: Install cloud dependencies 30 | run: poetry install -E "pytest cloud" 31 | - name: Downgrade setuptools 32 | run: poetry run pip install setuptools==59.5.0 33 | - name: Run utils tests 34 | run: poetry run pytest tests/test_utils.py 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | slurm 2 | .aim 3 | runs 4 | balance_bot.xml 5 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/examples 6 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/isaacgym 7 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/LICENSE.txt 8 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/rlgpu_conda_env.yml 9 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/setup.py 10 | 11 | IsaacGym_Preview_3_Package.tar.gz 12 | IsaacGym_Preview_4_Package.tar.gz 13 | cleanrl_hpopt.db 14 | debug.sh.docker.sh 15 | docker_cache 16 | rl-video-*.mp4 17 | rl-video-*.json 18 | cleanrl_utils/charts_episode_reward 19 | tutorials 20 | .DS_Store 21 | *.tfevents.* 22 | wandb 23 | openaigym.* 24 | videos/* 25 | cleanrl/videos/* 26 | benchmark/**/*.svg 27 | benchmark/**/*.pkl 28 | mjkey.txt 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # pyenv 103 | # .python-version 104 | 105 | # celery beat schedule file 106 | celerybeat-schedule 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | -------------------------------------------------------------------------------- /.gitpod.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-full-vnc:latest 2 | USER gitpod 3 | RUN if ! grep -q "export PIP_USER=no" "$HOME/.bashrc"; then printf '%s\n' "export PIP_USER=no" >> "$HOME/.bashrc"; fi 4 | 5 | # install ubuntu dependencies 6 | ENV DEBIAN_FRONTEND=noninteractive 7 | RUN sudo apt-get update && \ 8 | sudo apt-get -y install xvfb ffmpeg git build-essential python-opengl 9 | 10 | # install python dependencies 11 | RUN mkdir cleanrl_utils && touch cleanrl_utils/__init__.py 12 | RUN pip install poetry --upgrade 13 | RUN poetry config virtualenvs.in-project true 14 | 15 | # install mujoco_py 16 | RUN sudo apt-get -y install wget unzip software-properties-common \ 17 | libgl1-mesa-dev \ 18 | libgl1-mesa-glx \ 19 | libglew-dev \ 20 | libosmesa6-dev patchelf 21 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | image: 2 | file: .gitpod.Dockerfile 3 | 4 | tasks: 5 | - init: poetry install 6 | 7 | # vscode: 8 | # extensions: 9 | # - learnpack.learnpack-vscode 10 | 11 | github: 12 | prebuilds: 13 | # enable for the master/default branch (defaults to true) 14 | master: true 15 | # enable for all branches in this repo (defaults to false) 16 | branches: true 17 | # enable for pull requests coming from this repo (defaults to true) 18 | pullRequests: true 19 | # enable for pull requests coming from forks (defaults to false) 20 | pullRequestsFromForks: true 21 | # add a "Review in Gitpod" button as a comment to pull requests (defaults to true) 22 | addComment: false 23 | # add a "Review in Gitpod" button to pull requests (defaults to false) 24 | addBadge: false 25 | # add a label once the prebuild is ready to pull requests (defaults to false) 26 | addLabel: prebuilt-in-gitpod 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/pyupgrade 3 | rev: v2.31.1 4 | hooks: 5 | - id: pyupgrade 6 | args: 7 | - --py37-plus 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | args: 13 | - --profile=black 14 | - --skip-glob=wandb/**/* 15 | - --thirdparty=wandb 16 | - repo: https://github.com/myint/autoflake 17 | rev: v1.4 18 | hooks: 19 | - id: autoflake 20 | args: 21 | - -r 22 | - --exclude=wandb 23 | - --in-place 24 | - --remove-unused-variables 25 | - --remove-all-unused-imports 26 | - repo: https://github.com/python/black 27 | rev: 22.3.0 28 | hooks: 29 | - id: black 30 | args: 31 | - --line-length=127 32 | - --exclude=wandb 33 | - repo: https://github.com/codespell-project/codespell 34 | rev: v2.1.0 35 | hooks: 36 | - id: codespell 37 | args: 38 | - --ignore-words-list=nd,reacher,thist,ths,magent,ba 39 | - --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb 40 | - repo: https://github.com/python-poetry/poetry 41 | rev: 1.3.2 42 | hooks: 43 | - id: poetry-export 44 | name: poetry-export requirements.txt 45 | args: ["--without-hashes", "-o", "requirements/requirements.txt"] 46 | stages: [manual] 47 | - id: poetry-export 48 | name: poetry-export requirements-atari.txt 49 | args: ["--without-hashes", "-o", "requirements/requirements-atari.txt", "-E", "atari"] 50 | stages: [manual] 51 | - id: poetry-export 52 | name: poetry-export requirements-mujoco.txt 53 | args: ["--without-hashes", "-o", "requirements/requirements-mujoco.txt", "-E", "mujoco"] 54 | stages: [manual] 55 | - id: poetry-export 56 | name: poetry-export requirements-dm_control.txt 57 | args: ["--without-hashes", "-o", "requirements/requirements-dm_control.txt", "-E", "dm_control"] 58 | stages: [manual] 59 | - id: poetry-export 60 | name: poetry-export requirements-procgen.txt 61 | args: ["--without-hashes", "-o", "requirements/requirements-procgen.txt", "-E", "procgen"] 62 | stages: [manual] 63 | - id: poetry-export 64 | name: poetry-export requirements-envpool.txt 65 | args: ["--without-hashes", "-o", "requirements/requirements-envpool.txt", "-E", "envpool"] 66 | stages: [manual] 67 | - id: poetry-export 68 | name: poetry-export requirements-pettingzoo.txt 69 | args: ["--without-hashes", "-o", "requirements/requirements-pettingzoo.txt", "-E", "pettingzoo"] 70 | stages: [manual] 71 | - id: poetry-export 72 | name: poetry-export requirements-jax.txt 73 | args: ["--without-hashes", "-o", "requirements/requirements-jax.txt", "-E", "jax"] 74 | stages: [manual] 75 | - id: poetry-export 76 | name: poetry-export requirements-optuna.txt 77 | args: ["--without-hashes", "-o", "requirements/requirements-optuna.txt", "-E", "optuna"] 78 | stages: [manual] 79 | - id: poetry-export 80 | name: poetry-export requirements-docs.txt 81 | args: ["--without-hashes", "-o", "requirements/requirements-docs.txt", "-E", "docs"] 82 | stages: [manual] 83 | - id: poetry-export 84 | name: poetry-export requirements-cloud.txt 85 | args: ["--without-hashes", "-o", "requirements/requirements-cloud.txt", "-E", "cloud"] 86 | stages: [manual] 87 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Please check out https://docs.cleanrl.dev/contribution/ for more detail. -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.4.2-runtime-ubuntu20.04 2 | 3 | # install ubuntu dependencies 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | RUN apt-get update && \ 6 | apt-get -y install python3-pip xvfb ffmpeg git build-essential python-opengl 7 | RUN ln -s /usr/bin/python3 /usr/bin/python 8 | 9 | # install python dependencies 10 | RUN mkdir cleanrl_utils && touch cleanrl_utils/__init__.py 11 | RUN pip install poetry --upgrade 12 | COPY pyproject.toml pyproject.toml 13 | COPY poetry.lock poetry.lock 14 | RUN poetry install 15 | 16 | # install mujoco_py 17 | RUN apt-get -y install wget unzip software-properties-common \ 18 | libgl1-mesa-dev \ 19 | libgl1-mesa-glx \ 20 | libglew-dev \ 21 | libosmesa6-dev patchelf 22 | RUN poetry install -E "atari mujoco_py" 23 | RUN poetry run python -c "import mujoco_py" 24 | 25 | COPY entrypoint.sh /usr/local/bin/ 26 | RUN chmod 777 /usr/local/bin/entrypoint.sh 27 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] 28 | 29 | # copy local files 30 | COPY ./cleanrl /cleanrl 31 | -------------------------------------------------------------------------------- /benchmark/c51.sh: -------------------------------------------------------------------------------- 1 | poetry install 2 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 3 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 4 | --command "poetry run python cleanrl/c51.py --no_cuda --track --capture_video" \ 5 | --num-seeds 3 \ 6 | --workers 9 7 | 8 | poetry install -E atari 9 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 10 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 11 | --command "poetry run python cleanrl/c51_atari.py --track --capture_video" \ 12 | --num-seeds 3 \ 13 | --workers 1 14 | 15 | poetry install -E "jax" 16 | poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 17 | CUDA_VISIBLE_DEVICES=-1 xvfb-run -a python -m cleanrl_utils.benchmark \ 18 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 19 | --command "poetry run python cleanrl/c51_jax.py --track --capture_video" \ 20 | --num-seeds 3 \ 21 | --workers 1 22 | 23 | poetry install -E "atari jax" 24 | poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 25 | xvfb-run -a python -m cleanrl_utils.benchmark \ 26 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 27 | --command "poetry run python cleanrl/c51_atari_jax.py --track --capture_video" \ 28 | --num-seeds 3 \ 29 | --workers 1 30 | -------------------------------------------------------------------------------- /benchmark/cleanrl_1gpu.slurm_template: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=low-priority 3 | #SBATCH --partition=production-cluster 4 | #SBATCH --gpus-per-task={{gpus_per_task}} 5 | #SBATCH --cpus-per-gpu={{cpus_per_gpu}} 6 | #SBATCH --ntasks={{ntasks}} 7 | #SBATCH --output=slurm/logs/%x_%j.out 8 | #SBATCH --array={{array}} 9 | #SBATCH --mem-per-cpu=12G 10 | #SBATCH --exclude=ip-26-0-146-[33,100,122-123,149,183,212,249],ip-26-0-147-[6,94,120,141],ip-26-0-152-[71,101,119,178,186,207,211],ip-26-0-153-[6,62,112,132,166,251],ip-26-0-154-[38,65],ip-26-0-155-[164,174,187,217],ip-26-0-156-[13,40],ip-26-0-157-27 11 | ##SBATCH --nodelist=ip-26-0-147-204 12 | {{nodes}} 13 | 14 | env_ids={{env_ids}} 15 | seeds={{seeds}} 16 | env_id=${env_ids[$SLURM_ARRAY_TASK_ID / {{len_seeds}}]} 17 | seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]} 18 | 19 | echo "Running task $SLURM_ARRAY_TASK_ID with env_id: $env_id and seed: $seed" 20 | 21 | srun {{command}} --env-id $env_id --seed $seed # 22 | -------------------------------------------------------------------------------- /benchmark/ddpg.sh: -------------------------------------------------------------------------------- 1 | poetry install -E "mujoco" 2 | python -m cleanrl_utils.benchmark \ 3 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 4 | --command "poetry run python cleanrl/ddpg_continuous_action.py --track" \ 5 | --num-seeds 3 \ 6 | --workers 18 \ 7 | --slurm-gpus-per-task 1 \ 8 | --slurm-ntasks 1 \ 9 | --slurm-total-cpus 10 \ 10 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 11 | 12 | poetry install -E "mujoco jax" 13 | poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 14 | poetry run python -m cleanrl_utils.benchmark \ 15 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 16 | --command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track" \ 17 | --num-seeds 3 \ 18 | --workers 18 \ 19 | --slurm-gpus-per-task 1 \ 20 | --slurm-ntasks 1 \ 21 | --slurm-total-cpus 10 \ 22 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 23 | -------------------------------------------------------------------------------- /benchmark/ddpg_plot.sh: -------------------------------------------------------------------------------- 1 | python -m openrlbenchmark.rlops \ 2 | --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 3 | 'ddpg_continuous_action?tag=pr-424' \ 4 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 5 | --no-check-empty-runs \ 6 | --pc.ncols 3 \ 7 | --pc.ncols-legend 2 \ 8 | --output-filename benchmark/cleanrl/ddpg \ 9 | --scan-history 10 | 11 | python -m openrlbenchmark.rlops \ 12 | --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 13 | 'ddpg_continuous_action?tag=pr-424' \ 14 | 'ddpg_continuous_action_jax?tag=pr-424' \ 15 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 16 | --no-check-empty-runs \ 17 | --pc.ncols 3 \ 18 | --pc.ncols-legend 2 \ 19 | --output-filename benchmark/cleanrl/ddpg_jax \ 20 | --scan-history 21 | -------------------------------------------------------------------------------- /benchmark/dqn.sh: -------------------------------------------------------------------------------- 1 | poetry install 2 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 3 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 4 | --command "poetry run python cleanrl/dqn.py --no_cuda --track --capture_video" \ 5 | --num-seeds 3 \ 6 | --workers 9 7 | 8 | poetry install -E atari 9 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 10 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 11 | --command "poetry run python cleanrl/dqn_atari.py --track --capture_video" \ 12 | --num-seeds 3 \ 13 | --workers 1 14 | 15 | poetry install -E jax 16 | poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 17 | xvfb-run -a python -m cleanrl_utils.benchmark \ 18 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 19 | --command "poetry run python cleanrl/dqn_jax.py --track --capture_video" \ 20 | --num-seeds 3 \ 21 | --workers 1 22 | 23 | poetry install -E "atari jax" 24 | poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 25 | xvfb-run -a python -m cleanrl_utils.benchmark \ 26 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 27 | --command "poetry run python cleanrl/dqn_atari_jax.py --track --capture_video" \ 28 | --num-seeds 3 \ 29 | --workers 1 30 | -------------------------------------------------------------------------------- /benchmark/ppg.sh: -------------------------------------------------------------------------------- 1 | # export WANDB_ENTITY=openrlbenchmark 2 | 3 | poetry install -E procgen 4 | xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 5 | --env-ids starpilot bossfight bigfish \ 6 | --command "poetry run python cleanrl/ppg_procgen.py --track --capture_video" \ 7 | --num-seeds 3 \ 8 | --workers 1 9 | -------------------------------------------------------------------------------- /benchmark/ppo_trxl.sh: -------------------------------------------------------------------------------- 1 | # export WANDB_ENTITY=openrlbenchmark 2 | 3 | cd cleanrl/ppo_trxl 4 | poetry install 5 | OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ 6 | --env-ids MortarMayhem-Grid-v0 \ 7 | --command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --norm_adv --trxl_memory_length 119 --total_timesteps 100000000" \ 8 | --num-seeds 3 \ 9 | --workers 32 \ 10 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 11 | 12 | OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ 13 | --env-ids MortarMayhem-v0 \ 14 | --command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 275" \ 15 | --num-seeds 3 \ 16 | --workers 32 \ 17 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 18 | 19 | OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ 20 | --env-ids MysteryPath-Grid-v0 \ 21 | --command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 96 --total_timesteps 100000000" \ 22 | --num-seeds 3 \ 23 | --workers 32 \ 24 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 25 | 26 | OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ 27 | --env-ids MysteryPath-v0 \ 28 | --command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256" \ 29 | --num-seeds 3 \ 30 | --workers 32 \ 31 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 32 | 33 | OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ 34 | --env-ids SearingSpotlights-v0 \ 35 | --command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256" \ 36 | --num-seeds 3 \ 37 | --workers 32 \ 38 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 39 | 40 | OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ 41 | --env-ids Endless-SearingSpotlights-v0 \ 42 | --command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256 --total_timesteps 350000000" \ 43 | --num-seeds 3 \ 44 | --workers 32 \ 45 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 46 | 47 | OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ 48 | --env-ids Endless-MortarMayhem-v0 Endless-MysteryPath-v0 \ 49 | --command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256 --total_timesteps 350000000" \ 50 | --num-seeds 3 \ 51 | --workers 32 \ 52 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 53 | -------------------------------------------------------------------------------- /benchmark/pqn.sh: -------------------------------------------------------------------------------- 1 | poetry install 2 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 3 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 4 | --command "poetry run python cleanrl/pqn.py --no_cuda --track" \ 5 | --num-seeds 3 \ 6 | --workers 9 \ 7 | --slurm-gpus-per-task 1 \ 8 | --slurm-ntasks 1 \ 9 | --slurm-total-cpus 10 \ 10 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 11 | 12 | poetry install -E envpool 13 | poetry run python -m cleanrl_utils.benchmark \ 14 | --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ 15 | --command "poetry run python cleanrl/pqn_atari_envpool.py --track" \ 16 | --num-seeds 3 \ 17 | --workers 9 \ 18 | --slurm-gpus-per-task 1 \ 19 | --slurm-ntasks 1 \ 20 | --slurm-total-cpus 10 \ 21 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 22 | 23 | poetry install -E envpool 24 | poetry run python -m cleanrl_utils.benchmark \ 25 | --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ 26 | --command "poetry run python cleanrl/pqn_atari_envpool_lstm.py --track" \ 27 | --num-seeds 3 \ 28 | --workers 9 \ 29 | --slurm-gpus-per-task 1 \ 30 | --slurm-ntasks 1 \ 31 | --slurm-total-cpus 10 \ 32 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 33 | -------------------------------------------------------------------------------- /benchmark/pqn_plot.sh: -------------------------------------------------------------------------------- 1 | 2 | python -m openrlbenchmark.rlops \ 3 | --filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 4 | 'pqn?tag=pr-494&cl=CleanRL PQN' \ 5 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 6 | --no-check-empty-runs \ 7 | --pc.ncols 3 \ 8 | --pc.ncols-legend 2 \ 9 | --output-filename benchmark/cleanrl/pqn \ 10 | --scan-history 11 | 12 | python -m openrlbenchmark.rlops \ 13 | --filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 14 | 'pqn_atari_envpool?tag=pr-494&cl=CleanRL PQN' \ 15 | --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ 16 | --no-check-empty-runs \ 17 | --pc.ncols 3 \ 18 | --pc.ncols-legend 3 \ 19 | --rliable \ 20 | --rc.score_normalization_method maxmin \ 21 | --rc.normalized_score_threshold 1.0 \ 22 | --rc.sample_efficiency_plots \ 23 | --rc.sample_efficiency_and_walltime_efficiency_method Median \ 24 | --rc.performance_profile_plots \ 25 | --rc.aggregate_metrics_plots \ 26 | --rc.sample_efficiency_num_bootstrap_reps 10 \ 27 | --rc.performance_profile_num_bootstrap_reps 10 \ 28 | --rc.interval_estimates_num_bootstrap_reps 10 \ 29 | --output-filename static/0compare \ 30 | --scan-history 31 | 32 | python -m openrlbenchmark.rlops \ 33 | --filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 34 | 'pqn_atari_envpool_lstm?tag=pr-494&cl=CleanRL PQN' \ 35 | --env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ 36 | --no-check-empty-runs \ 37 | --pc.ncols 3 \ 38 | --pc.ncols-legend 3 \ 39 | --rliable \ 40 | --rc.score_normalization_method maxmin \ 41 | --rc.normalized_score_threshold 1.0 \ 42 | --rc.sample_efficiency_plots \ 43 | --rc.sample_efficiency_and_walltime_efficiency_method Median \ 44 | --rc.performance_profile_plots \ 45 | --rc.aggregate_metrics_plots \ 46 | --rc.sample_efficiency_num_bootstrap_reps 10 \ 47 | --rc.performance_profile_num_bootstrap_reps 10 \ 48 | --rc.interval_estimates_num_bootstrap_reps 10 \ 49 | --output-filename static/0compare \ 50 | --scan-history 51 | -------------------------------------------------------------------------------- /benchmark/qdagger.sh: -------------------------------------------------------------------------------- 1 | poetry install -E atari 2 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 3 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 4 | --command "poetry run python cleanrl/qdagger_dqn_atari_impalacnn.py --track --capture_video" \ 5 | --num-seeds 3 \ 6 | --workers 1 7 | 8 | 9 | poetry install -E "atari jax" 10 | poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 11 | xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 12 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 13 | --command "poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --track --capture_video" \ 14 | --num-seeds 3 \ 15 | --workers 1 16 | -------------------------------------------------------------------------------- /benchmark/rainbow.sh: -------------------------------------------------------------------------------- 1 | poetry install -E atari 2 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 3 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 4 | --command "poetry run python cleanrl/rainbow_atari.py --track --capture_video" \ 5 | --num-seeds 3 \ 6 | --workers 1 -------------------------------------------------------------------------------- /benchmark/rnd.sh: -------------------------------------------------------------------------------- 1 | # export WANDB_ENTITY=openrlbenchmark 2 | 3 | poetry install -E envpool 4 | xvfb-run -a python -m cleanrl_utils.benchmark \ 5 | --env-ids MontezumaRevenge-v5 \ 6 | --command "poetry run python cleanrl/ppo_rnd_envpool.py --track" \ 7 | --num-seeds 1 \ 8 | --workers 1 -------------------------------------------------------------------------------- /benchmark/rpo.sh: -------------------------------------------------------------------------------- 1 | poetry install "mujoco dm_control" 2 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 3 | --env-ids dm_control/acrobot-swingup-v0 dm_control/acrobot-swingup_sparse-v0 dm_control/ball_in_cup-catch-v0 dm_control/cartpole-balance-v0 dm_control/cartpole-balance_sparse-v0 dm_control/cartpole-swingup-v0 dm_control/cartpole-swingup_sparse-v0 dm_control/cartpole-two_poles-v0 dm_control/cartpole-three_poles-v0 dm_control/cheetah-run-v0 dm_control/dog-stand-v0 dm_control/dog-walk-v0 dm_control/dog-trot-v0 dm_control/dog-run-v0 dm_control/dog-fetch-v0 dm_control/finger-spin-v0 dm_control/finger-turn_easy-v0 dm_control/finger-turn_hard-v0 dm_control/fish-upright-v0 dm_control/fish-swim-v0 dm_control/hopper-stand-v0 dm_control/hopper-hop-v0 dm_control/humanoid-stand-v0 dm_control/humanoid-walk-v0 dm_control/humanoid-run-v0 dm_control/humanoid-run_pure_state-v0 dm_control/humanoid_CMU-stand-v0 dm_control/humanoid_CMU-run-v0 dm_control/lqr-lqr_2_1-v0 dm_control/lqr-lqr_6_2-v0 dm_control/manipulator-bring_ball-v0 dm_control/manipulator-bring_peg-v0 dm_control/manipulator-insert_ball-v0 dm_control/manipulator-insert_peg-v0 dm_control/pendulum-swingup-v0 dm_control/point_mass-easy-v0 dm_control/point_mass-hard-v0 dm_control/quadruped-walk-v0 dm_control/quadruped-run-v0 dm_control/quadruped-escape-v0 dm_control/quadruped-fetch-v0 dm_control/reacher-easy-v0 dm_control/reacher-hard-v0 dm_control/stacker-stack_2-v0 dm_control/stacker-stack_4-v0 dm_control/swimmer-swimmer6-v0 dm_control/swimmer-swimmer15-v0 dm_control/walker-stand-v0 dm_control/walker-walk-v0 dm_control/walker-run-v0 \ 4 | --command "poetry run python cleanrl/rpo_continuous_action.py --no_cuda --track" \ 5 | --num-seeds 10 \ 6 | --workers 1 7 | 8 | poetry run pip install box2d-py==2.3.5 9 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 10 | --env-ids Pendulum-v1 BipedalWalker-v3 \ 11 | --command "poetry run python cleanrl/rpo_continuous_action.py --no_cuda --track --capture_video" \ 12 | --num-seeds 1 \ 13 | --workers 1 14 | 15 | poetry install -E mujoco 16 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 17 | --env-ids HumanoidStandup-v4 Humanoid-v4 InvertedPendulum-v4 Walker2d-v4 \ 18 | --command "poetry run python cleanrl/rpo_continuous_action.py --no_cuda --track --capture_video" \ 19 | --num-seeds 10 \ 20 | --workers 1 21 | 22 | poetry install -E mujoco 23 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 24 | --env-ids HumanoidStandup-v2 Humanoid-v2 InvertedPendulum-v2 Walker2d-v2 \ 25 | --command "poetry run python cleanrl/rpo_continuous_action.py --no_cuda --track --capture_video" \ 26 | --num-seeds 10 \ 27 | --workers 1 28 | 29 | poetry install -E mujoco 30 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 31 | --env-ids Ant-v4 InvertedDoublePendulum-v4 Reacher-v4 Pusher-v4 Hopper-v4 HalfCheetah-v4 Swimmer-v4 \ 32 | --command "poetry run python cleanrl/rpo_continuous_action.py --rpo-alpha 0.01 --no_cuda --track --capture_video" \ 33 | --num-seeds 10 \ 34 | --workers 1 35 | 36 | poetry install -E mujoco 37 | OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 38 | --env-ids Ant-v2 InvertedDoublePendulum-v2 Reacher-v2 Pusher-v2 Hopper-v2 HalfCheetah-v2 Swimmer-v2 \ 39 | --command "poetry run python cleanrl/rpo_continuous_action.py --rpo-alpha 0.01 --no_cuda --track --capture_video" \ 40 | --num-seeds 10 \ 41 | --workers 1 42 | 43 | 44 | -------------------------------------------------------------------------------- /benchmark/sac.sh: -------------------------------------------------------------------------------- 1 | poetry install -E mujoco 2 | poetry run python -m cleanrl_utils.benchmark \ 3 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 4 | --command "poetry run python cleanrl/sac_continuous_action.py --track" \ 5 | --num-seeds 3 \ 6 | --workers 18 \ 7 | --slurm-gpus-per-task 1 \ 8 | --slurm-ntasks 1 \ 9 | --slurm-total-cpus 10 \ 10 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 11 | -------------------------------------------------------------------------------- /benchmark/sac_atari.sh: -------------------------------------------------------------------------------- 1 | poetry install -E atari 2 | OMP_NUM_THREADS=1 python -m cleanrl_utils.benchmark \ 3 | --env-ids PongNoFrameskip-v4 BreakoutNoFrameskip-v4 BeamRiderNoFrameskip-v4 \ 4 | --command "poetry run python cleanrl/sac_atari.py --track" \ 5 | --num-seeds 3 \ 6 | --workers 2 7 | -------------------------------------------------------------------------------- /benchmark/sac_plot.sh: -------------------------------------------------------------------------------- 1 | python -m openrlbenchmark.rlops \ 2 | --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 3 | 'sac_continuous_action?tag=pr-424' \ 4 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 5 | --no-check-empty-runs \ 6 | --pc.ncols 3 \ 7 | --pc.ncols-legend 2 \ 8 | --output-filename benchmark/cleanrl/sac \ 9 | --scan-history 10 | -------------------------------------------------------------------------------- /benchmark/td3.sh: -------------------------------------------------------------------------------- 1 | poetry install -E "mujoco" 2 | python -m cleanrl_utils.benchmark \ 3 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 4 | --command "poetry run python cleanrl/td3_continuous_action.py --track" \ 5 | --num-seeds 3 \ 6 | --workers 18 \ 7 | --slurm-gpus-per-task 1 \ 8 | --slurm-ntasks 1 \ 9 | --slurm-total-cpus 10 \ 10 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 11 | 12 | poetry install -E "mujoco jax" 13 | poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 14 | poetry run python -m cleanrl_utils.benchmark \ 15 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 16 | --command "poetry run python cleanrl/td3_continuous_action_jax.py --track" \ 17 | --num-seeds 3 \ 18 | --workers 18 \ 19 | --slurm-gpus-per-task 1 \ 20 | --slurm-ntasks 1 \ 21 | --slurm-total-cpus 10 \ 22 | --slurm-template-path benchmark/cleanrl_1gpu.slurm_template 23 | -------------------------------------------------------------------------------- /benchmark/td3_plot.sh: -------------------------------------------------------------------------------- 1 | python -m openrlbenchmark.rlops \ 2 | --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 3 | 'td3_continuous_action?tag=pr-424' \ 4 | 'td3_continuous_action_jax?tag=pr-424' \ 5 | --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 6 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 7 | --no-check-empty-runs \ 8 | --pc.ncols 3 \ 9 | --pc.ncols-legend 2 \ 10 | --output-filename benchmark/cleanrl/td3 \ 11 | --scan-history 12 | 13 | python -m openrlbenchmark.rlops \ 14 | --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ 15 | 'sac_continuous_action?tag=pr-424' \ 16 | --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ 17 | --no-check-empty-runs \ 18 | --pc.ncols 3 \ 19 | --pc.ncols-legend 2 \ 20 | --output-filename benchmark/cleanrl/sac \ 21 | --scan-history 22 | -------------------------------------------------------------------------------- /benchmark/zoo.sh: -------------------------------------------------------------------------------- 1 | poetry run python cleanrl/dqn_jax.py --env-id CartPole-v1 --save-model --upload-model --hf-entity cleanrl 2 | poetry run python cleanrl/dqn_atari_jax.py --env-id SeaquestNoFrameskip-v4 --save-model --upload-model --hf-entity cleanrl 3 | 4 | xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 5 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 6 | --command "poetry run python cleanrl/dqn.py --no_cuda --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ 7 | --num-seeds 1 \ 8 | --workers 1 9 | 10 | CUDA_VISIBLE_DEVICES="-1" xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ 11 | --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ 12 | --command "poetry run python cleanrl/dqn_jax.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ 13 | --num-seeds 1 \ 14 | --workers 1 15 | 16 | xvfb-run -a python -m cleanrl_utils.benchmark \ 17 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 18 | --command "poetry run python cleanrl/dqn_atari_jax.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ 19 | --num-seeds 1 \ 20 | --workers 1 21 | 22 | xvfb-run -a python -m cleanrl_utils.benchmark \ 23 | --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 24 | --command "poetry run python cleanrl/dqn_atari.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ 25 | --num-seeds 1 \ 26 | --workers 1 27 | 28 | python -m cleanrl_utils.benchmark \ 29 | --env-ids Pong-v5 BeamRider-v5 Breakout-v5 \ 30 | --command "poetry run python cleanrl/ppo_atari_envpool_xla_jax_scan.py --track --save-model --upload-model --hf-entity cleanrl" \ 31 | --num-seeds 1 \ 32 | --workers 1 33 | 34 | CUDA_VISIBLE_DEVICES="1" taskset --cpu-list 16,17,18,19,20,21,22,23 python -m cleanrl_utils.benchmark \ 35 | --env-ids Breakout-v5 \ 36 | --command "poetry run python cleanrl/ppo_atari_envpool_xla_jax_scan.py --track --save-model --upload-model --hf-entity cleanrl" \ 37 | --num-seeds 1 \ 38 | --workers 1 39 | -------------------------------------------------------------------------------- /cleanrl/ppo_continuous_action_isaacgym/isaacgym/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "isaacgym" 3 | version = "1.0.preview4" 4 | description = "" 5 | authors = ["Costa Huang "] 6 | include = ["isaacgym/**/*", "examples/**/*"] 7 | packages = [ 8 | { include = "isaacgym" }, 9 | ] 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.7.1" 13 | gym = "0.23.1" 14 | torch = "^1.12.0" 15 | torchvision = "^0.13.0" 16 | PyYAML = ">=5.3.1" 17 | scipy = ">=1.5.0" 18 | numpy = ">=1.16.4" 19 | Pillow = "^9.2.0" 20 | imageio = "^2.19.5" 21 | ninja = "^1.10.2" 22 | 23 | [tool.poetry.dev-dependencies] 24 | 25 | [build-system] 26 | requires = ["poetry-core>=1.0.0"] 27 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /cleanrl/ppo_trxl/enjoy.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import gymnasium as gym 4 | import torch 5 | import tyro 6 | from ppo_trxl import Agent, make_env 7 | 8 | 9 | @dataclass 10 | class Args: 11 | hub: bool = False 12 | """whether to load the model from the huggingface hub or from the local disk""" 13 | name: str = "Endless-MortarMayhem-v0_12.nn" 14 | """path to the model file""" 15 | 16 | 17 | if __name__ == "__main__": 18 | # Parse command line arguments and retrieve model path 19 | cli_args = tyro.cli(Args) 20 | if cli_args.hub: 21 | try: 22 | from huggingface_hub import hf_hub_download 23 | 24 | path = hf_hub_download(repo_id="LilHairdy/cleanrl_memory_gym", filename=cli_args.name) 25 | except: 26 | raise RuntimeError( 27 | "Cannot load model from the huggingface hub. Please install the huggingface_hub pypi package and verify the model name. You can also download the model from the hub manually and load it from disk." 28 | ) 29 | else: 30 | path = cli_args.name 31 | 32 | # Load the pre-trained model and the original args used to train it 33 | checkpoint = torch.load(path) 34 | args = checkpoint["args"] 35 | args = type("Args", (), args) 36 | 37 | # Init environment and reset 38 | env = make_env(args.env_id, 0, False, "", "human")() 39 | obs, _ = env.reset() 40 | env.render() 41 | 42 | # Determine maximum episode steps 43 | max_episode_steps = env.spec.max_episode_steps 44 | if not max_episode_steps: 45 | max_episode_steps = env.max_episode_steps 46 | if max_episode_steps <= 0: 47 | max_episode_steps = 1024 # Memory Gym envs have max_episode_steps set to -1 48 | # May episode impacts positional encoding, so make sure to set this accordingly 49 | 50 | # Setup agent and load its model parameters 51 | action_space_shape = ( 52 | (env.action_space.n,) if isinstance(env.action_space, gym.spaces.Discrete) else tuple(env.action_space.nvec) 53 | ) 54 | agent = Agent(args, env.observation_space, action_space_shape, max_episode_steps) 55 | agent.load_state_dict(checkpoint["model_weights"]) 56 | 57 | # Setup Transformer-XL memory, mask and indices 58 | memory = torch.zeros((1, max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32) 59 | memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1) 60 | repetitions = torch.repeat_interleave( 61 | torch.arange(0, args.trxl_memory_length).unsqueeze(0), args.trxl_memory_length - 1, dim=0 62 | ).long() 63 | memory_indices = torch.stack( 64 | [torch.arange(i, i + args.trxl_memory_length) for i in range(max_episode_steps - args.trxl_memory_length + 1)] 65 | ).long() 66 | memory_indices = torch.cat((repetitions, memory_indices)) 67 | 68 | # Run episode 69 | done = False 70 | t = 0 71 | while not done: 72 | # Prepare observation and memory 73 | obs = torch.Tensor(obs).unsqueeze(0) 74 | memory_window = memory[0, memory_indices[t].unsqueeze(0)] 75 | t_ = max(0, min(t, args.trxl_memory_length - 1)) 76 | mask = memory_mask[t_].unsqueeze(0) 77 | indices = memory_indices[t].unsqueeze(0) 78 | # Forward agent 79 | action, _, _, _, new_memory = agent.get_action_and_value(obs, memory_window, mask, indices) 80 | memory[:, t] = new_memory 81 | # Step 82 | obs, reward, termination, truncation, info = env.step(action.cpu().squeeze().numpy()) 83 | env.render() 84 | done = termination or truncation 85 | t += 1 86 | 87 | if "r" in info["episode"].keys(): 88 | print(f"Episode return: {info['episode']['r'][0]}, Episode length: {info['episode']['l'][0]}") 89 | else: 90 | print(f"Episode return: {info['reward']}, Episode length: {info['length']}") 91 | env.close() 92 | -------------------------------------------------------------------------------- /cleanrl/ppo_trxl/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ppo-trxl" 3 | version = "1.0.1" 4 | description = "" 5 | authors = ["Marco Pleines "] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | memory-gym = "^1.0.2" 11 | einops = "^0.7.0" 12 | minigrid = "^2.3.1" 13 | reprint = "^0.6.0" 14 | opencv-python = "^4.9.0.80" 15 | torch = { version = "^2.0.0", source = "pytorch" } 16 | torchaudio = { version = "^2.0.0", source = "pytorch" } 17 | wandb = "^0.16.6" 18 | tyro = "^0.8.3" 19 | tensorboard = "^2.16.2" 20 | 21 | [build-system] 22 | requires = ["poetry-core"] 23 | build-backend = "poetry.core.masonry.api" 24 | 25 | [[tool.poetry.source]] 26 | name = "pytorch" 27 | url = "https://download.pytorch.org/whl/cu118" 28 | priority = "explicit" 29 | -------------------------------------------------------------------------------- /cleanrl_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/cleanrl_utils/__init__.py -------------------------------------------------------------------------------- /cleanrl_utils/add_header.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def add_header(dirname: str): 5 | """ 6 | Add a header string with documentation link 7 | to each file in the directory `dirname`. 8 | """ 9 | 10 | for filename in os.listdir(dirname): 11 | if filename.endswith(".py"): 12 | with open(os.path.join(dirname, filename)) as f: 13 | lines = f.readlines() 14 | 15 | # hacky bit 16 | exp_name = filename.split(".")[0] 17 | algo_name = exp_name.split("_")[0] 18 | header_string = f"# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/{algo_name}/#{exp_name}py" 19 | 20 | if not lines[0].startswith(header_string): 21 | print(f"adding headers for {filename}") 22 | lines.insert(0, header_string + "\n") 23 | with open(os.path.join(dirname, filename), "w") as f: 24 | f.writelines(lines) 25 | 26 | 27 | if __name__ == "__main__": 28 | add_header("cleanrl") 29 | -------------------------------------------------------------------------------- /cleanrl_utils/docker_build.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--tag", type=str, default="cleanrl:latest", help="the name of this experiment") 6 | args = parser.parse_args() 7 | 8 | subprocess.run( 9 | f"docker build -t {args.tag} .", 10 | shell=True, 11 | check=True, 12 | ) 13 | -------------------------------------------------------------------------------- /cleanrl_utils/docker_queue.py: -------------------------------------------------------------------------------- 1 | """ 2 | See https://github.com/docker/docker-py/issues/2395 3 | At the moment, nvidia-container-toolkit still includes nvidia-container-runtime. So, you can still add nvidia-container-runtime as a runtime in /etc/docker/daemon.json: 4 | 5 | { 6 | "runtimes": { 7 | "nvidia": { 8 | "path": "nvidia-container-runtime", 9 | "runtimeArgs": [] 10 | } 11 | } 12 | } 13 | Then restart the docker service (sudo systemctl restart docker) and use runtime="nvidia" in docker-py as before. 14 | """ 15 | 16 | 17 | import argparse 18 | import shlex 19 | import time 20 | 21 | import docker 22 | 23 | parser = argparse.ArgumentParser(description="CleanRL Docker Submission") 24 | # Common arguments 25 | parser.add_argument("--exp-script", type=str, default="test1.sh", help="the file name of this experiment") 26 | # parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True, 27 | # help='if toggled, cuda will not be enabled by default') 28 | parser.add_argument("--num-vcpus", type=int, default=16, help="total number of vcpus used in the host machine") 29 | parser.add_argument("--frequency", type=int, default=1, help="the number of seconds to check container update status") 30 | args = parser.parse_args() 31 | 32 | client = docker.from_env() 33 | 34 | # c = client.containers.run("ubuntu:latest", "echo hello world", detach=True) 35 | 36 | with open(args.exp_script) as f: 37 | lines = f.readlines() 38 | 39 | tasks = [] 40 | for line in lines: 41 | line.replace("\n", "") 42 | line_split = shlex.split(line) 43 | for idx, item in enumerate(line_split): 44 | if item == "-e": 45 | break 46 | env_vars = line_split[idx + 1 : idx + 2] 47 | image = line_split[idx + 2] 48 | commands = line_split[idx + 3 :] 49 | tasks += [[image, env_vars, commands]] 50 | 51 | running_containers = [] 52 | vcpus = list(range(args.num_vcpus)) 53 | while len(tasks) != 0: 54 | time.sleep(args.frequency) 55 | 56 | # update running_containers 57 | new_running_containers = [] 58 | for item in running_containers: 59 | c = item[0] 60 | c.reload() 61 | if c.status != "exited": 62 | new_running_containers += [item] 63 | else: 64 | print(f"✅ task on vcpu {item[1]} has finished") 65 | vcpus += [item[1]] 66 | running_containers = new_running_containers 67 | 68 | if len(vcpus) != 0: 69 | task = tasks.pop() 70 | vcpu = vcpus.pop() 71 | # if args.cuda: 72 | # c = client.containers.run( 73 | # image=task[0], 74 | # environment=task[1], 75 | # command=task[2], 76 | # runtime="nvidia", 77 | # cpuset_cpus=str(vcpu), 78 | # detach=True) 79 | # running_containers += [[c, vcpu]] 80 | # else: 81 | c = client.containers.run(image=task[0], environment=task[1], command=task[2], cpuset_cpus=str(vcpu), detach=True) 82 | running_containers += [[c, vcpu]] 83 | print("========================") 84 | print(f"remaining tasks={len(tasks)}, running containers={len(running_containers)}") 85 | print(f"running on vcpu {vcpu}", task) 86 | -------------------------------------------------------------------------------- /cleanrl_utils/enjoy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from huggingface_hub import hf_hub_download 4 | 5 | from cleanrl_utils.evals import MODELS 6 | 7 | 8 | def parse_args(): 9 | # fmt: off 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--exp-name", type=str, default="dqn_atari", 12 | help="the name of this experiment (e.g., ppo, dqn_atari)") 13 | parser.add_argument("--seed", type=int, default=1, 14 | help="seed of the experiment") 15 | parser.add_argument("--hf-entity", type=str, default="cleanrl", 16 | help="the user or org name of the model repository from the Hugging Face Hub") 17 | parser.add_argument("--hf-repository", type=str, default="", 18 | help="the huggingface repo (e.g., cleanrl/BreakoutNoFrameskip-v4-dqn_atari-seed1)") 19 | parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4", 20 | help="the id of the environment") 21 | parser.add_argument("--eval-episodes", type=int, default=10, 22 | help="the number of evaluation episodes") 23 | args = parser.parse_args() 24 | # fmt: on 25 | return args 26 | 27 | 28 | if __name__ == "__main__": 29 | args = parse_args() 30 | Model, make_env, evaluate = MODELS[args.exp_name]() 31 | if not args.hf_repository: 32 | args.hf_repository = f"{args.hf_entity}/{args.env_id}-{args.exp_name}-seed{args.seed}" 33 | print(f"loading saved models from {args.hf_repository}...") 34 | model_path = hf_hub_download(repo_id=args.hf_repository, filename=f"{args.exp_name}.cleanrl_model") 35 | evaluate( 36 | model_path, 37 | make_env, 38 | args.env_id, 39 | eval_episodes=args.eval_episodes, 40 | run_name=f"eval", 41 | Model=Model, 42 | capture_video=args.capture_video, 43 | ) 44 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/__init__.py: -------------------------------------------------------------------------------- 1 | def dqn(): 2 | import cleanrl.dqn 3 | import cleanrl_utils.evals.dqn_eval 4 | 5 | return cleanrl.dqn.QNetwork, cleanrl.dqn.make_env, cleanrl_utils.evals.dqn_eval.evaluate 6 | 7 | 8 | def dqn_atari(): 9 | import cleanrl.dqn_atari 10 | import cleanrl_utils.evals.dqn_eval 11 | 12 | return cleanrl.dqn_atari.QNetwork, cleanrl.dqn_atari.make_env, cleanrl_utils.evals.dqn_eval.evaluate 13 | 14 | 15 | def dqn_jax(): 16 | import cleanrl.dqn_jax 17 | import cleanrl_utils.evals.dqn_jax_eval 18 | 19 | return cleanrl.dqn_jax.QNetwork, cleanrl.dqn_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate 20 | 21 | 22 | def dqn_atari_jax(): 23 | import cleanrl.dqn_atari_jax 24 | import cleanrl_utils.evals.dqn_jax_eval 25 | 26 | return cleanrl.dqn_atari_jax.QNetwork, cleanrl.dqn_atari_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate 27 | 28 | 29 | def c51(): 30 | import cleanrl.c51 31 | import cleanrl_utils.evals.c51_eval 32 | 33 | return cleanrl.c51.QNetwork, cleanrl.c51.make_env, cleanrl_utils.evals.c51_eval.evaluate 34 | 35 | 36 | def c51_atari(): 37 | import cleanrl.c51_atari 38 | import cleanrl_utils.evals.c51_eval 39 | 40 | return cleanrl.c51_atari.QNetwork, cleanrl.c51_atari.make_env, cleanrl_utils.evals.c51_eval.evaluate 41 | 42 | 43 | def c51_jax(): 44 | import cleanrl.c51_jax 45 | import cleanrl_utils.evals.c51_jax_eval 46 | 47 | return cleanrl.c51_jax.QNetwork, cleanrl.c51_jax.make_env, cleanrl_utils.evals.c51_jax_eval.evaluate 48 | 49 | 50 | def c51_atari_jax(): 51 | import cleanrl.c51_atari_jax 52 | import cleanrl_utils.evals.c51_jax_eval 53 | 54 | return cleanrl.c51_atari_jax.QNetwork, cleanrl.c51_atari_jax.make_env, cleanrl_utils.evals.c51_jax_eval.evaluate 55 | 56 | 57 | def ppo_atari_envpool_xla_jax_scan(): 58 | import cleanrl.ppo_atari_envpool_xla_jax_scan 59 | import cleanrl_utils.evals.ppo_envpool_jax_eval 60 | 61 | return ( 62 | ( 63 | cleanrl.ppo_atari_envpool_xla_jax_scan.Network, 64 | cleanrl.ppo_atari_envpool_xla_jax_scan.Actor, 65 | cleanrl.ppo_atari_envpool_xla_jax_scan.Critic, 66 | ), 67 | cleanrl.ppo_atari_envpool_xla_jax_scan.make_env, 68 | cleanrl_utils.evals.ppo_envpool_jax_eval.evaluate, 69 | ) 70 | 71 | 72 | MODELS = { 73 | "dqn": dqn, 74 | "dqn_atari": dqn_atari, 75 | "dqn_jax": dqn_jax, 76 | "dqn_atari_jax": dqn_atari_jax, 77 | "c51": c51, 78 | "c51_atari": c51_atari, 79 | "c51_jax": c51_jax, 80 | "c51_atari_jax": c51_atari_jax, 81 | "ppo_atari_envpool_xla_jax_scan": ppo_atari_envpool_xla_jax_scan, 82 | } 83 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/c51_eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | from argparse import Namespace 3 | from typing import Callable 4 | 5 | import gymnasium as gym 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def evaluate( 11 | model_path: str, 12 | make_env: Callable, 13 | env_id: str, 14 | eval_episodes: int, 15 | run_name: str, 16 | Model: torch.nn.Module, 17 | device: torch.device = torch.device("cpu"), 18 | epsilon: float = 0.05, 19 | capture_video: bool = True, 20 | ): 21 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 22 | model_data = torch.load(model_path, map_location="cpu") 23 | args = Namespace(**model_data["args"]) 24 | model = Model(envs, n_atoms=args.n_atoms, v_min=args.v_min, v_max=args.v_max) 25 | model.load_state_dict(model_data["model_weights"]) 26 | model = model.to(device) 27 | model.eval() 28 | 29 | obs, _ = envs.reset() 30 | episodic_returns = [] 31 | while len(episodic_returns) < eval_episodes: 32 | if random.random() < epsilon: 33 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 34 | else: 35 | actions, _ = model.get_action(torch.Tensor(obs).to(device)) 36 | actions = actions.cpu().numpy() 37 | next_obs, _, _, _, infos = envs.step(actions) 38 | if "final_info" in infos: 39 | for info in infos["final_info"]: 40 | if "episode" not in info: 41 | continue 42 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 43 | episodic_returns += [info["episode"]["r"]] 44 | obs = next_obs 45 | 46 | return episodic_returns 47 | 48 | 49 | if __name__ == "__main__": 50 | from huggingface_hub import hf_hub_download 51 | 52 | from cleanrl.c51 import QNetwork, make_env 53 | 54 | model_path = hf_hub_download(repo_id="cleanrl/CartPole-v1-c51-seed1", filename="c51.cleanrl_model") 55 | evaluate( 56 | model_path, 57 | make_env, 58 | "CartPole-v1", 59 | eval_episodes=10, 60 | run_name=f"eval", 61 | Model=QNetwork, 62 | device="cpu", 63 | capture_video=False, 64 | ) 65 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/c51_jax_eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | from argparse import Namespace 3 | from typing import Callable 4 | 5 | import flax 6 | import flax.linen as nn 7 | import gymnasium as gym 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | 12 | 13 | def evaluate( 14 | model_path: str, 15 | make_env: Callable, 16 | env_id: str, 17 | eval_episodes: int, 18 | run_name: str, 19 | Model: nn.Module, 20 | epsilon: float = 0.05, 21 | capture_video: bool = True, 22 | seed=1, 23 | ): 24 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 25 | obs, _ = envs.reset() 26 | model_data = None 27 | with open(model_path, "rb") as f: 28 | model_data = flax.serialization.from_bytes(model_data, f.read()) 29 | args = Namespace(**model_data["args"]) 30 | model = Model(action_dim=envs.single_action_space.n, n_atoms=args.n_atoms) 31 | # q_key = jax.random.PRNGKey(seed) 32 | params = model_data["model_weights"] 33 | model.apply = jax.jit(model.apply) 34 | atoms = jnp.asarray(np.linspace(args.v_min, args.v_max, num=args.n_atoms)) 35 | 36 | episodic_returns = [] 37 | while len(episodic_returns) < eval_episodes: 38 | if random.random() < epsilon: 39 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 40 | else: 41 | pmfs = model.apply(params, obs) 42 | q_vals = (pmfs * atoms).sum(axis=-1) 43 | actions = q_vals.argmax(axis=-1) 44 | actions = jax.device_get(actions) 45 | next_obs, _, _, _, infos = envs.step(actions) 46 | if "final_info" in infos: 47 | for info in infos["final_info"]: 48 | if "episode" not in info: 49 | continue 50 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 51 | episodic_returns += [info["episode"]["r"]] 52 | obs = next_obs 53 | 54 | return episodic_returns 55 | 56 | 57 | if __name__ == "__main__": 58 | from huggingface_hub import hf_hub_download 59 | 60 | from cleanrl.c51_jax import QNetwork, make_env 61 | 62 | model_path = hf_hub_download(repo_id="cleanrl/CartPole-v1-c51_jax-seed1", filename="c51_jax.cleanrl_model") 63 | evaluate( 64 | model_path, 65 | make_env, 66 | "CartPole-v1", 67 | eval_episodes=10, 68 | run_name=f"eval", 69 | Model=QNetwork, 70 | capture_video=False, 71 | ) 72 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/ddpg_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import gymnasium as gym 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def evaluate( 9 | model_path: str, 10 | make_env: Callable, 11 | env_id: str, 12 | eval_episodes: int, 13 | run_name: str, 14 | Model: nn.Module, 15 | device: torch.device = torch.device("cpu"), 16 | capture_video: bool = True, 17 | exploration_noise: float = 0.1, 18 | ): 19 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 20 | actor = Model[0](envs).to(device) 21 | qf = Model[1](envs).to(device) 22 | actor_params, qf_params = torch.load(model_path, map_location=device) 23 | actor.load_state_dict(actor_params) 24 | actor.eval() 25 | qf.load_state_dict(qf_params) 26 | qf.eval() 27 | # note: qf is not used in this script 28 | 29 | obs, _ = envs.reset() 30 | episodic_returns = [] 31 | while len(episodic_returns) < eval_episodes: 32 | with torch.no_grad(): 33 | actions = actor(torch.Tensor(obs).to(device)) 34 | actions += torch.normal(0, actor.action_scale * exploration_noise) 35 | actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) 36 | 37 | next_obs, _, _, _, infos = envs.step(actions) 38 | if "final_info" in infos: 39 | for info in infos["final_info"]: 40 | if "episode" not in info: 41 | continue 42 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 43 | episodic_returns += [info["episode"]["r"]] 44 | obs = next_obs 45 | 46 | return episodic_returns 47 | 48 | 49 | if __name__ == "__main__": 50 | from huggingface_hub import hf_hub_download 51 | 52 | from cleanrl.ddpg_continuous_action import Actor, QNetwork, make_env 53 | 54 | model_path = hf_hub_download( 55 | repo_id="cleanrl/HalfCheetah-v4-ddpg_continuous_action-seed1", filename="ddpg_continuous_action.cleanrl_model" 56 | ) 57 | evaluate( 58 | model_path, 59 | make_env, 60 | "HalfCheetah-v4", 61 | eval_episodes=10, 62 | run_name=f"eval", 63 | Model=(Actor, QNetwork), 64 | device="cpu", 65 | capture_video=False, 66 | ) 67 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/ddpg_jax_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import flax 4 | import flax.linen as nn 5 | import gymnasium as gym 6 | import jax 7 | import numpy as np 8 | 9 | 10 | def evaluate( 11 | model_path: str, 12 | make_env: Callable, 13 | env_id: str, 14 | eval_episodes: int, 15 | run_name: str, 16 | Model: nn.Module, 17 | capture_video: bool = True, 18 | exploration_noise: float = 0.1, 19 | seed=1, 20 | ): 21 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 22 | obs, _ = envs.reset() 23 | 24 | Actor, QNetwork = Model 25 | action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0) 26 | action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0) 27 | actor = Actor( 28 | action_dim=np.prod(envs.single_action_space.shape), 29 | action_scale=action_scale, 30 | action_bias=action_bias, 31 | ) 32 | qf = QNetwork() 33 | key = jax.random.PRNGKey(seed) 34 | key, actor_key, qf_key = jax.random.split(key, 3) 35 | actor_params = actor.init(actor_key, obs) 36 | qf_params = qf.init(qf_key, obs, envs.action_space.sample()) 37 | # note: qf_params is not used in this script 38 | with open(model_path, "rb") as f: 39 | (actor_params, qf_params) = flax.serialization.from_bytes((actor_params, qf_params), f.read()) 40 | actor.apply = jax.jit(actor.apply) 41 | qf.apply = jax.jit(qf.apply) 42 | 43 | episodic_returns = [] 44 | while len(episodic_returns) < eval_episodes: 45 | actions = actor.apply(actor_params, obs) 46 | actions = np.array( 47 | [ 48 | (jax.device_get(actions)[0] + np.random.normal(0, action_scale * exploration_noise)[0]).clip( 49 | envs.single_action_space.low, envs.single_action_space.high 50 | ) 51 | ] 52 | ) 53 | 54 | next_obs, _, _, _, infos = envs.step(actions) 55 | if "final_info" in infos: 56 | for info in infos["final_info"]: 57 | if "episode" not in info: 58 | continue 59 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 60 | episodic_returns += [info["episode"]["r"]] 61 | obs = next_obs 62 | 63 | return episodic_returns 64 | 65 | 66 | if __name__ == "__main__": 67 | from huggingface_hub import hf_hub_download 68 | 69 | from cleanrl.ddpg_continuous_action_jax import Actor, QNetwork, make_env 70 | 71 | model_path = hf_hub_download( 72 | repo_id="cleanrl/HalfCheetah-v4-ddpg_continuous_action_jax-seed1", filename="ddpg_continuous_action_jax.cleanrl_model" 73 | ) 74 | evaluate( 75 | model_path, 76 | make_env, 77 | "HalfCheetah-v4", 78 | eval_episodes=10, 79 | run_name=f"eval", 80 | Model=(Actor, QNetwork), 81 | exploration_noise=0.1, 82 | ) 83 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/dqn_eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Callable 3 | 4 | import gymnasium as gym 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def evaluate( 10 | model_path: str, 11 | make_env: Callable, 12 | env_id: str, 13 | eval_episodes: int, 14 | run_name: str, 15 | Model: torch.nn.Module, 16 | device: torch.device = torch.device("cpu"), 17 | epsilon: float = 0.05, 18 | capture_video: bool = True, 19 | ): 20 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 21 | model = Model(envs).to(device) 22 | model.load_state_dict(torch.load(model_path, map_location=device)) 23 | model.eval() 24 | 25 | obs, _ = envs.reset() 26 | episodic_returns = [] 27 | while len(episodic_returns) < eval_episodes: 28 | if random.random() < epsilon: 29 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 30 | else: 31 | q_values = model(torch.Tensor(obs).to(device)) 32 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 33 | next_obs, _, _, _, infos = envs.step(actions) 34 | if "final_info" in infos: 35 | for info in infos["final_info"]: 36 | if "episode" not in info: 37 | continue 38 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 39 | episodic_returns += [info["episode"]["r"]] 40 | obs = next_obs 41 | 42 | return episodic_returns 43 | 44 | 45 | if __name__ == "__main__": 46 | from huggingface_hub import hf_hub_download 47 | 48 | from cleanrl.dqn import QNetwork, make_env 49 | 50 | model_path = hf_hub_download(repo_id="cleanrl/CartPole-v1-dqn-seed1", filename="q_network.pth") 51 | evaluate( 52 | model_path, 53 | make_env, 54 | "CartPole-v1", 55 | eval_episodes=10, 56 | run_name=f"eval", 57 | Model=QNetwork, 58 | device="cpu", 59 | capture_video=False, 60 | ) 61 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/dqn_jax_eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Callable 3 | 4 | import flax 5 | import flax.linen as nn 6 | import gymnasium as gym 7 | import jax 8 | import numpy as np 9 | 10 | 11 | def evaluate( 12 | model_path: str, 13 | make_env: Callable, 14 | env_id: str, 15 | eval_episodes: int, 16 | run_name: str, 17 | Model: nn.Module, 18 | epsilon: float = 0.05, 19 | capture_video: bool = True, 20 | seed=1, 21 | ): 22 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 23 | obs, _ = envs.reset() 24 | model = Model(action_dim=envs.single_action_space.n) 25 | q_key = jax.random.PRNGKey(seed) 26 | params = model.init(q_key, obs) 27 | with open(model_path, "rb") as f: 28 | params = flax.serialization.from_bytes(params, f.read()) 29 | model.apply = jax.jit(model.apply) 30 | 31 | episodic_returns = [] 32 | while len(episodic_returns) < eval_episodes: 33 | if random.random() < epsilon: 34 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 35 | else: 36 | q_values = model.apply(params, obs) 37 | actions = q_values.argmax(axis=-1) 38 | actions = jax.device_get(actions) 39 | next_obs, _, _, _, infos = envs.step(actions) 40 | if "final_info" in infos: 41 | for info in infos["final_info"]: 42 | if "episode" not in info: 43 | continue 44 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 45 | episodic_returns += [info["episode"]["r"]] 46 | obs = next_obs 47 | 48 | return episodic_returns 49 | 50 | 51 | if __name__ == "__main__": 52 | from huggingface_hub import hf_hub_download 53 | 54 | from cleanrl.dqn_jax import QNetwork, make_env 55 | 56 | model_path = hf_hub_download(repo_id="vwxyzjn/CartPole-v1-dqn_jax-seed1", filename="dqn_jax.cleanrl_model") 57 | evaluate( 58 | model_path, 59 | make_env, 60 | "CartPole-v1", 61 | eval_episodes=10, 62 | run_name=f"eval", 63 | Model=QNetwork, 64 | capture_video=False, 65 | ) 66 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/ppo_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import gymnasium as gym 4 | import torch 5 | 6 | 7 | def evaluate( 8 | model_path: str, 9 | make_env: Callable, 10 | env_id: str, 11 | eval_episodes: int, 12 | run_name: str, 13 | Model: torch.nn.Module, 14 | device: torch.device = torch.device("cpu"), 15 | capture_video: bool = True, 16 | gamma: float = 0.99, 17 | ): 18 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, gamma)]) 19 | agent = Model(envs).to(device) 20 | agent.load_state_dict(torch.load(model_path, map_location=device)) 21 | agent.eval() 22 | 23 | obs, _ = envs.reset() 24 | episodic_returns = [] 25 | while len(episodic_returns) < eval_episodes: 26 | actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device)) 27 | next_obs, _, _, _, infos = envs.step(actions.cpu().numpy()) 28 | if "final_info" in infos: 29 | for info in infos["final_info"]: 30 | if "episode" not in info: 31 | continue 32 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 33 | episodic_returns += [info["episode"]["r"]] 34 | obs = next_obs 35 | 36 | return episodic_returns 37 | 38 | 39 | if __name__ == "__main__": 40 | from huggingface_hub import hf_hub_download 41 | 42 | from cleanrl.ppo_continuous_action import Agent, make_env 43 | 44 | model_path = hf_hub_download( 45 | repo_id="sdpkjc/Hopper-v4-ppo_continuous_action-seed1", filename="ppo_continuous_action.cleanrl_model" 46 | ) 47 | evaluate( 48 | model_path, 49 | make_env, 50 | "Hopper-v4", 51 | eval_episodes=10, 52 | run_name=f"eval", 53 | Model=Agent, 54 | device="cpu", 55 | capture_video=False, 56 | ) 57 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/td3_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import gymnasium as gym 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def evaluate( 9 | model_path: str, 10 | make_env: Callable, 11 | env_id: str, 12 | eval_episodes: int, 13 | run_name: str, 14 | Model: nn.Module, 15 | device: torch.device = torch.device("cpu"), 16 | capture_video: bool = True, 17 | exploration_noise: float = 0.1, 18 | ): 19 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 20 | actor = Model[0](envs).to(device) 21 | qf1 = Model[1](envs).to(device) 22 | qf2 = Model[1](envs).to(device) 23 | actor_params, qf1_params, qf2_params = torch.load(model_path, map_location=device) 24 | actor.load_state_dict(actor_params) 25 | actor.eval() 26 | qf1.load_state_dict(qf1_params) 27 | qf2.load_state_dict(qf2_params) 28 | qf1.eval() 29 | qf2.eval() 30 | # note: qf1 and qf2 are not used in this script 31 | 32 | obs, _ = envs.reset() 33 | episodic_returns = [] 34 | while len(episodic_returns) < eval_episodes: 35 | with torch.no_grad(): 36 | actions = actor(torch.Tensor(obs).to(device)) 37 | actions += torch.normal(0, actor.action_scale * exploration_noise) 38 | actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) 39 | 40 | next_obs, _, _, _, infos = envs.step(actions) 41 | if "final_info" in infos: 42 | for info in infos["final_info"]: 43 | if "episode" not in info: 44 | continue 45 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 46 | episodic_returns += [info["episode"]["r"]] 47 | obs = next_obs 48 | 49 | return episodic_returns 50 | 51 | 52 | if __name__ == "__main__": 53 | from huggingface_hub import hf_hub_download 54 | 55 | from cleanrl.td3_continuous_action import Actor, QNetwork, make_env 56 | 57 | model_path = hf_hub_download( 58 | repo_id="cleanrl/HalfCheetah-v4-td3_continuous_action-seed1", filename="td3_continuous_action.cleanrl_model" 59 | ) 60 | evaluate( 61 | model_path, 62 | make_env, 63 | "HalfCheetah-v4", 64 | eval_episodes=10, 65 | run_name=f"eval", 66 | Model=(Actor, QNetwork), 67 | device="cpu", 68 | capture_video=False, 69 | ) 70 | -------------------------------------------------------------------------------- /cleanrl_utils/evals/td3_jax_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import flax 4 | import flax.linen as nn 5 | import gymnasium as gym 6 | import jax 7 | import numpy as np 8 | 9 | 10 | def evaluate( 11 | model_path: str, 12 | make_env: Callable, 13 | env_id: str, 14 | eval_episodes: int, 15 | run_name: str, 16 | Model: nn.Module, 17 | capture_video: bool = True, 18 | exploration_noise: float = 0.1, 19 | seed=1, 20 | ): 21 | envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)]) 22 | max_action = float(envs.single_action_space.high[0]) 23 | obs, _ = envs.reset() 24 | 25 | Actor, QNetwork = Model 26 | action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0) 27 | action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0) 28 | actor = Actor( 29 | action_dim=np.prod(envs.single_action_space.shape), 30 | action_scale=action_scale, 31 | action_bias=action_bias, 32 | ) 33 | qf = QNetwork() 34 | key = jax.random.PRNGKey(seed) 35 | key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) 36 | actor_params = actor.init(actor_key, obs) 37 | qf1_params = qf.init(qf1_key, obs, envs.action_space.sample()) 38 | qf2_params = qf.init(qf2_key, obs, envs.action_space.sample()) 39 | with open(model_path, "rb") as f: 40 | (actor_params, qf1_params, qf2_params) = flax.serialization.from_bytes( 41 | (actor_params, qf1_params, qf2_params), f.read() 42 | ) 43 | # note: qf1_params and qf2_params are not used in this script 44 | actor.apply = jax.jit(actor.apply) 45 | qf.apply = jax.jit(qf.apply) 46 | 47 | episodic_returns = [] 48 | while len(episodic_returns) < eval_episodes: 49 | actions = actor.apply(actor_params, obs) 50 | actions = np.array( 51 | [ 52 | ( 53 | jax.device_get(actions)[0] 54 | + np.random.normal(0, max_action * exploration_noise, size=envs.single_action_space.shape) 55 | ).clip(envs.single_action_space.low, envs.single_action_space.high) 56 | ] 57 | ) 58 | 59 | next_obs, _, _, _, infos = envs.step(actions) 60 | if "final_info" in infos: 61 | for info in infos["final_info"]: 62 | if "episode" not in info: 63 | continue 64 | print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") 65 | episodic_returns += [info["episode"]["r"]] 66 | obs = next_obs 67 | 68 | return episodic_returns 69 | 70 | 71 | if __name__ == "__main__": 72 | from huggingface_hub import hf_hub_download 73 | 74 | from cleanrl.td3_continuous_action_jax import Actor, QNetwork, make_env 75 | 76 | model_path = hf_hub_download( 77 | repo_id="cleanrl/HalfCheetah-v4-td3_continuous_action_jax-seed1", filename="td3_continuous_action_jax.cleanrl_model" 78 | ) 79 | evaluate( 80 | model_path, 81 | make_env, 82 | "HalfCheetah-v4", 83 | eval_episodes=10, 84 | run_name=f"eval", 85 | Model=(Actor, QNetwork), 86 | exploration_noise=0.1, 87 | ) 88 | -------------------------------------------------------------------------------- /cleanrl_utils/reproduce.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from distutils.util import strtobool 3 | 4 | import requests 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description="CleanRL Plots") 8 | # Common arguments 9 | parser.add_argument( 10 | "--run", 11 | type=str, 12 | default="cleanrl/cleanrl.benchmark/runs/thq5rgnz", 13 | help="the name of wandb project (e.g. cleanrl/cleanrl)", 14 | ) 15 | parser.add_argument( 16 | "--remove-entity", 17 | type=lambda x: bool(strtobool(x)), 18 | default=True, 19 | nargs="?", 20 | const=True, 21 | help="if toggled, the wandb-entity will be removed", 22 | ) 23 | args = parser.parse_args() 24 | uri = args.run.replace("/runs", "") 25 | 26 | requirements_txt_url = f"https://api.wandb.ai/files/{uri}/requirements.txt" 27 | metadata_url = f"https://api.wandb.ai/files/{uri}/wandb-metadata.json" 28 | metadata = requests.get(url=metadata_url).json() 29 | 30 | if args.remove_entity: 31 | a = [] 32 | wandb_entity_idx = None 33 | for i in range(len(metadata["args"])): 34 | if metadata["args"][i] == "--wandb-entity": 35 | wandb_entity_idx = i 36 | continue 37 | if wandb_entity_idx and i == wandb_entity_idx + 1: 38 | continue 39 | a += [metadata["args"][i]] 40 | else: 41 | a = metadata["args"] 42 | 43 | program = ["python"] + [metadata["program"]] + a 44 | 45 | print( 46 | f""" 47 | # run the following 48 | python3 -m venv venv 49 | source venv/bin/activate 50 | pip install -r {requirements_txt_url} 51 | curl -OL https://api.wandb.ai/files/{uri}/code/{metadata["codePath"]} 52 | {" ".join(program)} 53 | """ 54 | ) 55 | -------------------------------------------------------------------------------- /cloud/.gitignore: -------------------------------------------------------------------------------- 1 | *.tfstate** 2 | *.lock.** 3 | *.terraform -------------------------------------------------------------------------------- /cloud/examples/submit_exp.sh: -------------------------------------------------------------------------------- 1 | python -m cleanrl.submit_exp --exp-script offline_dqn_cql_atari_visual.sh \ 2 | --algo offline_dqn_cql_atari_visual.py \ 3 | --total-timesteps 10000000 \ 4 | --env-ids BeamRiderNoFrameskip-v4 QbertNoFrameskip-v4 SpaceInvadersNoFrameskip-v4 PongNoFrameskip-v4 BreakoutNoFrameskip-v4 \ 5 | --wandb-project-name cleanrl.benchmark \ 6 | --other-args "--wandb-entity cleanrl --cuda True" \ 7 | --job-queue cleanrl_gpu_large_memory \ 8 | --job-definition cleanrl \ 9 | --num-seed 2 \ 10 | --num-vcpu 16 \ 11 | --num-gpu 1 \ 12 | --num-memory 63000 \ 13 | --num-hours 48.0 \ 14 | --submit-aws $SUBMIT_AWS 15 | 16 | python ppg_procgen_impala_cnn.py --env-id starpilot --capture_video --track --wandb-entity cleanrl --wandb-project cleanrl.benchmark --seed 1 17 | 18 | python -m cleanrl.utils.submit_exp --exp-script ppo.sh \ 19 | --algo ppo.py \ 20 | --total-timesteps 100000 \ 21 | --env-ids CartPole-v0 \ 22 | --wandb-project-name cleanrl \ 23 | --other-args "--wandb-entity cleanrl --cuda True" \ 24 | --job-queue gpu \ 25 | --job-definition cleanrl \ 26 | --num-seed 1 \ 27 | --num-vcpu 1 \ 28 | --num-gpu 1 \ 29 | --num-memory 13000 \ 30 | --num-hours 48.0 \ 31 | --submit-aws $SUBMIT_AWS 32 | 33 | python -m cleanrl.utils.submit_exp --exp-script ppo.sh \ 34 | --algo ppo.py \ 35 | --total-timesteps 100000 \ 36 | --env-ids CartPole-v0 \ 37 | --wandb-project-name cleanrl \ 38 | --other-args "--wandb-entity cleanrl --cuda True" \ 39 | --job-queue cpu \ 40 | --job-definition cleanrl \ 41 | --num-seed 1 \ 42 | --num-vcpu 1 \ 43 | --num-memory 2000 \ 44 | --num-hours 48.0 \ 45 | --submit-aws $SUBMIT_AWS 46 | 47 | 48 | python -m cleanrl.utils.submit_exp --exp-script ppo.sh \ 49 | --algo ppo.py \ 50 | --other-args "--env-id CartPole-v0 --wandb-project-name cleanrl --total-timesteps 100000 --wandb-entity cleanrl --cuda True" \ 51 | --job-queue cpu \ 52 | --job-definition cleanrl \ 53 | --num-seed 1 \ 54 | --num-vcpu 1 \ 55 | --num-memory 2000 \ 56 | --num-hours 48.0 \ -------------------------------------------------------------------------------- /cloud/examples/terminate_all.sh: -------------------------------------------------------------------------------- 1 | # #! /bin/bash 2 | for i in $(aws batch list-jobs --job-queue cleanrl_gpu --job-status running --output text --query jobSummaryList[*].[jobId]) 3 | do 4 | echo "Deleting Job: $i" 5 | aws batch terminate-job --job-id $i --reason "Terminating job." 6 | echo "Job $i deleted" 7 | done 8 | 9 | #! /bin/bash 10 | for i in $(aws batch list-jobs --job-queue cleanrl_gpu --job-status runnable --output text --query jobSummaryList[*].[jobId]) 11 | do 12 | echo "Deleting Job: $i" 13 | aws batch terminate-job --job-id $i --reason "Terminating job." 14 | echo "Job $i deleted" 15 | done -------------------------------------------------------------------------------- /cloud/main.tf: -------------------------------------------------------------------------------- 1 | terraform { 2 | required_providers { 3 | aws = { 4 | source = "hashicorp/aws" 5 | version = "~> 3.27" 6 | } 7 | } 8 | 9 | required_version = ">= 0.14.9" 10 | } 11 | 12 | provider "aws" { 13 | profile = "default" 14 | # region = "us-west-2" 15 | } 16 | 17 | module "cleanrl" { 18 | source = "./modules/cleanrl" 19 | spot_bid_percentage = "50" 20 | instance_types = [ 21 | "g4dn.4xlarge", # 16 vCPU, 64GB, $1.204, GPU 22 | "g4dn.xlarge", # 4 vCPU, 16GB, $0.526, GPU 23 | "r5ad.large", # 2 vCPU, 16GB, $0.131 24 | "c5a.large", # 2 vCPU, 4GB, $0.077 25 | # ARM-based 26 | "c6g.medium", # 1 vCPU, 2GB, $0.034 27 | "m6gd.medium", # 1 vCPU, 4GB, $0.0452 28 | ] 29 | } -------------------------------------------------------------------------------- /cloud/modules/cleanrl/main.tf: -------------------------------------------------------------------------------- 1 | ############ 2 | # On-demand resources 3 | ############ 4 | 5 | resource "aws_batch_compute_environment" "on_demand" { 6 | count = length(var.instance_types) 7 | compute_environment_name = replace(var.instance_types[count.index], ".", "-") 8 | compute_resources { 9 | instance_role = aws_iam_instance_profile.ecs_instance_role.arn 10 | instance_type = [ 11 | var.instance_types[count.index], 12 | ] 13 | max_vcpus = var.max_vcpus 14 | min_vcpus = 0 15 | security_group_ids = [ 16 | aws_security_group.sample.id, 17 | ] 18 | subnets = data.aws_subnet_ids.all_default_subnets.ids 19 | type = "EC2" 20 | allocation_strategy = var.on_demand_allocation_strategy 21 | } 22 | service_role = aws_iam_role.aws_batch_service_role.arn 23 | type = "MANAGED" 24 | depends_on = [aws_iam_role_policy_attachment.aws_batch_service_role] 25 | } 26 | 27 | resource "aws_batch_job_queue" "on_demand" { 28 | count = length(var.instance_types) 29 | name = replace(var.instance_types[count.index], ".", "-") 30 | state = "ENABLED" 31 | priority = 100 32 | compute_environments = [ 33 | aws_batch_compute_environment.on_demand[count.index].arn, 34 | ] 35 | } 36 | 37 | ############ 38 | # Spot resources 39 | ############ 40 | 41 | resource "aws_batch_compute_environment" "spot" { 42 | count = length(var.instance_types) 43 | compute_environment_name = replace("${var.instance_types[count.index]}-spot", ".", "-") 44 | compute_resources { 45 | instance_role = aws_iam_instance_profile.ecs_instance_role.arn 46 | instance_type = [ 47 | var.instance_types[count.index], 48 | ] 49 | max_vcpus = var.max_vcpus 50 | min_vcpus = 0 51 | security_group_ids = [ 52 | aws_security_group.sample.id, 53 | ] 54 | subnets = data.aws_subnet_ids.all_default_subnets.ids 55 | type = "SPOT" 56 | bid_percentage = var.spot_bid_percentage 57 | allocation_strategy = var.spot_allocation_strategy 58 | spot_iam_fleet_role = aws_iam_role.AWS_EC2_spot_fleet_role.arn 59 | } 60 | service_role = aws_iam_role.aws_batch_service_role.arn 61 | type = "MANAGED" 62 | depends_on = [aws_iam_role_policy_attachment.aws_batch_service_role] 63 | } 64 | 65 | resource "aws_batch_job_queue" "spot" { 66 | count = length(var.instance_types) 67 | name = replace("${var.instance_types[count.index]}-spot", ".", "-") 68 | state = "ENABLED" 69 | priority = 100 70 | compute_environments = [ 71 | aws_batch_compute_environment.spot[count.index].arn, 72 | ] 73 | } 74 | -------------------------------------------------------------------------------- /cloud/modules/cleanrl/setups.tf: -------------------------------------------------------------------------------- 1 | resource "aws_iam_role" "ecs_instance_role" { 2 | name = "ecs_instance_role" 3 | assume_role_policy = < 36 | 37 | 38 | ## Resume training 39 | 40 | The second step is to automatically download the `agent.pt` from the URL above and resume training as follows: 41 | 42 | 43 | ```python linenums="1" hl_lines="6-16" 44 | num_updates = args.total_timesteps // args.batch_size 45 | 46 | CHECKPOINT_FREQUENCY = 50 47 | starting_update = 1 48 | 49 | if args.track and wandb.run.resumed: 50 | starting_update = run.summary.get("charts/update") + 1 51 | global_step = starting_update * args.batch_size 52 | api = wandb.Api() 53 | run = api.run(f"{run.entity}/{run.project}/{run.id}") 54 | model = run.file("agent.pt") 55 | model.download(f"models/{experiment_name}/") 56 | agent.load_state_dict(torch.load( 57 | f"models/{experiment_name}/agent.pt", map_location=device)) 58 | agent.eval() 59 | print(f"resumed at update {starting_update}") 60 | 61 | for update in range(starting_update, num_updates + 1): 62 | # ... do rollouts and train models 63 | 64 | if args.track: 65 | # make sure to tune `CHECKPOINT_FREQUENCY` 66 | # so models are not saved too frequently 67 | if update % CHECKPOINT_FREQUENCY == 0: 68 | torch.save(agent.state_dict(), f"{wandb.run.dir}/agent.pt") 69 | wandb.save(f"{wandb.run.dir}/agent.pt", policy="now") 70 | ``` 71 | 72 | To resume training, note the ID of the experiment is `21421tda` as in the URL [https://wandb.ai/costa-huang/cleanRL/runs/21421tda](https://wandb.ai/costa-huang/cleanRL/runs/21421tda), so we need to pass in the ID via environment variable to trigger the resume mode of W&B: 73 | 74 | ``` 75 | WANDB_RUN_ID=21421tda WANDB_RESUME=must python ppo_gridnet.py --prod-mode --capture_video 76 | ``` -------------------------------------------------------------------------------- /docs/benchmark/ddpg.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ddpg_continuous_action ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ddpg_continuous_action_jax ({'tag': ['pr-424']}) | 2 | |:--------------------|:-----------------------------------------------------------------------|:---------------------------------------------------------------------------| 3 | | HalfCheetah-v4 | 10374.07 ± 157.37 | 8638.60 ± 1954.46 | 4 | | Walker2d-v4 | 1240.16 ± 390.10 | 1427.23 ± 104.91 | 5 | | Hopper-v4 | 1576.78 ± 818.98 | 1208.52 ± 659.22 | 6 | | InvertedPendulum-v4 | 642.68 ± 69.56 | 804.30 ± 87.60 | 7 | | Humanoid-v4 | 1699.56 ± 694.22 | 1513.61 ± 248.60 | 8 | | Pusher-v4 | -77.30 ± 38.78 | -38.56 ± 4.47 | -------------------------------------------------------------------------------- /docs/benchmark/ppo.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo ({'tag': ['pr-424']}) | 2 | |:---------------|:----------------------------------------------------| 3 | | CartPole-v1 | 490.04 ± 6.12 | 4 | | Acrobot-v1 | -86.36 ± 1.32 | 5 | | MountainCar-v0 | -200.00 ± 0.00 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari ({'tag': ['pr-424']}) | 2 | |:------------------------|:----------------------------------------------------------| 3 | | PongNoFrameskip-v4 | 20.36 ± 0.20 | 4 | | BeamRiderNoFrameskip-v4 | 1915.93 ± 484.58 | 5 | | BreakoutNoFrameskip-v4 | 414.66 ± 28.09 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_envpool.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_envpool ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari ({'tag': ['pr-424']}) | 2 | |:-------------|:------------------------------------------------------------------|:----------------------------------------------------------| 3 | | Pong-v5 | 20.45 ± 0.09 | 20.36 ± 0.20 | 4 | | BeamRider-v5 | 2501.85 ± 210.52 | 1915.93 ± 484.58 | 5 | | Breakout-v5 | 211.24 ± 151.84 | 414.66 ± 28.09 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_envpool_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_envpool ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari ({'tag': ['pr-424']}) | 2 | |:-------------|------------------------------------------------------------------:|----------------------------------------------------------:| 3 | | Pong-v5 | 178.375 | 281.071 | 4 | | BeamRider-v5 | 182.944 | 284.941 | 5 | | Breakout-v5 | 151.384 | 264.077 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_envpool_xla_jax_scan.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax_scan ({'tag': ['pr-424']}) | 2 | |:-------------|:--------------------------------------------------------------------------|:-------------------------------------------------------------------------------| 3 | | Pong-v5 | 20.82 ± 0.21 | 20.52 ± 0.32 | 4 | | BeamRider-v5 | 2678.73 ± 426.42 | 2860.61 ± 801.30 | 5 | | Breakout-v5 | 420.92 ± 16.75 | 423.90 ± 5.49 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_envpool_xla_jax_scan_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax_scan ({'tag': ['pr-424']}) | 2 | |:-------------|--------------------------------------------------------------------------:|-------------------------------------------------------------------------------:| 3 | | Pong-v5 | 34.3237 | 34.701 | 4 | | BeamRider-v5 | 37.1076 | 37.2449 | 5 | | Breakout-v5 | 39.576 | 39.775 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_lstm.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_lstm ({'tag': ['pr-424']}) | 2 | |:------------------------|:---------------------------------------------------------------| 3 | | PongNoFrameskip-v4 | 19.81 ± 0.62 | 4 | | BeamRiderNoFrameskip-v4 | 1299.25 ± 509.90 | 5 | | BreakoutNoFrameskip-v4 | 113.42 ± 5.85 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_lstm_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_lstm ({'tag': ['pr-424']}) | 2 | |:------------------------|---------------------------------------------------------------:| 3 | | PongNoFrameskip-v4 | 317.607 | 4 | | BeamRiderNoFrameskip-v4 | 314.864 | 5 | | BreakoutNoFrameskip-v4 | 383.724 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_multigpu.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_multigpu ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari ({'tag': ['pr-424']}) | 2 | |:------------------------|:-------------------------------------------------------------------|:----------------------------------------------------------| 3 | | PongNoFrameskip-v4 | 20.34 ± 0.43 | 20.36 ± 0.20 | 4 | | BeamRiderNoFrameskip-v4 | 2414.65 ± 643.74 | 1915.93 ± 484.58 | 5 | | BreakoutNoFrameskip-v4 | 414.94 ± 20.60 | 414.66 ± 28.09 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_multigpu_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_multigpu ({'tag': ['pr-424']}) | 2 | |:------------------------|-------------------------------------------------------------------:| 3 | | PongNoFrameskip-v4 | 276.599 | 4 | | BeamRiderNoFrameskip-v4 | 280.902 | 5 | | BreakoutNoFrameskip-v4 | 270.532 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_atari_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari ({'tag': ['pr-424']}) | 2 | |:------------------------|----------------------------------------------------------:| 3 | | PongNoFrameskip-v4 | 281.071 | 4 | | BeamRiderNoFrameskip-v4 | 284.941 | 5 | | BreakoutNoFrameskip-v4 | 264.077 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_continuous_action.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_continuous_action ({'tag': ['pr-424']}) | 2 | |:-------------------------------------|:----------------------------------------------------------------------| 3 | | HalfCheetah-v4 | 1442.64 ± 46.03 | 4 | | Walker2d-v4 | 2287.95 ± 571.78 | 5 | | Hopper-v4 | 2382.86 ± 271.74 | 6 | | InvertedPendulum-v4 | 963.09 ± 22.20 | 7 | | Humanoid-v4 | 716.11 ± 49.08 | 8 | | Pusher-v4 | -40.38 ± 7.15 | 9 | | dm_control/acrobot-swingup-v0 | 25.60 ± 6.30 | 10 | | dm_control/acrobot-swingup_sparse-v0 | 1.35 ± 0.27 | 11 | | dm_control/ball_in_cup-catch-v0 | 619.26 ± 278.67 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_continuous_action_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_continuous_action ({'tag': ['pr-424']}) | 2 | |:-------------------------------------|----------------------------------------------------------------------:| 3 | | HalfCheetah-v4 | 25.3589 | 4 | | Walker2d-v4 | 24.3157 | 5 | | Hopper-v4 | 25.7066 | 6 | | InvertedPendulum-v4 | 23.7672 | 7 | | Humanoid-v4 | 49.5592 | 8 | | Pusher-v4 | 28.8162 | 9 | | dm_control/acrobot-swingup-v0 | 26.5793 | 10 | | dm_control/acrobot-swingup_sparse-v0 | 25.1265 | 11 | | dm_control/ball_in_cup-catch-v0 | 26.1947 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_envpool.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax_scan ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari_envpool ({'tag': ['pr-424']}) | 2 | |:-------------|:--------------------------------------------------------------------------|:-------------------------------------------------------------------------------|:------------------------------------------------------------------| 3 | | Pong-v5 | 20.82 ± 0.21 | 20.52 ± 0.32 | 20.45 ± 0.09 | 4 | | BeamRider-v5 | 2678.73 ± 426.42 | 2860.61 ± 801.30 | 2501.85 ± 210.52 | 5 | | Breakout-v5 | 420.92 ± 16.75 | 423.90 ± 5.49 | 211.24 ± 151.84 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_envpool_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari_envpool_xla_jax_scan ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/ppo_atari_envpool ({'tag': ['pr-424']}) | 2 | |:-------------|--------------------------------------------------------------------------:|-------------------------------------------------------------------------------:|------------------------------------------------------------------:| 3 | | Pong-v5 | 34.3237 | 34.701 | 178.375 | 4 | | BeamRider-v5 | 37.1076 | 37.2449 | 182.944 | 5 | | Breakout-v5 | 39.576 | 39.775 | 151.384 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_procgen.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_procgen ({'tag': ['pr-424']}) | 2 | |:----------|:------------------------------------------------------------| 3 | | starpilot | 30.99 ± 1.96 | 4 | | bossfight | 8.85 ± 0.33 | 5 | | bigfish | 16.46 ± 2.71 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_procgen_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo_procgen ({'tag': ['pr-424']}) | 2 | |:----------|------------------------------------------------------------:| 3 | | starpilot | 114.649 | 4 | | bossfight | 128.679 | 5 | | bigfish | 107.788 | -------------------------------------------------------------------------------- /docs/benchmark/ppo_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/ppo ({'tag': ['pr-424']}) | 2 | |:---------------|----------------------------------------------------:| 3 | | CartPole-v1 | 10.4737 | 4 | | Acrobot-v1 | 15.4606 | 5 | | MountainCar-v0 | 6.95995 | -------------------------------------------------------------------------------- /docs/benchmark/sac.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/sac_continuous_action ({'tag': ['pr-424']}) | 2 | |:--------------------|:----------------------------------------------------------------------| 3 | | HalfCheetah-v4 | 9634.89 ± 1423.73 | 4 | | Walker2d-v4 | 3591.45 ± 911.33 | 5 | | Hopper-v4 | 2310.46 ± 342.82 | 6 | | InvertedPendulum-v4 | 909.37 ± 55.66 | 7 | | Humanoid-v4 | 4996.29 ± 686.40 | 8 | | Pusher-v4 | -22.45 ± 0.51 | -------------------------------------------------------------------------------- /docs/benchmark/sac_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/sac_continuous_action ({'tag': ['pr-424']}) | 2 | |:--------------------|----------------------------------------------------------------------:| 3 | | HalfCheetah-v4 | 174.778 | 4 | | Walker2d-v4 | 161.161 | 5 | | Hopper-v4 | 173.242 | 6 | | InvertedPendulum-v4 | 179.042 | 7 | | Humanoid-v4 | 177.31 | 8 | | Pusher-v4 | 172.123 | -------------------------------------------------------------------------------- /docs/benchmark/td3.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/td3_continuous_action ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/td3_continuous_action_jax ({'tag': ['pr-424']}) | 2 | |:--------------------|:----------------------------------------------------------------------|:--------------------------------------------------------------------------| 3 | | HalfCheetah-v4 | 9583.22 ± 126.09 | 9345.93 ± 770.54 | 4 | | Walker2d-v4 | 4057.59 ± 658.78 | 3686.19 ± 141.23 | 5 | | Hopper-v4 | 3134.61 ± 360.18 | 2940.10 ± 655.63 | 6 | | InvertedPendulum-v4 | 968.99 ± 25.80 | 988.94 ± 8.86 | 7 | | Humanoid-v4 | 5035.36 ± 21.67 | 5033.22 ± 122.14 | 8 | | Pusher-v4 | -30.92 ± 1.05 | -29.18 ± 1.02 | -------------------------------------------------------------------------------- /docs/benchmark/td3_runtimes.md: -------------------------------------------------------------------------------- 1 | | | openrlbenchmark/cleanrl/td3_continuous_action ({'tag': ['pr-424']}) | openrlbenchmark/cleanrl/td3_continuous_action_jax ({'tag': ['pr-424']}) | 2 | |:--------------------|----------------------------------------------------------------------:|--------------------------------------------------------------------------:| 3 | | HalfCheetah-v4 | 87.353 | 39.5119 | 4 | | Walker2d-v4 | 80.8592 | 34.0497 | 5 | | Hopper-v4 | 90.9921 | 33.4079 | 6 | | InvertedPendulum-v4 | 70.4218 | 30.2624 | 7 | | Humanoid-v4 | 79.1624 | 70.2437 | 8 | | Pusher-v4 | 95.2208 | 39.6051 | -------------------------------------------------------------------------------- /docs/blog/.authors.yml: -------------------------------------------------------------------------------- 1 | costa: 2 | name: Costa Huang 3 | description: Lead dev of CleanRL 4 | avatar: https://avatars.githubusercontent.com/u/5555347 5 | -------------------------------------------------------------------------------- /docs/blog/.meta.yml: -------------------------------------------------------------------------------- 1 | # comments: true 2 | # hide: 3 | # - feedback 4 | -------------------------------------------------------------------------------- /docs/blog/index.md: -------------------------------------------------------------------------------- 1 | # Blog 2 | -------------------------------------------------------------------------------- /docs/cloud/aws_batch1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/cloud/aws_batch1.png -------------------------------------------------------------------------------- /docs/cloud/aws_batch2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/cloud/aws_batch2.png -------------------------------------------------------------------------------- /docs/cloud/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | The rough idea behind the cloud integration is to package our code into a docker container and use AWS Batch to 4 | run thousands of experiments concurrently. 5 | 6 | ## Prerequisites 7 | 8 | * Terraform (see installation tutorial [here](https://learn.hashicorp.com/tutorials/terraform/install-cli)) 9 | 10 | We use Terraform to define our infrastructure with AWS Batch, which you can spin up as follows 11 | 12 | ```bash 13 | # assuming you are at the root of the CleanRL project 14 | poetry install -E cloud 15 | cd cloud 16 | python -m awscli configure 17 | terraform init 18 | export AWS_DEFAULT_REGION=$(aws configure get region --profile default) 19 | terraform apply 20 | ``` 21 | 22 | 23 | 24 | !!! note 25 | Don't worry about the cost of spining up these AWS Batch compute environments and job queues. They are completely free and you are only charged when you submit experiments. 26 | 27 | 28 | Then your AWS Batch console should look like 29 | 30 | ![aws_batch1.png](aws_batch1.png) 31 | 32 | 33 | ### Clean Up 34 | Uninstalling/Deleting the infrastructure is pretty straightforward: 35 | ``` 36 | export AWS_DEFAULT_REGION=$(aws configure get region --profile default) 37 | terraform destroy 38 | ``` 39 | -------------------------------------------------------------------------------- /docs/cloud/wandb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/cloud/wandb.png -------------------------------------------------------------------------------- /docs/css/custom.css: -------------------------------------------------------------------------------- 1 | .termynal-comment { 2 | color: #4a968f; 3 | font-style: italic; 4 | display: block; 5 | } 6 | 7 | .termy [data-termynal] { 8 | white-space: pre-wrap; 9 | } 10 | 11 | a.external-link::after { 12 | /* \00A0 is a non-breaking space 13 | to make the mark be on the same line as the link 14 | */ 15 | content: "\00A0[↪]"; 16 | } 17 | 18 | a.internal-link::after { 19 | /* \00A0 is a non-breaking space 20 | to make the mark be on the same line as the link 21 | */ 22 | content: "\00A0↪"; 23 | } 24 | 25 | .shadow { 26 | box-shadow: 5px 5px 10px #999; 27 | } 28 | 29 | /* Give space to lower icons so Gitter chat doesn't get on top of them */ 30 | .md-footer-meta { 31 | padding-bottom: 2em; 32 | } 33 | 34 | .user-list { 35 | display: flex; 36 | flex-wrap: wrap; 37 | margin-bottom: 2rem; 38 | } 39 | 40 | .user-list-center { 41 | justify-content: space-evenly; 42 | } 43 | 44 | .user { 45 | margin: 1em; 46 | min-width: 7em; 47 | } 48 | 49 | .user .avatar-wrapper { 50 | width: 80px; 51 | height: 80px; 52 | margin: 10px auto; 53 | overflow: hidden; 54 | border-radius: 50%; 55 | position: relative; 56 | } 57 | 58 | .user .avatar-wrapper img { 59 | position: absolute; 60 | top: 50%; 61 | left: 50%; 62 | transform: translate(-50%, -50%); 63 | } 64 | 65 | .user .title { 66 | text-align: center; 67 | } 68 | 69 | .user .count { 70 | font-size: 80%; 71 | text-align: center; 72 | } 73 | 74 | a.announce-link:link, 75 | a.announce-link:visited { 76 | color: #fff; 77 | } 78 | 79 | a.announce-link:hover { 80 | color: var(--md-accent-fg-color); 81 | } 82 | 83 | .announce-wrapper { 84 | display: flex; 85 | justify-content: space-between; 86 | flex-wrap: wrap; 87 | align-items: center; 88 | } 89 | 90 | .announce-wrapper div.item { 91 | display: none; 92 | } 93 | 94 | .announce-wrapper .sponsor-badge { 95 | display: block; 96 | position: absolute; 97 | top: -5px; 98 | right: 0; 99 | font-size: 0.5rem; 100 | color: #999; 101 | background-color: #666; 102 | border-radius: 10px; 103 | padding: 0 10px; 104 | z-index: 10; 105 | } 106 | 107 | .announce-wrapper .sponsor-image { 108 | display: block; 109 | border-radius: 20px; 110 | } 111 | 112 | .announce-wrapper>div { 113 | min-height: 40px; 114 | display: flex; 115 | align-items: center; 116 | } 117 | 118 | .twitter { 119 | color: #00acee; 120 | } 121 | -------------------------------------------------------------------------------- /docs/css/termynal.css: -------------------------------------------------------------------------------- 1 | /** 2 | * termynal.js 3 | * 4 | * @author Ines Montani 5 | * @version 0.0.1 6 | * @license MIT 7 | */ 8 | 9 | :root { 10 | --color-bg: #252a33; 11 | --color-text: #eee; 12 | --color-text-subtle: #a2a2a2; 13 | } 14 | 15 | [data-termynal] { 16 | width: 750px; 17 | max-width: 100%; 18 | background: var(--color-bg); 19 | color: var(--color-text); 20 | /* font-size: 18px; */ 21 | font-size: 15px; 22 | /* font-family: 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; */ 23 | font-family: 'Roboto Mono', 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; 24 | border-radius: 4px; 25 | padding: 75px 45px 35px; 26 | position: relative; 27 | -webkit-box-sizing: border-box; 28 | box-sizing: border-box; 29 | } 30 | 31 | [data-termynal]:before { 32 | content: ''; 33 | position: absolute; 34 | top: 15px; 35 | left: 15px; 36 | display: inline-block; 37 | width: 15px; 38 | height: 15px; 39 | border-radius: 50%; 40 | /* A little hack to display the window buttons in one pseudo element. */ 41 | background: #d9515d; 42 | -webkit-box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 43 | box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 44 | } 45 | 46 | [data-termynal]:after { 47 | content: 'bash'; 48 | position: absolute; 49 | color: var(--color-text-subtle); 50 | top: 5px; 51 | left: 0; 52 | width: 100%; 53 | text-align: center; 54 | } 55 | 56 | a[data-terminal-control] { 57 | text-align: right; 58 | display: block; 59 | color: #aebbff; 60 | } 61 | 62 | [data-ty] { 63 | display: block; 64 | line-height: 2; 65 | } 66 | 67 | [data-ty]:before { 68 | /* Set up defaults and ensure empty lines are displayed. */ 69 | content: ''; 70 | display: inline-block; 71 | vertical-align: middle; 72 | } 73 | 74 | [data-ty="input"]:before, 75 | [data-ty-prompt]:before { 76 | margin-right: 0.75em; 77 | color: var(--color-text-subtle); 78 | } 79 | 80 | [data-ty="input"]:before { 81 | content: '$'; 82 | } 83 | 84 | [data-ty][data-ty-prompt]:before { 85 | content: attr(data-ty-prompt); 86 | } 87 | 88 | [data-ty-cursor]:after { 89 | content: attr(data-ty-cursor); 90 | font-family: monospace; 91 | margin-left: 0.5em; 92 | -webkit-animation: blink 1s infinite; 93 | animation: blink 1s infinite; 94 | } 95 | 96 | 97 | /* Cursor animation */ 98 | 99 | @-webkit-keyframes blink { 100 | 50% { 101 | opacity: 0; 102 | } 103 | } 104 | 105 | @keyframes blink { 106 | 50% { 107 | opacity: 0; 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /docs/get-started/colab-badge.svg: -------------------------------------------------------------------------------- 1 | Open in ColabOpen in Colab 2 | -------------------------------------------------------------------------------- /docs/get-started/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Atari 4 | ``` 5 | poetry shell 6 | 7 | poetry install -E atari 8 | python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4 9 | python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4 10 | python cleanrl/ppo_atari.py --env-id BreakoutNoFrameskip-v4 11 | python cleanrl/sac_atari.py --env-id BreakoutNoFrameskip-v4 12 | 13 | # NEW: 3-4x side-effects free speed up with envpool's atari (only available to linux) 14 | poetry install -E envpool 15 | python cleanrl/ppo_atari_envpool.py --env-id BreakoutNoFrameskip-v4 16 | # Learn Pong-v5 in ~5-10 mins 17 | # Side effects such as lower sample efficiency might occur 18 | poetry run python ppo_atari_envpool.py --clip-coef=0.2 --num-envs=16 --num-minibatches=8 --num-steps=128 --update-epochs=3 19 | ``` 20 | ### Demo 21 | 22 | 23 | 24 | You can also run training scripts in other games, such as: 25 | 26 | ## Classic Control 27 | ``` 28 | poetry shell 29 | 30 | python cleanrl/dqn.py --env-id CartPole-v1 31 | python cleanrl/ppo.py --env-id CartPole-v1 32 | python cleanrl/c51.py --env-id CartPole-v1 33 | ``` 34 | 35 | ## Procgen 36 | ``` 37 | poetry shell 38 | 39 | poetry install -E procgen 40 | python cleanrl/ppo_procgen.py --env-id starpilot 41 | python cleanrl/ppg_procgen.py --env-id starpilot 42 | ``` 43 | 44 | 45 | ## PPO + LSTM 46 | ``` 47 | poetry shell 48 | 49 | poetry install -E atari 50 | python cleanrl/ppo_atari_lstm.py --env-id BreakoutNoFrameskip-v4 51 | ``` 52 | -------------------------------------------------------------------------------- /docs/get-started/experiment-tracking.md: -------------------------------------------------------------------------------- 1 | # Experiment tracking 2 | 3 | To use experiment tracking with wandb, run with the `--track` flag, which will also 4 | upload the videos recorded by the `--capture_video` flag. 5 | ```bash 6 | poetry shell 7 | wandb login # only required for the first time 8 | python cleanrl/ppo.py --track --capture_video 9 | ``` 10 | 11 | 12 | 13 | 14 | The console will output the url for the tracked experiment like the following 15 | 16 | ```bash 17 | wandb: View project at https://wandb.ai/costa-huang/cleanRL 18 | wandb: View run at https://wandb.ai/costa-huang/cleanRL/runs/10dwbgeh 19 | ``` 20 | 21 | When you open the URL, it's going to look like the following page: 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/get-started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Prerequisites 4 | 5 | * Python >=3.7.1,<3.11 6 | * [Poetry 1.2.1+](https://python-poetry.org) 7 | 8 | Simply run the following command for a quick start 9 | 10 | ```bash 11 | git clone https://github.com/vwxyzjn/cleanrl.git && cd cleanrl 12 | poetry install 13 | ``` 14 | 15 | 16 | 17 | 18 | !!! warning "`poetry install` hangs / stucks" 19 | 20 | Since 1.2+ `poetry` added some keyring authentication mechanisms that may cause `poetry install` hang or stuck. See [:material-github: python-poetry/poetry#1917](https://github.com/python-poetry/poetry/issues/1917). To fix this issue, try 21 | 22 | ```bash 23 | export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring 24 | poetry install 25 | ``` 26 | 27 | 28 | !!! note "Working with different CUDA versions for `torch`" 29 | 30 | By default, the `torch` wheel is built with CUDA 10.2. If you are using newer NVIDIA GPUs (e.g., 3060 TI), you may need to specifically install CUDA 11.3 wheels by overriding the `torch` dependency with `pip`: 31 | 32 | ```bash 33 | poetry run pip install "torch==1.12.1" --upgrade --extra-index-url https://download.pytorch.org/whl/cu113 34 | ``` 35 | 36 | 37 | !!! note "Working with PyPI mirrors" 38 | 39 | Users in some countries (e.g., China) can usually speed up package installation via faster PyPI mirrors. If this helps you, try appending the following lines to the [pyproject.toml](https://github.com/vwxyzjn/cleanrl/blob/master/pyproject.toml) at the root of this repository and run `poetry install` 40 | 41 | ```toml 42 | [[tool.poetry.source]] 43 | name = "douban" 44 | url = "https://pypi.doubanio.com/simple/" 45 | default = true 46 | ``` 47 | 48 | 49 | ## Install via `pip` 50 | 51 | While we recommend using `poetry` to manage environments and dependencies, the traditional `requirements.txt` are available: 52 | 53 | ```bash 54 | # core dependencies 55 | pip install -r requirements/requirements.txt 56 | 57 | # optional dependencies 58 | pip install -r requirements/requirements-atari.txt 59 | pip install -r requirements/requirements-mujoco.txt 60 | pip install -r requirements/requirements-mujoco_py.txt 61 | pip install -r requirements/requirements-procgen.txt 62 | pip install -r requirements/requirements-envpool.txt 63 | pip install -r requirements/requirements-pettingzoo.txt 64 | pip install -r requirements/requirements-jax.txt 65 | pip install -r requirements/requirements-docs.txt 66 | pip install -r requirements/requirements-cloud.txt 67 | ``` 68 | 69 | 70 | ## Optional Dependencies 71 | 72 | CleanRL makes it easy to install optional dependencies for common RL environments 73 | and various development utilities. These optional dependencies are defined at the 74 | [`pyproject.toml`](https://github.com/vwxyzjn/cleanrl/blob/6afb51624a6fd51775b8351dd25099bd778cb1b1/pyproject.toml#L22-L37) as [poetry dependency groups](https://python-poetry.org/docs/master/managing-dependencies/#dependency-groups): 75 | 76 | 77 | ```toml 78 | [tool.poetry.group.atari] 79 | optional = true 80 | [tool.poetry.group.atari.dependencies] 81 | ale-py = "0.7.4" 82 | AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"} 83 | opencv-python = "^4.6.0.66" 84 | 85 | [tool.poetry.group.procgen] 86 | optional = true 87 | [tool.poetry.group.procgen.dependencies] 88 | procgen = "^0.10.7" 89 | ``` 90 | 91 | You can install them using the following command 92 | 93 | ```bash 94 | poetry install -E atari 95 | poetry install -E mujoco 96 | poetry install -E mujoco_py 97 | poetry install -E dm_control 98 | poetry install -E procgen 99 | poetry install -E envpool 100 | poetry install -E pettingzoo 101 | poetry install -E jax 102 | poetry install -E optuna 103 | poetry install -E docs 104 | poetry install -E cloud 105 | ``` 106 | -------------------------------------------------------------------------------- /docs/get-started/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/get-started/tensorboard.png -------------------------------------------------------------------------------- /docs/get-started/videos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/get-started/videos.png -------------------------------------------------------------------------------- /docs/get-started/videos2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/get-started/videos2.png -------------------------------------------------------------------------------- /docs/js/chat.js: -------------------------------------------------------------------------------- 1 | ((window.gitter = {}).chat = {}).options = { 2 | room: 'tiangolo/fastapi' 3 | }; 4 | -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/Acrobot-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/Acrobot-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/CartPole-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/CartPole-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/MountainCar-v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/MountainCar-v0.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/Acrobot-v1-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/Acrobot-v1-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/Acrobot-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/Acrobot-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/BeamRiderNoFrameskip-v4-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/BeamRiderNoFrameskip-v4-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/BreakoutNoFrameskip-v4-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/BreakoutNoFrameskip-v4-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/CartPole-v1-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/CartPole-v1-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/CartPole-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/CartPole-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/MountainCar-v0-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/MountainCar-v0-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/MountainCar-v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/MountainCar-v0.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/PongNoFrameskip-v4-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/PongNoFrameskip-v4-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/c51/jax/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/c51/jax/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ddpg-jax/HalfCheetah-v2-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ddpg-jax/HalfCheetah-v2-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ddpg-jax/HalfCheetah-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ddpg-jax/HalfCheetah-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ddpg-jax/Hopper-v2-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ddpg-jax/Hopper-v2-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ddpg-jax/Hopper-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ddpg-jax/Hopper-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ddpg-jax/Walker2d-v2-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ddpg-jax/Walker2d-v2-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ddpg-jax/Walker2d-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ddpg-jax/Walker2d-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ddpg/ddpg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ddpg/ddpg.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/Acrobot-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/Acrobot-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/CartPole-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/CartPole-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/MountainCar-v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/MountainCar-v0.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/Acrobot-v1-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/Acrobot-v1-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/Acrobot-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/Acrobot-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/BeamRiderNoFrameskip-v4-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/BeamRiderNoFrameskip-v4-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/BreakoutNoFrameskip-v4-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/BreakoutNoFrameskip-v4-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/CartPole-v1-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/CartPole-v1-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/CartPole-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/CartPole-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/MountainCar-v0-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/MountainCar-v0-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/MountainCar-v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/MountainCar-v0.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/PongNoFrameskip-v4-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/PongNoFrameskip-v4-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/dqn/jax/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/dqn/jax/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppg/BigFish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppg/BigFish.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppg/BossFight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppg/BossFight.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppg/StarPilot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppg/StarPilot.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppg/comparison/BigFish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppg/comparison/BigFish.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppg/comparison/BossFight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppg/comparison/BossFight.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppg/comparison/StarPilot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppg/comparison/StarPilot.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppg/ppg-ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppg/ppg-ppo.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo-rnd/MontezumaRevenge-v5-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo-rnd/MontezumaRevenge-v5-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo-rnd/MontezumaRevenge-v5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo-rnd/MontezumaRevenge-v5.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo-trxl/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo-trxl/compare.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Acrobot-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Acrobot-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BeamRider-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BeamRider-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BeamRider.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BeamRider.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BeamRiderNoFrameskip-v4multigpu-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BeamRiderNoFrameskip-v4multigpu-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BeamRiderNoFrameskip-v4multigpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BeamRiderNoFrameskip-v4multigpu.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BigFish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BigFish.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BossFight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BossFight.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Breakout-a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Breakout-a.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Breakout-time-a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Breakout-time-a.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Breakout-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Breakout-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Breakout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Breakout.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BreakoutNoFrameskip-v4multigpu-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BreakoutNoFrameskip-v4multigpu-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/BreakoutNoFrameskip-v4multigpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/BreakoutNoFrameskip-v4multigpu.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/CartPole-v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/CartPole-v1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/HalfCheetah-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/HalfCheetah-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Hopper-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Hopper-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/MountainCar-v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/MountainCar-v0.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Pong-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Pong-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Pong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Pong.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/PongNoFrameskip-v4multigpu-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/PongNoFrameskip-v4multigpu-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/PongNoFrameskip-v4multigpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/PongNoFrameskip-v4multigpu.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/StarPilot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/StarPilot.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/Walker2d-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/Walker2d-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand-c-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand-c-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand-c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand-c.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/AllegroHand.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Ant-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Ant-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Ant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Ant.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Anymal-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Anymal-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Anymal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Anymal.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/BallBalance-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/BallBalance-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/BallBalance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/BallBalance.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Cartpole-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Cartpole-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Cartpole.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Humanoid-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Humanoid-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/Humanoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/Humanoid.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand-c-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand-c-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand-c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand-c.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/ShadowHand.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/old/AllegroHand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/old/AllegroHand.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/old/Ant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/old/Ant.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/old/Anymal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/old/Anymal.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/old/BallBalance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/old/BallBalance.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/old/Cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/old/Cartpole.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/old/Humanoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/old/Humanoid.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/isaacgymenvs/old/ShadowHand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/isaacgymenvs/old/ShadowHand.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/lstm/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/lstm/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/lstm/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/lstm/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/lstm/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/lstm/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/pong_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/pong_v3.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo-1-title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo-1-title.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo-2-title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo-2-title.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo-3-title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo-3-title.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/hms_each_game.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/hms_each_game.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/hns_ppo_vs_baselines.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/hns_ppo_vs_baselines.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/hns_ppo_vs_r2d2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/hns_ppo_vs_r2d2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/runset_0_hms_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/runset_0_hms_bar.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/runset_1_hms_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax/runset_1_hms_bar.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax_scan/compare-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax_scan/compare-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax_scan/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_atari_envpool_xla_jax_scan/compare.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_continuous_action_gymnasium_dm_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_continuous_action_gymnasium_dm_control.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_continuous_action_gymnasium_mujoco_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_continuous_action_gymnasium_mujoco_v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_continuous_action_gymnasium_mujoco_v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_continuous_action_gymnasium_mujoco_v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/ppo_continuous_action_v2_vs_v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/ppo_continuous_action_v2_vs_v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/surround_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/surround_v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/ppo/tennis_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/ppo/tennis_v3.png -------------------------------------------------------------------------------- /docs/rl-algorithms/pqn/pqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/pqn/pqn.png -------------------------------------------------------------------------------- /docs/rl-algorithms/pqn/pqn_lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/pqn/pqn_lstm.png -------------------------------------------------------------------------------- /docs/rl-algorithms/pqn/pqn_state.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/pqn/pqn_state.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/compare.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/jax/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/jax/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/jax/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/jax/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/jax/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/jax/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/qdagger/jax/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/qdagger/jax/compare.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rainbow/rainbow_c51_dqn_bars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rainbow/rainbow_c51_dqn_bars.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rainbow/rainbow_env_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rainbow/rainbow_env_curves.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rainbow/rainbow_sample_eff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rainbow/rainbow_sample_eff.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/dm_control_all_ppo_rpo_8M.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/dm_control_all_ppo_rpo_8M.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/gym.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/gym.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v2_failure_0_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v2_failure_0_5.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v2_part1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v2_part1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v2_part2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v2_part2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v2_part2_0_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v2_part2_0_5.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v4_failure_0_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v4_failure_0_5.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v4_part1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v4_part1.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v4_part2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v4_part2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/rpo/mujoco_v4_part2_0_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/rpo/mujoco_v4_part2_0_5.png -------------------------------------------------------------------------------- /docs/rl-algorithms/sac/BeamRiderNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/sac/BeamRiderNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/sac/BreakoutNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/sac/BreakoutNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/sac/HalfCheetah-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/sac/HalfCheetah-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/sac/Hopper-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/sac/Hopper-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/sac/PongNoFrameskip-v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/sac/PongNoFrameskip-v4.png -------------------------------------------------------------------------------- /docs/rl-algorithms/sac/Walker2d-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/sac/Walker2d-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3-jax/HalfCheetah-v2-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3-jax/HalfCheetah-v2-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3-jax/HalfCheetah-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3-jax/HalfCheetah-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3-jax/Hopper-v2-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3-jax/Hopper-v2-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3-jax/Hopper-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3-jax/Hopper-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3-jax/Walker2d-v2-time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3-jax/Walker2d-v2-time.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3-jax/Walker2d-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3-jax/Walker2d-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3/HalfCheetah-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3/HalfCheetah-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3/Hopper-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3/Hopper-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3/Humanoid-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3/Humanoid-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3/InvertedPendulum-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3/InvertedPendulum-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3/Pusher-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3/Pusher-v2.png -------------------------------------------------------------------------------- /docs/rl-algorithms/td3/Walker2d-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rl-algorithms/td3/Walker2d-v2.png -------------------------------------------------------------------------------- /docs/rlops/docs-update.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rlops/docs-update.png -------------------------------------------------------------------------------- /docs/rlops/rlops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rlops/rlops.png -------------------------------------------------------------------------------- /docs/rlops/tags.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/rlops/tags.png -------------------------------------------------------------------------------- /docs/static/blog/cleanrl-v1/github-action.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/static/blog/cleanrl-v1/github-action.png -------------------------------------------------------------------------------- /docs/static/blog/cleanrl-v1/hf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/static/blog/cleanrl-v1/hf.png -------------------------------------------------------------------------------- /docs/static/blog/cleanrl-v1/rlops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/static/blog/cleanrl-v1/rlops.png -------------------------------------------------------------------------------- /docs/static/o1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/static/o1.png -------------------------------------------------------------------------------- /docs/static/o2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/static/o2.png -------------------------------------------------------------------------------- /docs/static/o3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/static/o3.png -------------------------------------------------------------------------------- /docs/static/pre-commit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vwxyzjn/cleanrl/dcc289fc6f0bda492fa7360a155262cf826b12a5/docs/static/pre-commit.png -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | .grid-container { 2 | display: grid; 3 | grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); 4 | image-rendering: -webkit-optimize-contrast; 5 | grid-gap: 50px; 6 | } -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | Xvfb :1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset &> xvfb.log & 3 | export DISPLAY=:1 4 | set -e 5 | # bash -c "echo vm.overcommit_memory=1 >> /etc/sysctl.conf" && sysctl -p 6 | exec "$@" 7 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: CleanRL 2 | theme: 3 | name: material 4 | features: 5 | # - navigation.instant 6 | - navigation.tracking 7 | # - navigation.tabs 8 | # - navigation.tabs.sticky 9 | - navigation.sections 10 | - navigation.expand 11 | - navigation.top 12 | - search.suggest 13 | - search.highlight 14 | palette: 15 | - media: "(prefers-color-scheme: dark)" 16 | scheme: slate 17 | primary: teal 18 | accent: light green 19 | toggle: 20 | icon: material/lightbulb 21 | name: Switch to light mode 22 | - media: "(prefers-color-scheme: light)" 23 | scheme: default 24 | primary: green 25 | accent: deep orange 26 | toggle: 27 | icon: material/lightbulb-outline 28 | name: Switch to dark mode 29 | plugins: 30 | - search 31 | nav: 32 | - Overview: index.md 33 | - Get Started: 34 | - get-started/installation.md 35 | - get-started/basic-usage.md 36 | - get-started/experiment-tracking.md 37 | - get-started/examples.md 38 | - get-started/benchmark-utility.md 39 | - get-started/zoo.md 40 | - RL Algorithms: 41 | - rl-algorithms/overview.md 42 | - rl-algorithms/ppo.md 43 | - rl-algorithms/dqn.md 44 | - rl-algorithms/c51.md 45 | - rl-algorithms/ddpg.md 46 | - rl-algorithms/sac.md 47 | - rl-algorithms/td3.md 48 | - rl-algorithms/ppg.md 49 | - rl-algorithms/ppo-rnd.md 50 | - rl-algorithms/rpo.md 51 | - rl-algorithms/qdagger.md 52 | - rl-algorithms/ppo-trxl.md 53 | - rl-algorithms/pqn.md 54 | - rl-algorithms/rainbow.md 55 | - Advanced: 56 | - advanced/hyperparameter-tuning.md 57 | - advanced/resume-training.md 58 | - Community: 59 | - contribution.md 60 | - cleanrl-supported-papers-projects.md 61 | - Cloud Integration: 62 | - cloud/installation.md 63 | - cloud/submit-experiments.md 64 | #adding git repo 65 | repo_url: https://github.com/vwxyzjn/cleanrl 66 | repo_name: vwxyzjn/cleanrl 67 | #markdown_extensions 68 | markdown_extensions: 69 | - pymdownx.superfences 70 | - pymdownx.tabbed: 71 | alternate_style: true 72 | - abbr 73 | - pymdownx.highlight 74 | - pymdownx.inlinehilite 75 | - pymdownx.superfences 76 | - pymdownx.snippets 77 | - admonition 78 | - pymdownx.details 79 | - attr_list 80 | - md_in_html 81 | - footnotes 82 | - markdown_include.include: 83 | base_path: docs 84 | - pymdownx.emoji: 85 | emoji_index: !!python/name:materialx.emoji.twemoji 86 | emoji_generator: !!python/name:materialx.emoji.to_svg 87 | - pymdownx.arithmatex: 88 | generic: true 89 | # - toc: 90 | # permalink: true 91 | # - markdown.extensions.codehilite: 92 | # guess_lang: false 93 | # - admonition 94 | # - codehilite 95 | # - extra 96 | # - pymdownx.superfences: 97 | # custom_fences: 98 | # - name: mermaid 99 | # class: mermaid 100 | # format: !!python/name:pymdownx.superfences.fence_code_format '' 101 | # - pymdownx.tabbed 102 | extra_css: 103 | - stylesheets/extra.css 104 | # extra_javascript: 105 | # - js/termynal.js 106 | # - js/custom.js 107 | #footer 108 | extra: 109 | social: 110 | - icon: fontawesome/solid/envelope 111 | link: mailto:costa.huang@outlook.com 112 | - icon: fontawesome/brands/twitter 113 | link: https://twitter.com/vwxyzjn 114 | - icon: fontawesome/brands/github 115 | link: https://github.com/vwxyzjn/cleanrl 116 | copyright: Copyright © 2021, CleanRL. All rights reserved. 117 | extra_javascript: 118 | # - javascripts/mathjax.js 119 | # - https://polyfill.io/v3/polyfill.min.js?features=es6 120 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 121 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cleanrl" 3 | version = "2.0.0b1" 4 | description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features" 5 | authors = ["Costa Huang "] 6 | packages = [ 7 | { include = "cleanrl" }, 8 | { include = "cleanrl_utils" }, 9 | ] 10 | keywords = ["reinforcement", "machine", "learning", "research"] 11 | license="MIT" 12 | readme = "README.md" 13 | 14 | [tool.poetry.dependencies] 15 | python = ">=3.8,<3.11" 16 | tensorboard = "^2.10.0" 17 | wandb = "^0.13.11" 18 | gym = "0.23.1" 19 | torch = ">=1.12.1" 20 | stable-baselines3 = "2.0.0" 21 | gymnasium = ">=0.28.1" 22 | moviepy = "^1.0.3" 23 | pygame = "2.1.0" 24 | huggingface-hub = "^0.11.1" 25 | rich = "<12.0" 26 | tenacity = "^8.2.2" 27 | tyro = "^0.5.10" 28 | pyyaml = "^6.0.1" 29 | 30 | ale-py = {version = "0.8.1", optional = true} 31 | AutoROM = {extras = ["accept-rom-license"], version = "~0.4.2", optional = true} 32 | opencv-python = {version = "^4.6.0.66", optional = true} 33 | procgen = {version = "^0.10.7", optional = true} 34 | pytest = {version = "^7.1.3", optional = true} 35 | mujoco = {version = "<=2.3.3", optional = true} 36 | imageio = {version = "^2.14.1", optional = true} 37 | mkdocs-material = {version = "^8.4.3", optional = true} 38 | markdown-include = {version = "^0.7.0", optional = true} 39 | openrlbenchmark = {version = "^0.1.1b4", optional = true} 40 | jax = {version = "0.4.8", optional = true} 41 | jaxlib = {version = "0.4.7", optional = true} 42 | flax = {version = "0.6.8", optional = true} 43 | optuna = {version = "^3.0.1", optional = true} 44 | optuna-dashboard = {version = "^0.7.2", optional = true} 45 | envpool = {version = "^0.6.4", optional = true} 46 | PettingZoo = {version = "1.18.1", optional = true} 47 | SuperSuit = {version = "3.4.0", optional = true} 48 | multi-agent-ale-py = {version = "0.1.11", optional = true} 49 | boto3 = {version = "^1.24.70", optional = true} 50 | awscli = {version = "^1.31.0", optional = true} 51 | shimmy = {version = ">=1.1.0", optional = true} 52 | dm-control = {version = ">=1.0.10", optional = true} 53 | h5py = {version = ">=3.7.0", optional = true} 54 | optax = {version = "0.1.4", optional = true} 55 | chex = {version = "0.1.5", optional = true} 56 | numpy = ">=1.21.6" 57 | 58 | [tool.poetry.group.dev.dependencies] 59 | pre-commit = "^2.20.0" 60 | 61 | [build-system] 62 | requires = ["poetry-core"] 63 | build-backend = "poetry.core.masonry.api" 64 | 65 | [tool.poetry.extras] 66 | atari = ["ale-py", "AutoROM", "opencv-python", "shimmy"] 67 | procgen = ["procgen"] 68 | plot = ["pandas", "seaborn"] 69 | pytest = ["pytest"] 70 | mujoco = ["mujoco", "imageio"] 71 | jax = ["jax", "jaxlib", "flax"] 72 | docs = ["mkdocs-material", "markdown-include", "openrlbenchmark"] 73 | envpool = ["envpool"] 74 | optuna = ["optuna", "optuna-dashboard"] 75 | pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"] 76 | cloud = ["boto3", "awscli"] 77 | dm_control = ["shimmy", "mujoco", "dm-control", "h5py"] 78 | 79 | # dependencies for algorithm variant (useful when you want to run a specific algorithm) 80 | dqn = [] 81 | dqn_atari = ["ale-py", "AutoROM", "opencv-python"] 82 | dqn_jax = ["jax", "jaxlib", "flax"] 83 | dqn_atari_jax = [ 84 | "ale-py", "AutoROM", "opencv-python", # atari 85 | "jax", "jaxlib", "flax" # jax 86 | ] 87 | c51 = [] 88 | c51_atari = ["ale-py", "AutoROM", "opencv-python"] 89 | c51_jax = ["jax", "jaxlib", "flax"] 90 | c51_atari_jax = [ 91 | "ale-py", "AutoROM", "opencv-python", # atari 92 | "jax", "jaxlib", "flax" # jax 93 | ] 94 | ppo_atari_envpool_xla_jax_scan = [ 95 | "ale-py", "AutoROM", "opencv-python", # atari 96 | "jax", "jaxlib", "flax", # jax 97 | "envpool", # envpool 98 | ] 99 | qdagger_dqn_atari_impalacnn = [ 100 | "ale-py", "AutoROM", "opencv-python" 101 | ] 102 | qdagger_dqn_atari_jax_impalacnn = [ 103 | "ale-py", "AutoROM", "opencv-python", # atari 104 | "jax", "jaxlib", "flax", # jax 105 | ] 106 | -------------------------------------------------------------------------------- /tests/test_atari.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo(): 5 | subprocess.run( 6 | "python cleanrl/ppo_atari.py --num-envs 1 --num-steps 64 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_ppo_lstm(): 13 | subprocess.run( 14 | "python cleanrl/ppo_atari_lstm.py --num-envs 4 --num-steps 64 --total-timesteps 256", 15 | shell=True, 16 | check=True, 17 | ) 18 | -------------------------------------------------------------------------------- /tests/test_atari_gymnasium.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_dqn(): 5 | subprocess.run( 6 | "python cleanrl/dqn_atari.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_dqn_eval(): 13 | subprocess.run( 14 | "python cleanrl/dqn_atari.py --save-model --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_qdagger_dqn_atari_impalacnn(): 21 | subprocess.run( 22 | "python cleanrl/qdagger_dqn_atari_impalacnn.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", 23 | shell=True, 24 | check=True, 25 | ) 26 | 27 | 28 | def test_qdagger_dqn_atari_impalacnn_eval(): 29 | subprocess.run( 30 | "python cleanrl/qdagger_dqn_atari_impalacnn.py --save-model --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", 31 | shell=True, 32 | check=True, 33 | ) 34 | 35 | 36 | def test_c51_atari(): 37 | subprocess.run( 38 | "python cleanrl/c51_atari.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 39 | shell=True, 40 | check=True, 41 | ) 42 | 43 | 44 | def test_c51_atari_eval(): 45 | subprocess.run( 46 | "python cleanrl/c51_atari.py --save-model --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 47 | shell=True, 48 | check=True, 49 | ) 50 | 51 | 52 | def test_sac(): 53 | subprocess.run( 54 | "python cleanrl/sac_atari.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 55 | shell=True, 56 | check=True, 57 | ) 58 | -------------------------------------------------------------------------------- /tests/test_atari_jax_gymnasium.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_dqn_jax(): 5 | subprocess.run( 6 | "python cleanrl/dqn_atari_jax.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_dqn_jax_eval(): 13 | subprocess.run( 14 | "python cleanrl/dqn_atari_jax.py --save-model --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_qdagger_dqn_atari_jax_impalacnn(): 21 | subprocess.run( 22 | "python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", 23 | shell=True, 24 | check=True, 25 | ) 26 | 27 | 28 | def test_qdagger_dqn_atari_jax_impalacnn_eval(): 29 | subprocess.run( 30 | "python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --save-model --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", 31 | shell=True, 32 | check=True, 33 | ) 34 | 35 | 36 | def test_c51_atari_jax(): 37 | subprocess.run( 38 | "python cleanrl/c51_atari_jax.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 39 | shell=True, 40 | check=True, 41 | ) 42 | 43 | 44 | def test_c51_atari_jax_eval(): 45 | subprocess.run( 46 | "python cleanrl/c51_atari_jax.py --save-model --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4", 47 | shell=True, 48 | check=True, 49 | ) 50 | -------------------------------------------------------------------------------- /tests/test_atari_multigpu.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo_multigpu(): 5 | subprocess.run( 6 | "torchrun --standalone --nnodes=1 --nproc_per_node=2 cleanrl/ppo_atari_multigpu.py --num-envs 8 --num-steps 32 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | -------------------------------------------------------------------------------- /tests/test_classic_control.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo(): 5 | subprocess.run( 6 | "python cleanrl/ppo.py --num-envs 1 --num-steps 64 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | -------------------------------------------------------------------------------- /tests/test_classic_control_gymnasium.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_dqn(): 5 | subprocess.run( 6 | "python cleanrl/dqn.py --learning-starts 200 --total-timesteps 205", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_c51(): 13 | subprocess.run( 14 | "python cleanrl/c51.py --learning-starts 200 --total-timesteps 205", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_c51_eval(): 21 | subprocess.run( 22 | "python cleanrl/c51.py --save-model --learning-starts 200 --total-timesteps 205", 23 | shell=True, 24 | check=True, 25 | ) 26 | -------------------------------------------------------------------------------- /tests/test_classic_control_jax_gymnasium.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_dqn_jax(): 5 | subprocess.run( 6 | "python cleanrl/dqn_jax.py --learning-starts 200 --total-timesteps 205", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_c51_jax(): 13 | subprocess.run( 14 | "python cleanrl/c51_jax.py --learning-starts 200 --total-timesteps 205", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_c51_jax_eval(): 21 | subprocess.run( 22 | "python cleanrl/c51_jax.py --save-model --learning-starts 200 --total-timesteps 205", 23 | shell=True, 24 | check=True, 25 | ) 26 | -------------------------------------------------------------------------------- /tests/test_enjoy.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_dqn(): 5 | subprocess.run( 6 | "python enjoy.py --exp-name dqn --env CartPole-v1 --eval-episodes 1", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_dqn_atari(): 13 | subprocess.run( 14 | "python enjoy.py --exp-name dqn_atari --env BreakoutNoFrameskip-v4 --eval-episodes 1", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_dqn_jax(): 21 | subprocess.run( 22 | "python enjoy.py --exp-name dqn_jax --env CartPole-v1 --eval-episodes 1", 23 | shell=True, 24 | check=True, 25 | ) 26 | -------------------------------------------------------------------------------- /tests/test_envpool.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo_atari_envpool(): 5 | subprocess.run( 6 | "python cleanrl/ppo_atari_envpool.py --num-envs 8 --num-steps 32 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_ppo_rnd_envpool(): 13 | subprocess.run( 14 | "python cleanrl/ppo_rnd_envpool.py --num-envs 8 --num-steps 32 --num-iterations-obs-norm-init 1 --total-timesteps 256", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_ppo_atari_envpool_xla_jax(): 21 | subprocess.run( 22 | "python cleanrl/ppo_atari_envpool_xla_jax.py --num-envs 8 --num-steps 6 --update-epochs 1 --num-minibatches 1 --total-timesteps 256", 23 | shell=True, 24 | check=True, 25 | ) 26 | 27 | 28 | def test_ppo_atari_envpool_xla_jax_scan(): 29 | subprocess.run( 30 | "python cleanrl/ppo_atari_envpool_xla_jax_scan.py --num-envs 8 --num-steps 6 --update-epochs 1 --num-minibatches 1 --total-timesteps 256", 31 | shell=True, 32 | check=True, 33 | ) 34 | 35 | 36 | def test_ppo_atari_envpool_xla_jax_scan_eval(): 37 | subprocess.run( 38 | "python cleanrl/ppo_atari_envpool_xla_jax_scan.py --save-model --num-envs 8 --num-steps 6 --update-epochs 1 --num-minibatches 1 --total-timesteps 256", 39 | shell=True, 40 | check=True, 41 | ) 42 | -------------------------------------------------------------------------------- /tests/test_jax_compute_gae.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from functools import partial 3 | from typing import Callable 4 | 5 | import flax 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | 11 | def test_compute_gae(): 12 | @flax.struct.dataclass 13 | class Storage: 14 | dones: jnp.array 15 | values: jnp.array 16 | advantages: jnp.array 17 | returns: jnp.array 18 | rewards: jnp.array 19 | 20 | def compute_gae_once(carry, inp, gamma, gae_lambda): 21 | advantages = carry 22 | nextdone, nextvalues, curvalues, reward = inp 23 | nextnonterminal = 1.0 - nextdone 24 | 25 | delta = reward + gamma * nextvalues * nextnonterminal - curvalues 26 | advantages = delta + gamma * gae_lambda * nextnonterminal * advantages 27 | return advantages, advantages 28 | 29 | def compute_gae_scan( 30 | next_done: np.ndarray, 31 | next_value: np.ndarray, 32 | storage: Storage, 33 | num_envs: int, 34 | compute_gae_once_fn: Callable, 35 | ): 36 | advantages = jnp.zeros((num_envs,)) 37 | dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0) 38 | values = jnp.concatenate([storage.values, next_value[None, :]], axis=0) 39 | _, advantages = jax.lax.scan( 40 | compute_gae_once_fn, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True 41 | ) 42 | storage = storage.replace( 43 | advantages=advantages, 44 | returns=advantages + storage.values, 45 | ) 46 | return storage 47 | 48 | def compute_gae_python_loop( 49 | next_done: np.ndarray, next_value: np.ndarray, storage: Storage, num_steps: int, gamma: float, gae_lambda: float 50 | ): 51 | storage = storage.replace(advantages=storage.advantages.at[:].set(0.0)) 52 | lastgaelam = 0 53 | for t in reversed(range(num_steps)): 54 | if t == num_steps - 1: 55 | nextnonterminal = 1.0 - next_done 56 | nextvalues = next_value 57 | else: 58 | nextnonterminal = 1.0 - storage.dones[t + 1] 59 | nextvalues = storage.values[t + 1] 60 | delta = storage.rewards[t] + gamma * nextvalues * nextnonterminal - storage.values[t] 61 | lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam 62 | storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam)) 63 | storage = storage.replace(returns=storage.advantages + storage.values) 64 | return storage 65 | 66 | num_steps = 123 67 | num_envs = 7 68 | gamma = 0.99 69 | gae_lambda = 0.95 70 | seed = 42 71 | compute_gae_once_fn = partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda) 72 | compute_gae_scan = jax.jit(partial(compute_gae_scan, num_envs=num_envs, compute_gae_once_fn=compute_gae_once_fn)) 73 | compute_gae_python_loop = jax.jit( 74 | partial(compute_gae_python_loop, num_steps=num_steps, gamma=gamma, gae_lambda=gae_lambda) 75 | ) 76 | key = jax.random.PRNGKey(seed) 77 | key, *k = jax.random.split(key, 6) 78 | storage1 = Storage( 79 | dones=jax.random.randint(k[0], (num_steps, num_envs), 0, 2), 80 | values=jax.random.uniform(k[1], (num_steps, num_envs)), 81 | advantages=jnp.zeros((num_steps, num_envs)), 82 | returns=jnp.zeros((num_steps, num_envs)), 83 | rewards=jax.random.uniform(k[2], (num_steps, num_envs), minval=-1, maxval=1), 84 | ) 85 | storage2 = deepcopy(storage1) 86 | next_value = jax.random.uniform(k[3], (num_envs,)) 87 | next_done = jax.random.randint(k[4], (num_envs,), 0, 2) 88 | storage1 = compute_gae_scan(next_done, next_value, storage1) 89 | storage2 = compute_gae_python_loop(next_done, next_value, storage2) 90 | assert (storage1.advantages == storage2.advantages).all() 91 | assert (storage1.returns == storage2.returns).all() 92 | -------------------------------------------------------------------------------- /tests/test_mujoco.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_mujoco(): 5 | """ 6 | Test mujoco 7 | """ 8 | subprocess.run( 9 | "python cleanrl/ddpg_continuous_action.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", 10 | shell=True, 11 | check=True, 12 | ) 13 | subprocess.run( 14 | "python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", 15 | shell=True, 16 | check=True, 17 | ) 18 | subprocess.run( 19 | "python cleanrl/td3_continuous_action.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", 20 | shell=True, 21 | check=True, 22 | ) 23 | subprocess.run( 24 | "python cleanrl/td3_continuous_action_jax.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", 25 | shell=True, 26 | check=True, 27 | ) 28 | subprocess.run( 29 | "python cleanrl/sac_continuous_action.py --env-id Hopper-v4 --batch-size 128 --total-timesteps 135", 30 | shell=True, 31 | check=True, 32 | ) 33 | subprocess.run( 34 | "python cleanrl/ppo_continuous_action.py --env-id Hopper-v4 --num-envs 1 --num-steps 64 --total-timesteps 128", 35 | shell=True, 36 | check=True, 37 | ) 38 | subprocess.run( 39 | "python cleanrl/ppo_continuous_action.py --env-id dm_control/cartpole-balance-v0 --num-envs 1 --num-steps 64 --total-timesteps 128", 40 | shell=True, 41 | check=True, 42 | ) 43 | subprocess.run( 44 | "python cleanrl/rpo_continuous_action.py --env-id Hopper-v4 --num-envs 1 --num-steps 64 --total-timesteps 128", 45 | shell=True, 46 | check=True, 47 | ) 48 | subprocess.run( 49 | "python cleanrl/rpo_continuous_action.py --env-id dm_control/cartpole-balance-v0 --num-envs 1 --num-steps 64 --total-timesteps 128", 50 | shell=True, 51 | check=True, 52 | ) 53 | 54 | 55 | def test_mujoco_eval(): 56 | """ 57 | Test mujoco_eval 58 | """ 59 | subprocess.run( 60 | "python cleanrl/ddpg_continuous_action.py --save-model --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", 61 | shell=True, 62 | check=True, 63 | ) 64 | subprocess.run( 65 | "python cleanrl/ddpg_continuous_action_jax.py --save-model --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", 66 | shell=True, 67 | check=True, 68 | ) 69 | -------------------------------------------------------------------------------- /tests/test_pettingzoo_ma_atari.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo(): 5 | subprocess.run( 6 | "python cleanrl/ppo_pettingzoo_ma_atari.py --num-steps 32 --num-envs 6 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | -------------------------------------------------------------------------------- /tests/test_procgen.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo(): 5 | subprocess.run( 6 | "python cleanrl/ppo_procgen.py --num-envs 1 --num-steps 64 --total-timesteps 256 --num-minibatches 2", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_ppg(): 13 | subprocess.run( 14 | "python cleanrl/ppg_procgen.py --num-envs 1 --num-steps 64 --total-timesteps 256 --num-minibatches 2 --n-iteration 1", 15 | shell=True, 16 | check=True, 17 | ) 18 | -------------------------------------------------------------------------------- /tests/test_tuner.py: -------------------------------------------------------------------------------- 1 | import optuna 2 | 3 | from cleanrl_utils.tuner import Tuner 4 | 5 | 6 | def test_tuner(): 7 | tuner = Tuner( 8 | script="cleanrl/ppo.py", 9 | metric="charts/episodic_return", 10 | metric_last_n_average_window=50, 11 | direction="maximize", 12 | target_scores={ 13 | "CartPole-v1": [0, 500], 14 | "Acrobot-v1": [-500, 0], 15 | }, 16 | params_fn=lambda trial: { 17 | "learning-rate": trial.suggest_float("learning-rate", 0.0003, 0.003, log=True), 18 | "num-minibatches": trial.suggest_categorical("num-minibatches", [1, 2, 4]), 19 | "update-epochs": trial.suggest_categorical("update-epochs", [1, 2, 4]), 20 | "num-steps": trial.suggest_categorical("num-steps", [1200]), 21 | "vf-coef": trial.suggest_float("vf-coef", 0, 5), 22 | "max-grad-norm": trial.suggest_float("max-grad-norm", 0, 5), 23 | "total-timesteps": 1200, 24 | "num-envs": 1, 25 | }, 26 | pruner=optuna.pruners.MedianPruner(n_startup_trials=5), 27 | sampler=optuna.samplers.TPESampler(), 28 | # wandb_kwargs={"project": "cleanrl"}, 29 | ) 30 | tuner.tune( 31 | num_trials=1, 32 | num_seeds=1, 33 | ) 34 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_submit_exp_no_build(): 5 | subprocess.run( 6 | "poetry run python -m cleanrl_utils.submit_exp --docker-tag vwxyzjn/cleanrl:latest --wandb-key xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 7 | shell=True, 8 | check=True, 9 | ) 10 | -------------------------------------------------------------------------------- /tuner_example.py: -------------------------------------------------------------------------------- 1 | import optuna 2 | 3 | from cleanrl_utils.tuner import Tuner 4 | 5 | tuner = Tuner( 6 | script="cleanrl/ppo.py", 7 | metric="charts/episodic_return", 8 | metric_last_n_average_window=50, 9 | direction="maximize", 10 | aggregation_type="average", 11 | target_scores={ 12 | "CartPole-v1": [0, 500], 13 | "Acrobot-v1": [-500, 0], 14 | }, 15 | params_fn=lambda trial: { 16 | "learning-rate": trial.suggest_float("learning-rate", 0.0003, 0.003, log=True), 17 | "num-minibatches": trial.suggest_categorical("num-minibatches", [1, 2, 4]), 18 | "update-epochs": trial.suggest_categorical("update-epochs", [1, 2, 4, 8]), 19 | "num-steps": trial.suggest_categorical("num-steps", [5, 16, 32, 64, 128]), 20 | "vf-coef": trial.suggest_float("vf-coef", 0, 5), 21 | "max-grad-norm": trial.suggest_float("max-grad-norm", 0, 5), 22 | "total-timesteps": 100000, 23 | "num-envs": 16, 24 | }, 25 | pruner=optuna.pruners.MedianPruner(n_startup_trials=5), 26 | sampler=optuna.samplers.TPESampler(), 27 | ) 28 | tuner.tune( 29 | num_trials=100, 30 | num_seeds=3, 31 | ) 32 | --------------------------------------------------------------------------------