├── .github └── workflows │ └── pre-commit.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── contributing.md ├── experiments ├── configs │ ├── cql_config.py │ ├── ensemble_config.py │ ├── iql_config.py │ ├── sac_config.py │ ├── train_config.py │ └── wsrl_config.py └── scripts │ ├── adroit │ ├── launch_calql_finetune.sh │ ├── launch_cql_finetune.sh │ ├── launch_iql_finetune.sh │ ├── launch_rlpd.sh │ └── launch_wsrl_finetune.sh │ ├── antmaze │ ├── launch_calql_finetune.sh │ ├── launch_cql_finetune.sh │ ├── launch_iql_finetune.sh │ ├── launch_rlpd.sh │ └── launch_wsrl_finetune.sh │ ├── kitchen │ ├── launch_calql_finetune.sh │ ├── launch_cql_finetune.sh │ ├── launch_iql_finetune.sh │ ├── launch_rlpd.sh │ └── launch_wsrl_finetune.sh │ └── locomotion │ ├── launch_cql_finetune.sh │ ├── launch_iql_finetune.sh │ ├── launch_rlpd.sh │ └── launch_wsrl_finetune.sh ├── finetune.py ├── requirements.txt ├── setup.py └── wsrl ├── agents ├── __init__.py ├── bc.py ├── calql.py ├── cql.py ├── iql.py └── sac.py ├── common ├── common.py ├── evaluation.py ├── initialization.py ├── optimizers.py ├── typing.py └── wandb.py ├── data ├── dataset.py └── replay_buffer.py ├── envs ├── adroit_binary_dataset.py ├── d4rl_dataset.py ├── env_common.py └── wrappers │ ├── __init__.py │ ├── add_truncation.py │ ├── adroit.py │ ├── kitchen.py │ └── reward_scale.py ├── networks ├── actor_critic_nets.py ├── lagrange.py └── mlp.py └── utils ├── timer_utils.py └── train_utils.py /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.0 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | .vscode/ 3 | .onager/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | .idea/ 165 | 166 | *.ipynb 167 | wandb 168 | checkpoints 169 | log 170 | render 171 | *.png 172 | 173 | *.sif 174 | 175 | install_jax_on_matrix.sh 176 | 177 | # VSCode 178 | .vscode 179 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: check-ast 7 | - id: check-added-large-files 8 | - id: check-case-conflict 9 | - id: check-merge-conflict 10 | - id: end-of-file-fixer 11 | - id: trailing-whitespace 12 | - id: detect-private-key 13 | - id: debug-statements 14 | exclude: ^experiments/ 15 | - repo: https://github.com/psf/black 16 | rev: 22.10.0 17 | hooks: 18 | - id: black 19 | exclude: ^experiments/ 20 | - repo: https://github.com/pycqa/isort 21 | rev: 5.12.0 22 | hooks: 23 | - id: isort 24 | exclude: ^experiments/ 25 | args: ["--profile", "black", "--src", "wsrl", "--src", "experiments"] 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WSRL: Warm-Start Reinforcement Learning 2 | [](https:/zhouzypaul.github.io/images/paper-images/wsrl/wsrl.png) 3 | [![arXiv](https://img.shields.io/badge/arXiv-2412.07762-df2a2a.svg?style=for-the-badge)](http://arxiv.org/abs/2412.07762) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge)](https://opensource.org/licenses/MIT) 5 | [![Static Badge](https://img.shields.io/badge/Project-Page-a?style=for-the-badge)](https://zhouzypaul.github.io/wsrl) 6 | 7 | This is the code release for paper [Efficient Online Reinforcement Learning Fine-Tuning Need Not Retain Offline Data](http://arxiv.org/abs/2412.07762). We provide the implementation of [WSRL](http://arxiv.org/abs/2412.07762) (Warm-Start Reinforcement Learning), as well as popular actor-critic RL algorithms in JAX and Flax: [IQL](https://arxiv.org/abs/2110.06169), [CQL](https://arxiv.org/abs/2006.04779), [CalQL](https://arxiv.org/abs/2303.05479), [SAC](https://arxiv.org/abs/1801.01290), [RLPD](https://arxiv.org/abs/2302.02948). Variants of SAC also supported, such as [TD3](https://arxiv.org/pdf/1802.09477), [REDQ](https://arxiv.org/abs/2101.05982), and IQL policy extraction supports both AWR and DDPG+BC. 8 | We support the following environments: D4RL antmaze, adroit, kitchen, and Mujoco locomotion, but the code can be easily adpated to work with other environments and datasets. 9 | 10 | ![teaser](https://zhouzypaul.github.io/images/paper-images/wsrl/teaser.png) 11 | 12 | ``` 13 | @article{zhou2024efficient, 14 | author = {Zhiyuan Zhou and Andy Peng and Qiyang Li and Sergey Levine and Aviral Kumar}, 15 | title = {Efficient Online Reinforcement Learning Fine-Tuning Need Not Retain Offline Data}, 16 | conference = {arXiv Pre-print}, 17 | year = {2024}, 18 | url = {http://arxiv.org/abs/2412.07762}, 19 | } 20 | ``` 21 | 22 | 23 | ## Installation 24 | ```bash 25 | conda create -n wsrl python=3.10 -y 26 | conda activate wsrl 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | For jax, install 31 | ``` 32 | pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 33 | ``` 34 | 35 | To use the D4RL envs, you would also need my fork of the d4rl envs below. 36 | This fork incorporates the antmaze-ultra environments and fixes the kitchen environment rewards to be consistent between the offline dataset and the environment. 37 | ``` 38 | git clone git@github.com:zhouzypaul/D4RL.git 39 | cd D4RL 40 | pip install -e . 41 | ``` 42 | 43 | To use Mujoco, you would also need to install mujoco manually to `~/.mujoco/` (for more instructions on download see [here](https://github.com/openai/mujoco-py?tab=readme-ov-file#install-mujoco)), and use the following environment variables 44 | ```bash 45 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin 46 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia 47 | ``` 48 | 49 | To use the adroit envs, you would need 50 | ``` 51 | git clone --recursive https://github.com/nakamotoo/mj_envs.git 52 | cd mj_envs 53 | git submodule update --remote 54 | pip install -e . 55 | ``` 56 | 57 | Download the adroit dataset from [here](https://drive.google.com/file/d/1yUdJnGgYit94X_AvV6JJP5Y3Lx2JF30Y/view) and unzip the files into `~/adroit_data/`. 58 | If you would like to put the adroit datasets into another directory, use the environment variable `DATA_DIR_PREFIX` (checkout the code [here](https://github.com/zhouzypaul/wsrl/blob/4b5665987079934a926c10a09bd81bc3c48ea9fa/wsrl/envs/adroit_binary_dataset.py#L7) for more details). 59 | ```bash 60 | export DATA_DIR_PREFIX=/path/to/your/data 61 | ``` 62 | 63 | ## Running 64 | The main run script is `finetune.py`. We provide bash scripts in `experiments/scripts/` to train WSRL/IQL/CQL/CalQ/RLPD on the different environments. 65 | 66 | The shared agent configs are in `experiments/configs/*`, and the environment-specific configs are in `experiments/configs/train_config.py` and in the bash scripts. 67 | 68 | ### Pre-training 69 | For example, to run CalQL (with Q-ensemble) pre-training 70 | ```bash 71 | # on antmaze 72 | bash experiments/scripts/antmaze/launch_calql_finetune.sh --use_redq --env antmaze-large-diverse-v2 73 | 74 | # on adroit 75 | bash experiments/scripts/adroit/launch_calql_finetune.sh --use_redq --env door-binary-v0 76 | 77 | # on kitchen 78 | bash experiments/scripts/kitchen/launch_calql_finetune.sh --use_redq --env kitchen-mixed-v0 79 | 80 | # on mujoco locomotion (CQL pre-train because MC returns are hard to estimate) 81 | bash experiments/scripts/locomotion/launch_cql_finetune.sh --use_redq --env halfcheetah-medium-replay-v0 82 | ``` 83 | 84 | ### Fine-tuning 85 | To run WSRL fine-tuning from a pre-trained checkpoint 86 | ```bash 87 | # on antmaze 88 | bash experiments/scripts/antmaze/launch_wsrl_finetune.sh --env antmaze-large-diverse-v2 --resume_path /path/to/checkpoint 89 | 90 | # on adroit 91 | bash experiments/scripts/adroit/launch_wsrl_finetune.sh --env door-binary-v0 --resume_path /path/to/checkpoint 92 | 93 | # on kitchen 94 | bash experiments/scripts/kitchen/launch_wsrl_finetune.sh --env kitchen-mixed-v0 --resume_path /path/to/checkpoint 95 | 96 | # on mujoco locomotion 97 | bash experiments/scripts/locomotion/launch_wsrl_finetune.sh --env halfcheetah-medium-replay-v0 --resume_path /path/to/checkpoint 98 | ``` 99 | 100 | ### No Data Retention 101 | The default setting is to not retain offline data during fine-tuning, as described in the [paper](http://arxiv.org/abs/2412.07762). However, if you wish to retain the data, you can use the `--offline_data_ratio <>` or `--online_sampling_method append` option. Checkout `finetune.py` for more details. 102 | 103 | ## Contributing 104 | For a detailed explanation of how the codebase works, please checkout the [contributing.md](contributing.md) file. 105 | 106 | To enable code checks and auto-formatting, please install pre-commit hooks (run this in the root directory): 107 | ``` 108 | pre-commit install 109 | ``` 110 | The hooks should now run before every commit. If files are modified during the checks, you'll need to re-stage them and commit again. 111 | 112 | ## Credits 113 | This repo is built upon a version of Dibya Ghosh's [jaxrl_minimal](https://github.com/dibyaghosh/jaxrl_minimal) repository, which also included contributions from Kevin Black, Homer Walke, Kyle Stachowicz, and others. 114 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We discuss two key abstractions used heavily in this codebase: the use of `TrainState` and the expression of agents as `PytreeNodes`. 4 | 5 | ## Agents 6 | 7 | In this codebase, we represent agents as PytreeNodes (first-class Jax citizens), making them really easy to handle. The general structure of an Agent is as follows: it contains some number of neural networks, some set of configuration values, and has an update function that takes in a batch and returns a agent with updated parameters after performing some gradient update. Usually there's a `sample_actions` to sample from the resulting policy too. 8 | 9 | ```python 10 | class Agent(flax.struct.PyTreeNode): 11 | value_function: TrainState 12 | policy: TrainState 13 | config: dict = nonpytree_field() # tells Jax to not look at this (usually contains discount factor / target update speed / other hyperparams) 14 | 15 | @jax.jit 16 | def update(self, batch: Batch): 17 | ... 18 | new_value_function = ... 19 | new_policy = ... 20 | info = {'loss': 100} 21 | new_agent = self.replace(value_function=value_function, policy=new_policy) 22 | return new_agent, info 23 | 24 | @jax.jit 25 | def sample_actions(self, observations, *, seed): 26 | actions = ... 27 | return actions 28 | ``` 29 | 30 | ### Multiple Devices 31 | 32 | Operating on multiple GPUs / TPUs is really easy! Check out the section at the bottom of the page as to how to accumulate gradients across all the GPUs. 33 | 34 | 35 | - `flax.jax_utils.replicate()`: replicates an object on all GPUs 36 | - `wsrl.common.common.shard_batch`: splits an batch evenly across all the GPUs 37 | - `flax.jax_utils.unreplicate()` brings back to single GPU 38 | 39 | ```python 40 | agent = ... 41 | batch = ... 42 | 43 | replicated_agent = replicate(agent) 44 | replicated_agent, info = replicated_agent.update(shard_batch(batch)) 45 | info = unreplicate(info) # bring info back to single device 46 | 47 | 48 | ``` 49 | ## TrainState 50 | 51 | 52 | The TrainState class (located at `wsrl.common.common.TrainState`) is a fork of Flax's TrainState class with some additional syntactic features for ease of use. 53 | 54 | The TrainState class combines a neural network module (`flax.linen.Module`) with a set of parameters for this network (alongside with potentially an optimizer) 55 | 56 | ### Creating a TrainState 57 | 58 | ```python 59 | model_def = nn.Dense(10) # nn.Module 60 | params = model_def.init(rng, x)['params'] # parameters for nn.Module 61 | tx = optax.adam(1e-3) 62 | model = TrainState.create(model_def, params, tx=tx) 63 | ``` 64 | 65 | ### Running the Model 66 | 67 | ```python 68 | model = TrainState.create(...) 69 | y_pred = model(x) 70 | ``` 71 | 72 | In some cases, the neural network module may have several functions; for example, a VAE might have an `.encode(x)` function and a `.decode(z)` function. By default, the `__call__()` method is used, but this can be specified via an argument: 73 | 74 | ```python 75 | z = model(x, method='encode') 76 | x_pred = model(z, method='decode') 77 | ``` 78 | 79 | You can also run the model with a different set of parameters than that bound to the TrainState. This is most commonly done when taking the gradient with respect to model parameters. 80 | 81 | ```python 82 | y_pred = model(x, params=other_params) 83 | ``` 84 | 85 | ```python 86 | def loss(params): 87 | y_pred = model(x, params=params) 88 | return jnp.mean((y - y_pred) ** 2) 89 | 90 | grads = jax.grad(loss)(model.params) 91 | ``` 92 | 93 | ### Optimizing a TrainState 94 | 95 | To update a model (that has a `tx`), we provide two convenience functions: `.apply_gradients` and `.apply_loss_fn` 96 | 97 | `model.apply_gradients` takes in a set of gradients (same shape as parameters) and computes the new set of parameters using optax. 98 | 99 | ```python 100 | def loss(params): 101 | y_pred = model(x, params=params) 102 | return jnp.mean((y - y_pred) ** 2) 103 | 104 | grads = jax.grad(loss)(model.params) 105 | new_model = model.apply_gradients(grads=grads) 106 | ``` 107 | 108 | `model.apply_loss_fn()` is a convenience method that both computes the gradients and runs `.apply_gradients()`. 109 | 110 | ```python 111 | def loss(params): 112 | y_pred = model(x, params=params) 113 | return jnp.mean((y - y_pred) ** 2) 114 | 115 | new_model = model.apply_loss_fn(loss_fn=loss) 116 | ``` 117 | 118 | If the model is being run across multiple GPUs / TPUs and we wish to aggregate gradients, this can be specified with the `pmap_axis` argument (you can always use jax.lax.pmean as an alternative): 119 | 120 | ```python 121 | @functools.partial(jax.pmap, axis_name='pmap') 122 | def update(model, x, y): 123 | def loss(params): 124 | y_pred = model(x, params=params) 125 | return jnp.mean((y - y_pred) ** 2) 126 | 127 | new_model = model.apply_loss_fn(loss_fn=loss, pmap_axis='pmap') 128 | return new_model 129 | ``` 130 | -------------------------------------------------------------------------------- /experiments/configs/cql_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | import numpy as np 3 | 4 | from experiments.configs import sac_config 5 | 6 | 7 | def get_config(updates=None): 8 | config = sac_config.get_config() 9 | 10 | config.cql_n_actions = 10 11 | config.cql_action_sample_method = "uniform" 12 | config.cql_max_target_backup = True 13 | config.cql_importance_sample = True 14 | config.cql_autotune_alpha = False 15 | config.cql_alpha_lagrange_init = 1.0 16 | config.cql_alpha_lagrange_otpimizer_kwargs = ConfigDict( 17 | { 18 | "learning_rate": 3e-4, 19 | } 20 | ) 21 | config.cql_target_action_gap = 1.0 22 | config.cql_temp = 1.0 23 | config.cql_alpha = 5.0 24 | config.cql_clip_diff_min = -np.inf 25 | config.cql_clip_diff_max = np.inf 26 | config.use_td_loss = True # set this to False to essentially do BC 27 | config.use_cql_loss = True # set this to False to default to SAC 28 | 29 | # Cal-QL 30 | config.use_calql = False 31 | config.calql_bound_random_actions = False 32 | 33 | if updates is not None: 34 | config.update(ConfigDict(updates).copy_and_resolve_references()) 35 | return config 36 | -------------------------------------------------------------------------------- /experiments/configs/ensemble_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | 4 | def add_redq_config(config, updates=None): 5 | # use ensemble and layer norm 6 | config.critic_ensemble_size = 10 7 | config.critic_subsample_size = 2 8 | config.policy_network_kwargs.use_layer_norm = True 9 | config.critic_network_kwargs.use_layer_norm = True 10 | 11 | if updates is not None: 12 | config.update(ConfigDict(updates).copy_and_resolve_references()) 13 | 14 | return config 15 | -------------------------------------------------------------------------------- /experiments/configs/iql_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | 4 | def get_config(updates=None): 5 | config = ConfigDict() 6 | 7 | config.discount = 0.99 8 | config.expectile = 0.9 9 | config.temperature = 10.0 10 | config.target_update_rate = 5e-3 11 | config.actor_type = "awr" 12 | 13 | config.critic_ensemble_size = 2 14 | config.critic_subsample_size = None 15 | 16 | config.policy_network_kwargs=ConfigDict( 17 | dict( 18 | hidden_dims=(256, 256), 19 | kernel_init_type="var_scaling", 20 | kernel_scale_final=1e-2, 21 | ) 22 | ) 23 | config.critic_network_kwargs=ConfigDict( 24 | dict( 25 | hidden_dims=(256, 256), 26 | kernel_init_type="var_scaling", 27 | ) 28 | ) 29 | config.policy_kwargs=ConfigDict( 30 | dict( 31 | tanh_squash_distribution=False, 32 | std_parameterization="uniform", 33 | ) 34 | ) 35 | 36 | config.actor_optimizer_kwargs=ConfigDict( 37 | { 38 | "learning_rate": 3e-4, 39 | } 40 | ) 41 | config.value_critic_optimizer_kwargs=ConfigDict( 42 | { 43 | "learning_rate": 3e-4, 44 | } 45 | ) 46 | 47 | if updates is not None: 48 | config.update(ConfigDict(updates).copy_and_resolve_references()) 49 | 50 | return config 51 | -------------------------------------------------------------------------------- /experiments/configs/sac_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | 4 | def get_config(updates=None): 5 | config = ConfigDict() 6 | config.discount = 0.99 7 | config.backup_entropy = False 8 | config.target_entropy = 0.0 9 | config.soft_target_update_rate = 5e-3 10 | config.critic_ensemble_size = 2 11 | config.critic_subsample_size = None 12 | config.autotune_entropy = True 13 | config.temperature_init = 1.0 14 | 15 | # arch 16 | config.critic_network_kwargs = ConfigDict( 17 | { 18 | "hidden_dims": [256, 256], 19 | "activate_final": True, 20 | "use_layer_norm": False, 21 | } 22 | ) 23 | config.policy_network_kwargs = ConfigDict( 24 | { 25 | "hidden_dims": [256, 256], 26 | "activate_final": True, 27 | "use_layer_norm": False, 28 | } 29 | ) 30 | config.policy_kwargs = ConfigDict( 31 | { 32 | "tanh_squash_distribution": True, 33 | "std_parameterization": "exp", 34 | } 35 | ) 36 | 37 | config.actor_optimizer_kwargs = ConfigDict( 38 | { 39 | "learning_rate": 1e-4, 40 | } 41 | ) 42 | config.critic_optimizer_kwargs = ConfigDict( 43 | { 44 | "learning_rate": 3e-4, 45 | } 46 | ) 47 | config.temperature_optimizer_kwargs = ConfigDict( 48 | { 49 | "learning_rate": 1e-4, 50 | } 51 | ) 52 | 53 | if updates is not None: 54 | config.update(ConfigDict(updates).copy_and_resolve_references()) 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /experiments/configs/train_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | from experiments.configs.cql_config import get_config as get_cql_config 4 | from experiments.configs.iql_config import get_config as get_iql_config 5 | from experiments.configs.sac_config import get_config as get_sac_config 6 | from experiments.configs.wsrl_config import get_config as get_wsrl_config 7 | 8 | 9 | def get_config(config_string): 10 | 11 | possible_structures = { 12 | 13 | ######################################################## 14 | # antmaze configs # 15 | ######################################################## 16 | 17 | "antmaze_cql": ConfigDict( 18 | dict( 19 | agent_kwargs=get_cql_config( 20 | updates=dict( 21 | policy_kwargs=dict( 22 | tanh_squash_distribution=True, 23 | std_parameterization="uniform", 24 | ), 25 | critic_network_kwargs={ 26 | "hidden_dims": [256, 256, 256, 256], 27 | "activations": "relu", 28 | "kernel_scale_final": 1e-2, 29 | }, 30 | policy_network_kwargs={ 31 | "hidden_dims": [256, 256], 32 | "activations": "relu", 33 | "kernel_scale_final": 1e-2, 34 | }, 35 | cql_autotune_alpha=True, 36 | cql_target_action_gap=0.8, 37 | ) 38 | ).to_dict(), 39 | ) 40 | ), 41 | 42 | "antmaze_iql":ConfigDict( 43 | dict( 44 | agent_kwargs=get_iql_config( 45 | updates=dict( 46 | expectile=0.9, 47 | temperature=10.0, 48 | ) 49 | ).to_dict(), 50 | ) 51 | ), 52 | 53 | "antmaze_wsrl": ConfigDict( 54 | dict( 55 | agent_kwargs=get_wsrl_config( 56 | updates=dict( 57 | policy_kwargs=dict( 58 | tanh_squash_distribution=True, 59 | std_parameterization="uniform", 60 | ), 61 | critic_network_kwargs={ 62 | "hidden_dims": [256, 256, 256, 256], 63 | "activations": "relu", 64 | "kernel_scale_final": 1e-2, 65 | "use_layer_norm": True, 66 | }, 67 | policy_network_kwargs={ 68 | "hidden_dims": [256, 256], 69 | "activations": "relu", 70 | "kernel_scale_final": 1e-2, 71 | "use_layer_norm": True, 72 | }, 73 | max_target_backup=True, 74 | ) 75 | ).to_dict(), 76 | ) 77 | ), 78 | 79 | ######################################################## 80 | # adroit configs # 81 | ######################################################## 82 | 83 | "adroit_cql": ConfigDict( 84 | dict( 85 | agent_kwargs=get_cql_config( 86 | updates=dict( 87 | policy_kwargs=dict( 88 | tanh_squash_distribution=True, 89 | std_parameterization="exp", 90 | ), 91 | critic_network_kwargs={ 92 | "hidden_dims": [512, 512, 512], 93 | "kernel_scale_final": 1e-2, 94 | "activations": "relu", 95 | }, 96 | policy_network_kwargs={ 97 | "hidden_dims": [512, 512], 98 | "kernel_scale_final": 1e-2, 99 | "activations": "relu", 100 | }, 101 | online_cql_alpha=1.0, 102 | cql_alpha=1.0, 103 | ) 104 | ).to_dict(), 105 | ) 106 | ), 107 | 108 | "adroit_iql":ConfigDict( 109 | dict( 110 | agent_kwargs=get_iql_config( 111 | updates=dict( 112 | policy_network_kwargs=dict( 113 | hidden_dims=(256, 256), 114 | kernel_init_type="var_scaling", 115 | kernel_scale_final=1e-2, 116 | dropout_rate=0.1, 117 | ), 118 | expectile=0.7, 119 | temperature=0.5, 120 | ), 121 | ).to_dict(), 122 | ) 123 | ), 124 | 125 | "adroit_wsrl": ConfigDict( 126 | dict( 127 | agent_kwargs=get_wsrl_config( 128 | updates=dict( 129 | policy_kwargs=dict( 130 | tanh_squash_distribution=True, 131 | std_parameterization="exp", 132 | ), 133 | critic_network_kwargs={ 134 | "hidden_dims": [512, 512, 512], 135 | "kernel_scale_final": 1e-2, 136 | "activations": "relu", 137 | "use_layer_norm": True, 138 | }, 139 | policy_network_kwargs={ 140 | "hidden_dims": [512, 512], 141 | "kernel_scale_final": 1e-2, 142 | "activations": "relu", 143 | "use_layer_norm": True, 144 | }, 145 | ) 146 | ).to_dict(), 147 | ) 148 | ), 149 | 150 | ######################################################## 151 | # kitchen configs # 152 | ######################################################## 153 | 154 | "kitchen_cql": ConfigDict( 155 | dict( 156 | agent_kwargs=get_cql_config( 157 | updates=dict( 158 | policy_kwargs=dict( 159 | tanh_squash_distribution=True, 160 | std_parameterization="exp", 161 | ), 162 | critic_network_kwargs={ 163 | "hidden_dims": [512, 512, 512], 164 | "activations": "relu", 165 | }, 166 | policy_network_kwargs={ 167 | "hidden_dims": [512, 512, 512], 168 | "activations": "relu", 169 | }, 170 | online_cql_alpha=5.0, 171 | cql_alpha=5.0, 172 | cql_importance_sample=False, 173 | ) 174 | ).to_dict(), 175 | ) 176 | ), 177 | 178 | "kitchen_iql":ConfigDict( 179 | dict( 180 | agent_kwargs=get_iql_config( 181 | updates=dict( 182 | policy_network_kwargs=dict( 183 | hidden_dims=(256, 256), 184 | activations="relu", 185 | dropout_rate=0.1, 186 | ), 187 | critic_network_kwargs=dict( 188 | hidden_dims=(256, 256), 189 | activations="relu", 190 | ), 191 | expectile=0.7, 192 | temperature=0.5, 193 | ) 194 | ).to_dict(), 195 | ) 196 | ), 197 | 198 | "kitchen_wsrl": ConfigDict( 199 | dict( 200 | agent_kwargs=get_wsrl_config( 201 | updates=dict( 202 | policy_kwargs=dict( 203 | tanh_squash_distribution=True, 204 | std_parameterization="exp", 205 | ), 206 | critic_network_kwargs={ 207 | "hidden_dims": [512, 512, 512], 208 | "activations": "relu", 209 | "use_layer_norm": True, 210 | }, 211 | policy_network_kwargs={ 212 | "hidden_dims": [512, 512, 512], 213 | "activations": "relu", 214 | "use_layer_norm": True, 215 | }, 216 | ) 217 | ).to_dict(), 218 | ) 219 | ), 220 | 221 | ######################################################## 222 | # locomotion configs # 223 | ######################################################## 224 | 225 | "locomotion_cql": ConfigDict( 226 | dict( 227 | agent_kwargs=get_cql_config( 228 | updates=dict( 229 | critic_network_kwargs={ 230 | "hidden_dims": [256, 256], 231 | "activations": "relu", 232 | "kernel_scale_final": 1e-2, 233 | }, 234 | policy_network_kwargs={ 235 | "hidden_dims": [256, 256], 236 | "activations": "relu", 237 | "kernel_scale_final": 1e-2, 238 | }, 239 | online_cql_alpha=5.0, 240 | cql_alpha=5.0, 241 | ) 242 | ).to_dict(), 243 | ) 244 | ), 245 | 246 | "locomotion_iql":ConfigDict( 247 | dict( 248 | agent_kwargs=get_iql_config( 249 | updates=dict( 250 | expectile=0.7, 251 | temperature=3.0, 252 | ) 253 | ).to_dict(), 254 | ) 255 | ), 256 | 257 | "locomotion_wsrl": ConfigDict( 258 | dict( 259 | agent_kwargs=get_wsrl_config( 260 | updates=dict( 261 | critic_network_kwargs={ 262 | "hidden_dims": [256, 256], 263 | "activations": "relu", 264 | "kernel_scale_final": 1e-2, 265 | "use_layer_norm": True, 266 | }, 267 | policy_network_kwargs={ 268 | "hidden_dims": [256, 256], 269 | "activations": "relu", 270 | "kernel_scale_final": 1e-2, 271 | "use_layer_norm": True, 272 | }, 273 | ) 274 | ).to_dict(), 275 | ) 276 | ), 277 | } 278 | 279 | return possible_structures[config_string] 280 | -------------------------------------------------------------------------------- /experiments/configs/wsrl_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | from experiments.configs import sac_config 4 | 5 | 6 | def get_config(updates=None): 7 | config = sac_config.get_config() 8 | 9 | config.critic_ensemble_size = 10 10 | config.critic_subsample_size = 2 11 | 12 | config.policy_network_kwargs.use_layer_norm = True 13 | config.critic_network_kwargs.use_layer_norm = True 14 | 15 | if updates is not None: 16 | config.update(ConfigDict(updates).copy_and_resolve_references()) 17 | return config 18 | -------------------------------------------------------------------------------- /experiments/scripts/adroit/launch_calql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | # env: pen-binary-v0, door-binary-v0, relocate-binary-v0 6 | 7 | python finetune.py \ 8 | --agent calql \ 9 | --config experiments/configs/train_config.py:adroit_cql \ 10 | --project baselines-section \ 11 | --group no-redq-utd1 \ 12 | --warmup_steps 0 \ 13 | --num_offline_steps 20_000 \ 14 | --reward_scale 10.0 \ 15 | --reward_bias 5.0 \ 16 | --env pen-binary-v0 \ 17 | $@ 18 | -------------------------------------------------------------------------------- /experiments/scripts/adroit/launch_cql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | # env: pen-binary-v0, door-binary-v0, relocate-binary-v0 6 | 7 | python finetune.py \ 8 | --agent cql \ 9 | --config experiments/configs/train_config.py:adroit_cql \ 10 | --project baselines-section \ 11 | --group no-redq-utd1 \ 12 | --warmup_steps 0 \ 13 | --num_offline_steps 20_000 \ 14 | --reward_scale 10.0 \ 15 | --reward_bias 5.0 \ 16 | --env pen-binary-v0 \ 17 | $@ 18 | -------------------------------------------------------------------------------- /experiments/scripts/adroit/launch_iql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | # env: pen-binary-v0, door-binary-v0, relocate-binary-v0 6 | 7 | python finetune.py \ 8 | --agent iql \ 9 | --config experiments/configs/train_config.py:adroit_iql \ 10 | --project baselines-section \ 11 | --group no-redq-utd1 \ 12 | --reward_scale 10.0 \ 13 | --reward_bias 5.0 \ 14 | --num_offline_steps 20_000 \ 15 | --log_interval 1_000 \ 16 | --eval_interval 10_000 \ 17 | --save_interval 20_000 \ 18 | $@ 19 | -------------------------------------------------------------------------------- /experiments/scripts/adroit/launch_rlpd.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent sac \ 7 | --config experiments/configs/train_config.py:adroit_wsrl \ 8 | --project baselines-section \ 9 | --num_offline_steps 0 \ 10 | --reward_scale 10.0 \ 11 | --reward_bias 5.0 \ 12 | --offline_data_ratio 0.5 \ 13 | --utd 4 \ 14 | --batch_size $((256 * 4)) \ 15 | --warmup_steps 5_000 \ 16 | $@ 17 | -------------------------------------------------------------------------------- /experiments/scripts/adroit/launch_wsrl_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | # env: pen-binary-v0, door-binary-v0, relocate-binary-v0 6 | 7 | python3 finetune.py \ 8 | --agent sac \ 9 | --config experiments/configs/train_config.py:adroit_wsrl \ 10 | --project method-section \ 11 | --num_offline_steps 20_000 \ 12 | --reward_scale 10.0 \ 13 | --reward_bias 5.0 \ 14 | --env pen-binary-v0 \ 15 | --utd 4 \ 16 | --batch_size 1024 \ 17 | --warmup_steps 5000 \ 18 | $@ 19 | -------------------------------------------------------------------------------- /experiments/scripts/antmaze/launch_calql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python finetune.py \ 6 | --agent calql \ 7 | --config experiments/configs/train_config.py:antmaze_cql \ 8 | --project baselines-section \ 9 | --group no-redq-utd1 \ 10 | --reward_scale 10.0 \ 11 | --reward_bias -5.0 \ 12 | --num_offline_steps 1_000_000 \ 13 | --env antmaze-large-diverse-v2 \ 14 | $@ 15 | -------------------------------------------------------------------------------- /experiments/scripts/antmaze/launch_cql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python finetune.py \ 6 | --agent cql \ 7 | --config experiments/configs/train_config.py:antmaze_cql \ 8 | --project baselines-section \ 9 | --group no-redq-utd1 \ 10 | --reward_scale 10.0 \ 11 | --reward_bias -5.0 \ 12 | --num_offline_steps 1_000_000 \ 13 | --env antmaze-large-diverse-v2 \ 14 | $@ 15 | -------------------------------------------------------------------------------- /experiments/scripts/antmaze/launch_iql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python finetune.py \ 6 | --agent iql \ 7 | --config experiments/configs/train_config.py:antmaze_iql \ 8 | --project baselines-section \ 9 | --group no-redq-utd1 \ 10 | --reward_scale 10.0 \ 11 | --reward_bias -5.0 \ 12 | --num_offline_steps 1_000_000 \ 13 | --env antmaze-large-diverse-v2 \ 14 | $@ 15 | -------------------------------------------------------------------------------- /experiments/scripts/antmaze/launch_rlpd.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent sac \ 7 | --config experiments/configs/train_config.py:antmaze_wsrl \ 8 | --project baselines-section \ 9 | --config.agent_kwargs.critic_subsample_size 1 \ 10 | --reward_scale 10.0 \ 11 | --reward_bias -5.0 \ 12 | --num_offline_steps 0 \ 13 | --num_online_steps 500_000 \ 14 | --offline_data_ratio 0.5 \ 15 | --env antmaze-large-diverse-v2 \ 16 | --utd 4 \ 17 | --batch_size 1024 \ 18 | --warmup_steps 5000 \ 19 | $@ 20 | -------------------------------------------------------------------------------- /experiments/scripts/antmaze/launch_wsrl_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent sac \ 7 | --config experiments/configs/train_config.py:antmaze_wsrl \ 8 | --project method-section \ 9 | --reward_scale 10.0 \ 10 | --reward_bias -5.0 \ 11 | --num_offline_steps 1_000_000 \ 12 | --env antmaze-large-diverse-v2 \ 13 | --utd 4 \ 14 | --batch_size 1024 \ 15 | --warmup_steps 5000 \ 16 | $@ 17 | -------------------------------------------------------------------------------- /experiments/scripts/kitchen/launch_calql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent calql \ 7 | --config experiments/configs/train_config.py:kitchen_cql \ 8 | --project baselines-section \ 9 | --group no-redq-utd1 \ 10 | --num_offline_steps 250_000 \ 11 | --reward_scale 1.0 \ 12 | --reward_bias -4.0 \ 13 | --env kitchen-partial-v0 \ 14 | $@ 15 | -------------------------------------------------------------------------------- /experiments/scripts/kitchen/launch_cql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent cql \ 7 | --config experiments/configs/train_config.py:kitchen_cql \ 8 | --project baselines-section \ 9 | --group no-redq-utd1 \ 10 | --num_offline_steps 250_000 \ 11 | --reward_scale 1.0 \ 12 | --reward_bias -4.0 \ 13 | --env kitchen-partial-v0 \ 14 | $@ 15 | -------------------------------------------------------------------------------- /experiments/scripts/kitchen/launch_iql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent iql \ 7 | --config experiments/configs/train_config.py:kitchen_iql \ 8 | --project baselines-section \ 9 | --group no-redq-utd1 \ 10 | --num_offline_steps 250_000 \ 11 | --reward_scale 1.0 \ 12 | --reward_bias -4.0 \ 13 | --env kitchen-mixed-v0 \ 14 | $@ 15 | -------------------------------------------------------------------------------- /experiments/scripts/kitchen/launch_rlpd.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent sac \ 7 | --config experiments/configs/train_config.py:kitchen_wsrl \ 8 | --project baselines-section \ 9 | --num_offline_steps 0 \ 10 | --offline_data_ratio 0.5 \ 11 | --reward_scale 1.0 \ 12 | --reward_bias -4.0 \ 13 | --utd 4 \ 14 | --batch_size 1024 \ 15 | --warmup_steps 5_000 \ 16 | $@ 17 | -------------------------------------------------------------------------------- /experiments/scripts/kitchen/launch_wsrl_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent sac \ 7 | --config experiments/configs/train_config.py:kitchen_wsrl \ 8 | --project kitchen-finetune \ 9 | --num_offline_steps 250_000 \ 10 | --reward_scale 1.0 \ 11 | --reward_bias -4.0 \ 12 | --env kitchen-partial-v0 \ 13 | --utd 4 \ 14 | --batch_size 1024 \ 15 | --warmup_steps 5000 \ 16 | $@ 17 | -------------------------------------------------------------------------------- /experiments/scripts/locomotion/launch_cql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent cql \ 7 | --config experiments/configs/train_config.py:locomotion_cql \ 8 | --env halfcheetah-medium-replay-v2 \ 9 | --project locomotion-finetune \ 10 | --reward_scale 1.0 \ 11 | --reward_bias 0.0 \ 12 | --num_offline_steps 250_000 \ 13 | $@ 14 | -------------------------------------------------------------------------------- /experiments/scripts/locomotion/launch_iql_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent iql \ 7 | --config experiments/configs/train_config.py:locomotion_iql \ 8 | --env halfcheetah-medium-replay-v2 \ 9 | --project locomotion-finetune \ 10 | --reward_scale 1.0 \ 11 | --reward_bias 0.0 \ 12 | --num_offline_steps 250_000 \ 13 | $@ 14 | -------------------------------------------------------------------------------- /experiments/scripts/locomotion/launch_rlpd.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent sac \ 7 | --config experiments/configs/train_config.py:locomotion_wsrl \ 8 | --project baselines-section \ 9 | --reward_scale 1.0 \ 10 | --reward_bias 0.0 \ 11 | --num_offline_steps 0 \ 12 | --offline_data_ratio 0.5 \ 13 | --env halfcheetah-medium-replay-v2 \ 14 | --utd 4 \ 15 | --batch_size 1024 \ 16 | --warmup_steps 5_000 \ 17 | $@ 18 | -------------------------------------------------------------------------------- /experiments/scripts/locomotion/launch_wsrl_finetune.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 2 | export PYOPENGL_PLATFORM=egl 3 | export MUJOCO_GL=egl 4 | 5 | python3 finetune.py \ 6 | --agent sac \ 7 | --config experiments/configs/train_config.py:locomotion_wsrl \ 8 | --project method-section \ 9 | --reward_scale 1.0 \ 10 | --reward_bias 0.0 \ 11 | --num_offline_steps 250_000 \ 12 | --env halfcheetah-medium-replay-v2 \ 13 | --utd 4 \ 14 | --batch_size 1024 \ 15 | --warmup_steps 5000 \ 16 | $@ 17 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import gym 5 | import jax 6 | import numpy as np 7 | import tqdm 8 | from absl import app, flags, logging 9 | from flax.training import checkpoints 10 | from ml_collections import config_flags 11 | 12 | from experiments.configs.ensemble_config import add_redq_config 13 | from wsrl.agents import agents 14 | from wsrl.common.evaluation import evaluate_with_trajectories 15 | from wsrl.common.wandb import WandBLogger 16 | from wsrl.data.replay_buffer import ReplayBuffer, ReplayBufferMC 17 | from wsrl.envs.adroit_binary_dataset import get_hand_dataset_with_mc_calculation 18 | from wsrl.envs.d4rl_dataset import ( 19 | get_d4rl_dataset, 20 | get_d4rl_dataset_with_mc_calculation, 21 | ) 22 | from wsrl.envs.env_common import get_env_type, make_gym_env 23 | from wsrl.utils.timer_utils import Timer 24 | from wsrl.utils.train_utils import concatenate_batches, subsample_batch 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | # env 29 | flags.DEFINE_string("env", "antmaze-large-diverse-v2", "Environemnt to use") 30 | flags.DEFINE_float("reward_scale", 1.0, "Reward scale.") 31 | flags.DEFINE_float("reward_bias", -1.0, "Reward bias.") 32 | flags.DEFINE_float( 33 | "clip_action", 34 | 0.99999, 35 | "Clip actions to be between [-n, n]. This is needed for tanh policies.", 36 | ) 37 | 38 | # training 39 | flags.DEFINE_integer("num_offline_steps", 1_000_000, "Number of offline epochs.") 40 | flags.DEFINE_integer("num_online_steps", 500_000, "Number of online epochs.") 41 | flags.DEFINE_float( 42 | "offline_data_ratio", 43 | 0.0, 44 | "How much offline data to retain in each online batch update", 45 | ) 46 | flags.DEFINE_string( 47 | "online_sampling_method", 48 | "mixed", 49 | """Method of sampling data during online update: mixed or append. 50 | `mixed` samples from a mix of offline and online data according to offline_data_ratio. 51 | `append` adds offline data to replay buffer and samples from it.""", 52 | ) 53 | flags.DEFINE_bool( 54 | "online_use_cql_loss", 55 | True, 56 | """When agent is CQL/CalQL, whether to use CQL loss for the online phase (use SAC loss if False)""", 57 | ) 58 | flags.DEFINE_integer( 59 | "warmup_steps", 0, "number of warmup steps (WSRL) before performing online updates" 60 | ) 61 | 62 | # agent 63 | flags.DEFINE_string("agent", "calql", "what RL agent to use") 64 | flags.DEFINE_integer("utd", 1, "update-to-data ratio of the critic") 65 | flags.DEFINE_integer("batch_size", 256, "batch size for training") 66 | flags.DEFINE_integer("replay_buffer_capacity", int(2e6), "Replay buffer capacity") 67 | flags.DEFINE_bool("use_redq", False, "Use an ensemble of Q-functions for the agent") 68 | 69 | # experiment house keeping 70 | flags.DEFINE_integer("seed", 0, "Random seed.") 71 | flags.DEFINE_string( 72 | "save_dir", 73 | os.path.expanduser("~/wsrl_log"), 74 | "Directory to save the logs and checkpoints", 75 | ) 76 | flags.DEFINE_string("resume_path", "", "Path to resume from") 77 | flags.DEFINE_integer("log_interval", 5_000, "Log every n steps") 78 | flags.DEFINE_integer("eval_interval", 20_000, "Evaluate every n steps") 79 | flags.DEFINE_integer("save_interval", 100_000, "Save every n steps.") 80 | flags.DEFINE_integer( 81 | "n_eval_trajs", 20, "Number of trajectories to use for each evaluation." 82 | ) 83 | flags.DEFINE_bool("deterministic_eval", True, "Whether to use deterministic evaluation") 84 | 85 | # wandb 86 | flags.DEFINE_string("exp_name", "", "Experiment name for wandb logging") 87 | flags.DEFINE_string("project", None, "Wandb project folder") 88 | flags.DEFINE_string("group", None, "Wandb group of the experiment") 89 | flags.DEFINE_bool("debug", False, "If true, no logging to wandb") 90 | 91 | config_flags.DEFINE_config_file( 92 | "config", 93 | None, 94 | "File path to the training hyperparameter configuration.", 95 | lock_config=False, 96 | ) 97 | 98 | 99 | def main(_): 100 | """ 101 | house keeping 102 | """ 103 | assert FLAGS.online_sampling_method in [ 104 | "mixed", 105 | "append", 106 | ], "incorrect online sampling method" 107 | 108 | if FLAGS.use_redq: 109 | FLAGS.config.agent_kwargs = add_redq_config(FLAGS.config.agent_kwargs) 110 | 111 | min_steps_to_update = FLAGS.batch_size * (1 - FLAGS.offline_data_ratio) 112 | if FLAGS.agent == "calql": 113 | min_steps_to_update = max( 114 | min_steps_to_update, gym.make(FLAGS.env)._max_episode_steps 115 | ) 116 | 117 | """ 118 | wandb and logging 119 | """ 120 | wandb_config = WandBLogger.get_default_config() 121 | wandb_config.update( 122 | { 123 | "project": "wsrl" or FLAGS.project, 124 | "group": "wsrl" or FLAGS.group, 125 | "exp_descriptor": f"{FLAGS.exp_name}_{FLAGS.env}_{FLAGS.agent}_seed{FLAGS.seed}", 126 | } 127 | ) 128 | wandb_logger = WandBLogger( 129 | wandb_config=wandb_config, 130 | variant=FLAGS.config.to_dict(), 131 | random_str_in_identifier=True, 132 | disable_online_logging=FLAGS.debug, 133 | ) 134 | 135 | save_dir = os.path.join( 136 | FLAGS.save_dir, 137 | wandb_logger.config.project, 138 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}", 139 | ) 140 | 141 | """ 142 | env 143 | """ 144 | # do not clip adroit actions online following CalQL repo 145 | # https://github.com/nakamotoo/Cal-QL 146 | env_type = get_env_type(FLAGS.env) 147 | finetune_env = make_gym_env( 148 | env_name=FLAGS.env, 149 | reward_scale=FLAGS.reward_scale, 150 | reward_bias=FLAGS.reward_bias, 151 | scale_and_clip_action=env_type in ("antmaze", "kitchen", "locomotion"), 152 | action_clip_lim=FLAGS.clip_action, 153 | seed=FLAGS.seed, 154 | ) 155 | eval_env = make_gym_env( 156 | env_name=FLAGS.env, 157 | scale_and_clip_action=env_type in ("antmaze", "kitchen", "locomotion"), 158 | action_clip_lim=FLAGS.clip_action, 159 | seed=FLAGS.seed + 1000, 160 | ) 161 | 162 | """ 163 | load dataset 164 | """ 165 | if env_type == "adroit-binary": 166 | dataset = get_hand_dataset_with_mc_calculation( 167 | FLAGS.env, 168 | gamma=FLAGS.config.agent_kwargs.discount, 169 | reward_scale=FLAGS.reward_scale, 170 | reward_bias=FLAGS.reward_bias, 171 | clip_action=FLAGS.clip_action, 172 | ) 173 | else: 174 | if FLAGS.agent == "calql": 175 | # need dataset with mc return 176 | dataset = get_d4rl_dataset_with_mc_calculation( 177 | FLAGS.env, 178 | reward_scale=FLAGS.reward_scale, 179 | reward_bias=FLAGS.reward_bias, 180 | clip_action=FLAGS.clip_action, 181 | gamma=FLAGS.config.agent_kwargs.discount, 182 | ) 183 | else: 184 | dataset = get_d4rl_dataset( 185 | FLAGS.env, 186 | reward_scale=FLAGS.reward_scale, 187 | reward_bias=FLAGS.reward_bias, 188 | clip_action=FLAGS.clip_action, 189 | ) 190 | 191 | """ 192 | replay buffer 193 | """ 194 | replay_buffer_type = ReplayBufferMC if FLAGS.agent == "calql" else ReplayBuffer 195 | replay_buffer = replay_buffer_type( 196 | finetune_env.observation_space, 197 | finetune_env.action_space, 198 | capacity=FLAGS.replay_buffer_capacity, 199 | seed=FLAGS.seed, 200 | discount=FLAGS.config.agent_kwargs.discount if FLAGS.agent == "calql" else None, 201 | ) 202 | 203 | """ 204 | Initialize agent 205 | """ 206 | rng = jax.random.PRNGKey(FLAGS.seed) 207 | rng, construct_rng = jax.random.split(rng) 208 | example_batch = subsample_batch(dataset, FLAGS.batch_size) 209 | agent = agents[FLAGS.agent].create( 210 | rng=construct_rng, 211 | observations=example_batch["observations"], 212 | actions=example_batch["actions"], 213 | encoder_def=None, 214 | **FLAGS.config.agent_kwargs, 215 | ) 216 | 217 | if FLAGS.resume_path != "": 218 | assert os.path.exists(FLAGS.resume_path), "resume path does not exist" 219 | agent = checkpoints.restore_checkpoint(FLAGS.resume_path, target=agent) 220 | 221 | """ 222 | eval function 223 | """ 224 | 225 | def evaluate_and_log_results( 226 | eval_env, 227 | policy_fn, 228 | eval_func, 229 | step_number, 230 | wandb_logger, 231 | n_eval_trajs=FLAGS.n_eval_trajs, 232 | ): 233 | stats, trajs = eval_func( 234 | policy_fn, 235 | eval_env, 236 | n_eval_trajs, 237 | ) 238 | 239 | eval_info = { 240 | "average_return": np.mean([np.sum(t["rewards"]) for t in trajs]), 241 | "average_traj_length": np.mean([len(t["rewards"]) for t in trajs]), 242 | } 243 | if env_type == "adroit-binary": 244 | # adroit 245 | eval_info["success_rate"] = np.mean( 246 | [any(d["goal_achieved"] for d in t["infos"]) for t in trajs] 247 | ) 248 | elif env_type == "kitchen": 249 | # kitchen 250 | eval_info["num_stages_solved"] = np.mean([t["rewards"][-1] for t in trajs]) 251 | eval_info["success_rate"] = np.mean([t["rewards"][-1] for t in trajs]) / 4 252 | else: 253 | # d4rl antmaze, locomotion 254 | eval_info["success_rate"] = eval_info[ 255 | "average_normalized_return" 256 | ] = np.mean( 257 | [eval_env.get_normalized_score(np.sum(t["rewards"])) for t in trajs] 258 | ) 259 | 260 | wandb_logger.log({"evaluation": eval_info}, step=step_number) 261 | 262 | """ 263 | training loop 264 | """ 265 | timer = Timer() 266 | step = int(agent.state.step) # 0 for new agents, or load from pre-trained 267 | is_online_stage = False 268 | observation, info = finetune_env.reset() 269 | done = False # env done signal 270 | 271 | for _ in tqdm.tqdm(range(step, FLAGS.num_offline_steps + FLAGS.num_online_steps)): 272 | """ 273 | Switch from offline to online 274 | """ 275 | if not is_online_stage and step >= FLAGS.num_offline_steps: 276 | logging.info("Switching to online training") 277 | is_online_stage = True 278 | 279 | # upload offline data to online buffer 280 | if FLAGS.online_sampling_method == "append": 281 | offline_dataset_size = dataset["actions"].shape[0] 282 | dataset_items = dataset.items() 283 | for j in range(offline_dataset_size): 284 | transition = {k: v[j] for k, v in dataset_items} 285 | replay_buffer.insert(transition) 286 | 287 | # option for CQL and CalQL to change the online alpha, and whether to use CQL regularizer 288 | if FLAGS.agent in ("cql", "calql"): 289 | online_agent_configs = { 290 | "cql_alpha": FLAGS.config.agent_kwargs.get( 291 | "online_cql_alpha", None 292 | ), 293 | "use_cql_loss": FLAGS.online_use_cql_loss, 294 | } 295 | agent.update_config(online_agent_configs) 296 | 297 | timer.tick("total") 298 | 299 | """ 300 | Env Step 301 | """ 302 | with timer.context("env step"): 303 | if is_online_stage: 304 | rng, action_rng = jax.random.split(rng) 305 | action = agent.sample_actions(observation, seed=action_rng) 306 | next_observation, reward, done, truncated, info = finetune_env.step( 307 | action 308 | ) 309 | 310 | transition = dict( 311 | observations=observation, 312 | next_observations=next_observation, 313 | actions=action, 314 | rewards=reward, 315 | masks=1.0 - done, 316 | dones=1.0 if (done or truncated) else 0, 317 | ) 318 | replay_buffer.insert(transition) 319 | 320 | observation = next_observation 321 | if done or truncated: 322 | observation, info = finetune_env.reset() 323 | done = False 324 | 325 | """ 326 | Updates 327 | """ 328 | with timer.context("update"): 329 | # offline updates 330 | if not is_online_stage: 331 | batch = subsample_batch(dataset, FLAGS.batch_size) 332 | agent, update_info = agent.update( 333 | batch, 334 | ) 335 | 336 | # online updates 337 | else: 338 | if step - FLAGS.num_offline_steps <= max( 339 | FLAGS.warmup_steps, min_steps_to_update 340 | ): 341 | # no updates during warmup 342 | pass 343 | else: 344 | # do online updates, gather batch 345 | if FLAGS.online_sampling_method == "mixed": 346 | # batch from a mixing ratio of offline and online data 347 | batch_size_offline = int( 348 | FLAGS.batch_size * FLAGS.offline_data_ratio 349 | ) 350 | batch_size_online = FLAGS.batch_size - batch_size_offline 351 | online_batch = replay_buffer.sample(batch_size_online) 352 | offline_batch = subsample_batch(dataset, batch_size_offline) 353 | # update with the combined batch 354 | batch = concatenate_batches([online_batch, offline_batch]) 355 | elif FLAGS.online_sampling_method == "append": 356 | # batch from online replay buffer, with is initialized with offline data 357 | batch = replay_buffer.sample(FLAGS.batch_size) 358 | else: 359 | raise RuntimeError("Incorrect online sampling method") 360 | 361 | # update 362 | if FLAGS.utd > 1: 363 | agent, update_info = agent.update_high_utd( 364 | batch, 365 | utd_ratio=FLAGS.utd, 366 | ) 367 | else: 368 | agent, update_info = agent.update( 369 | batch, 370 | ) 371 | 372 | """ 373 | Advance Step 374 | """ 375 | step += 1 376 | 377 | """ 378 | Evals 379 | """ 380 | eval_steps = ( 381 | FLAGS.num_offline_steps, # finish offline training 382 | FLAGS.num_offline_steps + 1, # start of online training 383 | FLAGS.num_offline_steps + FLAGS.num_online_steps, # end of online training 384 | ) 385 | if step % FLAGS.eval_interval == 0 or step in eval_steps: 386 | logging.info("Evaluating...") 387 | with timer.context("evaluation"): 388 | policy_fn = partial( 389 | agent.sample_actions, argmax=FLAGS.deterministic_eval 390 | ) 391 | eval_func = partial( 392 | evaluate_with_trajectories, clip_action=FLAGS.clip_action 393 | ) 394 | 395 | evaluate_and_log_results( 396 | eval_env=eval_env, 397 | policy_fn=policy_fn, 398 | eval_func=eval_func, 399 | step_number=step, 400 | wandb_logger=wandb_logger, 401 | ) 402 | 403 | """ 404 | Save Checkpoint 405 | """ 406 | if step % FLAGS.save_interval == 0 or step == FLAGS.num_offline_steps: 407 | logging.info("Saving checkpoint...") 408 | checkpoint_path = checkpoints.save_checkpoint( 409 | save_dir, agent, step=step, keep=30 410 | ) 411 | logging.info("Saved checkpoint to %s", checkpoint_path) 412 | 413 | timer.tock("total") 414 | 415 | """ 416 | Logging 417 | """ 418 | if step % FLAGS.log_interval == 0: 419 | # check if update_info is available (False during warmup) 420 | if "update_info" in locals(): 421 | update_info = jax.device_get(update_info) 422 | wandb_logger.log({"training": update_info}, step=step) 423 | 424 | wandb_logger.log({"timer": timer.get_average_times()}, step=step) 425 | 426 | 427 | if __name__ == "__main__": 428 | app.run(main) 429 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.26 2 | numpy==1.26.4 3 | distrax==0.1.2 4 | ml_collections >= 0.1.0 5 | tqdm >= 4.60.0 6 | chex==0.1.82 7 | optax==0.1.5 8 | absl-py >= 0.12.0 9 | scipy==1.11.2 10 | wandb >= 0.12.14 11 | einops >= 0.6.1 12 | imageio >= 2.31.1 13 | moviepy >= 1.0.3 14 | pre-commit == 3.3.3 15 | overrides 16 | cython < 3 17 | patchelf 18 | orbax-checkpoint == 0.3.5 19 | flax == 0.7.5 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="wsrl", packages=["wsrl"]) 4 | -------------------------------------------------------------------------------- /wsrl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .bc import BCAgent 2 | from .calql import CalQLAgent 3 | from .cql import CQLAgent 4 | from .iql import IQLAgent 5 | from .sac import SACAgent 6 | 7 | agents = { 8 | "bc": BCAgent, 9 | "iql": IQLAgent, 10 | "cql": CQLAgent, 11 | "calql": CalQLAgent, 12 | "sac": SACAgent, 13 | } 14 | -------------------------------------------------------------------------------- /wsrl/agents/bc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Optional 3 | 4 | import flax 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | from flax.core import FrozenDict 11 | 12 | from wsrl.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 13 | from wsrl.common.typing import Batch, PRNGKey 14 | from wsrl.networks.actor_critic_nets import Policy 15 | from wsrl.networks.mlp import MLP 16 | 17 | 18 | class BCAgent(flax.struct.PyTreeNode): 19 | state: JaxRLTrainState 20 | lr_schedule: Any = nonpytree_field() 21 | 22 | @partial(jax.jit, static_argnames="pmap_axis") 23 | def update(self, batch: Batch, pmap_axis: str = None): 24 | def loss_fn(params, rng): 25 | rng, key = jax.random.split(rng) 26 | dist = self.state.apply_fn( 27 | {"params": params}, 28 | batch["observations"], 29 | temperature=1.0, 30 | train=True, 31 | rngs={"dropout": key}, 32 | name="actor", 33 | ) 34 | pi_actions = dist.mode() 35 | log_probs = dist.log_prob(batch["actions"]) 36 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 37 | actor_loss = -(log_probs).mean() 38 | actor_std = dist.stddev().mean(axis=1) 39 | 40 | return actor_loss, { 41 | "actor_loss": actor_loss, 42 | "mse": mse.mean(), 43 | "entropy": -dist.log_prob(pi_actions).mean(), 44 | "log_probs": log_probs, 45 | "pi_actions": pi_actions, 46 | "mean_std": actor_std.mean(), 47 | "max_std": actor_std.max(), 48 | } 49 | 50 | # compute gradients and update params 51 | new_state, info = self.state.apply_loss_fns( 52 | loss_fn, pmap_axis=pmap_axis, has_aux=True 53 | ) 54 | 55 | # log learning rates 56 | info["lr"] = self.lr_schedule(self.state.step) 57 | 58 | return self.replace(state=new_state), info 59 | 60 | @partial(jax.jit, static_argnames="argmax") 61 | def sample_actions( 62 | self, 63 | observations: np.ndarray, 64 | *, 65 | seed: Optional[PRNGKey] = None, 66 | temperature: float = 1.0, 67 | argmax=False, 68 | ) -> jnp.ndarray: 69 | dist = self.state.apply_fn( 70 | {"params": self.state.params}, 71 | observations, 72 | temperature=temperature, 73 | name="actor", 74 | ) 75 | if argmax: 76 | assert seed is None, "Cannot specify seed when sampling deterministically" 77 | actions = dist.mode() 78 | else: 79 | actions = dist.sample(seed=seed) 80 | return actions 81 | 82 | @jax.jit 83 | def get_debug_metrics(self, batch, **kwargs): 84 | dist = self.state.apply_fn( 85 | {"params": self.state.params}, 86 | batch["observations"], 87 | temperature=1.0, 88 | name="actor", 89 | ) 90 | pi_actions = dist.mode() 91 | log_probs = dist.log_prob(batch["actions"]) 92 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 93 | 94 | return { 95 | "mse": mse, 96 | "log_probs": log_probs, 97 | "pi_actions": pi_actions, 98 | } 99 | 100 | @classmethod 101 | def create( 102 | cls, 103 | rng: PRNGKey, 104 | observations: FrozenDict, 105 | actions: jnp.ndarray, 106 | # Model architecture 107 | encoder_def: nn.Module, 108 | network_kwargs: dict = { 109 | "hidden_dims": [256, 256], 110 | }, 111 | policy_kwargs: dict = { 112 | "tanh_squash_distribution": False, 113 | }, 114 | # Optimizer 115 | learning_rate: float = 3e-4, 116 | warmup_steps: int = 1000, 117 | decay_steps: int = 1000000, 118 | **kwargs, 119 | ): 120 | network_kwargs["activate_final"] = True 121 | networks = { 122 | "actor": Policy( 123 | encoder_def, 124 | MLP(**network_kwargs), 125 | action_dim=actions.shape[-1], 126 | **policy_kwargs, 127 | ) 128 | } 129 | 130 | model_def = ModuleDict(networks) 131 | 132 | lr_schedule = optax.warmup_cosine_decay_schedule( 133 | init_value=0.0, 134 | peak_value=learning_rate, 135 | warmup_steps=warmup_steps, 136 | decay_steps=decay_steps, 137 | end_value=0.0, 138 | ) 139 | tx = optax.adam(lr_schedule) 140 | 141 | rng, init_rng = jax.random.split(rng) 142 | params = model_def.init(init_rng, actor=[observations])["params"] 143 | 144 | rng, create_rng = jax.random.split(rng) 145 | state = JaxRLTrainState.create( 146 | apply_fn=model_def.apply, 147 | params=params, 148 | txs=tx, 149 | target_params=params, 150 | rng=create_rng, 151 | ) 152 | 153 | return cls(state, lr_schedule) 154 | -------------------------------------------------------------------------------- /wsrl/agents/calql.py: -------------------------------------------------------------------------------- 1 | from wsrl.agents.cql import CQLAgent 2 | 3 | 4 | class CalQLAgent(CQLAgent): 5 | """Same agent as CQL, just add an additional check that the use_calql flag is on.""" 6 | 7 | @classmethod 8 | def create( 9 | cls, 10 | *args, 11 | **kwargs, 12 | ): 13 | kwargs["use_calql"] = True 14 | return super(CalQLAgent, cls).create( 15 | *args, 16 | **kwargs, 17 | ) 18 | -------------------------------------------------------------------------------- /wsrl/agents/cql.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of CQL in continuous action spaces. 3 | """ 4 | import copy 5 | from functools import partial 6 | from typing import Optional, Tuple 7 | 8 | import chex 9 | import flax 10 | import flax.linen as nn 11 | import jax 12 | import jax.numpy as jnp 13 | from ml_collections import ConfigDict 14 | from overrides import overrides 15 | 16 | from wsrl.agents.sac import SACAgent 17 | from wsrl.common.common import JaxRLTrainState, ModuleDict 18 | from wsrl.common.optimizers import make_optimizer 19 | from wsrl.common.typing import * 20 | from wsrl.networks.actor_critic_nets import Critic, Policy, ensemblize 21 | from wsrl.networks.lagrange import GeqLagrangeMultiplier, LeqLagrangeMultiplier 22 | from wsrl.networks.mlp import MLP 23 | 24 | 25 | class CQLAgent(SACAgent): 26 | def forward_cql_alpha_lagrange(self, *, grad_params: Optional[Params] = None): 27 | """ 28 | Forward pass for the CQL alpha Lagrange multiplier 29 | Pass grad_params to use non-default parameters (e.g. for gradients). 30 | """ 31 | return self.state.apply_fn( 32 | {"params": grad_params or self.state.params}, 33 | name="cql_alpha_lagrange", 34 | ) 35 | 36 | def _get_cql_q_diff( 37 | self, batch, rng: PRNGKey, grad_params: Optional[Params] = None 38 | ): 39 | """ 40 | most of the CQL loss logic is here 41 | It is needed for both critic_loss_fn and cql_alpha_loss_fn 42 | """ 43 | batch_size = batch["rewards"].shape[0] 44 | q_pred = self.forward_critic( 45 | batch["observations"], 46 | batch["actions"], 47 | rng, 48 | grad_params=grad_params, 49 | ) 50 | chex.assert_shape(q_pred, (self.config["critic_ensemble_size"], batch_size)) 51 | 52 | """sample random actions""" 53 | action_dim = batch["actions"].shape[-1] 54 | rng, action_rng = jax.random.split(rng) 55 | if self.config["cql_action_sample_method"] == "uniform": 56 | cql_random_actions = jax.random.uniform( 57 | action_rng, 58 | shape=(batch_size, self.config["cql_n_actions"], action_dim), 59 | minval=-1.0, 60 | maxval=1.0, 61 | ) 62 | elif self.config["cql_action_sample_method"] == "normal": 63 | cql_random_actions = jax.random.normal( 64 | action_rng, 65 | shape=(batch_size, self.config["cql_n_actions"], action_dim), 66 | ) 67 | else: 68 | raise NotImplementedError 69 | 70 | rng, current_a_rng, next_a_rng = jax.random.split(rng, 3) 71 | cql_current_actions, cql_current_log_pis = self.forward_policy_and_sample( 72 | batch["observations"], 73 | current_a_rng, 74 | repeat=self.config["cql_n_actions"], 75 | ) 76 | chex.assert_shape( 77 | cql_current_log_pis, (batch_size, self.config["cql_n_actions"]) 78 | ) 79 | 80 | cql_next_actions, cql_next_log_pis = self.forward_policy_and_sample( 81 | batch["next_observations"], 82 | next_a_rng, 83 | repeat=self.config["cql_n_actions"], 84 | ) 85 | 86 | all_sampled_actions = jnp.concatenate( 87 | [ 88 | cql_random_actions, 89 | cql_current_actions, 90 | cql_next_actions, 91 | ], 92 | axis=1, 93 | ) 94 | 95 | """q values of randomly sampled actions""" 96 | rng, q_rng = jax.random.split(rng) 97 | cql_q_samples = self.forward_critic( 98 | batch["observations"], 99 | all_sampled_actions, # this is being vmapped over in sac.py 100 | q_rng, 101 | grad_params=grad_params, 102 | train=True, 103 | ) 104 | chex.assert_shape( 105 | cql_q_samples, 106 | ( 107 | self.config["critic_ensemble_size"], 108 | batch_size, 109 | self.config["cql_n_actions"] * 3, 110 | ), 111 | ) 112 | 113 | if self.config["critic_subsample_size"] is not None: 114 | rng, subsample_key = jax.random.split(rng) 115 | subsample_idcs = jax.random.randint( 116 | subsample_key, 117 | (self.config["critic_subsample_size"],), 118 | 0, 119 | self.config["critic_ensemble_size"], 120 | ) 121 | cql_q_samples = cql_q_samples[subsample_idcs] 122 | q_pred = q_pred[subsample_idcs] 123 | critic_size = self.config["critic_subsample_size"] 124 | else: 125 | critic_size = self.config["critic_ensemble_size"] 126 | """Cal-QL""" 127 | if self.config["use_calql"]: 128 | if self.config["calql_bound_random_actions"]: 129 | mc_lower_bound = jnp.repeat( 130 | batch["mc_returns"].reshape(-1, 1), 131 | self.config["cql_n_actions"] * 3, 132 | axis=1, 133 | ) 134 | else: 135 | fake_lower_bound = jnp.repeat( 136 | jnp.ones_like(batch["mc_returns"].reshape(-1, 1)) * (-jnp.inf), 137 | self.config["cql_n_actions"], 138 | axis=1, 139 | ) 140 | mc_lower_bound = jnp.repeat( 141 | batch["mc_returns"].reshape(-1, 1), 142 | self.config["cql_n_actions"] * 2, 143 | axis=1, 144 | ) 145 | mc_lower_bound = jnp.concatenate( 146 | [fake_lower_bound, mc_lower_bound], axis=1 147 | ) 148 | chex.assert_shape( 149 | mc_lower_bound, (batch_size, self.config["cql_n_actions"] * 3) 150 | ) 151 | 152 | num_vals = jnp.size(cql_q_samples) 153 | calql_bound_rate = jnp.sum(cql_q_samples < mc_lower_bound) / num_vals 154 | cql_q_samples = jnp.maximum(cql_q_samples, mc_lower_bound) 155 | 156 | if self.config["cql_importance_sample"]: 157 | random_density = jnp.log(0.5**action_dim) 158 | 159 | importance_prob = jnp.concatenate( 160 | [ 161 | jnp.broadcast_to( 162 | random_density, (batch_size, self.config["cql_n_actions"]) 163 | ), 164 | cql_current_log_pis, 165 | cql_next_log_pis, # this order matters, should match all_sampled_actions 166 | ], 167 | axis=1, 168 | ) 169 | cql_q_samples = cql_q_samples - importance_prob # broadcast over dim 0 170 | else: 171 | cql_q_samples = jnp.concatenate( 172 | [ 173 | cql_q_samples, 174 | jnp.expand_dims(q_pred, -1), 175 | ], 176 | axis=-1, 177 | ) 178 | cql_q_samples -= jnp.log(cql_q_samples.shape[-1]) * self.config["cql_temp"] 179 | chex.assert_shape( 180 | cql_q_samples, 181 | ( 182 | critic_size, 183 | batch_size, 184 | 3 * self.config["cql_n_actions"] + 1, 185 | ), 186 | ) 187 | 188 | """log sum exp of the ood actions""" 189 | cql_ood_values = ( 190 | jax.scipy.special.logsumexp( 191 | cql_q_samples / self.config["cql_temp"], axis=-1 192 | ) 193 | * self.config["cql_temp"] 194 | ) 195 | chex.assert_shape(cql_ood_values, (critic_size, batch_size)) 196 | 197 | cql_q_diff = cql_ood_values - q_pred 198 | info = { 199 | "cql_ood_values": cql_ood_values.mean(), 200 | } 201 | if self.config["use_calql"]: 202 | info["calql_bound_rate"] = calql_bound_rate 203 | 204 | return cql_q_diff, info 205 | 206 | @overrides 207 | def _compute_next_actions(self, batch, rng): 208 | """ 209 | compute the next actions but with repeat cql_n_actions times 210 | this should only be used when calculating critic loss using 211 | cql_max_target_backup 212 | """ 213 | sample_n_actions = ( 214 | self.config["cql_n_actions"] 215 | if self.config["cql_max_target_backup"] 216 | else None 217 | ) 218 | next_actions, next_actions_log_probs = self.forward_policy_and_sample( 219 | batch["next_observations"], 220 | rng, 221 | repeat=sample_n_actions, 222 | ) 223 | return next_actions, next_actions_log_probs 224 | 225 | @overrides 226 | def _process_target_next_qs(self, target_next_qs, next_actions_log_probs): 227 | """add cql_max_target_backup option""" 228 | 229 | if self.config["cql_max_target_backup"]: 230 | max_target_indices = jnp.expand_dims( 231 | jnp.argmax(target_next_qs, axis=-1), axis=-1 232 | ) 233 | target_next_qs = jnp.take_along_axis( 234 | target_next_qs, max_target_indices, axis=-1 235 | ).squeeze(-1) 236 | next_actions_log_probs = jnp.take_along_axis( 237 | next_actions_log_probs, max_target_indices, axis=-1 238 | ).squeeze(-1) 239 | 240 | assert not self.config["backup_entropy"], "Need to call the super() fn" 241 | 242 | return target_next_qs 243 | 244 | @overrides 245 | def critic_loss_fn(self, batch, params: Params, rng: PRNGKey): 246 | """add CQL loss on top of SAC loss""" 247 | if self.config["use_td_loss"]: 248 | td_loss, td_loss_info = super().critic_loss_fn(batch, params, rng) 249 | else: 250 | td_loss, td_loss_info = 0.0, {} 251 | 252 | if self.config["use_cql_loss"]: 253 | 254 | cql_q_diff, cql_intermediate_results = self._get_cql_q_diff( 255 | batch, rng, params 256 | ) 257 | 258 | """auto tune cql alpha""" 259 | if self.config["cql_autotune_alpha"]: 260 | alpha = self.forward_cql_alpha_lagrange() 261 | cql_loss = (cql_q_diff - self.config["cql_target_action_gap"]).mean() 262 | else: 263 | alpha = self.config["cql_alpha"] 264 | cql_loss = jnp.clip( 265 | cql_q_diff, 266 | self.config["cql_clip_diff_min"], 267 | self.config["cql_clip_diff_max"], 268 | ).mean() 269 | 270 | critic_loss = td_loss + alpha * cql_loss 271 | cql_loss_info = { 272 | "cql_loss": cql_loss, 273 | "cql_alpha": alpha, 274 | "cql_diff": cql_q_diff.mean(), 275 | **cql_intermediate_results, 276 | } 277 | else: 278 | critic_loss = td_loss 279 | cql_loss_info = {} 280 | 281 | info = { 282 | **td_loss_info, 283 | **cql_loss_info, 284 | "critic_loss": critic_loss, 285 | "td_loss": td_loss, 286 | } 287 | 288 | return critic_loss, info 289 | 290 | def cql_alpha_lagrange_penalty( 291 | self, qvals_diff, *, grad_params: Optional[Params] = None 292 | ): 293 | return self.state.apply_fn( 294 | {"params": grad_params or self.state.params}, 295 | lhs=qvals_diff, 296 | rhs=self.config["cql_target_action_gap"], 297 | name="cql_alpha_lagrange", 298 | ) 299 | 300 | def cql_alpha_loss_fn(self, batch, params: Params, rng: PRNGKey): 301 | """recompute cql_q_diff without gradients (not optimal for runtime)""" 302 | cql_q_diff, _ = self._get_cql_q_diff(batch, rng) 303 | 304 | cql_alpha_loss = self.cql_alpha_lagrange_penalty( 305 | qvals_diff=cql_q_diff.mean(), 306 | grad_params=params, 307 | ) 308 | lmbda = self.forward_cql_alpha_lagrange() 309 | 310 | return cql_alpha_loss, { 311 | "cql_alpha_loss": cql_alpha_loss, 312 | "cql_alpha_lagrange_multiplier": lmbda, 313 | } 314 | 315 | @overrides 316 | def loss_fns(self, batch): 317 | losses = super().loss_fns(batch) 318 | if self.config["cql_autotune_alpha"]: 319 | losses["cql_alpha_lagrange"] = partial(self.cql_alpha_loss_fn, batch) 320 | 321 | return losses 322 | 323 | def update( 324 | self, 325 | batch: Batch, 326 | pmap_axis: str = None, 327 | networks_to_update: set = set({"actor", "critic"}), 328 | ): 329 | """update super() to perhaps include updating CQL lagrange multiplier""" 330 | if not isinstance(networks_to_update, frozenset): 331 | if self.config["autotune_entropy"]: 332 | networks_to_update.add("temperature") 333 | if self.config["cql_autotune_alpha"]: 334 | networks_to_update.add("cql_alpha_lagrange") 335 | 336 | return super().update( 337 | batch, 338 | pmap_axis=pmap_axis, 339 | networks_to_update=frozenset(networks_to_update), 340 | ) 341 | 342 | @partial(jax.jit, static_argnames=("utd_ratio", "pmap_axis")) 343 | def update_high_utd( 344 | self, 345 | batch: Batch, 346 | *, 347 | utd_ratio: int, 348 | pmap_axis: Optional[str] = None, 349 | ) -> Tuple["SACAgent", dict]: 350 | """ 351 | same as super().update_high_utd, but also considers the CQL alpha lagrange loss 352 | """ 353 | batch_size = batch["rewards"].shape[0] 354 | assert ( 355 | batch_size % utd_ratio == 0 356 | ), f"Batch size {batch_size} must be divisible by UTD ratio {utd_ratio}" 357 | minibatch_size = batch_size // utd_ratio 358 | chex.assert_tree_shape_prefix(batch, (batch_size,)) 359 | 360 | def scan_body(carry: Tuple[SACAgent], data: Tuple[Batch]): 361 | (agent,) = carry 362 | (minibatch,) = data 363 | agent, info = agent.update( 364 | minibatch, 365 | pmap_axis=pmap_axis, 366 | networks_to_update=frozenset({"critic"}), 367 | ) 368 | return (agent,), info 369 | 370 | def make_minibatch(data: jnp.ndarray): 371 | return jnp.reshape(data, (utd_ratio, minibatch_size) + data.shape[1:]) 372 | 373 | minibatches = jax.tree_map(make_minibatch, batch) 374 | 375 | (agent,), critic_infos = jax.lax.scan(scan_body, (self,), (minibatches,)) 376 | 377 | critic_infos = jax.tree_map(lambda x: jnp.mean(x, axis=0), critic_infos) 378 | del critic_infos["actor"] 379 | del critic_infos["temperature"] 380 | 381 | # Take one gradient descent step on the actor, temperature, and cql_alpha_lagrange 382 | networks_to_update = set(("actor", "temperature")) 383 | if self.config["cql_autotune_alpha"]: # only diff from super().update_high_utd 384 | networks_to_update.add("cql_alpha_lagrange") 385 | agent, actor_temp_infos = agent.update( 386 | batch, 387 | pmap_axis=pmap_axis, 388 | networks_to_update=frozenset(networks_to_update), 389 | ) 390 | del actor_temp_infos["critic"] 391 | 392 | infos = {**critic_infos, **actor_temp_infos} 393 | 394 | return agent, infos 395 | 396 | @classmethod 397 | def create( 398 | cls, 399 | rng: PRNGKey, 400 | observations: Data, 401 | actions: jnp.ndarray, 402 | # Model arch 403 | encoder_def: nn.Module, 404 | shared_encoder: bool = False, 405 | critic_network_type: str = "mlp", 406 | critic_network_kwargs: dict = { 407 | "hidden_dims": [256, 256], 408 | }, 409 | policy_network_kwargs: dict = { 410 | "hidden_dims": [256, 256], 411 | }, 412 | policy_kwargs: dict = { 413 | "tanh_squash_distribution": True, 414 | "std_parameterization": "exp", 415 | }, 416 | **kwargs, 417 | ): 418 | # update algorithm config 419 | config = ConfigDict(kwargs) 420 | 421 | if shared_encoder: 422 | encoders = { 423 | "actor": encoder_def, 424 | "critic": encoder_def, 425 | } 426 | else: 427 | encoders = { 428 | "actor": encoder_def, 429 | "critic": copy.deepcopy(encoder_def), 430 | } 431 | 432 | # Define networks 433 | policy_def = Policy( 434 | encoder=encoders["actor"], 435 | network=MLP(**policy_network_kwargs), 436 | action_dim=actions.shape[-1], 437 | **policy_kwargs, 438 | name="actor", 439 | ) 440 | 441 | critic_backbone = partial(MLP, **critic_network_kwargs) 442 | critic_backbone = ensemblize(critic_backbone, config.critic_ensemble_size)( 443 | name="critic_ensemble" 444 | ) 445 | 446 | critic_def = partial( 447 | Critic, 448 | encoder=encoders["critic"], 449 | network=critic_backbone, 450 | )(name="critic") 451 | 452 | temperature_def = GeqLagrangeMultiplier( 453 | init_value=config.temperature_init, 454 | constraint_shape=(), 455 | name="temperature", 456 | ) 457 | if config["cql_autotune_alpha"]: 458 | cql_alpha_lagrange_def = LeqLagrangeMultiplier( 459 | init_value=config.cql_alpha_lagrange_init, 460 | constraint_shape=(), 461 | name="cql_alpha_lagrange", 462 | ) 463 | 464 | # model def 465 | networks = { 466 | "actor": policy_def, 467 | "critic": critic_def, 468 | "temperature": temperature_def, 469 | } 470 | if config["cql_autotune_alpha"]: 471 | networks["cql_alpha_lagrange"] = cql_alpha_lagrange_def 472 | model_def = ModuleDict(networks) 473 | 474 | # Define optimizers 475 | txs = { 476 | "actor": make_optimizer(**config.actor_optimizer_kwargs), 477 | "critic": make_optimizer(**config.critic_optimizer_kwargs), 478 | "temperature": make_optimizer(**config.temperature_optimizer_kwargs), 479 | } 480 | if config["cql_autotune_alpha"]: 481 | txs["cql_alpha_lagrange"] = make_optimizer( 482 | **config.cql_alpha_lagrange_otpimizer_kwargs 483 | ) 484 | 485 | # init params 486 | rng, init_rng = jax.random.split(rng) 487 | extra_kwargs = {} 488 | if config["cql_autotune_alpha"]: 489 | extra_kwargs["cql_alpha_lagrange"] = [] 490 | network_input = observations 491 | params = model_def.init( 492 | init_rng, 493 | actor=[network_input], 494 | critic=[network_input, actions], 495 | temperature=[], 496 | **extra_kwargs, 497 | )["params"] 498 | 499 | # create 500 | rng, create_rng = jax.random.split(rng) 501 | state = JaxRLTrainState.create( 502 | apply_fn=model_def.apply, 503 | params=params, 504 | txs=txs, 505 | target_params=params, 506 | rng=create_rng, 507 | ) 508 | 509 | # config 510 | if config.target_entropy >= 0.0: 511 | config.target_entropy = -actions.shape[-1] 512 | config = flax.core.FrozenDict(config) 513 | return cls(state, config) 514 | -------------------------------------------------------------------------------- /wsrl/agents/iql.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import chex 6 | import flax 7 | import flax.linen as nn 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | from flax.core import FrozenDict 12 | 13 | from wsrl.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 14 | from wsrl.common.optimizers import make_optimizer 15 | from wsrl.common.typing import Batch, Data, Params, PRNGKey 16 | from wsrl.networks.actor_critic_nets import Critic, Policy, ValueCritic, ensemblize 17 | from wsrl.networks.mlp import MLP 18 | 19 | 20 | def expectile_loss(diff, expectile=0.5): 21 | weight = jnp.where(diff > 0, expectile, (1 - expectile)) 22 | return weight * (diff**2) 23 | 24 | 25 | def iql_value_loss(q, v, expectile): 26 | value_loss = expectile_loss(q - v, expectile) 27 | return value_loss.mean(), { 28 | "value_loss": value_loss.mean(), 29 | "uncentered_loss": jnp.mean((q - v) ** 2), 30 | "v": v.mean(), 31 | } 32 | 33 | 34 | def iql_critic_loss(q, q_target): 35 | """mse loss""" 36 | critic_loss = jnp.square(q - q_target) 37 | return critic_loss.mean(), { 38 | "td_loss": critic_loss.mean(), 39 | "q": q.mean(), 40 | } 41 | 42 | 43 | def awr_actor_loss(q, v, dist, actions, temperature=1.0, adv_clip_max=100.0, mask=None): 44 | adv = q - v 45 | 46 | exp_adv = jnp.exp(adv * temperature) 47 | exp_adv = jnp.minimum(exp_adv, adv_clip_max) 48 | 49 | log_probs = dist.log_prob(actions) 50 | actor_loss = -(exp_adv * log_probs) 51 | 52 | if mask is not None: 53 | actor_loss *= mask 54 | actor_loss = jnp.sum(actor_loss) / jnp.sum(mask) 55 | else: 56 | actor_loss = jnp.mean(actor_loss) 57 | 58 | behavior_mse = jnp.square(dist.mode() - actions).sum(-1) 59 | 60 | return actor_loss, { 61 | "actor_loss": actor_loss, 62 | "behavior_logprob": log_probs.mean(), 63 | "behavior_entropy": -log_probs.mean(), 64 | "behavior_mse": behavior_mse.mean(), 65 | "adv_mean": adv.mean(), 66 | "adv_max": adv.max(), 67 | "adv_min": adv.min(), 68 | "predicted actions": dist.mode(), 69 | "dataset actions": actions, 70 | } 71 | 72 | 73 | def ddpg_bc_actor_loss(q, dist, actions, bc_loss_weight, mask=None): 74 | ddpg_objective = q # policy action values 75 | bc_loss = -dist.log_prob(actions) 76 | actor_loss = -ddpg_objective + bc_loss_weight * bc_loss 77 | if mask is not None: 78 | actor_loss *= mask 79 | actor_loss = jnp.sum(actor_loss) / jnp.sum(mask) 80 | else: 81 | actor_loss = jnp.mean(actor_loss) 82 | return actor_loss, { 83 | "bc_loss": bc_loss.mean(), 84 | "ddpg_objective": ddpg_objective.mean(), 85 | "actor_loss": actor_loss, 86 | } 87 | 88 | 89 | class IQLAgent(flax.struct.PyTreeNode): 90 | state: JaxRLTrainState 91 | config: dict = nonpytree_field() 92 | 93 | def forward_policy( 94 | self, 95 | observations: Data, 96 | rng: Optional[PRNGKey] = None, 97 | *, 98 | grad_params: Optional[Params] = None, 99 | train: bool = True, 100 | ): 101 | """ 102 | Forward pass for policy network. 103 | Pass grad_params to use non-default parameters (e.g. for gradients) 104 | """ 105 | if train: 106 | assert rng is not None, "Must specify rng when training" 107 | return self.state.apply_fn( 108 | {"params": grad_params or self.state.params}, 109 | observations, 110 | name="actor", 111 | rngs={"dropout": rng} if train else {}, 112 | train=train, 113 | ) 114 | 115 | def forward_critic( 116 | self, 117 | observations: Data, 118 | actions: jax.Array, 119 | rng: Optional[PRNGKey] = None, 120 | *, 121 | grad_params: Optional[Params] = None, 122 | train: bool = True, 123 | ) -> jax.Array: 124 | """ 125 | Forward pass for critic network. 126 | Pass grad_params to use non-default parameters (e.g. for gradients). 127 | """ 128 | if train: 129 | assert rng is not None, "Must specify rng when training" 130 | qs = self.state.apply_fn( 131 | {"params": grad_params or self.state.params}, 132 | observations, 133 | actions, 134 | name="critic", 135 | rngs={"dropout": rng} if train else {}, 136 | train=train, 137 | ) 138 | 139 | return qs 140 | 141 | def forward_target_critic( 142 | self, 143 | observations: Data, 144 | actions: jax.Array, 145 | rng: Optional[PRNGKey] = None, 146 | ) -> jax.Array: 147 | """ 148 | Forward pass for target critic network. 149 | Pass grad_params to use non-default parameters (e.g. for gradients). 150 | """ 151 | return self.forward_critic( 152 | observations, actions, train=False, grad_params=self.state.target_params 153 | ) 154 | 155 | @partial(jax.jit, static_argnames="train") 156 | def forward_value( 157 | self, 158 | observations: Data, 159 | rng: Optional[PRNGKey] = None, 160 | *, 161 | grad_params: Optional[Params] = None, 162 | train: bool = True, 163 | ) -> jax.Array: 164 | """ 165 | Forward pass for value network. 166 | Pass grad_params 167 | """ 168 | if train: 169 | assert rng is not None, "Must specify rng when training" 170 | return self.state.apply_fn( 171 | {"params": grad_params or self.state.params}, 172 | observations, 173 | name="value", 174 | rngs={"dropout": rng} if train else {}, 175 | train=train, 176 | ) 177 | 178 | def forward_target_value( 179 | self, 180 | observations: Data, 181 | rng: PRNGKey, 182 | ) -> jax.Array: 183 | """ 184 | Forward pass for target value network. 185 | Pass grad_params to use non-default parameters (e.g. for gradients). 186 | """ 187 | return self.forward_value( 188 | observations, rng=rng, grad_params=self.state.target_params 189 | ) 190 | 191 | def _get_ensemble_q_value(self, q, rng): 192 | """ 193 | subsample to a single critic value given an ensemble 194 | """ 195 | if self.config["critic_subsample_size"] is not None: 196 | # REDQ 197 | rng, subsample_key = jax.random.split(rng) 198 | subsample_idcs = jax.random.randint( 199 | subsample_key, 200 | (self.config["critic_subsample_size"],), 201 | 0, 202 | self.config["critic_ensemble_size"], 203 | ) 204 | q = q[subsample_idcs] 205 | q = jnp.min(q, axis=0) 206 | else: 207 | # double Q 208 | q = jnp.min(q, axis=0) 209 | return q, rng 210 | 211 | def critic_loss_fn(self, batch, params: Params, rng: PRNGKey): 212 | batch_size = batch["rewards"].shape[0] 213 | 214 | rng, key = jax.random.split(rng) 215 | next_v = self.forward_value(batch["next_observations"], key) 216 | target_q = batch["rewards"] + self.config["discount"] * next_v * batch["masks"] 217 | chex.assert_shape(target_q, (batch_size,)) 218 | 219 | rng, key = jax.random.split(rng) 220 | q = self.forward_critic( 221 | batch["observations"], 222 | batch["actions"], 223 | key, 224 | grad_params=params, 225 | ) 226 | chex.assert_shape(q, (self.config["critic_ensemble_size"], batch_size)) 227 | 228 | # MSE loss 229 | critic_loss = jnp.square(q - target_q) 230 | chex.assert_shape( 231 | critic_loss, (self.config["critic_ensemble_size"], batch_size) 232 | ) 233 | 234 | return critic_loss.mean(), { 235 | "td_loss": critic_loss.mean(), 236 | "q": q.mean(), 237 | "target_q": target_q.mean(), 238 | } 239 | 240 | def value_loss_fn(self, batch, params: Params, rng: PRNGKey): 241 | rng, key = jax.random.split(rng) 242 | q = self.forward_target_critic( 243 | batch["observations"], batch["actions"], key 244 | ) # no gradient 245 | q, rng = self._get_ensemble_q_value(q, rng) # min over Q functions 246 | 247 | rng, key = jax.random.split(rng) 248 | v = self.forward_value(batch["observations"], key, grad_params=params) 249 | 250 | # expectile loss 251 | return iql_value_loss(q, v, self.config["expectile"]) 252 | 253 | def policy_loss_fn(self, batch, params: Params, rng: PRNGKey): 254 | rng, key = jax.random.split(rng) 255 | 256 | if self.config["update_actor_with_target_adv"]: 257 | critic_fn = self.forward_target_critic 258 | else: 259 | # Seohong: not using the target will make updates faster 260 | critic_fn = self.forward_critic 261 | 262 | rng, key = jax.random.split(rng) 263 | dist = self.forward_policy(batch["observations"], key, grad_params=params) 264 | mask = batch.get("actor_loss_mask", None) 265 | 266 | if self.config["actor_type"] == "awr": 267 | 268 | q = critic_fn(batch["observations"], batch["actions"], key) # no gradient 269 | q, rng = self._get_ensemble_q_value(q, rng) # min over Q functions 270 | 271 | rng, key = jax.random.split(rng) 272 | v = self.forward_value(batch["observations"], key) # no gradients 273 | 274 | return awr_actor_loss( 275 | q, 276 | v, 277 | dist, 278 | batch["actions"], 279 | self.config["temperature"], 280 | mask=mask, 281 | ) 282 | elif self.config["actor_type"] == "ddpg+bc": 283 | 284 | rng, key = jax.random.split(rng) 285 | policy_a = dist.sample(seed=key) 286 | policy_q = critic_fn(batch["observations"], policy_a) # no gradient 287 | policy_q, rng = self._get_ensemble_q_value( 288 | policy_q, rng 289 | ) # min over Q functions 290 | 291 | return ddpg_bc_actor_loss( 292 | policy_q, 293 | dist, 294 | batch["actions"], 295 | self.config["actor_bc_loss_weight"], 296 | mask=mask, 297 | ) 298 | else: 299 | raise NotImplementedError 300 | 301 | @partial(jax.jit, static_argnames="pmap_axis") 302 | def update(self, batch: Batch, pmap_axis: str = None): 303 | rng, new_rng = jax.random.split(self.state.rng) 304 | batch_size = batch["rewards"].shape[0] 305 | 306 | loss_fns = { 307 | "critic": partial(self.critic_loss_fn, batch), 308 | "value": partial(self.value_loss_fn, batch), 309 | "actor": partial(self.policy_loss_fn, batch), 310 | } 311 | 312 | # compute gradients and update params 313 | new_state, info = self.state.apply_loss_fns( 314 | loss_fns, pmap_axis=pmap_axis, has_aux=True 315 | ) 316 | 317 | # update the target params 318 | new_state = new_state.target_update(self.config["target_update_rate"]) 319 | 320 | # update rng 321 | new_state = new_state.replace(rng=new_rng) 322 | 323 | # Log learning rates 324 | for name, opt_state in new_state.opt_states.items(): 325 | if ( 326 | hasattr(opt_state, "hyperparams") 327 | and "learning_rate" in opt_state.hyperparams.keys() 328 | ): 329 | info[f"{name}_lr"] = opt_state.hyperparams["learning_rate"] 330 | 331 | return self.replace(state=new_state), info 332 | 333 | @partial(jax.jit, static_argnames="argmax") 334 | def sample_actions( 335 | self, 336 | observations: np.ndarray, 337 | *, 338 | seed: Optional[PRNGKey] = None, 339 | argmax=False, 340 | ) -> jnp.ndarray: 341 | dist = self.forward_policy(observations, seed, train=False) 342 | if argmax: 343 | assert seed is None, "Cannot specify seed when sampling deterministically" 344 | actions = dist.mode() 345 | else: 346 | actions = dist.sample(seed=seed) 347 | return actions 348 | 349 | @jax.jit 350 | def get_debug_metrics(self, batch, **kwargs): 351 | 352 | dist = self.forward_policy(batch["observations"], train=False) 353 | pi_actions = dist.mode() 354 | log_probs = dist.log_prob(batch["actions"]) 355 | mse = ((pi_actions - batch["actions"]) ** 2).mean() 356 | 357 | v = self.forward_value(batch["observations"], train=False) 358 | next_v = self.forward_value(batch["next_observations"], train=False) 359 | target = batch["rewards"] + self.config["discount"] * next_v * batch["masks"] 360 | q = self.forward_critic(batch["observations"], batch["actions"], train=False) 361 | q, _ = self._get_ensemble_q_value(q, self.state.rng) # min over Q functions 362 | q_target = self.forward_target_critic(batch["observations"], batch["actions"]) 363 | q_target, _ = self._get_ensemble_q_value( 364 | q_target, self.state.rng 365 | ) # min over Q functions 366 | 367 | metrics = { 368 | "log_probs": log_probs, 369 | "action_mse": mse, 370 | "pi_actions": pi_actions, 371 | "v": v.mean(), 372 | "q": q.mean(), 373 | "value loss": expectile_loss(q_target - v, self.config["expectile"]).mean(), 374 | "critic mse loss": jnp.square(target - q).mean(), 375 | "advantage": target - v, 376 | "qf_advantage": q - v, 377 | } 378 | 379 | return metrics 380 | 381 | @classmethod 382 | def create( 383 | cls, 384 | rng: PRNGKey, 385 | observations: FrozenDict, 386 | actions: jnp.ndarray, 387 | # Model architecture 388 | encoder_def: nn.Module, 389 | shared_encoder: bool = True, 390 | critic_ensemble_size: int = 2, 391 | critic_subsample_size: Optional[int] = None, 392 | policy_network_kwargs: dict = { 393 | "hidden_dims": [256, 256], 394 | "kernel_init_type": "var_scaling", 395 | "kernel_scale_final": 1e-2, 396 | }, 397 | critic_network_kwargs: dict = { 398 | "hidden_dims": [256, 256], 399 | "kernel_init_type": "var_scaling", 400 | }, 401 | policy_kwargs: dict = { 402 | "tanh_squash_distribution": False, 403 | "std_parameterization": "exp", 404 | }, 405 | # Optimizer 406 | actor_optimizer_kwargs={ 407 | "learning_rate": 3e-4, 408 | }, 409 | value_critic_optimizer_kwargs={ 410 | "learning_rate": 3e-4, 411 | }, 412 | # Algorithm config 413 | discount=0.99, 414 | expectile=0.9, 415 | temperature=1.0, 416 | target_update_rate=0.005, 417 | update_actor_with_target_adv=True, 418 | actor_type="awr", 419 | actor_bc_loss_weight=0.0, 420 | **kwargs, 421 | ): 422 | assert actor_type in ("awr", "ddpg+bc") 423 | assert not ( 424 | actor_bc_loss_weight > 0 and actor_type == "awr" 425 | ), "BC loss is not yet supported with AWR" 426 | 427 | if shared_encoder: 428 | encoders = { 429 | "actor": encoder_def, 430 | "value": encoder_def, 431 | "critic": encoder_def, 432 | } 433 | else: 434 | encoders = { 435 | "actor": encoder_def, 436 | "value": copy.deepcopy(encoder_def), 437 | "critic": copy.deepcopy(encoder_def), 438 | } 439 | 440 | networks = { 441 | "actor": Policy( 442 | encoders["actor"], 443 | MLP(**policy_network_kwargs), 444 | action_dim=actions.shape[-1], 445 | **policy_kwargs, 446 | ), 447 | "value": ValueCritic(encoders["value"], MLP(**critic_network_kwargs)), 448 | "critic": Critic( 449 | encoders["critic"], 450 | network=ensemblize( 451 | partial(MLP, **critic_network_kwargs), critic_ensemble_size 452 | )(name="critic_ensemble"), 453 | ), 454 | } 455 | 456 | model_def = ModuleDict(networks) 457 | 458 | txs = { 459 | "actor": make_optimizer(**actor_optimizer_kwargs), 460 | "value": make_optimizer(**value_critic_optimizer_kwargs), 461 | "critic": make_optimizer(**value_critic_optimizer_kwargs), 462 | } 463 | 464 | rng, init_rng = jax.random.split(rng) 465 | params = model_def.init( 466 | init_rng, 467 | actor=[observations], 468 | value=[observations], 469 | critic=[observations, actions], 470 | )["params"] 471 | 472 | rng, create_rng = jax.random.split(rng) 473 | state = JaxRLTrainState.create( 474 | apply_fn=model_def.apply, 475 | params=params, 476 | txs=txs, 477 | target_params=params, 478 | rng=create_rng, 479 | ) 480 | 481 | config = flax.core.FrozenDict( 482 | dict( 483 | discount=discount, 484 | temperature=temperature, 485 | target_update_rate=target_update_rate, 486 | expectile=expectile, 487 | critic_ensemble_size=critic_ensemble_size, 488 | critic_subsample_size=critic_subsample_size, 489 | update_actor_with_target_adv=update_actor_with_target_adv, 490 | actor_type=actor_type, 491 | actor_bc_loss_weight=actor_bc_loss_weight, 492 | ) 493 | ) 494 | return cls(state, config) 495 | -------------------------------------------------------------------------------- /wsrl/agents/sac.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Optional, Tuple, Union 4 | 5 | import chex 6 | import distrax 7 | import flax 8 | import flax.linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | from absl import flags 12 | 13 | from wsrl.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 14 | from wsrl.common.optimizers import make_optimizer 15 | from wsrl.common.typing import Batch, Data, Params, PRNGKey 16 | from wsrl.networks.actor_critic_nets import Critic, Policy, ensemblize 17 | from wsrl.networks.lagrange import GeqLagrangeMultiplier 18 | from wsrl.networks.mlp import MLP 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | class SACAgent(flax.struct.PyTreeNode): 24 | """ 25 | Online actor-critic supporting several different algorithms depending on configuration: 26 | - SAC (default) 27 | - TD3 (policy_kwargs={"std_parameterization": "fixed", "fixed_std": 0.1}) 28 | - REDQ (critic_ensemble_size=10, critic_subsample_size=2) 29 | - SAC-ensemble (critic_ensemble_size>>1) 30 | """ 31 | 32 | state: JaxRLTrainState 33 | config: dict = nonpytree_field() 34 | 35 | def forward_critic( 36 | self, 37 | observations: Union[Data, Tuple[Data, Data]], 38 | actions: jax.Array, 39 | rng: Optional[PRNGKey] = None, 40 | *, 41 | grad_params: Optional[Params] = None, 42 | train: bool = True, 43 | ) -> jax.Array: 44 | """ 45 | Forward pass for critic network. 46 | Pass grad_params to use non-default parameters (e.g. for gradients). 47 | """ 48 | if train: 49 | assert rng is not None, "Must specify rng when training" 50 | if jnp.ndim(actions) == 3: 51 | # forward the q function with multiple actions on each state 52 | q = jax.vmap( 53 | lambda a: self.state.apply_fn( 54 | {"params": grad_params or self.state.params}, 55 | observations, 56 | a, 57 | name="critic", 58 | rngs={"dropout": rng} if train else {}, 59 | train=train, 60 | ), 61 | in_axes=1, 62 | out_axes=-1, 63 | )( 64 | actions 65 | ) # (ensemble_size, batch_size, n_actions) 66 | else: 67 | # forward the q function on 1 action on each state 68 | q = self.state.apply_fn( 69 | {"params": grad_params or self.state.params}, 70 | observations, 71 | actions, 72 | name="critic", 73 | rngs={"dropout": rng} if train else {}, 74 | train=train, 75 | ) # (ensemble_size, batch_size) 76 | 77 | return q 78 | 79 | def forward_target_critic( 80 | self, 81 | observations: Union[Data, Tuple[Data, Data]], 82 | actions: jax.Array, 83 | rng: PRNGKey, 84 | ) -> jax.Array: 85 | """ 86 | Forward pass for target critic network. 87 | Pass grad_params to use non-default parameters (e.g. for gradients). 88 | """ 89 | return self.forward_critic( 90 | observations, actions, rng=rng, grad_params=self.state.target_params 91 | ) 92 | 93 | def forward_policy( 94 | self, 95 | observations: Union[Data, Tuple[Data, Data]], 96 | rng: Optional[PRNGKey] = None, 97 | *, 98 | grad_params: Optional[Params] = None, 99 | train: bool = True, 100 | ) -> distrax.Distribution: 101 | """ 102 | Forward pass for policy network. 103 | Pass grad_params to use non-default parameters (e.g. for gradients). 104 | """ 105 | if train: 106 | assert rng is not None, "Must specify rng when training" 107 | return self.state.apply_fn( 108 | {"params": grad_params or self.state.params}, 109 | observations, 110 | name="actor", 111 | rngs={"dropout": rng} if train else {}, 112 | train=train, 113 | ) 114 | 115 | def forward_policy_and_sample( 116 | self, 117 | obs: Data, 118 | rng: PRNGKey, 119 | *, 120 | grad_params: Optional[Params] = None, 121 | repeat=None, 122 | ): 123 | rng, sample_rng = jax.random.split(rng) 124 | action_dist = self.forward_policy(obs, rng, grad_params=grad_params) 125 | if repeat: 126 | new_actions, log_pi = action_dist.sample_and_log_prob( 127 | seed=sample_rng, sample_shape=repeat 128 | ) 129 | new_actions = jnp.transpose( 130 | new_actions, (1, 0, 2) 131 | ) # (batch, repeat, action_dim) 132 | log_pi = jnp.transpose(log_pi, (1, 0)) # (batch, repeat) 133 | else: 134 | new_actions, log_pi = action_dist.sample_and_log_prob(seed=sample_rng) 135 | return new_actions, log_pi 136 | 137 | def forward_temperature( 138 | self, *, grad_params: Optional[Params] = None 139 | ) -> distrax.Distribution: 140 | """ 141 | Forward pass for temperature Lagrange multiplier. 142 | Pass grad_params to use non-default parameters (e.g. for gradients). 143 | """ 144 | return self.state.apply_fn( 145 | {"params": grad_params or self.state.params}, name="temperature" 146 | ) 147 | 148 | @jax.jit 149 | def forward_value( 150 | self, 151 | observations: Union[Data, Tuple[Data, Data]], 152 | *, 153 | train: bool = False, 154 | ) -> jax.Array: 155 | """ 156 | Get the option state-value function 157 | This is never needed in training, only for evaluation 158 | """ 159 | pi_dist = self.forward_policy(observations, train=False) 160 | action = pi_dist.mode() 161 | q = self.forward_critic(observations, action, train=False) 162 | q = q.min(axis=0) 163 | return q 164 | 165 | def temperature_lagrange_penalty( 166 | self, entropy: jnp.ndarray, *, grad_params: Optional[Params] = None 167 | ) -> distrax.Distribution: 168 | """ 169 | Forward pass for Lagrange penalty for temperature. 170 | Pass grad_params to use non-default parameters (e.g. for gradients). 171 | """ 172 | return self.state.apply_fn( 173 | {"params": grad_params or self.state.params}, 174 | lhs=entropy, 175 | rhs=self.config["target_entropy"], 176 | name="temperature", 177 | ) 178 | 179 | def _compute_next_actions(self, batch, rng): 180 | """shared computation between loss functions""" 181 | batch_size = batch["rewards"].shape[0] 182 | sample_n_actions = ( 183 | self.config["n_actions"] if self.config["max_target_backup"] else None 184 | ) 185 | 186 | next_actions, next_actions_log_probs = self.forward_policy_and_sample( 187 | batch["next_observations"], 188 | rng, 189 | repeat=sample_n_actions, 190 | ) 191 | 192 | if sample_n_actions: 193 | chex.assert_shape(next_actions_log_probs, (batch_size, sample_n_actions)) 194 | else: 195 | chex.assert_shape(next_actions_log_probs, (batch_size,)) 196 | return next_actions, next_actions_log_probs 197 | 198 | def _process_target_next_qs(self, target_next_qs, next_actions_log_probs): 199 | """classes that inherit this class can add to this function 200 | e.g. CQL will add the cql_max_target_backup option 201 | """ 202 | if self.config["backup_entropy"]: 203 | temperature = self.forward_temperature() 204 | target_next_qs = target_next_qs - temperature * next_actions_log_probs 205 | 206 | if self.config["max_target_backup"]: 207 | max_target_indices = jnp.expand_dims( 208 | jnp.argmax(target_next_qs, axis=-1), axis=-1 209 | ) 210 | target_next_qs = jnp.take_along_axis( 211 | target_next_qs, max_target_indices, axis=-1 212 | ).squeeze(-1) 213 | next_actions_log_probs = jnp.take_along_axis( 214 | next_actions_log_probs, max_target_indices, axis=-1 215 | ).squeeze(-1) 216 | 217 | return target_next_qs 218 | 219 | def critic_loss_fn(self, batch, params: Params, rng: PRNGKey): 220 | """classes that inherit this class can change this function""" 221 | batch_size = batch["rewards"].shape[0] 222 | rng, next_action_sample_key = jax.random.split(rng) 223 | next_actions, next_actions_log_probs = self._compute_next_actions( 224 | batch, next_action_sample_key 225 | ) 226 | # (batch_size, ) for sac, (batch_size, cql_n_actions) for cql 227 | 228 | # Evaluate next Qs for all ensemble members (cheap because we're only doing the forward pass) 229 | target_next_qs = self.forward_target_critic( 230 | batch["next_observations"], 231 | next_actions, 232 | rng=rng, 233 | ) # (critic_ensemble_size, batch_size) 234 | 235 | # Subsample if requested 236 | if self.config["critic_subsample_size"] is not None: 237 | rng, subsample_key = jax.random.split(rng) 238 | subsample_idcs = jax.random.randint( 239 | subsample_key, 240 | (self.config["critic_subsample_size"],), 241 | 0, 242 | self.config["critic_ensemble_size"], 243 | ) 244 | target_next_qs = target_next_qs[subsample_idcs] 245 | 246 | # Minimum Q across (subsampled) ensemble members 247 | target_next_min_q = target_next_qs.min(axis=0) 248 | chex.assert_equal_shape([target_next_min_q, next_actions_log_probs]) 249 | # (batch_size,) for sac, (batch_size, cql_n_actions) for cql 250 | 251 | target_next_min_q = self._process_target_next_qs( 252 | target_next_min_q, 253 | next_actions_log_probs, 254 | ) 255 | 256 | target_q = ( 257 | batch["rewards"] 258 | + self.config["discount"] * batch["masks"] * target_next_min_q 259 | ) 260 | chex.assert_shape(target_q, (batch_size,)) 261 | 262 | predicted_qs = self.forward_critic( 263 | batch["observations"], 264 | batch["actions"], 265 | rng=rng, 266 | grad_params=params, 267 | ) 268 | chex.assert_shape( 269 | predicted_qs, (self.config["critic_ensemble_size"], batch_size) 270 | ) 271 | 272 | # MSE loss 273 | target_qs = target_q[None].repeat(self.config["critic_ensemble_size"], axis=0) 274 | chex.assert_equal_shape([predicted_qs, target_qs]) 275 | critic_loss = jnp.mean((predicted_qs - target_qs) ** 2) 276 | 277 | info = { 278 | "critic_loss": critic_loss, 279 | "predicted_qs": jnp.mean(predicted_qs), 280 | "target_qs": jnp.mean(target_q), 281 | } 282 | 283 | return critic_loss, info 284 | 285 | def policy_loss_fn(self, batch, params: Params, rng: PRNGKey): 286 | batch_size = batch["rewards"].shape[0] 287 | temperature = self.forward_temperature() 288 | 289 | rng, policy_rng, sample_rng, critic_rng = jax.random.split(rng, 4) 290 | action_distributions = self.forward_policy( 291 | batch["observations"], 292 | rng=policy_rng, 293 | grad_params=params, 294 | ) 295 | actions, log_probs = action_distributions.sample_and_log_prob(seed=sample_rng) 296 | 297 | predicted_qs = self.forward_critic( 298 | batch["observations"], 299 | actions, 300 | rng=critic_rng, 301 | ) 302 | predicted_q = predicted_qs.min(axis=0) 303 | chex.assert_shape(predicted_q, (batch_size,)) 304 | chex.assert_shape(log_probs, (batch_size,)) 305 | 306 | nll_objective = -jnp.mean( 307 | action_distributions.log_prob(jnp.clip(batch["actions"], -0.99, 0.99)) 308 | ) 309 | actor_objective = predicted_q 310 | actor_loss = -jnp.mean(actor_objective) + jnp.mean(temperature * log_probs) 311 | 312 | info = { 313 | "actor_loss": actor_loss, 314 | "actor_nll": nll_objective, 315 | "temperature": temperature, 316 | "entropy": -log_probs.mean(), 317 | "log_probs": log_probs, 318 | "actions_mse": ((actions - batch["actions"]) ** 2).sum(axis=-1).mean(), 319 | "dataset_rewards": batch["rewards"], 320 | "mc_returns": batch.get("mc_returns", None), 321 | "actions": actions, 322 | } 323 | 324 | # optionally add BC regularization 325 | if self.config.get("bc_loss_weight", 0.0) > 0: 326 | bc_loss = -action_distributions.log_prob(batch["actions"]).mean() 327 | 328 | info["actor_q_loss"] = actor_loss 329 | info["bc_loss"] = bc_loss 330 | info["actor_bc_loss_weight"] = self.config["bc_loss_weight"] 331 | 332 | actor_loss = ( 333 | actor_loss * (1 - self.config["bc_loss_weight"]) 334 | + bc_loss * self.config["bc_loss_weight"] 335 | ) 336 | info["actor_loss"] = actor_loss 337 | 338 | return actor_loss, info 339 | 340 | def temperature_loss_fn(self, batch, params: Params, rng: PRNGKey): 341 | rng, next_action_sample_key = jax.random.split(rng) 342 | next_actions, next_actions_log_probs = self._compute_next_actions( 343 | batch, next_action_sample_key 344 | ) 345 | 346 | entropy = -next_actions_log_probs.mean() 347 | temperature_loss = self.temperature_lagrange_penalty( 348 | entropy, 349 | grad_params=params, 350 | ) 351 | return temperature_loss, {"temperature_loss": temperature_loss} 352 | 353 | def loss_fns(self, batch): 354 | return { 355 | "critic": partial(self.critic_loss_fn, batch), 356 | "actor": partial(self.policy_loss_fn, batch), 357 | "temperature": partial(self.temperature_loss_fn, batch), 358 | } 359 | 360 | @partial(jax.jit, static_argnames=("pmap_axis", "networks_to_update")) 361 | def update( 362 | self, 363 | batch: Batch, 364 | *, 365 | pmap_axis: str = None, 366 | networks_to_update: frozenset[str] = frozenset( 367 | {"actor", "critic", "temperature"} 368 | ), 369 | ) -> Tuple["SACAgent", dict]: 370 | """ 371 | Take one gradient step on all (or a subset) of the networks in the agent. 372 | 373 | Parameters: 374 | batch: Batch of data to use for the update. Should have keys: 375 | "observations", "actions", "next_observations", "rewards", "masks". 376 | pmap_axis: Axis to use for pmap (if None, no pmap is used). 377 | networks_to_update: Names of networks to update (default: all networks). 378 | For example, in high-UTD settings it's common to update the critic 379 | many times and only update the actor (and other networks) once. 380 | Returns: 381 | Tuple of (new agent, info dict). 382 | """ 383 | batch_size = batch["rewards"].shape[0] 384 | chex.assert_tree_shape_prefix(batch, (batch_size,)) 385 | 386 | rng, key = jax.random.split(self.state.rng) 387 | 388 | # Compute gradients and update params 389 | loss_fns = self.loss_fns(batch) 390 | 391 | # Only compute gradients for specified steps 392 | assert networks_to_update.issubset( 393 | loss_fns.keys() 394 | ), f"Invalid gradient steps: {networks_to_update}" 395 | for key in loss_fns.keys() - networks_to_update: 396 | loss_fns[key] = lambda params, rng: (0.0, {}) 397 | 398 | new_state, info = self.state.apply_loss_fns( 399 | loss_fns, pmap_axis=pmap_axis, has_aux=True 400 | ) 401 | 402 | # Update target network (if requested) 403 | if "critic" in networks_to_update: 404 | new_state = new_state.target_update(self.config["soft_target_update_rate"]) 405 | 406 | # Update RNG 407 | new_state = new_state.replace(rng=rng) 408 | 409 | # Log learning rates 410 | for name, opt_state in new_state.opt_states.items(): 411 | if ( 412 | hasattr(opt_state, "hyperparams") 413 | and "learning_rate" in opt_state.hyperparams.keys() 414 | ): 415 | info[f"{name}_lr"] = opt_state.hyperparams["learning_rate"] 416 | 417 | return self.replace(state=new_state), info 418 | 419 | @partial(jax.jit, static_argnames=("argmax",)) 420 | def sample_actions( 421 | self, 422 | observations: Data, 423 | *, 424 | seed: Optional[PRNGKey] = None, 425 | argmax: bool = False, 426 | **kwargs, 427 | ) -> jnp.ndarray: 428 | """ 429 | Sample actions from the policy network, **using an external RNG** (or approximating the argmax by the mode). 430 | The internal RNG will not be updated. 431 | """ 432 | dist = self.forward_policy(observations, rng=seed, train=False) 433 | if argmax: 434 | assert seed is None, "Cannot specify seed when sampling deterministically" 435 | return dist.mode() 436 | else: 437 | return dist.sample(seed=seed) 438 | 439 | @jax.jit 440 | def get_debug_metrics(self, batch, **kwargs): 441 | rng, critic_rng, actor_rng = jax.random.split(self.state.rng, 3) 442 | critic_loss, critic_info = self.critic_loss_fn( 443 | batch, self.state.params, critic_rng 444 | ) 445 | policy_loss, policy_info = self.policy_loss_fn( 446 | batch, self.state.params, actor_rng 447 | ) 448 | 449 | metrics = {**critic_info, **policy_info} 450 | 451 | return metrics 452 | 453 | def update_config(self, new_config): 454 | """update the frozen self.config""" 455 | object.__setattr__(self, "config", self.config.copy(new_config)) 456 | 457 | @classmethod 458 | def _create_common( 459 | cls, 460 | rng: PRNGKey, 461 | observations: Data, 462 | actions: jnp.ndarray, 463 | # Models 464 | actor_def: nn.Module, 465 | critic_def: nn.Module, 466 | temperature_def: nn.Module, 467 | # Optimizer 468 | actor_optimizer_kwargs={ 469 | "learning_rate": 3e-4, 470 | }, 471 | critic_optimizer_kwargs={ 472 | "learning_rate": 3e-4, 473 | }, 474 | temperature_optimizer_kwargs={ 475 | "learning_rate": 3e-4, 476 | }, 477 | # Algorithm config 478 | discount: float = 0.99, 479 | n_actions: int = 10, 480 | max_target_backup: bool = False, 481 | soft_target_update_rate: float = 0.005, 482 | target_entropy: Optional[float] = None, 483 | backup_entropy: bool = False, 484 | critic_ensemble_size: int = 2, 485 | critic_subsample_size: Optional[int] = None, 486 | # bc loss: 487 | bc_loss_weight: float = 0.0, 488 | **kwargs, 489 | ): 490 | """common part of both create() methods. 491 | for real create, call create() or create_states()""" 492 | networks = { 493 | "actor": actor_def, 494 | "critic": critic_def, 495 | "temperature": temperature_def, 496 | } 497 | 498 | model_def = ModuleDict(networks) 499 | 500 | # Define optimizers 501 | txs = { 502 | "actor": make_optimizer(**actor_optimizer_kwargs), 503 | "critic": make_optimizer(**critic_optimizer_kwargs), 504 | "temperature": make_optimizer(**temperature_optimizer_kwargs), 505 | } 506 | 507 | rng, init_rng = jax.random.split(rng) 508 | network_input = observations 509 | params = model_def.init( 510 | init_rng, 511 | actor=[network_input], 512 | critic=[network_input, actions], 513 | temperature=[], 514 | )["params"] 515 | 516 | rng, create_rng = jax.random.split(rng) 517 | state = JaxRLTrainState.create( 518 | apply_fn=model_def.apply, 519 | params=params, 520 | txs=txs, 521 | target_params=params, 522 | rng=create_rng, 523 | ) 524 | 525 | # Config 526 | if target_entropy is None or target_entropy >= 0.0: 527 | target_entropy = -actions.shape[-1] 528 | 529 | return cls( 530 | state=state, 531 | config=dict( 532 | critic_ensemble_size=critic_ensemble_size, 533 | critic_subsample_size=critic_subsample_size, 534 | discount=discount, 535 | soft_target_update_rate=soft_target_update_rate, 536 | target_entropy=target_entropy, 537 | backup_entropy=backup_entropy, 538 | bc_loss_weight=bc_loss_weight, 539 | n_actions=n_actions, 540 | max_target_backup=max_target_backup, 541 | **kwargs, 542 | ), 543 | ) 544 | 545 | @classmethod 546 | def create( 547 | cls, 548 | rng: PRNGKey, 549 | observations: Data, 550 | actions: jnp.ndarray, 551 | # Model architecture 552 | encoder_def: nn.Module, 553 | shared_encoder: bool = True, 554 | critic_network_kwargs: dict = { 555 | "hidden_dims": [256, 256], 556 | }, 557 | policy_network_kwargs: dict = { 558 | "hidden_dims": [256, 256], 559 | }, 560 | policy_kwargs: dict = { 561 | "tanh_squash_distribution": True, 562 | "std_parameterization": "exp", 563 | }, 564 | critic_ensemble_size: int = 2, 565 | critic_subsample_size: Optional[int] = None, 566 | temperature_init: float = 1.0, 567 | **kwargs, 568 | ): 569 | """ 570 | Create a new pixel-based agent, with no encoders. 571 | This is the default create. 572 | Call cls.create_states to create a state-based agent. 573 | """ 574 | 575 | if shared_encoder: 576 | encoders = { 577 | "actor": encoder_def, 578 | "critic": encoder_def, 579 | } 580 | else: 581 | encoders = { 582 | "actor": encoder_def, 583 | "critic": copy.deepcopy(encoder_def), 584 | } 585 | 586 | # Define networks 587 | policy_def = Policy( 588 | encoder=encoders["actor"], 589 | network=MLP(**policy_network_kwargs), 590 | action_dim=actions.shape[-1], 591 | **policy_kwargs, 592 | name="actor", 593 | ) 594 | 595 | critic_backbone = partial(MLP, **critic_network_kwargs) 596 | critic_backbone = ensemblize(critic_backbone, critic_ensemble_size)( 597 | name="critic_ensemble" 598 | ) 599 | critic_def = partial( 600 | Critic, 601 | encoder=encoders["critic"], 602 | network=critic_backbone, 603 | )(name="critic") 604 | 605 | temperature_def = GeqLagrangeMultiplier( 606 | init_value=temperature_init, 607 | constraint_shape=(), 608 | constraint_type="geq", 609 | name="temperature", 610 | ) 611 | 612 | return cls._create_common( 613 | rng, 614 | observations, 615 | actions, 616 | actor_def=policy_def, 617 | critic_def=critic_def, 618 | temperature_def=temperature_def, 619 | critic_ensemble_size=critic_ensemble_size, 620 | critic_subsample_size=critic_subsample_size, 621 | **kwargs, 622 | ) 623 | 624 | @partial(jax.jit, static_argnames=("utd_ratio", "pmap_axis")) 625 | def update_high_utd( 626 | self, 627 | batch: Batch, 628 | *, 629 | utd_ratio: int, 630 | pmap_axis: Optional[str] = None, 631 | ) -> Tuple["SACAgent", dict]: 632 | """ 633 | Fast JITted high-UTD version of `.update`. 634 | 635 | Splits the batch into minibatches, performs `utd_ratio` critic 636 | (and target) updates, and then one actor/temperature update. 637 | 638 | Batch dimension must be divisible by `utd_ratio`. 639 | """ 640 | batch_size = batch["rewards"].shape[0] 641 | assert ( 642 | batch_size % utd_ratio == 0 643 | ), f"Batch size {batch_size} must be divisible by UTD ratio {utd_ratio}" 644 | minibatch_size = batch_size // utd_ratio 645 | chex.assert_tree_shape_prefix(batch, (batch_size,)) 646 | 647 | def scan_body(carry: Tuple[SACAgent], data: Tuple[Batch]): 648 | (agent,) = carry 649 | (minibatch,) = data 650 | agent, info = agent.update( 651 | minibatch, 652 | pmap_axis=pmap_axis, 653 | networks_to_update=frozenset({"critic"}), 654 | ) 655 | return (agent,), info 656 | 657 | def make_minibatch(data: jnp.ndarray): 658 | return jnp.reshape(data, (utd_ratio, minibatch_size) + data.shape[1:]) 659 | 660 | minibatches = jax.tree_map(make_minibatch, batch) 661 | 662 | (agent,), critic_infos = jax.lax.scan(scan_body, (self,), (minibatches,)) 663 | 664 | critic_infos = jax.tree_map(lambda x: jnp.mean(x, axis=0), critic_infos) 665 | del critic_infos["actor"] 666 | del critic_infos["temperature"] 667 | 668 | # Take one gradient descent step on the actor and temperature 669 | agent, actor_temp_infos = agent.update( 670 | batch, 671 | pmap_axis=pmap_axis, 672 | networks_to_update=frozenset({"actor", "temperature"}), 673 | ) 674 | del actor_temp_infos["critic"] 675 | 676 | infos = {**critic_infos, **actor_temp_infos} 677 | 678 | return agent, infos 679 | -------------------------------------------------------------------------------- /wsrl/common/common.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, Union 3 | 4 | import flax 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import optax 9 | from flax import struct 10 | 11 | from wsrl.common.typing import Params, PRNGKey 12 | 13 | nonpytree_field = functools.partial(flax.struct.field, pytree_node=False) 14 | 15 | 16 | def shard_batch(batch, sharding): 17 | """Shards a batch across devices along its first dimension. 18 | 19 | Args: 20 | batch: A pytree of arrays. 21 | sharding: A jax Sharding object with shape (num_devices,). 22 | """ 23 | return jax.tree_map( 24 | lambda x: jax.device_put( 25 | x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) 26 | ), 27 | batch, 28 | ) 29 | 30 | 31 | class ModuleDict(nn.Module): 32 | """ 33 | Utility class for wrapping a dictionary of modules. This is useful when you have multiple modules that you want to 34 | initialize all at once (creating a single `params` dictionary), but you want to be able to call them separately 35 | later. As a bonus, the modules may have sub-modules nested inside them that share parameters (e.g. an image encoder) 36 | and Flax will automatically handle this without duplicating the parameters. 37 | 38 | To initialize the modules, call `init` with no `name` kwarg, and then pass the example arguments to each module as 39 | additional kwargs. To call the modules, pass the name of the module as the `name` kwarg, and then pass the arguments 40 | to the module as additional args or kwargs. 41 | 42 | Example usage: 43 | ``` 44 | shared_encoder = Encoder() 45 | actor = Actor(encoder=shared_encoder) 46 | critic = Critic(encoder=shared_encoder) 47 | 48 | model_def = ModuleDict({"actor": actor, "critic": critic}) 49 | params = model_def.init(rng_key, actor=example_obs, critic=(example_obs, example_action)) 50 | 51 | actor_output = model_def.apply({"params": params}, example_obs, name="actor") 52 | critic_output = model_def.apply({"params": params}, example_obs, action=example_action, name="critic") 53 | ``` 54 | """ 55 | 56 | modules: Dict[str, nn.Module] 57 | 58 | @nn.compact 59 | def __call__(self, *args, name=None, **kwargs): 60 | if name is None: 61 | if kwargs.keys() != self.modules.keys(): 62 | raise ValueError( 63 | f"When `name` is not specified, kwargs must contain the arguments for each module. " 64 | f"Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}" 65 | ) 66 | out = {} 67 | for key, value in kwargs.items(): 68 | if isinstance(value, Mapping): 69 | out[key] = self.modules[key](**value) 70 | elif isinstance(value, Sequence): 71 | out[key] = self.modules[key](*value) 72 | else: 73 | out[key] = self.modules[key](value) 74 | return out 75 | 76 | return self.modules[name](*args, **kwargs) 77 | 78 | 79 | class JaxRLTrainState(struct.PyTreeNode): 80 | """ 81 | Custom TrainState class to replace `flax.training.train_state.TrainState`. 82 | 83 | Adds support for holding target params and updating them via polyak 84 | averaging. Adds the ability to hold an rng key for dropout. 85 | 86 | Also generalizes the TrainState to support an arbitrary pytree of 87 | optimizers, `txs`. When `apply_gradients()` is called, the `grads` argument 88 | must have `txs` as a prefix. This is backwards-compatible, meaning `txs` can 89 | be a single optimizer and `grads` can be a single tree with the same 90 | structure as `self.params`. 91 | 92 | Also adds a convenience method `apply_loss_fns` that takes a pytree of loss 93 | functions with the same structure as `txs`, computes gradients, and applies 94 | them using `apply_gradients`. 95 | 96 | Attributes: 97 | step: The current training step. 98 | apply_fn: The function used to apply the model. 99 | params: The model parameters. 100 | target_params: The target model parameters. 101 | txs: The optimizer or pytree of optimizers. 102 | opt_states: The optimizer state or pytree of optimizer states. 103 | rng: The internal rng state. 104 | """ 105 | 106 | step: int 107 | apply_fn: Callable = struct.field(pytree_node=False) 108 | params: Params 109 | target_params: Params 110 | txs: Any = struct.field(pytree_node=False) 111 | opt_states: Any 112 | rng: PRNGKey 113 | 114 | @staticmethod 115 | def _tx_tree_map(*args, **kwargs): 116 | return jax.tree_map( 117 | *args, 118 | is_leaf=lambda x: isinstance(x, optax.GradientTransformation), 119 | **kwargs, 120 | ) 121 | 122 | def target_update(self, tau: float) -> "JaxRLTrainState": 123 | """ 124 | Performs an update of the target params via polyak averaging. The new 125 | target params are given by: 126 | 127 | new_target_params = tau * params + (1 - tau) * target_params 128 | """ 129 | new_target_params = jax.tree_map( 130 | lambda p, tp: p * tau + tp * (1 - tau), self.params, self.target_params 131 | ) 132 | return self.replace(target_params=new_target_params) 133 | 134 | def apply_gradients(self, *, grads: Any) -> Tuple["JaxRLTrainState", Any]: 135 | """ 136 | Only difference from flax's TrainState is that `grads` must have 137 | `self.txs` as a tree prefix (i.e. where `self.txs` has a leaf, `grads` 138 | has a subtree with the same structure as `self.params`.) 139 | """ 140 | updates_and_new_states = self._tx_tree_map( 141 | lambda tx, opt_state, grad: tx.update(grad, opt_state, self.params), 142 | self.txs, 143 | self.opt_states, 144 | grads, 145 | ) 146 | updates = self._tx_tree_map(lambda _, x: x[0], self.txs, updates_and_new_states) 147 | new_opt_states = self._tx_tree_map( 148 | lambda _, x: x[1], self.txs, updates_and_new_states 149 | ) 150 | 151 | # not the cleanest, I know, but this flattens the leaves of `updates` 152 | # into a list where leaves are defined by `self.txs` 153 | updates_flat = [] 154 | self._tx_tree_map( 155 | lambda _, update: updates_flat.append(update), self.txs, updates 156 | ) 157 | 158 | # apply all the updates additively 159 | updates_acc = jax.tree_map( 160 | lambda *xs: jnp.sum(jnp.array(xs), axis=0), *updates_flat 161 | ) 162 | new_params = optax.apply_updates(self.params, updates_acc) 163 | 164 | return ( 165 | self.replace( 166 | step=self.step + 1, params=new_params, opt_states=new_opt_states 167 | ), 168 | updates, 169 | ) 170 | 171 | def apply_loss_fns( 172 | self, loss_fns: Any, pmap_axis: str = None, has_aux: bool = False 173 | ) -> Union["JaxRLTrainState", Tuple["JaxRLTrainState", Any]]: 174 | """ 175 | Convenience method to compute gradients based on `self.params` and apply 176 | them using `apply_gradients`. `loss_fns` must have the same structure as 177 | `txs`, and each leaf must be a function that takes two arguments: 178 | `params` and `rng`. 179 | 180 | This method automatically provides fresh rng to each loss function and 181 | updates this train state's internal rng key. 182 | 183 | Args: 184 | loss_fns: loss function or pytree of loss functions with same 185 | structure as `self.txs`. Each loss function must take `params` 186 | as the first argument and `rng` as the second argument, and return 187 | a scalar value. 188 | pmap_axis: if not None, gradients (and optionally auxiliary values) 189 | will be averaged over this axis 190 | has_aux: if True, each `loss_fn` returns a tuple of (loss, aux) where 191 | `aux` is a pytree of auxiliary values to be returned by this 192 | method. 193 | 194 | Returns: 195 | If `has_aux` is True, returns a tuple of (new_train_state, aux). 196 | Otherwise, returns the new train state. 197 | """ 198 | # create a pytree of rngs with the same structure as `loss_fns` 199 | treedef = jax.tree_util.tree_structure(loss_fns) 200 | new_rng, *rngs = jax.random.split(self.rng, treedef.num_leaves + 1) 201 | rngs = jax.tree_util.tree_unflatten(treedef, rngs) 202 | 203 | # compute gradients 204 | grads_and_aux = jax.tree_map( 205 | lambda loss_fn, rng: jax.grad(loss_fn, has_aux=has_aux)(self.params, rng), 206 | loss_fns, 207 | rngs, 208 | ) 209 | 210 | # update rng state 211 | self = self.replace(rng=new_rng) 212 | 213 | # average across devices if necessary 214 | if pmap_axis is not None: 215 | grads_and_aux = jax.lax.pmean(grads_and_aux, axis_name=pmap_axis) 216 | 217 | # apply gradients 218 | if has_aux: 219 | grads = jax.tree_map(lambda _, x: x[0], loss_fns, grads_and_aux) 220 | aux = jax.tree_map(lambda _, x: x[1], loss_fns, grads_and_aux) 221 | new_train_state, updates = self.apply_gradients(grads=grads) 222 | # log the norm values 223 | grad_norm = optax.global_norm(grads) 224 | param_norm = optax.global_norm(self.params) 225 | update_norm = optax.global_norm(updates) 226 | aux.update( 227 | { 228 | "grad_norm": grad_norm, 229 | "param_norm": param_norm, 230 | "update_norm": update_norm, 231 | } 232 | ) 233 | return new_train_state, aux 234 | else: 235 | return self.apply_gradients(grads=grads_and_aux)[0] 236 | 237 | @classmethod 238 | def create( 239 | cls, *, apply_fn, params, txs, target_params=None, rng=jax.random.PRNGKey(0) 240 | ): 241 | """ 242 | Initializes a new train state. 243 | 244 | Args: 245 | apply_fn: The function used to apply the model, typically `model_def.apply`. 246 | params: The model parameters, typically from `model_def.init`. 247 | txs: The optimizer or pytree of optimizers. 248 | target_params: The target model parameters. 249 | rng: The rng key used to initialize the rng chain for `apply_loss_fns`. 250 | """ 251 | return cls( 252 | step=0, 253 | apply_fn=apply_fn, 254 | params=params, 255 | target_params=target_params, 256 | txs=txs, 257 | opt_states=cls._tx_tree_map(lambda tx: tx.init(params), txs), 258 | rng=rng, 259 | ) 260 | -------------------------------------------------------------------------------- /wsrl/common/evaluation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict 3 | 4 | import gym 5 | import numpy as np 6 | 7 | 8 | def flatten(d, parent_key="", sep="."): 9 | items = [] 10 | for k, v in d.items(): 11 | new_key = parent_key + sep + k if parent_key else k 12 | if hasattr(v, "items"): 13 | items.extend(flatten(v, new_key, sep=sep).items()) 14 | else: 15 | items.append((new_key, v)) 16 | return dict(items) 17 | 18 | 19 | def add_to(dict_of_lists, single_dict): 20 | for k, v in single_dict.items(): 21 | dict_of_lists[k].append(v) 22 | 23 | 24 | def evaluate( 25 | policy_fn, env: gym.Env, num_episodes: int, clip_action: float = np.inf 26 | ) -> Dict[str, float]: 27 | stats = defaultdict(list) 28 | for _ in range(num_episodes): 29 | observation, info = env.reset() 30 | add_to(stats, flatten(info)) 31 | done = False 32 | while not done: 33 | action = policy_fn(observation) 34 | action = np.clip(action, -clip_action, clip_action) 35 | observation, _, terminated, truncated, info = env.step(action) 36 | done = terminated or truncated 37 | add_to(stats, flatten(info)) 38 | add_to(stats, flatten(info, parent_key="final")) 39 | 40 | for k, v in stats.items(): 41 | stats[k] = np.mean(v) 42 | return stats 43 | 44 | 45 | def evaluate_with_trajectories( 46 | policy_fn, env: gym.Env, num_episodes: int, clip_action: float = np.inf 47 | ) -> Dict[str, float]: 48 | trajectories = [] 49 | stats = defaultdict(list) 50 | 51 | for _ in range(num_episodes): 52 | trajectory = defaultdict(list) 53 | observation, info = env.reset() 54 | add_to(stats, flatten(info)) 55 | done = False 56 | while not done: 57 | action = policy_fn(observation) 58 | action = np.clip(action, -clip_action, clip_action) 59 | next_observation, r, terminated, truncated, info = env.step(action) 60 | done = terminated or truncated 61 | transition = dict( 62 | observations=observation, 63 | next_observations=next_observation, 64 | actions=action, 65 | rewards=r, 66 | dones=done, 67 | infos=info, 68 | masks=1 - terminated, 69 | ) 70 | add_to(trajectory, transition) 71 | add_to(stats, flatten(info)) 72 | observation = next_observation 73 | add_to(stats, flatten(info, parent_key="final")) 74 | trajectories.append(trajectory) 75 | 76 | for k, v in stats.items(): 77 | stats[k] = np.mean(v) 78 | return stats, trajectories 79 | -------------------------------------------------------------------------------- /wsrl/common/initialization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | 7 | def var_scaling_init(scale: Optional[float] = 1.0): 8 | return nn.initializers.variance_scaling(scale, "fan_avg", "uniform") 9 | 10 | 11 | def orthogonal_init(scale: Optional[float] = jnp.sqrt(2.0)): 12 | return nn.initializers.orthogonal(scale) 13 | 14 | 15 | def xavier_normal_init(): 16 | return nn.initializers.xavier_normal() 17 | 18 | 19 | def kaiming_init(): 20 | return nn.initializers.kaiming_normal() 21 | 22 | 23 | def xavier_uniform_init(): 24 | return nn.initializers.xavier_uniform() 25 | 26 | 27 | init_fns = { 28 | None: orthogonal_init, 29 | "var_scaling": var_scaling_init, 30 | "orthogonal": orthogonal_init, 31 | "xavier_normal": xavier_normal_init, 32 | "kaiming": kaiming_init, 33 | "xavier_uniform": xavier_uniform_init, 34 | } 35 | -------------------------------------------------------------------------------- /wsrl/common/optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import optax 4 | 5 | 6 | def make_optimizer( 7 | learning_rate: float = 3e-4, 8 | warmup_steps: int = 0, 9 | cosine_decay_steps: Optional[int] = None, 10 | weight_decay: Optional[float] = None, 11 | clip_grad_norm: Optional[float] = None, 12 | return_lr_schedule: bool = False, 13 | ) -> optax.GradientTransformation: 14 | if cosine_decay_steps is not None: 15 | learning_rate_schedule = optax.warmup_cosine_decay_schedule( 16 | init_value=0.0, 17 | peak_value=learning_rate, 18 | warmup_steps=warmup_steps, 19 | decay_steps=cosine_decay_steps, 20 | end_value=0.0, 21 | ) 22 | else: 23 | learning_rate_schedule = optax.join_schedules( 24 | [ 25 | optax.linear_schedule(0.0, learning_rate, warmup_steps), 26 | optax.constant_schedule(learning_rate), 27 | ], 28 | [warmup_steps], 29 | ) 30 | 31 | # Define optimizers 32 | @optax.inject_hyperparams 33 | def optimizer(learning_rate: float, weight_decay: Optional[float]): 34 | optimizer_stages = [] 35 | 36 | if clip_grad_norm is not None: 37 | optimizer_stages.append(optax.clip_by_global_norm(clip_grad_norm)) 38 | 39 | if weight_decay is not None: 40 | optimizer_stages.append( 41 | optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay) 42 | ) 43 | else: 44 | optimizer_stages.append(optax.adam(learning_rate=learning_rate)) 45 | 46 | return optax.chain(*optimizer_stages) 47 | 48 | if return_lr_schedule: 49 | return ( 50 | optimizer(learning_rate=learning_rate_schedule, weight_decay=weight_decay), 51 | learning_rate_schedule, 52 | ) 53 | else: 54 | return optimizer( 55 | learning_rate=learning_rate_schedule, weight_decay=weight_decay 56 | ) 57 | -------------------------------------------------------------------------------- /wsrl/common/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Sequence, Union 2 | 3 | import flax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | PRNGKey = Any 8 | Params = flax.core.FrozenDict[str, Any] 9 | Shape = Sequence[int] 10 | Dtype = Any # this could be a real type? 11 | InfoDict = Dict[str, float] 12 | Array = Union[np.ndarray, jnp.ndarray] 13 | Data = Union[Array, Dict[str, "Data"]] 14 | Batch = Dict[str, Data] 15 | # A method to be passed into TrainState.__call__ 16 | ModuleMethod = Union[str, Callable, None] 17 | -------------------------------------------------------------------------------- /wsrl/common/wandb.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | import string 4 | import tempfile 5 | from copy import copy 6 | from socket import gethostname 7 | 8 | import absl.flags as flags 9 | import ml_collections 10 | import wandb 11 | 12 | 13 | def _recursive_flatten_dict(d: dict): 14 | keys, values = [], [] 15 | for key, value in d.items(): 16 | if isinstance(value, dict): 17 | sub_keys, sub_values = _recursive_flatten_dict(value) 18 | keys += [f"{key}/{k}" for k in sub_keys] 19 | values += sub_values 20 | else: 21 | keys.append(key) 22 | values.append(value) 23 | return keys, values 24 | 25 | 26 | def generate_random_string(length=6): 27 | # Define the character set for the random string 28 | characters = string.digits # Use digits 0-9 29 | 30 | # Generate the random string by sampling from the character set 31 | random_string = "".join(random.choices(characters, k=length)) 32 | 33 | return "rnd" + random_string 34 | 35 | 36 | # Generate a 6-digit random string 37 | random_string = generate_random_string() 38 | print(random_string) 39 | 40 | 41 | class WandBLogger(object): 42 | @staticmethod 43 | def get_default_config(): 44 | config = ml_collections.ConfigDict() 45 | config.project = "wsrl" # WandB Project Name 46 | config.entity = ml_collections.config_dict.FieldReference(None, field_type=str) 47 | # Which entity to log as (default: your own user) 48 | config.exp_descriptor = "" # Run name (doesn't have to be unique) 49 | # Unique identifier for run (will be automatically generated unless 50 | # provided) 51 | config.unique_identifier = "" 52 | config.group = None 53 | return config 54 | 55 | def __init__( 56 | self, 57 | wandb_config, 58 | variant, 59 | random_str_in_identifier=False, 60 | wandb_output_dir=None, 61 | disable_online_logging=False, 62 | ): 63 | self.config = wandb_config 64 | if self.config.unique_identifier == "": 65 | self.config.unique_identifier = datetime.datetime.now().strftime( 66 | "%Y%m%d_%H%M%S" 67 | ) 68 | if random_str_in_identifier: 69 | self.config.unique_identifier += "_" + generate_random_string() 70 | 71 | self.config.experiment_id = ( 72 | self.experiment_id 73 | ) = f"{self.config.exp_descriptor}_{self.config.unique_identifier}" # NOQA 74 | 75 | print(self.config) 76 | 77 | if wandb_output_dir is None: 78 | wandb_output_dir = tempfile.mkdtemp() 79 | 80 | self._variant = copy(variant) 81 | 82 | if "hostname" not in self._variant: 83 | self._variant["hostname"] = gethostname() 84 | 85 | if disable_online_logging: 86 | mode = "disabled" 87 | else: 88 | mode = "online" 89 | 90 | self.run = wandb.init( 91 | config=self._variant, 92 | project=self.config.project, 93 | entity=self.config.entity, 94 | group=self.config.group, 95 | dir=wandb_output_dir, 96 | id=self.config.experiment_id, 97 | save_code=True, 98 | mode=mode, 99 | ) 100 | 101 | if flags.FLAGS.is_parsed(): 102 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS} 103 | else: 104 | flag_dict = {} 105 | for k in flag_dict: 106 | if isinstance(flag_dict[k], ml_collections.ConfigDict): 107 | flag_dict[k] = flag_dict[k].to_dict() 108 | wandb.config.update(flag_dict) 109 | 110 | def log(self, data: dict, step: int = None): 111 | data_flat = _recursive_flatten_dict(data) 112 | data = {k: v for k, v in zip(*data_flat)} 113 | wandb.log(data, step=step) 114 | -------------------------------------------------------------------------------- /wsrl/data/dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Dict, Iterable, Optional, Tuple, Union 3 | 4 | import jax 5 | import numpy as np 6 | from flax.core import frozen_dict 7 | from gym.utils import seeding 8 | 9 | from wsrl.common.typing import Data 10 | 11 | DatasetDict = Dict[str, Data] 12 | 13 | 14 | def _check_lengths(dataset_dict: DatasetDict, dataset_len: Optional[int] = None) -> int: 15 | for v in dataset_dict.values(): 16 | if isinstance(v, dict): 17 | dataset_len = dataset_len or _check_lengths(v, dataset_len) 18 | elif isinstance(v, np.ndarray): 19 | item_len = len(v) 20 | dataset_len = dataset_len or item_len 21 | assert dataset_len == item_len, "Inconsistent item lengths in the dataset." 22 | else: 23 | raise TypeError("Unsupported type.") 24 | return dataset_len 25 | 26 | 27 | def _subselect(dataset_dict: DatasetDict, index: np.ndarray) -> DatasetDict: 28 | new_dataset_dict = {} 29 | for k, v in dataset_dict.items(): 30 | if isinstance(v, dict): 31 | new_v = _subselect(v, index) 32 | elif isinstance(v, np.ndarray): 33 | new_v = v[index] 34 | else: 35 | raise TypeError("Unsupported type.") 36 | new_dataset_dict[k] = new_v 37 | return new_dataset_dict 38 | 39 | 40 | def _sample( 41 | dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray 42 | ) -> DatasetDict: 43 | if isinstance(dataset_dict, np.ndarray): 44 | return dataset_dict[indx] 45 | elif isinstance(dataset_dict, dict): 46 | batch = {} 47 | for k, v in dataset_dict.items(): 48 | batch[k] = _sample(v, indx) 49 | else: 50 | raise TypeError("Unsupported type.") 51 | return batch 52 | 53 | 54 | class Dataset(object): 55 | def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): 56 | self.dataset_dict = dataset_dict 57 | self.dataset_len = _check_lengths(dataset_dict) 58 | 59 | # Seeding similar to OpenAI Gym 60 | self._np_random = None 61 | if seed is not None: 62 | self.seed(seed) 63 | 64 | @property 65 | def np_random(self) -> np.random.RandomState: 66 | if self._np_random is None: 67 | self.seed() 68 | return self._np_random 69 | 70 | def seed(self, seed: Optional[int] = None) -> list: 71 | self._np_random, seed = seeding.np_random(seed) 72 | return [seed] 73 | 74 | def __len__(self) -> int: 75 | return self.dataset_len 76 | 77 | def sample( 78 | self, 79 | batch_size: int, 80 | keys: Optional[Iterable[str]] = None, 81 | indx: Optional[np.ndarray] = None, 82 | ) -> dict: 83 | if indx is None: 84 | indx = self.np_random.choice(len(self), size=batch_size, replace=True) 85 | 86 | batch = dict() 87 | 88 | if keys is None: 89 | keys = self.dataset_dict.keys() 90 | 91 | for k in keys: 92 | batch[k] = _sample(self.dataset_dict[k], indx) 93 | 94 | return batch 95 | 96 | def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: 97 | assert 0 < ratio < 1 98 | train_index = np.index_exp[: int(self.dataset_len * ratio)] 99 | test_index = np.index_exp[int(self.dataset_len * ratio) :] 100 | 101 | index = np.arange(len(self), dtype=np.int32) 102 | self.np_random.shuffle(index) 103 | train_index = index[: int(self.dataset_len * ratio)] 104 | test_index = index[int(self.dataset_len * ratio) :] 105 | 106 | train_dataset_dict = _subselect(self.dataset_dict, train_index) 107 | test_dataset_dict = _subselect(self.dataset_dict, test_index) 108 | return Dataset(train_dataset_dict), Dataset(test_dataset_dict) 109 | -------------------------------------------------------------------------------- /wsrl/data/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Iterable, Optional, Union 3 | 4 | import gym 5 | import gym.spaces 6 | import numpy as np 7 | from absl import flags 8 | 9 | from wsrl.data.dataset import Dataset, DatasetDict, _sample 10 | from wsrl.envs.env_common import calc_return_to_go 11 | 12 | 13 | def _init_replay_dict( 14 | obs_space: gym.Space, capacity: int 15 | ) -> Union[np.ndarray, DatasetDict]: 16 | if isinstance(obs_space, gym.spaces.Box): 17 | return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype) 18 | elif isinstance(obs_space, gym.spaces.Dict): 19 | data_dict = {} 20 | for k, v in obs_space.spaces.items(): 21 | data_dict[k] = _init_replay_dict(v, capacity) 22 | return data_dict 23 | else: 24 | raise TypeError() 25 | 26 | 27 | def _insert_recursively( 28 | dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int 29 | ): 30 | if isinstance(dataset_dict, np.ndarray): 31 | dataset_dict[insert_index] = data_dict 32 | elif isinstance(dataset_dict, dict): 33 | assert ( 34 | dataset_dict.keys() == data_dict.keys() 35 | ), f"{dataset_dict.keys()} != {data_dict.keys()}" 36 | for k in dataset_dict.keys(): 37 | _insert_recursively(dataset_dict[k], data_dict[k], insert_index) 38 | else: 39 | raise TypeError() 40 | 41 | 42 | class ReplayBuffer(Dataset): 43 | def __init__( 44 | self, 45 | observation_space: gym.Space, 46 | action_space: gym.Space, 47 | capacity: int, 48 | next_observation_space: Optional[gym.Space] = None, 49 | seed: Optional[int] = None, 50 | discount: Optional[float] = None, 51 | ): 52 | if next_observation_space is None: 53 | next_observation_space = observation_space 54 | 55 | observation_data = _init_replay_dict(observation_space, capacity) 56 | next_observation_data = _init_replay_dict(next_observation_space, capacity) 57 | dataset_dict = dict( 58 | observations=observation_data, 59 | next_observations=next_observation_data, 60 | actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype), 61 | rewards=np.empty((capacity,), dtype=np.float32), 62 | masks=np.empty((capacity,), dtype=bool), 63 | dones=np.empty((capacity,), dtype=np.float32), 64 | ) 65 | 66 | super().__init__(dataset_dict, seed) 67 | 68 | self._size = 0 69 | self._capacity = capacity 70 | self._insert_index = 0 71 | self._sequential_index = 0 72 | self.unsampled_indices = list(range(self._size)) 73 | self._discount = discount 74 | 75 | def __len__(self) -> int: 76 | return self._size 77 | 78 | def insert(self, data_dict: DatasetDict): 79 | _insert_recursively(self.dataset_dict, data_dict, self._insert_index) 80 | 81 | self._insert_index = (self._insert_index + 1) % self._capacity 82 | self._size = min(self._size + 1, self._capacity) 83 | 84 | def sample_without_repeat( 85 | self, 86 | batch_size: int, 87 | keys: Optional[Iterable[str]] = None, 88 | ) -> dict: 89 | if keys is None: 90 | keys = self.dataset_dict.keys() 91 | 92 | batch = dict() 93 | if len(self.unsampled_indices) < batch_size: 94 | raise ValueError("Not enough samples left to sample without repeat.") 95 | selected_indices = [] 96 | for _ in range(batch_size): 97 | idx = self.np_random.randint(len(self.unsampled_indices)) 98 | selected_indices.append(self.unsampled_indices[idx]) 99 | # Swap the selected index with the last unselected index 100 | self.unsampled_indices[idx], self.unsampled_indices[-1] = ( 101 | self.unsampled_indices[-1], 102 | self.unsampled_indices[idx], 103 | ) 104 | # Remove the last unselected index (which is now the selected index) 105 | self.unsampled_indices.pop() 106 | 107 | for k in keys: 108 | batch[k] = _sample(self.dataset_dict[k], np.array(selected_indices)) 109 | 110 | return batch 111 | 112 | def save(self, save_dir): 113 | save_buffer_file = os.path.join(save_dir, "online_buffer.npy") 114 | save_size_file = os.path.join(save_dir, "size.npy") 115 | np.save(save_buffer_file, self.dataset_dict) 116 | np.save(save_size_file, self._size) 117 | 118 | def load(self, save_dir): 119 | # TODO: maybe make sure the dataset_dict thats being loaded has mc_returns if self is ReplayBufferMC 120 | save_buffer_file = os.path.join(save_dir, "online_buffer.npy") 121 | save_size_file = os.path.join(save_dir, "size.npy") 122 | self.dataset_dict = np.load(save_buffer_file, allow_pickle=True).item() 123 | self._size = np.load(save_size_file, allow_pickle=True).item() 124 | self.unsampled_indices = list(range(self._size)) 125 | 126 | 127 | class ReplayBufferMC(ReplayBuffer): 128 | def __init__( 129 | self, 130 | observation_space: gym.Space, 131 | action_space: gym.Space, 132 | capacity: int, 133 | next_observation_space: Optional[gym.Space] = None, 134 | seed: Optional[int] = None, 135 | discount: Optional[float] = None, 136 | ): 137 | assert discount is not None, "ReplayBufferMC requires a discount factor" 138 | super().__init__( 139 | observation_space, 140 | action_space, 141 | capacity, 142 | next_observation_space, 143 | seed, 144 | discount, 145 | ) 146 | 147 | mc_returns = np.empty((capacity,), dtype=np.float32) 148 | self.dataset_dict["mc_returns"] = mc_returns 149 | 150 | self._allow_idxs = [] 151 | self._traj_start_idx = 0 152 | 153 | def insert(self, data_dict: DatasetDict): 154 | # assumes replay buffer capacity is more than the number of online steps 155 | assert self._size < self._capacity, "replay buffer has reached capacity" 156 | 157 | data_dict["mc_returns"] = None 158 | _insert_recursively(self.dataset_dict, data_dict, self._insert_index) 159 | 160 | # if "dones" not in data_dict: 161 | # data_dict["dones"] = 1 - data_dict["masks"] 162 | 163 | if data_dict["dones"] == 1.0: 164 | # compute the mc_returns 165 | FLAGS = flags.FLAGS 166 | rewards = self.dataset_dict["rewards"][ 167 | self._traj_start_idx : self._insert_index + 1 168 | ] 169 | masks = self.dataset_dict["masks"][ 170 | self._traj_start_idx : self._insert_index + 1 171 | ] 172 | self.dataset_dict["mc_returns"][ 173 | self._traj_start_idx : self._insert_index + 1 174 | ] = calc_return_to_go( 175 | FLAGS.env, 176 | rewards, 177 | masks, 178 | self._discount, 179 | ) 180 | 181 | self._allow_idxs.extend( 182 | list(range(self._traj_start_idx, self._insert_index + 1)) 183 | ) 184 | self._traj_start_idx = self._insert_index + 1 185 | 186 | self._size += 1 187 | self._insert_index += 1 188 | 189 | def sample( 190 | self, 191 | batch_size: int, 192 | keys: Optional[Iterable[str]] = None, 193 | indx: Optional[np.ndarray] = None, 194 | ) -> dict: 195 | if indx is None: 196 | indx = self.np_random.choice( 197 | self._allow_idxs, size=batch_size, replace=True 198 | ) 199 | batch = dict() 200 | 201 | if keys is None: 202 | keys = self.dataset_dict.keys() 203 | 204 | for k in keys: 205 | batch[k] = _sample(self.dataset_dict[k], indx) 206 | 207 | return batch 208 | -------------------------------------------------------------------------------- /wsrl/envs/adroit_binary_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: https://github.com/nakamotoo/Cal-QL/blob/ac6eafec22e8d60836573e1f488c7f626ce8a77e/JaxCQL/replay_buffer.py 3 | """ 4 | import os 5 | 6 | import numpy as np 7 | from absl import flags 8 | 9 | from wsrl.envs.env_common import calc_return_to_go 10 | 11 | DEMO_PATHS = os.environ.get("DATA_DIR_PREFIX", os.path.expanduser("~/adroit_data")) 12 | 13 | FLAGS = flags.FLAGS 14 | 15 | 16 | def get_hand_dataset_with_mc_calculation( 17 | env_name, 18 | gamma, 19 | add_expert_demos=True, 20 | add_bc_demos=True, 21 | reward_scale=1.0, 22 | reward_bias=0.0, 23 | pos_ind=-1, 24 | clip_action=None, 25 | ): 26 | assert env_name in [ 27 | "pen-binary-v0", 28 | "door-binary-v0", 29 | "relocate-binary-v0", 30 | "pen-binary", 31 | "door-binary", 32 | "relocate-binary", 33 | ] 34 | 35 | expert_demo_paths = { 36 | "pen-binary-v0": f"{DEMO_PATHS}/offpolicy_hand_data/pen2_sparse.npy", 37 | "door-binary-v0": f"{DEMO_PATHS}/offpolicy_hand_data/door2_sparse.npy", 38 | "relocate-binary-v0": f"{DEMO_PATHS}/offpolicy_hand_data/relocate2_sparse.npy", 39 | } 40 | 41 | bc_demo_paths = { 42 | "pen-binary-v0": f"{DEMO_PATHS}/offpolicy_hand_data/pen_bc_sparse4.npy", 43 | "door-binary-v0": f"{DEMO_PATHS}/offpolicy_hand_data/door_bc_sparse4.npy", 44 | "relocate-binary-v0": f"{DEMO_PATHS}/offpolicy_hand_data/relocate_bc_sparse4.npy", 45 | } 46 | 47 | def truncate_traj( 48 | env_name, 49 | dataset, 50 | i, 51 | gamma, 52 | start_index=None, 53 | end_index=None, 54 | ): 55 | """ 56 | This function truncates the i'th trajectory in dataset from start_index to end_index. 57 | Since in Adroit-binary datasets, we have trajectories like [-1, -1, -1, -1, 0, 0, 0, -1, -1] which transit from neg -> pos -> neg, 58 | we truncate the trajcotry from the beginning to the last positive reward, i.e., [-1, -1, -1, -1, 0, 0, 0] 59 | """ 60 | observations = np.array(dataset[i]["observations"])[start_index:end_index] 61 | next_observations = np.array(dataset[i]["next_observations"])[ 62 | start_index:end_index 63 | ] 64 | rewards = dataset[i]["rewards"][start_index:end_index] 65 | dones = rewards == 0 # by default, adroit has -1/0 rewards 66 | actions = np.array(dataset[i]["actions"])[start_index:end_index] 67 | mc_returns = calc_return_to_go( 68 | env_name, 69 | rewards * FLAGS.reward_scale + FLAGS.reward_bias, 70 | 1 - dones, 71 | gamma, 72 | infinite_horizon=False, 73 | ) 74 | 75 | return dict( 76 | observations=observations, 77 | next_observations=next_observations, 78 | actions=actions, 79 | rewards=rewards, 80 | dones=dones, 81 | masks=1 - dones, 82 | mc_returns=mc_returns, 83 | ) 84 | 85 | dataset_list = [] 86 | dataset_bc_list = [] 87 | 88 | if add_expert_demos: 89 | print("loading expert demos from:", expert_demo_paths[env_name]) 90 | dataset = np.load(expert_demo_paths[env_name], allow_pickle=True) 91 | 92 | for i in range(len(dataset)): 93 | N = len(dataset[i]["observations"]) 94 | for j in range(len(dataset[i]["observations"])): 95 | dataset[i]["observations"][j] = dataset[i]["observations"][j][ 96 | "state_observation" 97 | ] 98 | dataset[i]["next_observations"][j] = dataset[i]["next_observations"][j][ 99 | "state_observation" 100 | ] 101 | if ( 102 | np.array(dataset[i]["rewards"]).shape 103 | != np.array(dataset[i]["terminals"]).shape 104 | ): 105 | dataset[i]["rewards"] = dataset[i]["rewards"][:N] 106 | 107 | assert ( 108 | np.array(dataset[i]["rewards"]).shape 109 | == np.array(dataset[i]["terminals"]).shape 110 | ) 111 | dataset[i].pop("terminals", None) 112 | 113 | if not (0 in dataset[i]["rewards"]): 114 | continue 115 | 116 | trunc_ind = np.where(dataset[i]["rewards"] == 0)[0][pos_ind] + 1 117 | d_pos = truncate_traj( 118 | env_name, 119 | dataset, 120 | i, 121 | gamma, 122 | start_index=None, 123 | end_index=trunc_ind, 124 | ) 125 | dataset_list.append(d_pos) 126 | 127 | if add_bc_demos: 128 | print("loading BC demos from:", bc_demo_paths[env_name]) 129 | dataset_bc = np.load(bc_demo_paths[env_name], allow_pickle=True) 130 | for i in range(len(dataset_bc)): 131 | dataset_bc[i]["rewards"] = dataset_bc[i]["rewards"].squeeze() 132 | dataset_bc[i]["dones"] = dataset_bc[i]["terminals"].squeeze() 133 | dataset_bc[i].pop("terminals", None) 134 | 135 | if not (0 in dataset_bc[i]["rewards"]): 136 | continue 137 | trunc_ind = np.where(dataset_bc[i]["rewards"] == 0)[0][pos_ind] + 1 138 | d_pos = truncate_traj( 139 | env_name, 140 | dataset_bc, 141 | i, 142 | gamma, 143 | start_index=None, 144 | end_index=trunc_ind, 145 | ) 146 | dataset_bc_list.append(d_pos) 147 | 148 | dataset = np.concatenate([dataset_list, dataset_bc_list]) 149 | 150 | print("num offline trajs:", len(dataset)) 151 | concatenated = {} 152 | for key in dataset[0].keys(): 153 | if key in ["agent_infos", "env_infos"]: 154 | continue 155 | concatenated[key] = np.concatenate( 156 | [batch[key] for batch in dataset], axis=0 157 | ).astype(np.float32) 158 | 159 | # global transforms 160 | if clip_action: 161 | concatenated["actions"] = np.clip( 162 | concatenated["actions"], -clip_action, clip_action 163 | ) 164 | concatenated["rewards"] = concatenated["rewards"] * reward_scale + reward_bias 165 | 166 | return concatenated 167 | -------------------------------------------------------------------------------- /wsrl/envs/d4rl_dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Optional 3 | 4 | import d4rl 5 | import gym 6 | import gym.wrappers 7 | import numpy as np 8 | 9 | from wsrl.envs.env_common import calc_return_to_go 10 | from wsrl.utils.train_utils import concatenate_batches 11 | 12 | 13 | def get_d4rl_dataset_by_trajectory(env, dataset=None, terminate_on_end=False, **kwargs): 14 | """ 15 | This function heavily inherits from d4rl.qlearning_dataset 16 | Instead of returning a flat dataset that is a dictionary, this function 17 | returns a list of trajectories, each of which is a small dict-dataset. 18 | 19 | Returns datasets formatted for use by standard Q-learning algorithms, 20 | with observations, actions, next_observations, rewards, and a terminal 21 | flag. 22 | 23 | Args: 24 | env: An OfflineEnv object. 25 | dataset: An optional dataset to pass in for processing. If None, 26 | the dataset will default to env.get_dataset() 27 | terminate_on_end (bool): Set done=True on the last timestep 28 | in a trajectory. Default is False, and will discard the 29 | last timestep in each trajectory. 30 | **kwargs: Arguments to pass to env.get_dataset(). 31 | 32 | Returns: 33 | A dictionary containing keys: 34 | observations: An N x dim_obs array of observations. 35 | actions: An N x dim_action array of actions. 36 | next_observations: An N x dim_obs array of next observations. 37 | rewards: An N-dim float array of rewards. 38 | terminals: An N-dim boolean array of "done" or episode termination flags. 39 | """ 40 | if dataset is None: 41 | dataset = env.get_dataset(**kwargs) 42 | 43 | N = dataset["rewards"].shape[0] 44 | obs_ = [] 45 | next_obs_ = [] 46 | action_ = [] 47 | reward_ = [] 48 | done_ = [] 49 | 50 | # The newer version of the dataset adds an explicit 51 | # timeouts field. Keep old method for backwards compatability. 52 | use_timeouts = False 53 | if "timeouts" in dataset: 54 | use_timeouts = True 55 | 56 | episodes = [] 57 | episode_step = 0 58 | for i in range(N - 1): 59 | obs = dataset["observations"][i].astype(np.float32) 60 | new_obs = dataset["observations"][i + 1].astype(np.float32) 61 | action = dataset["actions"][i].astype(np.float32) 62 | reward = dataset["rewards"][i].astype(np.float32) 63 | done_bool = bool(dataset["terminals"][i]) 64 | 65 | if use_timeouts: 66 | final_timestep = dataset["timeouts"][i] 67 | else: 68 | final_timestep = episode_step == env._max_episode_steps - 1 69 | 70 | if done_bool or final_timestep: 71 | # record this episode and reset the stats 72 | episode = { 73 | "observations": np.array(obs_), 74 | "actions": np.array(action_), 75 | "next_observations": np.array(next_obs_), 76 | "rewards": np.array(reward_), 77 | "terminals": np.array(done_), 78 | } 79 | episodes.append(episode) 80 | 81 | episode_step = 0 82 | obs_ = [] 83 | next_obs_ = [] 84 | action_ = [] 85 | reward_ = [] 86 | done_ = [] 87 | 88 | if (not terminate_on_end) and final_timestep: 89 | # Skip this transition and don't apply terminals on the last step of an episode 90 | episode_step = 0 91 | continue 92 | 93 | obs_.append(obs) 94 | next_obs_.append(new_obs) 95 | action_.append(action) 96 | reward_.append(reward) 97 | done_.append(done_bool) 98 | episode_step += 1 99 | 100 | return episodes 101 | 102 | 103 | def get_d4rl_dataset( 104 | env, 105 | reward_scale: float = 1.0, 106 | reward_bias: float = 0.0, 107 | clip_action: Optional[float] = None, 108 | ): 109 | dataset = d4rl.qlearning_dataset(gym.make(env).unwrapped) 110 | 111 | if clip_action: 112 | dataset["actions"] = np.clip(dataset["actions"], -clip_action, clip_action) 113 | 114 | dones_float = np.zeros_like(dataset["rewards"]) 115 | 116 | if "kitchen" in env: 117 | # kitchen envs don't set the done signal correctly 118 | dones_float = dataset["rewards"] == 4 119 | 120 | else: 121 | # antmaze / locomotion envs 122 | for i in range(len(dones_float) - 1): 123 | if ( 124 | np.linalg.norm( 125 | dataset["observations"][i + 1] - dataset["next_observations"][i] 126 | ) 127 | > 1e-6 128 | or dataset["terminals"][i] == 1.0 129 | ): 130 | dones_float[i] = 1 131 | else: 132 | dones_float[i] = 0 133 | 134 | dones_float[-1] = 1 135 | 136 | # reward scale and bias 137 | dataset["rewards"] = dataset["rewards"] * reward_scale + reward_bias 138 | 139 | return dict( 140 | observations=dataset["observations"], 141 | actions=dataset["actions"], 142 | next_observations=dataset["next_observations"], 143 | rewards=dataset["rewards"], 144 | dones=np.logical_or(dataset["terminals"], dones_float), 145 | masks=1 - dataset["terminals"].astype(np.float32), 146 | ) 147 | 148 | 149 | def get_d4rl_dataset_with_mc_calculation( 150 | env_name, reward_scale, reward_bias, clip_action, gamma 151 | ): 152 | dataset = qlearning_dataset_and_calc_mc( 153 | gym.make(env_name).unwrapped, 154 | reward_scale, 155 | reward_bias, 156 | clip_action, 157 | gamma, 158 | ) 159 | 160 | return dict( 161 | observations=dataset["observations"], 162 | actions=dataset["actions"], 163 | next_observations=dataset["next_observations"], 164 | rewards=dataset["rewards"], 165 | dones=dataset["terminals"].astype(np.float32), 166 | mc_returns=dataset["mc_returns"], 167 | masks=1 - dataset["terminals"].astype(np.float32), 168 | ) 169 | 170 | 171 | def qlearning_dataset_and_calc_mc( 172 | env, 173 | reward_scale, 174 | reward_bias, 175 | clip_action, 176 | gamma, 177 | dataset=None, 178 | terminate_on_end=False, 179 | **kwargs 180 | ): 181 | # this funtion follows d4rl.qlearning_dataset 182 | # and adds the logic to calculate the return to go 183 | 184 | if dataset is None: 185 | dataset = env.get_dataset(**kwargs) 186 | N = dataset["rewards"].shape[0] 187 | data_ = collections.defaultdict(list) 188 | episodes_dict_list = [] 189 | 190 | # The newer version of the dataset adds an explicit 191 | # timeouts field. Keep old method for backwards compatability. 192 | use_timeouts = False 193 | if "timeouts" in dataset: 194 | use_timeouts = True 195 | 196 | # manually update the terminals for kitchen envs 197 | if "kitchen" in env.unwrapped.spec.id: 198 | # kitchen envs don't set the done signal correctly 199 | dataset["terminals"] = np.logical_or( 200 | dataset["terminals"], dataset["rewards"] == 4 201 | ) 202 | 203 | # iterate over transitions, put them into trajectories 204 | episode_step = 0 205 | for i in range(N): 206 | 207 | done_bool = bool(dataset["terminals"][i]) 208 | 209 | if use_timeouts: 210 | is_final_timestep = dataset["timeouts"][i] 211 | else: 212 | is_final_timestep = episode_step == env._max_episode_steps - 1 213 | 214 | if (not terminate_on_end) and is_final_timestep or i == N - 1: 215 | # Skip this transition and don't apply terminals on the last step of an episode 216 | pass 217 | else: 218 | for k in dataset: 219 | if k in ( 220 | "actions", 221 | "next_observations", 222 | "observations", 223 | "rewards", 224 | "terminals", 225 | "timeouts", 226 | ): 227 | data_[k].append(dataset[k][i]) 228 | if "next_observations" not in dataset.keys(): 229 | data_["next_observations"].append(dataset["observations"][i + 1]) 230 | episode_step += 1 231 | if (done_bool or is_final_timestep) and episode_step > 0: 232 | episode_step = 0 233 | episode_data = {} 234 | for k in data_: 235 | episode_data[k] = np.array(data_[k]) 236 | 237 | episode_data["rewards"] = ( 238 | episode_data["rewards"] * reward_scale + reward_bias 239 | ) 240 | episode_data["mc_returns"] = calc_return_to_go( 241 | env.spec.name, 242 | episode_data["rewards"], 243 | 1 - episode_data["terminals"], 244 | gamma, 245 | reward_scale, 246 | reward_bias, 247 | ) 248 | episode_data["actions"] = np.clip( 249 | episode_data["actions"], -clip_action, clip_action 250 | ) 251 | episodes_dict_list.append(episode_data) 252 | data_ = collections.defaultdict(list) 253 | return concatenate_batches(episodes_dict_list) 254 | -------------------------------------------------------------------------------- /wsrl/envs/env_common.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import d4rl 4 | import gym 5 | import mj_envs 6 | import numpy as np 7 | from absl import flags 8 | 9 | from wsrl.envs.wrappers import ( 10 | AdroitTerminalWrapper, 11 | KitchenTerminalWrapper, 12 | ScaledRewardWrapper, 13 | TruncationWrapper, 14 | ) 15 | 16 | FLAGS = flags.FLAGS 17 | 18 | 19 | def make_gym_env( 20 | env_name: str, 21 | reward_scale: Optional[float] = None, 22 | reward_bias: Optional[float] = None, 23 | scale_and_clip_action: bool = False, 24 | action_clip_lim: Optional[float] = None, 25 | max_episode_steps: Optional[int] = None, 26 | seed: int = 0, 27 | ): 28 | """ 29 | create a gym environment for antmaze, kitchen, adroit, and locomotion tasks. 30 | """ 31 | try: 32 | env = gym.make(env_name, seed=seed) 33 | except TypeError: 34 | # some envs don't take in seed as argument 35 | env = gym.make(env_name) 36 | 37 | # fix the done signal 38 | if "kitchen" in env_name: 39 | env = KitchenTerminalWrapper(env) 40 | if "binary" in env_name: 41 | # adroit 42 | env = AdroitTerminalWrapper(env) 43 | 44 | if max_episode_steps is not None: 45 | env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps) 46 | 47 | if scale_and_clip_action: 48 | # avoid NaNs for dist.log_prob(1.0) for tanh policies 49 | env = gym.wrappers.RescaleAction(env, -action_clip_lim, action_clip_lim) 50 | env = gym.wrappers.ClipAction(env) 51 | 52 | if reward_scale is not None and reward_bias is not None: 53 | env = ScaledRewardWrapper(env, reward_scale, reward_bias) 54 | 55 | env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=1) 56 | # 4-tuple to 5-tuple return 57 | env = TruncationWrapper(env) 58 | 59 | return env 60 | 61 | 62 | def get_env_type(env_name): 63 | """ 64 | separate the environment into different types 65 | (e.g. because different envs may need different logging / success conditions) 66 | """ 67 | if env_name in ( 68 | "pen-binary-v0", 69 | "door-binary-v0", 70 | "relocate-binary-v0", 71 | ): 72 | env_type = "adroit-binary" 73 | elif "antmaze" in env_name: 74 | env_type = "antmaze" 75 | elif "kitchen" in env_name: 76 | env_type = "kitchen" 77 | elif "halfcheetah" in env_name or "hopper" in env_name or "walker" in env_name: 78 | env_type = "locomotion" 79 | else: 80 | raise RuntimeError(f"Unknown environment type for {env_name}") 81 | 82 | return env_type 83 | 84 | 85 | def _determine_whether_sparse_reward(env_name): 86 | # return True if the environment is sparse-reward 87 | # determine if the env is sparse-reward or not 88 | if "antmaze" in env_name or env_name in [ 89 | "pen-binary-v0", 90 | "door-binary-v0", 91 | "relocate-binary-v0", 92 | "pen-binary", 93 | "door-binary", 94 | "relocate-binary", 95 | ]: 96 | is_sparse_reward = True 97 | elif ( 98 | "halfcheetah" in env_name 99 | or "hopper" in env_name 100 | or "walker" in env_name 101 | or "kitchen" in env_name 102 | ): 103 | is_sparse_reward = False 104 | else: 105 | raise NotImplementedError 106 | 107 | return is_sparse_reward 108 | 109 | 110 | # used to calculate the MC return for sparse-reward tasks. 111 | # Assumes that the environment issues two reward values: reward_pos when the 112 | # task is completed, and reward_neg at all the other steps. 113 | ENV_REWARD_INFO = { 114 | "antmaze": { # antmaze default is 0/1 reward 115 | "reward_pos": 1.0, 116 | "reward_neg": 0.0, 117 | }, 118 | "adroit-binary": { # adroit default is -1/0 reward 119 | "reward_pos": 0.0, 120 | "reward_neg": -1.0, 121 | }, 122 | } 123 | 124 | 125 | def _get_negative_reward(env_name, reward_scale, reward_bias): 126 | """ 127 | Given an environment with sparse rewards (aka there's only two reward values, 128 | the goal reward when the task is done, or the step penalty otherwise). 129 | Args: 130 | env_name: the name of the environment 131 | reward_scale: the reward scale 132 | reward_bias: the reward bias. The reward_scale and reward_bias are not applied 133 | here to scale the reward, but to determine the correct negative reward value. 134 | 135 | NOTE: this function should only be called on sparse-reward environments 136 | """ 137 | if "antmaze" in env_name: 138 | reward_neg = ( 139 | ENV_REWARD_INFO["antmaze"]["reward_neg"] * reward_scale + reward_bias 140 | ) 141 | elif env_name in [ 142 | "pen-binary-v0", 143 | "door-binary-v0", 144 | "relocate-binary-v0", 145 | ]: 146 | reward_neg = ( 147 | ENV_REWARD_INFO["adroit-binary"]["reward_neg"] * reward_scale + reward_bias 148 | ) 149 | else: 150 | raise NotImplementedError( 151 | """ 152 | If you want to try on a sparse reward env, 153 | please add the reward_neg value in the ENV_REWARD_INFO dict. 154 | """ 155 | ) 156 | 157 | return reward_neg 158 | 159 | 160 | def calc_return_to_go( 161 | env_name, 162 | rewards, 163 | masks, 164 | gamma, 165 | reward_scale=None, 166 | reward_bias=None, 167 | infinite_horizon=False, 168 | ): 169 | """ 170 | Calculat the Monte Carlo return to go given a list of reward for a single trajectory. 171 | Args: 172 | env_name: the name of the environment 173 | rewards: a list of rewards 174 | masks: a list of done masks 175 | gamma: the discount factor used to discount rewards 176 | reward_scale, reward_bias: the reward scale and bias used to determine 177 | the negative reward value for sparse-reward environments. If None, 178 | default from FLAGS values. Leave None unless for special cases. 179 | infinite_horizon: whether the MDP has inifite horizion (and therefore infinite return to go) 180 | """ 181 | if len(rewards) == 0: 182 | return np.array([]) 183 | 184 | # process sparse-reward envs 185 | if reward_scale is None or reward_bias is None: 186 | # scale and bias not applied, but used to determien the negative reward value 187 | assert reward_scale is None and reward_bias is None # both should be unset 188 | reward_scale = FLAGS.reward_scale 189 | reward_bias = FLAGS.reward_bias 190 | is_sparse_reward = _determine_whether_sparse_reward(env_name) 191 | if is_sparse_reward: 192 | reward_neg = _get_negative_reward(env_name, reward_scale, reward_bias) 193 | 194 | if is_sparse_reward and np.all(np.array(rewards) == reward_neg): 195 | """ 196 | If the env has sparse reward and the trajectory is all negative rewards, 197 | we use r / (1-gamma) as return to go. 198 | For exapmle, if gamma = 0.99 and the rewards = [-1, -1, -1], 199 | then return_to_go = [-100, -100, -100] 200 | """ 201 | return_to_go = [float(reward_neg / (1 - gamma))] * len(rewards) 202 | else: 203 | # sum up the rewards backwards as the return to go 204 | return_to_go = [0] * len(rewards) 205 | prev_return = 0 if not infinite_horizon else float(rewards[-1] / (1 - gamma)) 206 | for i in range(len(rewards)): 207 | return_to_go[-i - 1] = rewards[-i - 1] + gamma * prev_return * ( 208 | masks[-i - 1] 209 | ) 210 | prev_return = return_to_go[-i - 1] 211 | return np.array(return_to_go, dtype=np.float32) 212 | -------------------------------------------------------------------------------- /wsrl/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from wsrl.envs.wrappers.add_truncation import TruncationWrapper 2 | from wsrl.envs.wrappers.adroit import AdroitTerminalWrapper 3 | from wsrl.envs.wrappers.kitchen import KitchenTerminalWrapper 4 | from wsrl.envs.wrappers.reward_scale import ScaledRewardWrapper 5 | -------------------------------------------------------------------------------- /wsrl/envs/wrappers/add_truncation.py: -------------------------------------------------------------------------------- 1 | from gym import Wrapper 2 | 3 | 4 | class TruncationWrapper(Wrapper): 5 | """d4rl only supports the old gym API, where env.step returns a 4-tuple without 6 | the truncated signal. Here we explicity expose the truncated signal.""" 7 | 8 | def __init__(self, env): 9 | super().__init__(env) 10 | 11 | def reset(self): 12 | s = self.env.reset() 13 | return s, {} 14 | 15 | def step(self, a): 16 | s, r, done, info = self.env.step(a) 17 | if "TimeLimit.truncated" in info: 18 | truncated = True 19 | done = False 20 | else: 21 | truncated = False 22 | return s, r, done, truncated, info 23 | -------------------------------------------------------------------------------- /wsrl/envs/wrappers/adroit.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class AdroitTerminalWrapper(gym.Wrapper): 5 | """ 6 | The original adroit environment doesn't set the done signal when the goal is 7 | achieved. This wrapper sets the done signal when the episode is done, 8 | decided by if the reward is 4 9 | """ 10 | 11 | def __init__(self, env): 12 | super().__init__(env) 13 | 14 | def step(self, action): 15 | # this wrapper should be wrapped right after environment creation and before 16 | # the Trucation wrapper, so it returns a 4-tuple 17 | obs, reward, done, info = self.env.step(action) 18 | if info["goal_achieved"]: 19 | done = True 20 | return obs, reward, done, info 21 | -------------------------------------------------------------------------------- /wsrl/envs/wrappers/kitchen.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class KitchenTerminalWrapper(gym.Wrapper): 5 | """ 6 | The original kitchen environment doesn't set the done signal when the episode is 7 | successfully completed. This wrapper sets the done signal when the episode is done, 8 | decided by if the reward is 4 9 | """ 10 | 11 | def __init__(self, env): 12 | super().__init__(env) 13 | 14 | def step(self, action): 15 | # this wrapper should be wrapped right after environment creation and before 16 | # the Trucation wrapper, so it returns a 4-tuple 17 | obs, reward, done, info = self.env.step(action) 18 | if reward == 4 or done: 19 | done = True 20 | return obs, reward, done, info 21 | -------------------------------------------------------------------------------- /wsrl/envs/wrappers/reward_scale.py: -------------------------------------------------------------------------------- 1 | from gym import Env, RewardWrapper 2 | 3 | 4 | class ScaledRewardWrapper(RewardWrapper): 5 | def __init__(self, env: Env, scale: float, bias: float): 6 | super().__init__(env) 7 | self.scale = scale 8 | self.bias = bias 9 | 10 | def reward(self, reward): 11 | return reward * self.scale + self.bias 12 | -------------------------------------------------------------------------------- /wsrl/networks/actor_critic_nets.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import distrax 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | 7 | from wsrl.common.initialization import init_fns 8 | 9 | 10 | class ValueCritic(nn.Module): 11 | encoder: Optional[nn.Module] 12 | network: nn.Module 13 | init_final: Optional[float] = None 14 | kernel_init_type: Optional[str] = None 15 | 16 | def setup(self): 17 | self.init_fn = init_fns[self.kernel_init_type] 18 | 19 | @nn.compact 20 | def __call__( 21 | self, 22 | observations: jnp.ndarray, 23 | train: bool = False, 24 | ) -> jnp.ndarray: 25 | if self.encoder is None: 26 | obs_enc = observations 27 | else: 28 | obs_enc = self.encoder(observations) 29 | outputs = self.network(obs_enc, train=train) 30 | if self.init_final is not None: 31 | value = nn.Dense( 32 | 1, 33 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), 34 | )(outputs) 35 | else: 36 | value = nn.Dense(1, kernel_init=self.init_fn())(outputs) 37 | 38 | return jnp.squeeze(value, -1) 39 | 40 | 41 | class Critic(nn.Module): 42 | encoder: Optional[nn.Module] 43 | network: nn.Module 44 | init_final: Optional[float] = None 45 | kernel_init_type: Optional[str] = None 46 | 47 | def setup(self): 48 | self.init_fn = init_fns[self.kernel_init_type] 49 | 50 | @nn.compact 51 | def __call__( 52 | self, 53 | observations: jnp.ndarray, 54 | actions: jnp.ndarray, 55 | train: bool = False, 56 | ) -> jnp.ndarray: 57 | if self.encoder is None: 58 | obs_enc = observations 59 | else: 60 | obs_enc = self.encoder(observations) 61 | 62 | inputs = jnp.concatenate([obs_enc, actions], -1) 63 | outputs = self.network(inputs, train=train) 64 | if self.init_final is not None: 65 | value = nn.Dense( 66 | 1, 67 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), 68 | )(outputs) 69 | else: 70 | value = nn.Dense(1, kernel_init=self.init_fn())(outputs) 71 | 72 | return jnp.squeeze(value, -1) 73 | 74 | 75 | def ensemblize(cls, num_qs, out_axes=0): 76 | return nn.vmap( 77 | cls, 78 | variable_axes={"params": 0}, 79 | split_rngs={"params": True}, 80 | in_axes=None, 81 | out_axes=out_axes, 82 | axis_size=num_qs, 83 | ) 84 | 85 | 86 | class Policy(nn.Module): 87 | encoder: Optional[nn.Module] 88 | network: nn.Module 89 | action_dim: int 90 | init_final: Optional[float] = None 91 | std_parameterization: str = "exp" # "exp", "softplus", "fixed", or "uniform" 92 | std_min: Optional[float] = 1e-5 93 | std_max: Optional[float] = 10.0 94 | tanh_squash_distribution: bool = False 95 | fixed_std: Optional[jnp.ndarray] = None 96 | kernel_init_type: Optional[str] = None 97 | 98 | def setup(self): 99 | self.init_fn = init_fns[self.kernel_init_type] 100 | 101 | @nn.compact 102 | def __call__( 103 | self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False 104 | ) -> distrax.Distribution: 105 | if self.encoder is None: 106 | obs_enc = observations 107 | else: 108 | obs_enc = self.encoder(observations) 109 | 110 | outputs = self.network(obs_enc, train=train) 111 | 112 | means = nn.Dense(self.action_dim, kernel_init=self.init_fn())(outputs) 113 | if self.fixed_std is None: 114 | if self.std_parameterization == "exp": 115 | log_stds = nn.Dense(self.action_dim, kernel_init=self.init_fn())( 116 | outputs 117 | ) 118 | 119 | # # mitsuhiko ablation 120 | # base_network_output = nn.Dense(2 * self.action_dim, kernel_init=self.init_fn())( 121 | # outputs 122 | # ) 123 | # means, log_stds = jnp.split(base_network_output, 2, axis=-1) 124 | # log_stds = jnp.clip(log_stds + Scalar(-1.0)(), -20.0, 2.0) 125 | 126 | stds = jnp.exp(log_stds) 127 | 128 | elif self.std_parameterization == "softplus": 129 | stds = nn.Dense(self.action_dim, kernel_init=self.init_fn())(outputs) 130 | stds = nn.softplus(stds) 131 | elif self.std_parameterization == "uniform": 132 | log_stds = self.param( 133 | "log_stds", nn.initializers.zeros, (self.action_dim,) 134 | ) 135 | stds = jnp.exp(log_stds) 136 | else: 137 | raise ValueError( 138 | f"Invalid std_parameterization: {self.std_parameterization}" 139 | ) 140 | else: 141 | assert self.std_parameterization == "fixed" 142 | if type(self.fixed_std) == list: 143 | stds = jnp.array(self.fixed_std) 144 | else: 145 | # self.fixed_std is a float 146 | assert isinstance( 147 | self.fixed_std, (int, float) 148 | ), "fixed std must be a number" 149 | stds = jnp.array([self.fixed_std] * self.action_dim) 150 | 151 | # Clip stds to avoid numerical instability 152 | # For a normal distribution under MaxEnt, optimal std scales with sqrt(temperature) 153 | stds = jnp.clip(stds, self.std_min, self.std_max) * jnp.sqrt(temperature) 154 | 155 | if self.tanh_squash_distribution: 156 | distribution = TanhMultivariateNormalDiag( 157 | loc=means, 158 | scale_diag=stds, 159 | ) 160 | else: 161 | distribution = distrax.MultivariateNormalDiag( 162 | loc=means, 163 | scale_diag=stds, 164 | ) 165 | 166 | return distribution 167 | 168 | 169 | class TanhMultivariateNormalDiag(distrax.Transformed): 170 | def __init__( 171 | self, 172 | loc: jnp.ndarray, 173 | scale_diag: jnp.ndarray, 174 | low: Optional[jnp.ndarray] = None, 175 | high: Optional[jnp.ndarray] = None, 176 | ): 177 | distribution = distrax.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) 178 | 179 | layers = [] 180 | 181 | if not (low is None or high is None): 182 | 183 | def rescale_from_tanh(x): 184 | x = (x + 1) / 2 # (-1, 1) => (0, 1) 185 | return x * (high - low) + low 186 | 187 | def forward_log_det_jacobian(x): 188 | high_ = jnp.broadcast_to(high, x.shape) 189 | low_ = jnp.broadcast_to(low, x.shape) 190 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1) 191 | 192 | layers.append( 193 | distrax.Lambda( 194 | rescale_from_tanh, 195 | forward_log_det_jacobian=forward_log_det_jacobian, 196 | event_ndims_in=1, 197 | event_ndims_out=1, 198 | ) 199 | ) 200 | 201 | layers.append(distrax.Block(distrax.Tanh(), 1)) 202 | 203 | bijector = distrax.Chain(layers) 204 | 205 | super().__init__(distribution=distribution, bijector=bijector) 206 | 207 | def mode(self) -> jnp.ndarray: 208 | return self.bijector.forward(self.distribution.mode()) 209 | 210 | def stddev(self) -> jnp.ndarray: 211 | return self.bijector.forward(self.distribution.stddev()) 212 | 213 | 214 | class Scalar(nn.Module): 215 | init_value: float 216 | 217 | def setup(self): 218 | self.value = self.param("value", lambda x: self.init_value) 219 | 220 | def __call__(self): 221 | return self.value 222 | -------------------------------------------------------------------------------- /wsrl/networks/lagrange.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Optional, Sequence 3 | 4 | import chex 5 | import flax.linen as nn 6 | import jax.numpy as jnp 7 | 8 | 9 | class LagrangeMultiplier(nn.Module): 10 | init_value: float = 1.0 11 | constraint_shape: Sequence[int] = () 12 | constraint_type: str = "eq" # One of ("eq", "leq", "geq") 13 | parameterization: Optional[ 14 | str 15 | ] = None # One of ("softplus", "exp"), or None for equality constraints 16 | 17 | @nn.compact 18 | def __call__( 19 | self, *, lhs: Optional[jnp.ndarray] = None, rhs: Optional[jnp.ndarray] = None 20 | ) -> jnp.ndarray: 21 | init_value = self.init_value 22 | 23 | if self.constraint_type != "eq": 24 | assert ( 25 | init_value > 0 26 | ), "Inequality constraints must have non-negative initial multiplier values" 27 | 28 | if self.parameterization == "softplus": 29 | init_value = jnp.log(jnp.exp(init_value) - 1) 30 | elif self.parameterization == "exp": 31 | init_value = jnp.log(init_value) 32 | else: 33 | raise ValueError( 34 | f"Invalid multiplier parameterization {self.parameterization}" 35 | ) 36 | else: 37 | assert ( 38 | self.parameterization is None 39 | ), "Equality constraints must have no parameterization" 40 | 41 | multiplier = self.param( 42 | "lagrange", 43 | lambda _, shape: jnp.full(shape, init_value), 44 | self.constraint_shape, 45 | ) 46 | 47 | if self.constraint_type != "eq": 48 | if self.parameterization == "softplus": 49 | multiplier = nn.softplus(multiplier) 50 | elif self.parameterization == "exp": 51 | multiplier = jnp.exp(multiplier) 52 | else: 53 | raise ValueError( 54 | f"Invalid multiplier parameterization {self.parameterization}" 55 | ) 56 | 57 | # Return the raw multiplier 58 | if lhs is None: 59 | return multiplier 60 | 61 | # Use the multiplier to compute the Lagrange penalty 62 | if rhs is None: 63 | rhs = jnp.zeros_like(lhs) 64 | 65 | diff = lhs - rhs 66 | 67 | chex.assert_equal_shape([diff, multiplier]) 68 | 69 | if self.constraint_type == "eq": 70 | return multiplier * diff 71 | elif self.constraint_type == "geq": 72 | return multiplier * diff 73 | elif self.constraint_type == "leq": 74 | return -multiplier * diff 75 | 76 | 77 | GeqLagrangeMultiplier = partial( 78 | LagrangeMultiplier, constraint_type="geq", parameterization="softplus" 79 | ) 80 | 81 | LeqLagrangeMultiplier = partial( 82 | LagrangeMultiplier, constraint_type="leq", parameterization="softplus" 83 | ) 84 | -------------------------------------------------------------------------------- /wsrl/networks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | from wsrl.common.initialization import init_fns 7 | 8 | 9 | class MLP(nn.Module): 10 | hidden_dims: Sequence[int] 11 | activations: Callable[[jnp.ndarray], jnp.ndarray] | str = nn.relu 12 | activate_final: bool = True 13 | use_layer_norm: bool = False 14 | use_group_norm: bool = False 15 | dropout_rate: Optional[float] = None 16 | kernel_init_type: Optional[str] = None 17 | kernel_scale_final: Optional[float] = None 18 | 19 | def setup(self): 20 | assert not (self.use_layer_norm and self.use_group_norm) 21 | self.init_fn = init_fns[self.kernel_init_type] 22 | 23 | @nn.compact 24 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 25 | activations = self.activations 26 | if isinstance(activations, str): 27 | activations = getattr(nn, activations) 28 | 29 | for i, size in enumerate(self.hidden_dims): 30 | 31 | # optinally final layer have different init scale 32 | if i + 1 == len(self.hidden_dims) and self.kernel_scale_final is not None: 33 | x = nn.Dense(size, kernel_init=self.init_fn(self.kernel_scale_final))(x) 34 | else: 35 | x = nn.Dense(size, kernel_init=self.init_fn())(x) 36 | 37 | # normalization and activation after each layer 38 | if i + 1 < len(self.hidden_dims) or self.activate_final: 39 | if self.dropout_rate is not None and self.dropout_rate > 0: 40 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 41 | if self.use_layer_norm: 42 | x = nn.LayerNorm()(x) 43 | elif self.use_group_norm: 44 | x = nn.GroupNorm()(x) 45 | x = activations(x) 46 | return x 47 | 48 | 49 | class MLPResNetBlock(nn.Module): 50 | features: int 51 | act: Callable 52 | dropout_rate: float = None 53 | use_layer_norm: bool = False 54 | 55 | @nn.compact 56 | def __call__(self, x, train: bool = False): 57 | residual = x 58 | if self.dropout_rate is not None and self.dropout_rate > 0: 59 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 60 | if self.use_layer_norm: 61 | x = nn.LayerNorm()(x) 62 | x = nn.Dense(self.features * 4)(x) 63 | x = self.act(x) 64 | x = nn.Dense(self.features)(x) 65 | 66 | if residual.shape != x.shape: 67 | residual = nn.Dense(self.features)(residual) 68 | 69 | return residual + x 70 | 71 | 72 | class MLPResNet(nn.Module): 73 | num_blocks: int 74 | out_dim: int 75 | dropout_rate: float = None 76 | use_layer_norm: bool = False 77 | hidden_dim: int = 256 78 | activations: Callable = nn.swish 79 | kernel_init_type: Optional[str] = None 80 | 81 | def setup(self): 82 | self.init_fn = init_fns[self.kernel_init_type] 83 | 84 | @nn.compact 85 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 86 | x = nn.Dense(self.hidden_dim, kernel_init=self.init_fn())(x) 87 | for _ in range(self.num_blocks): 88 | x = MLPResNetBlock( 89 | self.hidden_dim, 90 | act=self.activations, 91 | use_layer_norm=self.use_layer_norm, 92 | dropout_rate=self.dropout_rate, 93 | )(x, train=train) 94 | 95 | x = self.activations(x) 96 | x = nn.Dense(self.out_dim, kernel_init=self.init_fn())(x) 97 | return x 98 | 99 | 100 | class Scalar(nn.Module): 101 | init_value: float 102 | 103 | def setup(self): 104 | self.value = self.param("value", lambda x: self.init_value) 105 | 106 | def __call__(self): 107 | return self.value 108 | -------------------------------------------------------------------------------- /wsrl/utils/timer_utils.py: -------------------------------------------------------------------------------- 1 | """Timer utility.""" 2 | 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class _TimerContextManager: 8 | def __init__(self, timer: "Timer", key: str): 9 | self.timer = timer 10 | self.key = key 11 | 12 | def __enter__(self): 13 | self.timer.tick(self.key) 14 | 15 | def __exit__(self, exc_type, exc_value, exc_traceback): 16 | self.timer.tock(self.key) 17 | 18 | 19 | class Timer: 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.counts = defaultdict(int) 25 | self.times = defaultdict(float) 26 | self.start_times = {} 27 | 28 | def tick(self, key): 29 | if key in self.start_times: 30 | raise ValueError(f"Timer is already ticking for key: {key}") 31 | self.start_times[key] = time.time() 32 | 33 | def tock(self, key): 34 | if key not in self.start_times: 35 | raise ValueError(f"Timer is not ticking for key: {key}") 36 | self.counts[key] += 1 37 | self.times[key] += time.time() - self.start_times[key] 38 | del self.start_times[key] 39 | 40 | def context(self, key): 41 | """ 42 | Use this like: 43 | 44 | with timer.context("key"): 45 | # do stuff 46 | 47 | Then timer.tock("key") will be called automatically. 48 | """ 49 | return _TimerContextManager(self, key) 50 | 51 | def get_average_times(self, reset=True): 52 | ret = {key: self.times[key] / self.counts[key] for key in self.counts} 53 | if reset: 54 | self.reset() 55 | return ret 56 | -------------------------------------------------------------------------------- /wsrl/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | import numpy as np 4 | 5 | 6 | def concatenate_batches(batches): 7 | concatenated = {} 8 | for key in batches[0].keys(): 9 | if isinstance(batches[0][key], Mapping): 10 | # to concatenate batch["observations"]["image"], etc. 11 | concatenated[key] = concatenate_batches([batch[key] for batch in batches]) 12 | else: 13 | concatenated[key] = np.concatenate( 14 | [batch[key] for batch in batches], axis=0 15 | ).astype(np.float32) 16 | return concatenated 17 | 18 | 19 | def index_batch(batch, indices): 20 | indexed = {} 21 | for key in batch.keys(): 22 | if isinstance(batch[key], Mapping): 23 | # to index into batch["observations"]["image"], etc. 24 | indexed[key] = index_batch(batch[key], indices) 25 | else: 26 | indexed[key] = batch[key][indices, ...] 27 | return indexed 28 | 29 | 30 | def subsample_batch(batch, size): 31 | indices = np.random.randint(batch["rewards"].shape[0], size=size) 32 | return index_batch(batch, indices) 33 | --------------------------------------------------------------------------------