├── CHANGELOG.md ├── doc └── artifacts │ ├── SAC.png │ ├── TD3.png │ ├── ppo.png │ ├── sac_gpu.png │ ├── ppo_speed.png │ ├── sac_speed.png │ └── td3_speed.png ├── requirements ├── requirements.txt ├── requirements-envpool.txt ├── requirements-jax.txt ├── requirements-mujoco.txt └── requirements-atari.txt ├── tests ├── test_sac_continuous.py ├── test_ppo_continuous.py ├── test_dqn.py ├── test_td3_continuous.py └── test_atari.py ├── .gitpod.Dockerfile ├── .gitpod.yml ├── run.sh ├── CONTRIBUTING.md ├── .gitignore ├── .pre-commit-config.yaml ├── mkdocs.yml ├── CODE_OF_CONDUCT.md ├── leanrl ├── dqn.py ├── dqn_jax.py ├── dqn_torchcompile.py ├── td3_continuous_action.py ├── td3_continuous_action_jax.py ├── sac_continuous_action.py ├── ppo_continuous_action.py ├── td3_continuous_action_torchcompile.py ├── ppo_atari_envpool.py ├── sac_continuous_action_torchcompile.py ├── ppo_continuous_action_torchcompile.py └── ppo_atari_envpool_torchcompile.py ├── README.md └── LICENSE /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/artifacts/SAC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/LeanRL/HEAD/doc/artifacts/SAC.png -------------------------------------------------------------------------------- /doc/artifacts/TD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/LeanRL/HEAD/doc/artifacts/TD3.png -------------------------------------------------------------------------------- /doc/artifacts/ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/LeanRL/HEAD/doc/artifacts/ppo.png -------------------------------------------------------------------------------- /doc/artifacts/sac_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/LeanRL/HEAD/doc/artifacts/sac_gpu.png -------------------------------------------------------------------------------- /doc/artifacts/ppo_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/LeanRL/HEAD/doc/artifacts/ppo_speed.png -------------------------------------------------------------------------------- /doc/artifacts/sac_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/LeanRL/HEAD/doc/artifacts/sac_speed.png -------------------------------------------------------------------------------- /doc/artifacts/td3_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meta-pytorch/LeanRL/HEAD/doc/artifacts/td3_speed.png -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | gymnasium<1.0.0 2 | jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" 3 | matplotlib 4 | moviepy 5 | numpy<2.0 6 | pandas 7 | protobuf 8 | pygame 9 | stable-baselines3 10 | tqdm 11 | wandb 12 | torchrl 13 | tensordict 14 | tyro 15 | -------------------------------------------------------------------------------- /requirements/requirements-envpool.txt: -------------------------------------------------------------------------------- 1 | gym<0.26 2 | envpool 3 | jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" 4 | matplotlib 5 | moviepy 6 | numpy<2.0 7 | pandas 8 | protobuf 9 | pygame 10 | stable-baselines3 11 | tensordict 12 | torchrl 13 | tqdm 14 | tyro 15 | wandb 16 | -------------------------------------------------------------------------------- /requirements/requirements-jax.txt: -------------------------------------------------------------------------------- 1 | flax==0.6.8 2 | gym 3 | gymnasium<1.0.0 4 | jax-jumpy==1.0.0 5 | jax-jumpy==1.0.0 6 | jax[cuda]==0.4.8 7 | matplotlib 8 | moviepy 9 | numpy<2.0 10 | pandas 11 | protobuf 12 | pygame 13 | stable-baselines3 14 | tensordict 15 | torchrl 16 | tqdm 17 | tyro 18 | wandb -------------------------------------------------------------------------------- /requirements/requirements-mujoco.txt: -------------------------------------------------------------------------------- 1 | gym 2 | gymnasium[mujoco]<1.0.0 3 | jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" 4 | matplotlib 5 | moviepy 6 | numpy<2.0 7 | pandas 8 | protobuf 9 | pygame 10 | stable-baselines3 11 | tqdm 12 | wandb 13 | torchrl 14 | tensordict 15 | tyro 16 | -------------------------------------------------------------------------------- /requirements/requirements-atari.txt: -------------------------------------------------------------------------------- 1 | gymnasium[atari,accept-rom-license]<1.0.0 2 | jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" 3 | matplotlib 4 | moviepy 5 | numpy<2.0 6 | pandas 7 | protobuf 8 | pygame 9 | stable-baselines3 10 | tqdm 11 | wandb 12 | torchrl 13 | tensordict 14 | tyro 15 | -------------------------------------------------------------------------------- /tests/test_sac_continuous.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_sac_continuous_action(): 5 | subprocess.run( 6 | "python leanrl/sac_continuous_action.py --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_sac_continuous_action_torchcompile(): 13 | subprocess.run( 14 | "python leanrl/sac_continuous_action_torchcompile.py --total-timesteps 256 --compile --cudagraphs", 15 | shell=True, 16 | check=True, 17 | ) 18 | -------------------------------------------------------------------------------- /tests/test_ppo_continuous.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo_continuous_action(): 5 | subprocess.run( 6 | "python leanrl/ppo_continuous_action.py --num-envs 1 --num-steps 64 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_ppo_continuous_action_torchcompile(): 13 | subprocess.run( 14 | "python leanrl/ppo_continuous_action_torchcompile.py --num-envs 1 --num-steps 64 --total-timesteps 256 --compile --cudagraphs", 15 | shell=True, 16 | check=True, 17 | ) 18 | -------------------------------------------------------------------------------- /tests/test_dqn.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_dqn(): 5 | subprocess.run( 6 | "python leanrl/dqn.py --num-envs 1 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_dqn_jax(): 13 | subprocess.run( 14 | "python leanrl/dqn_jax.py --num-envs 1 --total-timesteps 256", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_dqn_torchcompile(): 21 | subprocess.run( 22 | "python leanrl/dqn_torchcompile.py --num-envs 1 --total-timesteps 256 --compile --cudagraphs", 23 | shell=True, 24 | check=True, 25 | ) 26 | -------------------------------------------------------------------------------- /tests/test_td3_continuous.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_td3_continuous_action(): 5 | subprocess.run( 6 | "python leanrl/td3_continuous_action.py --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_td3_continuous_action_jax(): 13 | subprocess.run( 14 | "python leanrl/td3_continuous_action_jax.py --total-timesteps 256", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_td3_continuous_action_torchcompile(): 21 | subprocess.run( 22 | "python leanrl/td3_continuous_action_torchcompile.py --total-timesteps 256 --compile --cudagraphs", 23 | shell=True, 24 | check=True, 25 | ) 26 | -------------------------------------------------------------------------------- /.gitpod.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-full-vnc:latest 2 | USER gitpod 3 | RUN if ! grep -q "export PIP_USER=no" "$HOME/.bashrc"; then printf '%s\n' "export PIP_USER=no" >> "$HOME/.bashrc"; fi 4 | 5 | # install ubuntu dependencies 6 | ENV DEBIAN_FRONTEND=noninteractive 7 | RUN sudo apt-get update && \ 8 | sudo apt-get -y install xvfb ffmpeg git build-essential python-opengl 9 | 10 | # install python dependencies 11 | RUN mkdir cleanrl_utils && touch cleanrl_utils/__init__.py 12 | RUN pip install poetry --upgrade 13 | RUN poetry config virtualenvs.in-project true 14 | 15 | # install mujoco_py 16 | RUN sudo apt-get -y install wget unzip software-properties-common \ 17 | libgl1-mesa-dev \ 18 | libgl1-mesa-glx \ 19 | libglew-dev \ 20 | libosmesa6-dev patchelf 21 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | image: 2 | file: .gitpod.Dockerfile 3 | 4 | tasks: 5 | - init: poetry install 6 | 7 | # vscode: 8 | # extensions: 9 | # - learnpack.learnpack-vscode 10 | 11 | github: 12 | prebuilds: 13 | # enable for the master/default branch (defaults to true) 14 | master: true 15 | # enable for all branches in this repo (defaults to false) 16 | branches: true 17 | # enable for pull requests coming from this repo (defaults to true) 18 | pullRequests: true 19 | # enable for pull requests coming from forks (defaults to false) 20 | pullRequestsFromForks: true 21 | # add a "Review in Gitpod" button as a comment to pull requests (defaults to true) 22 | addComment: false 23 | # add a "Review in Gitpod" button to pull requests (defaults to false) 24 | addBadge: false 25 | # add a label once the prebuild is ready to pull requests (defaults to false) 26 | addLabel: prebuilt-in-gitpod 27 | -------------------------------------------------------------------------------- /tests/test_atari.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def test_ppo(): 5 | subprocess.run( 6 | "python leanrl/ppo_atari.py --num-envs 1 --num-steps 64 --total-timesteps 256", 7 | shell=True, 8 | check=True, 9 | ) 10 | 11 | 12 | def test_ppo_envpool(): 13 | subprocess.run( 14 | "python leanrl/ppo_atari_envpool.py --num-envs 1 --num-steps 64 --total-timesteps 256", 15 | shell=True, 16 | check=True, 17 | ) 18 | 19 | 20 | def test_ppo_atari_envpool_torchcompile(): 21 | subprocess.run( 22 | "python leanrl/ppo_atari_envpool_torchcompile.py --num-envs 1 --num-steps 64 --total-timesteps 256 --compile --cudagraphs", 23 | shell=True, 24 | check=True, 25 | ) 26 | 27 | 28 | def test_ppo_atari_envpool_xla_jax(): 29 | subprocess.run( 30 | "python leanrl/ppo_atari_envpool_xla_jax.py --num-envs 1 --num-steps 64 --total-timesteps 256", 31 | shell=True, 32 | check=True, 33 | ) 34 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Execute scripts with different seeds and additional arguments for torchcompile scripts 3 | scripts=( 4 | leanrl/ppo_continuous_action.py 5 | leanrl/ppo_continuous_action_torchcompile.py 6 | leanrl/dqn.py 7 | leanrl/dqn_jax.py 8 | leanrl/dqn_torchcompile.py 9 | leanrl/td3_continuous_action_jax.py 10 | leanrl/td3_continuous_action.py 11 | leanrl/td3_continuous_action_torchcompile.py 12 | leanrl/ppo_atari_envpool.py 13 | leanrl/ppo_atari_envpool_torchcompile.py 14 | leanrl/ppo_atari_envpool_xla_jax.py 15 | leanrl/sac_continuous_action.py 16 | leanrl/sac_continuous_action_torchcompile.py 17 | ) 18 | for script in "${scripts[@]}"; do 19 | for seed in 21 31 41; do 20 | if [[ $script == *_torchcompile.py ]]; then 21 | python $script --seed=$seed --cudagraphs 22 | python $script --seed=$seed --cudagraphs --compile 23 | python $script --seed=$seed --compile 24 | python $script --seed=$seed 25 | else 26 | python $script --seed=$seed 27 | fi 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We welcome contribution from the community. 4 | The project is - as is CleanRL - under MIT license which is a very permissive license. 5 | 6 | 7 | ## Getting Started with Contributions 8 | To contribute to this project, please follow these steps: 9 | 10 | ### 1. Clone and Fork the Repository 11 | First, clone the repository using the following command: 12 | ```bash 13 | git clone https://github.com/meta-pytorch/leanrl.git 14 | ``` 15 | 16 | Then, fork the repository by clicking the "Fork" button on the top-right corner of the GitHub page. 17 | This will create a copy of the repository in your own account. 18 | Add the fork to your local list of remote forks: 19 | ```bash 20 | git remote add https://github.com//leanrl.git 21 | ``` 22 | 23 | ### 2. Create a New Branch 24 | Create a new branch for your changes using the following command: 25 | ```bash 26 | git checkout -b [branch-name] 27 | ``` 28 | Choose a descriptive name for your branch that indicates the type of change you're making (e.g., `fix-bug-123`, `add-feature-xyz`, etc.). 29 | ### 3. Make Changes and Commit 30 | Make your changes to the codebase, then add them to the staging area using: 31 | ```bash 32 | git add 33 | ``` 34 | 35 | Commit your changes with a clear and concise commit message: 36 | ```bash 37 | git commit -m "[commit-message]" 38 | ``` 39 | Follow standard commit message guidelines, such as starting with a verb (e.g., "Fix", "Add", "Update") and keeping it short. 40 | 41 | ### 4. Push Your Changes 42 | Push your changes to your forked repository using: 43 | ```bash 44 | git push --set-upstream 45 | ``` 46 | 47 | ### 5. Create a Pull Request 48 | Finally, create a pull request to merge your changes into the main repository. Go to your forked repository on GitHub, click on the "New pull request" button, and select the branch you just pushed. Fill out the pull request form with a clear description of your changes and submit it. 49 | 50 | We'll review your pull request and provide feedback or merge it into the main repository. 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | slurm 2 | .aim 3 | runs 4 | balance_bot.xml 5 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/examples 6 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/isaacgym 7 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/LICENSE.txt 8 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/rlgpu_conda_env.yml 9 | cleanrl/ppo_continuous_action_isaacgym/isaacgym/setup.py 10 | 11 | IsaacGym_Preview_3_Package.tar.gz 12 | IsaacGym_Preview_4_Package.tar.gz 13 | cleanrl_hpopt.db 14 | debug.sh.docker.sh 15 | docker_cache 16 | rl-video-*.mp4 17 | rl-video-*.json 18 | cleanrl_utils/charts_episode_reward 19 | tutorials 20 | .DS_Store 21 | *.tfevents.* 22 | wandb 23 | openaigym.* 24 | videos/* 25 | cleanrl/videos/* 26 | benchmark/**/*.svg 27 | benchmark/**/*.pkl 28 | mjkey.txt 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # pyenv 103 | # .python-version 104 | 105 | # celery beat schedule file 106 | celerybeat-schedule 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/pyupgrade 3 | rev: v2.31.1 4 | hooks: 5 | - id: pyupgrade 6 | args: 7 | - --py37-plus 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | args: 13 | - --profile=black 14 | - --skip-glob=wandb/**/* 15 | - --thirdparty=wandb 16 | - repo: https://github.com/myint/autoflake 17 | rev: v1.4 18 | hooks: 19 | - id: autoflake 20 | args: 21 | - -r 22 | - --exclude=wandb 23 | - --in-place 24 | - --remove-unused-variables 25 | - --remove-all-unused-imports 26 | - repo: https://github.com/python/black 27 | rev: 22.3.0 28 | hooks: 29 | - id: black 30 | args: 31 | - --line-length=127 32 | - --exclude=wandb 33 | - repo: https://github.com/codespell-project/codespell 34 | rev: v2.1.0 35 | hooks: 36 | - id: codespell 37 | args: 38 | - --ignore-words-list=nd,reacher,thist,ths,magent,ba 39 | - --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb 40 | - repo: https://github.com/python-poetry/poetry 41 | rev: 1.3.2 42 | hooks: 43 | - id: poetry-export 44 | name: poetry-export requirements.txt 45 | args: ["--without-hashes", "-o", "requirements/requirements.txt"] 46 | stages: [manual] 47 | - id: poetry-export 48 | name: poetry-export requirements-atari.txt 49 | args: ["--without-hashes", "-o", "requirements/requirements-atari.txt", "-E", "atari"] 50 | stages: [manual] 51 | - id: poetry-export 52 | name: poetry-export requirements-mujoco.txt 53 | args: ["--without-hashes", "-o", "requirements/requirements-mujoco.txt", "-E", "mujoco"] 54 | stages: [manual] 55 | - id: poetry-export 56 | name: poetry-export requirements-dm_control.txt 57 | args: ["--without-hashes", "-o", "requirements/requirements-dm_control.txt", "-E", "dm_control"] 58 | stages: [manual] 59 | - id: poetry-export 60 | name: poetry-export requirements-procgen.txt 61 | args: ["--without-hashes", "-o", "requirements/requirements-procgen.txt", "-E", "procgen"] 62 | stages: [manual] 63 | - id: poetry-export 64 | name: poetry-export requirements-envpool.txt 65 | args: ["--without-hashes", "-o", "requirements/requirements-envpool.txt", "-E", "envpool"] 66 | stages: [manual] 67 | - id: poetry-export 68 | name: poetry-export requirements-pettingzoo.txt 69 | args: ["--without-hashes", "-o", "requirements/requirements-pettingzoo.txt", "-E", "pettingzoo"] 70 | stages: [manual] 71 | - id: poetry-export 72 | name: poetry-export requirements-jax.txt 73 | args: ["--without-hashes", "-o", "requirements/requirements-jax.txt", "-E", "jax"] 74 | stages: [manual] 75 | - id: poetry-export 76 | name: poetry-export requirements-optuna.txt 77 | args: ["--without-hashes", "-o", "requirements/requirements-optuna.txt", "-E", "optuna"] 78 | stages: [manual] 79 | - id: poetry-export 80 | name: poetry-export requirements-docs.txt 81 | args: ["--without-hashes", "-o", "requirements/requirements-docs.txt", "-E", "docs"] 82 | stages: [manual] 83 | - id: poetry-export 84 | name: poetry-export requirements-cloud.txt 85 | args: ["--without-hashes", "-o", "requirements/requirements-cloud.txt", "-E", "cloud"] 86 | stages: [manual] 87 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: CleanRL 2 | theme: 3 | name: material 4 | features: 5 | # - navigation.instant 6 | - navigation.tracking 7 | # - navigation.tabs 8 | # - navigation.tabs.sticky 9 | - navigation.sections 10 | - navigation.expand 11 | - navigation.top 12 | - search.suggest 13 | - search.highlight 14 | palette: 15 | - media: "(prefers-color-scheme: dark)" 16 | scheme: slate 17 | primary: teal 18 | accent: light green 19 | toggle: 20 | icon: material/lightbulb 21 | name: Switch to light mode 22 | - media: "(prefers-color-scheme: light)" 23 | scheme: default 24 | primary: green 25 | accent: deep orange 26 | toggle: 27 | icon: material/lightbulb-outline 28 | name: Switch to dark mode 29 | plugins: 30 | - search 31 | nav: 32 | - Overview: index.md 33 | - Get Started: 34 | - get-started/installation.md 35 | - get-started/basic-usage.md 36 | - get-started/experiment-tracking.md 37 | - get-started/examples.md 38 | - get-started/benchmark-utility.md 39 | - get-started/zoo.md 40 | - RL Algorithms: 41 | - rl-algorithms/overview.md 42 | - rl-algorithms/ppo.md 43 | - rl-algorithms/dqn.md 44 | - rl-algorithms/c51.md 45 | - rl-algorithms/ddpg.md 46 | - rl-algorithms/sac.md 47 | - rl-algorithms/td3.md 48 | - rl-algorithms/ppg.md 49 | - rl-algorithms/ppo-rnd.md 50 | - rl-algorithms/rpo.md 51 | - rl-algorithms/qdagger.md 52 | - Advanced: 53 | - advanced/hyperparameter-tuning.md 54 | - advanced/resume-training.md 55 | - Community: 56 | - contribution.md 57 | - leanrl-supported-papers-projects.md 58 | - Cloud Integration: 59 | - cloud/installation.md 60 | - cloud/submit-experiments.md 61 | #adding git repo 62 | repo_url: https://github.com/vwxyzjn/cleanrl 63 | repo_name: vwxyzjn/leanrl 64 | #markdown_extensions 65 | markdown_extensions: 66 | - pymdownx.superfences 67 | - pymdownx.tabbed: 68 | alternate_style: true 69 | - abbr 70 | - pymdownx.highlight 71 | - pymdownx.inlinehilite 72 | - pymdownx.superfences 73 | - pymdownx.snippets 74 | - admonition 75 | - pymdownx.details 76 | - attr_list 77 | - md_in_html 78 | - footnotes 79 | - markdown_include.include: 80 | base_path: docs 81 | - pymdownx.emoji: 82 | emoji_index: !!python/name:materialx.emoji.twemoji 83 | emoji_generator: !!python/name:materialx.emoji.to_svg 84 | - pymdownx.arithmatex: 85 | generic: true 86 | # - toc: 87 | # permalink: true 88 | # - markdown.extensions.codehilite: 89 | # guess_lang: false 90 | # - admonition 91 | # - codehilite 92 | # - extra 93 | # - pymdownx.superfences: 94 | # custom_fences: 95 | # - name: mermaid 96 | # class: mermaid 97 | # format: !!python/name:pymdownx.superfences.fence_code_format '' 98 | # - pymdownx.tabbed 99 | extra_css: 100 | - stylesheets/extra.css 101 | # extra_javascript: 102 | # - js/termynal.js 103 | # - js/custom.js 104 | #footer 105 | extra: 106 | social: 107 | - icon: fontawesome/solid/envelope 108 | link: mailto:costa.huang@outlook.com 109 | - icon: fontawesome/brands/twitter 110 | link: https://twitter.com/vwxyzjn 111 | - icon: fontawesome/brands/github 112 | link: https://github.com/vwxyzjn/cleanrl 113 | copyright: Copyright © 2021, CleanRL. All rights reserved. 114 | extra_javascript: 115 | # - javascripts/mathjax.js 116 | # - https://polyfill.io/v3/polyfill.min.js?features=es6 117 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 118 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /leanrl/dqn.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy 2 | import os 3 | import random 4 | import time 5 | from collections import deque 6 | from dataclasses import dataclass 7 | 8 | import gymnasium as gym 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import tqdm 15 | import tyro 16 | import wandb 17 | from stable_baselines3.common.buffers import ReplayBuffer 18 | 19 | 20 | @dataclass 21 | class Args: 22 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 23 | """the name of this experiment""" 24 | seed: int = 1 25 | """seed of the experiment""" 26 | torch_deterministic: bool = True 27 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 28 | cuda: bool = True 29 | """if toggled, cuda will be enabled by default""" 30 | capture_video: bool = False 31 | """whether to capture videos of the agent performances (check out `videos` folder)""" 32 | 33 | # Algorithm specific arguments 34 | env_id: str = "CartPole-v1" 35 | """the id of the environment""" 36 | total_timesteps: int = 500000 37 | """total timesteps of the experiments""" 38 | learning_rate: float = 2.5e-4 39 | """the learning rate of the optimizer""" 40 | num_envs: int = 1 41 | """the number of parallel game environments""" 42 | buffer_size: int = 10000 43 | """the replay memory buffer size""" 44 | gamma: float = 0.99 45 | """the discount factor gamma""" 46 | tau: float = 1.0 47 | """the target network update rate""" 48 | target_network_frequency: int = 500 49 | """the timesteps it takes to update the target network""" 50 | batch_size: int = 128 51 | """the batch size of sample from the reply memory""" 52 | start_e: float = 1 53 | """the starting epsilon for exploration""" 54 | end_e: float = 0.05 55 | """the ending epsilon for exploration""" 56 | exploration_fraction: float = 0.5 57 | """the fraction of `total-timesteps` it takes from start-e to go end-e""" 58 | learning_starts: int = 10000 59 | """timestep to start learning""" 60 | train_frequency: int = 10 61 | """the frequency of training""" 62 | 63 | measure_burnin: int = 3 64 | """Number of burn-in iterations for speed measure.""" 65 | 66 | 67 | def make_env(env_id, seed, idx, capture_video, run_name): 68 | def thunk(): 69 | if capture_video and idx == 0: 70 | env = gym.make(env_id, render_mode="rgb_array") 71 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 72 | else: 73 | env = gym.make(env_id) 74 | env = gym.wrappers.RecordEpisodeStatistics(env) 75 | env.action_space.seed(seed) 76 | 77 | return env 78 | 79 | return thunk 80 | 81 | 82 | # ALGO LOGIC: initialize agent here: 83 | class QNetwork(nn.Module): 84 | def __init__(self, env): 85 | super().__init__() 86 | self.network = nn.Sequential( 87 | nn.Linear(np.array(env.single_observation_space.shape).prod(), 120), 88 | nn.ReLU(), 89 | nn.Linear(120, 84), 90 | nn.ReLU(), 91 | nn.Linear(84, env.single_action_space.n), 92 | ) 93 | 94 | def forward(self, x): 95 | return self.network(x) 96 | 97 | 98 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 99 | slope = (end_e - start_e) / duration 100 | return max(slope * t + start_e, end_e) 101 | 102 | 103 | if __name__ == "__main__": 104 | import stable_baselines3 as sb3 105 | 106 | if sb3.__version__ < "2.0": 107 | raise ValueError( 108 | """Ongoing migration: run the following command to install the new dependencies: 109 | 110 | poetry run pip install "stable_baselines3==2.0.0a1" 111 | """ 112 | ) 113 | args = tyro.cli(Args) 114 | assert args.num_envs == 1, "vectorized envs are not supported at the moment" 115 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" 116 | 117 | wandb.init( 118 | project="dqn", 119 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 120 | config=vars(args), 121 | save_code=True, 122 | ) 123 | 124 | # TRY NOT TO MODIFY: seeding 125 | random.seed(args.seed) 126 | np.random.seed(args.seed) 127 | torch.manual_seed(args.seed) 128 | torch.backends.cudnn.deterministic = args.torch_deterministic 129 | 130 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 131 | 132 | # env setup 133 | envs = gym.vector.SyncVectorEnv( 134 | [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] 135 | ) 136 | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" 137 | 138 | q_network = QNetwork(envs).to(device) 139 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 140 | target_network = QNetwork(envs).to(device) 141 | target_network.load_state_dict(q_network.state_dict()) 142 | 143 | rb = ReplayBuffer( 144 | args.buffer_size, 145 | envs.single_observation_space, 146 | envs.single_action_space, 147 | device, 148 | handle_timeout_termination=False, 149 | ) 150 | start_time = None 151 | 152 | # TRY NOT TO MODIFY: start the game 153 | obs, _ = envs.reset(seed=args.seed) 154 | pbar = tqdm.tqdm(range(args.total_timesteps)) 155 | avg_returns = deque(maxlen=20) 156 | 157 | for global_step in pbar: 158 | if global_step == args.learning_starts + args.measure_burnin: 159 | start_time = time.time() 160 | global_step_start = global_step 161 | 162 | # ALGO LOGIC: put action logic here 163 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) 164 | if random.random() < epsilon: 165 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 166 | else: 167 | q_values = q_network(torch.Tensor(obs).to(device)) 168 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 169 | 170 | # TRY NOT TO MODIFY: execute the game and log data. 171 | next_obs, rewards, terminations, truncations, infos = envs.step(actions) 172 | 173 | # TRY NOT TO MODIFY: record rewards for plotting purposes 174 | if "final_info" in infos: 175 | for info in infos["final_info"]: 176 | if info and "episode" in info: 177 | avg_returns.append(info["episode"]["r"]) 178 | desc = f"global_step={global_step}, episodic_return={np.array(avg_returns).mean()}" 179 | 180 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` 181 | real_next_obs = next_obs.copy() 182 | for idx, trunc in enumerate(truncations): 183 | if trunc: 184 | real_next_obs[idx] = infos["final_observation"][idx] 185 | rb.add(obs, real_next_obs, actions, rewards, terminations, infos) 186 | 187 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 188 | obs = next_obs 189 | 190 | # ALGO LOGIC: training. 191 | if global_step > args.learning_starts: 192 | if global_step % args.train_frequency == 0: 193 | data = rb.sample(args.batch_size) 194 | with torch.no_grad(): 195 | target_max, _ = target_network(data.next_observations).max(dim=1) 196 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) 197 | old_val = q_network(data.observations).gather(1, data.actions).squeeze() 198 | loss = F.mse_loss(td_target, old_val) 199 | 200 | # optimize the model 201 | optimizer.zero_grad() 202 | loss.backward() 203 | optimizer.step() 204 | 205 | # update target network 206 | if global_step % args.target_network_frequency == 0: 207 | for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): 208 | target_network_param.data.copy_( 209 | args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data 210 | ) 211 | 212 | if global_step % 100 == 0 and start_time is not None: 213 | speed = (global_step - global_step_start) / (time.time() - start_time) 214 | pbar.set_description(f"speed: {speed: 4.2f} sps, " + desc) 215 | with torch.no_grad(): 216 | logs = { 217 | "episode_return": torch.tensor(avg_returns).mean(), 218 | "loss": loss.mean(), 219 | "epsilon": epsilon, 220 | } 221 | wandb.log( 222 | { 223 | "speed": speed, 224 | **logs, 225 | }, 226 | step=global_step, 227 | ) 228 | 229 | envs.close() 230 | -------------------------------------------------------------------------------- /leanrl/dqn_jax.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_jaxpy 2 | import os 3 | import random 4 | import time 5 | from collections import deque 6 | from dataclasses import dataclass 7 | 8 | import flax 9 | import flax.linen as nn 10 | import gymnasium as gym 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import tqdm 16 | import tyro 17 | import wandb 18 | from flax.training.train_state import TrainState 19 | from stable_baselines3.common.buffers import ReplayBuffer 20 | 21 | 22 | @dataclass 23 | class Args: 24 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 25 | """the name of this experiment""" 26 | seed: int = 1 27 | """seed of the experiment""" 28 | capture_video: bool = False 29 | """whether to capture videos of the agent performances (check out `videos` folder)""" 30 | 31 | # Algorithm specific arguments 32 | env_id: str = "CartPole-v1" 33 | """the id of the environment""" 34 | total_timesteps: int = 500000 35 | """total timesteps of the experiments""" 36 | learning_rate: float = 2.5e-4 37 | """the learning rate of the optimizer""" 38 | num_envs: int = 1 39 | """the number of parallel game environments""" 40 | buffer_size: int = 10000 41 | """the replay memory buffer size""" 42 | gamma: float = 0.99 43 | """the discount factor gamma""" 44 | tau: float = 1.0 45 | """the target network update rate""" 46 | target_network_frequency: int = 500 47 | """the timesteps it takes to update the target network""" 48 | batch_size: int = 128 49 | """the batch size of sample from the reply memory""" 50 | start_e: float = 1 51 | """the starting epsilon for exploration""" 52 | end_e: float = 0.05 53 | """the ending epsilon for exploration""" 54 | exploration_fraction: float = 0.5 55 | """the fraction of `total-timesteps` it takes from start-e to go end-e""" 56 | learning_starts: int = 10000 57 | """timestep to start learning""" 58 | train_frequency: int = 10 59 | """the frequency of training""" 60 | 61 | measure_burnin: int = 3 62 | """Number of burn-in iterations for speed measure.""" 63 | 64 | 65 | def make_env(env_id, seed, idx, capture_video, run_name): 66 | def thunk(): 67 | if capture_video and idx == 0: 68 | env = gym.make(env_id, render_mode="rgb_array") 69 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 70 | else: 71 | env = gym.make(env_id) 72 | env = gym.wrappers.RecordEpisodeStatistics(env) 73 | env.action_space.seed(seed) 74 | 75 | return env 76 | 77 | return thunk 78 | 79 | 80 | # ALGO LOGIC: initialize agent here: 81 | class QNetwork(nn.Module): 82 | action_dim: int 83 | 84 | @nn.compact 85 | def __call__(self, x: jnp.ndarray): 86 | x = nn.Dense(120)(x) 87 | x = nn.relu(x) 88 | x = nn.Dense(84)(x) 89 | x = nn.relu(x) 90 | x = nn.Dense(self.action_dim)(x) 91 | return x 92 | 93 | 94 | class TrainState(TrainState): 95 | target_params: flax.core.FrozenDict 96 | 97 | 98 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 99 | slope = (end_e - start_e) / duration 100 | return max(slope * t + start_e, end_e) 101 | 102 | 103 | if __name__ == "__main__": 104 | import stable_baselines3 as sb3 105 | 106 | if sb3.__version__ < "2.0": 107 | raise ValueError( 108 | """Ongoing migration: run the following command to install the new dependencies: 109 | 110 | poetry run pip install "stable_baselines3==2.0.0a1" 111 | """ 112 | ) 113 | args = tyro.cli(Args) 114 | assert args.num_envs == 1, "vectorized envs are not supported at the moment" 115 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" 116 | 117 | wandb.init( 118 | project="dqn", 119 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 120 | config=vars(args), 121 | save_code=True, 122 | ) 123 | 124 | # TRY NOT TO MODIFY: seeding 125 | random.seed(args.seed) 126 | np.random.seed(args.seed) 127 | key = jax.random.PRNGKey(args.seed) 128 | key, q_key = jax.random.split(key, 2) 129 | 130 | # env setup 131 | envs = gym.vector.SyncVectorEnv( 132 | [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] 133 | ) 134 | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" 135 | 136 | obs, _ = envs.reset(seed=args.seed) 137 | q_network = QNetwork(action_dim=envs.single_action_space.n) 138 | q_state = TrainState.create( 139 | apply_fn=q_network.apply, 140 | params=q_network.init(q_key, obs), 141 | target_params=q_network.init(q_key, obs), 142 | tx=optax.adam(learning_rate=args.learning_rate), 143 | ) 144 | 145 | q_network.apply = jax.jit(q_network.apply) 146 | # This step is not necessary as init called on same observation and key will always lead to same initializations 147 | q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) 148 | 149 | rb = ReplayBuffer( 150 | args.buffer_size, 151 | envs.single_observation_space, 152 | envs.single_action_space, 153 | "cpu", 154 | handle_timeout_termination=False, 155 | ) 156 | 157 | @jax.jit 158 | def update(q_state, observations, actions, next_observations, rewards, dones): 159 | q_next_target = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions) 160 | q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) 161 | next_q_value = rewards + (1 - dones) * args.gamma * q_next_target 162 | 163 | def mse_loss(params): 164 | q_pred = q_network.apply(params, observations) # (batch_size, num_actions) 165 | q_pred = q_pred[jnp.arange(q_pred.shape[0]), actions.squeeze()] # (batch_size,) 166 | return ((q_pred - next_q_value) ** 2).mean(), q_pred 167 | 168 | (loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params) 169 | q_state = q_state.apply_gradients(grads=grads) 170 | return loss_value, q_pred, q_state 171 | 172 | start_time = None 173 | 174 | # TRY NOT TO MODIFY: start the game 175 | obs, _ = envs.reset(seed=args.seed) 176 | avg_returns = deque(maxlen=20) 177 | pbar = tqdm.tqdm(range(args.total_timesteps)) 178 | 179 | for global_step in pbar: 180 | if global_step == args.learning_starts + args.measure_burnin: 181 | start_time = time.time() 182 | global_step_start = global_step 183 | 184 | # ALGO LOGIC: put action logic here 185 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) 186 | if random.random() < epsilon: 187 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 188 | else: 189 | q_values = q_network.apply(q_state.params, obs) 190 | actions = q_values.argmax(axis=-1) 191 | actions = jax.device_get(actions) 192 | 193 | # TRY NOT TO MODIFY: execute the game and log data. 194 | next_obs, rewards, terminations, truncations, infos = envs.step(actions) 195 | 196 | # TRY NOT TO MODIFY: record rewards for plotting purposes 197 | if "final_info" in infos: 198 | for info in infos["final_info"]: 199 | if info and "episode" in info: 200 | avg_returns.append(info["episode"]["r"]) 201 | desc = f"global_step={global_step}, episodic_return={np.array(avg_returns).mean()}" 202 | 203 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` 204 | real_next_obs = next_obs.copy() 205 | for idx, trunc in enumerate(truncations): 206 | if trunc: 207 | real_next_obs[idx] = infos["final_observation"][idx] 208 | rb.add(obs, real_next_obs, actions, rewards, terminations, infos) 209 | 210 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 211 | obs = next_obs 212 | 213 | # ALGO LOGIC: training. 214 | if global_step > args.learning_starts: 215 | if global_step % args.train_frequency == 0: 216 | data = rb.sample(args.batch_size) 217 | # perform a gradient-descent step 218 | loss, old_val, q_state = update( 219 | q_state, 220 | data.observations.numpy(), 221 | data.actions.numpy(), 222 | data.next_observations.numpy(), 223 | data.rewards.flatten().numpy(), 224 | data.dones.flatten().numpy(), 225 | ) 226 | 227 | # update target network 228 | if global_step % args.target_network_frequency == 0: 229 | q_state = q_state.replace( 230 | target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau) 231 | ) 232 | if global_step % 100 == 0 and start_time is not None: 233 | speed = (global_step - global_step_start) / (time.time() - start_time) 234 | pbar.set_description(f"speed: {speed: 4.2f} sps, " + desc) 235 | logs = { 236 | "episode_return": np.array(avg_returns).mean(), 237 | } 238 | wandb.log( 239 | { 240 | "speed": speed, 241 | **logs, 242 | }, 243 | step=global_step, 244 | ) 245 | 246 | envs.close() 247 | -------------------------------------------------------------------------------- /leanrl/dqn_torchcompile.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy 2 | import math 3 | import os 4 | import random 5 | import time 6 | from collections import deque 7 | from dataclasses import dataclass 8 | 9 | import gymnasium as gym 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import tqdm 16 | import tyro 17 | import wandb 18 | from tensordict import TensorDict, from_module 19 | from tensordict.nn import CudaGraphModule 20 | from torchrl.data import LazyTensorStorage, ReplayBuffer 21 | 22 | 23 | @dataclass 24 | class Args: 25 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 26 | """the name of this experiment""" 27 | seed: int = 1 28 | """seed of the experiment""" 29 | torch_deterministic: bool = True 30 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 31 | cuda: bool = True 32 | """if toggled, cuda will be enabled by default""" 33 | capture_video: bool = False 34 | """whether to capture videos of the agent performances (check out `videos` folder)""" 35 | 36 | # Algorithm specific arguments 37 | env_id: str = "CartPole-v1" 38 | """the id of the environment""" 39 | total_timesteps: int = 500000 40 | """total timesteps of the experiments""" 41 | learning_rate: float = 2.5e-4 42 | """the learning rate of the optimizer""" 43 | num_envs: int = 1 44 | """the number of parallel game environments""" 45 | buffer_size: int = 10000 46 | """the replay memory buffer size""" 47 | gamma: float = 0.99 48 | """the discount factor gamma""" 49 | tau: float = 1.0 50 | """the target network update rate""" 51 | target_network_frequency: int = 500 52 | """the timesteps it takes to update the target network""" 53 | batch_size: int = 128 54 | """the batch size of sample from the reply memory""" 55 | start_e: float = 1 56 | """the starting epsilon for exploration""" 57 | end_e: float = 0.05 58 | """the ending epsilon for exploration""" 59 | exploration_fraction: float = 0.5 60 | """the fraction of `total-timesteps` it takes from start-e to go end-e""" 61 | learning_starts: int = 10000 62 | """timestep to start learning""" 63 | train_frequency: int = 10 64 | """the frequency of training""" 65 | 66 | measure_burnin: int = 3 67 | """Number of burn-in iterations for speed measure.""" 68 | 69 | compile: bool = False 70 | """whether to use torch.compile.""" 71 | cudagraphs: bool = False 72 | """whether to use cudagraphs on top of compile.""" 73 | 74 | 75 | def make_env(env_id, seed, idx, capture_video, run_name): 76 | def thunk(): 77 | if capture_video and idx == 0: 78 | env = gym.make(env_id, render_mode="rgb_array") 79 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 80 | else: 81 | env = gym.make(env_id) 82 | env = gym.wrappers.RecordEpisodeStatistics(env) 83 | env.action_space.seed(seed) 84 | 85 | return env 86 | 87 | return thunk 88 | 89 | 90 | # ALGO LOGIC: initialize agent here: 91 | class QNetwork(nn.Module): 92 | def __init__(self, n_obs, n_act, device=None): 93 | super().__init__() 94 | self.network = nn.Sequential( 95 | nn.Linear(n_obs, 120, device=device), 96 | nn.ReLU(), 97 | nn.Linear(120, 84, device=device), 98 | nn.ReLU(), 99 | nn.Linear(84, n_act, device=device), 100 | ) 101 | 102 | def forward(self, x): 103 | return self.network(x) 104 | 105 | 106 | def linear_schedule(start_e: float, end_e: float, duration: int): 107 | slope = (end_e - start_e) / duration 108 | slope = torch.tensor(slope, device=device) 109 | while True: 110 | yield slope.clamp_min(end_e) 111 | 112 | 113 | if __name__ == "__main__": 114 | args = tyro.cli(Args) 115 | assert args.num_envs == 1, "vectorized envs are not supported at the moment" 116 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" 117 | 118 | wandb.init( 119 | project="dqn", 120 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 121 | config=vars(args), 122 | save_code=True, 123 | ) 124 | 125 | # TRY NOT TO MODIFY: seeding 126 | random.seed(args.seed) 127 | np.random.seed(args.seed) 128 | torch.manual_seed(args.seed) 129 | torch.backends.cudnn.deterministic = args.torch_deterministic 130 | 131 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") 132 | 133 | # env setup 134 | envs = gym.vector.SyncVectorEnv( 135 | [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] 136 | ) 137 | 138 | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" 139 | n_act = envs.single_action_space.n 140 | n_obs = math.prod(envs.single_observation_space.shape) 141 | 142 | q_network = QNetwork(n_obs=n_obs, n_act=n_act, device=device) 143 | q_network_detach = QNetwork(n_obs=n_obs, n_act=n_act, device=device) 144 | params_vals = from_module(q_network).detach() 145 | params_vals.to_module(q_network_detach) 146 | 147 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate, capturable=args.cudagraphs and not args.compile) 148 | 149 | target_network = QNetwork(n_obs=n_obs, n_act=n_act, device=device) 150 | target_params = params_vals.clone().lock_() 151 | target_params.to_module(target_network) 152 | 153 | def update(data): 154 | with torch.no_grad(): 155 | target_max, _ = target_network(data["next_observations"]).max(dim=1) 156 | td_target = data["rewards"].flatten() + args.gamma * target_max * (~data["dones"].flatten()).float() 157 | old_val = q_network(data["observations"]).gather(1, data["actions"].unsqueeze(-1)).squeeze() 158 | loss = F.mse_loss(td_target, old_val) 159 | 160 | # optimize the model 161 | optimizer.zero_grad() 162 | loss.backward() 163 | optimizer.step() 164 | return loss.detach() 165 | 166 | def policy(obs, epsilon): 167 | q_values = q_network_detach(obs) 168 | actions = torch.argmax(q_values, dim=1) 169 | actions_random = torch.rand(actions.shape, device=actions.device).mul(n_act).floor().to(torch.long) 170 | # actions_random = torch.randint_like(actions, n_act) 171 | use_policy = torch.rand(actions.shape, device=actions.device).gt(epsilon) 172 | return torch.where(use_policy, actions, actions_random) 173 | 174 | rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device)) 175 | 176 | if args.compile: 177 | mode = None # "reduce-overhead" if not args.cudagraphs else None 178 | update = torch.compile(update, mode=mode) 179 | policy = torch.compile(policy, mode=mode, fullgraph=True) 180 | 181 | if args.cudagraphs: 182 | update = CudaGraphModule(update) 183 | policy = CudaGraphModule(policy) 184 | 185 | start_time = None 186 | 187 | # TRY NOT TO MODIFY: start the game 188 | obs, _ = envs.reset(seed=args.seed) 189 | obs = torch.as_tensor(obs, device=device, dtype=torch.float) 190 | eps_schedule = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps) 191 | avg_returns = deque(maxlen=20) 192 | 193 | pbar = tqdm.tqdm(range(args.total_timesteps)) 194 | transitions = [] 195 | for global_step in pbar: 196 | if global_step == args.learning_starts + args.measure_burnin: 197 | start_time = time.time() 198 | global_step_start = global_step 199 | 200 | # ALGO LOGIC: put action logic here 201 | epsilon = next(eps_schedule) 202 | actions = policy(obs, epsilon) 203 | 204 | # TRY NOT TO MODIFY: execute the game and log data. 205 | next_obs, rewards, terminations, truncations, infos = envs.step(actions.cpu().numpy()) 206 | 207 | # TRY NOT TO MODIFY: record rewards for plotting purposes 208 | if "final_info" in infos: 209 | for info in infos["final_info"]: 210 | if info and "episode" in info: 211 | avg_returns.append(info["episode"]["r"]) 212 | desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean()}" 213 | 214 | next_obs = torch.as_tensor(next_obs, dtype=torch.float).to(device, non_blocking=True) 215 | terminations = torch.as_tensor(terminations, dtype=torch.bool).to(device, non_blocking=True) 216 | rewards = torch.as_tensor(rewards, dtype=torch.float).to(device, non_blocking=True) 217 | 218 | real_next_obs = None 219 | for idx, trunc in enumerate(truncations): 220 | if trunc: 221 | if real_next_obs is None: 222 | real_next_obs = next_obs.clone() 223 | real_next_obs[idx] = torch.as_tensor(infos["final_observation"][idx], device=device, dtype=torch.float) 224 | if real_next_obs is None: 225 | real_next_obs = next_obs 226 | # obs = torch.as_tensor(obs, device=device, dtype=torch.float) 227 | transitions.append( 228 | TensorDict._new_unsafe( 229 | observations=obs, 230 | next_observations=real_next_obs, 231 | actions=actions, 232 | rewards=rewards, 233 | terminations=terminations, 234 | dones=terminations, 235 | batch_size=obs.shape[:1], 236 | device=device, 237 | ) 238 | ) 239 | 240 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 241 | obs = next_obs 242 | 243 | # ALGO LOGIC: training. 244 | if global_step > args.learning_starts: 245 | if global_step % args.train_frequency == 0: 246 | rb.extend(torch.cat(transitions)) 247 | transitions = [] 248 | data = rb.sample(args.batch_size) 249 | loss = update(data) 250 | # update target network 251 | if global_step % args.target_network_frequency == 0: 252 | target_params.lerp_(params_vals, args.tau) 253 | 254 | if global_step % 100 == 0 and start_time is not None: 255 | speed = (global_step - global_step_start) / (time.time() - start_time) 256 | pbar.set_description(f"speed: {speed: 4.2f} sps, " f"epsilon: {epsilon.cpu().item(): 4.2f}, " + desc) 257 | with torch.no_grad(): 258 | logs = { 259 | "episode_return": torch.tensor(avg_returns).mean(), 260 | "loss": loss.mean(), 261 | "epsilon": epsilon, 262 | } 263 | wandb.log( 264 | { 265 | "speed": speed, 266 | **logs, 267 | }, 268 | step=global_step, 269 | ) 270 | 271 | envs.close() 272 | -------------------------------------------------------------------------------- /leanrl/td3_continuous_action.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy 2 | import os 3 | import random 4 | import time 5 | from collections import deque 6 | from dataclasses import dataclass 7 | 8 | import gymnasium as gym 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import tqdm 15 | import tyro 16 | import wandb 17 | from stable_baselines3.common.buffers import ReplayBuffer 18 | 19 | 20 | @dataclass 21 | class Args: 22 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 23 | """the name of this experiment""" 24 | seed: int = 1 25 | """seed of the experiment""" 26 | torch_deterministic: bool = True 27 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 28 | cuda: bool = True 29 | """if toggled, cuda will be enabled by default""" 30 | capture_video: bool = False 31 | """whether to capture videos of the agent performances (check out `videos` folder)""" 32 | 33 | # Algorithm specific arguments 34 | env_id: str = "HalfCheetah-v4" 35 | """the id of the environment""" 36 | total_timesteps: int = 1000000 37 | """total timesteps of the experiments""" 38 | learning_rate: float = 3e-4 39 | """the learning rate of the optimizer""" 40 | buffer_size: int = int(1e6) 41 | """the replay memory buffer size""" 42 | gamma: float = 0.99 43 | """the discount factor gamma""" 44 | tau: float = 0.005 45 | """target smoothing coefficient (default: 0.005)""" 46 | batch_size: int = 256 47 | """the batch size of sample from the reply memory""" 48 | policy_noise: float = 0.2 49 | """the scale of policy noise""" 50 | exploration_noise: float = 0.1 51 | """the scale of exploration noise""" 52 | learning_starts: int = 25e3 53 | """timestep to start learning""" 54 | policy_frequency: int = 2 55 | """the frequency of training policy (delayed)""" 56 | noise_clip: float = 0.5 57 | """noise clip parameter of the Target Policy Smoothing Regularization""" 58 | 59 | measure_burnin: int = 3 60 | 61 | 62 | def make_env(env_id, seed, idx, capture_video, run_name): 63 | def thunk(): 64 | if capture_video and idx == 0: 65 | env = gym.make(env_id, render_mode="rgb_array") 66 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 67 | else: 68 | env = gym.make(env_id) 69 | env = gym.wrappers.RecordEpisodeStatistics(env) 70 | env.action_space.seed(seed) 71 | return env 72 | 73 | return thunk 74 | 75 | 76 | # ALGO LOGIC: initialize agent here: 77 | class QNetwork(nn.Module): 78 | def __init__(self, env): 79 | super().__init__() 80 | self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) 81 | self.fc2 = nn.Linear(256, 256) 82 | self.fc3 = nn.Linear(256, 1) 83 | 84 | def forward(self, x, a): 85 | x = torch.cat([x, a], 1) 86 | x = F.relu(self.fc1(x)) 87 | x = F.relu(self.fc2(x)) 88 | x = self.fc3(x) 89 | return x 90 | 91 | 92 | class Actor(nn.Module): 93 | def __init__(self, env): 94 | super().__init__() 95 | self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) 96 | self.fc2 = nn.Linear(256, 256) 97 | self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape)) 98 | # action rescaling 99 | self.register_buffer( 100 | "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32) 101 | ) 102 | self.register_buffer( 103 | "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32) 104 | ) 105 | 106 | def forward(self, x): 107 | x = F.relu(self.fc1(x)) 108 | x = F.relu(self.fc2(x)) 109 | x = torch.tanh(self.fc_mu(x)) 110 | return x * self.action_scale + self.action_bias 111 | 112 | 113 | if __name__ == "__main__": 114 | import stable_baselines3 as sb3 115 | 116 | if sb3.__version__ < "2.0": 117 | raise ValueError( 118 | """Ongoing migration: run the following command to install the new dependencies: 119 | poetry run pip install "stable_baselines3==2.0.0a1" 120 | """ 121 | ) 122 | 123 | args = tyro.cli(Args) 124 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" 125 | 126 | wandb.init( 127 | project="td3_continuous_action", 128 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 129 | config=vars(args), 130 | save_code=True, 131 | ) 132 | 133 | # TRY NOT TO MODIFY: seeding 134 | random.seed(args.seed) 135 | np.random.seed(args.seed) 136 | torch.manual_seed(args.seed) 137 | torch.backends.cudnn.deterministic = args.torch_deterministic 138 | 139 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 140 | 141 | # env setup 142 | envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) 143 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 144 | 145 | actor = Actor(envs).to(device) 146 | qf1 = QNetwork(envs).to(device) 147 | qf2 = QNetwork(envs).to(device) 148 | qf1_target = QNetwork(envs).to(device) 149 | qf2_target = QNetwork(envs).to(device) 150 | target_actor = Actor(envs).to(device) 151 | target_actor.load_state_dict(actor.state_dict()) 152 | qf1_target.load_state_dict(qf1.state_dict()) 153 | qf2_target.load_state_dict(qf2.state_dict()) 154 | q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate) 155 | actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate) 156 | 157 | envs.single_observation_space.dtype = np.float32 158 | rb = ReplayBuffer( 159 | args.buffer_size, 160 | envs.single_observation_space, 161 | envs.single_action_space, 162 | device, 163 | handle_timeout_termination=False, 164 | ) 165 | start_time = time.time() 166 | 167 | # TRY NOT TO MODIFY: start the game 168 | obs, _ = envs.reset(seed=args.seed) 169 | pbar = tqdm.tqdm(range(args.total_timesteps)) 170 | start_time = None 171 | max_ep_ret = -float("inf") 172 | avg_returns = deque(maxlen=20) 173 | desc = "" 174 | for global_step in pbar: 175 | if global_step == args.measure_burnin + args.learning_starts: 176 | start_time = time.time() 177 | measure_burnin = global_step 178 | 179 | # ALGO LOGIC: put action logic here 180 | if global_step < args.learning_starts: 181 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 182 | else: 183 | with torch.no_grad(): 184 | actions = actor(torch.Tensor(obs).to(device)) 185 | actions += torch.normal(0, actor.action_scale * args.exploration_noise) 186 | actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) 187 | 188 | # TRY NOT TO MODIFY: execute the game and log data. 189 | next_obs, rewards, terminations, truncations, infos = envs.step(actions) 190 | 191 | # TRY NOT TO MODIFY: record rewards for plotting purposes 192 | if "final_info" in infos: 193 | for info in infos["final_info"]: 194 | r = float(info["episode"]["r"]) 195 | max_ep_ret = max(max_ep_ret, r) 196 | avg_returns.append(r) 197 | desc = ( 198 | f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 199 | ) 200 | 201 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` 202 | real_next_obs = next_obs.copy() 203 | for idx, trunc in enumerate(truncations): 204 | if trunc: 205 | real_next_obs[idx] = infos["final_observation"][idx] 206 | rb.add(obs, real_next_obs, actions, rewards, terminations, infos) 207 | 208 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 209 | obs = next_obs 210 | 211 | # ALGO LOGIC: training. 212 | if global_step > args.learning_starts: 213 | data = rb.sample(args.batch_size) 214 | with torch.no_grad(): 215 | clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp( 216 | -args.noise_clip, args.noise_clip 217 | ) * target_actor.action_scale 218 | 219 | next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp( 220 | envs.single_action_space.low[0], envs.single_action_space.high[0] 221 | ) 222 | qf1_next_target = qf1_target(data.next_observations, next_state_actions) 223 | qf2_next_target = qf2_target(data.next_observations, next_state_actions) 224 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) 225 | next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) 226 | 227 | qf1_a_values = qf1(data.observations, data.actions).view(-1) 228 | qf2_a_values = qf2(data.observations, data.actions).view(-1) 229 | qf1_loss = F.mse_loss(qf1_a_values, next_q_value) 230 | qf2_loss = F.mse_loss(qf2_a_values, next_q_value) 231 | qf_loss = qf1_loss + qf2_loss 232 | 233 | # optimize the model 234 | q_optimizer.zero_grad() 235 | qf_loss.backward() 236 | q_optimizer.step() 237 | 238 | if global_step % args.policy_frequency == 0: 239 | actor_loss = -qf1(data.observations, actor(data.observations)).mean() 240 | actor_optimizer.zero_grad() 241 | actor_loss.backward() 242 | actor_optimizer.step() 243 | 244 | # update the target network 245 | for param, target_param in zip(actor.parameters(), target_actor.parameters()): 246 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) 247 | for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): 248 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) 249 | for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): 250 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) 251 | 252 | if (global_step % 100 == 0) and start_time is not None: 253 | speed = (global_step - measure_burnin) / (time.time() - start_time) 254 | pbar.set_description(f"{speed: 4.4f} sps, " + desc) 255 | with torch.no_grad(): 256 | logs = { 257 | "episode_return": torch.tensor(avg_returns).mean(), 258 | "actor_loss": actor_loss.mean(), 259 | "qf_loss": qf_loss.mean(), 260 | } 261 | wandb.log( 262 | { 263 | "speed": speed, 264 | **logs, 265 | }, 266 | step=global_step, 267 | ) 268 | 269 | envs.close() 270 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![LeanRL](https://img.shields.io/badge/Discord-LeanRL-blue)](https://discord.com/channels/1171857748607115354/1289142756614213697) 2 | 3 | # LeanRL - Turbo-implementations of CleanRL scripts 4 | 5 | LeanRL is a lightweight library consisting of single-file, pytorch-based implementations of popular Reinforcement 6 | Learning (RL) algorithms. 7 | The primary goal of this library is to inform the RL PyTorch user base of optimization tricks to cut training time by 8 | half or more. 9 | 10 | More precisely, LeanRL is a fork of CleanRL, where hand-picked scripts have been re-written using PyTorch 2 features, 11 | mainly [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) and 12 | [`cudagraphs`](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/). 13 | The goal is to provide guidance on how to run your RL script at full speed with minimal impact on the user experience. 14 | 15 | ## Key Features: 16 | 17 | * 📜 Single-file implementation 18 | * We stick to the original spirit of CleanRL which is to keep *every detail about an algorithm variant in a single standalone file.* 19 | * 🚀 Fast implementations: 20 | * We provide an optimized, lean version of the PyTorch scripts (`_torchcompile.py`) where data copies 21 | and code execution have been optimized thanks to four tools: 22 | * 🖥️ `torch.compile` to reduce the overhead and fuse operators whenever possible; 23 | * 📈 `cudagraphs` to isolate all the cuda operations and eliminate the cost of entering the compiled code; 24 | * 📖 `tensordict` to speed-up and clarify data copies on CUDA, facilitate functional calls and fast target parameters updates. 25 | * 🗺️ `torch.vmap` to vectorize the execution of the Q-value networks, when needed. 26 | * We provide a somewhat lighter version of each script, removing some logging and checkpointing-related lines of code. 27 | to focus on the time spent optimizing the models. 28 | * If available, we do the same with the Jax version of the code. 29 | * 🪛 Local Reproducibility via Seeding 30 | 31 | **Disclaimer**: This repo is a highly simplified version of CleanRL that lacks many features such as detailed logging 32 | or checkpointing - its only purpose is to provide various versions of similar training scripts to measure the plain 33 | runtime under various constraints. However, we welcome contributions that re-implement these features. 34 | 35 | ## Speed-ups 36 | 37 | There are three sources of speed-ups in the codes proposed here: 38 | 39 | - **torch.compile**: Introduced in PyTorch 2.0, `torch.compile` serves as the primary framework for accelerating the 40 | execution of PyTorch code during both training and inference phases. This compiler translates Python code into a 41 | series of elementary operations and identifies opportunities for fusion. A significant advantage of `torch.compile` is 42 | its ability to minimize the overhead of transitioning between the Python interpreter and the C++ runtime. 43 | Unlike PyTorch's eager execution mode, which requires numerous such boundary crossings, `torch.compile` generates a 44 | single C++ executable, thereby minimizing the need to frequently revert to Python. Additionally, `torch.compile` is 45 | notably resilient to graph breaks, which occur when an operation is not supported by the compiler (due to design 46 | constraints or pending integration of the Python operator). This robustness ensures that virtually any Python code can 47 | be compiled in principle. 48 | - **cudagraphs**: Reinforcement Learning (RL) is typically constrained by significant CPU overhead. Unlike other machine 49 | learning domains where networks might be deep, RL commonly employs shallower networks. When using `torch.compile`, 50 | there is a minor CPU overhead associated with the execution of compiled code itself (e.g., guard checks). 51 | This overhead can negate the benefits of operator fusions, especially since the functions being compiled are already 52 | quick to execute. To address this, PyTorch offers cudagraph support. Utilizing cudagraphs involves capturing the operations 53 | executed on a CUDA device, using device buffers, and replaying the same operations graph later. If the graph's buffers 54 | (content) are updated in-place, new results can be generated. Here is how a typical cudagraph pipeline appears: 55 | 56 | ```python 57 | g = torch.cuda.CUDAGraph() 58 | with torch.cuda.graph(graph): 59 | # x_buffer, y_buffer are example tensors of the desired shape 60 | z_buffer = func(x_buffer, y_buffer) 61 | # later on, with a new x and y we want to pass to func 62 | x_buffer.copy_(x) 63 | y_buffer.copy_(y) 64 | graph.replay() 65 | z = z_buffer.clone() 66 | ``` 67 | 68 | This has some strong requirements (all tensors must be on CUDA, and dynamic shapes are not supported). 69 | Because we are explicitly avoiding the `torch.compile` entry cost, this is much faster. Cudagraphs can also be used 70 | without `torch.compile`, but by using both simultaneously we can benefit from both operator fusion and cudagraphs 71 | speed-ups. 72 | As one can see, using cudagraph as such is a bit convoluted and not very pythonic. Fortunately, the `tensordict` 73 | library provides a `CudaGraphModule` that acts as a wrapper around an `nn.Module` and allows for a flexible and safe 74 | usage of `CudaGraphModule`. 75 | 76 | **To reproduce these results in your own code base**: look for calls to `torch.compile` and `CudaGraphModule` wrapper 77 | within the `*_torchcompile.py` scripts. 78 | 79 | You can also look into `run.sh` for the exact commands we used to run the scripts. 80 | 81 | The following table displays speed-ups obtained on a H100 equipped node with TODO cpu cores. 82 | All models were executed on GPU, simulation was done on CPU. 83 | 84 | 85 | 86 | 88 | 90 | 92 | 94 | 96 | 98 | 99 | 101 | 103 | 105 | 107 | 109 | 111 | 112 | 114 | 116 | 118 | 120 | 122 | 124 | 125 | 127 | 129 | 131 | 133 | 135 | 137 | 138 | 140 | 142 | 144 | 146 | 148 | 150 | 151 |
Algorithm 87 | PyTorch speed (fps) - CleanRL implementation 89 | PyTorch speed (fps) - LeanRL implementation 91 | PyTorch speed (fps) - compile 93 | PyTorch speed (fps) - compile+cudagraphs 95 | Overall speed-up 97 |
PPO (Atari) 100 | 1022 102 | 3728 104 | 3841 106 | 6809 108 | 6.8x 110 |
PPO (Continuous action) 113 | 652 115 | 683 117 | 908 119 | 1774 121 | 2.7x 123 |
SAC (Continuous action) 126 | 127 128 | 130 130 | 255 132 | 725 134 | 5.7x 136 |
TD3 (Continuous action) 139 | 272 141 | 247 143 | 272 145 | 936 147 | 3.4x 149 |
152 | 153 | These figures are displayed in the plots below. All runs were executed for an identical number of steps across 3 154 | different seeds. 155 | Fluctuations in the results are due to seeding artifacts, not implementations details (which are identical across 156 | scripts). 157 | 158 |
159 | SAC (HalfCheetah-v4) 160 | 161 | ![SAC.png](doc/artifacts/SAC.png) 162 | 163 | ![sac_speed.png](doc/artifacts/sac_speed.png) 164 | 165 |
166 | 167 |
168 | TD3 (HalfCheetah-v4) 169 | 170 | ![TD3.png](doc/artifacts/TD3.png) 171 | 172 | ![td3_speed.png](doc/artifacts/td3_speed.png) 173 | 174 |
175 | 176 |
177 | PPO (Atari - Breakout-v5) 178 | 179 | ![SAC.png](doc/artifacts/ppo.png) 180 | 181 | ![sac_speed.png](doc/artifacts/ppo_speed.png) 182 | 183 |
184 | 185 | ### GPU utilization 186 | 187 | Using `torch.compile` and cudagraphs also makes a better use of your GPU. 188 | To show this, we plot the GPU utilization throughout training for SAC. The Area Under The Curve (AUC) 189 | measures the total usage of the GPU over the course of the training loop execution. 190 | As this plot show, the combined usage of compile and cudagraphs brings the GPU utilization to its minimum value, 191 | meaning that you can train more models in a shorter time by utilizing these features together. 192 | 193 | ![sac_gpu.png](doc/artifacts/sac_gpu.png) 194 | 195 | ### Tips to accelerate your code in eager mode 196 | 197 | There may be multiple reasons your RL code is running slower than it should. 198 | Here are some off-the-shelf tips to get a better runtime: 199 | 200 | - Don't send tensors to device using `to(device)` if you can instantiate them directly there. For instance, 201 | prefer `randn((), device=device)` to `randn(()).to(device)`. 202 | - Avoid pinning memory in your code unless you thoroughly tested that it accelerates runtime (see 203 | [this tutorial](https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html) for more info). 204 | - Avoid calling `tensor.item()` in between cuda operations. 205 | This triggers a cuda synchronization and blocks 206 | your code. Do the logging after all code (forward / backward / optim) has completed. 207 | See how to find sync points 208 | [here](https://pytorch.org/docs/stable/generated/torch.cuda.set_sync_debug_mode.html#torch-cuda-set-sync-debug-mode)) 209 | - Avoid frequent calls to `eval()` or `train()` in eager mode. 210 | - Avoid calling `args.attribute` often in the code, especially with [Hydra](https://hydra.cc/docs/). Instead, cache 211 | the args values in your script as global workspace variables. 212 | - In general, in-place operations are not preferable to regular ones. Don't load your code with `mul_`, `add_` if not 213 | absolutely necessary. 214 | 215 | ## Get started 216 | 217 | Unlike CleanRL, LeanRL does not currently support `poetry`. 218 | 219 | Prerequisites: 220 | * Clone the repo locally: 221 | ```bash 222 | git clone https://github.com/meta-pytorch/leanrl.git && cd leanrl 223 | ``` 224 | - `pip install -r requirements/requirements.txt` for basic requirements, or another `.txt` file for specific applications. 225 | 226 | Once the dependencies have been installed, run the scripts as follows 227 | 228 | ```bash 229 | python leanrl/ppo_atari_envpool_torchcompile.py \ 230 | --seed 1 \ 231 | --total-timesteps 50000 \ 232 | --compile \ 233 | --cudagraphs 234 | ``` 235 | 236 | Together, the installation steps will generally look like this: 237 | ```bash 238 | conda create -n leanrl python=3.10 -y 239 | conda activate leanrl 240 | python -m pip install --upgrade --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 241 | python -m pip install -r requirements/requirements.txt 242 | python -m pip install -r requirements/requirements-atari.txt 243 | python -m pip install -r requirements/requirements-envpool.txt 244 | python -m pip install -r requirements/requirements-mujoco.txt 245 | 246 | python leanrl/ppo_atari_envpool_torchcompile.py \ 247 | --seed 1 \ 248 | --compile \ 249 | --cudagraphs 250 | 251 | ``` 252 | 253 | ## Citing CleanRL 254 | 255 | LeanRL does not have a citation yet, credentials should be given to CleanRL instead. 256 | To cite CleanRL in your work, please cite our technical [paper](https://www.jmlr.org/papers/v23/21-1342.html): 257 | 258 | ```bibtex 259 | @article{huang2022cleanrl, 260 | author = {Shengyi Huang and Rousslan Fernand Julien Dossa and Chang Ye and Jeff Braga and Dipam Chakraborty and Kinal Mehta and João G.M. Araújo}, 261 | title = {CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms}, 262 | journal = {Journal of Machine Learning Research}, 263 | year = {2022}, 264 | volume = {23}, 265 | number = {274}, 266 | pages = {1--18}, 267 | url = {http://jmlr.org/papers/v23/21-1342.html} 268 | } 269 | ``` 270 | 271 | 272 | ## Acknowledgement 273 | 274 | LeanRL is forked from [CleanRL](https://github.com/vwxyzjn/cleanrl). 275 | 276 | CleanRL is a community-powered by project and our contributors run experiments on a variety of hardware. 277 | 278 | * We thank many contributors for using their own computers to run experiments 279 | * We thank Google's [TPU research cloud](https://sites.research.google/trc/about/) for providing TPU resources. 280 | * We thank [Hugging Face](https://huggingface.co/)'s cluster for providing GPU resources. 281 | 282 | ## License 283 | LeanRL is MIT licensed, as found in the LICENSE file. 284 | -------------------------------------------------------------------------------- /leanrl/td3_continuous_action_jax.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy 2 | import os 3 | import random 4 | import time 5 | from collections import deque 6 | from dataclasses import dataclass 7 | 8 | import flax 9 | import flax.linen as nn 10 | import gymnasium as gym 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import tqdm 16 | import tyro 17 | import wandb 18 | from flax.training.train_state import TrainState 19 | from stable_baselines3.common.buffers import ReplayBuffer 20 | 21 | 22 | @dataclass 23 | class Args: 24 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 25 | """the name of this experiment""" 26 | seed: int = 1 27 | """seed of the experiment""" 28 | capture_video: bool = False 29 | """whether to capture videos of the agent performances (check out `videos` folder)""" 30 | 31 | # Algorithm specific arguments 32 | env_id: str = "HalfCheetah-v4" 33 | """the id of the environment""" 34 | total_timesteps: int = 1000000 35 | """total timesteps of the experiments""" 36 | learning_rate: float = 3e-4 37 | """the learning rate of the optimizer""" 38 | buffer_size: int = int(1e6) 39 | """the replay memory buffer size""" 40 | gamma: float = 0.99 41 | """the discount factor gamma""" 42 | tau: float = 0.005 43 | """target smoothing coefficient (default: 0.005)""" 44 | batch_size: int = 256 45 | """the batch size of sample from the reply memory""" 46 | policy_noise: float = 0.2 47 | """the scale of policy noise""" 48 | exploration_noise: float = 0.1 49 | """the scale of exploration noise""" 50 | learning_starts: int = 25e3 51 | """timestep to start learning""" 52 | policy_frequency: int = 2 53 | """the frequency of training policy (delayed)""" 54 | noise_clip: float = 0.5 55 | """noise clip parameter of the Target Policy Smoothing Regularization""" 56 | 57 | measure_burnin: int = 3 58 | 59 | 60 | def make_env(env_id, seed, idx, capture_video, run_name): 61 | def thunk(): 62 | if capture_video and idx == 0: 63 | env = gym.make(env_id, render_mode="rgb_array") 64 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 65 | else: 66 | env = gym.make(env_id) 67 | env = gym.wrappers.RecordEpisodeStatistics(env) 68 | env.action_space.seed(seed) 69 | return env 70 | 71 | return thunk 72 | 73 | 74 | # ALGO LOGIC: initialize agent here: 75 | class QNetwork(nn.Module): 76 | @nn.compact 77 | def __call__(self, x: jnp.ndarray, a: jnp.ndarray): 78 | x = jnp.concatenate([x, a], -1) 79 | x = nn.Dense(256)(x) 80 | x = nn.relu(x) 81 | x = nn.Dense(256)(x) 82 | x = nn.relu(x) 83 | x = nn.Dense(1)(x) 84 | return x 85 | 86 | 87 | class Actor(nn.Module): 88 | action_dim: int 89 | action_scale: jnp.ndarray 90 | action_bias: jnp.ndarray 91 | 92 | @nn.compact 93 | def __call__(self, x): 94 | x = nn.Dense(256)(x) 95 | x = nn.relu(x) 96 | x = nn.Dense(256)(x) 97 | x = nn.relu(x) 98 | x = nn.Dense(self.action_dim)(x) 99 | x = nn.tanh(x) 100 | x = x * self.action_scale + self.action_bias 101 | return x 102 | 103 | 104 | class TrainState(TrainState): 105 | target_params: flax.core.FrozenDict 106 | 107 | 108 | if __name__ == "__main__": 109 | import stable_baselines3 as sb3 110 | 111 | if sb3.__version__ < "2.0": 112 | raise ValueError( 113 | """Ongoing migration: run the following command to install the new dependencies: 114 | poetry run pip install "stable_baselines3==2.0.0a1" 115 | """ 116 | ) 117 | args = tyro.cli(Args) 118 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" 119 | 120 | wandb.init( 121 | project="td3_continuous_action", 122 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 123 | config=vars(args), 124 | save_code=True, 125 | ) 126 | 127 | # TRY NOT TO MODIFY: seeding 128 | random.seed(args.seed) 129 | np.random.seed(args.seed) 130 | key = jax.random.PRNGKey(args.seed) 131 | key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) 132 | 133 | # env setup 134 | envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) 135 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 136 | 137 | max_action = float(envs.single_action_space.high[0]) 138 | envs.single_observation_space.dtype = np.float32 139 | rb = ReplayBuffer( 140 | args.buffer_size, 141 | envs.single_observation_space, 142 | envs.single_action_space, 143 | device="cpu", 144 | handle_timeout_termination=False, 145 | ) 146 | 147 | # TRY NOT TO MODIFY: start the game 148 | obs, _ = envs.reset(seed=args.seed) 149 | 150 | actor = Actor( 151 | action_dim=np.prod(envs.single_action_space.shape), 152 | action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0), 153 | action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0), 154 | ) 155 | actor_state = TrainState.create( 156 | apply_fn=actor.apply, 157 | params=actor.init(actor_key, obs), 158 | target_params=actor.init(actor_key, obs), 159 | tx=optax.adam(learning_rate=args.learning_rate), 160 | ) 161 | qf = QNetwork() 162 | qf1_state = TrainState.create( 163 | apply_fn=qf.apply, 164 | params=qf.init(qf1_key, obs, envs.action_space.sample()), 165 | target_params=qf.init(qf1_key, obs, envs.action_space.sample()), 166 | tx=optax.adam(learning_rate=args.learning_rate), 167 | ) 168 | qf2_state = TrainState.create( 169 | apply_fn=qf.apply, 170 | params=qf.init(qf2_key, obs, envs.action_space.sample()), 171 | target_params=qf.init(qf2_key, obs, envs.action_space.sample()), 172 | tx=optax.adam(learning_rate=args.learning_rate), 173 | ) 174 | actor.apply = jax.jit(actor.apply) 175 | qf.apply = jax.jit(qf.apply) 176 | 177 | @jax.jit 178 | def update_critic( 179 | actor_state: TrainState, 180 | qf1_state: TrainState, 181 | qf2_state: TrainState, 182 | observations: np.ndarray, 183 | actions: np.ndarray, 184 | next_observations: np.ndarray, 185 | rewards: np.ndarray, 186 | terminations: np.ndarray, 187 | key: jnp.ndarray, 188 | ): 189 | # TODO Maybe pre-generate a lot of random keys 190 | # also check https://jax.readthedocs.io/en/latest/jax.random.html 191 | key, noise_key = jax.random.split(key, 2) 192 | clipped_noise = ( 193 | jnp.clip( 194 | (jax.random.normal(noise_key, actions.shape) * args.policy_noise), 195 | -args.noise_clip, 196 | args.noise_clip, 197 | ) 198 | * actor.action_scale 199 | ) 200 | next_state_actions = jnp.clip( 201 | actor.apply(actor_state.target_params, next_observations) + clipped_noise, 202 | envs.single_action_space.low, 203 | envs.single_action_space.high, 204 | ) 205 | qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1) 206 | qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1) 207 | min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) 208 | next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1) 209 | 210 | def mse_loss(params): 211 | qf_a_values = qf.apply(params, observations, actions).squeeze() 212 | return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() 213 | 214 | (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params) 215 | (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params) 216 | qf1_state = qf1_state.apply_gradients(grads=grads1) 217 | qf2_state = qf2_state.apply_gradients(grads=grads2) 218 | 219 | return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key 220 | 221 | @jax.jit 222 | def update_actor( 223 | actor_state: TrainState, 224 | qf1_state: TrainState, 225 | qf2_state: TrainState, 226 | observations: np.ndarray, 227 | ): 228 | def actor_loss(params): 229 | return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean() 230 | 231 | actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) 232 | actor_state = actor_state.apply_gradients(grads=grads) 233 | actor_state = actor_state.replace( 234 | target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) 235 | ) 236 | 237 | qf1_state = qf1_state.replace( 238 | target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) 239 | ) 240 | qf2_state = qf2_state.replace( 241 | target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) 242 | ) 243 | return actor_state, (qf1_state, qf2_state), actor_loss_value 244 | 245 | pbar = tqdm.tqdm(range(args.total_timesteps)) 246 | start_time = None 247 | max_ep_ret = -float("inf") 248 | avg_returns = deque(maxlen=20) 249 | desc = "" 250 | 251 | for global_step in pbar: 252 | if global_step == args.measure_burnin + args.learning_starts: 253 | start_time = time.time() 254 | measure_burnin = global_step 255 | 256 | # ALGO LOGIC: put action logic here 257 | if global_step < args.learning_starts: 258 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 259 | else: 260 | actions = actor.apply(actor_state.params, obs) 261 | actions = np.array( 262 | [ 263 | ( 264 | jax.device_get(actions)[0] 265 | + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape) 266 | ).clip(envs.single_action_space.low, envs.single_action_space.high) 267 | ] 268 | ) 269 | 270 | # TRY NOT TO MODIFY: execute the game and log data. 271 | next_obs, rewards, terminations, truncations, infos = envs.step(actions) 272 | 273 | # TRY NOT TO MODIFY: record rewards for plotting purposes 274 | if "final_info" in infos: 275 | for info in infos["final_info"]: 276 | r = float(info["episode"]["r"]) 277 | max_ep_ret = max(max_ep_ret, r) 278 | avg_returns.append(r) 279 | desc = f"global_step={global_step}, episodic_return={np.array(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 280 | 281 | # TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation` 282 | real_next_obs = next_obs.copy() 283 | for idx, trunc in enumerate(truncations): 284 | if trunc: 285 | real_next_obs[idx] = infos["final_observation"][idx] 286 | rb.add(obs, real_next_obs, actions, rewards, terminations, infos) 287 | 288 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 289 | obs = next_obs 290 | 291 | # ALGO LOGIC: training. 292 | if global_step > args.learning_starts: 293 | data = rb.sample(args.batch_size) 294 | 295 | (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic( 296 | actor_state, 297 | qf1_state, 298 | qf2_state, 299 | data.observations.numpy(), 300 | data.actions.numpy(), 301 | data.next_observations.numpy(), 302 | data.rewards.flatten().numpy(), 303 | data.dones.flatten().numpy(), 304 | key, 305 | ) 306 | 307 | if global_step % args.policy_frequency == 0: 308 | actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor( 309 | actor_state, 310 | qf1_state, 311 | qf2_state, 312 | data.observations.numpy(), 313 | ) 314 | 315 | if global_step % 100 == 0 and start_time is not None: 316 | speed = (global_step - measure_burnin) / (time.time() - start_time) 317 | pbar.set_description(f"{speed: 4.4f} sps, " + desc) 318 | logs = { 319 | "episode_return": np.array(avg_returns).mean(), 320 | } 321 | wandb.log( 322 | { 323 | "speed": speed, 324 | **logs, 325 | }, 326 | step=global_step, 327 | ) 328 | 329 | envs.close() 330 | -------------------------------------------------------------------------------- /leanrl/sac_continuous_action.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy 2 | import os 3 | import random 4 | import time 5 | from collections import deque 6 | from dataclasses import dataclass 7 | 8 | import gymnasium as gym 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import tqdm 15 | import tyro 16 | import wandb 17 | from stable_baselines3.common.buffers import ReplayBuffer 18 | 19 | 20 | @dataclass 21 | class Args: 22 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 23 | """the name of this experiment""" 24 | seed: int = 1 25 | """seed of the experiment""" 26 | torch_deterministic: bool = True 27 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 28 | cuda: bool = True 29 | """if toggled, cuda will be enabled by default""" 30 | capture_video: bool = False 31 | """whether to capture videos of the agent performances (check out `videos` folder)""" 32 | 33 | # Algorithm specific arguments 34 | env_id: str = "HalfCheetah-v4" 35 | """the environment id of the task""" 36 | total_timesteps: int = 1000000 37 | """total timesteps of the experiments""" 38 | buffer_size: int = int(1e6) 39 | """the replay memory buffer size""" 40 | gamma: float = 0.99 41 | """the discount factor gamma""" 42 | tau: float = 0.005 43 | """target smoothing coefficient (default: 0.005)""" 44 | batch_size: int = 256 45 | """the batch size of sample from the reply memory""" 46 | learning_starts: int = 5e3 47 | """timestep to start learning""" 48 | policy_lr: float = 3e-4 49 | """the learning rate of the policy network optimizer""" 50 | q_lr: float = 1e-3 51 | """the learning rate of the Q network network optimizer""" 52 | policy_frequency: int = 2 53 | """the frequency of training policy (delayed)""" 54 | target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2. 55 | """the frequency of updates for the target nerworks""" 56 | alpha: float = 0.2 57 | """Entropy regularization coefficient.""" 58 | autotune: bool = True 59 | """automatic tuning of the entropy coefficient""" 60 | 61 | measure_burnin: int = 3 62 | 63 | 64 | def make_env(env_id, seed, idx, capture_video, run_name): 65 | def thunk(): 66 | if capture_video and idx == 0: 67 | env = gym.make(env_id, render_mode="rgb_array") 68 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 69 | else: 70 | env = gym.make(env_id) 71 | env = gym.wrappers.RecordEpisodeStatistics(env) 72 | env.action_space.seed(seed) 73 | return env 74 | 75 | return thunk 76 | 77 | 78 | # ALGO LOGIC: initialize agent here: 79 | class SoftQNetwork(nn.Module): 80 | def __init__(self, env): 81 | super().__init__() 82 | self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) 83 | self.fc2 = nn.Linear(256, 256) 84 | self.fc3 = nn.Linear(256, 1) 85 | 86 | def forward(self, x, a): 87 | x = torch.cat([x, a], 1) 88 | x = F.relu(self.fc1(x)) 89 | x = F.relu(self.fc2(x)) 90 | x = self.fc3(x) 91 | return x 92 | 93 | 94 | LOG_STD_MAX = 2 95 | LOG_STD_MIN = -5 96 | 97 | 98 | class Actor(nn.Module): 99 | def __init__(self, env): 100 | super().__init__() 101 | self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) 102 | self.fc2 = nn.Linear(256, 256) 103 | self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape)) 104 | self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape)) 105 | # action rescaling 106 | self.register_buffer( 107 | "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32) 108 | ) 109 | self.register_buffer( 110 | "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32) 111 | ) 112 | 113 | def forward(self, x): 114 | x = F.relu(self.fc1(x)) 115 | x = F.relu(self.fc2(x)) 116 | mean = self.fc_mean(x) 117 | log_std = self.fc_logstd(x) 118 | log_std = torch.tanh(log_std) 119 | log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats 120 | 121 | return mean, log_std 122 | 123 | def get_action(self, x): 124 | mean, log_std = self(x) 125 | std = log_std.exp() 126 | normal = torch.distributions.Normal(mean, std) 127 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 128 | y_t = torch.tanh(x_t) 129 | action = y_t * self.action_scale + self.action_bias 130 | log_prob = normal.log_prob(x_t) 131 | # Enforcing Action Bound 132 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) 133 | log_prob = log_prob.sum(1, keepdim=True) 134 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 135 | return action, log_prob, mean 136 | 137 | 138 | if __name__ == "__main__": 139 | import stable_baselines3 as sb3 140 | 141 | if sb3.__version__ < "2.0": 142 | raise ValueError( 143 | """Ongoing migration: run the following command to install the new dependencies: 144 | poetry run pip install "stable_baselines3==2.0.0a1" 145 | """ 146 | ) 147 | 148 | args = tyro.cli(Args) 149 | 150 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" 151 | 152 | wandb.init( 153 | project="sac_continuous_action", 154 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 155 | config=vars(args), 156 | save_code=True, 157 | ) 158 | 159 | # TRY NOT TO MODIFY: seeding 160 | random.seed(args.seed) 161 | np.random.seed(args.seed) 162 | torch.manual_seed(args.seed) 163 | torch.backends.cudnn.deterministic = args.torch_deterministic 164 | 165 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 166 | 167 | # env setup 168 | envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) 169 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 170 | 171 | max_action = float(envs.single_action_space.high[0]) 172 | 173 | actor = Actor(envs).to(device) 174 | qf1 = SoftQNetwork(envs).to(device) 175 | qf2 = SoftQNetwork(envs).to(device) 176 | qf1_target = SoftQNetwork(envs).to(device) 177 | qf2_target = SoftQNetwork(envs).to(device) 178 | qf1_target.load_state_dict(qf1.state_dict()) 179 | qf2_target.load_state_dict(qf2.state_dict()) 180 | q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr) 181 | actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr) 182 | 183 | # Automatic entropy tuning 184 | if args.autotune: 185 | target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item() 186 | log_alpha = torch.zeros(1, requires_grad=True, device=device) 187 | alpha = log_alpha.exp().item() 188 | a_optimizer = optim.Adam([log_alpha], lr=args.q_lr) 189 | else: 190 | alpha = args.alpha 191 | 192 | envs.single_observation_space.dtype = np.float32 193 | rb = ReplayBuffer( 194 | args.buffer_size, 195 | envs.single_observation_space, 196 | envs.single_action_space, 197 | device, 198 | handle_timeout_termination=False, 199 | ) 200 | start_time = time.time() 201 | 202 | # TRY NOT TO MODIFY: start the game 203 | obs, _ = envs.reset(seed=args.seed) 204 | pbar = tqdm.tqdm(range(args.total_timesteps)) 205 | start_time = None 206 | max_ep_ret = -float("inf") 207 | avg_returns = deque(maxlen=20) 208 | desc = "" 209 | for global_step in pbar: 210 | if global_step == args.measure_burnin + args.learning_starts: 211 | start_time = time.time() 212 | measure_burnin = global_step 213 | 214 | # ALGO LOGIC: put action logic here 215 | if global_step < args.learning_starts: 216 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 217 | else: 218 | actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) 219 | actions = actions.detach().cpu().numpy() 220 | 221 | # TRY NOT TO MODIFY: execute the game and log data. 222 | next_obs, rewards, terminations, truncations, infos = envs.step(actions) 223 | 224 | # TRY NOT TO MODIFY: record rewards for plotting purposes 225 | if "final_info" in infos: 226 | for info in infos["final_info"]: 227 | r = float(info["episode"]["r"]) 228 | max_ep_ret = max(max_ep_ret, r) 229 | avg_returns.append(r) 230 | desc = ( 231 | f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 232 | ) 233 | 234 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` 235 | real_next_obs = next_obs.copy() 236 | for idx, trunc in enumerate(truncations): 237 | if trunc: 238 | real_next_obs[idx] = infos["final_observation"][idx] 239 | rb.add(obs, real_next_obs, actions, rewards, terminations, infos) 240 | 241 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 242 | obs = next_obs 243 | 244 | # ALGO LOGIC: training. 245 | if global_step > args.learning_starts: 246 | data = rb.sample(args.batch_size) 247 | with torch.no_grad(): 248 | next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations) 249 | qf1_next_target = qf1_target(data.next_observations, next_state_actions) 250 | qf2_next_target = qf2_target(data.next_observations, next_state_actions) 251 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi 252 | next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) 253 | 254 | qf1_a_values = qf1(data.observations, data.actions).view(-1) 255 | qf2_a_values = qf2(data.observations, data.actions).view(-1) 256 | qf1_loss = F.mse_loss(qf1_a_values, next_q_value) 257 | qf2_loss = F.mse_loss(qf2_a_values, next_q_value) 258 | qf_loss = qf1_loss + qf2_loss 259 | 260 | # optimize the model 261 | q_optimizer.zero_grad() 262 | qf_loss.backward() 263 | q_optimizer.step() 264 | 265 | if global_step % args.policy_frequency == 0: # TD 3 Delayed update support 266 | for _ in range( 267 | args.policy_frequency 268 | ): # compensate for the delay by doing 'actor_update_interval' instead of 1 269 | pi, log_pi, _ = actor.get_action(data.observations) 270 | qf1_pi = qf1(data.observations, pi) 271 | qf2_pi = qf2(data.observations, pi) 272 | min_qf_pi = torch.min(qf1_pi, qf2_pi) 273 | actor_loss = ((alpha * log_pi) - min_qf_pi).mean() 274 | 275 | actor_optimizer.zero_grad() 276 | actor_loss.backward() 277 | actor_optimizer.step() 278 | 279 | if args.autotune: 280 | with torch.no_grad(): 281 | _, log_pi, _ = actor.get_action(data.observations) 282 | alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean() 283 | 284 | a_optimizer.zero_grad() 285 | alpha_loss.backward() 286 | a_optimizer.step() 287 | alpha = log_alpha.exp().item() 288 | 289 | # update the target networks 290 | if global_step % args.target_network_frequency == 0: 291 | for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): 292 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) 293 | for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): 294 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) 295 | 296 | if global_step % 100 == 0 and start_time is not None: 297 | speed = (global_step - measure_burnin) / (time.time() - start_time) 298 | pbar.set_description(f"{speed: 4.4f} sps, " + desc) 299 | with torch.no_grad(): 300 | logs = { 301 | "episode_return": torch.tensor(avg_returns).mean(), 302 | "actor_loss": actor_loss.mean(), 303 | "alpha_loss": alpha_loss.mean(), 304 | "qf_loss": qf_loss.mean(), 305 | } 306 | wandb.log( 307 | { 308 | "speed": speed, 309 | **logs, 310 | }, 311 | step=global_step, 312 | ) 313 | 314 | envs.close() 315 | -------------------------------------------------------------------------------- /leanrl/ppo_continuous_action.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy 2 | import os 3 | import random 4 | import time 5 | from collections import deque 6 | from dataclasses import dataclass 7 | 8 | import gymnasium as gym 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import tqdm 14 | import tyro 15 | import wandb 16 | from torch.distributions.normal import Normal 17 | 18 | 19 | @dataclass 20 | class Args: 21 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 22 | """the name of this experiment""" 23 | seed: int = 1 24 | """seed of the experiment""" 25 | torch_deterministic: bool = True 26 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 27 | cuda: bool = True 28 | """if toggled, cuda will be enabled by default""" 29 | capture_video: bool = False 30 | """whether to capture videos of the agent performances (check out `videos` folder)""" 31 | 32 | # Algorithm specific arguments 33 | env_id: str = "HalfCheetah-v4" 34 | """the id of the environment""" 35 | total_timesteps: int = 1000000 36 | """total timesteps of the experiments""" 37 | learning_rate: float = 3e-4 38 | """the learning rate of the optimizer""" 39 | num_envs: int = 1 40 | """the number of parallel game environments""" 41 | num_steps: int = 2048 42 | """the number of steps to run in each environment per policy rollout""" 43 | anneal_lr: bool = True 44 | """Toggle learning rate annealing for policy and value networks""" 45 | gamma: float = 0.99 46 | """the discount factor gamma""" 47 | gae_lambda: float = 0.95 48 | """the lambda for the general advantage estimation""" 49 | num_minibatches: int = 32 50 | """the number of mini-batches""" 51 | update_epochs: int = 10 52 | """the K epochs to update the policy""" 53 | norm_adv: bool = True 54 | """Toggles advantages normalization""" 55 | clip_coef: float = 0.2 56 | """the surrogate clipping coefficient""" 57 | clip_vloss: bool = True 58 | """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" 59 | ent_coef: float = 0.0 60 | """coefficient of the entropy""" 61 | vf_coef: float = 0.5 62 | """coefficient of the value function""" 63 | max_grad_norm: float = 0.5 64 | """the maximum norm for the gradient clipping""" 65 | target_kl: float = None 66 | """the target KL divergence threshold""" 67 | 68 | # to be filled in runtime 69 | batch_size: int = 0 70 | """the batch size (computed in runtime)""" 71 | minibatch_size: int = 0 72 | """the mini-batch size (computed in runtime)""" 73 | num_iterations: int = 0 74 | """the number of iterations (computed in runtime)""" 75 | 76 | measure_burnin: int = 3 77 | 78 | 79 | def make_env(env_id, idx, capture_video, run_name, gamma): 80 | def thunk(): 81 | if capture_video and idx == 0: 82 | env = gym.make(env_id, render_mode="rgb_array") 83 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 84 | else: 85 | env = gym.make(env_id) 86 | env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space 87 | env = gym.wrappers.RecordEpisodeStatistics(env) 88 | env = gym.wrappers.ClipAction(env) 89 | env = gym.wrappers.NormalizeObservation(env) 90 | env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) 91 | env = gym.wrappers.NormalizeReward(env, gamma=gamma) 92 | env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) 93 | return env 94 | 95 | return thunk 96 | 97 | 98 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 99 | torch.nn.init.orthogonal_(layer.weight, std) 100 | torch.nn.init.constant_(layer.bias, bias_const) 101 | return layer 102 | 103 | 104 | class Agent(nn.Module): 105 | def __init__(self, envs): 106 | super().__init__() 107 | self.critic = nn.Sequential( 108 | layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), 109 | nn.Tanh(), 110 | layer_init(nn.Linear(64, 64)), 111 | nn.Tanh(), 112 | layer_init(nn.Linear(64, 1), std=1.0), 113 | ) 114 | self.actor_mean = nn.Sequential( 115 | layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), 116 | nn.Tanh(), 117 | layer_init(nn.Linear(64, 64)), 118 | nn.Tanh(), 119 | layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), 120 | ) 121 | self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) 122 | 123 | def get_value(self, x): 124 | return self.critic(x) 125 | 126 | def get_action_and_value(self, x, action=None): 127 | action_mean = self.actor_mean(x) 128 | action_logstd = self.actor_logstd.expand_as(action_mean) 129 | action_std = torch.exp(action_logstd) 130 | probs = Normal(action_mean, action_std) 131 | if action is None: 132 | action = probs.sample() 133 | return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) 134 | 135 | 136 | if __name__ == "__main__": 137 | args = tyro.cli(Args) 138 | args.batch_size = int(args.num_envs * args.num_steps) 139 | args.minibatch_size = int(args.batch_size // args.num_minibatches) 140 | args.num_iterations = args.total_timesteps // args.batch_size 141 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" 142 | 143 | wandb.init( 144 | project="ppo_continuous_action", 145 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 146 | config=vars(args), 147 | save_code=True, 148 | ) 149 | 150 | # TRY NOT TO MODIFY: seeding 151 | random.seed(args.seed) 152 | np.random.seed(args.seed) 153 | torch.manual_seed(args.seed) 154 | torch.backends.cudnn.deterministic = args.torch_deterministic 155 | 156 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 157 | 158 | # env setup 159 | envs = gym.vector.SyncVectorEnv( 160 | [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] 161 | ) 162 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 163 | 164 | agent = Agent(envs).to(device) 165 | optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) 166 | 167 | # ALGO Logic: Storage setup 168 | obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) 169 | actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) 170 | logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) 171 | rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) 172 | dones = torch.zeros((args.num_steps, args.num_envs)).to(device) 173 | values = torch.zeros((args.num_steps, args.num_envs)).to(device) 174 | avg_returns = deque(maxlen=20) 175 | 176 | # TRY NOT TO MODIFY: start the game 177 | global_step = 0 178 | next_obs, _ = envs.reset(seed=args.seed) 179 | next_obs = torch.Tensor(next_obs).to(device) 180 | next_done = torch.zeros(args.num_envs).to(device) 181 | max_ep_ret = -float("inf") 182 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) 183 | global_step_burnin = None 184 | start_time = None 185 | desc = "" 186 | 187 | for iteration in pbar: 188 | if iteration == args.measure_burnin: 189 | global_step_burnin = global_step 190 | start_time = time.time() 191 | 192 | # Annealing the rate if instructed to do so. 193 | if args.anneal_lr: 194 | frac = 1.0 - (iteration - 1.0) / args.num_iterations 195 | lrnow = frac * args.learning_rate 196 | optimizer.param_groups[0]["lr"] = lrnow 197 | 198 | for step in range(0, args.num_steps): 199 | global_step += args.num_envs 200 | obs[step] = next_obs 201 | dones[step] = next_done 202 | 203 | # ALGO LOGIC: action logic 204 | with torch.no_grad(): 205 | action, logprob, _, value = agent.get_action_and_value(next_obs) 206 | values[step] = value.flatten() 207 | actions[step] = action 208 | logprobs[step] = logprob 209 | 210 | # TRY NOT TO MODIFY: execute the game and log data. 211 | next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) 212 | next_done = np.logical_or(terminations, truncations) 213 | rewards[step] = torch.tensor(reward).to(device).view(-1) 214 | next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) 215 | 216 | if "final_info" in infos: 217 | for info in infos["final_info"]: 218 | r = float(info["episode"]["r"].reshape(())) 219 | max_ep_ret = max(max_ep_ret, r) 220 | avg_returns.append(r) 221 | desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 222 | 223 | # bootstrap value if not done 224 | with torch.no_grad(): 225 | next_value = agent.get_value(next_obs).reshape(1, -1) 226 | advantages = torch.zeros_like(rewards).to(device) 227 | lastgaelam = 0 228 | for t in reversed(range(args.num_steps)): 229 | if t == args.num_steps - 1: 230 | nextnonterminal = 1.0 - next_done 231 | nextvalues = next_value 232 | else: 233 | nextnonterminal = 1.0 - dones[t + 1] 234 | nextvalues = values[t + 1] 235 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] 236 | advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam 237 | returns = advantages + values 238 | 239 | # flatten the batch 240 | b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) 241 | b_logprobs = logprobs.reshape(-1) 242 | b_actions = actions.reshape((-1,) + envs.single_action_space.shape) 243 | b_advantages = advantages.reshape(-1) 244 | b_returns = returns.reshape(-1) 245 | b_values = values.reshape(-1) 246 | 247 | # Optimizing the policy and value network 248 | b_inds = np.arange(args.batch_size) 249 | clipfracs = [] 250 | for epoch in range(args.update_epochs): 251 | np.random.shuffle(b_inds) 252 | for start in range(0, args.batch_size, args.minibatch_size): 253 | end = start + args.minibatch_size 254 | mb_inds = b_inds[start:end] 255 | 256 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) 257 | logratio = newlogprob - b_logprobs[mb_inds] 258 | ratio = logratio.exp() 259 | 260 | with torch.no_grad(): 261 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 262 | old_approx_kl = (-logratio).mean() 263 | approx_kl = ((ratio - 1) - logratio).mean() 264 | clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] 265 | 266 | mb_advantages = b_advantages[mb_inds] 267 | if args.norm_adv: 268 | mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) 269 | 270 | # Policy loss 271 | pg_loss1 = -mb_advantages * ratio 272 | pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 273 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 274 | 275 | # Value loss 276 | newvalue = newvalue.view(-1) 277 | if args.clip_vloss: 278 | v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 279 | v_clipped = b_values[mb_inds] + torch.clamp( 280 | newvalue - b_values[mb_inds], 281 | -args.clip_coef, 282 | args.clip_coef, 283 | ) 284 | v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 285 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 286 | v_loss = 0.5 * v_loss_max.mean() 287 | else: 288 | v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() 289 | 290 | entropy_loss = entropy.mean() 291 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 292 | 293 | optimizer.zero_grad() 294 | loss.backward() 295 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 296 | optimizer.step() 297 | 298 | if args.target_kl is not None and approx_kl > args.target_kl: 299 | break 300 | 301 | y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() 302 | var_y = np.var(y_true) 303 | explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 304 | 305 | if global_step_burnin is not None and iteration % 10 == 0: 306 | speed = (global_step - global_step_burnin) / (time.time() - start_time) 307 | pbar.set_description(f"speed: {speed: 4.1f} sps, " + desc) 308 | with torch.no_grad(): 309 | logs = { 310 | "episode_return": np.array(avg_returns).mean(), 311 | "logprobs": b_logprobs.mean(), 312 | "advantages": advantages.mean(), 313 | "returns": returns.mean(), 314 | "values": values.mean(), 315 | "gn": gn, 316 | } 317 | wandb.log( 318 | { 319 | "speed": speed, 320 | **logs, 321 | }, 322 | step=global_step, 323 | ) 324 | 325 | envs.close() 326 | -------------------------------------------------------------------------------- /leanrl/td3_continuous_action_torchcompile.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy 2 | import os 3 | 4 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" 5 | 6 | import math 7 | import os 8 | import random 9 | import time 10 | from collections import deque 11 | from dataclasses import dataclass 12 | 13 | import gymnasium as gym 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | import tqdm 20 | import tyro 21 | import wandb 22 | from tensordict import TensorDict, from_module, from_modules 23 | from tensordict.nn import CudaGraphModule 24 | from torchrl.data import LazyTensorStorage, ReplayBuffer 25 | 26 | 27 | @dataclass 28 | class Args: 29 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 30 | """the name of this experiment""" 31 | seed: int = 1 32 | """seed of the experiment""" 33 | torch_deterministic: bool = True 34 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 35 | cuda: bool = True 36 | """if toggled, cuda will be enabled by default""" 37 | capture_video: bool = False 38 | """whether to capture videos of the agent performances (check out `videos` folder)""" 39 | 40 | # Algorithm specific arguments 41 | env_id: str = "HalfCheetah-v4" 42 | """the id of the environment""" 43 | total_timesteps: int = 1000000 44 | """total timesteps of the experiments""" 45 | learning_rate: float = 3e-4 46 | """the learning rate of the optimizer""" 47 | buffer_size: int = int(1e6) 48 | """the replay memory buffer size""" 49 | gamma: float = 0.99 50 | """the discount factor gamma""" 51 | tau: float = 0.005 52 | """target smoothing coefficient (default: 0.005)""" 53 | batch_size: int = 256 54 | """the batch size of sample from the reply memory""" 55 | policy_noise: float = 0.2 56 | """the scale of policy noise""" 57 | exploration_noise: float = 0.1 58 | """the scale of exploration noise""" 59 | learning_starts: int = 25e3 60 | """timestep to start learning""" 61 | policy_frequency: int = 2 62 | """the frequency of training policy (delayed)""" 63 | noise_clip: float = 0.5 64 | """noise clip parameter of the Target Policy Smoothing Regularization""" 65 | 66 | measure_burnin: int = 3 67 | """Number of burn-in iterations for speed measure.""" 68 | 69 | compile: bool = False 70 | """whether to use torch.compile.""" 71 | cudagraphs: bool = False 72 | """whether to use cudagraphs on top of compile.""" 73 | 74 | 75 | def make_env(env_id, seed, idx, capture_video, run_name): 76 | def thunk(): 77 | if capture_video and idx == 0: 78 | env = gym.make(env_id, render_mode="rgb_array") 79 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 80 | else: 81 | env = gym.make(env_id) 82 | env = gym.wrappers.RecordEpisodeStatistics(env) 83 | env.action_space.seed(seed) 84 | return env 85 | 86 | return thunk 87 | 88 | 89 | # ALGO LOGIC: initialize agent here: 90 | class QNetwork(nn.Module): 91 | def __init__(self, n_obs, n_act, device=None): 92 | super().__init__() 93 | self.fc1 = nn.Linear(n_obs + n_act, 256, device=device) 94 | self.fc2 = nn.Linear(256, 256, device=device) 95 | self.fc3 = nn.Linear(256, 1, device=device) 96 | 97 | def forward(self, x, a): 98 | x = torch.cat([x, a], 1) 99 | x = F.relu(self.fc1(x)) 100 | x = F.relu(self.fc2(x)) 101 | x = self.fc3(x) 102 | return x 103 | 104 | 105 | class Actor(nn.Module): 106 | def __init__(self, n_obs, n_act, env, exploration_noise=1, device=None): 107 | super().__init__() 108 | self.fc1 = nn.Linear(n_obs, 256, device=device) 109 | self.fc2 = nn.Linear(256, 256, device=device) 110 | self.fc_mu = nn.Linear(256, n_act, device=device) 111 | # action rescaling 112 | self.register_buffer( 113 | "action_scale", 114 | torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32, device=device), 115 | ) 116 | self.register_buffer( 117 | "action_bias", 118 | torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32, device=device), 119 | ) 120 | self.register_buffer("exploration_noise", torch.as_tensor(exploration_noise, device=device)) 121 | 122 | def forward(self, obs): 123 | obs = F.relu(self.fc1(obs)) 124 | obs = F.relu(self.fc2(obs)) 125 | obs = self.fc_mu(obs).tanh() 126 | return obs * self.action_scale + self.action_bias 127 | 128 | def explore(self, obs): 129 | act = self(obs) 130 | return act + torch.randn_like(act).mul(self.action_scale * self.exploration_noise) 131 | 132 | 133 | if __name__ == "__main__": 134 | args = tyro.cli(Args) 135 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" 136 | 137 | wandb.init( 138 | project="td3_continuous_action", 139 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 140 | config=vars(args), 141 | save_code=True, 142 | ) 143 | 144 | # TRY NOT TO MODIFY: seeding 145 | random.seed(args.seed) 146 | np.random.seed(args.seed) 147 | torch.manual_seed(args.seed) 148 | torch.backends.cudnn.deterministic = args.torch_deterministic 149 | 150 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 151 | 152 | # env setup 153 | envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) 154 | n_act = math.prod(envs.single_action_space.shape) 155 | n_obs = math.prod(envs.single_observation_space.shape) 156 | action_low, action_high = float(envs.single_action_space.low[0]), float(envs.single_action_space.high[0]) 157 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 158 | 159 | actor = Actor(env=envs, n_obs=n_obs, n_act=n_act, device=device, exploration_noise=args.exploration_noise) 160 | actor_detach = Actor(env=envs, n_obs=n_obs, n_act=n_act, device=device, exploration_noise=args.exploration_noise) 161 | # Copy params to actor_detach without grad 162 | from_module(actor).data.to_module(actor_detach) 163 | policy = actor_detach.explore 164 | 165 | def get_params_qnet(): 166 | qf1 = QNetwork(n_obs=n_obs, n_act=n_act, device=device) 167 | qf2 = QNetwork(n_obs=n_obs, n_act=n_act, device=device) 168 | 169 | qnet_params = from_modules(qf1, qf2, as_module=True) 170 | qnet_target_params = qnet_params.data.clone() 171 | 172 | # discard params of net 173 | qnet = QNetwork(n_obs=n_obs, n_act=n_act, device="meta") 174 | qnet_params.to_module(qnet) 175 | 176 | return qnet_params, qnet_target_params, qnet 177 | 178 | def get_params_actor(actor): 179 | target_actor = Actor(env=envs, device="meta", n_act=n_act, n_obs=n_obs) 180 | actor_params = from_module(actor).data 181 | target_actor_params = actor_params.clone() 182 | target_actor_params.to_module(target_actor) 183 | return actor_params, target_actor_params, target_actor 184 | 185 | qnet_params, qnet_target_params, qnet = get_params_qnet() 186 | actor_params, target_actor_params, target_actor = get_params_actor(actor) 187 | 188 | q_optimizer = optim.Adam( 189 | qnet_params.values(include_nested=True, leaves_only=True), 190 | lr=args.learning_rate, 191 | capturable=args.cudagraphs and not args.compile, 192 | ) 193 | actor_optimizer = optim.Adam( 194 | list(actor.parameters()), lr=args.learning_rate, capturable=args.cudagraphs and not args.compile 195 | ) 196 | 197 | envs.single_observation_space.dtype = np.float32 198 | rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device)) 199 | 200 | def batched_qf(params, obs, action, next_q_value=None): 201 | with params.to_module(qnet): 202 | vals = qnet(obs, action) 203 | if next_q_value is not None: 204 | loss_val = F.mse_loss(vals.view(-1), next_q_value) 205 | return loss_val 206 | return vals 207 | 208 | policy_noise = args.policy_noise 209 | noise_clip = args.noise_clip 210 | action_scale = target_actor.action_scale 211 | 212 | def update_main(data): 213 | observations = data["observations"] 214 | next_observations = data["next_observations"] 215 | actions = data["actions"] 216 | rewards = data["rewards"] 217 | dones = data["dones"] 218 | clipped_noise = torch.randn_like(actions) 219 | clipped_noise = clipped_noise.mul(policy_noise).clamp(-noise_clip, noise_clip).mul(action_scale) 220 | 221 | next_state_actions = (target_actor(next_observations) + clipped_noise).clamp(action_low, action_high) 222 | 223 | qf_next_target = torch.vmap(batched_qf, (0, None, None))(qnet_target_params, next_observations, next_state_actions) 224 | min_qf_next_target = qf_next_target.min(0).values 225 | next_q_value = rewards.flatten() + (~dones.flatten()).float() * args.gamma * min_qf_next_target.flatten() 226 | 227 | qf_loss = torch.vmap(batched_qf, (0, None, None, None))(qnet_params, observations, actions, next_q_value) 228 | qf_loss = qf_loss.sum(0) 229 | 230 | # optimize the model 231 | q_optimizer.zero_grad() 232 | qf_loss.backward() 233 | q_optimizer.step() 234 | return TensorDict(qf_loss=qf_loss.detach()) 235 | 236 | def update_pol(data): 237 | actor_optimizer.zero_grad() 238 | with qnet_params.data[0].to_module(qnet): 239 | actor_loss = -qnet(data["observations"], actor(data["observations"])).mean() 240 | 241 | actor_loss.backward() 242 | actor_optimizer.step() 243 | return TensorDict(actor_loss=actor_loss.detach()) 244 | 245 | def extend_and_sample(transition): 246 | rb.extend(transition) 247 | return rb.sample(args.batch_size) 248 | 249 | if args.compile: 250 | mode = None # "reduce-overhead" if not args.cudagraphs else None 251 | update_main = torch.compile(update_main, mode=mode) 252 | update_pol = torch.compile(update_pol, mode=mode) 253 | policy = torch.compile(policy, mode=mode) 254 | 255 | if args.cudagraphs: 256 | update_main = CudaGraphModule(update_main, in_keys=[], out_keys=[], warmup=5) 257 | update_pol = CudaGraphModule(update_pol, in_keys=[], out_keys=[], warmup=5) 258 | policy = CudaGraphModule(policy) 259 | 260 | # TRY NOT TO MODIFY: start the game 261 | obs, _ = envs.reset(seed=args.seed) 262 | obs = torch.as_tensor(obs, device=device, dtype=torch.float) 263 | pbar = tqdm.tqdm(range(args.total_timesteps)) 264 | start_time = None 265 | max_ep_ret = -float("inf") 266 | avg_returns = deque(maxlen=20) 267 | desc = "" 268 | 269 | for global_step in pbar: 270 | if global_step == args.measure_burnin + args.learning_starts: 271 | start_time = time.time() 272 | measure_burnin = global_step 273 | 274 | # ALGO LOGIC: put action logic here 275 | if global_step < args.learning_starts: 276 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 277 | else: 278 | actions = policy(obs=obs) 279 | actions = actions.clamp(action_low, action_high).cpu().numpy() 280 | 281 | # TRY NOT TO MODIFY: execute the game and log data. 282 | next_obs, rewards, terminations, truncations, infos = envs.step(actions) 283 | 284 | # TRY NOT TO MODIFY: record rewards for plotting purposes 285 | if "final_info" in infos: 286 | for info in infos["final_info"]: 287 | r = float(info["episode"]["r"].reshape(())) 288 | max_ep_ret = max(max_ep_ret, r) 289 | avg_returns.append(r) 290 | desc = ( 291 | f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 292 | ) 293 | 294 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` 295 | next_obs = torch.as_tensor(next_obs, device=device, dtype=torch.float) 296 | real_next_obs = next_obs.clone() 297 | if "final_observation" in infos: 298 | real_next_obs[truncations] = torch.as_tensor( 299 | np.asarray(list(infos["final_observation"][truncations]), dtype=np.float32), device=device, dtype=torch.float 300 | ) 301 | # obs = torch.as_tensor(obs, device=device, dtype=torch.float) 302 | transition = TensorDict( 303 | observations=obs, 304 | next_observations=real_next_obs, 305 | actions=torch.as_tensor(actions, device=device, dtype=torch.float), 306 | rewards=torch.as_tensor(rewards, device=device, dtype=torch.float), 307 | terminations=terminations, 308 | dones=terminations, 309 | batch_size=obs.shape[0], 310 | device=device, 311 | ) 312 | 313 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 314 | obs = next_obs 315 | data = extend_and_sample(transition) 316 | 317 | # ALGO LOGIC: training. 318 | if global_step > args.learning_starts: 319 | out_main = update_main(data) 320 | if global_step % args.policy_frequency == 0: 321 | out_main.update(update_pol(data)) 322 | 323 | # update the target networks 324 | # lerp is defined as x' = x + w (y-x), which is equivalent to x' = (1-w) x + w y 325 | qnet_target_params.lerp_(qnet_params.data, args.tau) 326 | target_actor_params.lerp_(actor_params.data, args.tau) 327 | 328 | if global_step % 100 == 0 and start_time is not None: 329 | speed = (global_step - measure_burnin) / (time.time() - start_time) 330 | pbar.set_description(f"{speed: 4.4f} sps, " + desc) 331 | with torch.no_grad(): 332 | logs = { 333 | "episode_return": torch.tensor(avg_returns).mean(), 334 | "actor_loss": out_main["actor_loss"].mean(), 335 | "qf_loss": out_main["qf_loss"].mean(), 336 | } 337 | wandb.log( 338 | { 339 | "speed": speed, 340 | **logs, 341 | }, 342 | step=global_step, 343 | ) 344 | 345 | envs.close() 346 | -------------------------------------------------------------------------------- /leanrl/ppo_atari_envpool.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy 2 | import os 3 | import random 4 | import time 5 | from collections import deque 6 | from dataclasses import dataclass 7 | 8 | import envpool 9 | import gym 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import tqdm 15 | import tyro 16 | import wandb 17 | from torch.distributions.categorical import Categorical 18 | 19 | 20 | @dataclass 21 | class Args: 22 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 23 | """the name of this experiment""" 24 | seed: int = 1 25 | """seed of the experiment""" 26 | torch_deterministic: bool = True 27 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 28 | cuda: bool = True 29 | """if toggled, cuda will be enabled by default""" 30 | capture_video: bool = False 31 | """whether to capture videos of the agent performances (check out `videos` folder)""" 32 | 33 | # Algorithm specific arguments 34 | env_id: str = "Breakout-v5" 35 | """the id of the environment""" 36 | total_timesteps: int = 10000000 37 | """total timesteps of the experiments""" 38 | learning_rate: float = 2.5e-4 39 | """the learning rate of the optimizer""" 40 | num_envs: int = 8 41 | """the number of parallel game environments""" 42 | num_steps: int = 128 43 | """the number of steps to run in each environment per policy rollout""" 44 | anneal_lr: bool = True 45 | """Toggle learning rate annealing for policy and value networks""" 46 | gamma: float = 0.99 47 | """the discount factor gamma""" 48 | gae_lambda: float = 0.95 49 | """the lambda for the general advantage estimation""" 50 | num_minibatches: int = 4 51 | """the number of mini-batches""" 52 | update_epochs: int = 4 53 | """the K epochs to update the policy""" 54 | norm_adv: bool = True 55 | """Toggles advantages normalization""" 56 | clip_coef: float = 0.1 57 | """the surrogate clipping coefficient""" 58 | clip_vloss: bool = True 59 | """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" 60 | ent_coef: float = 0.01 61 | """coefficient of the entropy""" 62 | vf_coef: float = 0.5 63 | """coefficient of the value function""" 64 | max_grad_norm: float = 0.5 65 | """the maximum norm for the gradient clipping""" 66 | target_kl: float = None 67 | """the target KL divergence threshold""" 68 | 69 | # to be filled in runtime 70 | batch_size: int = 0 71 | """the batch size (computed in runtime)""" 72 | minibatch_size: int = 0 73 | """the mini-batch size (computed in runtime)""" 74 | num_iterations: int = 0 75 | """the number of iterations (computed in runtime)""" 76 | 77 | measure_burnin: int = 3 78 | """Number of burn-in iterations for speed measure.""" 79 | 80 | 81 | class RecordEpisodeStatistics(gym.Wrapper): 82 | def __init__(self, env, deque_size=100): 83 | super().__init__(env) 84 | self.num_envs = getattr(env, "num_envs", 1) 85 | self.episode_returns = None 86 | self.episode_lengths = None 87 | 88 | def reset(self, **kwargs): 89 | observations = super().reset(**kwargs) 90 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 91 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 92 | self.lives = np.zeros(self.num_envs, dtype=np.int32) 93 | self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) 94 | self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 95 | return observations 96 | 97 | def step(self, action): 98 | observations, rewards, dones, infos = super().step(action) 99 | self.episode_returns += infos["reward"] 100 | self.episode_lengths += 1 101 | self.returned_episode_returns[:] = self.episode_returns 102 | self.returned_episode_lengths[:] = self.episode_lengths 103 | self.episode_returns *= 1 - infos["terminated"] 104 | self.episode_lengths *= 1 - infos["terminated"] 105 | infos["r"] = self.returned_episode_returns 106 | infos["l"] = self.returned_episode_lengths 107 | return ( 108 | observations, 109 | rewards, 110 | dones, 111 | infos, 112 | ) 113 | 114 | 115 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 116 | torch.nn.init.orthogonal_(layer.weight, std) 117 | torch.nn.init.constant_(layer.bias, bias_const) 118 | return layer 119 | 120 | 121 | class Agent(nn.Module): 122 | def __init__(self, envs): 123 | super().__init__() 124 | self.network = nn.Sequential( 125 | layer_init(nn.Conv2d(4, 32, 8, stride=4)), 126 | nn.ReLU(), 127 | layer_init(nn.Conv2d(32, 64, 4, stride=2)), 128 | nn.ReLU(), 129 | layer_init(nn.Conv2d(64, 64, 3, stride=1)), 130 | nn.ReLU(), 131 | nn.Flatten(), 132 | layer_init(nn.Linear(64 * 7 * 7, 512)), 133 | nn.ReLU(), 134 | ) 135 | self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01) 136 | self.critic = layer_init(nn.Linear(512, 1), std=1) 137 | 138 | def get_value(self, x): 139 | return self.critic(self.network(x / 255.0)) 140 | 141 | def get_action_and_value(self, x, action=None): 142 | hidden = self.network(x / 255.0) 143 | logits = self.actor(hidden) 144 | probs = Categorical(logits=logits) 145 | if action is None: 146 | action = probs.sample() 147 | return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) 148 | 149 | 150 | if __name__ == "__main__": 151 | args = tyro.cli(Args) 152 | 153 | args.batch_size = int(args.num_envs * args.num_steps) 154 | args.minibatch_size = int(args.batch_size // args.num_minibatches) 155 | args.num_iterations = args.total_timesteps // args.batch_size 156 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}" 157 | 158 | wandb.init( 159 | project="ppo_atari", 160 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 161 | config=vars(args), 162 | save_code=True, 163 | ) 164 | 165 | # TRY NOT TO MODIFY: seeding 166 | random.seed(args.seed) 167 | np.random.seed(args.seed) 168 | torch.manual_seed(args.seed) 169 | torch.backends.cudnn.deterministic = args.torch_deterministic 170 | 171 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 172 | 173 | # env setup 174 | envs = envpool.make( 175 | args.env_id, 176 | env_type="gym", 177 | num_envs=args.num_envs, 178 | episodic_life=True, 179 | reward_clip=True, 180 | seed=args.seed, 181 | ) 182 | envs.num_envs = args.num_envs 183 | envs.single_action_space = envs.action_space 184 | envs.single_observation_space = envs.observation_space 185 | envs = RecordEpisodeStatistics(envs) 186 | assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" 187 | 188 | agent = Agent(envs).to(device) 189 | optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) 190 | 191 | # ALGO Logic: Storage setup 192 | obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) 193 | actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) 194 | logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) 195 | rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) 196 | dones = torch.zeros((args.num_steps, args.num_envs)).to(device) 197 | values = torch.zeros((args.num_steps, args.num_envs)).to(device) 198 | avg_returns = deque(maxlen=20) 199 | 200 | # TRY NOT TO MODIFY: start the game 201 | global_step = 0 202 | next_obs = torch.Tensor(envs.reset()).to(device) 203 | next_done = torch.zeros(args.num_envs).to(device) 204 | max_ep_ret = -float("inf") 205 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) 206 | global_step_burnin = None 207 | start_time = None 208 | desc = "" 209 | 210 | for iteration in pbar: 211 | if iteration == args.measure_burnin: 212 | global_step_burnin = global_step 213 | start_time = time.time() 214 | 215 | # Annealing the rate if instructed to do so. 216 | if args.anneal_lr: 217 | frac = 1.0 - (iteration - 1.0) / args.num_iterations 218 | lrnow = frac * args.learning_rate 219 | optimizer.param_groups[0]["lr"] = lrnow 220 | 221 | for step in range(0, args.num_steps): 222 | global_step += args.num_envs 223 | obs[step] = next_obs 224 | dones[step] = next_done 225 | 226 | # ALGO LOGIC: action logic 227 | with torch.no_grad(): 228 | action, logprob, _, value = agent.get_action_and_value(next_obs) 229 | values[step] = value.flatten() 230 | actions[step] = action 231 | logprobs[step] = logprob 232 | 233 | # TRY NOT TO MODIFY: execute the game and log data. 234 | next_obs, reward, next_done, info = envs.step(action.cpu().numpy()) 235 | rewards[step] = torch.tensor(reward).to(device).view(-1) 236 | next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) 237 | 238 | for idx, d in enumerate(next_done): 239 | if d and info["lives"][idx] == 0: 240 | r = float(info["r"][idx]) 241 | max_ep_ret = max(max_ep_ret, r) 242 | avg_returns.append(r) 243 | desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 244 | 245 | # bootstrap value if not done 246 | with torch.no_grad(): 247 | next_value = agent.get_value(next_obs).reshape(1, -1) 248 | advantages = torch.zeros_like(rewards).to(device) 249 | lastgaelam = 0 250 | for t in reversed(range(args.num_steps)): 251 | if t == args.num_steps - 1: 252 | nextnonterminal = 1.0 - next_done 253 | nextvalues = next_value 254 | else: 255 | nextnonterminal = 1.0 - dones[t + 1] 256 | nextvalues = values[t + 1] 257 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] 258 | advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam 259 | returns = advantages + values 260 | 261 | # flatten the batch 262 | b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) 263 | b_logprobs = logprobs.reshape(-1) 264 | b_actions = actions.reshape((-1,) + envs.single_action_space.shape) 265 | b_advantages = advantages.reshape(-1) 266 | b_returns = returns.reshape(-1) 267 | b_values = values.reshape(-1) 268 | 269 | # Optimizing the policy and value network 270 | b_inds = np.arange(args.batch_size) 271 | clipfracs = [] 272 | for epoch in range(args.update_epochs): 273 | np.random.shuffle(b_inds) 274 | for start in range(0, args.batch_size, args.minibatch_size): 275 | end = start + args.minibatch_size 276 | mb_inds = b_inds[start:end] 277 | 278 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) 279 | logratio = newlogprob - b_logprobs[mb_inds] 280 | ratio = logratio.exp() 281 | 282 | with torch.no_grad(): 283 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 284 | old_approx_kl = (-logratio).mean() 285 | approx_kl = ((ratio - 1) - logratio).mean() 286 | clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] 287 | 288 | mb_advantages = b_advantages[mb_inds] 289 | if args.norm_adv: 290 | mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) 291 | 292 | # Policy loss 293 | pg_loss1 = -mb_advantages * ratio 294 | pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 295 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 296 | 297 | # Value loss 298 | newvalue = newvalue.view(-1) 299 | if args.clip_vloss: 300 | v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 301 | v_clipped = b_values[mb_inds] + torch.clamp( 302 | newvalue - b_values[mb_inds], 303 | -args.clip_coef, 304 | args.clip_coef, 305 | ) 306 | v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 307 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 308 | v_loss = 0.5 * v_loss_max.mean() 309 | else: 310 | v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() 311 | 312 | entropy_loss = entropy.mean() 313 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 314 | 315 | optimizer.zero_grad() 316 | loss.backward() 317 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 318 | optimizer.step() 319 | 320 | if args.target_kl is not None and approx_kl > args.target_kl: 321 | break 322 | 323 | y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() 324 | var_y = np.var(y_true) 325 | explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 326 | 327 | if global_step_burnin is not None and iteration % 10 == 0: 328 | speed = (global_step - global_step_burnin) / (time.time() - start_time) 329 | pbar.set_description(f"speed: {speed: 4.1f} sps, " + desc) 330 | with torch.no_grad(): 331 | logs = { 332 | "episode_return": np.array(avg_returns).mean(), 333 | "logprobs": b_logprobs.mean(), 334 | "advantages": advantages.mean(), 335 | "returns": returns.mean(), 336 | "values": values.mean(), 337 | "gn": gn, 338 | } 339 | wandb.log( 340 | { 341 | "speed": speed, 342 | **logs, 343 | }, 344 | step=global_step, 345 | ) 346 | 347 | envs.close() 348 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LeanRL developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- 24 | Code in `cleanrl/ddpg_continuous_action.py` and `cleanrl/td3_continuous_action.py` are adapted from https://github.com/sfujim/TD3 25 | 26 | MIT License 27 | 28 | Copyright (c) 2020 Scott Fujimoto 29 | 30 | Permission is hereby granted, free of charge, to any person obtaining a copy 31 | of this software and associated documentation files (the "Software"), to deal 32 | in the Software without restriction, including without limitation the rights 33 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 34 | copies of the Software, and to permit persons to whom the Software is 35 | furnished to do so, subject to the following conditions: 36 | 37 | The above copyright notice and this permission notice shall be included in all 38 | copies or substantial portions of the Software. 39 | 40 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 41 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 42 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 43 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 44 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 45 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 46 | SOFTWARE. 47 | 48 | -------------------------------------------------------------------------------- 49 | Code in `cleanrl/sac_continuous_action.py` is inspired and adapted from [haarnoja/sac](https://github.com/haarnoja/sac), [openai/spinningup](https://github.com/openai/spinningup), [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic), [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3), and [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac). 50 | 51 | - [haarnoja/sac](https://github.com/haarnoja/sac/blob/8258e33633c7e37833cc39315891e77adfbe14b2/LICENSE.txt) 52 | 53 | COPYRIGHT 54 | 55 | All contributions by the University of California: 56 | Copyright (c) 2017, 2018 The Regents of the University of California (Regents) 57 | All rights reserved. 58 | 59 | All other contributions: 60 | Copyright (c) 2017, 2018, the respective contributors 61 | All rights reserved. 62 | 63 | SAC uses a shared copyright model: each contributor holds copyright over 64 | their contributions to the SAC codebase. The project versioning records all such 65 | contribution and copyright details. If a contributor wants to further mark 66 | their specific copyright on a particular contribution, they should indicate 67 | their copyright solely in the commit message of the change when it is 68 | committed. 69 | 70 | LICENSE 71 | 72 | Redistribution and use in source and binary forms, with or without 73 | modification, are permitted provided that the following conditions are met: 74 | 75 | 1. Redistributions of source code must retain the above copyright notice, this 76 | list of conditions and the following disclaimer. 77 | 2. Redistributions in binary form must reproduce the above copyright notice, 78 | this list of conditions and the following disclaimer in the documentation 79 | and/or other materials provided with the distribution. 80 | 81 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 82 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 83 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 84 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 85 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 86 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 87 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 88 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 89 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 90 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 91 | 92 | CONTRIBUTION AGREEMENT 93 | 94 | By contributing to the SAC repository through pull-request, comment, 95 | or otherwise, the contributor releases their content to the 96 | license and copyright terms herein. 97 | 98 | - [openai/spinningup](https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/LICENSE) 99 | 100 | The MIT License 101 | 102 | Copyright (c) 2018 OpenAI (http://openai.com) 103 | 104 | Permission is hereby granted, free of charge, to any person obtaining a copy 105 | of this software and associated documentation files (the "Software"), to deal 106 | in the Software without restriction, including without limitation the rights 107 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 108 | copies of the Software, and to permit persons to whom the Software is 109 | furnished to do so, subject to the following conditions: 110 | 111 | The above copyright notice and this permission notice shall be included in 112 | all copies or substantial portions of the Software. 113 | 114 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 115 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 116 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 117 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 118 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 119 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 120 | THE SOFTWARE. 121 | 122 | - [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3/blob/44e53ff8115e8f4bff1d5218f10c8c7d1a4cfc12/LICENSE) 123 | 124 | The MIT License 125 | 126 | Copyright (c) 2019 Antonin Raffin 127 | 128 | Permission is hereby granted, free of charge, to any person obtaining a copy 129 | of this software and associated documentation files (the "Software"), to deal 130 | in the Software without restriction, including without limitation the rights 131 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 132 | copies of the Software, and to permit persons to whom the Software is 133 | furnished to do so, subject to the following conditions: 134 | 135 | The above copyright notice and this permission notice shall be included in 136 | all copies or substantial portions of the Software. 137 | 138 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 139 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 140 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 141 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 142 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 143 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 144 | THE SOFTWARE. 145 | 146 | - [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac/blob/81c5b536d3a1c5616b2531e446450df412a064fb/LICENSE) 147 | 148 | MIT License 149 | 150 | Copyright (c) 2019 Denis Yarats 151 | 152 | Permission is hereby granted, free of charge, to any person obtaining a copy 153 | of this software and associated documentation files (the "Software"), to deal 154 | in the Software without restriction, including without limitation the rights 155 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 156 | copies of the Software, and to permit persons to whom the Software is 157 | furnished to do so, subject to the following conditions: 158 | 159 | The above copyright notice and this permission notice shall be included in all 160 | copies or substantial portions of the Software. 161 | 162 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 163 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 164 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 165 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 166 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 167 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 168 | SOFTWARE. 169 | 170 | - [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/LICENSE) 171 | 172 | MIT License 173 | 174 | Copyright (c) 2018 Pranjal Tandon 175 | 176 | Permission is hereby granted, free of charge, to any person obtaining a copy 177 | of this software and associated documentation files (the "Software"), to deal 178 | in the Software without restriction, including without limitation the rights 179 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 180 | copies of the Software, and to permit persons to whom the Software is 181 | furnished to do so, subject to the following conditions: 182 | 183 | The above copyright notice and this permission notice shall be included in all 184 | copies or substantial portions of the Software. 185 | 186 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 187 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 188 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 189 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 190 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 191 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 192 | SOFTWARE. 193 | 194 | 195 | --------------------------------------------------------------------------------- 196 | The CONTRIBUTING.md is adopted from https://github.com/entity-neural-network/incubator/blob/2a0c38b30828df78c47b0318c76a4905020618dd/CONTRIBUTING.md 197 | and https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md 198 | 199 | MIT License 200 | 201 | Copyright (c) 2021 Entity Neural Network developers 202 | 203 | Permission is hereby granted, free of charge, to any person obtaining a copy 204 | of this software and associated documentation files (the "Software"), to deal 205 | in the Software without restriction, including without limitation the rights 206 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 207 | copies of the Software, and to permit persons to whom the Software is 208 | furnished to do so, subject to the following conditions: 209 | 210 | The above copyright notice and this permission notice shall be included in all 211 | copies or substantial portions of the Software. 212 | 213 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 214 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 215 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 216 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 217 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 218 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 219 | SOFTWARE. 220 | 221 | 222 | 223 | MIT License 224 | 225 | Copyright (c) 2020 Stable-Baselines Team 226 | 227 | Permission is hereby granted, free of charge, to any person obtaining a copy 228 | of this software and associated documentation files (the "Software"), to deal 229 | in the Software without restriction, including without limitation the rights 230 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 231 | copies of the Software, and to permit persons to whom the Software is 232 | furnished to do so, subject to the following conditions: 233 | 234 | The above copyright notice and this permission notice shall be included in all 235 | copies or substantial portions of the Software. 236 | 237 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 238 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 239 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 240 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 241 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 242 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 243 | SOFTWARE. 244 | 245 | 246 | --------------------------------------------------------------------------------- 247 | The cleanrl/ppo_continuous_action_isaacgym.py is contributed by Nvidia 248 | 249 | SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 250 | SPDX-License-Identifier: MIT 251 | 252 | Permission is hereby granted, free of charge, to any person obtaining a 253 | copy of this software and associated documentation files (the "Software"), 254 | to deal in the Software without restriction, including without limitation 255 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 256 | and/or sell copies of the Software, and to permit persons to whom the 257 | Software is furnished to do so, subject to the following conditions: 258 | 259 | The above copyright notice and this permission notice shall be included in 260 | all copies or substantial portions of the Software. 261 | 262 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 263 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 264 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 265 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 266 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 267 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 268 | DEALINGS IN THE SOFTWARE. 269 | 270 | -------------------------------------------------------------------------------- 271 | 272 | Code in `cleanrl/qdagger_dqn_atari_impalacnn.py` and `cleanrl/qdagger_dqn_atari_jax_impalacnn.py` are adapted from https://github.com/google-research/reincarnating_rl 273 | 274 | **NOTE: the original repo did not fill out the copyright section in their license 275 | so the following copyright notice is copied as is per the license requirement. 276 | See https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/LICENSE#L189 277 | 278 | 279 | Copyright [yyyy] [name of copyright owner] 280 | 281 | Licensed under the Apache License, Version 2.0 (the "License"); 282 | you may not use this file except in compliance with the License. 283 | You may obtain a copy of the License at 284 | 285 | http://www.apache.org/licenses/LICENSE-2.0 286 | 287 | Unless required by applicable law or agreed to in writing, software 288 | distributed under the License is distributed on an "AS IS" BASIS, 289 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 290 | See the License for the specific language governing permissions and 291 | limitations under the License. 292 | -------------------------------------------------------------------------------- /leanrl/sac_continuous_action_torchcompile.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy 2 | import os 3 | 4 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" 5 | 6 | import math 7 | import os 8 | import random 9 | import time 10 | from collections import deque 11 | from dataclasses import dataclass 12 | 13 | import gymnasium as gym 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | import tqdm 20 | import tyro 21 | import wandb 22 | from tensordict import TensorDict, from_module, from_modules 23 | from tensordict.nn import CudaGraphModule, TensorDictModule 24 | 25 | # from stable_baselines3.common.buffers import ReplayBuffer 26 | from torchrl.data import LazyTensorStorage, ReplayBuffer 27 | 28 | 29 | @dataclass 30 | class Args: 31 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 32 | """the name of this experiment""" 33 | seed: int = 1 34 | """seed of the experiment""" 35 | torch_deterministic: bool = True 36 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 37 | cuda: bool = True 38 | """if toggled, cuda will be enabled by default""" 39 | capture_video: bool = False 40 | """whether to capture videos of the agent performances (check out `videos` folder)""" 41 | 42 | # Algorithm specific arguments 43 | env_id: str = "HalfCheetah-v4" 44 | """the environment id of the task""" 45 | total_timesteps: int = 1000000 46 | """total timesteps of the experiments""" 47 | buffer_size: int = int(1e6) 48 | """the replay memory buffer size""" 49 | gamma: float = 0.99 50 | """the discount factor gamma""" 51 | tau: float = 0.005 52 | """target smoothing coefficient (default: 0.005)""" 53 | batch_size: int = 256 54 | """the batch size of sample from the reply memory""" 55 | learning_starts: int = 5e3 56 | """timestep to start learning""" 57 | policy_lr: float = 3e-4 58 | """the learning rate of the policy network optimizer""" 59 | q_lr: float = 1e-3 60 | """the learning rate of the Q network network optimizer""" 61 | policy_frequency: int = 2 62 | """the frequency of training policy (delayed)""" 63 | target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2. 64 | """the frequency of updates for the target nerworks""" 65 | alpha: float = 0.2 66 | """Entropy regularization coefficient.""" 67 | autotune: bool = True 68 | """automatic tuning of the entropy coefficient""" 69 | 70 | compile: bool = False 71 | """whether to use torch.compile.""" 72 | cudagraphs: bool = False 73 | """whether to use cudagraphs on top of compile.""" 74 | 75 | measure_burnin: int = 3 76 | """Number of burn-in iterations for speed measure.""" 77 | 78 | 79 | def make_env(env_id, seed, idx, capture_video, run_name): 80 | def thunk(): 81 | if capture_video and idx == 0: 82 | env = gym.make(env_id, render_mode="rgb_array") 83 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 84 | else: 85 | env = gym.make(env_id) 86 | env = gym.wrappers.RecordEpisodeStatistics(env) 87 | env.action_space.seed(seed) 88 | return env 89 | 90 | return thunk 91 | 92 | 93 | # ALGO LOGIC: initialize agent here: 94 | class SoftQNetwork(nn.Module): 95 | def __init__(self, env, n_act, n_obs, device=None): 96 | super().__init__() 97 | self.fc1 = nn.Linear(n_act + n_obs, 256, device=device) 98 | self.fc2 = nn.Linear(256, 256, device=device) 99 | self.fc3 = nn.Linear(256, 1, device=device) 100 | 101 | def forward(self, x, a): 102 | x = torch.cat([x, a], 1) 103 | x = F.relu(self.fc1(x)) 104 | x = F.relu(self.fc2(x)) 105 | x = self.fc3(x) 106 | return x 107 | 108 | 109 | LOG_STD_MAX = 2 110 | LOG_STD_MIN = -5 111 | 112 | 113 | class Actor(nn.Module): 114 | def __init__(self, env, n_obs, n_act, device=None): 115 | super().__init__() 116 | self.fc1 = nn.Linear(n_obs, 256, device=device) 117 | self.fc2 = nn.Linear(256, 256, device=device) 118 | self.fc_mean = nn.Linear(256, n_act, device=device) 119 | self.fc_logstd = nn.Linear(256, n_act, device=device) 120 | # action rescaling 121 | self.register_buffer( 122 | "action_scale", 123 | torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32, device=device), 124 | ) 125 | self.register_buffer( 126 | "action_bias", 127 | torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32, device=device), 128 | ) 129 | 130 | def forward(self, x): 131 | x = F.relu(self.fc1(x)) 132 | x = F.relu(self.fc2(x)) 133 | mean = self.fc_mean(x) 134 | log_std = self.fc_logstd(x) 135 | log_std = torch.tanh(log_std) 136 | log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats 137 | 138 | return mean, log_std 139 | 140 | def get_action(self, x): 141 | mean, log_std = self(x) 142 | std = log_std.exp() 143 | normal = torch.distributions.Normal(mean, std) 144 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 145 | y_t = torch.tanh(x_t) 146 | action = y_t * self.action_scale + self.action_bias 147 | log_prob = normal.log_prob(x_t) 148 | # Enforcing Action Bound 149 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) 150 | log_prob = log_prob.sum(1, keepdim=True) 151 | mean = torch.tanh(mean) * self.action_scale + self.action_bias 152 | return action, log_prob, mean 153 | 154 | 155 | if __name__ == "__main__": 156 | args = tyro.cli(Args) 157 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" 158 | 159 | wandb.init( 160 | project="sac_continuous_action", 161 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 162 | config=vars(args), 163 | save_code=True, 164 | ) 165 | 166 | # TRY NOT TO MODIFY: seeding 167 | random.seed(args.seed) 168 | np.random.seed(args.seed) 169 | torch.manual_seed(args.seed) 170 | torch.backends.cudnn.deterministic = args.torch_deterministic 171 | 172 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 173 | 174 | # env setup 175 | envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) 176 | n_act = math.prod(envs.single_action_space.shape) 177 | n_obs = math.prod(envs.single_observation_space.shape) 178 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 179 | 180 | max_action = float(envs.single_action_space.high[0]) 181 | 182 | actor = Actor(envs, device=device, n_act=n_act, n_obs=n_obs) 183 | actor_detach = Actor(envs, device=device, n_act=n_act, n_obs=n_obs) 184 | # Copy params to actor_detach without grad 185 | from_module(actor).data.to_module(actor_detach) 186 | policy = TensorDictModule(actor_detach.get_action, in_keys=["observation"], out_keys=["action"]) 187 | 188 | def get_q_params(): 189 | qf1 = SoftQNetwork(envs, device=device, n_act=n_act, n_obs=n_obs) 190 | qf2 = SoftQNetwork(envs, device=device, n_act=n_act, n_obs=n_obs) 191 | qnet_params = from_modules(qf1, qf2, as_module=True) 192 | qnet_target = qnet_params.data.clone() 193 | 194 | # discard params of net 195 | qnet = SoftQNetwork(envs, device="meta", n_act=n_act, n_obs=n_obs) 196 | qnet_params.to_module(qnet) 197 | 198 | return qnet_params, qnet_target, qnet 199 | 200 | qnet_params, qnet_target, qnet = get_q_params() 201 | 202 | q_optimizer = optim.Adam(qnet.parameters(), lr=args.q_lr, capturable=args.cudagraphs and not args.compile) 203 | actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, capturable=args.cudagraphs and not args.compile) 204 | 205 | # Automatic entropy tuning 206 | if args.autotune: 207 | target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item() 208 | log_alpha = torch.zeros(1, requires_grad=True, device=device) 209 | alpha = log_alpha.detach().exp() 210 | a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, capturable=args.cudagraphs and not args.compile) 211 | else: 212 | alpha = torch.as_tensor(args.alpha, device=device) 213 | 214 | envs.single_observation_space.dtype = np.float32 215 | rb = ReplayBuffer(storage=LazyTensorStorage(args.buffer_size, device=device)) 216 | 217 | def batched_qf(params, obs, action, next_q_value=None): 218 | with params.to_module(qnet): 219 | vals = qnet(obs, action) 220 | if next_q_value is not None: 221 | loss_val = F.mse_loss(vals.view(-1), next_q_value) 222 | return loss_val 223 | return vals 224 | 225 | def update_main(data): 226 | # optimize the model 227 | q_optimizer.zero_grad() 228 | with torch.no_grad(): 229 | next_state_actions, next_state_log_pi, _ = actor.get_action(data["next_observations"]) 230 | qf_next_target = torch.vmap(batched_qf, (0, None, None))( 231 | qnet_target, data["next_observations"], next_state_actions 232 | ) 233 | min_qf_next_target = qf_next_target.min(dim=0).values - alpha * next_state_log_pi 234 | next_q_value = data["rewards"].flatten() + ( 235 | ~data["dones"].flatten() 236 | ).float() * args.gamma * min_qf_next_target.view(-1) 237 | 238 | qf_a_values = torch.vmap(batched_qf, (0, None, None, None))( 239 | qnet_params, data["observations"], data["actions"], next_q_value 240 | ) 241 | qf_loss = qf_a_values.sum(0) 242 | 243 | qf_loss.backward() 244 | q_optimizer.step() 245 | return TensorDict(qf_loss=qf_loss.detach()) 246 | 247 | def update_pol(data): 248 | actor_optimizer.zero_grad() 249 | pi, log_pi, _ = actor.get_action(data["observations"]) 250 | qf_pi = torch.vmap(batched_qf, (0, None, None))(qnet_params.data, data["observations"], pi) 251 | min_qf_pi = qf_pi.min(0).values 252 | actor_loss = ((alpha * log_pi) - min_qf_pi).mean() 253 | 254 | actor_loss.backward() 255 | actor_optimizer.step() 256 | 257 | if args.autotune: 258 | a_optimizer.zero_grad() 259 | with torch.no_grad(): 260 | _, log_pi, _ = actor.get_action(data["observations"]) 261 | alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean() 262 | 263 | alpha_loss.backward() 264 | a_optimizer.step() 265 | return TensorDict(alpha=alpha.detach(), actor_loss=actor_loss.detach(), alpha_loss=alpha_loss.detach()) 266 | 267 | def extend_and_sample(transition): 268 | rb.extend(transition) 269 | return rb.sample(args.batch_size) 270 | 271 | is_extend_compiled = False 272 | if args.compile: 273 | mode = None # "reduce-overhead" if not args.cudagraphs else None 274 | update_main = torch.compile(update_main, mode=mode) 275 | update_pol = torch.compile(update_pol, mode=mode) 276 | policy = torch.compile(policy, mode=mode) 277 | 278 | if args.cudagraphs: 279 | update_main = CudaGraphModule(update_main, in_keys=[], out_keys=[]) 280 | update_pol = CudaGraphModule(update_pol, in_keys=[], out_keys=[]) 281 | # policy = CudaGraphModule(policy) 282 | 283 | # TRY NOT TO MODIFY: start the game 284 | obs, _ = envs.reset(seed=args.seed) 285 | obs = torch.as_tensor(obs, device=device, dtype=torch.float) 286 | pbar = tqdm.tqdm(range(args.total_timesteps)) 287 | start_time = None 288 | max_ep_ret = -float("inf") 289 | avg_returns = deque(maxlen=20) 290 | desc = "" 291 | 292 | for global_step in pbar: 293 | if global_step == args.measure_burnin + args.learning_starts: 294 | start_time = time.time() 295 | measure_burnin = global_step 296 | 297 | # ALGO LOGIC: put action logic here 298 | if global_step < args.learning_starts: 299 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 300 | else: 301 | actions = policy(obs) 302 | actions = actions.cpu().numpy() 303 | 304 | # TRY NOT TO MODIFY: execute the game and log data. 305 | next_obs, rewards, terminations, truncations, infos = envs.step(actions) 306 | 307 | # TRY NOT TO MODIFY: record rewards for plotting purposes 308 | if "final_info" in infos: 309 | for info in infos["final_info"]: 310 | r = float(info["episode"]["r"]) 311 | max_ep_ret = max(max_ep_ret, r) 312 | avg_returns.append(r) 313 | desc = ( 314 | f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 315 | ) 316 | 317 | # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` 318 | next_obs = torch.as_tensor(next_obs, device=device, dtype=torch.float) 319 | real_next_obs = next_obs.clone() 320 | for idx, trunc in enumerate(truncations): 321 | if trunc: 322 | real_next_obs[idx] = torch.as_tensor(infos["final_observation"][idx], device=device, dtype=torch.float) 323 | # obs = torch.as_tensor(obs, device=device, dtype=torch.float) 324 | transition = TensorDict( 325 | observations=obs, 326 | next_observations=real_next_obs, 327 | actions=torch.as_tensor(actions, device=device, dtype=torch.float), 328 | rewards=torch.as_tensor(rewards, device=device, dtype=torch.float), 329 | terminations=terminations, 330 | dones=terminations, 331 | batch_size=obs.shape[0], 332 | device=device, 333 | ) 334 | 335 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 336 | obs = next_obs 337 | data = extend_and_sample(transition) 338 | 339 | # ALGO LOGIC: training. 340 | if global_step > args.learning_starts: 341 | out_main = update_main(data) 342 | if global_step % args.policy_frequency == 0: # TD 3 Delayed update support 343 | for _ in range( 344 | args.policy_frequency 345 | ): # compensate for the delay by doing 'actor_update_interval' instead of 1 346 | out_main.update(update_pol(data)) 347 | 348 | alpha.copy_(log_alpha.detach().exp()) 349 | 350 | # update the target networks 351 | if global_step % args.target_network_frequency == 0: 352 | # lerp is defined as x' = x + w (y-x), which is equivalent to x' = (1-w) x + w y 353 | qnet_target.lerp_(qnet_params.data, args.tau) 354 | 355 | if global_step % 100 == 0 and start_time is not None: 356 | speed = (global_step - measure_burnin) / (time.time() - start_time) 357 | pbar.set_description(f"{speed: 4.4f} sps, " + desc) 358 | with torch.no_grad(): 359 | logs = { 360 | "episode_return": torch.tensor(avg_returns).mean(), 361 | "actor_loss": out_main["actor_loss"].mean(), 362 | "alpha_loss": out_main.get("alpha_loss", 0), 363 | "qf_loss": out_main["qf_loss"].mean(), 364 | } 365 | wandb.log( 366 | { 367 | "speed": speed, 368 | **logs, 369 | }, 370 | step=global_step, 371 | ) 372 | 373 | envs.close() 374 | -------------------------------------------------------------------------------- /leanrl/ppo_continuous_action_torchcompile.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy 2 | import os 3 | 4 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" 5 | 6 | import math 7 | import os 8 | import random 9 | import time 10 | from collections import deque 11 | from dataclasses import dataclass 12 | from typing import Tuple 13 | 14 | import gymnasium as gym 15 | import numpy as np 16 | import tensordict 17 | import torch 18 | import torch.nn as nn 19 | import torch.optim as optim 20 | import tqdm 21 | import tyro 22 | import wandb 23 | from tensordict import from_module 24 | from tensordict.nn import CudaGraphModule 25 | from torch.distributions.normal import Normal 26 | 27 | 28 | @dataclass 29 | class Args: 30 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 31 | """the name of this experiment""" 32 | seed: int = 1 33 | """seed of the experiment""" 34 | torch_deterministic: bool = True 35 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 36 | cuda: bool = True 37 | """if toggled, cuda will be enabled by default""" 38 | capture_video: bool = False 39 | """whether to capture videos of the agent performances (check out `videos` folder)""" 40 | 41 | # Algorithm specific arguments 42 | env_id: str = "HalfCheetah-v4" 43 | """the id of the environment""" 44 | total_timesteps: int = 1000000 45 | """total timesteps of the experiments""" 46 | learning_rate: float = 3e-4 47 | """the learning rate of the optimizer""" 48 | num_envs: int = 1 49 | """the number of parallel game environments""" 50 | num_steps: int = 2048 51 | """the number of steps to run in each environment per policy rollout""" 52 | anneal_lr: bool = True 53 | """Toggle learning rate annealing for policy and value networks""" 54 | gamma: float = 0.99 55 | """the discount factor gamma""" 56 | gae_lambda: float = 0.95 57 | """the lambda for the general advantage estimation""" 58 | num_minibatches: int = 32 59 | """the number of mini-batches""" 60 | update_epochs: int = 10 61 | """the K epochs to update the policy""" 62 | norm_adv: bool = True 63 | """Toggles advantages normalization""" 64 | clip_coef: float = 0.2 65 | """the surrogate clipping coefficient""" 66 | clip_vloss: bool = True 67 | """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" 68 | ent_coef: float = 0.0 69 | """coefficient of the entropy""" 70 | vf_coef: float = 0.5 71 | """coefficient of the value function""" 72 | max_grad_norm: float = 0.5 73 | """the maximum norm for the gradient clipping""" 74 | target_kl: float = None 75 | """the target KL divergence threshold""" 76 | 77 | # to be filled in runtime 78 | batch_size: int = 0 79 | """the batch size (computed in runtime)""" 80 | minibatch_size: int = 0 81 | """the mini-batch size (computed in runtime)""" 82 | num_iterations: int = 0 83 | """the number of iterations (computed in runtime)""" 84 | 85 | measure_burnin: int = 3 86 | """Number of burn-in iterations for speed measure.""" 87 | 88 | compile: bool = False 89 | """whether to use torch.compile.""" 90 | cudagraphs: bool = False 91 | """whether to use cudagraphs on top of compile.""" 92 | 93 | 94 | def make_env(env_id, idx, capture_video, run_name, gamma): 95 | def thunk(): 96 | if capture_video and idx == 0: 97 | env = gym.make(env_id, render_mode="rgb_array") 98 | env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") 99 | else: 100 | env = gym.make(env_id) 101 | env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space 102 | env = gym.wrappers.RecordEpisodeStatistics(env) 103 | env = gym.wrappers.ClipAction(env) 104 | env = gym.wrappers.NormalizeObservation(env) 105 | env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) 106 | env = gym.wrappers.NormalizeReward(env, gamma=gamma) 107 | env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) 108 | return env 109 | 110 | return thunk 111 | 112 | 113 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 114 | torch.nn.init.orthogonal_(layer.weight, std) 115 | torch.nn.init.constant_(layer.bias, bias_const) 116 | return layer 117 | 118 | 119 | class Agent(nn.Module): 120 | def __init__(self, n_obs, n_act, device=None): 121 | super().__init__() 122 | self.critic = nn.Sequential( 123 | layer_init(nn.Linear(n_obs, 64, device=device)), 124 | nn.Tanh(), 125 | layer_init(nn.Linear(64, 64, device=device)), 126 | nn.Tanh(), 127 | layer_init(nn.Linear(64, 1, device=device), std=1.0), 128 | ) 129 | self.actor_mean = nn.Sequential( 130 | layer_init(nn.Linear(n_obs, 64, device=device)), 131 | nn.Tanh(), 132 | layer_init(nn.Linear(64, 64, device=device)), 133 | nn.Tanh(), 134 | layer_init(nn.Linear(64, n_act, device=device), std=0.01), 135 | ) 136 | self.actor_logstd = nn.Parameter(torch.zeros(1, n_act, device=device)) 137 | 138 | def get_value(self, x): 139 | return self.critic(x) 140 | 141 | def get_action_and_value(self, obs, action=None): 142 | action_mean = self.actor_mean(obs) 143 | action_logstd = self.actor_logstd.expand_as(action_mean) 144 | action_std = torch.exp(action_logstd) 145 | probs = Normal(action_mean, action_std) 146 | if action is None: 147 | action = action_mean + action_std * torch.randn_like(action_mean) 148 | return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(obs) 149 | 150 | 151 | def gae(next_obs, next_done, container): 152 | # bootstrap value if not done 153 | next_value = get_value(next_obs).reshape(-1) 154 | lastgaelam = 0 155 | nextnonterminals = (~container["dones"]).float().unbind(0) 156 | vals = container["vals"] 157 | vals_unbind = vals.unbind(0) 158 | rewards = container["rewards"].unbind(0) 159 | 160 | advantages = [] 161 | nextnonterminal = (~next_done).float() 162 | nextvalues = next_value 163 | for t in range(args.num_steps - 1, -1, -1): 164 | cur_val = vals_unbind[t] 165 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - cur_val 166 | advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam) 167 | lastgaelam = advantages[-1] 168 | 169 | nextnonterminal = nextnonterminals[t] 170 | nextvalues = cur_val 171 | 172 | advantages = container["advantages"] = torch.stack(list(reversed(advantages))) 173 | container["returns"] = advantages + vals 174 | return container 175 | 176 | 177 | def rollout(obs, done, avg_returns=[]): 178 | ts = [] 179 | for step in range(args.num_steps): 180 | # ALGO LOGIC: action logic 181 | action, logprob, _, value = policy(obs=obs) 182 | 183 | # TRY NOT TO MODIFY: execute the game and log data. 184 | next_obs, reward, next_done, infos = step_func(action) 185 | 186 | if "final_info" in infos: 187 | for info in infos["final_info"]: 188 | r = float(info["episode"]["r"].reshape(())) 189 | # max_ep_ret = max(max_ep_ret, r) 190 | avg_returns.append(r) 191 | # desc = f"global_step={global_step}, episodic_return={torch.tensor(avg_returns).mean(): 4.2f} (max={max_ep_ret: 4.2f})" 192 | 193 | ts.append( 194 | tensordict.TensorDict._new_unsafe( 195 | obs=obs, 196 | # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action) 197 | dones=done, 198 | vals=value.flatten(), 199 | actions=action, 200 | logprobs=logprob, 201 | rewards=reward, 202 | batch_size=(args.num_envs,), 203 | ) 204 | ) 205 | 206 | obs = next_obs = next_obs.to(device, non_blocking=True) 207 | done = next_done.to(device, non_blocking=True) 208 | 209 | container = torch.stack(ts, 0).to(device) 210 | return next_obs, done, container 211 | 212 | 213 | def update(obs, actions, logprobs, advantages, returns, vals): 214 | optimizer.zero_grad() 215 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions) 216 | logratio = newlogprob - logprobs 217 | ratio = logratio.exp() 218 | 219 | with torch.no_grad(): 220 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 221 | old_approx_kl = (-logratio).mean() 222 | approx_kl = ((ratio - 1) - logratio).mean() 223 | clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean() 224 | 225 | if args.norm_adv: 226 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 227 | 228 | # Policy loss 229 | pg_loss1 = -advantages * ratio 230 | pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 231 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 232 | 233 | # Value loss 234 | newvalue = newvalue.view(-1) 235 | if args.clip_vloss: 236 | v_loss_unclipped = (newvalue - returns) ** 2 237 | v_clipped = vals + torch.clamp( 238 | newvalue - vals, 239 | -args.clip_coef, 240 | args.clip_coef, 241 | ) 242 | v_loss_clipped = (v_clipped - returns) ** 2 243 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 244 | v_loss = 0.5 * v_loss_max.mean() 245 | else: 246 | v_loss = 0.5 * ((newvalue - returns) ** 2).mean() 247 | 248 | entropy_loss = entropy.mean() 249 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 250 | 251 | loss.backward() 252 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 253 | optimizer.step() 254 | 255 | return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn 256 | 257 | 258 | update = tensordict.nn.TensorDictModule( 259 | update, 260 | in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"], 261 | out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"], 262 | ) 263 | 264 | if __name__ == "__main__": 265 | args = tyro.cli(Args) 266 | 267 | batch_size = int(args.num_envs * args.num_steps) 268 | args.minibatch_size = batch_size // args.num_minibatches 269 | args.batch_size = args.num_minibatches * args.minibatch_size 270 | args.num_iterations = args.total_timesteps // args.batch_size 271 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" 272 | 273 | wandb.init( 274 | project="ppo_continuous_action", 275 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 276 | config=vars(args), 277 | save_code=True, 278 | ) 279 | 280 | # TRY NOT TO MODIFY: seeding 281 | random.seed(args.seed) 282 | np.random.seed(args.seed) 283 | torch.manual_seed(args.seed) 284 | torch.backends.cudnn.deterministic = args.torch_deterministic 285 | 286 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 287 | 288 | ####### Environment setup ####### 289 | envs = gym.vector.SyncVectorEnv( 290 | [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)] 291 | ) 292 | n_act = math.prod(envs.single_action_space.shape) 293 | n_obs = math.prod(envs.single_observation_space.shape) 294 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 295 | 296 | # Register step as a special op not to graph break 297 | # @torch.library.custom_op("mylib::step", mutates_args=()) 298 | def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 299 | next_obs_np, reward, terminations, truncations, info = envs.step(action.cpu().numpy()) 300 | next_done = np.logical_or(terminations, truncations) 301 | return torch.as_tensor(next_obs_np, dtype=torch.float), torch.as_tensor(reward), torch.as_tensor(next_done), info 302 | 303 | ####### Agent ####### 304 | agent = Agent(n_obs, n_act, device=device) 305 | # Make a version of agent with detached params 306 | agent_inference = Agent(n_obs, n_act, device=device) 307 | agent_inference_p = from_module(agent).data 308 | agent_inference_p.to_module(agent_inference) 309 | 310 | ####### Optimizer ####### 311 | optimizer = optim.Adam( 312 | agent.parameters(), 313 | lr=torch.tensor(args.learning_rate, device=device), 314 | eps=1e-5, 315 | capturable=args.cudagraphs and not args.compile, 316 | ) 317 | 318 | ####### Executables ####### 319 | # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule 320 | policy = agent_inference.get_action_and_value 321 | get_value = agent_inference.get_value 322 | 323 | # Compile policy 324 | if args.compile: 325 | policy = torch.compile(policy) 326 | gae = torch.compile(gae, fullgraph=True) 327 | update = torch.compile(update) 328 | 329 | if args.cudagraphs: 330 | policy = CudaGraphModule(policy) 331 | gae = CudaGraphModule(gae) 332 | update = CudaGraphModule(update) 333 | 334 | avg_returns = deque(maxlen=20) 335 | global_step = 0 336 | container_local = None 337 | next_obs = torch.tensor(envs.reset()[0], device=device, dtype=torch.float) 338 | next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool) 339 | # max_ep_ret = -float("inf") 340 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) 341 | # desc = "" 342 | global_step_burnin = None 343 | for iteration in pbar: 344 | if iteration == args.measure_burnin: 345 | global_step_burnin = global_step 346 | start_time = time.time() 347 | 348 | # Annealing the rate if instructed to do so. 349 | if args.anneal_lr: 350 | frac = 1.0 - (iteration - 1.0) / args.num_iterations 351 | lrnow = frac * args.learning_rate 352 | optimizer.param_groups[0]["lr"].copy_(lrnow) 353 | 354 | torch.compiler.cudagraph_mark_step_begin() 355 | next_obs, next_done, container = rollout(next_obs, next_done, avg_returns=avg_returns) 356 | global_step += container.numel() 357 | 358 | container = gae(next_obs, next_done, container) 359 | container_flat = container.view(-1) 360 | 361 | # Optimizing the policy and value network 362 | clipfracs = [] 363 | for epoch in range(args.update_epochs): 364 | b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size) 365 | for b in b_inds: 366 | container_local = container_flat[b] 367 | 368 | out = update(container_local, tensordict_out=tensordict.TensorDict()) 369 | if args.target_kl is not None and out["approx_kl"] > args.target_kl: 370 | break 371 | else: 372 | continue 373 | break 374 | 375 | if global_step_burnin is not None and iteration % 10 == 0: 376 | speed = (global_step - global_step_burnin) / (time.time() - start_time) 377 | r = container["rewards"].mean() 378 | r_max = container["rewards"].max() 379 | avg_returns_t = torch.tensor(avg_returns).mean() 380 | 381 | with torch.no_grad(): 382 | logs = { 383 | "episode_return": np.array(avg_returns).mean(), 384 | "logprobs": container["logprobs"].mean(), 385 | "advantages": container["advantages"].mean(), 386 | "returns": container["returns"].mean(), 387 | "vals": container["vals"].mean(), 388 | "gn": out["gn"].mean(), 389 | } 390 | 391 | lr = optimizer.param_groups[0]["lr"] 392 | pbar.set_description( 393 | f"speed: {speed: 4.1f} sps, " 394 | f"reward avg: {r :4.2f}, " 395 | f"reward max: {r_max:4.2f}, " 396 | f"returns: {avg_returns_t: 4.2f}," 397 | f"lr: {lr: 4.2f}" 398 | ) 399 | wandb.log( 400 | {"speed": speed, "episode_return": avg_returns_t, "r": r, "r_max": r_max, "lr": lr, **logs}, step=global_step 401 | ) 402 | 403 | envs.close() 404 | -------------------------------------------------------------------------------- /leanrl/ppo_atari_envpool_torchcompile.py: -------------------------------------------------------------------------------- 1 | # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy 2 | import os 3 | 4 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" 5 | 6 | import os 7 | import random 8 | import time 9 | from collections import deque 10 | from dataclasses import dataclass 11 | 12 | import envpool 13 | 14 | # import gymnasium as gym 15 | import gym 16 | import numpy as np 17 | import tensordict 18 | import torch 19 | import torch.nn as nn 20 | import torch.optim as optim 21 | import tqdm 22 | import tyro 23 | import wandb 24 | from tensordict import from_module 25 | from tensordict.nn import CudaGraphModule 26 | from torch.distributions.categorical import Categorical, Distribution 27 | 28 | Distribution.set_default_validate_args(False) 29 | 30 | # This is a quick fix while waiting for https://github.com/pytorch/pytorch/pull/138080 to land 31 | Categorical.logits = property(Categorical.__dict__["logits"].wrapped) 32 | Categorical.probs = property(Categorical.__dict__["probs"].wrapped) 33 | 34 | torch.set_float32_matmul_precision("high") 35 | 36 | 37 | @dataclass 38 | class Args: 39 | exp_name: str = os.path.basename(__file__)[: -len(".py")] 40 | """the name of this experiment""" 41 | seed: int = 1 42 | """seed of the experiment""" 43 | torch_deterministic: bool = True 44 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 45 | cuda: bool = True 46 | """if toggled, cuda will be enabled by default""" 47 | capture_video: bool = False 48 | """whether to capture videos of the agent performances (check out `videos` folder)""" 49 | 50 | # Algorithm specific arguments 51 | env_id: str = "Breakout-v5" 52 | """the id of the environment""" 53 | total_timesteps: int = 10000000 54 | """total timesteps of the experiments""" 55 | learning_rate: float = 2.5e-4 56 | """the learning rate of the optimizer""" 57 | num_envs: int = 8 58 | """the number of parallel game environments""" 59 | num_steps: int = 128 60 | """the number of steps to run in each environment per policy rollout""" 61 | anneal_lr: bool = True 62 | """Toggle learning rate annealing for policy and value networks""" 63 | gamma: float = 0.99 64 | """the discount factor gamma""" 65 | gae_lambda: float = 0.95 66 | """the lambda for the general advantage estimation""" 67 | num_minibatches: int = 4 68 | """the number of mini-batches""" 69 | update_epochs: int = 4 70 | """the K epochs to update the policy""" 71 | norm_adv: bool = True 72 | """Toggles advantages normalization""" 73 | clip_coef: float = 0.1 74 | """the surrogate clipping coefficient""" 75 | clip_vloss: bool = True 76 | """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" 77 | ent_coef: float = 0.01 78 | """coefficient of the entropy""" 79 | vf_coef: float = 0.5 80 | """coefficient of the value function""" 81 | max_grad_norm: float = 0.5 82 | """the maximum norm for the gradient clipping""" 83 | target_kl: float = None 84 | """the target KL divergence threshold""" 85 | 86 | # to be filled in runtime 87 | batch_size: int = 0 88 | """the batch size (computed in runtime)""" 89 | minibatch_size: int = 0 90 | """the mini-batch size (computed in runtime)""" 91 | num_iterations: int = 0 92 | """the number of iterations (computed in runtime)""" 93 | 94 | measure_burnin: int = 3 95 | """Number of burn-in iterations for speed measure.""" 96 | 97 | compile: bool = False 98 | """whether to use torch.compile.""" 99 | cudagraphs: bool = False 100 | """whether to use cudagraphs on top of compile.""" 101 | 102 | 103 | class RecordEpisodeStatistics(gym.Wrapper): 104 | def __init__(self, env, deque_size=100): 105 | super().__init__(env) 106 | self.num_envs = getattr(env, "num_envs", 1) 107 | self.episode_returns = None 108 | self.episode_lengths = None 109 | 110 | def reset(self, **kwargs): 111 | observations = super().reset(**kwargs) 112 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 113 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 114 | self.lives = np.zeros(self.num_envs, dtype=np.int32) 115 | self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) 116 | self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 117 | return observations 118 | 119 | def step(self, action): 120 | observations, rewards, dones, infos = super().step(action) 121 | self.episode_returns += infos["reward"] 122 | self.episode_lengths += 1 123 | self.returned_episode_returns[:] = self.episode_returns 124 | self.returned_episode_lengths[:] = self.episode_lengths 125 | self.episode_returns *= 1 - infos["terminated"] 126 | self.episode_lengths *= 1 - infos["terminated"] 127 | infos["r"] = self.returned_episode_returns 128 | infos["l"] = self.returned_episode_lengths 129 | return ( 130 | observations, 131 | rewards, 132 | dones, 133 | infos, 134 | ) 135 | 136 | 137 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 138 | torch.nn.init.orthogonal_(layer.weight, std) 139 | torch.nn.init.constant_(layer.bias, bias_const) 140 | return layer 141 | 142 | 143 | class Agent(nn.Module): 144 | def __init__(self, envs, device=None): 145 | super().__init__() 146 | self.network = nn.Sequential( 147 | layer_init(nn.Conv2d(4, 32, 8, stride=4, device=device)), 148 | nn.ReLU(), 149 | layer_init(nn.Conv2d(32, 64, 4, stride=2, device=device)), 150 | nn.ReLU(), 151 | layer_init(nn.Conv2d(64, 64, 3, stride=1, device=device)), 152 | nn.ReLU(), 153 | nn.Flatten(), 154 | layer_init(nn.Linear(64 * 7 * 7, 512, device=device)), 155 | nn.ReLU(), 156 | ) 157 | self.actor = layer_init(nn.Linear(512, envs.single_action_space.n, device=device), std=0.01) 158 | self.critic = layer_init(nn.Linear(512, 1, device=device), std=1) 159 | 160 | def get_value(self, x): 161 | return self.critic(self.network(x / 255.0)) 162 | 163 | def get_action_and_value(self, obs, action=None): 164 | hidden = self.network(obs / 255.0) 165 | logits = self.actor(hidden) 166 | probs = Categorical(logits=logits) 167 | if action is None: 168 | action = probs.sample() 169 | return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) 170 | 171 | 172 | def gae(next_obs, next_done, container): 173 | # bootstrap value if not done 174 | next_value = get_value(next_obs).reshape(-1) 175 | lastgaelam = 0 176 | nextnonterminals = (~container["dones"]).float().unbind(0) 177 | vals = container["vals"] 178 | vals_unbind = vals.unbind(0) 179 | rewards = container["rewards"].unbind(0) 180 | 181 | advantages = [] 182 | nextnonterminal = (~next_done).float() 183 | nextvalues = next_value 184 | for t in range(args.num_steps - 1, -1, -1): 185 | cur_val = vals_unbind[t] 186 | delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - cur_val 187 | advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam) 188 | lastgaelam = advantages[-1] 189 | 190 | nextnonterminal = nextnonterminals[t] 191 | nextvalues = cur_val 192 | 193 | advantages = container["advantages"] = torch.stack(list(reversed(advantages))) 194 | container["returns"] = advantages + vals 195 | return container 196 | 197 | 198 | def rollout(obs, done, avg_returns=[]): 199 | ts = [] 200 | for step in range(args.num_steps): 201 | torch.compiler.cudagraph_mark_step_begin() 202 | action, logprob, _, value = policy(obs=obs) 203 | next_obs_np, reward, next_done, info = envs.step(action.cpu().numpy()) 204 | next_obs = torch.as_tensor(next_obs_np) 205 | reward = torch.as_tensor(reward) 206 | next_done = torch.as_tensor(next_done) 207 | 208 | idx = next_done 209 | if idx.any(): 210 | idx = idx & torch.as_tensor(info["lives"] == 0, device=next_done.device, dtype=torch.bool) 211 | if idx.any(): 212 | r = torch.as_tensor(info["r"]) 213 | avg_returns.extend(r[idx]) 214 | 215 | ts.append( 216 | tensordict.TensorDict._new_unsafe( 217 | obs=obs, 218 | # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action) 219 | dones=done, 220 | vals=value.flatten(), 221 | actions=action, 222 | logprobs=logprob, 223 | rewards=reward, 224 | batch_size=(args.num_envs,), 225 | ) 226 | ) 227 | 228 | obs = next_obs = next_obs.to(device, non_blocking=True) 229 | done = next_done.to(device, non_blocking=True) 230 | 231 | container = torch.stack(ts, 0).to(device) 232 | return next_obs, done, container 233 | 234 | 235 | def update(obs, actions, logprobs, advantages, returns, vals): 236 | optimizer.zero_grad() 237 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions) 238 | logratio = newlogprob - logprobs 239 | ratio = logratio.exp() 240 | 241 | with torch.no_grad(): 242 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 243 | old_approx_kl = (-logratio).mean() 244 | approx_kl = ((ratio - 1) - logratio).mean() 245 | clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean() 246 | 247 | if args.norm_adv: 248 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 249 | 250 | # Policy loss 251 | pg_loss1 = -advantages * ratio 252 | pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 253 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 254 | 255 | # Value loss 256 | newvalue = newvalue.view(-1) 257 | if args.clip_vloss: 258 | v_loss_unclipped = (newvalue - returns) ** 2 259 | v_clipped = vals + torch.clamp( 260 | newvalue - vals, 261 | -args.clip_coef, 262 | args.clip_coef, 263 | ) 264 | v_loss_clipped = (v_clipped - returns) ** 2 265 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 266 | v_loss = 0.5 * v_loss_max.mean() 267 | else: 268 | v_loss = 0.5 * ((newvalue - returns) ** 2).mean() 269 | 270 | entropy_loss = entropy.mean() 271 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 272 | 273 | loss.backward() 274 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 275 | optimizer.step() 276 | 277 | return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn 278 | 279 | 280 | update = tensordict.nn.TensorDictModule( 281 | update, 282 | in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"], 283 | out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"], 284 | ) 285 | 286 | if __name__ == "__main__": 287 | args = tyro.cli(Args) 288 | 289 | batch_size = int(args.num_envs * args.num_steps) 290 | args.minibatch_size = batch_size // args.num_minibatches 291 | args.batch_size = args.num_minibatches * args.minibatch_size 292 | args.num_iterations = args.total_timesteps // args.batch_size 293 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{args.compile}__{args.cudagraphs}" 294 | 295 | wandb.init( 296 | project="ppo_atari", 297 | name=f"{os.path.splitext(os.path.basename(__file__))[0]}-{run_name}", 298 | config=vars(args), 299 | save_code=True, 300 | ) 301 | 302 | # TRY NOT TO MODIFY: seeding 303 | random.seed(args.seed) 304 | np.random.seed(args.seed) 305 | torch.manual_seed(args.seed) 306 | torch.backends.cudnn.deterministic = args.torch_deterministic 307 | 308 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 309 | 310 | ####### Environment setup ####### 311 | envs = envpool.make( 312 | args.env_id, 313 | env_type="gym", 314 | num_envs=args.num_envs, 315 | episodic_life=True, 316 | reward_clip=True, 317 | seed=args.seed, 318 | ) 319 | envs.num_envs = args.num_envs 320 | envs.single_action_space = envs.action_space 321 | envs.single_observation_space = envs.observation_space 322 | envs = RecordEpisodeStatistics(envs) 323 | assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" 324 | 325 | # def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 326 | # next_obs_np, reward, next_done, info = envs.step(action.cpu().numpy()) 327 | # return torch.as_tensor(next_obs_np), torch.as_tensor(reward), torch.as_tensor(next_done), info 328 | 329 | ####### Agent ####### 330 | agent = Agent(envs, device=device) 331 | # Make a version of agent with detached params 332 | agent_inference = Agent(envs, device=device) 333 | agent_inference_p = from_module(agent).data 334 | agent_inference_p.to_module(agent_inference) 335 | 336 | ####### Optimizer ####### 337 | optimizer = optim.Adam( 338 | agent.parameters(), 339 | lr=torch.tensor(args.learning_rate, device=device), 340 | eps=1e-5, 341 | capturable=args.cudagraphs and not args.compile, 342 | ) 343 | 344 | ####### Executables ####### 345 | # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule 346 | policy = agent_inference.get_action_and_value 347 | get_value = agent_inference.get_value 348 | 349 | # Compile policy 350 | if args.compile: 351 | mode = "reduce-overhead" if not args.cudagraphs else None 352 | policy = torch.compile(policy, mode=mode) 353 | gae = torch.compile(gae, fullgraph=True, mode=mode) 354 | update = torch.compile(update, mode=mode) 355 | 356 | if args.cudagraphs: 357 | policy = CudaGraphModule(policy, warmup=20) 358 | #gae = CudaGraphModule(gae, warmup=20) 359 | update = CudaGraphModule(update, warmup=20) 360 | 361 | avg_returns = deque(maxlen=20) 362 | global_step = 0 363 | container_local = None 364 | next_obs = torch.tensor(envs.reset(), device=device, dtype=torch.uint8) 365 | next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool) 366 | max_ep_ret = -float("inf") 367 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) 368 | desc = "" 369 | global_step_burnin = None 370 | for iteration in pbar: 371 | if iteration == args.measure_burnin: 372 | global_step_burnin = global_step 373 | start_time = time.time() 374 | 375 | # Annealing the rate if instructed to do so. 376 | if args.anneal_lr: 377 | frac = 1.0 - (iteration - 1.0) / args.num_iterations 378 | lrnow = frac * args.learning_rate 379 | optimizer.param_groups[0]["lr"].copy_(lrnow) 380 | 381 | torch.compiler.cudagraph_mark_step_begin() 382 | next_obs, next_done, container = rollout(next_obs, next_done, avg_returns=avg_returns) 383 | global_step += container.numel() 384 | 385 | torch.compiler.cudagraph_mark_step_begin() 386 | container = gae(next_obs, next_done, container) 387 | container_flat = container.view(-1) 388 | 389 | # Optimizing the policy and value network 390 | clipfracs = [] 391 | for epoch in range(args.update_epochs): 392 | b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size) 393 | for b in b_inds: 394 | container_local = container_flat[b] 395 | 396 | torch.compiler.cudagraph_mark_step_begin() 397 | out = update(container_local, tensordict_out=tensordict.TensorDict()) 398 | if args.target_kl is not None and out["approx_kl"] > args.target_kl: 399 | break 400 | else: 401 | continue 402 | break 403 | 404 | if global_step_burnin is not None and iteration % 10 == 0: 405 | cur_time = time.time() 406 | speed = (global_step - global_step_burnin) / (cur_time - start_time) 407 | global_step_burnin = global_step 408 | start_time = cur_time 409 | 410 | r = container["rewards"].mean() 411 | r_max = container["rewards"].max() 412 | avg_returns_t = torch.tensor(avg_returns).mean() 413 | 414 | with torch.no_grad(): 415 | logs = { 416 | "episode_return": np.array(avg_returns).mean(), 417 | "logprobs": container["logprobs"].mean(), 418 | "advantages": container["advantages"].mean(), 419 | "returns": container["returns"].mean(), 420 | "vals": container["vals"].mean(), 421 | "gn": out["gn"].mean(), 422 | } 423 | 424 | lr = optimizer.param_groups[0]["lr"] 425 | pbar.set_description( 426 | f"speed: {speed: 4.1f} sps, " 427 | f"reward avg: {r :4.2f}, " 428 | f"reward max: {r_max:4.2f}, " 429 | f"returns: {avg_returns_t: 4.2f}," 430 | f"lr: {lr: 4.2f}" 431 | ) 432 | wandb.log( 433 | {"speed": speed, "episode_return": avg_returns_t, "r": r, "r_max": r_max, "lr": lr, **logs}, step=global_step 434 | ) 435 | 436 | envs.close() 437 | --------------------------------------------------------------------------------