├── requirements-dev.txt ├── requirements.txt ├── Makefile ├── scripts ├── dt.sh ├── cql.sh ├── iql.sh ├── awac.sh ├── td3bc.sh └── xql.sh ├── LICENSE ├── .gitignore ├── README.md └── algos ├── awac.py ├── td3bc.py ├── iql.py ├── dt.py ├── xql.py └── cql.py /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | blackdoc 3 | flake8 4 | isort 5 | mypy 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | numpy 3 | pydantic 4 | omegaconf 5 | flax 6 | optax 7 | black 8 | isort 9 | mypy 10 | flake8 11 | wandb 12 | tqdm 13 | git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl 14 | gym==0.23.1 15 | cython==0.29.22 16 | mujoco-py==2.1.2.14 17 | distrax 18 | 19 | 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install-dev clean format check install uninstall test diff-test 2 | 3 | 4 | format: 5 | black algos 6 | blackdoc algos 7 | isort algos 8 | 9 | check: 10 | black algos --check --diff 11 | blackdoc algos --check 12 | flake8 --config pyproject.toml --ignore E203,E501,W503,E741 algos 13 | mypy --config pyproject.toml algos 14 | isort algos --check --diff 15 | 16 | push: 17 | git add . 18 | git commit -m "." 19 | git push -u origin HEAD 20 | -------------------------------------------------------------------------------- /scripts/dt.sh: -------------------------------------------------------------------------------- 1 | project=dt-report 2 | 3 | cd .. && cd algos 4 | for seed in 1 2 3 4 5 5 | do 6 | python dt.py env_name=halfcheetah-medium-v2 seed=$seed project=$project 7 | python dt.py env_name=halfcheetah-medium-expert-v2 seed=$seed project=$project 8 | python dt.py env_name=hopper-medium-v2 seed=$seed project=$project 9 | python dt.py env_name=hopper-medium-expert-v2 seed=$seed project=$project 10 | python dt.py env_name=walker2d-medium-v2 seed=$seed project=$project 11 | python dt.py env_name=walker2d-medium-expert-v2 seed=$seed project=$project 12 | done -------------------------------------------------------------------------------- /scripts/cql.sh: -------------------------------------------------------------------------------- 1 | project=cql-report- 2 | 3 | cd .. && cd algos 4 | for seed in 1 2 3 4 5 5 | do 6 | python cql.py env_name=halfcheetah-medium-v2 seed=$seed project=$project 7 | python cql.py env_name=halfcheetah-medium-expert-v2 seed=$seed project=$project 8 | python cql.py env_name=hopper-medium-v2 seed=$seed project=$project 9 | python cql.py env_name=hopper-medium-expert-v2 seed=$seed project=$project 10 | python cql.py env_name=walker2d-medium-v2 seed=$seed project=$project 11 | python cql.py env_name=walker2d-medium-expert-v2 seed=$seed project=$project 12 | done -------------------------------------------------------------------------------- /scripts/iql.sh: -------------------------------------------------------------------------------- 1 | project=iql-report-separate 2 | 3 | cd .. && cd algos 4 | for seed in 1 2 3 4 5 6 7 8 9 10 5 | do 6 | python iql.py env_name=halfcheetah-medium-v2 seed=$seed project=$project 7 | python iql.py env_name=halfcheetah-medium-expert-v2 seed=$seed project=$project 8 | python iql.py env_name=hopper-medium-v2 seed=$seed project=$project 9 | python iql.py env_name=hopper-medium-expert-v2 seed=$seed project=$project 10 | python iql.py env_name=walker2d-medium-v2 seed=$seed project=$project 11 | python iql.py env_name=walker2d-medium-expert-v2 seed=$seed project=$project 12 | done -------------------------------------------------------------------------------- /scripts/awac.sh: -------------------------------------------------------------------------------- 1 | project=awac-report-separate 2 | 3 | cd .. && cd algos 4 | for seed in 1 2 3 4 5 6 7 8 9 10 5 | do 6 | python awac.py env_name=halfcheetah-medium-v2 seed=$seed project=$project 7 | python awac.py env_name=halfcheetah-medium-expert-v2 seed=$seed project=$project 8 | python awac.py env_name=hopper-medium-v2 seed=$seed project=$project 9 | python awac.py env_name=hopper-medium-expert-v2 seed=$seed project=$project 10 | python awac.py env_name=walker2d-medium-v2 seed=$seed project=$project 11 | python awac.py env_name=walker2d-medium-expert-v2 seed=$seed project=$project 12 | done -------------------------------------------------------------------------------- /scripts/td3bc.sh: -------------------------------------------------------------------------------- 1 | project=td3bc-report-separate 2 | 3 | cd .. && cd algos 4 | for seed in 1 2 3 4 5 6 7 8 9 10 5 | do 6 | python td3bc.py env_name=halfcheetah-medium-v2 seed=$seed project=$project 7 | python td3bc.py env_name=halfcheetah-medium-expert-v2 seed=$seed project=$project 8 | python td3bc.py env_name=hopper-medium-v2 seed=$seed project=$project 9 | python td3bc.py env_name=hopper-medium-expert-v2 seed=$seed project=$project 10 | python td3bc.py env_name=walker2d-medium-v2 seed=$seed project=$project 11 | python td3bc.py env_name=walker2d-medium-expert-v2 seed=$seed project=$project 12 | done -------------------------------------------------------------------------------- /scripts/xql.sh: -------------------------------------------------------------------------------- 1 | project=xql-report 2 | 3 | cd .. && cd algos 4 | for seed in 1 2 3 4 5 5 | do 6 | python xql.py env_name=halfcheetah-medium-v2 seed=$seed project=$project max_clip=7.0 noise=true loss_temp=1.0 7 | python xql.py env_name=halfcheetah-medium-expert-v2 seed=$seed project=$project max_clip=5.0 loss_tmp=1.0 8 | python xql.py env_name=hopper-medium-v2 seed=$seed project=$project max_clip=7.0 loss_temp=5.0 9 | python xql.py env_name=hopper-medium-expert-v2 seed=$seed project=$project max_clip=7.0 loss_temp=2.0 sample_random_times=1 10 | python xql.py env_name=walker2d-medium-v2 seed=$seed project=$project max_clip=7.0 loss_temp=10.0 11 | python xql.py env_name=walker2d-medium-expert-v2 seed=$seed project=$project max_clip=5.0 loss_temp=2.0 12 | done 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Soichiro Nishimori 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | wandb/ 6 | D4RL 7 | fig 8 | tmp 9 | 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAX-CORL 2 | This repository aims JAX version of [CORL](https://github.com/tinkoff-ai/CORL), clean **single-file** implementations of offline RL algorithms with **solid performance reports**. 3 | - 🌬️ Persuing **fast** training: speed up via jax functions such as `jit` and `vmap`. 4 | - 🔪 As **simple** as possible: implement minimum requirements. 5 | - 💠 Focus on **a few battle-tested algorithms**: Refer [here](https://github.com/nissymori/JAX-CORL/blob/main/README.md#algorithms). 6 | - 📈 Solid performance report ([README](https://github.com/nissymori/JAX-CORL?tab=readme-ov-file#reports-for-d4rl-mujoco), [Wiki](https://github.com/nissymori/JAX-CORL/wiki)). 7 | 8 | JAX-CORL is complementing the single-file RL ecosystem by offering the combination of offline x JAX. 9 | - [CleanRL](https://github.com/vwxyzjn/cleanrl): Online x PyTorch 10 | - [purejaxrl](https://github.com/luchris429/purejaxrl): Online x JAX 11 | - [CORL](https://github.com/tinkoff-ai/CORL): Offline x PyTorch 12 | - **JAX-CORL(ours): Offline x JAX** 13 | 14 | # Algorithms 15 | |Algorithm|implementation|training time (CORL)|training time (ours)| wandb | 16 | |---|---|---|---|---| 17 | |[AWAC](https://arxiv.org/abs/2006.09359)| [algos/awac.py](https://github.com/nissymori/JAX-CORL/blob/main/algos/awac.py) |4.46h|11m(**24x faster**)|[link](https://api.wandb.ai/links/nissymori/mwi235j6) | 18 | |[IQL](https://arxiv.org/abs/2110.06169)| [algos/iql.py](https://github.com/nissymori/JAX-CORL/blob/main/algos/iql.py) |4.08h|9m(**28x faster**)| [link](https://wandb.ai/nissymori/xql-report/reports/XQL-mujoco--VmlldzoxMDY0MDUyNQ?accessToken=nlwulejjkfvxoddnlp0xyl0pcy8zd61aw9cw2od0pp1wlgxe34glftw3gex2v1f4) | 19 | |[TD3+BC](https://arxiv.org/pdf/2106.06860)| [algos/td3_bc.py](https://github.com/nissymori/JAX-CORL/blob/main/algos/td3bc.py) |2.47h|9m(**16x faster**)| [link](https://api.wandb.ai/links/nissymori/h21py327) | 20 | |[CQL](https://arxiv.org/abs/2006.04779)| [algos/cql.py](https://github.com/nissymori/JAX-CORL/blob/main/algos/cql.py) |11.52h|56m(**12x faster**)|[link](https://api.wandb.ai/links/nissymori/cnxdwkgf)| 21 | |[XQL](https://arxiv.org/abs/2301.02328)| [algos/xql.py](https://github.com/nissymori/JAX-CORL/blob/main/algos/xql.py) | - | 12m | [link]([algos/cql.py](https://github.com/nissymori/JAX-CORL/blob/main/algos/cql.py)) | 22 | |[DT](https://arxiv.org/abs/2106.01345) | [algos/dt.py](https://github.com/nissymori/JAX-CORL/blob/main/algos/dt.py) |42m|11m(**4x faster**)|[link](https://api.wandb.ai/links/nissymori/yrpja8if)| 23 | 24 | Training time is for `1000_000` update steps without evaluation for `halfcheetah-medium-expert v2` (little difference between different [D4RL](https://arxiv.org/abs/2004.07219) mujoco environments). The training time of ours includes the compile time for `jit`. The computations were performed using four [GeForce GTX 1080 Ti GPUs](https://versus.com/en/inno3d-ichill-geforce-gtx-1080-ti-x4). PyTorch's time is measured with CORL implementations. 25 | 26 | # Reports for D4RL mujoco 27 | 28 | ### Normalized Score 29 | Here, we used [D4RL](https://arxiv.org/abs/2004.07219) mujoco control tasks as the benchmark. We reported the mean and standard deviation of the average normalized score of 5 episodes over 5 seeds. 30 | We plan to extend the verification to other D4RL benchmarks such as AntMaze. For those who would like to know about the source of hyperparameters and the validity of the performance, please refer to [Wiki](https://github.com/nissymori/JAX-CORL/wiki). 31 | |env|AWAC|IQL|TD3+BC|CQL|XQL|DT| 32 | |---|---|---|---|---|---|---| 33 | |halfcheetah-medium-v2| $41.56\pm0.79$ |$46.23\pm0.23$ |$48.12\pm0.42$ |$48.65\pm 0.49$|$47.16\pm0.16$|$42.63 \pm 0.53$| 34 | |halfcheetah-medium-expert-v2| $76.61\pm 9.60$ | $92.95\pm0.79$ | $92.99\pm 0.11$ |$53.76 \pm 14.53$|$86.33\pm4.89$|$70.63\pm 14.70$| 35 | |hopper-medium-v2| $51.45\pm 5.40$ | $56.78\pm3.50$ | $46.51\pm4.57$ |$77.56\pm 7.12$|$62.35\pm5.42$|$60.85\pm6.78$| 36 | |hopper-medium-expert-v2| $51.89\pm2.11$ | $90.72\pm 14.80$ |$105.47\pm5.03$ |$90.37 \pm 31.29$|$104.12\pm5.39$|$109.07\pm 4.56$| 37 | |walker2d-medium-v2| $68.12\pm12.08$ | $77.16\pm5.74$ | $72.73\pm4.66$ |$80.16\pm 4.19$|$83.45\pm0.42$|$71.04 \pm5.64$| 38 | |walker2d-medium-expert-v2| $91.36\pm23.13$ | $109.08\pm0.77$ | $109.17\pm0.71$ |$110.03 \pm 0.72$|$110.06\pm0.22$|$99.81\pm17.73$| 39 | 40 | 41 | # How to use this codebase for your research 42 | This codebase can be used independently as a baseline for D4RL projects. It is also designed to be flexible, allowing users to develop new algorithms or adapt them for datasets other than D4RL. 43 | 44 | For researchers interested in using this code for their projects, we provide a detailed explanation of the code's shared structure: 45 | ##### Data structure 46 | 47 | ```py 48 | Transition(NamedTuple): 49 | observations: jnp.ndarray 50 | actions: jnp.ndarray 51 | rewards: jnp.ndarray 52 | next_observations: jnp.ndarray 53 | dones: jnp.ndarray 54 | 55 | def get_dataset(...) -> Transition: 56 | ... 57 | return dataset 58 | ``` 59 | The code includes a `Transition` class, defined as a `NamedTuple`, which contains fields for observations, actions, rewards, next observations, and done flags. The get_dataset function is expected to output data in the Transition format, making it adaptable to any dataset that conforms to this structure. 60 | 61 | ##### Trainer class 62 | ```py 63 | class AlgoTrainState(NamedTuple): 64 | actor: TrainState 65 | critic: TrainState 66 | 67 | class Algo(object): 68 | ... 69 | def update_actor(self, train_state: AlgoTrainState, batch: Transition, config) -> AlgoTrainState: 70 | ... 71 | return train_state 72 | 73 | def update_critic(self, train_state: AlgoTrainState, batch: Transition, config) -> AlgoTrainState: 74 | ... 75 | return train_state 76 | 77 | @partial(jax.jit, static_argnames("n_jitted_updates") 78 | def update_n_times(self, train_state: AlgoTrainState, data, n_jitted_updates, config) -> AlgoTrainState: 79 | for _ in range(n_updates): 80 | batch = data.sample() 81 | train_state = self.update_actor(train_state, batch, config) 82 | agent = self.update_critic(train_state, batch, config) 83 | return train_state 84 | 85 | def create_train_state(...) -> AlgoTrainState: 86 | # initialize models... 87 | return AlgoTrainState( 88 | acotor=actor, 89 | critic=critic, 90 | ) 91 | ``` 92 | For all algorithms, we have `TrainState` class (e.g. `TD3BCTrainState` for TD3+BC) which encompasses all `flax` trainstate for models. Update logic is implemented as the method of `Algo` classes (e.g. TD3BC) Both `TrainState` and `Algo` classes are versatile and can be used outside of the provided files if the `create_train_state` function is properly implemented to meet the necessary specifications for the `TrainState` class. 93 | **Note**: So far, we have not followed the policy for CQL due to technical issues. This will be handled in the near future. 94 | 95 | # See also 96 | **Great Offline RL libraries** 97 | - [CORL](https://github.com/tinkoff-ai/CORL): Comprehensive single-file implementations of offline RL algorithms in pytorch. 98 | 99 | **Implementations of offline RL algorithms in JAX** 100 | - [jaxrl](https://github.com/ikostrikov/jaxrl): Includes implementatin of [AWAC](https://arxiv.org/abs/2006.09359). 101 | - [JaxCQL](https://github.com/young-geng/JaxCQL): Clean implementation of [CQL](https://arxiv.org/abs/2006.04779). 102 | - [implicit_q_learning](https://github.com/ikostrikov/implicit_q_learning): Official implementation of [IQL](https://arxiv.org/abs/2110.06169). 103 | - [decision-transformer-jax](https://github.com/yun-kwak/decision-transformer-jax): Jax implementation of [Decision Transformer](https://arxiv.org/abs/2106.01345) with Haiku. 104 | - [td3-bc-jax](https://github.com/ethanluoyc/td3_bc_jax): Direct port of [original implementation](https://github.com/sfujim/TD3_BC) with Haiku. 105 | - [XQL](https://github.com/Div99/XQL): Offlicial implementation. 106 | 107 | **Single-file implementations** 108 | - [CleanRL](https://github.com/vwxyzjn/cleanrl): High-quality single-file implementations of online RL algorithms in PyTorch. 109 | - [purejaxrl](https://github.com/luchris429/purejaxrl): High-quality single-file implementations of online RL algorithms in JAX. 110 | 111 | # Cite JAX-CORL 112 | ``` 113 | @article{nishimori2024jaxcorl, 114 | title={JAX-CORL: Clean Single-file Implementations of Offline RL Algorithms in JAX}, 115 | author={Soichiro Nishimori}, 116 | year={2024}, 117 | url={https://github.com/nissymori/JAX-CORL} 118 | } 119 | ``` 120 | 121 | # Credits 122 | - This project is inspired by [CORL](https://github.com/tinkoff-ai/CORL), clean single-file implementations of offline RL algorithm in pytorch. 123 | - I would like to thank [@JohannesAck](https://github.com/johannesack) for his TD3-BC codebase and helpful advices. 124 | - The IQL implementation is based on [implicit_q_learning](https://github.com/ikostrikov/implicit_q_learning). 125 | - AWAC implementation is based on [jaxrl](https://github.com/ikostrikov/jaxrl). 126 | - CQL implementation is based on [JaxCQL](https://github.com/young-geng/JaxCQL). 127 | - DT implementation is based on [min-decision-transformer](https://github.com/nikhilbarhate99/min-decision-transformer). 128 | 129 | -------------------------------------------------------------------------------- /algos/awac.py: -------------------------------------------------------------------------------- 1 | # source https://github.com/ikostrikov/jaxrl 2 | # https://arxiv.org/abs/2006.09359 3 | import os 4 | import time 5 | from functools import partial 6 | from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple 7 | 8 | import d4rl 9 | import distrax 10 | import flax 11 | import flax.linen as nn 12 | import gym 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | import optax 17 | import tqdm 18 | import wandb 19 | from flax.training.train_state import TrainState 20 | from omegaconf import OmegaConf 21 | from pydantic import BaseModel 22 | 23 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True " 24 | 25 | 26 | class AWACConfig(BaseModel): 27 | # GENERAL 28 | algo: str = "AWAC" 29 | project: str = "train-AWAC" 30 | env_name: str = "halfcheetah-medium-expert-v2" 31 | seed: int = 42 32 | eval_episodes: int = 5 33 | log_interval: int = 100000 34 | eval_interval: int = 100000 35 | batch_size: int = 256 36 | max_steps: int = int(1e6) 37 | n_jitted_updates: int = 8 38 | # DATASET 39 | data_size: int = int(1e6) 40 | normalize_state: bool = False 41 | # NETWORK 42 | actor_hidden_dims: Tuple[int, int] = (256, 256, 256, 256) 43 | critic_hidden_dims: Tuple[int, int] = (256, 256) 44 | actor_lr: float = 3e-4 45 | critic_lr: float = 3e-4 46 | # AWAC SPECIFIC 47 | _lambda: float = 1.0 48 | tau: float = 0.005 49 | discount: float = 0.99 50 | 51 | def __hash__( 52 | self, 53 | ): # make config hashable to be specified as static_argnums in jax.jit. 54 | return hash(self.__repr__()) 55 | 56 | 57 | conf_dict = OmegaConf.from_cli() 58 | config = AWACConfig(**conf_dict) 59 | 60 | 61 | def default_init(scale: Optional[float] = 1.0): 62 | return nn.initializers.variance_scaling(scale, "fan_avg", "uniform") 63 | 64 | 65 | class MLP(nn.Module): 66 | hidden_dims: Sequence[int] 67 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 68 | activate_final: bool = False 69 | kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init() 70 | add_layer_norm: bool = False 71 | layer_norm_final: bool = False 72 | 73 | @nn.compact 74 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 75 | for i, hidden_dims in enumerate(self.hidden_dims): 76 | x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x) 77 | if self.add_layer_norm: # Add layer norm after activation 78 | if self.layer_norm_final or i + 1 < len(self.hidden_dims): 79 | x = nn.LayerNorm()(x) 80 | if ( 81 | i + 1 < len(self.hidden_dims) or self.activate_final 82 | ): # Add activation after layer norm 83 | x = self.activations(x) 84 | return x 85 | 86 | 87 | class DoubleCritic(nn.Module): 88 | hidden_dims: Sequence[int] 89 | 90 | @nn.compact 91 | def __call__( 92 | self, observation: jnp.ndarray, action: jnp.ndarray 93 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 94 | x = jnp.concatenate([observation, action], axis=-1) 95 | q1 = MLP((*self.hidden_dims, 1), add_layer_norm=True)(x) 96 | q2 = MLP((*self.hidden_dims, 1), add_layer_norm=True)(x) 97 | return q1, q2 98 | 99 | 100 | class GaussianPolicy(nn.Module): 101 | hidden_dims: Sequence[int] 102 | action_dim: int 103 | log_std_min: Optional[float] = -20.0 104 | log_std_max: Optional[float] = 2.0 105 | final_fc_init_scale: float = 1e-3 106 | 107 | @nn.compact 108 | def __call__( 109 | self, observations: jnp.ndarray, temperature: float = 1.0 110 | ) -> distrax.Distribution: 111 | outputs = MLP( 112 | self.hidden_dims, 113 | activate_final=True, 114 | )(observations) 115 | 116 | means = nn.Dense( 117 | self.action_dim, kernel_init=default_init(self.final_fc_init_scale) 118 | )(outputs) 119 | 120 | log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,)) 121 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 122 | 123 | distribution = distrax.MultivariateNormalDiag( 124 | loc=means, scale_diag=jnp.exp(log_stds) * temperature 125 | ) 126 | return distribution 127 | 128 | 129 | class Transition(NamedTuple): 130 | observations: jnp.ndarray 131 | actions: jnp.ndarray 132 | rewards: jnp.ndarray 133 | next_observations: jnp.ndarray 134 | dones: jnp.ndarray 135 | 136 | 137 | def get_dataset( 138 | env: gym.Env, config: AWACConfig, clip_to_eps: bool = True, eps: float = 1e-5 139 | ) -> Transition: 140 | dataset = d4rl.qlearning_dataset(env) 141 | 142 | if clip_to_eps: 143 | lim = 1 - eps 144 | dataset["actions"] = np.clip(dataset["actions"], -lim, lim) 145 | 146 | imputed_next_observations = np.roll(dataset["observations"], -1, axis=0) 147 | same_obs = np.all( 148 | np.isclose(imputed_next_observations, dataset["next_observations"], atol=1e-5), 149 | axis=-1, 150 | ) 151 | dones = 1.0 - same_obs.astype(np.float32) 152 | dones[-1] = 1 153 | 154 | dataset = Transition( 155 | observations=jnp.array(dataset["observations"], dtype=jnp.float32), 156 | actions=jnp.array(dataset["actions"], dtype=jnp.float32), 157 | rewards=jnp.array(dataset["rewards"], dtype=jnp.float32), 158 | dones=jnp.array(dones, dtype=jnp.float32), 159 | next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32), 160 | ) 161 | # shuffle data and select the first data_size samples 162 | data_size = min(config.data_size, len(dataset.observations)) 163 | rng = jax.random.PRNGKey(config.seed) 164 | rng, rng_permute, rng_select = jax.random.split(rng, 3) 165 | perm = jax.random.permutation(rng_permute, len(dataset.observations)) 166 | dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset) 167 | assert len(dataset.observations) >= data_size 168 | dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset) 169 | # normalize states 170 | obs_mean, obs_std = 0, 1 171 | if config.normalize_state: 172 | obs_mean = dataset.observations.mean(0) 173 | obs_std = dataset.observations.std(0) 174 | dataset = dataset._replace( 175 | observations=(dataset.observations - obs_mean) / (obs_std + 1e-5), 176 | next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5), 177 | ) 178 | return dataset, obs_mean, obs_std 179 | 180 | 181 | def target_update( 182 | model: TrainState, target_model: TrainState, tau: float 183 | ) -> Tuple[TrainState, jnp.ndarray]: 184 | new_target_params = jax.tree_util.tree_map( 185 | lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params 186 | ) 187 | return target_model.replace(params=new_target_params) 188 | 189 | 190 | def update_by_loss_grad( 191 | train_state: TrainState, loss_fn: Callable 192 | ) -> Tuple[float, Any]: 193 | grad_fn = jax.value_and_grad(loss_fn) 194 | loss, grad = grad_fn(train_state.params) 195 | new_train_state = train_state.apply_gradients(grads=grad) 196 | return new_train_state, loss 197 | 198 | 199 | class AWACTrainState(NamedTuple): 200 | rng: jax.random.PRNGKey 201 | critic: TrainState 202 | target_critic: TrainState 203 | actor: TrainState 204 | 205 | 206 | class AWAC(object): 207 | 208 | @classmethod 209 | def update_actor( 210 | self, 211 | train_state: AWACTrainState, 212 | batch: Transition, 213 | rng: jax.random.PRNGKey, 214 | config: AWACConfig, 215 | ) -> Tuple["AWACTrainState", jnp.ndarray]: 216 | def get_actor_loss(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: 217 | dist = train_state.actor.apply_fn(actor_params, batch.observations) 218 | pi_actions = dist.sample(seed=rng) 219 | q_1, q_2 = train_state.critic.apply_fn( 220 | train_state.critic.params, batch.observations, pi_actions 221 | ) 222 | v = jnp.minimum(q_1, q_2) 223 | 224 | lim = 1 - 1e-5 225 | actions = jnp.clip(batch.actions, -lim, lim) 226 | q_1, q_2 = train_state.critic.apply_fn( 227 | train_state.critic.params, batch.observations, actions 228 | ) 229 | q = jnp.minimum(q_1, q_2) 230 | adv = q - v 231 | weights = jnp.exp(adv / config._lambda) 232 | 233 | weights = jax.lax.stop_gradient(weights) 234 | 235 | log_prob = dist.log_prob(batch.actions) 236 | loss = -jnp.mean(log_prob * weights).mean() 237 | return loss 238 | 239 | new_actor, actor_loss = update_by_loss_grad(train_state.actor, get_actor_loss) 240 | return train_state._replace(actor=new_actor), actor_loss 241 | 242 | @classmethod 243 | def update_critic( 244 | self, 245 | train_state: AWACTrainState, 246 | batch: Transition, 247 | rng: jax.random.PRNGKey, 248 | config: AWACConfig, 249 | ) -> Tuple["AWACTrainState", jnp.ndarray]: 250 | def get_critic_loss( 251 | critic_params: flax.core.FrozenDict[str, Any] 252 | ) -> jnp.ndarray: 253 | dist = train_state.actor.apply_fn( 254 | train_state.actor.params, batch.observations 255 | ) 256 | next_actions = dist.sample(seed=rng) 257 | n_q_1, n_q_2 = train_state.target_critic.apply_fn( 258 | train_state.target_critic.params, batch.next_observations, next_actions 259 | ) 260 | next_q = jnp.minimum(n_q_1, n_q_2) 261 | q_target = batch.rewards + config.discount * (1 - batch.dones) * next_q 262 | q_target = jax.lax.stop_gradient(q_target) 263 | 264 | q_1, q_2 = train_state.critic.apply_fn( 265 | critic_params, batch.observations, batch.actions 266 | ) 267 | 268 | loss = jnp.mean((q_1 - q_target) ** 2 + (q_2 - q_target) ** 2) 269 | return loss 270 | 271 | new_critic, critic_loss = update_by_loss_grad( 272 | train_state.critic, get_critic_loss 273 | ) 274 | return train_state._replace(critic=new_critic), critic_loss 275 | 276 | @classmethod 277 | def update_n_times( 278 | self, 279 | train_state: AWACTrainState, 280 | dataset: Transition, 281 | rng: jax.random.PRNGKey, 282 | config: AWACConfig, 283 | ) -> Tuple["AWACTrainState", Dict]: 284 | for _ in range(config.n_jitted_updates): 285 | rng, batch_rng, critic_rng, actor_rng = jax.random.split(rng, 4) 286 | batch_indices = jax.random.randint( 287 | batch_rng, (config.batch_size,), 0, len(dataset.observations) 288 | ) 289 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 290 | 291 | train_state, critic_loss = self.update_critic( 292 | train_state, batch, critic_rng, config 293 | ) 294 | new_target_critic = target_update( 295 | train_state.critic, 296 | train_state.target_critic, 297 | config.tau, 298 | ) 299 | train_state, actor_loss = self.update_actor( 300 | train_state, batch, actor_rng, config 301 | ) 302 | return train_state._replace(target_critic=new_target_critic), { 303 | "critic_loss": critic_loss, 304 | "actor_loss": actor_loss, 305 | } 306 | 307 | @classmethod 308 | def get_action( 309 | self, 310 | train_state: AWACTrainState, 311 | observations: np.ndarray, 312 | seed: jax.random.PRNGKey, 313 | temperature: float = 1.0, 314 | max_action: float = 1.0, # In D4RL envs, the action space is [-1, 1] 315 | ) -> jnp.ndarray: 316 | actions = train_state.actor.apply_fn( 317 | train_state.actor.params, observations=observations, temperature=temperature 318 | ).sample(seed=seed) 319 | actions = jnp.clip(actions, -max_action, max_action) 320 | return actions 321 | 322 | 323 | def create_awac_train_state( 324 | rng: jax.random.PRNGKey, 325 | observations: jnp.ndarray, 326 | actions: jnp.ndarray, 327 | config: AWACConfig, 328 | ) -> AWACTrainState: 329 | rng, actor_rng, critic_rng, value_rng = jax.random.split(rng, 4) 330 | # initialize actor 331 | action_dim = actions.shape[-1] 332 | actor_model = GaussianPolicy( 333 | config.actor_hidden_dims, 334 | action_dim=action_dim, 335 | ) 336 | actor = TrainState.create( 337 | apply_fn=actor_model.apply, 338 | params=actor_model.init(actor_rng, observations), 339 | tx=optax.adam(learning_rate=config.actor_lr), 340 | ) 341 | # initialize critic 342 | critic_model = DoubleCritic(config.critic_hidden_dims) 343 | critic = TrainState.create( 344 | apply_fn=critic_model.apply, 345 | params=critic_model.init(critic_rng, observations, actions), 346 | tx=optax.adam(learning_rate=config.critic_lr), 347 | ) 348 | # initialize target critic 349 | target_critic = TrainState.create( 350 | apply_fn=critic_model.apply, 351 | params=critic_model.init(critic_rng, observations, actions), 352 | tx=optax.adam(learning_rate=config.critic_lr), 353 | ) 354 | return AWACTrainState( 355 | rng, 356 | critic=critic, 357 | target_critic=target_critic, 358 | actor=actor, 359 | ) 360 | 361 | 362 | def evaluate( 363 | policy_fn: Callable, 364 | env: gym.Env, 365 | num_episodes: int, 366 | obs_mean: float, 367 | obs_std: float, 368 | ) -> float: 369 | episode_returns = [] 370 | for _ in range(num_episodes): 371 | episode_return = 0 372 | observation, done = env.reset(), False 373 | while not done: 374 | observation = (observation - obs_mean) / obs_std 375 | action = policy_fn(observations=observation) 376 | observation, reward, done, info = env.step(action) 377 | episode_return += reward 378 | episode_returns.append(episode_return) 379 | return env.get_normalized_score(np.mean(episode_returns)) * 100 380 | 381 | 382 | if __name__ == "__main__": 383 | wandb.init(config=config, project=config.project) 384 | rng = jax.random.PRNGKey(config.seed) 385 | env = gym.make(config.env_name) 386 | dataset, obs_mean, obs_std = get_dataset(env, config) 387 | # create train_state 388 | rng, subkey = jax.random.split(rng) 389 | example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset) 390 | train_state: AWACTrainState = create_awac_train_state( 391 | subkey, 392 | example_batch.observations, 393 | example_batch.actions, 394 | config, 395 | ) 396 | algo = AWAC() 397 | update_fn = jax.jit(algo.update_n_times, static_argnums=(3,)) 398 | act_fn = jax.jit(algo.get_action) 399 | 400 | num_steps = config.max_steps // config.n_jitted_updates 401 | eval_interval = config.eval_interval // config.n_jitted_updates 402 | start = time.time() 403 | for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True): 404 | rng, subkey = jax.random.split(rng) 405 | train_state, update_info = update_fn( 406 | train_state, 407 | dataset, 408 | subkey, 409 | config, 410 | ) 411 | if i % config.log_interval == 0: 412 | train_metrics = {f"training/{k}": v for k, v in update_info.items()} 413 | wandb.log(train_metrics, step=i) 414 | 415 | if i % eval_interval == 0: 416 | policy_fn = partial( 417 | act_fn, 418 | temperature=0.0, 419 | seed=jax.random.PRNGKey(0), 420 | train_state=train_state, 421 | ) 422 | normalized_score = evaluate( 423 | policy_fn, env, config.eval_episodes, obs_mean, obs_std 424 | ) 425 | print(i, normalized_score) 426 | eval_metrics = {f"{config.env_name}/normalized_score": normalized_score} 427 | wandb.log(eval_metrics, step=i) 428 | # final evaluation 429 | policy_fn = partial( 430 | act_fn, 431 | temperature=0.0, 432 | seed=jax.random.PRNGKey(0), 433 | train_state=train_state, 434 | ) 435 | normalized_score = evaluate(policy_fn, env, config.eval_episodes, obs_mean, obs_std) 436 | print("Final evaluation score", normalized_score) 437 | wandb.log({f"{config.env_name}/final_normalized_score": normalized_score}) 438 | wandb.finish() 439 | -------------------------------------------------------------------------------- /algos/td3bc.py: -------------------------------------------------------------------------------- 1 | # source https://github.com/sfujim/TD3_BC 2 | # https://arxiv.org/abs/2106.06860 3 | import os 4 | import time 5 | from functools import partial 6 | from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple 7 | 8 | import d4rl 9 | import distrax 10 | import flax 11 | import flax.linen as nn 12 | import gym 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | import optax 17 | import tqdm 18 | import wandb 19 | from flax.training.train_state import TrainState 20 | from omegaconf import OmegaConf 21 | from pydantic import BaseModel 22 | 23 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 24 | 25 | 26 | class TD3BCConfig(BaseModel): 27 | # GENERAL 28 | algo: str = "TD3-BC" 29 | project: str = "train-TD3-BC" 30 | env_name: str = "halfcheetah-medium-expert-v2" 31 | seed: int = 42 32 | eval_episodes: int = 5 33 | log_interval: int = 100000 34 | eval_interval: int = 100000 35 | batch_size: int = 256 36 | max_steps: int = int(1e6) 37 | n_jitted_updates: int = 8 38 | # DATASET 39 | data_size: int = int(1e6) 40 | normalize_state: bool = True 41 | # NETWORK 42 | hidden_dims: Sequence[int] = (256, 256) 43 | critic_lr: float = 1e-3 44 | actor_lr: float = 1e-3 45 | # TD3-BC SPECIFIC 46 | policy_freq: int = 2 # update actor every policy_freq updates 47 | alpha: float = 2.5 # BC loss weight 48 | policy_noise_std: float = 0.2 # std of policy noise 49 | policy_noise_clip: float = 0.5 # clip policy noise 50 | tau: float = 0.005 # target network update rate 51 | discount: float = 0.99 # discount factor 52 | 53 | def __hash__( 54 | self, 55 | ): # make config hashable to be specified as static_argnums in jax.jit. 56 | return hash(self.__repr__()) 57 | 58 | 59 | conf_dict = OmegaConf.from_cli() 60 | config = TD3BCConfig(**conf_dict) 61 | 62 | 63 | def default_init(scale: Optional[float] = jnp.sqrt(2)): 64 | return nn.initializers.orthogonal(scale) 65 | 66 | 67 | class MLP(nn.Module): 68 | hidden_dims: Sequence[int] 69 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 70 | activate_final: bool = False 71 | kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init() 72 | layer_norm: bool = False 73 | 74 | @nn.compact 75 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 76 | for i, hidden_dims in enumerate(self.hidden_dims): 77 | x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x) 78 | if i + 1 < len(self.hidden_dims) or self.activate_final: 79 | if self.layer_norm: # Add layer norm after activation 80 | if i + 1 < len(self.hidden_dims): 81 | x = nn.LayerNorm()(x) 82 | x = self.activations(x) 83 | return x 84 | 85 | 86 | class DoubleCritic(nn.Module): 87 | hidden_dims: Sequence[int] 88 | 89 | @nn.compact 90 | def __call__( 91 | self, observation: jnp.ndarray, action: jnp.ndarray 92 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 93 | x = jnp.concatenate([observation, action], axis=-1) 94 | q1 = MLP((*self.hidden_dims, 1), layer_norm=True)(x) 95 | q2 = MLP((*self.hidden_dims, 1), layer_norm=True)(x) 96 | return q1, q2 97 | 98 | 99 | class TD3Actor(nn.Module): 100 | hidden_dims: Sequence[int] 101 | action_dim: int 102 | max_action: float = 1.0 # In D4RL, action is scaled to [-1, 1] 103 | 104 | @nn.compact 105 | def __call__(self, observation: jnp.ndarray) -> jnp.ndarray: 106 | action = MLP((*self.hidden_dims, self.action_dim))(observation) 107 | action = self.max_action * jnp.tanh( 108 | action 109 | ) # scale to [-max_action, max_action] 110 | return action 111 | 112 | 113 | class Transition(NamedTuple): 114 | observations: jnp.ndarray 115 | actions: jnp.ndarray 116 | rewards: jnp.ndarray 117 | next_observations: jnp.ndarray 118 | dones: jnp.ndarray 119 | 120 | 121 | def get_dataset( 122 | env: gym.Env, config: TD3BCConfig, clip_to_eps: bool = True, eps: float = 1e-5 123 | ) -> Transition: 124 | dataset = d4rl.qlearning_dataset(env) 125 | 126 | if clip_to_eps: 127 | lim = 1 - eps 128 | dataset["actions"] = np.clip(dataset["actions"], -lim, lim) 129 | 130 | imputed_next_observations = np.roll(dataset["observations"], -1, axis=0) 131 | same_obs = np.all( 132 | np.isclose(imputed_next_observations, dataset["next_observations"], atol=1e-5), 133 | axis=-1, 134 | ) 135 | dones = 1.0 - same_obs.astype(np.float32) 136 | dones[-1] = 1 137 | 138 | dataset = Transition( 139 | observations=jnp.array(dataset["observations"], dtype=jnp.float32), 140 | actions=jnp.array(dataset["actions"], dtype=jnp.float32), 141 | rewards=jnp.array(dataset["rewards"], dtype=jnp.float32), 142 | dones=jnp.array(dones, dtype=jnp.float32), 143 | next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32), 144 | ) 145 | # shuffle data and select the first data_size samples 146 | data_size = min(config.data_size, len(dataset.observations)) 147 | rng = jax.random.PRNGKey(config.seed) 148 | rng, rng_permute, rng_select = jax.random.split(rng, 3) 149 | perm = jax.random.permutation(rng_permute, len(dataset.observations)) 150 | dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset) 151 | assert len(dataset.observations) >= data_size 152 | dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset) 153 | # normalize states 154 | obs_mean, obs_std = 0, 1 155 | if config.normalize_state: 156 | obs_mean = dataset.observations.mean(0) 157 | obs_std = dataset.observations.std(0) 158 | dataset = dataset._replace( 159 | observations=(dataset.observations - obs_mean) / (obs_std + 1e-5), 160 | next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5), 161 | ) 162 | return dataset, obs_mean, obs_std 163 | 164 | 165 | def target_update( 166 | model: TrainState, target_model: TrainState, tau: float 167 | ) -> TrainState: 168 | new_target_params = jax.tree_util.tree_map( 169 | lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params 170 | ) 171 | return target_model.replace(params=new_target_params) 172 | 173 | 174 | def update_by_loss_grad( 175 | train_state: TrainState, loss_fn: Callable 176 | ) -> Tuple[TrainState, jnp.ndarray]: 177 | grad_fn = jax.value_and_grad(loss_fn) 178 | loss, grad = grad_fn(train_state.params) 179 | new_train_state = train_state.apply_gradients(grads=grad) 180 | return new_train_state, loss 181 | 182 | 183 | class TD3BCTrainState(NamedTuple): 184 | actor: TrainState 185 | critic: TrainState 186 | target_actor: TrainState 187 | target_critic: TrainState 188 | max_action: float = 1.0 189 | 190 | 191 | class TD3BC(object): 192 | @classmethod 193 | def update_actor( 194 | self, 195 | train_state: TD3BCTrainState, 196 | batch: Transition, 197 | rng: jax.random.PRNGKey, 198 | config: TD3BCConfig, 199 | ) -> Tuple["TD3BCTrainState", jnp.ndarray]: 200 | def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: 201 | predicted_action = train_state.actor.apply_fn( 202 | actor_params, batch.observations 203 | ) 204 | critic_params = jax.lax.stop_gradient(train_state.critic.params) 205 | q_value, _ = train_state.critic.apply_fn( 206 | critic_params, batch.observations, predicted_action 207 | ) 208 | 209 | mean_abs_q = jax.lax.stop_gradient(jnp.abs(q_value).mean()) 210 | loss_lambda = config.alpha / mean_abs_q 211 | 212 | bc_loss = jnp.square(predicted_action - batch.actions).mean() 213 | loss_actor = -1.0 * q_value.mean() * loss_lambda + bc_loss 214 | return loss_actor 215 | 216 | new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn) 217 | return train_state._replace(actor=new_actor), actor_loss 218 | 219 | @classmethod 220 | def update_critic( 221 | self, 222 | train_state: TD3BCTrainState, 223 | batch: Transition, 224 | rng: jax.random.PRNGKey, 225 | config: TD3BCConfig, 226 | ) -> Tuple["TD3BCTrainState", jnp.ndarray]: 227 | def critic_loss_fn( 228 | critic_params: flax.core.FrozenDict[str, Any] 229 | ) -> jnp.ndarray: 230 | q_pred_1, q_pred_2 = train_state.critic.apply_fn( 231 | critic_params, batch.observations, batch.actions 232 | ) 233 | target_next_action = train_state.target_actor.apply_fn( 234 | train_state.target_actor.params, batch.next_observations 235 | ) 236 | policy_noise = ( 237 | config.policy_noise_std 238 | * train_state.max_action 239 | * jax.random.normal(rng, batch.actions.shape) 240 | ) 241 | target_next_action = target_next_action + policy_noise.clip( 242 | -config.policy_noise_clip, config.policy_noise_clip 243 | ) 244 | target_next_action = target_next_action.clip( 245 | -train_state.max_action, train_state.max_action 246 | ) 247 | q_next_1, q_next_2 = train_state.target_critic.apply_fn( 248 | train_state.target_critic.params, 249 | batch.next_observations, 250 | target_next_action, 251 | ) 252 | target = batch.rewards[..., None] + config.discount * jnp.minimum( 253 | q_next_1, q_next_2 254 | ) * (1 - batch.dones[..., None]) 255 | target = jax.lax.stop_gradient(target) # stop gradient for target 256 | value_loss_1 = jnp.square(q_pred_1 - target) 257 | value_loss_2 = jnp.square(q_pred_2 - target) 258 | value_loss = (value_loss_1 + value_loss_2).mean() 259 | return value_loss 260 | 261 | new_critic, critic_loss = update_by_loss_grad( 262 | train_state.critic, critic_loss_fn 263 | ) 264 | return train_state._replace(critic=new_critic), critic_loss 265 | 266 | @classmethod 267 | def update_n_times( 268 | self, 269 | train_state: TD3BCTrainState, 270 | data: Transition, 271 | rng: jax.random.PRNGKey, 272 | config: TD3BCConfig, 273 | ) -> Tuple["TD3BCTrainState", Dict]: 274 | for _ in range( 275 | config.n_jitted_updates 276 | ): # we can jit for roop for static unroll 277 | rng, batch_rng = jax.random.split(rng, 2) 278 | batch_idx = jax.random.randint( 279 | batch_rng, (config.batch_size,), 0, len(data.observations) 280 | ) 281 | batch: Transition = jax.tree_util.tree_map(lambda x: x[batch_idx], data) 282 | rng, critic_rng, actor_rng = jax.random.split(rng, 3) 283 | train_state, critic_loss = self.update_critic( 284 | train_state, batch, critic_rng, config 285 | ) 286 | if _ % config.policy_freq == 0: 287 | train_state, actor_loss = self.update_actor( 288 | train_state, batch, actor_rng, config 289 | ) 290 | new_target_critic = target_update( 291 | train_state.critic, train_state.target_critic, config.tau 292 | ) 293 | new_target_actor = target_update( 294 | train_state.actor, train_state.target_actor, config.tau 295 | ) 296 | train_state = train_state._replace( 297 | target_critic=new_target_critic, 298 | target_actor=new_target_actor, 299 | ) 300 | return train_state, { 301 | "critic_loss": critic_loss, 302 | "actor_loss": actor_loss, 303 | } 304 | 305 | @classmethod 306 | def get_action( 307 | self, 308 | train_state: TD3BCTrainState, 309 | obs: jnp.ndarray, 310 | max_action: float = 1.0, # In D4RL, action is scaled to [-1, 1] 311 | ) -> jnp.ndarray: 312 | action = train_state.actor.apply_fn(train_state.actor.params, obs) 313 | action = action.clip(-max_action, max_action) 314 | return action 315 | 316 | 317 | def create_td3bc_train_state( 318 | rng: jax.random.PRNGKey, 319 | observations: jnp.ndarray, 320 | actions: jnp.ndarray, 321 | config: TD3BCConfig, 322 | ) -> TD3BCTrainState: 323 | critic_model = DoubleCritic( 324 | hidden_dims=config.hidden_dims, 325 | ) 326 | action_dim = actions.shape[-1] 327 | actor_model = TD3Actor( 328 | action_dim=action_dim, 329 | hidden_dims=config.hidden_dims, 330 | ) 331 | rng, critic_rng, actor_rng = jax.random.split(rng, 3) 332 | # initialize critic 333 | critic_train_state: TrainState = TrainState.create( 334 | apply_fn=critic_model.apply, 335 | params=critic_model.init(critic_rng, observations, actions), 336 | tx=optax.adam(config.critic_lr), 337 | ) 338 | target_critic_train_state: TrainState = TrainState.create( 339 | apply_fn=critic_model.apply, 340 | params=critic_model.init(critic_rng, observations, actions), 341 | tx=optax.adam(config.critic_lr), 342 | ) 343 | # initialize actor 344 | actor_train_state: TrainState = TrainState.create( 345 | apply_fn=actor_model.apply, 346 | params=actor_model.init(actor_rng, observations), 347 | tx=optax.adam(config.actor_lr), 348 | ) 349 | target_actor_train_state: TrainState = TrainState.create( 350 | apply_fn=actor_model.apply, 351 | params=actor_model.init(actor_rng, observations), 352 | tx=optax.adam(config.actor_lr), 353 | ) 354 | return TD3BCTrainState( 355 | actor=actor_train_state, 356 | critic=critic_train_state, 357 | target_actor=target_actor_train_state, 358 | target_critic=target_critic_train_state, 359 | ) 360 | 361 | 362 | def evaluate( 363 | policy_fn: Callable[[jnp.ndarray], jnp.ndarray], 364 | env: gym.Env, 365 | num_episodes: int, 366 | obs_mean, 367 | obs_std, 368 | ) -> float: # D4RL specific 369 | episode_returns = [] 370 | for _ in range(num_episodes): 371 | episode_return = 0 372 | observation, done = env.reset(), False 373 | while not done: 374 | observation = (observation - obs_mean) / obs_std 375 | action = policy_fn(obs=observation) 376 | observation, reward, done, info = env.step(action) 377 | episode_return += reward 378 | episode_returns.append(episode_return) 379 | return env.get_normalized_score(np.mean(episode_returns)) * 100 380 | 381 | 382 | if __name__ == "__main__": 383 | wandb.init(project=config.project, config=config) 384 | env = gym.make(config.env_name) 385 | rng = jax.random.PRNGKey(config.seed) 386 | dataset, obs_mean, obs_std = get_dataset(env, config) 387 | # create train_state 388 | rng, subkey = jax.random.split(rng) 389 | example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset) 390 | train_state = create_td3bc_train_state( 391 | subkey, example_batch.observations, example_batch.actions, config 392 | ) 393 | algo = TD3BC() 394 | update_fn = jax.jit(algo.update_n_times, static_argnums=(3,)) 395 | act_fn = jax.jit(algo.get_action) 396 | 397 | num_steps = config.max_steps // config.n_jitted_updates 398 | eval_interval = config.eval_interval // config.n_jitted_updates 399 | for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True): 400 | rng, update_rng = jax.random.split(rng) 401 | train_state, update_info = update_fn( 402 | train_state, 403 | dataset, 404 | update_rng, 405 | config, 406 | ) # update parameters 407 | if i % config.log_interval == 0: 408 | train_metrics = {f"training/{k}": v for k, v in update_info.items()} 409 | wandb.log(train_metrics, step=i) 410 | 411 | if i % eval_interval == 0: 412 | policy_fn = partial(act_fn, train_state=train_state) 413 | normalized_score = evaluate( 414 | policy_fn, 415 | env, 416 | num_episodes=config.eval_episodes, 417 | obs_mean=obs_mean, 418 | obs_std=obs_std, 419 | ) 420 | print(i, normalized_score) 421 | eval_metrics = {f"{config.env_name}/normalized_score": normalized_score} 422 | wandb.log(eval_metrics, step=i) 423 | # final evaluation 424 | policy_fn = partial(act_fn, train_state=train_state) 425 | normalized_score = evaluate( 426 | policy_fn, 427 | env, 428 | num_episodes=config.eval_episodes, 429 | obs_mean=obs_mean, 430 | obs_std=obs_std, 431 | ) 432 | print("Final Evaluation Score:", normalized_score) 433 | wandb.log({f"{config.env_name}/final_normalized_score": normalized_score}) 434 | wandb.finish() 435 | -------------------------------------------------------------------------------- /algos/iql.py: -------------------------------------------------------------------------------- 1 | # source https://github.com/ikostrikov/implicit_q_learning 2 | # https://arxiv.org/abs/2110.06169 3 | import os 4 | import time 5 | from functools import partial 6 | from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple 7 | 8 | import d4rl 9 | import distrax 10 | import flax 11 | import flax.linen as nn 12 | import gym 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | import optax 17 | import tqdm 18 | import wandb 19 | from flax.training.train_state import TrainState 20 | from omegaconf import OmegaConf 21 | from pydantic import BaseModel 22 | 23 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 24 | 25 | 26 | class IQLConfig(BaseModel): 27 | # GENERAL 28 | algo: str = "IQL" 29 | project: str = "train-IQL" 30 | env_name: str = "halfcheetah-medium-expert-v2" 31 | seed: int = 42 32 | eval_episodes: int = 5 33 | log_interval: int = 100000 34 | eval_interval: int = 100000 35 | batch_size: int = 256 36 | max_steps: int = int(1e6) 37 | n_jitted_updates: int = 8 38 | # DATASET 39 | data_size: int = int(1e6) 40 | normalize_state: bool = False 41 | normalize_reward: bool = True 42 | # NETWORK 43 | hidden_dims: Tuple[int, int] = (256, 256) 44 | actor_lr: float = 3e-4 45 | value_lr: float = 3e-4 46 | critic_lr: float = 3e-4 47 | layer_norm: bool = True 48 | opt_decay_schedule: bool = True 49 | # IQL SPECIFIC 50 | expectile: float = ( 51 | 0.7 # FYI: for Hopper-me, 0.5 produce better result. (antmaze: expectile=0.9) 52 | ) 53 | beta: float = ( 54 | 3.0 # FYI: for Hopper-me, 6.0 produce better result. (antmaze: beta=10.0) 55 | ) 56 | tau: float = 0.005 57 | discount: float = 0.99 58 | 59 | def __hash__( 60 | self, 61 | ): # make config hashable to be specified as static_argnums in jax.jit. 62 | return hash(self.__repr__()) 63 | 64 | 65 | conf_dict = OmegaConf.from_cli() 66 | config = IQLConfig(**conf_dict) 67 | 68 | 69 | def default_init(scale: Optional[float] = jnp.sqrt(2)): 70 | return nn.initializers.orthogonal(scale) 71 | 72 | 73 | class MLP(nn.Module): 74 | hidden_dims: Sequence[int] 75 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 76 | activate_final: bool = False 77 | kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init() 78 | layer_norm: bool = False 79 | 80 | @nn.compact 81 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 82 | for i, hidden_dims in enumerate(self.hidden_dims): 83 | x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x) 84 | if i + 1 < len(self.hidden_dims) or self.activate_final: 85 | if self.layer_norm: # Add layer norm after activation 86 | x = nn.LayerNorm()(x) 87 | x = self.activations(x) 88 | return x 89 | 90 | 91 | class Critic(nn.Module): 92 | hidden_dims: Sequence[int] 93 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 94 | 95 | @nn.compact 96 | def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: 97 | inputs = jnp.concatenate([observations, actions], -1) 98 | critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs) 99 | return jnp.squeeze(critic, -1) 100 | 101 | 102 | def ensemblize(cls, num_qs, out_axes=0, **kwargs): 103 | split_rngs = kwargs.pop("split_rngs", {}) 104 | return nn.vmap( 105 | cls, 106 | variable_axes={"params": 0}, 107 | split_rngs={**split_rngs, "params": True}, 108 | in_axes=None, 109 | out_axes=out_axes, 110 | axis_size=num_qs, 111 | **kwargs, 112 | ) 113 | 114 | 115 | class ValueCritic(nn.Module): 116 | hidden_dims: Sequence[int] 117 | layer_norm: bool = False 118 | 119 | @nn.compact 120 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 121 | critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm)(observations) 122 | return jnp.squeeze(critic, -1) 123 | 124 | 125 | class GaussianPolicy(nn.Module): 126 | hidden_dims: Sequence[int] 127 | action_dim: int 128 | log_std_min: Optional[float] = -5.0 129 | log_std_max: Optional[float] = 2 130 | 131 | @nn.compact 132 | def __call__( 133 | self, observations: jnp.ndarray, temperature: float = 1.0 134 | ) -> distrax.Distribution: 135 | outputs = MLP( 136 | self.hidden_dims, 137 | activate_final=True, 138 | )(observations) 139 | 140 | means = nn.Dense( 141 | self.action_dim, kernel_init=default_init() 142 | )(outputs) 143 | log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,)) 144 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 145 | 146 | distribution = distrax.MultivariateNormalDiag( 147 | loc=means, scale_diag=jnp.exp(log_stds) * temperature 148 | ) 149 | return distribution 150 | 151 | 152 | class Transition(NamedTuple): 153 | observations: jnp.ndarray 154 | actions: jnp.ndarray 155 | rewards: jnp.ndarray 156 | next_observations: jnp.ndarray 157 | dones: jnp.ndarray 158 | dones_float: jnp.ndarray 159 | 160 | 161 | def get_normalization(dataset: Transition) -> float: 162 | # into numpy.ndarray 163 | dataset = jax.tree_util.tree_map(lambda x: np.array(x), dataset) 164 | returns = [] 165 | ret = 0 166 | for r, term in zip(dataset.rewards, dataset.dones_float): 167 | ret += r 168 | if term: 169 | returns.append(ret) 170 | ret = 0 171 | return (max(returns) - min(returns)) / 1000 172 | 173 | 174 | def get_dataset( 175 | env: gym.Env, config: IQLConfig, clip_to_eps: bool = True, eps: float = 1e-5 176 | ) -> Transition: 177 | dataset = d4rl.qlearning_dataset(env) 178 | 179 | if clip_to_eps: 180 | lim = 1 - eps 181 | dataset["actions"] = np.clip(dataset["actions"], -lim, lim) 182 | 183 | dones_float = np.zeros_like(dataset['rewards']) 184 | 185 | for i in range(len(dones_float) - 1): 186 | if np.linalg.norm(dataset['observations'][i + 1] - 187 | dataset['next_observations'][i] 188 | ) > 1e-6 or dataset['terminals'][i] == 1.0: 189 | dones_float[i] = 1 190 | else: 191 | dones_float[i] = 0 192 | dones_float[-1] = 1 193 | 194 | dataset = Transition( 195 | observations=jnp.array(dataset["observations"], dtype=jnp.float32), 196 | actions=jnp.array(dataset["actions"], dtype=jnp.float32), 197 | rewards=jnp.array(dataset["rewards"], dtype=jnp.float32), 198 | next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32), 199 | dones=jnp.array(dataset["terminals"], dtype=jnp.float32), 200 | dones_float=jnp.array(dones_float, dtype=jnp.float32), 201 | ) 202 | if "antmaze" in config.env_name: 203 | dataset = dataset._replace( 204 | rewards=dataset.rewards - 1.0 205 | ) 206 | # normalize states 207 | obs_mean, obs_std = 0, 1 208 | if config.normalize_state: 209 | obs_mean = dataset.observations.mean(0) 210 | obs_std = dataset.observations.std(0) 211 | dataset = dataset._replace( 212 | observations=(dataset.observations - obs_mean) / (obs_std + 1e-5), 213 | next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5), 214 | ) 215 | # normalize rewards 216 | if config.normalize_reward: 217 | normalizing_factor = get_normalization(dataset) 218 | dataset = dataset._replace(rewards=dataset.rewards / normalizing_factor) 219 | 220 | # shuffle data and select the first data_size samples 221 | data_size = min(config.data_size, len(dataset.observations)) 222 | rng = jax.random.PRNGKey(config.seed) 223 | rng, rng_permute, rng_select = jax.random.split(rng, 3) 224 | perm = jax.random.permutation(rng_permute, len(dataset.observations)) 225 | dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset) 226 | assert len(dataset.observations) >= data_size 227 | dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset) 228 | return dataset, obs_mean, obs_std 229 | 230 | 231 | def expectile_loss(diff, expectile=0.8) -> jnp.ndarray: 232 | weight = jnp.where(diff > 0, expectile, (1 - expectile)) 233 | return weight * (diff**2) 234 | 235 | 236 | def target_update( 237 | model: TrainState, target_model: TrainState, tau: float 238 | ) -> TrainState: 239 | new_target_params = jax.tree_util.tree_map( 240 | lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params 241 | ) 242 | return target_model.replace(params=new_target_params) 243 | 244 | 245 | def update_by_loss_grad( 246 | train_state: TrainState, loss_fn: Callable 247 | ) -> Tuple[TrainState, jnp.ndarray]: 248 | grad_fn = jax.value_and_grad(loss_fn) 249 | loss, grad = grad_fn(train_state.params) 250 | new_train_state = train_state.apply_gradients(grads=grad) 251 | return new_train_state, loss 252 | 253 | 254 | class IQLTrainState(NamedTuple): 255 | rng: jax.random.PRNGKey 256 | critic: TrainState 257 | target_critic: TrainState 258 | value: TrainState 259 | actor: TrainState 260 | 261 | 262 | class IQL(object): 263 | 264 | @classmethod 265 | def update_critic( 266 | self, train_state: IQLTrainState, batch: Transition, config: IQLConfig 267 | ) -> Tuple["IQLTrainState", Dict]: 268 | next_v = train_state.value.apply_fn( 269 | train_state.value.params, batch.next_observations 270 | ) 271 | target_q = batch.rewards + config.discount * (1 - batch.dones) * next_v 272 | 273 | def critic_loss_fn( 274 | critic_params: flax.core.FrozenDict[str, Any] 275 | ) -> jnp.ndarray: 276 | q1, q2 = train_state.critic.apply_fn( 277 | critic_params, batch.observations, batch.actions 278 | ) 279 | critic_loss = ((q1 - target_q) ** 2 + (q2 - target_q) ** 2).mean() 280 | return critic_loss 281 | 282 | new_critic, critic_loss = update_by_loss_grad( 283 | train_state.critic, critic_loss_fn 284 | ) 285 | return train_state._replace(critic=new_critic), critic_loss 286 | 287 | @classmethod 288 | def update_value( 289 | self, train_state: IQLTrainState, batch: Transition, config: IQLConfig 290 | ) -> Tuple["IQLTrainState", Dict]: 291 | q1, q2 = train_state.target_critic.apply_fn( 292 | train_state.target_critic.params, batch.observations, batch.actions 293 | ) 294 | q = jax.lax.stop_gradient(jnp.minimum(q1, q2)) 295 | def value_loss_fn(value_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: 296 | v = train_state.value.apply_fn(value_params, batch.observations) 297 | value_loss = expectile_loss(q - v, config.expectile).mean() 298 | return value_loss 299 | 300 | new_value, value_loss = update_by_loss_grad(train_state.value, value_loss_fn) 301 | return train_state._replace(value=new_value), value_loss 302 | 303 | @classmethod 304 | def update_actor( 305 | self, train_state: IQLTrainState, batch: Transition, config: IQLConfig 306 | ) -> Tuple["IQLTrainState", Dict]: 307 | v = train_state.value.apply_fn(train_state.value.params, batch.observations) 308 | q1, q2 = train_state.critic.apply_fn( 309 | train_state.target_critic.params, batch.observations, batch.actions 310 | ) 311 | q = jnp.minimum(q1, q2) 312 | exp_a = jnp.exp((q - v) * config.beta) 313 | exp_a = jnp.minimum(exp_a, 100.0) 314 | def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: 315 | dist = train_state.actor.apply_fn(actor_params, batch.observations) 316 | log_probs = dist.log_prob(batch.actions) 317 | actor_loss = -(exp_a * log_probs).mean() 318 | return actor_loss 319 | 320 | new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn) 321 | return train_state._replace(actor=new_actor), actor_loss 322 | 323 | @classmethod 324 | def update_n_times( 325 | self, 326 | train_state: IQLTrainState, 327 | dataset: Transition, 328 | rng: jax.random.PRNGKey, 329 | config: IQLConfig, 330 | ) -> Tuple["IQLTrainState", Dict]: 331 | for _ in range(config.n_jitted_updates): 332 | rng, subkey = jax.random.split(rng) 333 | batch_indices = jax.random.randint( 334 | subkey, (config.batch_size,), 0, len(dataset.observations) 335 | ) 336 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 337 | 338 | train_state, value_loss = self.update_value(train_state, batch, config) 339 | train_state, actor_loss = self.update_actor(train_state, batch, config) 340 | train_state, critic_loss = self.update_critic(train_state, batch, config) 341 | new_target_critic = target_update( 342 | train_state.critic, train_state.target_critic, config.tau 343 | ) 344 | train_state = train_state._replace(target_critic=new_target_critic) 345 | return train_state, { 346 | "value_loss": value_loss, 347 | "actor_loss": actor_loss, 348 | "critic_loss": critic_loss, 349 | } 350 | 351 | @classmethod 352 | def get_action( 353 | self, 354 | train_state: IQLTrainState, 355 | observations: np.ndarray, 356 | seed: jax.random.PRNGKey, 357 | temperature: float = 1.0, 358 | max_action: float = 1.0, # In D4RL, the action space is [-1, 1] 359 | ) -> jnp.ndarray: 360 | actions = train_state.actor.apply_fn( 361 | train_state.actor.params, observations, temperature=temperature 362 | ).sample(seed=seed) 363 | actions = jnp.clip(actions, -max_action, max_action) 364 | return actions 365 | 366 | 367 | def create_iql_train_state( 368 | rng: jax.random.PRNGKey, 369 | observations: jnp.ndarray, 370 | actions: jnp.ndarray, 371 | config: IQLConfig, 372 | ) -> IQLTrainState: 373 | rng, actor_rng, critic_rng, value_rng = jax.random.split(rng, 4) 374 | # initialize actor 375 | action_dim = actions.shape[-1] 376 | actor_model = GaussianPolicy( 377 | config.hidden_dims, 378 | action_dim=action_dim, 379 | log_std_min=-5.0, 380 | ) 381 | if config.opt_decay_schedule: 382 | schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps) 383 | actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn)) 384 | else: 385 | actor_tx = optax.adam(learning_rate=config.actor_lr) 386 | actor = TrainState.create( 387 | apply_fn=actor_model.apply, 388 | params=actor_model.init(actor_rng, observations), 389 | tx=actor_tx, 390 | ) 391 | # initialize critic 392 | critic_model = ensemblize(Critic, num_qs=2)(config.hidden_dims) 393 | critic = TrainState.create( 394 | apply_fn=critic_model.apply, 395 | params=critic_model.init(critic_rng, observations, actions), 396 | tx=optax.adam(learning_rate=config.critic_lr), 397 | ) 398 | target_critic = TrainState.create( 399 | apply_fn=critic_model.apply, 400 | params=critic_model.init(critic_rng, observations, actions), 401 | tx=optax.adam(learning_rate=config.critic_lr), 402 | ) 403 | # initialize value 404 | value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm) 405 | value = TrainState.create( 406 | apply_fn=value_model.apply, 407 | params=value_model.init(value_rng, observations), 408 | tx=optax.adam(learning_rate=config.value_lr), 409 | ) 410 | return IQLTrainState( 411 | rng, 412 | critic=critic, 413 | target_critic=target_critic, 414 | value=value, 415 | actor=actor, 416 | ) 417 | 418 | 419 | def evaluate( 420 | policy_fn, env: gym.Env, num_episodes: int, obs_mean: float, obs_std: float 421 | ) -> float: 422 | episode_returns = [] 423 | for _ in range(num_episodes): 424 | episode_return = 0 425 | observation, done = env.reset(), False 426 | while not done: 427 | observation = (observation - obs_mean) / (obs_std + 1e-5) 428 | action = policy_fn(observations=observation) 429 | observation, reward, done, info = env.step(action) 430 | episode_return += reward 431 | episode_returns.append(episode_return) 432 | return env.get_normalized_score(np.mean(episode_returns)) * 100 433 | 434 | 435 | if __name__ == "__main__": 436 | wandb.init(config=config, project=config.project) 437 | rng = jax.random.PRNGKey(config.seed) 438 | env = gym.make(config.env_name) 439 | dataset, obs_mean, obs_std = get_dataset(env, config) 440 | # create train_state 441 | rng, subkey = jax.random.split(rng) 442 | example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset) 443 | train_state: IQLTrainState = create_iql_train_state( 444 | subkey, 445 | example_batch.observations, 446 | example_batch.actions, 447 | config, 448 | ) 449 | 450 | algo = IQL() 451 | update_fn = jax.jit(algo.update_n_times, static_argnums=(3,)) 452 | act_fn = jax.jit(algo.get_action) 453 | num_steps = config.max_steps // config.n_jitted_updates 454 | eval_interval = config.eval_interval // config.n_jitted_updates 455 | for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True): 456 | rng, subkey = jax.random.split(rng) 457 | train_state, update_info = update_fn(train_state, dataset, subkey, config) 458 | 459 | if i % config.log_interval == 0: 460 | train_metrics = {f"training/{k}": v for k, v in update_info.items()} 461 | wandb.log(train_metrics, step=i) 462 | 463 | if i % eval_interval == 0: 464 | policy_fn = partial( 465 | act_fn, 466 | temperature=0.0, 467 | seed=jax.random.PRNGKey(0), 468 | train_state=train_state, 469 | ) 470 | normalized_score = evaluate( 471 | policy_fn, 472 | env, 473 | num_episodes=config.eval_episodes, 474 | obs_mean=obs_mean, 475 | obs_std=obs_std, 476 | ) 477 | print(i, normalized_score) 478 | eval_metrics = {f"{config.env_name}/normalized_score": normalized_score} 479 | wandb.log(eval_metrics, step=i) 480 | # final evaluation 481 | policy_fn = partial( 482 | act_fn, 483 | temperature=0.0, 484 | seed=jax.random.PRNGKey(0), 485 | train_state=train_state, 486 | ) 487 | normalized_score = evaluate( 488 | policy_fn, 489 | env, 490 | num_episodes=config.eval_episodes, 491 | obs_mean=obs_mean, 492 | obs_std=obs_std, 493 | ) 494 | print("Final Evaluation", normalized_score) 495 | wandb.log({f"{config.env_name}/final_normalized_score": normalized_score}) 496 | wandb.finish() 497 | -------------------------------------------------------------------------------- /algos/dt.py: -------------------------------------------------------------------------------- 1 | # source https://github.com/nikhilbarhate99/min-decision-transformer 2 | # https://arxiv.org/abs/2106.01345 3 | import collections 4 | import os 5 | from functools import partial 6 | from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple 7 | 8 | import d4rl 9 | import flax 10 | import gym 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import wandb 16 | from flax import linen as nn 17 | from flax.training.train_state import TrainState 18 | from omegaconf import OmegaConf 19 | from pydantic import BaseModel 20 | from tqdm import tqdm 21 | 22 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 23 | 24 | 25 | class DTConfig(BaseModel): 26 | # GENERAL 27 | algo: str = "DT" 28 | project: str = "decision-transformer" 29 | seed: int = 0 30 | env_name: str = "halfcheetah-medium-expert-v2" 31 | batch_size: int = 64 32 | num_eval_episodes: int = 5 33 | max_eval_ep_len: int = 1000 34 | max_steps: int = 20000 35 | eval_interval: int = 2000 36 | # NETWORK 37 | context_len: int = 20 38 | n_blocks: int = 3 39 | embed_dim: int = 128 40 | n_heads: int = 1 41 | dropout_p: float = 0.1 42 | lr: float = 0.0008 43 | wt_decay: float = 1e-4 44 | beta: Sequence = (0.9, 0.999) 45 | clip_grads: float = 0.25 46 | warmup_steps: int = 10000 47 | # DT SPECIFIC 48 | rtg_scale: int = 1000 49 | rtg_target: int = None 50 | 51 | 52 | conf_dict = OmegaConf.from_cli() 53 | config: DTConfig = DTConfig(**conf_dict) 54 | 55 | # RTG target is specific to each environment 56 | if "halfcheetah" in config.env_name: 57 | rtg_target = 12000 58 | elif "hopper" in config.env_name: 59 | rtg_target = 3600 60 | elif "walker" in config.env_name: 61 | rtg_target = 5000 62 | else: 63 | raise ValueError("We only care about Mujoco envs for now.") 64 | config.rtg_target = rtg_target 65 | 66 | 67 | def default_init(scale: Optional[float] = jnp.sqrt(2)): 68 | return nn.initializers.orthogonal(scale) 69 | 70 | 71 | class MaskedCausalAttention(nn.Module): 72 | h_dim: int 73 | max_T: int 74 | n_heads: int 75 | drop_p: float 76 | kernel_init: Callable = default_init() 77 | 78 | @nn.compact 79 | def __call__(self, x: jnp.ndarray, training=True) -> jnp.ndarray: 80 | B, T, C = x.shape 81 | N, D = self.n_heads, C // self.n_heads 82 | # rearrange q, k, v as (B, N, T, D) 83 | q = ( 84 | nn.Dense(self.h_dim, kernel_init=self.kernel_init)(x) 85 | .reshape(B, T, N, D) 86 | .transpose(0, 2, 1, 3) 87 | ) 88 | k = ( 89 | nn.Dense(self.h_dim, kernel_init=self.kernel_init)(x) 90 | .reshape(B, T, N, D) 91 | .transpose(0, 2, 1, 3) 92 | ) 93 | v = ( 94 | nn.Dense(self.h_dim, kernel_init=self.kernel_init)(x) 95 | .reshape(B, T, N, D) 96 | .transpose(0, 2, 1, 3) 97 | ) 98 | # causal mask 99 | ones = jnp.ones((self.max_T, self.max_T)) 100 | mask = jnp.tril(ones).reshape(1, 1, self.max_T, self.max_T) 101 | # weights (B, N, T, T) jax 102 | weights = jnp.einsum("bntd,bnfd->bntf", q, k) / jnp.sqrt(D) 103 | # causal mask applied to weights 104 | weights = jnp.where(mask[..., :T, :T] == 0, -jnp.inf, weights[..., :T, :T]) 105 | # normalize weights, all -inf -> 0 after softmax 106 | normalized_weights = jax.nn.softmax(weights, axis=-1) 107 | # attention (B, N, T, D) 108 | attention = nn.Dropout(self.drop_p, deterministic=not training)( 109 | jnp.einsum("bntf,bnfd->bntd", normalized_weights, v) 110 | ) 111 | # gather heads and project (B, N, T, D) -> (B, T, N*D) 112 | attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N * D) 113 | out = nn.Dropout(self.drop_p, deterministic=not training)( 114 | nn.Dense(self.h_dim)(attention) 115 | ) 116 | return out 117 | 118 | 119 | class Block(nn.Module): 120 | h_dim: int 121 | max_T: int 122 | n_heads: int 123 | drop_p: float 124 | kernel_init: Callable = default_init() 125 | 126 | @nn.compact 127 | def __call__(self, x: jnp.ndarray, training=True) -> jnp.ndarray: 128 | # Attention -> LayerNorm -> MLP -> LayerNorm 129 | x = x + MaskedCausalAttention( 130 | self.h_dim, self.max_T, self.n_heads, self.drop_p 131 | )( 132 | x, training=training 133 | ) # residual 134 | x = nn.LayerNorm()(x) 135 | # MLP 136 | out = nn.Dense(4 * self.h_dim, kernel_init=self.kernel_init)(x) 137 | out = nn.gelu(out) 138 | out = nn.Dense(self.h_dim, kernel_init=self.kernel_init)(out) 139 | out = nn.Dropout(self.drop_p, deterministic=not training)(out) 140 | # residual 141 | x = x + out 142 | x = nn.LayerNorm()(x) 143 | return x 144 | 145 | 146 | class DecisionTransformer(nn.Module): 147 | state_dim: int 148 | act_dim: int 149 | n_blocks: int 150 | h_dim: int 151 | context_len: int 152 | n_heads: int 153 | drop_p: float 154 | max_timestep: int = 4096 155 | kernel_init: Callable = default_init() 156 | 157 | def setup(self) -> None: 158 | self.blocks = [ 159 | Block(self.h_dim, 3 * self.context_len, self.n_heads, self.drop_p) 160 | for _ in range(self.n_blocks) 161 | ] 162 | # projection heads (project to embedding) 163 | self.embed_ln = nn.LayerNorm() 164 | self.embed_timestep = nn.Embed(self.max_timestep, self.h_dim) 165 | self.embed_rtg = nn.Dense(self.h_dim, kernel_init=self.kernel_init) 166 | self.embed_state = nn.Dense(self.h_dim, kernel_init=self.kernel_init) 167 | # continuous actions 168 | self.embed_action = nn.Dense(self.h_dim, kernel_init=self.kernel_init) 169 | self.use_action_tanh = True 170 | # prediction heads 171 | self.predict_rtg = nn.Dense(1, kernel_init=self.kernel_init) 172 | self.predict_state = nn.Dense(self.state_dim, kernel_init=self.kernel_init) 173 | self.predict_action = nn.Dense(self.act_dim, kernel_init=self.kernel_init) 174 | 175 | def __call__( 176 | self, 177 | timesteps: jnp.ndarray, 178 | states: jnp.ndarray, 179 | actions: jnp.ndarray, 180 | returns_to_go: jnp.ndarray, 181 | training=True, 182 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 183 | B, T, _ = states.shape 184 | 185 | time_embeddings = self.embed_timestep(timesteps) 186 | # time embeddings are treated similar to positional embeddings 187 | state_embeddings = self.embed_state(states) + time_embeddings 188 | action_embeddings = self.embed_action(actions) + time_embeddings 189 | returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings 190 | # stack rtg, states and actions and reshape sequence as 191 | # (r1, s1, a1, r2, s2, a2 ...) 192 | h = ( 193 | jnp.stack((returns_embeddings, state_embeddings, action_embeddings), axis=1) 194 | .transpose(0, 2, 1, 3) 195 | .reshape(B, 3 * T, self.h_dim) 196 | ) 197 | h = self.embed_ln(h) 198 | # transformer and prediction 199 | for block in self.blocks: 200 | h = block(h, training=training) 201 | # get h reshaped such that its size = (B x 3 x T x h_dim) and 202 | # h[:, 0, t] is conditioned on r_0, s_0, a_0 ... r_t 203 | # h[:, 1, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t 204 | # h[:, 2, t] is conditioned on r_0, s_0, a_0 ... r_t, s_t, a_t 205 | h = h.reshape(B, T, 3, self.h_dim).transpose(0, 2, 1, 3) 206 | # get predictions 207 | return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a 208 | state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a 209 | action_preds = self.predict_action(h[:, 1]) 210 | if self.use_action_tanh: 211 | action_preds = jnp.tanh(action_preds) 212 | 213 | return state_preds, action_preds, return_preds 214 | 215 | 216 | def discount_cumsum(x: jnp.ndarray, gamma: float) -> jnp.ndarray: 217 | disc_cumsum = np.zeros_like(x) 218 | disc_cumsum[-1] = x[-1] 219 | for t in reversed(range(x.shape[0] - 1)): 220 | disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1] 221 | return disc_cumsum 222 | 223 | 224 | def get_traj(env_name): 225 | name = env_name 226 | print("processing: ", name) 227 | env = gym.make(name) 228 | dataset = env.get_dataset() 229 | N = dataset["rewards"].shape[0] 230 | data_ = collections.defaultdict(list) 231 | use_timeouts = False 232 | if "timeouts" in dataset: 233 | use_timeouts = True 234 | 235 | episode_step = 0 236 | paths = [] 237 | for i in range(N): 238 | done_bool = bool(dataset["terminals"][i]) 239 | if use_timeouts: 240 | final_timestep = dataset["timeouts"][i] 241 | else: 242 | final_timestep = episode_step == 1000 - 1 243 | for k in [ 244 | "observations", 245 | "next_observations", 246 | "actions", 247 | "rewards", 248 | "terminals", 249 | ]: 250 | data_[k].append(dataset[k][i]) 251 | if done_bool or final_timestep: 252 | episode_step = 0 253 | episode_data = {} 254 | for k in data_: 255 | episode_data[k] = np.array(data_[k]) 256 | paths.append(episode_data) 257 | data_ = collections.defaultdict(list) 258 | episode_step += 1 259 | returns = np.array([np.sum(p["rewards"]) for p in paths]) 260 | num_samples = np.sum([p["rewards"].shape[0] for p in paths]) 261 | print(f"Number of samples collected: {num_samples}") 262 | print( 263 | f"Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}" 264 | ) 265 | obs_mean = dataset["observations"].mean(axis=0) 266 | obs_std = dataset["observations"].std(axis=0) 267 | return paths, obs_mean, obs_std 268 | 269 | 270 | class Trajectory(NamedTuple): 271 | timesteps: np.ndarray # num_ep x max_len 272 | states: np.ndarray # num_ep x max_len x state_dim 273 | actions: np.ndarray # num_ep x max_len x act_dim 274 | returns_to_go: np.ndarray # num_ep x max_len x 1 275 | masks: np.ndarray # num_ep x max_len 276 | 277 | 278 | def padd_by_zero(arr: jnp.ndarray, pad_to: int) -> jnp.ndarray: 279 | return np.pad(arr, ((0, pad_to - arr.shape[0]), (0, 0)), mode="constant") 280 | 281 | 282 | def make_padded_trajectories( 283 | config: DTConfig, 284 | ) -> Tuple[Trajectory, int, jnp.ndarray, jnp.ndarray, jnp.ndarray]: 285 | trajectories, mean, std = get_traj(config.env_name) 286 | # Calculate returns to go for all trajectories 287 | # Normalize states 288 | max_len = 0 289 | traj_lengths = [] 290 | for traj in trajectories: 291 | traj["returns_to_go"] = discount_cumsum(traj["rewards"], 1.0) / config.rtg_scale 292 | traj["observations"] = (traj["observations"] - mean) / std 293 | max_len = max(max_len, traj["observations"].shape[0]) 294 | traj_lengths.append(traj["observations"].shape[0]) 295 | # Pad trajectories 296 | padded_trajectories = {key: [] for key in Trajectory._fields} 297 | for traj in trajectories: 298 | timesteps = np.arange(0, len(traj["observations"])) 299 | padded_trajectories["timesteps"].append( 300 | padd_by_zero(timesteps.reshape(-1, 1), max_len).reshape(-1) 301 | ) 302 | padded_trajectories["states"].append( 303 | padd_by_zero(traj["observations"], max_len) 304 | ) 305 | padded_trajectories["actions"].append(padd_by_zero(traj["actions"], max_len)) 306 | padded_trajectories["returns_to_go"].append( 307 | padd_by_zero(traj["returns_to_go"].reshape(-1, 1), max_len) 308 | ) 309 | padded_trajectories["masks"].append( 310 | padd_by_zero( 311 | np.ones((len(traj["observations"]), 1)).reshape(-1, 1), max_len 312 | ).reshape(-1) 313 | ) 314 | return ( 315 | Trajectory( 316 | timesteps=np.stack(padded_trajectories["timesteps"]), 317 | states=np.stack(padded_trajectories["states"]), 318 | actions=np.stack(padded_trajectories["actions"]), 319 | returns_to_go=np.stack(padded_trajectories["returns_to_go"]), 320 | masks=np.stack(padded_trajectories["masks"]), 321 | ), 322 | len(trajectories), 323 | jnp.array(traj_lengths), 324 | mean, 325 | std, 326 | ) 327 | 328 | 329 | def sample_start_idx( 330 | rng: jax.random.PRNGKey, 331 | traj_idx: int, 332 | padded_traj_length: jnp.ndarray, 333 | context_len: int, 334 | ) -> jnp.ndarray: 335 | """ 336 | Determine the start_idx for given trajectory, the trajectories are padded to max_len. 337 | Therefore, naively sample from 0, max_len will produce bunch of all zero data. 338 | To avoid that, we refer padded_traj_length, the list of actual trajectry length + context_len 339 | """ 340 | traj_len = padded_traj_length[traj_idx] 341 | start_idx = jax.random.randint(rng, (1,), 0, traj_len - context_len - 1) 342 | return start_idx 343 | 344 | 345 | def extract_traj( 346 | traj_idx: jnp.ndarray, start_idx: jnp.ndarray, traj: Trajectory, context_len: int 347 | ) -> Trajectory: 348 | """ 349 | Extract the trajectory with context_len for given traj_idx and start_idx 350 | """ 351 | return jax.tree_util.tree_map( 352 | lambda x: jax.lax.dynamic_slice_in_dim(x[traj_idx], start_idx, context_len), 353 | traj, 354 | ) 355 | 356 | 357 | @partial(jax.jit, static_argnums=(2, 3, 4)) 358 | def sample_traj_batch( 359 | rng, 360 | traj: Trajectory, 361 | batch_size: int, 362 | context_len: int, 363 | episode_num: int, 364 | padded_traj_lengths: jnp.ndarray, 365 | ) -> Trajectory: 366 | traj_idx = jax.random.randint(rng, (batch_size,), 0, episode_num) # B 367 | start_idx = jax.vmap(sample_start_idx, in_axes=(0, 0, None, None))( 368 | jax.random.split(rng, batch_size), traj_idx, padded_traj_lengths, context_len 369 | ).reshape( 370 | -1 371 | ) # B 372 | return jax.vmap(extract_traj, in_axes=(0, 0, None, None))( 373 | traj_idx, start_idx, traj, context_len 374 | ) 375 | 376 | 377 | class DTTrainState(NamedTuple): 378 | transformer: TrainState 379 | 380 | 381 | class DT(object): 382 | 383 | @classmethod 384 | def update( 385 | self, train_state: DTTrainState, batch: Trajectory, rng: jax.random.PRNGKey 386 | ) -> Tuple[Any, jnp.ndarray]: 387 | timesteps, states, actions, returns_to_go, traj_mask = ( 388 | batch.timesteps, 389 | batch.states, 390 | batch.actions, 391 | batch.returns_to_go, 392 | batch.masks, 393 | ) 394 | 395 | def loss_fn(params): 396 | state_preds, action_preds, return_preds = train_state.transformer.apply_fn( 397 | params, timesteps, states, actions, returns_to_go, rngs={"dropout": rng} 398 | ) # B x T x state_dim, B x T x act_dim, B x T x 1 399 | # mask actions 400 | actions_masked = actions * traj_mask[:, :, None] 401 | action_preds_masked = action_preds * traj_mask[:, :, None] 402 | # Calculate mean squared error loss 403 | action_loss = jnp.mean(jnp.square(action_preds_masked - actions_masked)) 404 | return action_loss 405 | 406 | grad_fn = jax.value_and_grad(loss_fn) 407 | loss, grad = grad_fn(train_state.transformer.params) 408 | # Apply gradient clipping 409 | transformer = train_state.transformer.apply_gradients(grads=grad) 410 | return train_state._replace(transformer=transformer), loss 411 | 412 | @classmethod 413 | def get_action( 414 | self, 415 | train_state: DTTrainState, 416 | timesteps: jnp.ndarray, 417 | states: jnp.ndarray, 418 | actions: jnp.ndarray, 419 | returns_to_go: jnp.ndarray, 420 | ) -> jnp.ndarray: 421 | state_preds, action_preds, return_preds = train_state.transformer.apply_fn( 422 | train_state.transformer.params, 423 | timesteps, 424 | states, 425 | actions, 426 | returns_to_go, 427 | training=False, 428 | ) 429 | return action_preds 430 | 431 | 432 | def create_dt_train_state( 433 | rng: jax.random.PRNGKey, state_dim: int, act_dim: int, config: DTConfig 434 | ) -> DTTrainState: 435 | model = DecisionTransformer( 436 | state_dim=state_dim, 437 | act_dim=act_dim, 438 | n_blocks=config.n_blocks, 439 | h_dim=config.embed_dim, 440 | context_len=config.context_len, 441 | n_heads=config.n_heads, 442 | drop_p=config.dropout_p, 443 | ) 444 | rng, init_rng = jax.random.split(rng) 445 | # initialize params 446 | params = model.init( 447 | init_rng, 448 | timesteps=jnp.zeros((1, config.context_len), jnp.int32), 449 | states=jnp.zeros((1, config.context_len, state_dim), jnp.float32), 450 | actions=jnp.zeros((1, config.context_len, act_dim), jnp.float32), 451 | returns_to_go=jnp.zeros((1, config.context_len, 1), jnp.float32), 452 | training=False, 453 | ) 454 | # optimizer 455 | scheduler = optax.cosine_decay_schedule( 456 | init_value=config.lr, decay_steps=config.warmup_steps 457 | ) 458 | tx = optax.chain( 459 | optax.clip_by_global_norm(config.clip_grads), 460 | optax.scale_by_schedule(scheduler), 461 | optax.adamw( 462 | learning_rate=config.lr, 463 | weight_decay=config.wt_decay, 464 | b1=config.beta[0], 465 | b2=config.beta[1], 466 | ), 467 | ) 468 | train_state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) 469 | return DTTrainState(train_state) 470 | 471 | 472 | def evaluate( 473 | policy_fn: Callable, 474 | train_state: DTTrainState, 475 | env: gym.Env, 476 | config: DTConfig, 477 | state_mean=0, 478 | state_std=1, 479 | ) -> float: 480 | eval_batch_size = 1 # required for forward pass 481 | results = {} 482 | total_reward = 0 483 | total_timesteps = 0 484 | state_dim = env.observation_space.shape[0] 485 | act_dim = env.action_space.shape[0] 486 | # same as timesteps used for training the transformer 487 | timesteps = jnp.arange(0, config.max_eval_ep_len, 1, jnp.int32) 488 | # repeat 489 | timesteps = jnp.repeat(timesteps[None, :], eval_batch_size, axis=0) 490 | for _ in range(config.num_eval_episodes): 491 | # zeros place holders 492 | actions = jnp.zeros( 493 | (eval_batch_size, config.max_eval_ep_len, act_dim), dtype=jnp.float32 494 | ) 495 | states = jnp.zeros( 496 | (eval_batch_size, config.max_eval_ep_len, state_dim), dtype=jnp.float32 497 | ) 498 | rewards_to_go = jnp.zeros( 499 | (eval_batch_size, config.max_eval_ep_len, 1), dtype=jnp.float32 500 | ) 501 | # init episode 502 | running_state = env.reset() 503 | running_reward = 0 504 | running_rtg = config.rtg_target / config.rtg_scale 505 | for t in range(config.max_eval_ep_len): 506 | total_timesteps += 1 507 | # add state in placeholder and normalize 508 | states = states.at[0, t].set((running_state - state_mean) / state_std) 509 | # calcualate running rtg and add in placeholder 510 | running_rtg = running_rtg - (running_reward / config.rtg_scale) 511 | rewards_to_go = rewards_to_go.at[0, t].set(running_rtg) 512 | if t < config.context_len: 513 | act_preds = policy_fn( 514 | train_state, 515 | timesteps[:, : t + 1], 516 | states[:, : t + 1], 517 | actions[:, : t + 1], 518 | rewards_to_go[:, : t + 1], 519 | ) 520 | act = act_preds[0, -1] 521 | else: 522 | act_preds = policy_fn( 523 | train_state, 524 | timesteps[:, t - config.context_len + 1 : t + 1], 525 | states[:, t - config.context_len + 1 : t + 1], 526 | actions[:, t - config.context_len + 1 : t + 1], 527 | rewards_to_go[:, t - config.context_len + 1 : t + 1], 528 | ) 529 | act = act_preds[0, -1] 530 | running_state, running_reward, done, _ = env.step(act) 531 | # add action in placeholder 532 | actions = actions.at[0, t].set(act) 533 | total_reward += running_reward 534 | if done: 535 | break 536 | normalized_score = ( 537 | env.get_normalized_score(total_reward / config.num_eval_episodes) * 100 538 | ) 539 | return normalized_score 540 | 541 | 542 | if __name__ == "__main__": 543 | wandb.init(project=config.project, config=config) 544 | env = gym.make(config.env_name) 545 | rng = jax.random.PRNGKey(config.seed) 546 | state_dim = env.observation_space.shape[0] 547 | act_dim = env.action_space.shape[0] 548 | trajectories, episode_num, traj_lengths, state_mean, state_std = ( 549 | make_padded_trajectories(config) 550 | ) 551 | # create trainer 552 | rng, subkey = jax.random.split(rng) 553 | train_state = create_dt_train_state(subkey, state_dim, act_dim, config) 554 | 555 | algo = DT() 556 | update_fn = jax.jit(algo.update) 557 | for i in tqdm(range(1, config.max_steps + 1), smoothing=0.1, dynamic_ncols=True): 558 | rng, data_rng, update_rng = jax.random.split(rng, 3) 559 | traj_batch = sample_traj_batch( 560 | data_rng, 561 | trajectories, 562 | config.batch_size, 563 | config.context_len, 564 | episode_num, 565 | traj_lengths, 566 | ) # B x T x D 567 | train_state, action_loss = update_fn(train_state, traj_batch, update_rng) # update parameters 568 | 569 | if i % config.eval_interval == 0: 570 | # evaluate on env 571 | normalized_score = evaluate( 572 | algo.get_action, train_state, env, config, state_mean, state_std 573 | ) 574 | print(i, normalized_score) 575 | wandb.log( 576 | { 577 | "action_loss": action_loss, 578 | f"{config.env_name}/normalized_score": normalized_score, 579 | "step": i, 580 | } 581 | ) 582 | # final evaluation 583 | normalized_score = evaluate( 584 | algo.get_action, train_state, env, config, state_mean, state_std 585 | ) 586 | wandb.log({f"{config.env_name}/final_normalized_score": normalized_score}) 587 | wandb.finish() 588 | -------------------------------------------------------------------------------- /algos/xql.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/Div99/XQL 2 | # Paper: https://arxiv.org/abs/2301.02328 3 | import os 4 | import time 5 | from functools import partial 6 | from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple 7 | 8 | import d4rl 9 | import distrax 10 | import flax 11 | import flax.linen as nn 12 | import gym 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | import optax 17 | import tqdm 18 | import wandb 19 | from flax.training.train_state import TrainState 20 | from omegaconf import OmegaConf 21 | from pydantic import BaseModel 22 | 23 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 24 | 25 | 26 | class XQLConfig(BaseModel): 27 | # GENERAL 28 | algo: str = "XQL" 29 | project: str = "train-XQL" 30 | env_name: str = "halfcheetah-medium-expert-v2" 31 | seed: int = 42 32 | eval_episodes: int = 5 33 | log_interval: int = 100000 34 | eval_interval: int = 100000 35 | batch_size: int = 256 36 | max_steps: int = int(1e6) 37 | n_jitted_updates: int = 8 38 | # DATASET 39 | data_size: int = int(1e6) 40 | normalize_state: bool = False 41 | normalize_reward: bool = True 42 | # NETWORK 43 | hidden_dims: Tuple[int, int] = (256, 256) 44 | actor_lr: float = 3e-4 45 | value_lr: float = 3e-4 46 | critic_lr: float = 3e-4 47 | tau: float = 0.005 48 | discount: float = 0.99 49 | # XQL SPECIFIC 50 | expectile: float = ( 51 | 0.7 # FYI: for Hopper-me, 0.5 produce better result. (antmaze: tau=0.9) 52 | ) 53 | beta: float = ( 54 | 3.0 # FYI: for Hopper-me, 6.0 produce better result. (antmaze: beta=10.0) 55 | ) 56 | # XQL SPECIFIC 57 | vanilla: bool = False # Of course, we do not use expectile loss 58 | sample_random_times: int = 0 # sample random times 59 | grad_pen: bool = False # gradient penalty 60 | lambda_gp: int = 1 # grad penalty coefficient 61 | loss_temp: float = 1.0 # loss temperature 62 | log_loss: bool = False # log loss 63 | num_v_updates: int = 1 # number of value updates 64 | max_clip: float = 7.0 # Loss clip value 65 | noise: bool = False # noise 66 | noise_std: float = 0.1 # noise std 67 | layer_norm: bool = True # layer norm 68 | 69 | def __hash__( 70 | self, 71 | ): # make config hashable to be specified as static_argnums in jax.jit. 72 | return hash(self.__repr__()) 73 | 74 | 75 | conf_dict = OmegaConf.from_cli() 76 | config = XQLConfig(**conf_dict) 77 | 78 | 79 | def default_init(scale: Optional[float] = jnp.sqrt(2)): 80 | return nn.initializers.orthogonal(scale) 81 | 82 | 83 | class MLP(nn.Module): 84 | hidden_dims: Sequence[int] 85 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 86 | activate_final: bool = False 87 | kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init() 88 | layer_norm: bool = False 89 | 90 | @nn.compact 91 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 92 | for i, hidden_dims in enumerate(self.hidden_dims): 93 | x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x) 94 | if i + 1 < len(self.hidden_dims) or self.activate_final: 95 | if self.layer_norm: # Add layer norm after activation 96 | x = nn.LayerNorm()(x) 97 | x = self.activations(x) 98 | return x 99 | 100 | 101 | class Critic(nn.Module): 102 | hidden_dims: Sequence[int] 103 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 104 | 105 | @nn.compact 106 | def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: 107 | inputs = jnp.concatenate([observations, actions], -1) 108 | critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs) 109 | return jnp.squeeze(critic, -1) 110 | 111 | 112 | def ensemblize(cls, num_qs, out_axes=0, **kwargs): 113 | split_rngs = kwargs.pop("split_rngs", {}) 114 | return nn.vmap( 115 | cls, 116 | variable_axes={"params": 0}, 117 | split_rngs={**split_rngs, "params": True}, 118 | in_axes=None, 119 | out_axes=out_axes, 120 | axis_size=num_qs, 121 | **kwargs, 122 | ) 123 | 124 | 125 | class ValueCritic(nn.Module): 126 | hidden_dims: Sequence[int] 127 | layer_norm: bool = False 128 | 129 | @nn.compact 130 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 131 | critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm)(observations) 132 | return jnp.squeeze(critic, -1) 133 | 134 | 135 | class GaussianPolicy(nn.Module): 136 | hidden_dims: Sequence[int] 137 | action_dim: int 138 | log_std_min: Optional[float] = -5.0 139 | log_std_max: Optional[float] = 2 140 | 141 | @nn.compact 142 | def __call__( 143 | self, observations: jnp.ndarray, temperature: float = 1.0 144 | ) -> distrax.Distribution: 145 | outputs = MLP( 146 | self.hidden_dims, 147 | activate_final=True, 148 | )(observations) 149 | 150 | means = nn.Dense( 151 | self.action_dim, kernel_init=default_init() 152 | )(outputs) 153 | log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,)) 154 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 155 | 156 | distribution = distrax.MultivariateNormalDiag( 157 | loc=means, scale_diag=jnp.exp(log_stds) * temperature 158 | ) 159 | return distribution 160 | 161 | 162 | class Transition(NamedTuple): 163 | observations: jnp.ndarray 164 | actions: jnp.ndarray 165 | rewards: jnp.ndarray 166 | next_observations: jnp.ndarray 167 | dones: jnp.ndarray 168 | 169 | 170 | def get_normalization(dataset: Transition) -> float: 171 | # into numpy.ndarray 172 | dataset = jax.tree_util.tree_map(lambda x: np.array(x), dataset) 173 | returns = [] 174 | ret = 0 175 | for r, term in zip(dataset.rewards, dataset.dones): 176 | ret += r 177 | if term: 178 | returns.append(ret) 179 | ret = 0 180 | return (max(returns) - min(returns)) / 1000 181 | 182 | 183 | def get_dataset( 184 | env: gym.Env, config: XQLConfig, clip_to_eps: bool = True, eps: float = 1e-5 185 | ) -> Transition: 186 | dataset = d4rl.qlearning_dataset(env) 187 | 188 | if clip_to_eps: 189 | lim = 1 - eps 190 | dataset["actions"] = np.clip(dataset["actions"], -lim, lim) 191 | 192 | dones_float = np.zeros_like(dataset['rewards']) 193 | 194 | for i in range(len(dones_float) - 1): 195 | if np.linalg.norm(dataset['observations'][i + 1] - 196 | dataset['next_observations'][i] 197 | ) > 1e-6 or dataset['terminals'][i] == 1.0: 198 | dones_float[i] = 1 199 | else: 200 | dones_float[i] = 0 201 | dones_float[-1] = 1 202 | 203 | dataset = Transition( 204 | observations=jnp.array(dataset["observations"], dtype=jnp.float32), 205 | actions=jnp.array(dataset["actions"], dtype=jnp.float32), 206 | rewards=jnp.array(dataset["rewards"], dtype=jnp.float32), 207 | next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32), 208 | dones=jnp.array(dones_float, dtype=jnp.float32), 209 | ) 210 | # normalize states 211 | obs_mean, obs_std = 0, 1 212 | if config.normalize_state: 213 | obs_mean = dataset.observations.mean(0) 214 | obs_std = dataset.observations.std(0) 215 | dataset = dataset._replace( 216 | observations=(dataset.observations - obs_mean) / (obs_std + 1e-5), 217 | next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5), 218 | ) 219 | # normalize rewards 220 | if config.normalize_reward: 221 | normalizing_factor = get_normalization(dataset) 222 | dataset = dataset._replace(rewards=dataset.rewards / normalizing_factor) 223 | 224 | # shuffle data and select the first data_size samples 225 | data_size = min(config.data_size, len(dataset.observations)) 226 | rng = jax.random.PRNGKey(config.seed) 227 | rng, rng_permute, rng_select = jax.random.split(rng, 3) 228 | perm = jax.random.permutation(rng_permute, len(dataset.observations)) 229 | dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset) 230 | assert len(dataset.observations) >= data_size 231 | dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset) 232 | return dataset, obs_mean, obs_std 233 | 234 | 235 | def gumbel_rescale_loss(diff, alpha, max_clip=None): 236 | """Gumbel loss J: E[e^x - x - 1]. For stability to outliers, we scale the gradients with the max value over a batch 237 | and optionally clip the exponent. This has the effect of training with an adaptive lr. 238 | """ 239 | z = diff / alpha 240 | if max_clip is not None: 241 | z = jnp.minimum(z, max_clip) # clip max value 242 | max_z = jnp.max(z, axis=0) 243 | max_z = jnp.where(max_z < -1.0, -1.0, max_z) 244 | max_z = jax.lax.stop_gradient(max_z) # Detach the gradients 245 | loss = ( 246 | jnp.exp(z - max_z) - z * jnp.exp(-max_z) - jnp.exp(-max_z) 247 | ) # scale by e^max_z 248 | return loss 249 | 250 | 251 | def gumbel_log_loss(diff, alpha=1.0): 252 | """Gumbel loss J: E[e^x - x - 1]. We can calculate the log of Gumbel loss for stability, i.e. Log(J + 1) 253 | log_gumbel_loss: log((e^x - x - 1).mean() + 1) 254 | """ 255 | diff = diff 256 | x = diff / alpha 257 | grad = grad_gumbel(x, alpha) 258 | # use analytic gradients to improve stability 259 | loss = jax.lax.stop_gradient(grad) * x 260 | return loss 261 | 262 | 263 | def grad_gumbel(x, alpha, clip_max=7): 264 | """Calculate grads of log gumbel_loss: (e^x - 1)/[(e^x - x - 1).mean() + 1] 265 | We add e^-a to both numerator and denominator to get: (e^(x-a) - e^(-a))/[(e^(x-a) - xe^(-a)).mean()] 266 | """ 267 | # clip inputs to grad in [-10, 10] to improve stability (gradient clipping) 268 | x = jnp.minimum(x, clip_max) # jnp.clip(x, a_min=-10, a_max=10) 269 | 270 | # calculate an offset `a` to prevent overflow issues 271 | x_max = jnp.max(x, axis=0) 272 | # choose `a` as max(x_max, -1) as its possible for x_max to be very small and we want the offset to be reasonable 273 | x_max = jnp.where(x_max < -1, -1, x_max) 274 | 275 | # keep track of original x 276 | x_orig = x 277 | # offsetted x 278 | x1 = x - x_max 279 | 280 | grad = (jnp.exp(x1) - jnp.exp(-x_max)) / ( 281 | jnp.mean(jnp.exp(x1) - x_orig * jnp.exp(-x_max), axis=0, keepdims=True) 282 | ) 283 | return grad 284 | 285 | 286 | def expectile_loss(diff, expectile=0.8) -> jnp.ndarray: 287 | weight = jnp.where(diff > 0, expectile, (1 - expectile)) 288 | return weight * (diff**2) 289 | 290 | 291 | def huber_loss(x, delta: float = 1.): 292 | # 0.5 * x^2 if |x| <= d 293 | # 0.5 * d^2 + d * (|x| - d) if |x| > d 294 | abs_x = jnp.abs(x) 295 | quadratic = jnp.minimum(abs_x, delta) 296 | # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient. 297 | linear = abs_x - quadratic 298 | return 0.5 * quadratic**2 + delta * linear 299 | 300 | 301 | def target_update( 302 | model: TrainState, target_model: TrainState, tau: float 303 | ) -> TrainState: 304 | new_target_params = jax.tree_util.tree_map( 305 | lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params 306 | ) 307 | return target_model.replace(params=new_target_params) 308 | 309 | 310 | def update_by_loss_grad( 311 | train_state: TrainState, loss_fn: Callable 312 | ) -> Tuple[TrainState, jnp.ndarray]: 313 | grad_fn = jax.value_and_grad(loss_fn) 314 | loss, grad = grad_fn(train_state.params) 315 | new_train_state = train_state.apply_gradients(grads=grad) 316 | return new_train_state, loss 317 | 318 | 319 | class XQLTrainState(NamedTuple): 320 | rng: jax.random.PRNGKey 321 | critic: TrainState 322 | target_critic: TrainState 323 | value: TrainState 324 | actor: TrainState 325 | 326 | 327 | class XQL(object): 328 | @classmethod 329 | def update_critic( 330 | self, train_state: XQLTrainState, batch: Transition, config: XQLConfig 331 | ) -> Tuple["XQLTrainState", Dict]: 332 | next_v = train_state.value.apply_fn( 333 | train_state.value.params, batch.next_observations 334 | ) 335 | target_q = batch.rewards + config.discount * (1 - batch.dones) * next_v 336 | def critic_loss_fn( 337 | critic_params: flax.core.FrozenDict[str, Any] 338 | ) -> jnp.ndarray: 339 | v = train_state.value.apply_fn(train_state.value.params, batch.observations) 340 | def mse_loss(q, q_target, *args): 341 | x = q-q_target 342 | loss = huber_loss(x, delta=20.0) # x**2 343 | return loss.mean() 344 | 345 | q1, q2 = train_state.critic.apply_fn( 346 | critic_params, batch.observations, batch.actions 347 | ) 348 | loss_1 = mse_loss(q1, target_q, v, config.loss_temp) 349 | loss_2 = mse_loss(q2, target_q, v, config.loss_temp) 350 | loss = (loss_1 + loss_2) / 2 351 | return loss 352 | 353 | new_critic, critic_loss = update_by_loss_grad( 354 | train_state.critic, critic_loss_fn 355 | ) 356 | return train_state._replace(critic=new_critic), critic_loss 357 | 358 | @classmethod 359 | def update_value( 360 | self, train_state: XQLTrainState, batch: Transition, rng, config: XQLConfig 361 | ) -> Tuple["XQLTrainState", Dict]: 362 | actions = batch.actions 363 | 364 | rng1, rng2 = jax.random.split(rng) 365 | if config.sample_random_times > 0: 366 | # add random actions to smooth loss computation (use 1/2(rho + Unif)) 367 | times = config.sample_random_times 368 | random_action = jax.random.uniform( 369 | rng1, 370 | shape=(times * actions.shape[0], actions.shape[1]), 371 | minval=-1.0, 372 | maxval=1.0, 373 | ) 374 | obs = jnp.concatenate( 375 | [batch.observations, jnp.repeat(batch.observations, times, axis=0)], 376 | axis=0, 377 | ) 378 | acts = jnp.concatenate([batch.actions, random_action], axis=0) 379 | else: 380 | obs = batch.observations 381 | acts = batch.actions 382 | 383 | if config.noise: 384 | std = config.noise_std 385 | noise = jax.random.normal(rng2, shape=(acts.shape[0], acts.shape[1])) 386 | noise = jnp.clip(noise * std, -0.5, 0.5) 387 | acts = batch.actions + noise 388 | acts = jnp.clip(acts, -1, 1) 389 | 390 | q1, q2 = train_state.target_critic.apply_fn( 391 | train_state.target_critic.params, obs, acts 392 | ) 393 | q = jnp.minimum(q1, q2) 394 | 395 | def value_loss_fn( 396 | value_params: flax.core.FrozenDict[str, Any] 397 | ) -> Tuple[jnp.ndarray, Dict]: 398 | v = train_state.value.apply_fn(value_params, obs) 399 | 400 | if config.vanilla: 401 | value_loss = expectile_loss(q - v, config.expectile).mean() 402 | else: 403 | if config.log_loss: 404 | value_loss = gumbel_log_loss(q - v, alpha=config.loss_temp).mean() 405 | else: 406 | value_loss = gumbel_rescale_loss( 407 | q - v, alpha=config.loss_temp, max_clip=config.max_clip 408 | ).mean() 409 | return value_loss 410 | 411 | new_value, value_loss = update_by_loss_grad(train_state.value, value_loss_fn) 412 | return train_state._replace(value=new_value), value_loss 413 | 414 | @classmethod 415 | def update_actor( 416 | self, train_state: XQLTrainState, batch: Transition, config: XQLConfig 417 | ) -> Tuple["XQLTrainState", Dict]: 418 | def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray: 419 | v = train_state.value.apply_fn(train_state.value.params, batch.observations) 420 | q1, q2 = train_state.target_critic.apply_fn( 421 | train_state.target_critic.params, batch.observations, batch.actions 422 | ) 423 | q = jnp.minimum(q1, q2) 424 | exp_a = jnp.exp((q - v) * config.beta) 425 | exp_a = jnp.minimum(exp_a, 100.0) 426 | 427 | dist = train_state.actor.apply_fn(actor_params, batch.observations) 428 | log_probs = dist.log_prob(batch.actions) 429 | actor_loss = -(exp_a * log_probs).mean() 430 | return actor_loss 431 | 432 | new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn) 433 | return train_state._replace(actor=new_actor), actor_loss 434 | 435 | @classmethod 436 | def update_n_times( 437 | self, 438 | train_state: XQLTrainState, 439 | dataset: Transition, 440 | rng: jax.random.PRNGKey, 441 | config: XQLConfig, 442 | ) -> Tuple["XQLTrainState", Dict]: 443 | for _ in range(config.n_jitted_updates): 444 | rng, subkey = jax.random.split(rng) 445 | batch_indices = jax.random.randint( 446 | subkey, (config.batch_size,), 0, len(dataset.observations) 447 | ) 448 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 449 | 450 | rng, subkey = jax.random.split(rng) 451 | train_state, value_loss = self.update_value( 452 | train_state, batch, subkey, config 453 | ) 454 | train_state, actor_loss = self.update_actor(train_state, batch, config) 455 | train_state, critic_loss = self.update_critic(train_state, batch, config) 456 | new_target_critic = target_update( 457 | train_state.critic, train_state.target_critic, config.tau 458 | ) 459 | train_state = train_state._replace(target_critic=new_target_critic) 460 | return train_state, { 461 | "value_loss": value_loss, 462 | "actor_loss": actor_loss, 463 | "critic_loss": critic_loss, 464 | } 465 | 466 | @classmethod 467 | def get_action( 468 | self, 469 | train_state: XQLTrainState, 470 | observations: np.ndarray, 471 | seed: jax.random.PRNGKey, 472 | temperature: float = 1.0, 473 | max_action: float = 1.0, # In D4RL, the action space is [-1, 1] 474 | ) -> jnp.ndarray: 475 | actions = train_state.actor.apply_fn( 476 | train_state.actor.params, observations, temperature=temperature 477 | ).sample(seed=seed) 478 | actions = jnp.clip(actions, -max_action, max_action) 479 | return actions 480 | 481 | 482 | def create_xql_train_state( 483 | rng: jax.random.PRNGKey, 484 | observations: jnp.ndarray, 485 | actions: jnp.ndarray, 486 | config: XQLConfig, 487 | ) -> XQLTrainState: 488 | rng, actor_rng, critic_rng, value_rng = jax.random.split(rng, 4) 489 | # initialize actor 490 | action_dim = actions.shape[-1] 491 | actor_model = GaussianPolicy( 492 | config.hidden_dims, 493 | action_dim=action_dim, 494 | log_std_min=-5.0, 495 | ) 496 | schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps) 497 | actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn)) 498 | actor = TrainState.create( 499 | apply_fn=actor_model.apply, 500 | params=actor_model.init(actor_rng, observations), 501 | tx=actor_tx, 502 | ) 503 | # initialize critic 504 | critic_model = ensemblize(Critic, num_qs=2)(config.hidden_dims) 505 | critic = TrainState.create( 506 | apply_fn=critic_model.apply, 507 | params=critic_model.init(critic_rng, observations, actions), 508 | tx=optax.adam(learning_rate=config.critic_lr), 509 | ) 510 | target_critic = TrainState.create( 511 | apply_fn=critic_model.apply, 512 | params=critic_model.init(critic_rng, observations, actions), 513 | tx=optax.adam(learning_rate=config.critic_lr), 514 | ) 515 | # initialize value 516 | value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm) 517 | value = TrainState.create( 518 | apply_fn=value_model.apply, 519 | params=value_model.init(value_rng, observations), 520 | tx=optax.adam(learning_rate=config.value_lr), 521 | ) 522 | return XQLTrainState( 523 | rng, 524 | critic=critic, 525 | target_critic=target_critic, 526 | value=value, 527 | actor=actor, 528 | ) 529 | 530 | 531 | def evaluate( 532 | policy_fn, env: gym.Env, num_episodes: int, obs_mean: float, obs_std: float 533 | ) -> float: 534 | episode_returns = [] 535 | for _ in range(num_episodes): 536 | episode_return = 0 537 | observation, done = env.reset(), False 538 | while not done: 539 | observation = (observation - obs_mean) / (obs_std + 1e-5) 540 | action = policy_fn(observations=observation) 541 | observation, reward, done, info = env.step(action) 542 | episode_return += reward 543 | episode_returns.append(episode_return) 544 | return env.get_normalized_score(np.mean(episode_returns)) * 100 545 | 546 | 547 | if __name__ == "__main__": 548 | wandb.init(config=config, project=config.project) 549 | rng = jax.random.PRNGKey(config.seed) 550 | env = gym.make(config.env_name) 551 | dataset, obs_mean, obs_std = get_dataset(env, config) 552 | 553 | # create train_state 554 | rng, subkey = jax.random.split(rng) 555 | example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset) 556 | train_state: XQLTrainState = create_xql_train_state( 557 | subkey, example_batch.observations, example_batch.actions, config 558 | ) 559 | 560 | algo = XQL() 561 | update_fn = jax.jit(algo.update_n_times, static_argnums=(3,)) 562 | act_fn = jax.jit(algo.get_action) 563 | num_steps = config.max_steps // config.n_jitted_updates 564 | eval_interval = config.eval_interval // config.n_jitted_updates 565 | for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True): 566 | rng, subkey = jax.random.split(rng) 567 | train_state, update_info = update_fn(train_state, dataset, subkey, config) 568 | 569 | if i % config.log_interval == 0: 570 | train_metrics = {f"training/{k}": v for k, v in update_info.items()} 571 | wandb.log(train_metrics, step=i) 572 | 573 | if i % eval_interval == 0: 574 | policy_fn = partial( 575 | act_fn, 576 | temperature=0.0, 577 | seed=jax.random.PRNGKey(0), 578 | train_state=train_state, 579 | ) 580 | normalized_score = evaluate( 581 | policy_fn, 582 | env, 583 | num_episodes=config.eval_episodes, 584 | obs_mean=obs_mean, 585 | obs_std=obs_std, 586 | ) 587 | print(i, normalized_score) 588 | eval_metrics = {f"{config.env_name}/normalized_score": normalized_score} 589 | wandb.log(eval_metrics, step=i) 590 | # final evaluation 591 | policy_fn = partial( 592 | act_fn, 593 | temperature=0.0, 594 | seed=jax.random.PRNGKey(0), 595 | train_state=train_state, 596 | ) 597 | normalized_score = evaluate( 598 | policy_fn, 599 | env, 600 | num_episodes=config.eval_episodes, 601 | obs_mean=obs_mean, 602 | obs_std=obs_std, 603 | ) 604 | print("Final Evaluation", normalized_score) 605 | wandb.log({f"{config.env_name}/final_normalized_score": normalized_score}) 606 | wandb.finish() 607 | -------------------------------------------------------------------------------- /algos/cql.py: -------------------------------------------------------------------------------- 1 | # source https://github.com/young-geng/JaxCQL 2 | # https://arxiv.org/abs/2006.04779 3 | import os 4 | import time 5 | from copy import deepcopy 6 | from functools import partial 7 | from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple 8 | 9 | import d4rl 10 | import distrax 11 | import flax 12 | import flax.linen as nn 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as np 17 | import optax 18 | import tqdm 19 | import wandb 20 | from flax.training.train_state import TrainState 21 | from omegaconf import OmegaConf 22 | from pydantic import BaseModel 23 | 24 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True" 25 | 26 | 27 | class CQLConfig(BaseModel): 28 | # GENERAL 29 | also: str = "CQL" 30 | project: str = "cql-jax" 31 | env_name: str = "halfcheetah-medium-expert-v2" 32 | seed: int = 42 33 | n_jitted_updates: int = 8 34 | max_steps: int = 1000000 35 | eval_interval: int = 100000 36 | eval_episodes: int = 5 37 | # DATA 38 | data_size: int = 1000000 39 | action_dim: Optional[int] = None 40 | normalize_state: bool = False 41 | reward_scale: float = 1.0 42 | reward_bias: float = 0.0 43 | batch_size: int = 256 44 | max_traj_length: int = 1000 45 | # NETWORK 46 | hidden_dims: Tuple[int] = (256, 256) 47 | policy_lr: float = 3e-4 48 | qf_lr: float = 3e-4 49 | optimizer_type: str = "adam" 50 | soft_target_update_rate: float = 5e-3 51 | orthogonal_init: bool = False 52 | policy_log_std_multiplier: float = 1.0 53 | policy_log_std_offset: float = -1.0 54 | # CQL SPECIFIC 55 | discount: float = 0.99 56 | alpha_multiplier: float = 1.0 57 | use_automatic_entropy_tuning: bool = True 58 | backup_entropy: bool = False 59 | target_entropy: float = 0.0 60 | use_cql: bool = True 61 | cql_n_actions: int = 10 62 | cql_importance_sample: bool = True 63 | cql_lagrange: bool = False 64 | cql_target_action_gap: float = 1.0 65 | cql_temp: float = 1.0 66 | cql_min_q_weight: float = 5.0 67 | cql_max_target_backup: bool = False 68 | cql_clip_diff_min: float = -np.inf 69 | cql_clip_diff_max: float = np.inf 70 | 71 | def __hash__(self): 72 | return hash(self.__repr__()) 73 | 74 | 75 | conf_dict = OmegaConf.from_cli() 76 | config = CQLConfig(**conf_dict) 77 | 78 | 79 | def extend_and_repeat(tensor: jnp.ndarray, axis: int, repeat: int) -> jnp.ndarray: 80 | return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis) 81 | 82 | 83 | def mse_loss(val: jnp.ndarray, target: jnp.ndarray) -> jnp.ndarray: 84 | return jnp.mean(jnp.square(val - target)) 85 | 86 | 87 | def value_and_multi_grad( 88 | fun: Callable, n_outputs: int, argnums=0, has_aux=False 89 | ) -> Callable: 90 | def select_output(index: int) -> Callable: 91 | def wrapped(*args, **kwargs): 92 | if has_aux: 93 | x, *aux = fun(*args, **kwargs) 94 | return (x[index], *aux) 95 | else: 96 | x = fun(*args, **kwargs) 97 | return x[index] 98 | 99 | return wrapped 100 | 101 | grad_fns = tuple( 102 | jax.value_and_grad(select_output(i), argnums=argnums, has_aux=has_aux) 103 | for i in range(n_outputs) 104 | ) 105 | 106 | def multi_grad_fn(*args, **kwargs): 107 | grads = [] 108 | values = [] 109 | for grad_fn in grad_fns: 110 | (value, *aux), grad = grad_fn(*args, **kwargs) 111 | values.append(value) 112 | grads.append(grad) 113 | return (tuple(values), *aux), tuple(grads) 114 | 115 | return multi_grad_fn 116 | 117 | 118 | def update_target_network(main_params: Any, target_params: Any, tau: float) -> Any: 119 | return jax.tree_util.tree_map( 120 | lambda x, y: tau * x + (1.0 - tau) * y, main_params, target_params 121 | ) 122 | 123 | 124 | def multiple_action_q_function(forward: Callable) -> Callable: 125 | # Forward the q function with multiple actions on each state, to be used as a decorator 126 | def wrapped( 127 | self, observations: jnp.ndarray, actions: jnp.ndarray, **kwargs 128 | ) -> jnp.ndarray: 129 | multiple_actions = False 130 | batch_size = observations.shape[0] 131 | if actions.ndim == 3 and observations.ndim == 2: 132 | multiple_actions = True 133 | observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape( 134 | -1, observations.shape[-1] 135 | ) 136 | actions = actions.reshape(-1, actions.shape[-1]) 137 | q_values = forward(self, observations, actions, **kwargs) 138 | if multiple_actions: 139 | q_values = q_values.reshape(batch_size, -1) 140 | return q_values 141 | 142 | return wrapped 143 | 144 | 145 | class Scalar(nn.Module): 146 | init_value: float 147 | 148 | def setup(self) -> None: 149 | self.value = self.param("value", lambda x: self.init_value) 150 | 151 | def __call__(self) -> jnp.ndarray: 152 | return self.value 153 | 154 | 155 | class FullyConnectedNetwork(nn.Module): 156 | output_dim: int 157 | hidden_dims: Tuple[int] = (256, 256) 158 | orthogonal_init: bool = False 159 | 160 | @nn.compact 161 | def __call__(self, input_tensor: jnp.ndarray) -> jnp.ndarray: 162 | x = input_tensor 163 | for h in self.hidden_dims: 164 | if self.orthogonal_init: 165 | x = nn.Dense( 166 | h, 167 | kernel_init=jax.nn.initializers.orthogonal(jnp.sqrt(2.0)), 168 | bias_init=jax.nn.initializers.zeros, 169 | )(x) 170 | else: 171 | x = nn.Dense(h)(x) 172 | x = nn.relu(x) 173 | 174 | if self.orthogonal_init: 175 | output = nn.Dense( 176 | self.output_dim, 177 | kernel_init=jax.nn.initializers.orthogonal(1e-2), 178 | bias_init=jax.nn.initializers.zeros, 179 | )(x) 180 | else: 181 | output = nn.Dense( 182 | self.output_dim, 183 | kernel_init=jax.nn.initializers.variance_scaling( 184 | 1e-2, "fan_in", "uniform" 185 | ), 186 | bias_init=jax.nn.initializers.zeros, 187 | )(x) 188 | return output 189 | 190 | 191 | class FullyConnectedQFunction(nn.Module): 192 | observation_dim: int 193 | action_dim: int 194 | hidden_dims: Tuple[int] = (256, 256) 195 | orthogonal_init: bool = False 196 | 197 | @nn.compact 198 | @multiple_action_q_function 199 | def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: 200 | x = jnp.concatenate([observations, actions], axis=-1) 201 | x = FullyConnectedNetwork( 202 | output_dim=1, 203 | hidden_dims=self.hidden_dims, 204 | orthogonal_init=self.orthogonal_init, 205 | )(x) 206 | return jnp.squeeze(x, -1) 207 | 208 | 209 | class TanhGaussianPolicy(nn.Module): 210 | observation_dim: int 211 | action_dim: int 212 | hidden_dims: Tuple[int] = (256, 256) 213 | orthogonal_init: bool = False 214 | log_std_multiplier: float = 1.0 215 | log_std_offset: float = -1.0 216 | 217 | def setup(self) -> None: 218 | self.base_network = FullyConnectedNetwork( 219 | output_dim=2 * config.action_dim, 220 | hidden_dims=self.hidden_dims, 221 | orthogonal_init=self.orthogonal_init, 222 | ) 223 | self.log_std_multiplier_module = Scalar(self.log_std_multiplier) 224 | self.log_std_offset_module = Scalar(self.log_std_offset) 225 | 226 | def log_prob(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: 227 | if actions.ndim == 3: 228 | observations = extend_and_repeat(observations, 1, actions.shape[1]) 229 | base_network_output = self.base_network(observations) 230 | mean, log_std = jnp.split(base_network_output, 2, axis=-1) 231 | log_std = ( 232 | self.log_std_multiplier_module() * log_std + self.log_std_offset_module() 233 | ) 234 | log_std = jnp.clip(log_std, -20.0, 2.0) 235 | action_distribution = distrax.Transformed( 236 | distrax.MultivariateNormalDiag(mean, jnp.exp(log_std)), 237 | distrax.Block(distrax.Tanh(), ndims=1), 238 | ) 239 | return action_distribution.log_prob(actions) 240 | 241 | def __call__( 242 | self, 243 | observations: jnp.ndarray, 244 | rng: jax.random.PRNGKey, 245 | deterministic=False, 246 | repeat=None, 247 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 248 | if repeat is not None: 249 | observations = extend_and_repeat(observations, 1, repeat) 250 | base_network_output = self.base_network(observations) 251 | mean, log_std = jnp.split(base_network_output, 2, axis=-1) 252 | log_std = ( 253 | self.log_std_multiplier_module() * log_std + self.log_std_offset_module() 254 | ) 255 | log_std = jnp.clip(log_std, -20.0, 2.0) 256 | action_distribution = distrax.Transformed( 257 | distrax.MultivariateNormalDiag(mean, jnp.exp(log_std)), 258 | distrax.Block(distrax.Tanh(), ndims=1), 259 | ) 260 | if deterministic: 261 | samples = jnp.tanh(mean) 262 | log_prob = action_distribution.log_prob(samples) 263 | else: 264 | samples, log_prob = action_distribution.sample_and_log_prob(seed=rng) 265 | 266 | return samples, log_prob 267 | 268 | 269 | class Transition(NamedTuple): 270 | observations: np.ndarray 271 | actions: np.ndarray 272 | rewards: np.ndarray 273 | next_observations: np.ndarray 274 | dones: np.ndarray 275 | 276 | 277 | def get_dataset( 278 | env: gym.Env, config: CQLConfig, clip_to_eps: bool = True, eps: float = 1e-5 279 | ) -> Transition: 280 | dataset = d4rl.qlearning_dataset(env) 281 | 282 | if clip_to_eps: 283 | lim = 1 - eps 284 | dataset["actions"] = np.clip(dataset["actions"], -lim, lim) 285 | 286 | dataset = Transition( 287 | observations=jnp.array(dataset["observations"], dtype=jnp.float32), 288 | actions=jnp.array(dataset["actions"], dtype=jnp.float32), 289 | rewards=jnp.array(dataset["rewards"], dtype=jnp.float32), 290 | next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32), 291 | dones=jnp.array(dataset["terminals"], dtype=jnp.float32), 292 | ) 293 | # shuffle data and select the first data_size samples 294 | data_size = min(config.data_size, len(dataset.observations)) 295 | rng = jax.random.PRNGKey(config.seed) 296 | rng, rng_permute, rng_select = jax.random.split(rng, 3) 297 | perm = jax.random.permutation(rng_permute, len(dataset.observations)) 298 | dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset) 299 | assert len(dataset.observations) >= data_size 300 | dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset) 301 | # normalize states 302 | obs_mean, obs_std = 0, 1 303 | if config.normalize_state: 304 | obs_mean = dataset.observations.mean(0) 305 | obs_std = dataset.observations.std(0) 306 | dataset = dataset._replace( 307 | observations=(dataset.observations - obs_mean) / (obs_std + 1e-5), 308 | next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5), 309 | ) 310 | return dataset, obs_mean, obs_std 311 | 312 | 313 | def collect_metrics(metrics, names, prefix=None): 314 | collected = {} 315 | for name in names: 316 | if name in metrics: 317 | collected[name] = jnp.mean(metrics[name]) 318 | if prefix is not None: 319 | collected = { 320 | "{}/{}".format(prefix, key): value for key, value in collected.items() 321 | } 322 | return collected 323 | 324 | 325 | class CQLTrainState(NamedTuple): 326 | policy: TrainState 327 | qf1: TrainState 328 | qf2: TrainState 329 | log_alpha: TrainState 330 | alpha_prime: TrainState 331 | target_qf1_params: Any 332 | target_qf2_params: Any 333 | global_steps: int = 0 334 | 335 | def train_params(self): 336 | params_dict = { 337 | "policy": self.policy.params, 338 | "qf1": self.qf1.params, 339 | "qf2": self.qf2.params, 340 | "log_alpha": self.log_alpha.params, 341 | "alpha_prime": self.alpha_prime.params, 342 | } 343 | return params_dict 344 | 345 | def target_params(self): 346 | return {"qf1": self.target_qf1_params, "qf2": self.target_qf2_params} 347 | 348 | def model_keys(self): 349 | keys = ["policy", "qf1", "qf2", "log_alpha", "alpha_prime"] 350 | return keys 351 | 352 | def to_dict(self): 353 | _dict = { 354 | "policy": self.policy, 355 | "qf1": self.qf1, 356 | "qf2": self.qf2, 357 | "log_alpha": self.log_alpha, 358 | "alpha_prime": self.alpha_prime, 359 | } 360 | return _dict 361 | 362 | def update_from_dict( 363 | self, new_states: Dict[str, TrainState], new_target_qf_params: Dict[str, Any] 364 | ): 365 | return self._replace( 366 | policy=new_states["policy"], 367 | qf1=new_states["qf1"], 368 | qf2=new_states["qf2"], 369 | log_alpha=new_states["log_alpha"], 370 | alpha_prime=new_states["alpha_prime"], 371 | target_qf1_params=new_target_qf_params["qf1"], 372 | target_qf2_params=new_target_qf_params["qf2"], 373 | ) 374 | 375 | 376 | class CQL(object): 377 | 378 | @classmethod 379 | def update_n_times(self, train_state: CQLTrainState, dataset, rng, config, bc=False): 380 | for _ in range(config.n_jitted_updates): 381 | rng, batch_rng, update_rng = jax.random.split(rng, 3) 382 | batch_indices = jax.random.randint( 383 | batch_rng, (config.batch_size,), 0, len(dataset.observations) 384 | ) 385 | batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset) 386 | train_state, metrics = self._train_step( 387 | train_state, update_rng, batch, config, bc 388 | ) 389 | return train_state, metrics 390 | 391 | @classmethod 392 | def _train_step(self, train_state: CQLTrainState, _rng, batch, config, bc=False): 393 | policy_fn = train_state.policy.apply_fn 394 | qf_fn = train_state.qf1.apply_fn 395 | log_alpha_fn = train_state.log_alpha.apply_fn 396 | alpha_prime_fn = train_state.alpha_prime.apply_fn 397 | target_qf_params = train_state.target_params() 398 | 399 | def loss_fn(train_params): 400 | observations = batch.observations 401 | actions = batch.actions 402 | rewards = batch.rewards 403 | next_observations = batch.next_observations 404 | dones = batch.dones 405 | 406 | loss_collection = {} 407 | 408 | rng, new_actions_rng = jax.random.split(_rng) 409 | new_actions, log_pi = policy_fn( 410 | train_params["policy"], observations, new_actions_rng 411 | ) 412 | 413 | if config.use_automatic_entropy_tuning: 414 | alpha_loss = ( 415 | -log_alpha_fn(train_params["log_alpha"]) 416 | * (log_pi + config.target_entropy).mean() 417 | ) 418 | loss_collection["log_alpha"] = alpha_loss 419 | alpha = ( 420 | jnp.exp(log_alpha_fn(train_params["log_alpha"])) 421 | * config.alpha_multiplier 422 | ) 423 | else: 424 | alpha_loss = 0.0 425 | alpha = config.alpha_multiplier 426 | 427 | """ Policy loss """ 428 | if bc: 429 | rng, bc_rng = jax.random.split(rng) 430 | log_probs = policy_fn( 431 | train_params["policy"], 432 | observations, 433 | actions, 434 | bc_rng, 435 | method=self.policy.log_prob, 436 | ) 437 | policy_loss = (alpha * log_pi - log_probs).mean() 438 | else: 439 | q_new_actions = jnp.minimum( 440 | qf_fn(train_params["qf1"], observations, new_actions), 441 | qf_fn(train_params["qf2"], observations, new_actions), 442 | ) 443 | policy_loss = (alpha * log_pi - q_new_actions).mean() 444 | 445 | loss_collection["policy"] = policy_loss 446 | 447 | """ Q function loss """ 448 | q1_pred = qf_fn(train_params["qf1"], observations, actions) 449 | q2_pred = qf_fn(train_params["qf2"], observations, actions) 450 | 451 | if config.cql_max_target_backup: 452 | rng, cql_rng = jax.random.split(rng) 453 | new_next_actions, next_log_pi = policy_fn( 454 | train_params["policy"], 455 | next_observations, 456 | cql_rng, 457 | repeat=config.cql_n_actions, 458 | ) 459 | target_q_values = jnp.minimum( 460 | qf_fn(target_qf_params["qf1"], next_observations, new_next_actions), 461 | qf_fn(target_qf_params["qf2"], next_observations, new_next_actions), 462 | ) 463 | max_target_indices = jnp.expand_dims( 464 | jnp.argmax(target_q_values, axis=-1), axis=-1 465 | ) 466 | target_q_values = jnp.take_along_axis( 467 | target_q_values, max_target_indices, axis=-1 468 | ).squeeze(-1) 469 | next_log_pi = jnp.take_along_axis( 470 | next_log_pi, max_target_indices, axis=-1 471 | ).squeeze(-1) 472 | else: 473 | rng, cql_rng = jax.random.split(rng) 474 | new_next_actions, next_log_pi = policy_fn( 475 | train_params["policy"], next_observations, cql_rng 476 | ) 477 | target_q_values = jnp.minimum( 478 | qf_fn(target_qf_params["qf1"], next_observations, new_next_actions), 479 | qf_fn(target_qf_params["qf2"], next_observations, new_next_actions), 480 | ) 481 | 482 | if config.backup_entropy: 483 | target_q_values = target_q_values - alpha * next_log_pi 484 | 485 | td_target = jax.lax.stop_gradient( 486 | rewards + (1.0 - dones) * config.discount * target_q_values 487 | ) 488 | qf1_loss = mse_loss(q1_pred, td_target) 489 | qf2_loss = mse_loss(q2_pred, td_target) 490 | 491 | ### CQL 492 | if config.use_cql: 493 | batch_size = actions.shape[0] 494 | rng, random_rng = jax.random.split(rng) 495 | cql_random_actions = jax.random.uniform( 496 | random_rng, 497 | shape=(batch_size, config.cql_n_actions, config.action_dim), 498 | minval=-1.0, 499 | maxval=1.0, 500 | ) 501 | rng, current_rng = jax.random.split(rng) 502 | cql_current_actions, cql_current_log_pis = policy_fn( 503 | train_params["policy"], 504 | observations, 505 | current_rng, 506 | repeat=config.cql_n_actions, 507 | ) 508 | rng, next_rng = jax.random.split(rng) 509 | cql_next_actions, cql_next_log_pis = policy_fn( 510 | train_params["policy"], 511 | next_observations, 512 | next_rng, 513 | repeat=config.cql_n_actions, 514 | ) 515 | 516 | cql_q1_rand = qf_fn( 517 | train_params["qf1"], observations, cql_random_actions 518 | ) 519 | cql_q2_rand = qf_fn( 520 | train_params["qf2"], observations, cql_random_actions 521 | ) 522 | cql_q1_current_actions = qf_fn( 523 | train_params["qf1"], observations, cql_current_actions 524 | ) 525 | cql_q2_current_actions = qf_fn( 526 | train_params["qf2"], observations, cql_current_actions 527 | ) 528 | cql_q1_next_actions = qf_fn( 529 | train_params["qf1"], observations, cql_next_actions 530 | ) 531 | cql_q2_next_actions = qf_fn( 532 | train_params["qf2"], observations, cql_next_actions 533 | ) 534 | 535 | cql_cat_q1 = jnp.concatenate( 536 | [ 537 | cql_q1_rand, 538 | jnp.expand_dims(q1_pred, 1), 539 | cql_q1_next_actions, 540 | cql_q1_current_actions, 541 | ], 542 | axis=1, 543 | ) 544 | cql_cat_q2 = jnp.concatenate( 545 | [ 546 | cql_q2_rand, 547 | jnp.expand_dims(q2_pred, 1), 548 | cql_q2_next_actions, 549 | cql_q2_current_actions, 550 | ], 551 | axis=1, 552 | ) 553 | cql_std_q1 = jnp.std(cql_cat_q1, axis=1) 554 | cql_std_q2 = jnp.std(cql_cat_q2, axis=1) 555 | 556 | if config.cql_importance_sample: 557 | random_density = np.log(0.5**config.action_dim) 558 | cql_cat_q1 = jnp.concatenate( 559 | [ 560 | cql_q1_rand - random_density, 561 | cql_q1_next_actions - cql_next_log_pis, 562 | cql_q1_current_actions - cql_current_log_pis, 563 | ], 564 | axis=1, 565 | ) 566 | cql_cat_q2 = jnp.concatenate( 567 | [ 568 | cql_q2_rand - random_density, 569 | cql_q2_next_actions - cql_next_log_pis, 570 | cql_q2_current_actions - cql_current_log_pis, 571 | ], 572 | axis=1, 573 | ) 574 | 575 | cql_qf1_ood = ( 576 | jax.scipy.special.logsumexp(cql_cat_q1 / config.cql_temp, axis=1) 577 | * config.cql_temp 578 | ) 579 | cql_qf2_ood = ( 580 | jax.scipy.special.logsumexp(cql_cat_q2 / config.cql_temp, axis=1) 581 | * config.cql_temp 582 | ) 583 | 584 | """Subtract the log likelihood of data""" 585 | cql_qf1_diff = jnp.clip( 586 | cql_qf1_ood - q1_pred, 587 | config.cql_clip_diff_min, 588 | config.cql_clip_diff_max, 589 | ).mean() 590 | cql_qf2_diff = jnp.clip( 591 | cql_qf2_ood - q2_pred, 592 | config.cql_clip_diff_min, 593 | config.cql_clip_diff_max, 594 | ).mean() 595 | 596 | if config.cql_lagrange: 597 | alpha_prime = jnp.clip( 598 | jnp.exp(alpha_prime_fn(train_params["alpha_prime"])), 599 | a_min=0.0, 600 | a_max=1000000.0, 601 | ) 602 | cql_min_qf1_loss = alpha_prime * config.cql_min_q_weight * (cql_qf1_diff - config.cql_target_action_gap) 603 | cql_min_qf2_loss = alpha_prime * config.cql_min_q_weight * (cql_qf2_diff - config.cql_target_action_gap) 604 | 605 | alpha_prime_loss = - (cql_min_qf1_loss + cql_min_qf2_loss) * 0.5 606 | else: 607 | cql_min_qf1_loss = cql_qf1_diff * config.cql_min_q_weight 608 | cql_min_qf2_loss = cql_qf2_diff * config.cql_min_q_weight 609 | alpha_prime_loss = 0.0 610 | alpha_prime = 0.0 611 | 612 | loss_collection["alpha_prime"] = alpha_prime_loss 613 | 614 | qf1_loss = qf1_loss + cql_min_qf1_loss 615 | qf2_loss = qf2_loss + cql_min_qf2_loss 616 | 617 | loss_collection["qf1"] = qf1_loss 618 | loss_collection["qf2"] = qf2_loss 619 | return ( 620 | tuple(loss_collection[key] for key in train_state.model_keys()), 621 | locals(), 622 | ) 623 | 624 | train_params = train_state.train_params() 625 | (_, aux_values), grads = value_and_multi_grad( 626 | loss_fn, len(train_params), has_aux=True 627 | )(train_params) 628 | 629 | new_train_states = { 630 | key: train_state.to_dict()[key].apply_gradients(grads=grads[i][key]) 631 | for i, key in enumerate(train_state.model_keys()) 632 | } 633 | new_target_qf_params = {} 634 | new_target_qf_params["qf1"] = update_target_network( 635 | new_train_states["qf1"].params, 636 | target_qf_params["qf1"], 637 | config.soft_target_update_rate, 638 | ) 639 | new_target_qf_params["qf2"] = update_target_network( 640 | new_train_states["qf2"].params, 641 | target_qf_params["qf2"], 642 | config.soft_target_update_rate, 643 | ) 644 | train_state = train_state.update_from_dict( 645 | new_train_states, new_target_qf_params 646 | ) 647 | 648 | metrics = collect_metrics( 649 | aux_values, 650 | [ 651 | "log_pi", 652 | "policy_loss", 653 | "qf1_loss", 654 | "qf2_loss", 655 | "alpha_loss", 656 | "alpha", 657 | "q1_pred", 658 | "q2_pred", 659 | "target_q_values", 660 | ], 661 | ) 662 | 663 | if config.use_cql: 664 | metrics.update( 665 | collect_metrics( 666 | aux_values, 667 | [ 668 | "cql_std_q1", 669 | "cql_std_q2", 670 | "cql_q1_rand", 671 | "cql_q2_rand" "cql_qf1_diff", 672 | "cql_qf2_diff", 673 | "cql_min_qf1_loss", 674 | "cql_min_qf2_loss", 675 | "cql_q1_current_actions", 676 | "cql_q2_current_actions" "cql_q1_next_actions", 677 | "cql_q2_next_actions", 678 | "alpha_prime", 679 | "alpha_prime_loss", 680 | ], 681 | "cql", 682 | ) 683 | ) 684 | 685 | return train_state, metrics 686 | 687 | @classmethod 688 | def get_action(self, train_state, obs): 689 | action, _ = train_state.policy.apply_fn( 690 | train_state.policy.params, 691 | obs.reshape(1, -1), 692 | jax.random.PRNGKey(0), 693 | deterministic=True, 694 | ) 695 | return action.squeeze(0) 696 | 697 | 698 | def create_cql_train_state( 699 | rng: jax.random.PRNGKey, 700 | observations: jnp.ndarray, 701 | actions: jnp.ndarray, 702 | config: CQLConfig, 703 | ) -> CQLTrainState: 704 | policy_model = TanhGaussianPolicy( 705 | observation_dim=observations.shape[-1], 706 | action_dim=actions.shape[-1], 707 | hidden_dims=config.hidden_dims, 708 | orthogonal_init=config.orthogonal_init, 709 | log_std_multiplier=config.policy_log_std_multiplier, 710 | log_std_offset=config.policy_log_std_offset, 711 | ) 712 | qf_model = FullyConnectedQFunction( 713 | observation_dim=observations.shape[-1], 714 | action_dim=actions.shape[-1], 715 | hidden_dims=config.hidden_dims, 716 | orthogonal_init=config.orthogonal_init, 717 | ) 718 | optimizer_class = { 719 | "adam": optax.adam, 720 | "sgd": optax.sgd, 721 | }[config.optimizer_type] 722 | 723 | rng, policy_rng, q1_rng, q2_rng = jax.random.split(rng, 4) 724 | 725 | policy_params = policy_model.init(policy_rng, observations, policy_rng) 726 | policy = TrainState.create( 727 | params=policy_params, 728 | tx=optimizer_class(config.policy_lr), 729 | apply_fn=policy_model.apply, 730 | ) 731 | 732 | qf1_params = qf_model.init( 733 | q1_rng, 734 | observations, 735 | actions, 736 | ) 737 | qf1 = TrainState.create( 738 | params=qf1_params, 739 | tx=optimizer_class(config.qf_lr), 740 | apply_fn=qf_model.apply, 741 | ) 742 | qf2_params = qf_model.init( 743 | q2_rng, 744 | observations, 745 | actions, 746 | ) 747 | qf2 = TrainState.create( 748 | params=qf2_params, 749 | tx=optimizer_class(config.qf_lr), 750 | apply_fn=qf_model.apply, 751 | ) 752 | target_qf1_params = deepcopy(qf1_params) 753 | target_qf2_params = deepcopy(qf2_params) 754 | 755 | log_alpha_model = Scalar(0.0) 756 | rng, log_alpha_rng = jax.random.split(rng) 757 | log_alpha = TrainState.create( 758 | params=log_alpha_model.init(log_alpha_rng), 759 | tx=optimizer_class(config.policy_lr), 760 | apply_fn=log_alpha_model.apply, 761 | ) 762 | 763 | alpha_prime_model = Scalar(1.0) 764 | rng, alpha_prime_rng = jax.random.split(rng) 765 | alpha_prime = TrainState.create( 766 | params=alpha_prime_model.init(alpha_prime_rng), 767 | tx=optimizer_class(config.qf_lr), 768 | apply_fn=alpha_prime_model.apply, 769 | ) 770 | return CQLTrainState( 771 | policy=policy, 772 | qf1=qf1, 773 | qf2=qf2, 774 | log_alpha=log_alpha, 775 | alpha_prime=alpha_prime, 776 | target_qf1_params=target_qf1_params, 777 | target_qf2_params=target_qf2_params, 778 | global_steps=0, 779 | ) 780 | 781 | 782 | def evaluate( 783 | policy_fn: Callable[[jnp.ndarray], jnp.ndarray], 784 | env: gym.Env, 785 | num_episodes: int, 786 | obs_mean=0, 787 | obs_std=1, 788 | ): 789 | episode_returns = [] 790 | for _ in range(num_episodes): 791 | obs = env.reset() 792 | done = False 793 | total_reward = 0 794 | while not done: 795 | obs = (obs - obs_mean) / obs_std 796 | action = policy_fn(obs=obs) 797 | obs, reward, done, _ = env.step(action) 798 | total_reward += reward 799 | episode_returns.append(total_reward) 800 | return env.get_normalized_score(np.mean(episode_returns)) * 100 801 | 802 | 803 | if __name__ == "__main__": 804 | wandb.init(project=config.project, config=config) 805 | rng = jax.random.PRNGKey(config.seed) 806 | env = gym.make(config.env_name) 807 | dataset, obs_mean, obs_std = get_dataset(env, config) 808 | config.action_dim = env.action_space.shape[0] 809 | # rescale reward 810 | dataset = dataset._replace(rewards=dataset.rewards * config.reward_scale + config.reward_bias) 811 | 812 | if config.target_entropy >= 0.0: 813 | config.target_entropy = -np.prod(env.action_space.shape).item() 814 | # create train_state 815 | rng, subkey = jax.random.split(rng) 816 | example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset) 817 | train_state = create_cql_train_state( 818 | subkey, 819 | example_batch.observations, 820 | example_batch.actions, 821 | config, 822 | ) 823 | algo = CQL() 824 | update_fn = jax.jit(algo.update_n_times, static_argnums=(3,)) 825 | act_fn = jax.jit(algo.get_action) 826 | 827 | num_steps = int(config.max_steps // config.n_jitted_updates) 828 | eval_interval = config.eval_interval // config.n_jitted_updates 829 | for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True): 830 | metrics = {"step": i} 831 | rng, update_rng = jax.random.split(rng) 832 | train_state, metrics = update_fn(train_state, dataset, update_rng, config) 833 | metrics.update(metrics) 834 | 835 | if i == 0 or (i + 1) % eval_interval == 0: 836 | policy_fn = partial(act_fn, train_state=train_state) 837 | normalized_score = evaluate( 838 | policy_fn, env, config.eval_episodes, obs_mean=0, obs_std=1 839 | ) 840 | metrics[f"{config.env_name}/normalized_score"] = normalized_score 841 | print(config.env_name, i, metrics[f"{config.env_name}/normalized_score"]) 842 | wandb.log(metrics) 843 | 844 | # final evaluation 845 | policy_fn = partial(act_fn, train_state=train_state) 846 | normalized_score = evaluate( 847 | policy_fn, env, config.eval_episodes, obs_mean=0, obs_std=1 848 | ) 849 | wandb.log({f"{config.env_name}/finel_normalized_score": normalized_score}) 850 | print(config.env_name, i, normalized_score) 851 | wandb.finish() 852 | --------------------------------------------------------------------------------