├── assets └── csil.gif ├── .gitignore ├── requirements.txt ├── sil ├── __init__.py ├── builder.py ├── evaluator.py ├── pretraining.py ├── config.py ├── learning.py └── networks.py ├── CONTRIBUTING.md ├── experiment_logger.py ├── README.md ├── helpers.py ├── LICENSE ├── run_ppil.py ├── run_iqlearn.py └── run_csil.py /assets/csil.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/csil/HEAD/assets/csil.gif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | chex==0.1.6 3 | cython<3 4 | d4rl @ git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl 5 | dm-acme[jax] @ git+https://github.com/google-deepmind/acme@d92e23b 6 | dm-env 7 | dm-haiku==0.0.9 8 | dm-launchpad[reverb] 9 | dm-sonnet 10 | dm-tree 11 | envlogger[tfds] 12 | flax==0.6.4 13 | gym<0.24.0 14 | ipython 15 | matplotlib 16 | notebook 17 | numpy 18 | optax==0.1.4 19 | orbax<=0.1.7 20 | patchelf 21 | rlds 22 | tabulate 23 | tensorflow==2.8.0 24 | tensorflow-datasets==4.6.0 25 | tensorflow_probability==0.15.0 26 | tqdm 27 | wandb 28 | -------------------------------------------------------------------------------- /sil/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Soft imitation learning agent.""" 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /experiment_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Loggers for experiments.""" 17 | 18 | from typing import Any, Callable, Dict, Mapping, Optional 19 | 20 | import logging 21 | import time 22 | import wandb 23 | 24 | from acme.utils.loggers import aggregators 25 | from acme.utils.loggers import asynchronous as async_logger 26 | from acme.utils.loggers import base 27 | from acme.utils.loggers import csv 28 | from acme.utils.loggers import filters 29 | from acme.utils.loggers import terminal 30 | 31 | 32 | class WeightsAndBiasesLogger(base.Logger): 33 | 34 | def __init__( 35 | self, 36 | logger: wandb.sdk.wandb_run.Run, 37 | label: str = '', 38 | time_delta: float = 0.0, 39 | ): 40 | """Initializes the Weights And Biases wrapper for Acme. 41 | 42 | Args: 43 | logger: Weights & Biases logger instances 44 | label: label string to use when logging. 45 | serialize_fn: function to call which transforms values into a str. 46 | time_delta: How often (in seconds) to write values. This can be used to 47 | minimize terminal spam, but is 0 by default---ie everything is written. 48 | """ 49 | self._label = label 50 | self._time = time.time() 51 | self._time_delta = time_delta 52 | self._logger = logger 53 | 54 | def write(self, data: base.LoggingData): 55 | """Write to weights and biases.""" 56 | now = time.time() 57 | if (now - self._time) > self._time_delta: 58 | data = base.to_numpy(data) # type: ignore 59 | if self._label: 60 | stats = {f"{self._label}/{k}": v for k, v in data.items()} 61 | else: 62 | stats = data 63 | self._logger.log(stats) # type: ignore 64 | self._time = now 65 | 66 | def close(self): 67 | pass 68 | 69 | def make_logger( 70 | label: str, 71 | wandb_logger: wandb.sdk.wandb_run.Run, 72 | steps_key: str = 'steps', 73 | save_data: bool = False, 74 | time_delta: float = 1.0, 75 | asynchronous: bool = False, 76 | print_fn: Optional[Callable[[str], None]] = None, 77 | serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy, 78 | ) -> base.Logger: 79 | """Makes a default Acme logger. 80 | 81 | Args: 82 | label: Name to give to the logger. 83 | wandb_logger: Weights and Biases logger instance. 84 | save_data: Whether to persist data. 85 | time_delta: Time (in seconds) between logging events. 86 | asynchronous: Whether the write function should block or not. 87 | print_fn: How to print to terminal (defaults to print). 88 | serialize_fn: An optional function to apply to the write inputs before 89 | passing them to the various loggers. 90 | steps_key: Ignored. 91 | 92 | Returns: 93 | A logger object that responds to logger.write(some_dict). 94 | """ 95 | del steps_key 96 | if not print_fn: 97 | print_fn = logging.info 98 | 99 | terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn) 100 | wandb_logger = WeightsAndBiasesLogger( 101 | logger=wandb_logger, 102 | label=label) 103 | 104 | loggers = [terminal_logger, wandb_logger] 105 | 106 | if save_data: 107 | loggers.append(csv.CSVLogger(label=label)) 108 | 109 | # Dispatch to all writers and filter Nones and by time. 110 | logger = aggregators.Dispatcher(loggers, serialize_fn) 111 | logger = filters.NoneFilter(logger) 112 | if asynchronous: 113 | logger = async_logger.AsyncLogger(logger) 114 | logger = filters.TimeFilter(logger, time_delta) 115 | 116 | return logger 117 | 118 | 119 | def make_experiment_logger_factory( 120 | wandb_kwargs = Dict[str, Any] 121 | ) -> Callable[[str, Optional[str], int], base.Logger]: 122 | """Makes an Acme logger factory. 123 | 124 | Args: 125 | wandb_kwargs: Dictionary of keywork arguments for wandb.init(). 126 | 127 | Returns: 128 | A logger factory function. 129 | """ 130 | 131 | # In the distributed setting, it is better to initialize the logger once and pickle, 132 | # than to initialize the W&B logging in each process. 133 | wandb_logger = wandb.init( 134 | **wandb_kwargs, 135 | ) 136 | 137 | def make_experiment_logger(label: str, 138 | steps_key: Optional[str] = None, 139 | task_instance: int = 0) -> base.Logger: 140 | del task_instance 141 | if steps_key is None: 142 | steps_key = f'{label}_steps' 143 | return make_logger(label=label, steps_key=steps_key, 144 | wandb_logger=wandb_logger, 145 | ) 146 | return make_experiment_logger 147 | -------------------------------------------------------------------------------- /sil/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Soft imitation learning builder.""" 17 | from typing import Iterator, List, Optional 18 | 19 | import acme 20 | from acme import adders 21 | from acme import core 22 | from acme import specs 23 | from acme import types 24 | from acme.adders import reverb as adders_reverb 25 | from acme.agents.jax import actor_core as actor_core_lib 26 | from acme.agents.jax import actors 27 | from acme.agents.jax import builders 28 | from acme.agents.jax.sac import networks as sac_networks 29 | from acme.datasets import reverb as datasets 30 | from acme.jax import networks as networks_lib 31 | from acme.jax import utils 32 | from acme.jax import variable_utils 33 | from acme.utils import counting 34 | from acme.utils import loggers 35 | import jax 36 | import optax 37 | import reverb 38 | from reverb import rate_limiters 39 | 40 | from sil import config as sil_config 41 | from sil import learning 42 | from sil import networks as sil_networks 43 | 44 | 45 | class SILBuilder( 46 | builders.ActorLearnerBuilder[ 47 | sil_networks.SILNetworks, 48 | actor_core_lib.FeedForwardPolicy, 49 | learning.ImitationSample, 50 | ] 51 | ): 52 | """Soft Imitation Learning Builder.""" 53 | 54 | def __init__(self, config: sil_config.SILConfig): 55 | """Creates a soft imitation learner, a behavior policy and an eval actor. 56 | 57 | Args: 58 | config: a config with hyperparameters 59 | """ 60 | self._config = config 61 | self._make_demonstrations = config.expert_demonstration_factory 62 | 63 | def make_learner( 64 | self, 65 | random_key: networks_lib.PRNGKey, 66 | networks: sil_networks.SILNetworks, 67 | dataset: Iterator[learning.ImitationSample], 68 | logger_fn: loggers.LoggerFactory, 69 | environment_spec: specs.EnvironmentSpec, 70 | replay_client: Optional[reverb.Client] = None, 71 | counter: Optional[counting.Counter] = None, 72 | ) -> core.Learner: 73 | del environment_spec, replay_client 74 | 75 | # Create optimizers. 76 | policy_optimizer = optax.adam( 77 | learning_rate=self._config.actor_learning_rate 78 | ) 79 | q_optimizer = optax.adam(self._config.critic_learning_rate) 80 | r_optimizer = optax.sgd(learning_rate=self._config.reward_learning_rate) 81 | 82 | critic_loss = self._config.imitation.critic_loss_factory() 83 | reward_factory = self._config.imitation.reward_factory() 84 | 85 | n_policy_pretrainers = (len(self._config.policy_pretraining) 86 | if self._config.policy_pretraining else 0) 87 | policy_pretraining_loggers = [ 88 | logger_fn(f'pretrainer_policy{i}') 89 | for i in range(n_policy_pretrainers) 90 | ] 91 | 92 | return learning.SILLearner( 93 | networks=networks, 94 | critic_loss_def=critic_loss, 95 | reward_factory=reward_factory, 96 | tau=self._config.tau, 97 | discount=self._config.discount, 98 | critic_actor_update_ratio=self._config.critic_actor_update_ratio, 99 | entropy_coefficient=self._config.entropy_coefficient, 100 | target_entropy=self._config.target_entropy, 101 | alpha_init=self._config.alpha_init, 102 | alpha_learning_rate=self._config.alpha_learning_rate, 103 | damping=self._config.damping, 104 | rng=random_key, 105 | num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, 106 | policy_optimizer=policy_optimizer, 107 | q_optimizer=q_optimizer, 108 | r_optimizer=r_optimizer, 109 | dataset=dataset, 110 | actor_bc_loss=self._config.actor_bc_loss, 111 | policy_pretraining=self._config.policy_pretraining, 112 | critic_pretraining=self._config.critic_pretraining, 113 | learner_logger=logger_fn('learner'), 114 | policy_pretraining_loggers=policy_pretraining_loggers, 115 | critic_pretraining_logger=logger_fn('pretrainer_critic'), 116 | counter=counter, 117 | ) 118 | 119 | def make_actor( 120 | self, 121 | random_key: networks_lib.PRNGKey, 122 | policy: actor_core_lib.FeedForwardPolicy, 123 | environment_spec: specs.EnvironmentSpec, 124 | variable_source: Optional[core.VariableSource] = None, 125 | adder: Optional[adders.Adder] = None, 126 | ) -> acme.Actor: 127 | del environment_spec 128 | assert variable_source is not None 129 | actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) 130 | variable_client = variable_utils.VariableClient( 131 | variable_source, 'policy', device='cpu') 132 | return actors.GenericActor( 133 | actor_core, random_key, variable_client, adder, backend='cpu') 134 | 135 | def make_replay_tables( 136 | self, 137 | environment_spec: specs.EnvironmentSpec, 138 | policy: actor_core_lib.FeedForwardPolicy, 139 | ) -> List[reverb.Table]: 140 | """Create tables to insert data into.""" 141 | del policy 142 | samples_per_insert_tolerance = ( 143 | self._config.samples_per_insert_tolerance_rate * 144 | self._config.samples_per_insert) 145 | error_buffer = self._config.min_replay_size * samples_per_insert_tolerance 146 | limiter = rate_limiters.SampleToInsertRatio( 147 | min_size_to_sample=self._config.min_replay_size, 148 | samples_per_insert=self._config.samples_per_insert, 149 | error_buffer=error_buffer) 150 | return [ 151 | reverb.Table( 152 | name=self._config.replay_table_name, 153 | sampler=reverb.selectors.Uniform(), 154 | remover=reverb.selectors.Fifo(), 155 | max_size=self._config.max_replay_size, 156 | rate_limiter=limiter, 157 | signature=adders_reverb.NStepTransitionAdder.signature( 158 | environment_spec)) 159 | ] 160 | 161 | def make_dataset_iterator( 162 | self, replay_client: reverb.Client 163 | ) -> Iterator[learning.ImitationSample]: 164 | """Create a dataset iterator to use for learning/updating the agent.""" 165 | # Replay buffer for demonstration data. 166 | iterator_demos = self._make_demonstrations(self._config.batch_size) 167 | 168 | # Replay buffer for online experience. 169 | iterator_online = datasets.make_reverb_dataset( 170 | table=self._config.replay_table_name, 171 | server_address=replay_client.server_address, 172 | batch_size=self._config.batch_size, 173 | prefetch_size=self._config.prefetch_size, 174 | ).as_numpy_iterator() 175 | 176 | return utils.device_put( 177 | ( 178 | learning.ImitationSample(types.Transition(*online.data), demo) 179 | for online, demo in zip(iterator_online, iterator_demos) 180 | ), 181 | jax.devices()[0], 182 | ) 183 | 184 | def make_adder( 185 | self, replay_client: reverb.Client, 186 | environment_spec: Optional[specs.EnvironmentSpec], 187 | policy: Optional[actor_core_lib.FeedForwardPolicy] 188 | ) -> Optional[adders.Adder]: 189 | """Create an adder which records data generated by the actor/environment.""" 190 | del environment_spec, policy 191 | return adders_reverb.NStepTransitionAdder( 192 | priority_fns={self._config.replay_table_name: None}, 193 | client=replay_client, 194 | n_step=self._config.n_step, 195 | discount=self._config.discount, 196 | ) 197 | 198 | def make_policy( 199 | self, 200 | networks: sil_networks.SILNetworks, 201 | environment_spec: specs.EnvironmentSpec, 202 | evaluation: bool = False, 203 | ) -> actor_core_lib.FeedForwardPolicy: 204 | """Construct the policy, which is the same as soft actor critic's policy.""" 205 | del environment_spec 206 | return sac_networks.apply_policy_and_sample( 207 | networks.to_sac(using_bc_policy=False), eval_mode=evaluation 208 | ) 209 | 210 | def make_bc_policy( 211 | self, 212 | networks: sil_networks.SILNetworks, 213 | environment_spec: specs.EnvironmentSpec, 214 | evaluation: bool = False, 215 | ) -> actor_core_lib.FeedForwardPolicy: 216 | """Construct the policy, which is the same as soft actor critic's policy.""" 217 | del environment_spec 218 | return sac_networks.apply_policy_and_sample( 219 | networks.to_sac(using_bc_policy=True), eval_mode=evaluation 220 | ) 221 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Coherent Soft Imitation Learning 2 | 3 | [![arXiv](https://img.shields.io/badge/stat.ML-arXiv%3A2305.16498-B31B1B.svg)](https://arxiv.org/abs/2305.16498) 4 | [![Python 3.7+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-376/) 5 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 6 | 7 |

8 | 9 |

10 | 11 | 12 | This repository contains an implementation of [coherent soft imitation learning (CSIL)](https://arxiv.org/abs/2305.16498), 13 | published at [NeurIPS 2023](https://openreview.net/forum?id=kCCD8d2aEu). 14 | 15 | We also provide implementations of other 'soft' 16 | imitation learning (SIL) algorithms: [Inverse soft Q-learning (IQ-Learn)](https://arxiv.org/abs/2106.12142) and [proximal 17 | point imitation learning (PPIL)](https://arxiv.org/abs/2209.10968). 18 | 19 | ## Content 20 | The implementation is built on top of [Acme](https://github.com/google-deepmind/acme) and follows their agent structure. 21 | ``` 22 | . 23 | ├── run_csil.py - Example of running CSIL on continuous control tasks. 24 | ├── run_iqlearn.py - Example of running IQ-Learn on continuous control tasks. 25 | ├── run_ppil.py - Example of running PPIL on continuous control tasks. 26 | ├── soft_policy_iteration.ipynb - Evaluation of SIL algorithms in a discrete tabular setting. 27 | ├── helpers.py - Utilities such as dataset iterators and environment creation. 28 | ├── experiment_logger.py - Implements a Weights & Biases logger within the Acme framework. 29 | | 30 | ├── sil 31 | | ├── config.py - Algorithm-specific configurations for soft imitation learning (SIL). 32 | | ├── builder.py - Creates the learner, actor, and policy. 33 | | ├── evaluator.py - Creates the evaluators and video recorders. 34 | | ├── learning.py - Implements the imitation learners. 35 | | ├── networks.py - Defines the policy, reward and critic networks. 36 | | └── pretraining.py - Implements pre-training for policy and critic. 37 | ``` 38 | 39 | ## Usage 40 | 41 | Before running any code, first activate the conda environment and set the 42 | `PYTHONPATH`: 43 | ```bash 44 | conda activate csil 45 | export PYTHONPATH=$(pwd)/.. 46 | ``` 47 | 48 | To run CSIL with default settings: 49 | ```bash 50 | python scripts/run_csil.py 51 | ``` 52 | This runs the online version of CSIL on HalfCheetah-v2. 53 | 54 | The experiment configurations for each algorithm (CSIL, IQ-Learn, and PPIL), can 55 | be adjusted via the flags defined at the start of `run_*.py`. 56 | 57 | The available tasks (specified with the `--env_name` flag) are: 58 | ``` 59 | HalfCheetah-v2 60 | Ant-v2 61 | Walker2d-v2 62 | Hopper-v2 63 | Humanoid-v2 64 | door-v0 # Adroit hand 65 | hammer-v0 # Adroit hand 66 | pen-v0 # Adroit hand 67 | ``` 68 | 69 | The default setting is online soft imitation learning. To run the offline 70 | version on the Adroit door task, for example: 71 | ```bash 72 | python scripts/run_{algo_name}.py --offline=True --env_name=door-v0 73 | ``` 74 | replacing `{algo_name}` with either csil, iqlearn, or ppil. 75 | 76 | We have also included a Colab [here](https://colab.research.google.com/github/google-deepmind/csil/blob/main/soft_policy_iteration.ipynb) that reproduces 77 | the discrete grid world experiments shown in the paper, for a range of imitation learning algorithms. 78 | 79 | We highly encourage the use of accelerators (i.e. GPUs, TPUs) for CSIL. As CSIL requires a larger policy architecture, it has a slow wallclock time if run only on CPUs. 80 | 81 | For a reproduction of the paper's experiment, [see this Weights & Biases project](https://wandb.ai/jmw125/csil/workspace). 82 | 83 | The additional imitiation learning baselines shown in the paper [are available in Acme](https://github.com/google-deepmind/acme/tree/master/examples/baselines/imitation). 84 | 85 | ### Open issues 86 | 87 | [Distribued Acme experiments currently do not finish cleanly, so they appear as 'Crashed' on W&B when they finish successfully.](https://github.com/google-deepmind/acme/issues/312#issue-1990249288) 88 | 89 | The robomimic experiments are currently not open-sourced. 90 | 91 | ## Citing this work 92 | 93 | ```bibtex 94 | @inproceedings{watson2023csil, 95 | author = {Joe Watson and 96 | Sandy H. Huang and 97 | Nicolas Heess}, 98 | title = {Coherent Soft Imitation Learning}, 99 | booktitle = {Advances in Neural Information Processing Systems}, 100 | year = {2023} 101 | } 102 | ``` 103 | 104 | ## Installation 105 | 106 | First clone this code repository into a local directory: 107 | ```bash 108 | git clone https://github.com/google-deepmind/csil.git 109 | cd csil 110 | ``` 111 | 112 | We recommend installing required dependencies inside a 113 | [conda environment](https://www.anaconda.com/). To do this, first install 114 | [Anaconda](https://www.anaconda.com/download#downloads) and then create and 115 | activate the conda environment: 116 | ```bash 117 | conda create --name csil python=3.9 118 | conda activate csil 119 | ``` 120 | CSIL is written in JAX, so first install the correct version of JAX for your system by [following the installation instructions](https://jax.readthedocs.io/en/latest/installation.html). 121 | Acme requires `jax 0.4.3` and will install that version. This may need to be uninstalled for a CUDA-based JAX installation, e.g. 122 | ```bash 123 | pip install jax==0.4.7 https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.7+cuda12.cudnn88-cp39-cp39-manylinux2014_x86_64.whl 124 | ``` 125 | 126 | MuJoCo must also be installed, in order to load the environments. Please follow 127 | the instructions [here](https://github.com/openai/mujoco-py#install-mujoco) to 128 | install the MuJoCo binary and place it in a directory where `mujoco-py` can find 129 | it. 130 | This installation uses `mujoco200`, `gym < 0.24.0` and `mujoco-py 2.0.2.5` for compatibility reasons. 131 | 132 | Then install `pip` and use it to install all the dependencies: 133 | ```bash 134 | pip install -r requirements.txt 135 | ``` 136 | To verify the installation, run 137 | ```bash 138 | python -c "import jax.numpy as jnp; print(jnp.ones((1,)).device); import acme; import mujoco_py; import gym; print(gym.make('HalfCheetah-v2').reset())" 139 | ``` 140 | If this fails, follow the guidance below. 141 | 142 | ## Troubleshooting 143 | 144 | If you get the error 145 | ``` 146 | Command conda not found 147 | ``` 148 | then you need to add the folder where Anaconda is installed to your `PATH` 149 | variable: 150 | ```bash 151 | export PATH=/path/to/anaconda/bin:$PATH 152 | ``` 153 | 154 | If you get the error 155 | ``` 156 | ImportError: libpython3.9.so.1.0: cannot open shared object file: No such file or directory 157 | ``` 158 | first activate the conda environment and then add it to the `LD_LIBRARY_PATH`: 159 | ```bash 160 | conda activate csil 161 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:$CONDA_PREFIX/lib" 162 | ``` 163 | 164 | If you get the error 165 | ``` 166 | cannot find -lGL: No such file or directory 167 | ``` 168 | then install libGL with: 169 | ``` 170 | sudo apt install libgl-dev 171 | ``` 172 | 173 | 174 | If you get the error 175 | ``` 176 | fatal error: GL/glew.h: No such file or directory 177 | ``` 178 | then you need to install the following in your conda environment and update the 179 | `CPATH`: 180 | ```bash 181 | conda install -c conda-forge glew 182 | conda install -c conda-forge mesalib 183 | conda install -c menpo glfw3 184 | export CPATH=$CONDA_PREFIX/include 185 | ``` 186 | 187 | If you get the error 188 | ``` 189 | ImportError: libgmpxx.so.4: cannot open shared object file: No such file or directory 190 | ``` 191 | then you need to install the following in your conda environment and update the 192 | `CPATH`: 193 | ```bash 194 | conda install -c conda-forge gmp 195 | export CPATH=$CONDA_PREFIX/include 196 | ``` 197 | If you get the error 198 | ```commandline 199 | ImportError: ../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /lib/x86_64-linux-gnu/libLLVM-15.so.1) 200 | ``` 201 | try 202 | ```commandline 203 | mv libstdc++.so.6 libstdc++.so.6.old 204 | ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6 libstdc++.so.6 205 | ``` 206 | according to [this advice](https://stackoverflow.com/a/73708979). 207 | 208 | ## License and disclaimer 209 | 210 | Copyright 2023 DeepMind Technologies Limited 211 | 212 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 213 | you may not use this file except in compliance with the Apache 2.0 license. 214 | You may obtain a copy of the Apache 2.0 license at: 215 | https://www.apache.org/licenses/LICENSE-2.0 216 | 217 | All other materials are licensed under the Creative Commons Attribution 4.0 218 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 219 | https://creativecommons.org/licenses/by/4.0/legalcode 220 | 221 | Unless required by applicable law or agreed to in writing, all software and 222 | materials distributed here under the Apache 2.0 or CC-BY licenses are 223 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 224 | either express or implied. See the licenses for the specific language governing 225 | permissions and limitations under those licenses. 226 | 227 | This is not an official Google product. 228 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helpers for experiments.""" 17 | 18 | import itertools 19 | from typing import Any, Callable, Iterator 20 | 21 | from acme import specs 22 | from acme import types 23 | from acme import wrappers 24 | from acme.datasets import tfds 25 | import d4rl # pylint: disable=unused-import 26 | import dm_env 27 | import gym 28 | import jax 29 | import jax.numpy as jnp 30 | import rlds 31 | import tensorflow as tf 32 | import tree 33 | 34 | from sil import learning 35 | 36 | 37 | ImitationIterator = Callable[ 38 | [int, jax.Array], Iterator[learning.ImitationSample] 39 | ] 40 | TransitionIterator = Callable[ 41 | [int, jax.Array], Iterator[types.Transition] 42 | ] 43 | 44 | # RLDS TFDS file names, see www.tensorflow.org/datasets/catalog/ 45 | EXPERT_DATASET_NAMES = { 46 | 'HalfCheetah-v2': 'locomotion/halfcheetah_sac_1M_single_policy_stochastic', 47 | 'Ant-v2': 'locomotion/ant_sac_1M_single_policy_stochastic', 48 | 'Walker2d-v2': 'locomotion/walker2d_sac_1M_single_policy_stochastic', 49 | 'Hopper-v2': 'locomotion/hopper_sac_1M_single_policy_stochastic', 50 | 'Humanoid-v2': 'locomotion/humanoid_sac_15M_single_policy_stochastic', 51 | 'door-v0': 'd4rl_adroit_door/v0-expert', 52 | 'hammer-v0': 'd4rl_adroit_hammer/v0-expert', 53 | 'pen-v0': 'd4rl_adroit_pen/v0-expert', 54 | } 55 | 56 | OFFLINE_DATASET_NAMES = { 57 | 'HalfCheetah-v2': 'd4rl_mujoco_halfcheetah/v2-full-replay', 58 | 'Ant-v2': 'd4rl_mujoco_ant/v2-full-replay', 59 | 'Walker2d-v2': 'd4rl_mujoco_walker2d/v2-full-replay', 60 | 'Hopper-v2': 'd4rl_mujoco_hopper/v2-full-replay', 61 | # No alternative dataset for Humanoid-v2. 62 | 'Humanoid-v2': 'locomotion/humanoid_sac_15M_single_policy_stochastic', 63 | # Adroit doesn't have suboptimal datasets, so use human demos as alternative 64 | # the other options are -human or -cloned (50/50 human and agent). 65 | # Human data is hard for BC to learn from so we compromise with expert. 66 | 'door-v0': 'd4rl_adroit_door/v0-cloned', 67 | 'hammer-v0': 'd4rl_adroit_hammer/v0-cloned', 68 | 'pen-v0': 'd4rl_adroit_pen/v0-cloned', 69 | } 70 | 71 | 72 | class RandomIterator(Iterator[Any]): 73 | 74 | def __init__(self, dataset: tf.data.Dataset, batch_size: int, seed: int): 75 | dataset = dataset.shuffle(buffer_size=batch_size * 2, seed=seed) 76 | dataset = dataset.batch(batch_size, drop_remainder=True) 77 | self.iterator = itertools.cycle(dataset.as_numpy_iterator()) 78 | 79 | def __next__(self) -> Any: 80 | return self.iterator.__next__() 81 | 82 | 83 | class OfflineImitationIterator(Iterator[learning.ImitationSample]): 84 | """ImitationSample iterator for offline IL.""" 85 | 86 | def __init__( 87 | self, 88 | expert_iterator: Iterator[types.Transition], 89 | offline_iterator: Iterator[types.Transition], 90 | ): 91 | self._expert = expert_iterator 92 | self._offline = offline_iterator 93 | 94 | def __next__(self) -> learning.ImitationSample: 95 | """Combine data streams into an ImitationSample iterator.""" 96 | return learning.ImitationSample( 97 | online_sample=self._offline.__next__(), 98 | demonstration_sample=self._expert.__next__(), 99 | ) 100 | 101 | 102 | class MixedIterator(Iterator[types.Transition]): 103 | """Combine two streams of transitions 50/50.""" 104 | 105 | def __init__( 106 | self, 107 | first_iterator: Iterator[types.Transition], 108 | second_iterator: Iterator[types.Transition], 109 | with_extras: bool = False, 110 | ): 111 | self._first = first_iterator 112 | self._second = second_iterator 113 | self.with_extras = with_extras 114 | 115 | def __next__(self) -> types.Transition: 116 | """Combine data streams 50/50 into one, with the equal batch size.""" 117 | combined = tree.map_structure( 118 | lambda x, y: jnp.concatenate((x, y), axis=0), 119 | self._first.__next__(), 120 | self._second.__next__(), 121 | ) 122 | return tree.map_structure(lambda x: x[::2, ...], combined) 123 | 124 | 125 | def get_dataset_name(env_name: str, expert: bool = True) -> str: 126 | """Map environment to an expert or non-expert dataset name.""" 127 | if expert: 128 | assert ( 129 | env_name in EXPERT_DATASET_NAMES 130 | ), f"Choose from {', '.join(EXPERT_DATASET_NAMES)}" 131 | return EXPERT_DATASET_NAMES[env_name] 132 | else: # Arbitrary offline data. 133 | assert ( 134 | env_name in OFFLINE_DATASET_NAMES 135 | ), f"Choose from {', '.join(OFFLINE_DATASET_NAMES)}" 136 | return OFFLINE_DATASET_NAMES[env_name] 137 | 138 | 139 | def add_next_action_extras( 140 | transitions_iterator: tf.data.Dataset, 141 | ) -> tf.data.Dataset: 142 | """Creates transitions with next-action as extras information.""" 143 | 144 | def _add_next_action_extras( 145 | double_transitions: types.Transition, 146 | ) -> types.Transition: 147 | """Creates a new transition containing the next action in extras.""" 148 | # Observations may be dictionary or ndarray. 149 | get_obs = lambda x: tree.map_structure(lambda y: y[0], x) 150 | obs = get_obs(double_transitions.observation) 151 | next_obs = get_obs(double_transitions.next_observation) 152 | return types.Transition( 153 | observation=obs, 154 | action=double_transitions.action[0], 155 | reward=double_transitions.reward[0], 156 | discount=double_transitions.discount[0], 157 | next_observation=next_obs, 158 | extras={'next_action': double_transitions.action[1]}, 159 | ) 160 | 161 | double_transitions = rlds.transformations.batch( 162 | transitions_iterator, size=2, shift=1, drop_remainder=True 163 | ) 164 | return double_transitions.map(_add_next_action_extras) 165 | 166 | 167 | def get_offline_dataset( 168 | task: str, 169 | environment_spec: specs.EnvironmentSpec, 170 | expert_num_demonstration: int, 171 | offline_num_demonstrations: int, 172 | expert_offline_data: bool = False, 173 | use_sarsa: bool = False, 174 | in_memory: bool = True, 175 | ) -> tuple[ImitationIterator, TransitionIterator]: 176 | """Get the offline dataset for a given task.""" 177 | expert_dataset_name = get_dataset_name(task, expert=True) 178 | offline_dataset_name = get_dataset_name(task, expert=expert_offline_data) 179 | 180 | # Note: For offline learning we take the key, not a seed. 181 | def make_offline_dataset( 182 | batch_size: int, key: jax.Array 183 | ) -> Iterator[types.Transition]: 184 | offline_transitions_iterator = tfds.get_tfds_dataset( 185 | offline_dataset_name, 186 | offline_num_demonstrations, 187 | env_spec=environment_spec, 188 | ) 189 | if use_sarsa: 190 | offline_transitions_iterator = add_next_action_extras( 191 | offline_transitions_iterator 192 | ) 193 | if in_memory: 194 | return tfds.JaxInMemoryRandomSampleIterator( 195 | dataset=offline_transitions_iterator, key=key, batch_size=batch_size 196 | ) 197 | else: 198 | return RandomIterator( 199 | offline_transitions_iterator, batch_size, seed=int(key)) 200 | 201 | def make_imitation_dataset( 202 | batch_size: int, 203 | key: jax.Array, 204 | ) -> Iterator[learning.ImitationSample]: 205 | expert_transitions_iterator = tfds.get_tfds_dataset( 206 | expert_dataset_name, expert_num_demonstration, env_spec=environment_spec 207 | ) 208 | if use_sarsa: 209 | expert_transitions_iterator = add_next_action_extras( 210 | expert_transitions_iterator 211 | ) 212 | if in_memory: 213 | expert_iterator = tfds.JaxInMemoryRandomSampleIterator( 214 | dataset=expert_transitions_iterator, key=key, batch_size=batch_size 215 | ) 216 | else: 217 | expert_iterator = RandomIterator( 218 | expert_transitions_iterator, batch_size, seed=int(key)) 219 | offline_iterator = make_offline_dataset(batch_size, key) 220 | return OfflineImitationIterator(expert_iterator, offline_iterator) 221 | 222 | return make_imitation_dataset, make_offline_dataset 223 | 224 | 225 | def get_env_and_demonstrations( 226 | task: str, 227 | num_demonstrations: int, 228 | expert: bool = True, 229 | use_sarsa: bool = False, 230 | in_memory: bool = True, 231 | ) -> tuple[ 232 | Callable[[], dm_env.Environment], 233 | specs.EnvironmentSpec, 234 | Callable[[int, int], Iterator[types.Transition]], 235 | ]: 236 | """Returns environment, spec and expert demonstration iterator.""" 237 | make_env = make_environment(task) 238 | 239 | environment_spec = specs.make_environment_spec(make_env()) 240 | 241 | # Create demonstrations function. 242 | dataset_name = get_dataset_name(task, expert=expert) 243 | 244 | def make_demonstrations( 245 | batch_size: int, seed: int = 0 246 | ) -> Iterator[types.Transition]: 247 | transitions_iterator = tfds.get_tfds_dataset( 248 | dataset_name, num_demonstrations, env_spec=environment_spec 249 | ) 250 | if use_sarsa: 251 | transitions_iterator = add_next_action_extras(transitions_iterator) 252 | if in_memory: 253 | return tfds.JaxInMemoryRandomSampleIterator( 254 | dataset=transitions_iterator, 255 | key=jax.random.PRNGKey(seed), 256 | batch_size=batch_size, 257 | ) 258 | else: 259 | return RandomIterator( 260 | transitions_iterator, batch_size, seed=seed) 261 | return make_env, environment_spec, make_demonstrations 262 | 263 | 264 | def make_environment(task: str) -> Callable[[], dm_env.Environment]: 265 | """Makes the requested continuous control environment. 266 | 267 | Args: 268 | task: Task to load. 269 | 270 | Returns: 271 | An environment satisfying the dm_env interface expected by Acme agents. 272 | """ 273 | 274 | def make_env(): 275 | env = gym.make(task) 276 | env = wrappers.GymWrapper(env) 277 | # Make sure the environment obeys the dm_env.Environment interface. 278 | 279 | # Wrap the environment so the expected continuous action spec is [-1, 1]. 280 | # Note: This is a no-op on 'control' tasks. 281 | env = wrappers.CanonicalSpecWrapper(env, clip=True) 282 | env = wrappers.SinglePrecisionWrapper(env) 283 | return env 284 | 285 | env = make_env() 286 | env.render(mode='rgb_array') 287 | 288 | return make_env 289 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /run_ppil.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Example running proximal point imitation learning on continuous control tasks.""" 17 | 18 | from absl import flags 19 | from acme import specs 20 | from acme.agents.jax import sac 21 | from absl import app 22 | from acme.jax import experiments 23 | from acme.utils import lp_utils 24 | import jax 25 | import launchpad as lp 26 | 27 | import experiment_logger 28 | import helpers 29 | from sil import builder 30 | from sil import config as sil_config 31 | from sil import evaluator 32 | from sil import networks 33 | 34 | 35 | _DIST_FLAG = flags.DEFINE_bool( 36 | 'run_distributed', 37 | False, 38 | ( 39 | 'Should an agent be executed in a distributed ' 40 | 'way. If False, will run single-threaded.' 41 | ), 42 | ) 43 | _ENV_NAME = flags.DEFINE_string( 44 | 'env_name', 'HalfCheetah-v2', 'Which environment to run' 45 | ) 46 | _SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') 47 | _N_STEPS = flags.DEFINE_integer( 48 | 'num_steps', 250_000, 'Number of env steps to run.' 49 | ) 50 | _EVAL_RATIO = flags.DEFINE_integer( 51 | 'eval_every', 5_000, 'How often to evaluate for local runs.' 52 | ) 53 | _N_EVAL_EPS = flags.DEFINE_integer( 54 | 'evaluation_episodes', 1, 'Evaluation episodes for local runs.' 55 | ) 56 | _N_DEMONSTRATIONS = flags.DEFINE_integer( 57 | 'num_demonstrations', 25, 'No. of demonstration trajectories.' 58 | ) 59 | _N_OFFLINE_DATASET = flags.DEFINE_integer( 60 | 'num_offline_demonstrations', 1_000, 'Offline dataset size.' 61 | ) 62 | _BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.') 63 | _DISCOUNT = flags.DEFINE_float('discount', 0.99, 'Discount factor') 64 | _ACTOR_LR = flags.DEFINE_float('actor_lr', 3e-5, 'Actor learning rate.') 65 | _CRITIC_LR = flags.DEFINE_float('critic_lr', 3e-4, 'Critic learning rate.') 66 | _REWARD_LR = flags.DEFINE_float('reward_lr', 3e-4, 'Reward learning rate.') 67 | _TAU = flags.DEFINE_float( 68 | 'tau', 69 | 0.005, 70 | ( 71 | 'Target network exponential smoothing weight.' 72 | '1. = no update, 0, = no smoothing.' 73 | ), 74 | ) 75 | _ENT_COEF = flags.DEFINE_float( 76 | 'entropy_coefficient', 77 | None, 78 | 'Entropy coefficient. Becomes adaptive if None.', 79 | ) 80 | _CRITIC_ACTOR_RATIO = flags.DEFINE_integer( 81 | 'critic_actor_update_ratio', 20, 'Critic updates per actor update.' 82 | ) 83 | _SGD_STEPS = flags.DEFINE_integer( 84 | 'sgd_steps', 1, 'SGD steps for online sample.' 85 | ) 86 | _CRITIC_NETWORK = flags.DEFINE_multi_integer( 87 | 'critic_network', [256, 256], 'Define critic architecture.' 88 | ) 89 | _REWARD_NETWORK = flags.DEFINE_multi_integer( 90 | 'reward_network', [256, 256], 'Define reward architecture.' 91 | ) 92 | _POLICY_NETWORK = flags.DEFINE_multi_integer( 93 | 'policy_network', [256, 256], 'Define policy architecture.' 94 | ) 95 | _POLICY_MODEL = flags.DEFINE_enum( 96 | 'policy_model', 97 | networks.PolicyArchitectures.MLP.value, 98 | [e.value for e in networks.PolicyArchitectures], 99 | 'Define policy model type.', 100 | ) 101 | _CRITIC_MODEL = flags.DEFINE_enum( 102 | 'critic_model', 103 | networks.CriticArchitectures.LNMLP.value, 104 | [e.value for e in networks.CriticArchitectures], 105 | 'Define critic model type.', 106 | ) 107 | _REWARD_MODEL = flags.DEFINE_enum( 108 | 'reward_model', 109 | networks.RewardArchitectures.LNMLP.value, 110 | [e.value for e in networks.RewardArchitectures], 111 | 'Define reward model type.', 112 | ) 113 | _BET = flags.DEFINE_float( 114 | 'bellman_error_temp', 0.08, 'Temperature of logistic Bellman error term.' 115 | ) 116 | _CSIL_ALPHA = flags.DEFINE_float('csil_alpha', None, 'CSIL reward temperature.') 117 | _BC_ACTOR_LOSS = flags.DEFINE_bool( 118 | 'bc_actor_loss', False, 'Have expert BC term in actor loss.' 119 | ) 120 | _PRETRAIN_BC = flags.DEFINE_bool( 121 | 'pretrain_bc', False, 'Pretrain policy from demonstrations.' 122 | ) 123 | _BC_PRIOR = flags.DEFINE_bool( 124 | 'bc_prior', False, 'Used pretrained BC policy as prior.' 125 | ) 126 | _POLICY_PRETRAIN_STEPS = flags.DEFINE_integer( 127 | 'policy_pretrain_steps', 25_000, 'Policy pretraining steps.' 128 | ) 129 | _POLICY_PRETRAIN_LR = flags.DEFINE_float( 130 | 'policy_pretrain_lr', 1e-3, 'Policy pretraining learning rate.' 131 | ) 132 | _LOSS_TYPE = flags.DEFINE_enum( 133 | 'loss_type', 134 | sil_config.Losses.FAITHFUL.value, 135 | [e.value for e in sil_config.Losses], 136 | 'Define regression loss type.', 137 | ) 138 | _OFFLINE_FLAG = flags.DEFINE_bool('offline', False, 'Run an offline agent.') 139 | _EVAL_BC = flags.DEFINE_bool( 140 | 'eval_bc', False, 'Run evaluator of BC policy for comparison') 141 | _CHECKPOINTING = flags.DEFINE_bool( 142 | 'checkpoint', False, 'Save models during training.' 143 | ) 144 | _WANDB = flags.DEFINE_bool( 145 | 'wandb', True, 'Use weights and biases logging.' 146 | ) 147 | _NAME = flags.DEFINE_string('name', 'camera-ready', 'Experiment name') 148 | 149 | def build_experiment_config(): 150 | """Builds a P2IL experiment config which can be executed in different ways.""" 151 | # Create an environment, grab the spec, and use it to create networks. 152 | 153 | task = _ENV_NAME.value 154 | 155 | mode = f'{"off" if _OFFLINE_FLAG.value else "on"}line' 156 | name = f'ppil_{task}_{mode}' 157 | group = (f'{name}, {_NAME.value}, ' 158 | f'ndemos={_N_DEMONSTRATIONS.value}, ' 159 | f'alpha={_ENT_COEF.value}') 160 | wandb_kwargs = { 161 | 'project': 'csil', 162 | 'name': name, 163 | 'group': group, 164 | 'tags': ['ppil', task, mode, jax.default_backend()], 165 | 'mode': 'online' if _WANDB.value else 'disabled', 166 | } 167 | 168 | logger_fact = experiment_logger.make_experiment_logger_factory(wandb_kwargs) 169 | 170 | make_env, env_spec, make_demonstrations = helpers.get_env_and_demonstrations( 171 | task, _N_DEMONSTRATIONS.value 172 | ) 173 | 174 | def environment_factory(seed: int): 175 | del seed 176 | return make_env() 177 | 178 | batch_size = _BATCH_SIZE.value 179 | seed = _SEED.value 180 | actor_lr = _ACTOR_LR.value 181 | 182 | make_demonstrations_ = lambda batchsize: make_demonstrations(batchsize, seed) 183 | 184 | if _PRETRAIN_BC.value: 185 | dataset_factory = lambda seed_: make_demonstrations(batch_size, seed_) 186 | policy_pretraining = [ 187 | sil_config.PretrainingConfig( 188 | loss=sil_config.Losses(_LOSS_TYPE.value), 189 | seed=seed, 190 | dataset_factory=dataset_factory, 191 | steps=_POLICY_PRETRAIN_STEPS.value, 192 | learning_rate=_POLICY_PRETRAIN_LR.value, 193 | use_as_reference=_BC_PRIOR.value, 194 | ), 195 | ] 196 | else: 197 | policy_pretraining = None 198 | 199 | critic_layers, reward_layers, policy_layers = ( 200 | _CRITIC_NETWORK.value, 201 | _REWARD_NETWORK.value, 202 | _POLICY_NETWORK.value, 203 | ) 204 | 205 | policy_architecture = networks.PolicyArchitectures(_POLICY_MODEL.value) 206 | critic_architecture = networks.CriticArchitectures(_CRITIC_MODEL.value) 207 | reward_architecture = networks.RewardArchitectures(_REWARD_MODEL.value) 208 | csil_alpha = _CSIL_ALPHA.value 209 | 210 | def network_factory(spec: specs.EnvironmentSpec): 211 | return networks.make_networks( 212 | spec, 213 | policy_architecture=policy_architecture, 214 | critic_architecture=critic_architecture, 215 | reward_architecture=reward_architecture, 216 | critic_hidden_layer_sizes=tuple(critic_layers), 217 | reward_hidden_layer_sizes=tuple(reward_layers), 218 | policy_hidden_layer_sizes=tuple(policy_layers), 219 | reward_policy_coherence_alpha=csil_alpha, 220 | ) 221 | 222 | if _ENT_COEF.value is not None and _ENT_COEF.value > 0.0: 223 | kwargs = {'entropy_coefficient': _ENT_COEF.value} 224 | else: 225 | kwargs = {'target_entropy': sac.target_entropy_from_env_spec(env_spec)} 226 | 227 | # Construct the agent. 228 | config = sil_config.SILConfig( 229 | imitation=sil_config.ProximalPointConfig( 230 | bellman_error_temperature=_BET.value), 231 | actor_bc_loss=_BC_ACTOR_LOSS.value, 232 | policy_pretraining=policy_pretraining, 233 | expert_demonstration_factory=make_demonstrations_, 234 | discount=_DISCOUNT.value, 235 | critic_learning_rate=_CRITIC_LR.value, 236 | reward_learning_rate=_REWARD_LR.value, 237 | actor_learning_rate=actor_lr, 238 | num_sgd_steps_per_step=_SGD_STEPS.value, 239 | critic_actor_update_ratio=_CRITIC_ACTOR_RATIO.value, 240 | n_step=1, 241 | tau=_TAU.value, 242 | batch_size=batch_size, 243 | **kwargs, 244 | ) 245 | 246 | sil_builder = builder.SILBuilder(config) 247 | 248 | imitation_evaluator_factory = evaluator.imitation_evaluator_factory( 249 | agent_config=config, 250 | environment_factory=environment_factory, 251 | network_factory=network_factory, 252 | policy_factory=sil_builder.make_policy, 253 | logger_factory=logger_fact, 254 | ) 255 | 256 | evaluators = [imitation_evaluator_factory,] 257 | 258 | if _PRETRAIN_BC.value and _EVAL_BC.value: 259 | bc_evaluator_factory = evaluator.bc_evaluator_factory( 260 | environment_factory=environment_factory, 261 | network_factory=network_factory, 262 | policy_factory=sil_builder.make_policy, 263 | logger_factory=logger_fact, 264 | ) 265 | evaluators += [bc_evaluator_factory,] 266 | 267 | checkpoint_config = (experiments.CheckointingConfig() 268 | if _CHECKPOINTING.value else None) 269 | if _OFFLINE_FLAG.value: 270 | make_dataset, _ = helpers.get_offline_dataset( 271 | task, env_spec, _N_DEMONSTRATIONS.value, _N_OFFLINE_DATASET.value 272 | ) 273 | # The offline runner needs a random seed for the dataset. 274 | make_dataset_ = lambda k: make_dataset(batch_size, k) 275 | return experiments.OfflineExperimentConfig( 276 | builder=sil_builder, 277 | environment_factory=environment_factory, 278 | network_factory=network_factory, 279 | demonstration_dataset_factory=make_dataset_, 280 | evaluator_factories=evaluators, 281 | max_num_learner_steps=_N_STEPS.value, 282 | environment_spec=env_spec, 283 | seed=_SEED.value, 284 | logger_factory=logger_fact, 285 | checkpointing=checkpoint_config, 286 | ) 287 | else: 288 | return experiments.ExperimentConfig( 289 | builder=sil_builder, 290 | environment_factory=environment_factory, 291 | network_factory=network_factory, 292 | evaluator_factories=evaluators, 293 | seed=_SEED.value, 294 | max_num_actor_steps=_N_STEPS.value, 295 | logger_factory=logger_fact, 296 | checkpointing=checkpoint_config, 297 | ) 298 | 299 | 300 | def main(_): 301 | config = build_experiment_config() 302 | if _DIST_FLAG.value: 303 | if _OFFLINE_FLAG.value: 304 | program = experiments.make_distributed_offline_experiment( 305 | experiment=config 306 | ) 307 | else: 308 | program = experiments.make_distributed_experiment( 309 | experiment=config, num_actors=4 310 | ) 311 | lp.launch( 312 | program, 313 | xm_resources=lp_utils.make_xm_docker_resources(program), 314 | ) 315 | else: 316 | if _OFFLINE_FLAG.value: 317 | experiments.run_offline_experiment( 318 | experiment=config, 319 | eval_every=_EVAL_RATIO.value, 320 | num_eval_episodes=_N_EVAL_EPS.value, 321 | ) 322 | else: 323 | experiments.run_experiment( 324 | experiment=config, 325 | eval_every=_EVAL_RATIO.value, 326 | num_eval_episodes=_N_EVAL_EPS.value, 327 | ) 328 | 329 | 330 | if __name__ == '__main__': 331 | app.run(main) 332 | -------------------------------------------------------------------------------- /run_iqlearn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Example running IQ-Learn on continuous control tasks.""" 17 | 18 | from absl import flags 19 | from acme import specs 20 | from acme.agents.jax import sac 21 | from absl import app 22 | from acme.jax import experiments 23 | from acme.utils import lp_utils 24 | import jax 25 | import launchpad as lp 26 | 27 | import experiment_logger 28 | import helpers 29 | from sil import builder 30 | from sil import config as sil_config 31 | from sil import evaluator 32 | from sil import networks 33 | 34 | 35 | _DIST_FLAG = flags.DEFINE_bool( 36 | 'run_distributed', 37 | False, 38 | ( 39 | 'Should an agent be executed in a distributed ' 40 | 'way. If False, will run single-threaded.' 41 | ), 42 | ) 43 | _ENV_NAME = flags.DEFINE_string( 44 | 'env_name', 'HalfCheetah-v2', 'Which environment to run' 45 | ) 46 | _SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') 47 | _N_STEPS = flags.DEFINE_integer( 48 | 'num_steps', 250_000, 'Number of env steps to run.' 49 | ) 50 | _EVAL_RATIO = flags.DEFINE_integer( 51 | 'eval_every', 5_000, 'How often to run evaluation.' 52 | ) 53 | _N_EVAL_EPS = flags.DEFINE_integer( 54 | 'evaluation_episodes', 1, 'Evaluation episodes.' 55 | ) 56 | _N_DEMOS = flags.DEFINE_integer( 57 | 'num_demonstrations', 25, 'Number of demonstration trajectories.' 58 | ) 59 | _BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.') 60 | _N_OFFLINE_DATASET = flags.DEFINE_integer( 61 | 'num_offline_demonstrations', 25, 'Offline dataset size.' 62 | ) 63 | _DISCOUNT = flags.DEFINE_float('discount', 0.99, 'Discount factor') 64 | _ACTOR_LR = flags.DEFINE_float('actor_lr', 3e-5, 'Actor learning rate.') 65 | _CRITIC_LR = flags.DEFINE_float('critic_lr', 3e-4, 'Critic learning rate.') 66 | _REWARD_LR = flags.DEFINE_float( 67 | 'reward_lr', 3e-4, 'Reward learning rate (unused).' 68 | ) 69 | _ENT_COEF = flags.DEFINE_float( 70 | 'entropy_coefficient', 71 | 0.01, 72 | 'Entropy coefficient. Becomes adaptive if None.', 73 | ) 74 | _TAU = flags.DEFINE_float( 75 | 'tau', 76 | 0.005, 77 | ( 78 | 'Target network exponential smoothing weight.' 79 | '1. = no update, 0, = no smoothing.' 80 | ), 81 | ) 82 | _OFFLINE_FLAG = flags.DEFINE_bool('offline', False, 'Run an offline agent.') 83 | _CRITIC_ACTOR_RATIO = flags.DEFINE_integer( 84 | 'critic_actor_update_ratio', 1, 'Critic updates per actor update.' 85 | ) 86 | _SGD_STEPS = flags.DEFINE_integer( 87 | 'sgd_steps', 1, 'SGD steps for online sample.' 88 | ) 89 | _CRITIC_NETWORK = flags.DEFINE_multi_integer( 90 | 'critic_network', [256, 256], 'Define critic architecture.' 91 | ) 92 | _POLICY_NETWORK = flags.DEFINE_multi_integer( 93 | 'policy_network', [256, 256, 12, 256], 'Define policy architecture.' 94 | ) 95 | _POLICY_MODEL = flags.DEFINE_enum( 96 | 'policy_model', 97 | networks.PolicyArchitectures.MLP.value, 98 | [e.value for e in networks.PolicyArchitectures], 99 | 'Define policy model type.', 100 | ) 101 | _CRITIC_MODEL = flags.DEFINE_enum( 102 | 'critic_model', 103 | networks.CriticArchitectures.LNMLP.value, 104 | [e.value for e in networks.CriticArchitectures], 105 | 'Define policy model type.', 106 | ) 107 | _BC_ACTOR_LOSS = flags.DEFINE_bool( 108 | 'bc_actor_loss', False, 'Have expert BC term in actor loss.' 109 | ) 110 | _PRETRAIN_BC = flags.DEFINE_bool( 111 | 'pretrain_bc', False, 'Pretrain policy from demonstrations.' 112 | ) 113 | _BC_PRIOR = flags.DEFINE_bool( 114 | 'bc_prior', False, 'Used pretrained BC policy as prior.' 115 | ) 116 | _POLICY_PRETRAIN_STEPS = flags.DEFINE_integer( 117 | 'policy_pretrain_steps', 25_000, 'Policy pretraining steps.' 118 | ) 119 | _POLICY_PRETRAIN_LR = flags.DEFINE_float( 120 | 'policy_pretrain_lr', 1e-3, 'Policy pretraining learning rate.' 121 | ) 122 | _LOSS_TYPE = flags.DEFINE_enum( 123 | 'loss_type', 124 | sil_config.Losses.FAITHFUL.value, 125 | [e.value for e in sil_config.Losses], 126 | 'Define regression loss type.', 127 | ) 128 | _LAYERNORM = flags.DEFINE_bool( 129 | 'policy_layer_norm', False, 'Use layer norm for first layer of the policy.' 130 | ) 131 | _EVAL_BC = flags.DEFINE_bool('eval_bc', False, 132 | 'Run evaluator of BC policy for comparison') 133 | _EVAL_PER_VIDEO = flags.DEFINE_integer( 134 | 'evals_per_video', 0, 'Video frequency. Disable using 0.' 135 | ) 136 | _CHECKPOINTING = flags.DEFINE_bool( 137 | 'checkpoint', False, 'Save models during training.' 138 | ) 139 | _WANDB = flags.DEFINE_bool( 140 | 'wandb', True, 'Use weights and biases logging.') 141 | 142 | _NAME = flags.DEFINE_string('name', 'camera-ready', 'Experiment name') 143 | 144 | def _build_experiment_config(): 145 | """Builds an IQ-Learn experiment config which can be executed in different ways.""" 146 | # Create an environment, grab the spec, and use it to create networks. 147 | 148 | task = _ENV_NAME.value 149 | 150 | mode = f'{"off" if _OFFLINE_FLAG.value else "on"}line' 151 | name = f'iqlearn_{task}_{mode}' 152 | group = (f'{name}, {_NAME.value}, ' 153 | f'ndemos={_N_DEMOS.value}, ' 154 | f'alpha={_ENT_COEF.value}') 155 | wandb_kwargs = { 156 | 'project': 'csil', 157 | 'name': name, 158 | 'group': group, 159 | 'tags': ['iqlearn', task, mode, jax.default_backend()], 160 | 'config': flags.FLAGS._flags(), 161 | 'mode': 'online' if _WANDB.value else 'disabled', 162 | } 163 | 164 | logger_fact = experiment_logger.make_experiment_logger_factory(wandb_kwargs) 165 | 166 | make_env, env_spec, make_demonstrations = helpers.get_env_and_demonstrations( 167 | task, _N_DEMOS.value, use_sarsa=False 168 | ) 169 | 170 | def environment_factory(seed: int): 171 | del seed 172 | return make_env() 173 | 174 | batch_size = _BATCH_SIZE.value 175 | seed = _SEED.value 176 | actor_lr = _ACTOR_LR.value 177 | 178 | make_demonstrations_ = lambda batchsize: make_demonstrations(batchsize, seed) 179 | 180 | if _PRETRAIN_BC.value: 181 | dataset_factory = lambda seed_: make_demonstrations(batch_size, seed_) 182 | policy_pretraining = [ 183 | sil_config.PretrainingConfig( 184 | loss=sil_config.Losses(_LOSS_TYPE.value), 185 | seed=seed, 186 | dataset_factory=dataset_factory, 187 | steps=_POLICY_PRETRAIN_STEPS.value, 188 | learning_rate=_POLICY_PRETRAIN_LR.value, 189 | use_as_reference=_BC_PRIOR.value, 190 | ), 191 | ] 192 | else: 193 | policy_pretraining = None 194 | 195 | critic_layers = _CRITIC_NETWORK.value 196 | policy_layers = _POLICY_NETWORK.value 197 | policy_architecture = networks.PolicyArchitectures(_POLICY_MODEL.value) 198 | critic_architecture = networks.CriticArchitectures(_CRITIC_MODEL.value) 199 | use_layer_norm = _LAYERNORM.value 200 | 201 | def network_factory(spec: specs.EnvironmentSpec): 202 | return networks.make_networks( 203 | spec, 204 | policy_architecture=policy_architecture, 205 | critic_architecture=critic_architecture, 206 | critic_hidden_layer_sizes=tuple(critic_layers), 207 | policy_hidden_layer_sizes=tuple(policy_layers), 208 | layer_norm_policy=use_layer_norm, 209 | ) 210 | 211 | if _ENT_COEF.value is not None and _ENT_COEF.value > 0.0: 212 | kwargs = {'entropy_coefficient': _ENT_COEF.value} 213 | else: 214 | kwargs = {'target_entropy': sac.target_entropy_from_env_spec(env_spec)} 215 | 216 | # Construct the agent. 217 | config = sil_config.SILConfig( 218 | expert_demonstration_factory=make_demonstrations_, 219 | imitation=sil_config.InverseSoftQConfig( 220 | divergence=sil_config.Divergence.CHI), 221 | actor_bc_loss=_BC_ACTOR_LOSS.value, 222 | policy_pretraining=policy_pretraining, 223 | discount=_DISCOUNT.value, 224 | critic_learning_rate=_CRITIC_LR.value, 225 | reward_learning_rate=_REWARD_LR.value, 226 | actor_learning_rate=actor_lr, 227 | num_sgd_steps_per_step=_SGD_STEPS.value, 228 | tau=_TAU.value, 229 | critic_actor_update_ratio=_CRITIC_ACTOR_RATIO.value, 230 | n_step=1, 231 | batch_size=batch_size, 232 | **kwargs, 233 | ) 234 | 235 | sil_builder = builder.SILBuilder(config) 236 | 237 | imitation_evaluator_factory = evaluator.imitation_evaluator_factory( 238 | agent_config=config, 239 | environment_factory=environment_factory, 240 | network_factory=network_factory, 241 | policy_factory=sil_builder.make_policy, 242 | logger_factory=logger_fact, 243 | ) 244 | 245 | evaluators = [imitation_evaluator_factory,] 246 | 247 | if _PRETRAIN_BC.value and _EVAL_BC.value: 248 | bc_evaluator_factory = evaluator.bc_evaluator_factory( 249 | environment_factory=environment_factory, 250 | network_factory=network_factory, 251 | policy_factory=sil_builder.make_policy, 252 | logger_factory=logger_fact, 253 | ) 254 | evaluators += [bc_evaluator_factory,] 255 | 256 | if _EVAL_PER_VIDEO.value > 0: 257 | video_evaluator_factory = evaluator.video_evaluator_factory( 258 | environment_factory=environment_factory, 259 | network_factory=network_factory, 260 | policy_factory=sil_builder.make_policy, 261 | videos_per_eval=_EVAL_PER_VIDEO.value, 262 | logger_factory=logger_fact, 263 | ) 264 | evaluators += [video_evaluator_factory] 265 | 266 | checkpoint_config = (experiments.CheckointingConfig() 267 | if _CHECKPOINTING.value else None) 268 | if _OFFLINE_FLAG.value: 269 | # Note: For offline learning, the dataset needs to contain the offline and 270 | # expert data, so make_demonstrations isn't used and make_dataset combines 271 | # the two, due to how the OfflineBuilder is constructed. 272 | make_dataset, _ = helpers.get_offline_dataset( 273 | task, 274 | env_spec, 275 | _N_DEMOS.value, 276 | _N_OFFLINE_DATASET.value, 277 | use_sarsa=False, 278 | ) 279 | # Offline iterator takes RNG key. 280 | make_dataset_ = lambda k: make_dataset(batch_size, k) 281 | return experiments.OfflineExperimentConfig( 282 | builder=sil_builder, 283 | environment_factory=environment_factory, 284 | network_factory=network_factory, 285 | demonstration_dataset_factory=make_dataset_, 286 | evaluator_factories=evaluators, 287 | max_num_learner_steps=_N_STEPS.value, 288 | environment_spec=env_spec, 289 | seed=_SEED.value, 290 | logger_factory=logger_fact, 291 | checkpointing=checkpoint_config, 292 | ) 293 | else: # Online. 294 | return experiments.ExperimentConfig( 295 | builder=sil_builder, 296 | environment_factory=environment_factory, 297 | network_factory=network_factory, 298 | evaluator_factories=evaluators, 299 | seed=_SEED.value, 300 | max_num_actor_steps=_N_STEPS.value, 301 | logger_factory=logger_fact, 302 | checkpointing=checkpoint_config, 303 | ) 304 | 305 | def main(_): 306 | config = _build_experiment_config() 307 | if _DIST_FLAG.value: 308 | if _OFFLINE_FLAG.value: 309 | program = experiments.make_distributed_offline_experiment( 310 | experiment=config 311 | ) 312 | else: 313 | program = experiments.make_distributed_experiment( 314 | experiment=config, num_actors=4 315 | ) 316 | lp.launch( 317 | program, 318 | xm_resources=lp_utils.make_xm_docker_resources(program), 319 | ) 320 | else: 321 | if _OFFLINE_FLAG.value: 322 | experiments.run_offline_experiment( 323 | experiment=config, 324 | eval_every=_EVAL_RATIO.value, 325 | num_eval_episodes=_N_EVAL_EPS.value, 326 | ) 327 | else: 328 | experiments.run_experiment( 329 | experiment=config, 330 | eval_every=_EVAL_RATIO.value, 331 | num_eval_episodes=_N_EVAL_EPS.value, 332 | ) 333 | 334 | 335 | if __name__ == '__main__': 336 | app.run(main) 337 | -------------------------------------------------------------------------------- /run_csil.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Example running coherent soft imitation learning on continuous control tasks.""" 17 | 18 | import math 19 | from typing import Iterator 20 | import wandb 21 | 22 | from absl import flags 23 | from acme import specs 24 | from acme import types 25 | from acme.agents.jax import sac 26 | from absl import app 27 | from acme.jax import experiments 28 | from acme.utils import lp_utils 29 | import jax 30 | import jax.random as rand 31 | import launchpad as lp 32 | 33 | import experiment_logger 34 | import helpers 35 | from sil import builder 36 | from sil import config as sil_config 37 | from sil import evaluator 38 | from sil import networks 39 | 40 | USE_SARSA = True 41 | 42 | _DIST_FLAG = flags.DEFINE_bool( 43 | 'run_distributed', 44 | False, 45 | ( 46 | 'Should an agent be executed in a distributed ' 47 | 'way. If False, will run single-threaded.' 48 | ), 49 | ) 50 | _ENV_NAME = flags.DEFINE_string( 51 | 'env_name', 'HalfCheetah-v2', 'Which environment to run' 52 | ) 53 | _SEED = flags.DEFINE_integer('seed', 0, 'Random seed.') 54 | _N_STEPS = flags.DEFINE_integer( 55 | 'num_steps', 250_000, 'Number of env steps to run.' 56 | ) 57 | _EVAL_RATIO = flags.DEFINE_integer( 58 | 'eval_every', 1_000, 'How often to evaluate for local runs.' 59 | ) 60 | _N_EVAL_EPS = flags.DEFINE_integer( 61 | 'evaluation_episodes', 1, 'Evaluation episodes for local runs.' 62 | ) 63 | _N_DEMONSTRATIONS = flags.DEFINE_integer( 64 | 'num_demonstrations', 25, 'No. of demonstration trajectories.' 65 | ) 66 | _N_OFFLINE_DATASET = flags.DEFINE_integer( 67 | 'num_offline_demonstrations', 1_000, 'Offline dataset size.' 68 | ) 69 | _BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.') 70 | _ENT_COEF = flags.DEFINE_float('entropy_coefficient', 0.01, 'Temperature') 71 | _ENT_SF = flags.DEFINE_float('ent_sf', 1.0, 'Scale entropy target.') 72 | _DAMP = flags.DEFINE_float('damping', 0.0, 'Constraint damping.') 73 | _SF = flags.DEFINE_float('scale_factor', 1.0, 'Reward loss scale factor.') 74 | _GNSF = flags.DEFINE_float( 75 | 'grad_norm_scale_factor', 1.0, 'Critic grad scale factor.' 76 | ) 77 | _DISCOUNT = flags.DEFINE_float('discount', 0.99, 'Discount factor') 78 | _ACTOR_LR = flags.DEFINE_float('actor_lr', 3e-4, 'Actor learning rate.') 79 | _CRITIC_LR = flags.DEFINE_float('critic_lr', 3e-4, 'Critic learning rate.') 80 | _REWARD_LR = flags.DEFINE_float('reward_lr', 1e-3, 'Reward learning rate.') 81 | _TAU = flags.DEFINE_float( 82 | 'tau', 83 | 0.005, 84 | ( 85 | 'Target network exponential smoothing weight.' 86 | '1. = no update, 0, = no smoothing.' 87 | ), 88 | ) 89 | _CRITIC_ACTOR_RATIO = flags.DEFINE_integer( 90 | 'critic_actor_update_ratio', 1, 'Critic updates per actor update.' 91 | ) 92 | _SGD_STEPS = flags.DEFINE_integer( 93 | 'sgd_steps', 1, 'SGD steps for online sample.' 94 | ) 95 | _CRITIC_NETWORK = flags.DEFINE_multi_integer( 96 | 'critic_network', [256, 256], 'Define critic architecture.' 97 | ) 98 | _REWARD_NETWORK = flags.DEFINE_multi_integer( 99 | 'reward_network', [256, 256], 'Define reward architecture. (Unused)' 100 | ) 101 | _POLICY_NETWORK = flags.DEFINE_multi_integer( 102 | 'policy_network', [256, 256, 12, 256], 'Define policy architecture.' 103 | ) 104 | _POLICY_MODEL = flags.DEFINE_enum( 105 | 'policy_model', 106 | networks.PolicyArchitectures.HETSTATTRI.value, 107 | [e.value for e in networks.PolicyArchitectures], 108 | 'Define policy model type.', 109 | ) 110 | _LAYERNORM = flags.DEFINE_bool( 111 | 'policy_layer_norm', False, 'Use layer norm for first layer of the policy.' 112 | ) 113 | _CRITIC_MODEL = flags.DEFINE_enum( 114 | 'critic_model', 115 | networks.CriticArchitectures.LNMLP.value, 116 | [e.value for e in networks.CriticArchitectures], 117 | 'Define critic model type.', 118 | ) 119 | _REWARD_MODEL = flags.DEFINE_enum( 120 | 'reward_model', 121 | networks.RewardArchitectures.PCSIL.value, 122 | [e.value for e in networks.RewardArchitectures], 123 | 'Define reward model type.', 124 | ) 125 | _RCALE = flags.DEFINE_float('reward_scaling', 1.0, 'Scale learned reward.') 126 | _FINETUNE_R = flags.DEFINE_bool('finetune_reward', True, 'Finetune reward.') 127 | _LOSS_TYPE = flags.DEFINE_enum( 128 | 'loss_type', 129 | sil_config.Losses.FAITHFUL.value, 130 | [e.value for e in sil_config.Losses], 131 | 'Define regression loss type.', 132 | ) 133 | _POLICY_PRETRAIN_STEPS = flags.DEFINE_integer( 134 | 'policy_pretrain_steps', 25_000, 'Policy pretraining steps.' 135 | ) 136 | _POLICY_PRETRAIN_LR = flags.DEFINE_float( 137 | 'policy_pretrain_lr', 1e-3, 'Policy pretraining learning rate.' 138 | ) 139 | _CRITIC_PRETRAIN_STEPS = flags.DEFINE_integer( 140 | 'critic_pretrain_steps', 5_000, 'Critic pretraining steps.' 141 | ) 142 | _CRITIC_PRETRAIN_LR = flags.DEFINE_float( 143 | 'critic_pretrain_lr', 1e-4, 'Critic pretraining learning rate.' 144 | ) 145 | _EVAL_BC = flags.DEFINE_bool('eval_bc', False, 146 | 'Run evaluator of BC policy for comparison') 147 | _OFFLINE_FLAG = flags.DEFINE_bool('offline', False, 'Run an offline agent.') 148 | _EVAL_PER_VIDEO = flags.DEFINE_integer( 149 | 'evals_per_video', 0, 'Video frequency. Disable using 0.' 150 | ) 151 | _NUM_ACTORS = flags.DEFINE_integer( 152 | 'num_actors', 4, 'Number of distributed actors.' 153 | ) 154 | _CHECKPOINTING = flags.DEFINE_bool( 155 | 'checkpoint', False, 'Save models during training.' 156 | ) 157 | _WANDB = flags.DEFINE_bool( 158 | 'wandb', True, 'Use weights and biases logging.' 159 | ) 160 | _NAME = flags.DEFINE_string('name', 'camera-ready', 'Experiment name') 161 | 162 | def _build_experiment_config(): 163 | """Builds a CSIL experiment config which can be executed in different ways.""" 164 | 165 | # Create an environment, grab the spec, and use it to create networks. 166 | task = _ENV_NAME.value 167 | 168 | mode = f'{"off" if _OFFLINE_FLAG.value else "on"}line' 169 | name = f'csil_{task}_{mode}' 170 | group = (f'{name}, {_NAME.value}, ' 171 | f'ndemos={_N_DEMONSTRATIONS.value}, ' 172 | f'alpha={_ENT_COEF.value}') 173 | wandb_kwargs = { 174 | 'project': 'csil', 175 | 'name': name, 176 | 'group': group, 177 | 'tags': ['csil', task, mode, jax.default_backend()], 178 | 'config': flags.FLAGS._flags(), 179 | 'mode': 'online' if _WANDB.value else 'disabled', 180 | } 181 | 182 | logger_fact = experiment_logger.make_experiment_logger_factory(wandb_kwargs) 183 | 184 | make_env, env_spec, make_demonstrations = helpers.get_env_and_demonstrations( 185 | task, _N_DEMONSTRATIONS.value, use_sarsa=USE_SARSA, 186 | in_memory='image' not in task 187 | ) 188 | 189 | def environment_factory(seed: int): 190 | del seed 191 | return make_env() 192 | 193 | batch_size = _BATCH_SIZE.value 194 | seed = _SEED.value 195 | actor_lr = _ACTOR_LR.value 196 | 197 | make_demonstrations_ = lambda batchsize: make_demonstrations(batchsize, seed) 198 | 199 | if _ENT_COEF.value > 0.0: 200 | kwargs = {'entropy_coefficient': _ENT_COEF.value} 201 | else: 202 | target_entropy = _ENT_SF.value * sac.target_entropy_from_env_spec( 203 | env_spec, target_entropy_per_dimension=abs(_ENT_SF.value)) 204 | kwargs = {'target_entropy': target_entropy} 205 | 206 | # Important step that normalizes reward values -- do not change! 207 | csil_alpha = _RCALE.value / math.prod(env_spec.actions.shape) 208 | 209 | policy_architecture = networks.PolicyArchitectures(_POLICY_MODEL.value) 210 | bc_policy_architecture = policy_architecture 211 | critic_architecture = networks.CriticArchitectures(_CRITIC_MODEL.value) 212 | reward_architecture = networks.RewardArchitectures(_REWARD_MODEL.value) 213 | policy_layers = _POLICY_NETWORK.value 214 | reward_layers = _REWARD_NETWORK.value 215 | critic_layers = _CRITIC_NETWORK.value 216 | use_layer_norm = _LAYERNORM.value 217 | 218 | def network_factory(spec: specs.EnvironmentSpec): 219 | return networks.make_networks( 220 | spec=spec, 221 | reward_policy_coherence_alpha=csil_alpha, 222 | policy_architecture=policy_architecture, 223 | critic_architecture=critic_architecture, 224 | reward_architecture=reward_architecture, 225 | bc_policy_architecture=bc_policy_architecture, 226 | policy_hidden_layer_sizes=tuple(policy_layers), 227 | reward_hidden_layer_sizes=tuple(reward_layers), 228 | critic_hidden_layer_sizes=tuple(critic_layers), 229 | bc_policy_hidden_layer_sizes=tuple(policy_layers), 230 | layer_norm_policy=use_layer_norm, 231 | ) 232 | 233 | demo_factory = lambda seed_: make_demonstrations(batch_size, seed_) 234 | policy_pretraining = sil_config.PretrainingConfig( 235 | loss=sil_config.Losses(_LOSS_TYPE.value), 236 | seed=seed, 237 | dataset_factory=demo_factory, 238 | steps=_POLICY_PRETRAIN_STEPS.value, 239 | learning_rate=_POLICY_PRETRAIN_LR.value, 240 | use_as_reference=True, 241 | ) 242 | if _OFFLINE_FLAG.value: 243 | _, offline_dataset = helpers.get_offline_dataset( 244 | task, 245 | env_spec, 246 | _N_DEMONSTRATIONS.value, 247 | _N_OFFLINE_DATASET.value, 248 | use_sarsa=USE_SARSA, 249 | ) 250 | 251 | def offline_pretraining_dataset(rseed: int) -> Iterator[types.Transition]: 252 | rkey = rand.PRNGKey(rseed) 253 | return helpers.MixedIterator( 254 | offline_dataset(batch_size, rkey), 255 | make_demonstrations(batch_size, rseed), 256 | ) 257 | 258 | offline_policy_pretraining = sil_config.PretrainingConfig( 259 | loss=sil_config.Losses(_LOSS_TYPE.value), 260 | seed=seed, 261 | dataset_factory=offline_pretraining_dataset, 262 | steps=_POLICY_PRETRAIN_STEPS.value, 263 | learning_rate=_POLICY_PRETRAIN_LR.value, 264 | ) 265 | policy_pretrainers = [offline_policy_pretraining, policy_pretraining] 266 | critic_dataset = demo_factory 267 | else: 268 | policy_pretrainers = [policy_pretraining,] 269 | critic_dataset = demo_factory 270 | 271 | critic_pretraining = sil_config.PretrainingConfig( 272 | seed=seed, 273 | dataset_factory=critic_dataset, 274 | steps=_CRITIC_PRETRAIN_STEPS.value, 275 | learning_rate=_CRITIC_PRETRAIN_LR.value, 276 | ) 277 | # Construct the agent. 278 | config_ = sil_config.SILConfig( 279 | imitation=sil_config.CoherentConfig( 280 | alpha=csil_alpha, 281 | reward_scaling=_RCALE.value, 282 | refine_reward=_FINETUNE_R.value, 283 | negative_reward=( 284 | reward_architecture == networks.RewardArchitectures.NCSIL), 285 | grad_norm_sf=_GNSF.value, 286 | scale_factor=_SF.value, 287 | ), 288 | actor_bc_loss=False, 289 | policy_pretraining=policy_pretrainers, 290 | critic_pretraining=critic_pretraining, 291 | expert_demonstration_factory=make_demonstrations_, 292 | discount=_DISCOUNT.value, 293 | critic_learning_rate=_CRITIC_LR.value, 294 | reward_learning_rate=_REWARD_LR.value, 295 | actor_learning_rate=actor_lr, 296 | num_sgd_steps_per_step=_SGD_STEPS.value, 297 | critic_actor_update_ratio=_CRITIC_ACTOR_RATIO.value, 298 | n_step=1, 299 | damping=_DAMP.value, 300 | tau=_TAU.value, 301 | batch_size=batch_size, 302 | samples_per_insert=batch_size, 303 | alpha_learning_rate=1e-2, 304 | alpha_init=0.01, 305 | **kwargs, 306 | ) 307 | 308 | sil_builder = builder.SILBuilder(config_) 309 | 310 | imitation_evaluator_factory = evaluator.imitation_evaluator_factory( 311 | agent_config=config_, 312 | environment_factory=environment_factory, 313 | network_factory=network_factory, 314 | policy_factory=sil_builder.make_policy, 315 | logger_factory=logger_fact, 316 | ) 317 | 318 | evaluators = [imitation_evaluator_factory,] 319 | if _EVAL_PER_VIDEO.value > 0: 320 | video_evaluator_factory = evaluator.video_evaluator_factory( 321 | environment_factory=environment_factory, 322 | network_factory=network_factory, 323 | policy_factory=sil_builder.make_policy, 324 | videos_per_eval=_EVAL_PER_VIDEO.value, 325 | logger_factory=logger_fact, 326 | ) 327 | evaluators += [imitation_evaluator_factory,] 328 | 329 | if _EVAL_BC.value: 330 | bc_evaluator_factory = evaluator.bc_evaluator_factory( 331 | environment_factory=environment_factory, 332 | network_factory=network_factory, 333 | policy_factory=sil_builder.make_bc_policy, 334 | logger_factory=logger_fact, 335 | ) 336 | evaluators += [bc_evaluator_factory,] 337 | 338 | checkpoint_config = (experiments.CheckointingConfig() 339 | if _CHECKPOINTING.value else None) 340 | if _OFFLINE_FLAG.value: 341 | make_offline_dataset, _ = helpers.get_offline_dataset( 342 | task, 343 | env_spec, 344 | _N_DEMONSTRATIONS.value, 345 | _N_OFFLINE_DATASET.value, 346 | use_sarsa=USE_SARSA, 347 | ) 348 | make_offline_dataset_ = lambda rk: make_offline_dataset(batch_size, rk) 349 | return experiments.OfflineExperimentConfig( 350 | builder=sil_builder, 351 | environment_factory=environment_factory, 352 | network_factory=network_factory, 353 | demonstration_dataset_factory=make_offline_dataset_, 354 | evaluator_factories=evaluators, 355 | max_num_learner_steps=_N_STEPS.value, 356 | environment_spec=env_spec, 357 | seed=_SEED.value, 358 | logger_factory=logger_fact, 359 | checkpointing=checkpoint_config, 360 | ) 361 | else: 362 | return experiments.ExperimentConfig( 363 | builder=sil_builder, 364 | environment_factory=environment_factory, 365 | network_factory=network_factory, 366 | evaluator_factories=evaluators, 367 | seed=_SEED.value, 368 | max_num_actor_steps=_N_STEPS.value, 369 | logger_factory=logger_fact, 370 | checkpointing=checkpoint_config, 371 | ) 372 | 373 | 374 | def main(_): 375 | config = _build_experiment_config() 376 | if _DIST_FLAG.value: 377 | if _OFFLINE_FLAG.value: 378 | program = experiments.make_distributed_offline_experiment( 379 | experiment=config 380 | ) 381 | else: 382 | program = experiments.make_distributed_experiment( 383 | experiment=config, num_actors=_NUM_ACTORS.value 384 | ) 385 | lp.launch( 386 | program, 387 | xm_resources=lp_utils.make_xm_docker_resources(program), 388 | ) 389 | else: 390 | if _OFFLINE_FLAG.value: 391 | experiments.run_offline_experiment( 392 | experiment=config, 393 | eval_every=_EVAL_RATIO.value, 394 | num_eval_episodes=_N_EVAL_EPS.value, 395 | ) 396 | else: 397 | experiments.run_experiment( 398 | experiment=config, 399 | eval_every=_EVAL_RATIO.value, 400 | num_eval_episodes=_N_EVAL_EPS.value, 401 | ) 402 | 403 | 404 | if __name__ == '__main__': 405 | wandb.setup() 406 | app.run(main) 407 | wandb.finish() 408 | 409 | -------------------------------------------------------------------------------- /sil/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Custom evaluator for soft imitation learning setting. 17 | 18 | Includes the estimated return from the learned reward function. 19 | 20 | Also captures videos. 21 | """ 22 | from typing import Dict, Optional, Sequence, Union 23 | 24 | from absl import logging 25 | import acme 26 | from acme import core 27 | from acme import environment_loop 28 | from acme import specs 29 | from acme.agents.jax import actor_core as actor_core_lib 30 | from acme.agents.jax import actors 31 | from acme.agents.jax import builders 32 | from acme.jax import networks as networks_lib 33 | from acme.jax import types 34 | from acme.jax import utils 35 | from acme.jax import variable_utils 36 | from acme.jax.experiments import config as experiment_config 37 | from acme.utils import counting 38 | from acme.utils import experiment_utils 39 | from acme.utils import loggers 40 | from acme.utils import observers as observers_lib 41 | from acme.wrappers import video as video_wrappers 42 | import dm_env 43 | import jax 44 | from jax import random 45 | import jax.numpy as jnp 46 | import numpy as np 47 | import tree 48 | 49 | from sil import config as sil_config 50 | from sil import networks as sil_networks 51 | 52 | 53 | N_EVALS_PER_VIDEO = 1 54 | 55 | 56 | class GymVideoWrapper(video_wrappers.VideoWrapper): 57 | 58 | def _render_frame(self, observation): 59 | """Renders a frame from the given environment observation.""" 60 | del observation 61 | return self.environment.render(mode='rgb_array') # pytype: disable=attribute-error 62 | 63 | def _write_frames(self): 64 | logging.info( 65 | 'Saving video: %s/%s/%d', self._path, self._filename, self._counter 66 | ) 67 | super()._write_frames() 68 | 69 | 70 | class RewardCore: 71 | """A learned implicit or explicit reward function.""" 72 | 73 | def __init__( 74 | self, 75 | networks: sil_networks.SILNetworks, 76 | reward_factory: sil_config.RewardFact, 77 | discount: float, 78 | variable_source: core.VariableSource, 79 | ): 80 | self._reward_variable_client = variable_utils.VariableClient( 81 | variable_source, 'reward', device='cpu' 82 | ) 83 | self._critic_variable_client = variable_utils.VariableClient( 84 | variable_source, 'critic', device='cpu' 85 | ) 86 | self._policy_variable_client = variable_utils.VariableClient( 87 | variable_source, 'policy', device='cpu' 88 | ) 89 | self._reward_fn = None 90 | 91 | def _reward( 92 | state: jnp.ndarray, 93 | action: jnp.ndarray, 94 | next_state: jnp.ndarray, 95 | transition_discount: jnp.ndarray, 96 | key: jnp.ndarray, 97 | reward_params: networks_lib.Params, 98 | critic_params: networks_lib.Params, 99 | policy_params: networks_lib.Params, 100 | ) -> jnp.ndarray: 101 | def state_action_reward_fn( 102 | state: jnp.ndarray, action: jnp.ndarray 103 | ) -> jnp.ndarray: 104 | return jnp.ravel( 105 | networks.reward_network.apply(reward_params, state, action) 106 | ) 107 | 108 | def state_action_value_fn( 109 | state: jnp.ndarray, action: jnp.ndarray 110 | ) -> jnp.ndarray: 111 | return networks.critic_network.apply( 112 | critic_params, state, action).min(axis=-1) 113 | 114 | def state_value_fn( 115 | state: jnp.ndarray, policy_key: jnp.ndarray 116 | ) -> jnp.ndarray: 117 | action_dist = networks.policy_network.apply(policy_params, state) 118 | action = action_dist.sample(seed=policy_key) 119 | v = networks.critic_network.apply( 120 | critic_params, state, action).min(axis=-1) 121 | return v 122 | 123 | reward = reward_factory( 124 | state_action_reward_fn, 125 | state_action_value_fn, 126 | state_value_fn, 127 | discount, 128 | ) 129 | return reward(state, action, next_state, transition_discount, key) 130 | 131 | self._reward_fn = jax.jit(_reward) 132 | 133 | @property 134 | def _reward_params(self) -> Sequence[networks_lib.Params]: 135 | return self._reward_variable_client.params 136 | 137 | @property 138 | def _critic_params(self) -> Sequence[networks_lib.Params]: 139 | return self._critic_variable_client.params 140 | 141 | @property 142 | def _policy_params(self) -> Sequence[networks_lib.Params]: 143 | params = self._policy_variable_client.params 144 | return params 145 | 146 | def __call__( 147 | self, 148 | state: jnp.ndarray, 149 | action: jnp.ndarray, 150 | next_state: jnp.ndarray, 151 | discount: jnp.ndarray, 152 | key: jnp.ndarray, 153 | ) -> jnp.ndarray: 154 | assert self._reward_fn is not None 155 | return self._reward_fn( 156 | state, 157 | action, 158 | next_state, 159 | discount, 160 | key, 161 | self._reward_params, 162 | self._critic_params, 163 | self._policy_params, 164 | ) 165 | 166 | def update(self, wait: bool = False): 167 | """Get updated parameters and update reward function.""" 168 | 169 | self._reward_variable_client.update(wait) 170 | self._critic_variable_client.update(wait) 171 | self._policy_variable_client.update(wait) 172 | 173 | 174 | class ImitationObserver(observers_lib.EnvLoopObserver): 175 | """Observer that evaluated using the learned reward function.""" 176 | 177 | def __init__(self, reward_fn: RewardCore): 178 | self._reward_fn = reward_fn 179 | self._imitation_return = 0.0 180 | self._current_observation = None 181 | self._episode_rewards = [] 182 | self._epsiode_imitation_rewards = [] 183 | self._key = random.PRNGKey(0) 184 | 185 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep): 186 | self._reward_fn.update() 187 | self._imitation_return = 0.0 188 | self._current_observation = timestep.observation 189 | self._episode_rewards = [] 190 | self._epsiode_imitation_rewards = [] 191 | 192 | def observe( 193 | self, 194 | env: dm_env.Environment, 195 | timestep: dm_env.TimeStep, 196 | action: np.ndarray, 197 | ): 198 | """Records one environment step.""" 199 | self._key, subkey = random.split(self._key) 200 | obs = tree.map_structure( 201 | lambda obs: jnp.expand_dims(obs, 0), self._current_observation 202 | ) 203 | next_obs = tree.map_structure( 204 | lambda obs: jnp.expand_dims(obs, 0), timestep.observation 205 | ) 206 | imitation_reward = self._reward_fn( 207 | obs, 208 | jnp.expand_dims(action, 0), 209 | next_obs, 210 | jnp.expand_dims(timestep.discount, 0), 211 | subkey, 212 | ).squeeze() 213 | self._current_observation = timestep.observation 214 | self._imitation_return += imitation_reward 215 | self._episode_rewards += [timestep.reward] 216 | self._epsiode_imitation_rewards += [imitation_reward] 217 | 218 | def compute_correlation(self) -> jnp.ndarray: 219 | r = jnp.asarray(self._episode_rewards) 220 | ir = jnp.asarray(self._epsiode_imitation_rewards) 221 | return jnp.corrcoef(r, ir)[0, 1] 222 | 223 | def get_metrics(self) -> Dict[str, observers_lib.Number]: 224 | """Returns metrics collected for the current episode.""" 225 | corr = self.compute_correlation() 226 | metrics = { 227 | 'imitation_return': self._imitation_return, 228 | 'episode_reward_corr': corr, 229 | } 230 | return metrics # pytype: disable=bad-return-type # jnp-array 231 | 232 | 233 | def adroit_success(env: dm_env.Environment) -> bool: 234 | # Adroit environments have get_info methods. 235 | if not hasattr(env, 'get_info'): 236 | return False 237 | info = getattr(env, 'get_info')() 238 | if 'goal_achieved' in info: 239 | return info['goal_achieved'] 240 | else: 241 | return False 242 | 243 | 244 | def _get_success_from_env(env: dm_env.Environment) -> bool: 245 | """Obtain the success flag for Adroit environments.""" 246 | return adroit_success(env) 247 | 248 | 249 | class SuccessObserver(observers_lib.EnvLoopObserver): 250 | """Observer that extracts the goal_achieved flag from Adroit mj_env tasks.""" 251 | 252 | _success: bool = False 253 | 254 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep): 255 | del env, timestep 256 | self._success = False 257 | 258 | def observe( 259 | self, 260 | env: dm_env.Environment, 261 | timestep: dm_env.TimeStep, 262 | action: np.ndarray, 263 | ): 264 | """Records one environment step.""" 265 | del timestep, action 266 | success = _get_success_from_env(env) 267 | self._success = self._success or success 268 | 269 | def get_metrics(self) -> Dict[str, observers_lib.Number]: 270 | """Returns metrics collected for the current episode.""" 271 | metrics = { 272 | 'success': 1.0 if self._success else 0.0, 273 | } 274 | return metrics 275 | 276 | 277 | def imitation_evaluator_factory( 278 | environment_factory: types.EnvironmentFactory, 279 | network_factory: experiment_config.NetworkFactory[builders.Networks], 280 | policy_factory: experiment_config.PolicyFactory[ 281 | builders.Networks, builders.Policy 282 | ], 283 | logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger, 284 | agent_config: Union[sil_config.SILConfig, None] = None, 285 | ) -> experiment_config.EvaluatorFactory[builders.Policy]: 286 | """Returns an imitation learning evaluator process.""" 287 | 288 | def evaluator( 289 | random_key: types.PRNGKey, 290 | variable_source: core.VariableSource, 291 | counter: counting.Counter, 292 | make_actor: experiment_config.MakeActorFn[builders.Policy], 293 | ) -> environment_loop.EnvironmentLoop: 294 | """The evaluation process for imitation learning.""" 295 | # Create environment and evaluator networks. 296 | environment_key, actor_key = jax.random.split(random_key) 297 | # Environments normally require uint32 as a seed. 298 | environment = environment_factory(utils.sample_uint32(environment_key)) 299 | 300 | eval_environment = environment 301 | 302 | environment_spec = specs.make_environment_spec(environment) 303 | networks = network_factory(environment_spec) 304 | policy = policy_factory(networks, environment_spec, True) 305 | actor = make_actor(actor_key, policy, environment_spec, variable_source) 306 | 307 | success_observer = SuccessObserver() 308 | if agent_config is not None: 309 | reward_factory = agent_config.imitation.reward_factory() 310 | reward_fn = RewardCore( 311 | networks, reward_factory, agent_config.discount, variable_source 312 | ) 313 | 314 | imitation_observer = ImitationObserver(reward_fn) 315 | observers = (success_observer, imitation_observer) 316 | else: 317 | observers = (success_observer,) 318 | 319 | # Create logger and counter. 320 | counter = counting.Counter(counter, 'imitation_evaluator') 321 | logger = logger_factory('imitation_evaluator', 'actor_steps', 0) 322 | 323 | # Create the run loop and return it. 324 | return environment_loop.EnvironmentLoop( 325 | eval_environment, actor, counter, logger, observers=observers 326 | ) 327 | 328 | return evaluator 329 | 330 | 331 | def video_evaluator_factory( 332 | environment_factory: types.EnvironmentFactory, 333 | network_factory: experiment_config.NetworkFactory[builders.Networks], 334 | policy_factory: experiment_config.PolicyFactory[ 335 | builders.Networks, builders.Policy 336 | ], 337 | logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger, 338 | videos_per_eval: int = 0, 339 | ) -> experiment_config.EvaluatorFactory[builders.Policy]: 340 | """Returns an evaluator process that records videos.""" 341 | 342 | def evaluator( 343 | random_key: types.PRNGKey, 344 | variable_source: core.VariableSource, 345 | counter: counting.Counter, 346 | make_actor: experiment_config.MakeActorFn[builders.Policy], 347 | ) -> environment_loop.EnvironmentLoop: 348 | """The evaluation process for recording videos.""" 349 | # Create environment and evaluator networks. 350 | environment_key, actor_key = jax.random.split(random_key) 351 | # Environments normally require uint32 as a seed. 352 | environment = environment_factory(utils.sample_uint32(environment_key)) 353 | 354 | if videos_per_eval > 0: 355 | eval_environment = GymVideoWrapper( 356 | environment, 357 | record_every=videos_per_eval, 358 | frame_rate=40, 359 | filename='eval_episode', 360 | ) 361 | else: 362 | eval_environment = environment 363 | 364 | environment_spec = specs.make_environment_spec(environment) 365 | networks = network_factory(environment_spec) 366 | policy = policy_factory(networks, environment_spec, True) 367 | actor = make_actor(actor_key, policy, environment_spec, variable_source) 368 | 369 | observers = (SuccessObserver(),) 370 | 371 | # Create logger and counter. 372 | counter = counting.Counter(counter, 'video_evaluator') 373 | logger = logger_factory( 374 | 'video_evaluator', 'actor_steps', 0 375 | ) 376 | 377 | # Create the run loop and return it. 378 | return environment_loop.EnvironmentLoop( 379 | eval_environment, actor, counter, logger, observers=observers 380 | ) 381 | 382 | return evaluator 383 | 384 | 385 | def bc_evaluator_factory( 386 | environment_factory: types.EnvironmentFactory, 387 | network_factory: experiment_config.NetworkFactory[builders.Networks], 388 | policy_factory: experiment_config.PolicyFactory[ 389 | builders.Networks, builders.Policy 390 | ], 391 | logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger, 392 | ) -> experiment_config.EvaluatorFactory[builders.Policy]: 393 | """Returns an imitation learning evaluator process.""" 394 | 395 | def evaluator( 396 | random_key: types.PRNGKey, 397 | variable_source: core.VariableSource, 398 | counter: counting.Counter, 399 | make_actor: experiment_config.MakeActorFn[builders.Policy], 400 | ) -> environment_loop.EnvironmentLoop: 401 | del make_actor 402 | # Create environment and evaluator networks. 403 | environment_key, actor_key = jax.random.split(random_key) 404 | # Environments normally require uint32 as a seed. 405 | environment = environment_factory(utils.sample_uint32(environment_key)) 406 | 407 | eval_environment = environment 408 | 409 | environment_spec = specs.make_environment_spec(environment) 410 | networks = network_factory(environment_spec) 411 | policy = policy_factory(networks, environment_spec, True) 412 | 413 | def make_bc_actor( 414 | random_key: networks_lib.PRNGKey, 415 | policy: actor_core_lib.FeedForwardPolicy, 416 | environment_spec: specs.EnvironmentSpec, 417 | variable_source: Optional[core.VariableSource] = None, 418 | ) -> acme.Actor: 419 | del environment_spec 420 | assert variable_source is not None 421 | actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) 422 | variable_client = variable_utils.VariableClient( 423 | variable_source, 'bc_policy_params', device='cpu' 424 | ) 425 | return actors.GenericActor( 426 | actor_core, random_key, variable_client, None, backend='cpu' 427 | ) 428 | 429 | actor = make_bc_actor(actor_key, policy, environment_spec, variable_source) 430 | 431 | observers = (SuccessObserver(),) 432 | 433 | # Create logger and counter. 434 | counter = counting.Counter(counter, 'bc_evaluator') 435 | logger = logger_factory('bc_evaluator', 'actor_steps', 0) 436 | 437 | # Create the run loop and return it. 438 | return environment_loop.EnvironmentLoop( 439 | eval_environment, actor, counter, logger, observers=observers 440 | ) 441 | 442 | return evaluator 443 | -------------------------------------------------------------------------------- /sil/pretraining.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Methods to pretrain networks, e.g. policies, with regression. 17 | """ 18 | import time 19 | from typing import Callable, Dict, Iterator, NamedTuple, Optional, Sequence, Tuple 20 | 21 | from acme import specs 22 | from acme import types 23 | from acme.agents.jax import bc 24 | from acme.jax import networks as networks_lib 25 | from acme.jax import types as jax_types 26 | from acme.jax import utils 27 | from acme.utils import counting 28 | from acme.utils import experiment_utils 29 | from acme.utils import loggers 30 | import jax 31 | from jax import random 32 | import jax.numpy as jnp 33 | import numpy as np 34 | import optax 35 | import tensorflow_probability.substrates.jax as tfp 36 | import tree 37 | 38 | from sil import config 39 | 40 | tfd = tfp.distributions 41 | 42 | LoggerFactory = Callable[[], loggers.Logger] 43 | 44 | 45 | BCLossWithoutAux = bc.losses.BCLossWithoutAux 46 | BCLossWithAux = bc.losses.BCLossWithAux 47 | LossArgs = [ 48 | bc.networks.BCNetworks, 49 | networks_lib.Params, 50 | networks_lib.PRNGKey, 51 | types.Transition, 52 | ] 53 | Metrics = Dict[str, jnp.ndarray] 54 | 55 | ExtendedBCLossWithAux = Tuple[ 56 | Callable[LossArgs, Tuple[jnp.ndarray, Metrics]], 57 | Callable[[networks_lib.Params], networks_lib.Params], 58 | Callable[[networks_lib.Params], networks_lib.Params], 59 | ] 60 | 61 | 62 | def no_param_change(params: networks_lib.Params) -> networks_lib.Params: 63 | return params 64 | 65 | 66 | def weight_decay(params: networks_lib.Params) -> jnp.ndarray: 67 | """Used for weight decay loss terms.""" 68 | return 0.5 * sum( 69 | jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params)) 70 | 71 | 72 | def mse(action_dimension: int) -> ExtendedBCLossWithAux: 73 | """Log probability loss.""" 74 | del action_dimension 75 | def loss( 76 | networks: bc.networks.BCNetworks, 77 | params: networks_lib.Params, 78 | key: jax_types.PRNGKey, 79 | transitions: types.Transition, 80 | ) -> Tuple[jnp.ndarray, Metrics]: 81 | dist = networks.policy_network.apply( 82 | params, transitions.observation, is_training=True, key=key, 83 | train_encoder=True, 84 | ) 85 | sample = networks.sample_fn(dist, key) 86 | entropy = -networks.log_prob(dist, sample).mean() 87 | mean_sq_error = ((dist.mode() - transitions.action) ** 2).mean() 88 | metrics = { 89 | "mse": mean_sq_error, 90 | "nllh": -networks.log_prob(dist, transitions.action).mean(), 91 | "ent": entropy, 92 | } 93 | return mean_sq_error, metrics 94 | 95 | return loss, no_param_change, no_param_change 96 | 97 | 98 | def faithful_loss(action_dimension: int) -> ExtendedBCLossWithAux: 99 | """Combines mean-squared error and negative log-likeihood. 100 | 101 | Uses stop-gradients to ensure 'faithful' MSE fit and uncertainty 102 | quantification.' 103 | 104 | Args: 105 | action_dimension: (Unused for this loss) 106 | 107 | Returns: 108 | loss function 109 | parameter extender: for adding loss terms to pytree params 110 | parameter retracter: for removing loss terms from pytree params 111 | """ 112 | del action_dimension 113 | def loss( 114 | networks: bc.networks.BCNetworks, 115 | params: networks_lib.Params, 116 | key: jax_types.PRNGKey, 117 | transitions: types.Transition, 118 | ) -> Tuple[jnp.ndarray, Metrics]: 119 | dist, cut_dist = networks.policy_network.apply( 120 | params, 121 | transitions.observation, 122 | is_training=True, 123 | key=key, 124 | faithful_distributions=True, 125 | train_encoder=True, 126 | ) 127 | sample = networks.sample_fn(dist, key) 128 | entropy = -networks.log_prob(cut_dist, sample).mean() 129 | mean_sq_error = ((dist.mode() - transitions.action) ** 2).mean() 130 | nllh = -networks.log_prob(cut_dist, transitions.action).mean() 131 | 132 | loss_ = mean_sq_error + nllh 133 | 134 | metrics = { 135 | "mse": mean_sq_error, 136 | "nllh": nllh, 137 | "ent": entropy, 138 | } 139 | 140 | return loss_, metrics 141 | 142 | return loss, no_param_change, no_param_change 143 | 144 | 145 | def negative_loglikelihood(action_dimension: int) -> ExtendedBCLossWithAux: 146 | """Negative log likelihood loss.""" 147 | 148 | del action_dimension 149 | def loss( 150 | networks: bc.networks.BCNetworks, 151 | params: networks_lib.Params, 152 | key: jax_types.PRNGKey, 153 | transitions: types.Transition, 154 | ) -> Tuple[jnp.ndarray, Metrics]: 155 | dist = networks.policy_network.apply( 156 | params, transitions.observation, is_training=True, key=key, 157 | train_encoder=True, 158 | ) 159 | nllh = -networks.log_prob(dist, transitions.action).mean() 160 | sample = networks.sample_fn(dist, key) 161 | entropy = -networks.log_prob(dist, sample).mean() 162 | metrics = { 163 | "mse": ((dist.mode() - transitions.action) ** 2).mean(), 164 | "nllh": nllh, 165 | "ent": entropy, 166 | } 167 | return nllh, metrics 168 | 169 | return loss, no_param_change, no_param_change 170 | 171 | 172 | def zero_debiased_faithful_loss( 173 | action_dimension: int, 174 | ) -> ExtendedBCLossWithAux: 175 | """Combines mean-squared error and negative log-likeihood. 176 | 177 | Uses stop-gradients to ensure 'faithful' MSE fit and uncertainty 178 | quantification.' 179 | 180 | Args: 181 | action_dimension: (Unused for this loss) 182 | 183 | Returns: 184 | loss function 185 | parameter extender: for adding loss terms to pytree params 186 | parameter retracter: for removing loss terms from pytree params 187 | """ 188 | 189 | def loss( 190 | networks: bc.networks.BCNetworks, 191 | params: networks_lib.Params, 192 | key: jax_types.PRNGKey, 193 | transitions: types.Transition, 194 | ) -> Tuple[jnp.ndarray, Metrics]: 195 | policy_params = params["model"] 196 | dist, cut_dist = networks.policy_network.apply( 197 | policy_params, 198 | transitions.observation, 199 | is_training=True, 200 | key=key, 201 | faithful_distributions=True, 202 | train_encoder=True, 203 | ) 204 | sample = networks.sample_fn(dist, key) 205 | entropy = -networks.log_prob(cut_dist, sample).mean() 206 | mean_sq_error = ((dist.mode() - transitions.action) ** 2).mean(axis=1) 207 | nllh = -networks.log_prob(cut_dist, transitions.action) 208 | 209 | mse1 = jnp.expand_dims(mean_sq_error, axis=1) 210 | nllh1 = jnp.expand_dims(nllh, axis=1) 211 | 212 | loss_params = params["loss"] 213 | w_virtual = loss_params["w_virtual"] 214 | w = jax.nn.softmax(w_virtual) 215 | log_w = jnp.expand_dims(jnp.log(w), axis=0) 216 | iso_sigma = jax.nn.softplus(loss_params["sigma_virtual"]) 217 | sigma = iso_sigma * jnp.ones((action_dimension,)) 218 | 219 | dist0 = tfd.MultivariateNormalDiag( 220 | loc=jnp.zeros((action_dimension,)), scale_diag=sigma 221 | ) 222 | 223 | mse0 = jnp.expand_dims((transitions.action**2).mean(axis=1), axis=1) 224 | mse0 = mse0 / iso_sigma 225 | nllh0 = -jnp.expand_dims(dist0.log_prob(transitions.action), axis=1) 226 | 227 | mse_ = jnp.concatenate([mse0, mse1], axis=1) 228 | nllh_ = jnp.concatenate([nllh0, nllh1], axis=1) 229 | 230 | # Do a mixture likelihood. 231 | # log sum_i w_i p(x|i) = logsumexp(log_p(x|i) + log_w) along a new dimension 232 | # MSE needs temp for sharp minimum estimate. 233 | mse_lse_temp = 1000.0 234 | # Don't do logsumexp do the other form. 235 | mmse = ( 236 | -jax.scipy.special.logsumexp((mse_lse_temp * -mse_ + log_w), axis=1) 237 | / mse_lse_temp 238 | ) 239 | mnllh = -jax.scipy.special.logsumexp((-nllh_ + log_w), axis=1) 240 | mmse, mnllh = mmse.mean(), mnllh.mean() 241 | 242 | loss_ = mmse + mnllh 243 | metrics = { 244 | "mse": mmse, 245 | "nllh": mnllh, 246 | "ent": entropy, 247 | } 248 | 249 | return loss_, metrics 250 | 251 | def extend_params(params: networks_lib.Params) -> networks_lib.Params: 252 | new_params = { 253 | "loss": { 254 | "w_virtual": jnp.array([0.0, 1.0]), 255 | "sigma_virtual": -1.0 * jnp.ones((1,)), 256 | }, 257 | "model": params, 258 | } 259 | return new_params 260 | 261 | def retract_params(params: networks_lib.Params) -> networks_lib.Params: 262 | return params["model"] if "model" in params else params 263 | 264 | return loss, extend_params, retract_params 265 | 266 | 267 | # Map enum to loss function. 268 | _LOOKUP = { 269 | config.Losses.FAITHFUL: faithful_loss, 270 | config.Losses.DBFAITHFUL: zero_debiased_faithful_loss, 271 | config.Losses.MSE: mse, 272 | config.Losses.NLLH: negative_loglikelihood, 273 | } 274 | 275 | 276 | def get_loss_function( 277 | loss_type: config.Losses, 278 | ) -> Callable[[int], ExtendedBCLossWithAux]: 279 | assert loss_type in _LOOKUP 280 | return _LOOKUP[loss_type] 281 | 282 | 283 | TerminateCondition = Callable[[list[dict[str, jnp.ndarray]]], bool] 284 | 285 | 286 | class EarlyStoppingBCLearner(bc.BCLearner): 287 | """Behavioural cloning learner that stops based on metrics.""" 288 | 289 | def __init__(self, terminate_condition: TerminateCondition, *args, **kwargs): 290 | self.metrics = [] 291 | self.terminate_condition = terminate_condition 292 | self.terminate = False 293 | super().__init__(*args, **kwargs) 294 | 295 | def step(self): 296 | # Get a batch of Transitions. 297 | transitions = next(self._prefetching_iterator) 298 | self._state, metrics = self._sgd_step(self._state, transitions) 299 | metrics = utils.get_from_first_device(metrics) 300 | # Compute elapsed time. 301 | timestamp = time.time() 302 | elapsed_time = timestamp - self._timestamp if self._timestamp else 0 303 | self._timestamp = timestamp 304 | # Increment counts and record the current time. 305 | counts = self._counter.increment(steps=1, walltime=elapsed_time) 306 | # Attempts to write the logs. 307 | self._logger.write({**metrics, **counts}) 308 | self.metrics += [metrics,] 309 | self.terminate = self.terminate_condition(self.metrics) 310 | 311 | 312 | def behavioural_cloning_pretraining( 313 | seed: int, 314 | env_spec: specs.EnvironmentSpec, 315 | dataset_factory: Callable[[int], Iterator[types.Transition]], 316 | policy: networks_lib.FeedForwardNetwork, 317 | loss: config.Losses = config.Losses.FAITHFUL, 318 | num_steps: int = 40_000, 319 | learning_rate: float = 1e-4, 320 | logger: Optional[loggers.Logger] = None, 321 | name: str = "", 322 | ) -> networks_lib.Params: 323 | """Trains the policy and returns the params single-threaded training loop. 324 | 325 | Args: 326 | seed: Random seed for training. 327 | env_spec: Environment specification. 328 | dataset_factory: A function that returns an iterator with demonstrations to 329 | be imitated. 330 | policy: Policy network model. 331 | loss: loss type for pretraining (e.g. MSE, log likelihood, ...) 332 | num_steps: Number of training steps. 333 | learning_rate: Used for regression. 334 | logger: Optional external object for logging. 335 | name: Name used for logger. 336 | 337 | Returns: 338 | The trained network params. 339 | """ 340 | key = random.PRNGKey(seed) 341 | 342 | logger = logger or experiment_utils.make_experiment_logger(f"pretrainer_policy{name}") 343 | 344 | # Train using log likelihood. 345 | n_actions = np.prod(env_spec.actions.shape) 346 | loss_fn = _LOOKUP[loss] 347 | bc_loss, extend_params, retract_params = loss_fn(n_actions) 348 | 349 | bc_policy_network = bc.convert_to_bc_network(policy) 350 | # Add loss terms to params here. 351 | policy_network = bc.BCPolicyNetwork( 352 | lambda key: extend_params(policy.init(key)), bc_policy_network.apply 353 | ) 354 | 355 | bc_network = bc.BCNetworks( 356 | policy_network=policy_network, 357 | log_prob=lambda params, acts: params.log_prob(acts), 358 | # For BC agent, the sample_fn is used for evaluation. 359 | sample_fn=lambda params, key: params.sample(seed=key), 360 | ) 361 | 362 | dataset = dataset_factory(seed) 363 | 364 | counter = counting.Counter(prefix="policy_pretrainer", time_delta=0.0) 365 | 366 | history = 50 367 | ent_threshold = -2 * n_actions 368 | 369 | def terminate_condition(metrics: Sequence[Dict[str, jnp.ndarray]]) -> bool: 370 | if len(metrics) < history: 371 | return False 372 | else: 373 | return all(m["ent"] < ent_threshold for m in metrics[-history:]) 374 | 375 | learner = EarlyStoppingBCLearner( 376 | terminate_condition=terminate_condition, 377 | loss_fn=bc_loss, 378 | optimizer=optax.adam(learning_rate=learning_rate), 379 | random_key=key, 380 | networks=bc_network, 381 | prefetching_iterator=utils.sharded_prefetch(dataset), 382 | loss_has_aux=True, 383 | num_sgd_steps_per_step=1, 384 | logger=logger, 385 | counter=counter,) 386 | 387 | # Train the agent. 388 | for _ in range(num_steps): 389 | learner.step() 390 | # learner.terminate is available 391 | 392 | policy_and_loss_params = learner.get_variables(["policy"])[0] 393 | del learner # Ensure logger is closed. 394 | # Remove loss terms from params here. 395 | return retract_params(policy_and_loss_params) 396 | 397 | 398 | class TrainingState(NamedTuple): 399 | params: networks_lib.Params 400 | target_params: networks_lib.Params 401 | opt_state: optax.OptState 402 | 403 | 404 | def critic_pretraining( 405 | seed: int, 406 | dataset_factory: Callable[[int], Iterator[types.Transition]], 407 | critic: networks_lib.FeedForwardNetwork, 408 | critic_params: networks_lib.Params, 409 | reward: networks_lib.FeedForwardNetwork, 410 | reward_params: networks_lib.Params, 411 | discount_factor: float, 412 | num_steps: int = 10_000, 413 | learning_rate: float = 5e-3, 414 | counter: Optional[counting.Counter] = None, 415 | logger: Optional[loggers.Logger] = None, 416 | ) -> networks_lib.Params: 417 | """Pretrain the critic using a SARSA loss. 418 | 419 | Args: 420 | seed: for randomized training 421 | dataset_factory: SARSA data iterator for pretraining 422 | critic: critic function 423 | critic_params: initial critic params 424 | reward: reward function 425 | reward_params: known reward params 426 | discount_factor: discount used for Bellman equation 427 | num_steps: number of update steps 428 | learning_rate: learning rate of optimizer 429 | counter: used for logging 430 | logger: Optional external object for logging. 431 | 432 | Returns: 433 | Trained critic params. 434 | """ 435 | key = jax.random.PRNGKey(seed) 436 | optimiser = optax.adam(learning_rate) 437 | 438 | initial_opt_state = optimiser.init(critic_params) 439 | 440 | state = TrainingState(critic_params, critic_params, initial_opt_state) 441 | 442 | dataset_iterator = dataset_factory(seed) 443 | 444 | tau = 0.005 445 | 446 | sample = next(dataset_iterator) 447 | assert "next_action" in sample.extras, "Require SARSA dataset." 448 | 449 | @jax.jit 450 | def loss( 451 | params: networks_lib.Params, 452 | target_params: networks_lib.Params, 453 | observation: jnp.ndarray, 454 | action: jnp.ndarray, 455 | next_observation: jnp.ndarray, 456 | next_action: jnp.ndarray, 457 | discount: jnp.ndarray, 458 | key: jax_types.PRNGKey, 459 | ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: 460 | """.""" 461 | del key 462 | 463 | r = jnp.ravel(reward.apply(reward_params, observation, action)) 464 | 465 | next_v_target = critic.apply( 466 | target_params, next_observation, next_action).min(axis=-1) 467 | next_v_target = jax.lax.stop_gradient(next_v_target) 468 | discount_ = discount_factor * discount 469 | q_sarsa_target = jnp.expand_dims(r + discount_ * next_v_target, -1) 470 | q = critic.apply(params, observation, action) 471 | sarsa_loss = ((q_sarsa_target - q) ** 2).mean() 472 | 473 | def nonbatch_critic(s, a): 474 | batch_s = tree.map_structure(lambda f: f[None, ...], s) 475 | return critic.apply(params, batch_s, a[None, ...])[0] 476 | 477 | dqda = jax.vmap(jax.jacfwd(nonbatch_critic, argnums=1), in_axes=(0)) 478 | grads = dqda(observation, action) 479 | # Sum over actions, average over the rest. 480 | grad_norm = jnp.sqrt((grads**2).sum(axis=-1).mean()) 481 | 482 | loss = sarsa_loss + grad_norm 483 | 484 | metrics = { 485 | "loss": loss, 486 | "sarsa_loss": sarsa_loss, 487 | "grad_norm": grad_norm, 488 | } 489 | 490 | return loss, metrics 491 | 492 | @jax.jit 493 | def step(transition, state, key): 494 | if "next_action" in transition.extras: 495 | next_action = transition.extras["next_action"] 496 | else: 497 | next_action = transition.action 498 | values, grads = jax.value_and_grad(loss, has_aux=True)( 499 | state.params, 500 | state.target_params, 501 | transition.observation, 502 | transition.action, 503 | transition.next_observation, 504 | next_action, 505 | transition.discount, 506 | key, 507 | ) 508 | _, metrics = values 509 | updates, opt_state = optimiser.update(grads, state.opt_state) 510 | params = optax.apply_updates(state.params, updates) 511 | target_params = jax.tree_map( 512 | lambda x, y: x * (1 - tau) + y * tau, state.target_params, params 513 | ) 514 | return TrainingState(params, target_params, opt_state), metrics 515 | 516 | timestamp = time.time() 517 | counter = counter or counting.Counter( 518 | prefix="pretrainer_critic", time_delta=0.0 519 | ) 520 | logger = logger or loggers.make_default_logger( 521 | "pretrainer_critic", 522 | asynchronous=False, 523 | serialize_fn=utils.fetch_devicearray, 524 | steps_key=counter.get_steps_key(), 525 | ) 526 | for i in range(num_steps): 527 | _, key = jax.random.split(key) 528 | transitions = next(dataset_iterator) 529 | state, metrics = step(transitions, state, key) 530 | metrics["step"] = i 531 | timestamp_ = time.time() 532 | elapsed_time = timestamp_ - timestamp 533 | timestamp = timestamp_ 534 | counts = counter.increment(steps=1, walltime=elapsed_time) 535 | logger.write({**metrics, **counts}) 536 | 537 | logger.close() 538 | 539 | return state.params 540 | -------------------------------------------------------------------------------- /sil/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Soft imitation learning types, configurations and hyperparameters. 17 | 18 | Useful resources for these implementations: 19 | IQ-Learn: https://arxiv.org/abs/2106.12142 20 | https://github.com/Div99/IQ-Learn 21 | P^2IL: https://arxiv.org/abs/2209.10968 22 | https://github.com/lviano/P2IL 23 | """ 24 | import abc 25 | import dataclasses 26 | import enum 27 | from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union 28 | 29 | from acme import types 30 | from acme.agents.jax.sac import config as sac_config 31 | import jax 32 | from jax import lax 33 | from jax import random 34 | import jax.nn as jnn 35 | import jax.numpy as jnp 36 | import numpy as np 37 | import tree 38 | 39 | # Custom function and return typing for signature brevity. 40 | # Generalized reward function, including discounting. 41 | StateActionNextStateFunc = Callable[ 42 | [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], 43 | jnp.ndarray, 44 | ] 45 | # State-action critics. 46 | StateActionFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] 47 | # I.e. function regularizers. 48 | StateFunc = Callable[[jnp.ndarray], jnp.ndarray] 49 | # Soft (i.e. stochastic) value function. 50 | StochStateFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] 51 | # Losses and factories for defined here for brevity. 52 | AuxLoss = Tuple[jnp.ndarray, Dict[str, jnp.ndarray]] 53 | CriticLossFact = Callable[ 54 | [ 55 | StateActionNextStateFunc, 56 | StateActionFunc, 57 | StochStateFunc, 58 | StochStateFunc, 59 | float, 60 | types.Transition, 61 | types.Transition, 62 | jax.Array, 63 | ], 64 | AuxLoss, 65 | ] 66 | RewardFact = Callable[ 67 | [StateActionFunc, StateActionFunc, StochStateFunc, float], 68 | StateActionNextStateFunc, 69 | ] 70 | DemonstrationFact = Callable[..., Iterator[types.Transition]] 71 | ConcatenateFact = Callable[ 72 | [types.Transition, types.Transition], types.Transition 73 | ] 74 | 75 | 76 | # Used to bound policy log-likelihoods via the variance. 77 | MIN_VAR = 1e-5 78 | 79 | # Maximum reward for CSIL. 80 | # By having a lower bound on the variance, and setting alpha to be 81 | # (1 - \gamma) / dim_a, we can bound the maximum discounted return. 82 | MAX_REWARD = -0.5 * np.log(MIN_VAR) + np.log(2.) 83 | 84 | 85 | class Losses(enum.Enum): 86 | FAITHFUL = "faithful" 87 | DBFAITHFUL = "dbfaithful" 88 | MSE = "mse" 89 | NLLH = "nllh" 90 | 91 | def __str__(self) -> str: 92 | return self.value 93 | 94 | 95 | # Soft inverse Q learning has different objectives based on divergence choice 96 | # between expert and policy occupancy measure. 97 | class Divergence(enum.Enum): 98 | FORWARD_KL = 0 99 | REVERSE_KL = 1 100 | REVERSE_KL_DUAL = 2 101 | REVERSE_KL_UNBIASED = 3 102 | HELLINGER = 4 103 | JS = 5 104 | CHI = 6 105 | TOTAL_VARIATION = 7 106 | 107 | 108 | # Each f divergence can be written as 109 | # D_f[P || Q] = \sup_g E_P[g(x)] - E_Q[f^*(g(x))] 110 | # where f^* is the convex conjugate of f. 111 | # For the objective, we require a function \phi(x) = -f^*(-x). 112 | # These functions implement the \phi(x) for a given divergence. 113 | # See Table 4 of Garg et al. (2021) and Ho et al. (2016) for more details. 114 | _forward_kl: StateFunc = lambda r: 1.0 + jnp.log(r) 115 | _reverse_kl: StateFunc = lambda r: jnp.exp(-r - 1.0) 116 | _reverse_kl_dual: StateFunc = lambda r: jnn.softmax(-r) * r.shape[0] 117 | _reverse_kl_unbiased: StateFunc = lambda r: jnp.exp(-r) 118 | _hellinger: StateFunc = lambda r: 1.0 / (1.0 + r) ** 2 119 | _chi: StateFunc = lambda r: r - r**2 / 2.0 120 | _total_variation: StateFunc = lambda r: r 121 | _js: StateFunc = lambda r: jnp.log(2.0 - jnp.exp(-r)) 122 | 123 | DIVERGENCE_REGULARIZERS = { 124 | Divergence.FORWARD_KL: _forward_kl, 125 | Divergence.REVERSE_KL: _reverse_kl, 126 | Divergence.REVERSE_KL_DUAL: _reverse_kl_dual, 127 | Divergence.REVERSE_KL_UNBIASED: _reverse_kl_unbiased, 128 | Divergence.HELLINGER: _hellinger, 129 | Divergence.CHI: _chi, 130 | Divergence.TOTAL_VARIATION: _total_variation, 131 | Divergence.JS: _js, 132 | } 133 | 134 | 135 | def concatenate(x: Any, y: Any) -> Any: 136 | return tree.map_structure(lambda x, y: jnp.concatenate((x, y), axis=0), x, y) 137 | 138 | 139 | def concatenate_transitions( 140 | x: types.Transition, y: types.Transition 141 | ) -> types.Transition: 142 | fields = ["observation", "action", "reward", "discount", "next_observation"] 143 | return types.Transition( 144 | *[concatenate(getattr(x, field), getattr(y, field)) for field in fields] 145 | ) 146 | 147 | 148 | @dataclasses.dataclass 149 | class SoftImitationConfig(abc.ABC): 150 | """Abstact base class for soft imitation learning.""" 151 | 152 | @abc.abstractmethod 153 | def critic_loss_factory(self) -> CriticLossFact: 154 | """Define the critic loss based on the algorithm.""" 155 | 156 | @abc.abstractmethod 157 | def reward_factory(self) -> RewardFact: 158 | """Define the reward based on the algorithm, i.e. implicit vs explicit.""" 159 | 160 | 161 | @dataclasses.dataclass 162 | class InverseSoftQConfig(SoftImitationConfig): 163 | """Configuration for the IQ-Learn algorithm.""" 164 | 165 | # Defines divergence-derived function regularization. 166 | divergence: Divergence = Divergence.CHI 167 | 168 | def critic_loss_factory(self) -> CriticLossFact: 169 | """Generate critic objective from hyperparameters.""" 170 | regularizer = DIVERGENCE_REGULARIZERS[self.divergence] 171 | 172 | def objective( 173 | reward_fn: StateActionNextStateFunc, 174 | state_action_value_fn: StateActionFunc, 175 | value_fn: StochStateFunc, 176 | target_value_fn: StochStateFunc, 177 | discount: float, 178 | demonstration_transitions: types.Transition, 179 | online_transitions: types.Transition, 180 | key: jax.Array, 181 | ) -> AuxLoss: 182 | """See Equation 10 of Garg et al. (2021) for reference.""" 183 | del target_value_fn 184 | key_er, key_v, key_or = random.split(key, 3) 185 | expert_reward = reward_fn( 186 | demonstration_transitions.observation, 187 | demonstration_transitions.action, 188 | demonstration_transitions.next_observation, 189 | demonstration_transitions.discount, 190 | key_er, 191 | ) 192 | phi_grad = regularizer(expert_reward).mean() 193 | expert_loss = -phi_grad.mean() 194 | 195 | # This is denoted the value of the initial state distribution in the paper 196 | # and codebase, but in practice the distribution is replaced with the 197 | # demonstration distribution. In practice this term ensures the Q 198 | # function is maximized at the demonstration data since here we are 199 | # minimizing the value for action sampled around the optimal expert. 200 | value_reg = (1 - discount) * value_fn( 201 | demonstration_transitions.observation, key_v 202 | ).mean() 203 | 204 | online_reward = reward_fn( 205 | online_transitions.observation, 206 | online_transitions.action, 207 | online_transitions.next_observation, 208 | online_transitions.discount, 209 | key_or, 210 | ) 211 | # See code implementation IQ-Learn/iq_learn/iq.py 212 | agent_loss = 0.5 * (online_reward**2).mean() 213 | 214 | metrics = { 215 | "expert_reward": expert_reward.mean(), 216 | "online_reward": online_reward.mean(), 217 | "value_reg": value_reg, 218 | "expert_loss": expert_loss, 219 | "online_reg": agent_loss, 220 | } 221 | return expert_loss + value_reg + agent_loss, metrics 222 | 223 | return objective 224 | 225 | def reward_factory(self) -> RewardFact: 226 | def reward_factory_( 227 | state_action_reward: StateActionFunc, 228 | state_action_value_function: StateActionFunc, 229 | state_value_function: StochStateFunc, 230 | discount_factor: float, 231 | ) -> StateActionNextStateFunc: 232 | del state_action_reward 233 | 234 | def reward_fn( 235 | state: jnp.ndarray, 236 | action: jnp.ndarray, 237 | next_state: jnp.ndarray, 238 | discount: jnp.ndarray, 239 | value_key: jnp.ndarray, 240 | ) -> jnp.ndarray: 241 | q = state_action_value_function(state, action) 242 | future_v = discount * state_value_function(next_state, value_key) 243 | return q - discount_factor * future_v 244 | 245 | return reward_fn 246 | 247 | return reward_factory_ 248 | 249 | 250 | @dataclasses.dataclass 251 | class ProximalPointConfig(SoftImitationConfig): 252 | """Configuration for the proximal point imitation learning algorithm.""" 253 | 254 | bellman_error_temperature: float = 1.0 255 | 256 | def critic_loss_factory(self) -> CriticLossFact: 257 | """Generate critic objective from hyperparameters. 258 | 259 | P^2IL's objective consists of four terms. 260 | Optimizing the experts rewards, minimizing the logistic Bellman error, 261 | optimizing the initial state value using a proxy form to improve the value 262 | function, and regularizing the reward function with a squared penality. 263 | 264 | Returns: 265 | Function that return the critic loss function. 266 | """ 267 | alpha = self.bellman_error_temperature 268 | 269 | def objective( 270 | reward_fn: StateActionNextStateFunc, 271 | state_action_value_fn: StateActionFunc, 272 | value_fn: StochStateFunc, 273 | target_value_fn: StochStateFunc, 274 | discount: float, 275 | demonstration_transitions: types.Transition, 276 | online_transitions: types.Transition, 277 | key: jax.Array, 278 | ) -> AuxLoss: 279 | """See Equation 67 and Theorem 6 of Viano et al. (2022) for reference.""" 280 | key_cr, key_er, key_v, key_fv, key_or = random.split(key, 5) 281 | expert_reward = reward_fn( 282 | demonstration_transitions.observation, 283 | demonstration_transitions.action, 284 | demonstration_transitions.next_observation, 285 | demonstration_transitions.discount, 286 | key_er, 287 | ) 288 | expert_reward_mean = expert_reward.mean() 289 | expert_q = state_action_value_fn( 290 | demonstration_transitions.observation, 291 | demonstration_transitions.action, 292 | ) 293 | expert_v = value_fn(demonstration_transitions.observation, key_v) 294 | expert_next_v = value_fn( 295 | demonstration_transitions.next_observation, key_fv 296 | ) 297 | expert_d = discount * demonstration_transitions.discount 298 | expert_imp_r_mean = (expert_q - expert_d * expert_next_v).mean() 299 | 300 | online_r = reward_fn( 301 | online_transitions.observation, 302 | online_transitions.action, 303 | online_transitions.next_observation, 304 | online_transitions.discount, 305 | key_or, 306 | ) 307 | online_q = state_action_value_fn( 308 | online_transitions.observation, online_transitions.action 309 | ) 310 | online_v = value_fn(online_transitions.observation, key_v) 311 | online_d = online_transitions.discount 312 | combined_transition = concatenate_transitions( 313 | demonstration_transitions, online_transitions 314 | ) 315 | combined_reward = reward_fn( 316 | combined_transition.observation, 317 | combined_transition.action, 318 | combined_transition.next_observation, 319 | online_transitions.discount, 320 | key_cr, 321 | ) 322 | combined_q = state_action_value_fn( 323 | combined_transition.observation, combined_transition.action 324 | ) 325 | # The Bellman equation needs the target value function for stability, 326 | # while the regularization shapes the current value function. 327 | combined_next_target_v = target_value_fn( 328 | combined_transition.next_observation, key_fv 329 | ) 330 | 331 | # In theory, the reward function should be jointly optimized in the 332 | # Bellman equation to minimize the on-policy rewards, however, empirically 333 | # this produced worse results as the Bellman equation is harder to 334 | # minimize. 335 | discount_ = discount * combined_transition.discount 336 | combined_q_target = combined_reward + discount_ * combined_next_target_v 337 | bellman_error = lax.stop_gradient(combined_q_target) - combined_q 338 | # In theory, the Bellman error should not be negative as the Q and V 339 | # function should roughly track in magnitude, but in practice this was not 340 | # always the case. 341 | bellman_error = jnp.abs(bellman_error) 342 | # Self-normalized importance sampling weights z in Theorem 6 of 343 | # Vivano et al. (2022), used in logsumexp objective. 344 | log_weights = lax.stop_gradient(alpha * bellman_error) 345 | norm_weights = jnn.softmax(log_weights) 346 | ess = 1.0 / (norm_weights**2).sum() # effective sample size 347 | is_bellman_error = jnp.einsum("b,b->", norm_weights, bellman_error) 348 | sq_bellman_error = (bellman_error**2).mean() 349 | rms_bellman_error = jnp.sqrt(sq_bellman_error) 350 | 351 | apprenticeship_loss = -expert_reward.mean() + jnp.einsum( 352 | "b,b->", norm_weights, combined_reward 353 | ) 354 | 355 | # This is denoted the value of the initial state distribution in the paper 356 | # and codebase, but in practice the distribution is replaced with the 357 | # demonstration distribution. In practice this term ensures the Q 358 | # function is maximized at the demonstration data since here we are 359 | # minimizing the value for action sampled around the optimal expert. 360 | value_reg = (1.0 - discount) * expert_v.mean() 361 | 362 | # P^2IL uses IQ-Learn Chi^2-based regularization in practice. 363 | expert_r_mean = expert_reward_mean 364 | function_reg = 0.5 * (combined_reward**2).mean() 365 | metrics = { 366 | "apprenticeship_loss": apprenticeship_loss, 367 | "expert_reward": expert_reward.mean(), 368 | "expert_reward_implicit": expert_imp_r_mean, 369 | "expert_reward_combined": expert_r_mean, 370 | "online_reward": online_r.mean(), 371 | "value_reg": value_reg, 372 | "function_reg": function_reg, 373 | "sq_bellman_error": sq_bellman_error, 374 | "is_bellman_error": is_bellman_error, 375 | "ess": ess, 376 | } 377 | return ( 378 | -expert_r_mean + is_bellman_error + value_reg + function_reg, 379 | metrics, 380 | ) 381 | 382 | return objective 383 | 384 | def reward_factory(self) -> RewardFact: 385 | """P^2IL's reward function is a straightforwad MLP.""" 386 | 387 | def reward_factory_( 388 | state_action_reward: StateActionFunc, 389 | state_action_value_function: StateActionFunc, 390 | state_value_function: StochStateFunc, 391 | discount_factor: float, 392 | ) -> StateActionNextStateFunc: 393 | del state_action_value_function, state_value_function, discount_factor 394 | 395 | def reward_fn( 396 | state: jnp.ndarray, 397 | action: jnp.ndarray, 398 | next_state: jnp.ndarray, 399 | discount: jnp.ndarray, 400 | value_key: jnp.ndarray, 401 | ) -> jnp.ndarray: 402 | del next_state, discount, value_key 403 | return state_action_reward(state, action) 404 | 405 | return reward_fn 406 | 407 | return reward_factory_ 408 | 409 | 410 | @dataclasses.dataclass 411 | class CoherentConfig(SoftImitationConfig): 412 | """Coherent soft imitation learning.""" 413 | 414 | alpha: float # temperature used in the coherent reward 415 | reward_scaling: float = 1.0 416 | scale_factor: float = 1.0 # Scaling of online reward regularization. 417 | grad_norm_sf: float = 1.0 # Critic action Jacobian regularization. 418 | refine_reward: bool = True 419 | negative_reward: bool = False 420 | 421 | def critic_loss_factory(self) -> CriticLossFact: 422 | def objective( 423 | reward_fn: StateActionNextStateFunc, 424 | state_action_value_fn: StateActionFunc, 425 | value_fn: StochStateFunc, 426 | target_value_fn: StochStateFunc, 427 | discount: float, 428 | demonstration_transitions: types.Transition, 429 | online_transitions: types.Transition, 430 | key: jax.Array, 431 | ) -> AuxLoss: 432 | key_er, key_fv, key_or = random.split(key, 3) 433 | combined_transition = concatenate_transitions( 434 | demonstration_transitions, online_transitions 435 | ) 436 | 437 | online_reward = reward_fn( 438 | online_transitions.observation, 439 | online_transitions.action, 440 | online_transitions.next_observation, 441 | online_transitions.discount, 442 | key_or, 443 | ) 444 | 445 | expert_reward = reward_fn( 446 | demonstration_transitions.observation, 447 | demonstration_transitions.action, 448 | demonstration_transitions.next_observation, 449 | online_transitions.discount, 450 | key_er, 451 | ) 452 | 453 | reward = lax.stop_gradient(jnp.concatenate( 454 | (expert_reward, online_reward), axis=0)) 455 | 456 | state_action_value = state_action_value_fn( 457 | combined_transition.observation, combined_transition.action 458 | ) 459 | future_value = target_value_fn( 460 | combined_transition.next_observation, key_fv 461 | ) 462 | # Use knowledge to bound rogue Q values. 463 | max_value = 0. if self.negative_reward else MAX_REWARD / (1. - discount) 464 | future_value = jnp.clip(future_value, a_max=max_value) 465 | 466 | discount_ = discount * combined_transition.discount 467 | 468 | target_state_action_value = (reward + discount_ * future_value) 469 | bellman_error = state_action_value - target_state_action_value 470 | 471 | sbe = (bellman_error**2).mean() 472 | be = bellman_error.mean() 473 | 474 | # The mean expert reward corresponse to maximizing BC likelihood + const., 475 | # minimizing the online reward corresponds to minimizing KL to prior. 476 | # Use an unbiased KL estimator that can never be negative.al 477 | expert_reward_mean = expert_reward.mean() # + imp_expert_reward.mean() 478 | 479 | # In some cases the online reward can go as low as -50 480 | # even when the mean is positive. To be robust to these outliers, we just 481 | # clip the negative online rewards, as the role of this term is to 482 | # regularize the large positive values. 483 | # This KL estimator is motivated in http://joschu.net/blog/kl-approx.html 484 | if self.negative_reward: 485 | online_log_ratio = online_reward + self.reward_scaling * MAX_REWARD 486 | else: 487 | online_log_ratio = online_reward 488 | safe_online_log_ratio = jnp.maximum(online_log_ratio, -5.0) 489 | # the estimator is log r + 1/r - 1, and the reward is alpha log r 490 | policy_kl_est = ( 491 | jnp.exp(-safe_online_log_ratio) - 1. + online_log_ratio 492 | ).mean() 493 | 494 | def non_batch_state_action_value_fn(s, a): 495 | batch_s = tree.map_structure(lambda f: f[None, ...], s) 496 | return state_action_value_fn(batch_s, a[None, ...])[0] 497 | dqda = jax.vmap( 498 | jax.jacrev(non_batch_state_action_value_fn, argnums=1), in_axes=(0)) 499 | grads = dqda(demonstration_transitions.observation, 500 | demonstration_transitions.action) 501 | grad_norm = jnp.sqrt((grads**2).sum(axis=1).mean()) 502 | 503 | loss = sbe 504 | if self.refine_reward: 505 | loss -= expert_reward_mean 506 | loss += self.scale_factor * policy_kl_est 507 | loss += self.grad_norm_sf * grad_norm 508 | 509 | metrics = { 510 | "critic_action_grad_norm": grad_norm, 511 | "sq_bellman_error": sbe, 512 | "expert_reward": expert_reward.mean(), 513 | "online_reward": online_reward.mean(), 514 | "kl_est": policy_kl_est, 515 | } 516 | return loss, metrics 517 | 518 | return objective 519 | 520 | def reward_factory(self) -> RewardFact: 521 | """CSIL's reward function is policy-derived.""" 522 | 523 | def reward_factory_( 524 | state_action_reward: StateActionFunc, 525 | state_action_value_function: StateActionFunc, 526 | state_value_function: StochStateFunc, 527 | discount_factor: float, 528 | ) -> StateActionNextStateFunc: 529 | del state_action_value_function, state_value_function, discount_factor 530 | 531 | def reward_fn( 532 | state: jnp.ndarray, 533 | action: jnp.ndarray, 534 | next_state: jnp.ndarray, 535 | discount: jnp.ndarray, 536 | value_key: jnp.ndarray, 537 | ) -> jnp.ndarray: 538 | del next_state, discount, value_key 539 | return state_action_reward(state, action) 540 | 541 | return reward_fn 542 | 543 | return reward_factory_ 544 | 545 | 546 | @dataclasses.dataclass 547 | class PretrainingConfig: 548 | """Parameters for pretraining a model.""" 549 | 550 | dataset_factory: DemonstrationFact 551 | learning_rate: float 552 | steps: int 553 | seed: int 554 | loss: Optional[Losses] = None # Used for policy pretraining. 555 | use_as_reference: bool = False # Use pretrained policy as prior. 556 | 557 | 558 | def null_data_factory( 559 | n_demonstrations: int, seed: int 560 | ) -> Iterator[types.Transition]: 561 | del n_demonstrations, seed 562 | raise NotImplementedError() 563 | 564 | 565 | SilConfigTypes = Union[InverseSoftQConfig, ProximalPointConfig, CoherentConfig] 566 | 567 | 568 | @dataclasses.dataclass 569 | class SILConfig(sac_config.SACConfig): 570 | """Configuration options for soft imitation learning.""" 571 | 572 | # Imitation learning hyperparameters. 573 | expert_demonstration_factory: DemonstrationFact = null_data_factory 574 | imitation: SilConfigTypes = ( 575 | dataclasses.field(default_factory=InverseSoftQConfig)) 576 | actor_bc_loss: bool = False 577 | policy_pretraining: Optional[List[PretrainingConfig]] = None 578 | critic_pretraining: Optional[PretrainingConfig] = None 579 | actor_learning_rate: float = 3e-4 580 | reward_learning_rate: float = 3e-4 581 | critic_learning_rate: float = 3e-4 582 | critic_actor_update_ratio: int = 1 583 | alpha_learning_rate: float = 3e-4 584 | alpha_init: float = 1.0 585 | damping: float = 0.0 # Entropy constraint damping. 586 | -------------------------------------------------------------------------------- /sil/learning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Soft imitation learning learner implementation.""" 17 | 18 | from __future__ import annotations 19 | import time 20 | from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple 21 | 22 | import acme 23 | from acme import types 24 | from acme.jax import networks as networks_lib 25 | from acme.jax import utils 26 | from acme.utils import counting 27 | from acme.utils import loggers 28 | import jax 29 | from jax import lax 30 | import jax.numpy as jnp 31 | import optax 32 | 33 | from sil import config as sil_config 34 | from sil import networks as sil_networks 35 | from sil import pretraining 36 | 37 | # useful for analysis and comparing algorithms 38 | MONITOR_BC_METRICS = False 39 | 40 | 41 | class ImitationSample(NamedTuple): 42 | """For imitation learning, we require agent and demonstration experience.""" 43 | 44 | online_sample: types.Transition 45 | demonstration_sample: types.Transition 46 | 47 | 48 | class TrainingState(NamedTuple): 49 | """Contains training state for the learner.""" 50 | policy_optimizer_state: optax.OptState 51 | q_optimizer_state: optax.OptState 52 | r_optimizer_state: optax.OptState 53 | policy_params: networks_lib.Params 54 | q_params: networks_lib.Params 55 | target_q_params: networks_lib.Params 56 | r_params: networks_lib.Params 57 | key: networks_lib.PRNGKey 58 | bc_policy_params: Optional[networks_lib.Params] = None 59 | alpha_optimizer_state: Optional[optax.OptState] = None 60 | alpha_params: Optional[networks_lib.Params] = None 61 | 62 | 63 | class SILLearner(acme.Learner): 64 | """Soft imitation learning learner.""" 65 | 66 | _state: TrainingState 67 | 68 | def __init__( 69 | self, 70 | networks: sil_networks.SILNetworks, 71 | critic_loss_def: sil_config.CriticLossFact, 72 | reward_factory: sil_config.RewardFact, 73 | rng: jnp.ndarray, 74 | dataset: Iterator[ImitationSample], 75 | policy_optimizer: optax.GradientTransformation, 76 | q_optimizer: optax.GradientTransformation, 77 | r_optimizer: optax.GradientTransformation, 78 | tau: float = 0.005, 79 | discount: float = 0.99, 80 | critic_actor_update_ratio: int = 1, 81 | alpha_init: float = 1.0, 82 | alpha_learning_rate: float = 1e-3, 83 | entropy_coefficient: Optional[float] = None, 84 | target_entropy: float = 0.0, 85 | actor_bc_loss: bool = False, 86 | damping: float = 0.0, 87 | policy_pretraining: Optional[List[sil_config.PretrainingConfig]] = None, 88 | critic_pretraining: Optional[sil_config.PretrainingConfig] = None, 89 | counter: Optional[counting.Counter] = None, 90 | learner_logger: Optional[loggers.Logger] = None, 91 | policy_pretraining_loggers: Optional[List[loggers.Logger]] = None, 92 | critic_pretraining_logger: Optional[loggers.Logger] = None, 93 | num_sgd_steps_per_step: int = 1, 94 | ): 95 | """Initialize the soft imitation learning learner. 96 | 97 | Args: 98 | networks: SIL networks 99 | critic_loss_def: loss function definition for critic 100 | reward_factory: create implicit or explicit reward functions 101 | rng: a key for random number generation. 102 | dataset: an iterator over demonstrations and online data. 103 | policy_optimizer: the policy optimizer. 104 | q_optimizer: the Q-function optimizer. 105 | r_optimizer: the reward function optimizer. 106 | tau: target smoothing coefficient. 107 | discount: discount to use for TD updates. 108 | critic_actor_update_ratio: critic updates per single actor update. 109 | alpha_init: 110 | alpha_learning_rate: 111 | entropy_coefficient: coefficient applied to the entropy bonus. If None, an 112 | adaptative coefficient will be used. 113 | target_entropy: Used to normalize entropy. Only used when 114 | entropy_coefficient is None. 115 | actor_bc_loss: add auxiliary BC term to actor objective (unused) 116 | damping: damping of KL constraint 117 | policy_pretraining: Optional config for pretraining policy 118 | critic_pretraining: Optional config for pretraining critic 119 | counter: counter object used to keep track of steps. 120 | learner_logger: logger object to be used by learner. 121 | policy_pretraining_loggers: logger objects to be used by the policy pretraining. 122 | critic_pretraining_logger: logger object to be used by critic pretraining. 123 | num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'. 124 | """ 125 | 126 | adaptive_entropy_coefficient = entropy_coefficient is None 127 | kl_bound = jnp.abs(target_entropy) 128 | 129 | if adaptive_entropy_coefficient: 130 | # Alpha is the temperature parameter that determines the relative 131 | # importance of the entropy term versus the reward. 132 | # Invert softplus to initial virtual value. 133 | if alpha_init > 4.: 134 | virtual_alpha = jnp.asarray(alpha_init, dtype=jnp.float32) 135 | else: # safely invert softplus 136 | virtual_alpha = jnp.log( 137 | jnp.exp(jnp.asarray(alpha_init, dtype=jnp.float32)) - 1.0 138 | ) 139 | alpha_optimizer = optax.sgd(learning_rate=alpha_learning_rate) 140 | alpha_optimizer_state = alpha_optimizer.init(virtual_alpha) 141 | else: 142 | if target_entropy: 143 | raise ValueError( 144 | 'target_entropy should not be set when ' 145 | 'entropy_coefficient is provided' 146 | ) 147 | 148 | def make_initial_state( 149 | key: networks_lib.PRNGKey, 150 | ) -> Tuple[TrainingState, Optional[networks_lib.Params], bool]: 151 | """Initialises the training state (parameters and optimiser state).""" 152 | key_policy, key_q, key = jax.random.split(key, 3) 153 | 154 | # In the online setting we pretrain the policy against the demonstrations. 155 | # In the offfline setting we pretrain the policy against the dataset 156 | # and demonstrations. 157 | # In both cases, we use the last trained policy for the CSIL reward, 158 | # and in the offline case, we use the first policy as the 'BC' policy 159 | # to stay close to the data. 160 | bc_policy_params = [] 161 | use_pretrained_prior = False 162 | if policy_pretraining: 163 | for i, pt in enumerate(policy_pretraining): 164 | use_pretrained_prior = use_pretrained_prior or pt.use_as_reference 165 | if not bc_policy_params: 166 | policy_ = networks_lib.FeedForwardNetwork( 167 | networks.bc_policy_network.init, 168 | networks.bc_policy_network.apply, 169 | ) 170 | else: 171 | policy_ = networks_lib.FeedForwardNetwork( 172 | lambda key: bc_policy_params[0], 173 | networks.bc_policy_network.apply, 174 | ) 175 | params = pretraining.behavioural_cloning_pretraining( 176 | loss=pt.loss, 177 | seed=pt.seed, 178 | env_spec=networks.environment_specs, 179 | dataset_factory=pt.dataset_factory, 180 | policy=policy_, 181 | learning_rate=pt.learning_rate, 182 | num_steps=pt.steps, 183 | logger=policy_pretraining_loggers[i], 184 | name=f'{i}', 185 | ) 186 | bc_policy_params += [params,] 187 | else: 188 | bc_policy_params = [None] 189 | # While IQ-Learn and P2IL use policy pretraining for the policy, CSIL can 190 | # use it only for the reward initialization. 191 | policy_match = ( 192 | networks.policy_architecture == networks.bc_policy_architecture 193 | ) 194 | 195 | if policy_match and policy_pretraining: 196 | policy_params = bc_policy_params[0].copy() 197 | else: 198 | policy_params = networks.policy_network.init(key_policy) 199 | 200 | policy_optimizer_state = policy_optimizer.init(policy_params) 201 | 202 | if networks.reward_policy_coherence and bc_policy_params[-1]: 203 | r_params = bc_policy_params[-1].copy() 204 | else: 205 | r_params = networks.reward_network.init(key_q) 206 | # Share encoder with policy if present. 207 | r_params = sil_networks.update_encoder(r_params, policy_params) 208 | 209 | r_optimizer_state = r_optimizer.init(r_params) 210 | 211 | if critic_pretraining is not None: 212 | critic_ = networks_lib.FeedForwardNetwork( 213 | networks.critic_network.init, networks.critic_network.apply 214 | ) 215 | reward_ = networks_lib.FeedForwardNetwork( 216 | networks.reward_network.init, networks.reward_network.apply 217 | ) 218 | policy_ = networks_lib.FeedForwardNetwork( 219 | networks.policy_network.init, networks.policy_network.apply 220 | ) 221 | critic_params = critic_.init(key_q) 222 | critic_params = sil_networks.update_encoder( 223 | critic_params, policy_params) 224 | critic_params = pretraining.critic_pretraining( 225 | seed=critic_pretraining.seed, 226 | dataset_factory=critic_pretraining.dataset_factory, 227 | critic=critic_, 228 | critic_params=critic_params, 229 | reward=reward_, 230 | reward_params=r_params, 231 | discount_factor=discount, 232 | num_steps=critic_pretraining.steps, 233 | learning_rate=critic_pretraining.learning_rate, 234 | logger=critic_pretraining_logger, 235 | ) 236 | else: 237 | critic_params = networks.critic_network.init(key_q) 238 | # Share encoder with policy if present. 239 | critic_params = sil_networks.update_encoder( 240 | critic_params, policy_params) 241 | 242 | q_optimizer_state = q_optimizer.init(critic_params) 243 | 244 | state = TrainingState( 245 | policy_optimizer_state=policy_optimizer_state, 246 | q_optimizer_state=q_optimizer_state, 247 | r_optimizer_state=r_optimizer_state, 248 | policy_params=policy_params, 249 | q_params=critic_params, 250 | target_q_params=critic_params, 251 | r_params=r_params, 252 | bc_policy_params=bc_policy_params[-1], 253 | key=key, 254 | ) 255 | 256 | if adaptive_entropy_coefficient: 257 | state = state._replace( 258 | alpha_optimizer_state=alpha_optimizer_state, 259 | alpha_params=virtual_alpha, 260 | ) 261 | return state, bc_policy_params[-1], use_pretrained_prior 262 | 263 | # Create initial state. 264 | self._state, bc_policy_params, use_policy_prior = make_initial_state(rng) 265 | 266 | if use_policy_prior: 267 | assert bc_policy_params is not None 268 | 269 | def alpha_loss( 270 | virtual_alpha: jnp.ndarray, 271 | policy_params: networks_lib.Params, 272 | transitions: types.Transition, 273 | key: networks_lib.PRNGKey, 274 | ) -> jnp.ndarray: 275 | """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" 276 | dist_params = networks.policy_network.apply( 277 | policy_params, transitions.observation 278 | ) 279 | action = dist_params.sample(seed=key) 280 | log_prob = networks.log_prob(dist_params, action) 281 | alpha = jax.nn.softplus(virtual_alpha) 282 | if use_policy_prior: # bc_policy_params are not None 283 | dist_bc = networks.bc_policy_network.apply( 284 | bc_policy_params, transitions.observation 285 | ) 286 | prior_log_prob = networks.log_prob(dist_bc, action) 287 | kl = (log_prob - prior_log_prob).mean() 288 | # KL constraint. 289 | constraint = jax.lax.stop_gradient(kl_bound - kl) 290 | # Zero if kl < kl_bound, negative if violated. 291 | # We want temp to go zero if kl not violated, so don't clip here. 292 | loss = constraint 293 | # Do gradient ascent so invert sign w.r.t. actor loss term. 294 | alpha_loss = alpha * loss 295 | else: 296 | alpha_loss = ( 297 | alpha * jax.lax.stop_gradient(-log_prob - target_entropy).mean() 298 | ) 299 | return alpha_loss 300 | 301 | def critic_loss( 302 | q_params: networks_lib.Params, 303 | r_params: networks_lib.Params, 304 | policy_params: networks_lib.Params, 305 | target_q_params: networks_lib.Params, 306 | alpha: jnp.ndarray, 307 | demonstration_transitions: types.Transition, 308 | online_transitions: types.Transition, 309 | key: networks_lib.PRNGKey, 310 | ): 311 | # The key aspect of soft imitation learning is the critic objective and 312 | # reward. We obtain these from factories defined in the config. 313 | def state_action_reward_fn(state, action): 314 | return jnp.ravel(networks.reward_network.apply(r_params, state, action)) 315 | 316 | def state_action_value_fn(state, action): 317 | return networks.critic_network.apply( 318 | q_params, state, action).min(axis=-1) # reduce via min even for 1D 319 | 320 | def _state_value_fn(state, critic_params, policy_key): 321 | # SAC's soft value function, see Equation 3 of 322 | # https://arxiv.org/pdf/1812.05905.pdf. 323 | action_dist = networks.policy_network.apply(policy_params, state) 324 | action = action_dist.sample(seed=policy_key) 325 | policy_log_prob = networks.log_prob(action_dist, action) 326 | if use_policy_prior: # bc_policy_params have been trained 327 | prior_log_prob = networks.bc_policy_network.apply( 328 | bc_policy_params, state 329 | ).log_prob(action) 330 | else: 331 | prior_log_prob = networks.log_prob_prior(action) 332 | q = networks.critic_network.apply( 333 | critic_params, state, action).min(axis=-1) 334 | return q - alpha * (policy_log_prob - prior_log_prob) 335 | 336 | def state_value_fn( 337 | state: jnp.ndarray, key: jax.Array 338 | ) -> jnp.ndarray: 339 | return _state_value_fn(state, q_params, key) 340 | 341 | def target_state_value_fn( 342 | state: jnp.ndarray, key: jax.Array 343 | ) -> jnp.ndarray: 344 | return lax.stop_gradient(_state_value_fn(state, target_q_params, key)) 345 | 346 | reward_fn = reward_factory( 347 | state_action_reward_fn, 348 | state_action_value_fn, 349 | target_state_value_fn, 350 | discount, 351 | ) 352 | 353 | critic_loss, metrics = critic_loss_def( 354 | reward_fn, 355 | state_action_value_fn, 356 | state_value_fn, 357 | target_state_value_fn, 358 | discount, 359 | demonstration_transitions, 360 | online_transitions, 361 | key, 362 | ) 363 | return critic_loss, metrics 364 | 365 | def actor_loss( 366 | policy_params: networks_lib.Params, 367 | q_params: networks_lib.Params, 368 | alpha: jnp.ndarray, 369 | demonstration_transitions: types.Transition, 370 | online_transitions: types.Transition, 371 | key: networks_lib.PRNGKey, 372 | ) -> Tuple[jnp.ndarray, Dict[str, float | jnp.Array]]: 373 | 374 | def action_sample( 375 | observation: jnp.ndarray, 376 | action_key: jax.Array, 377 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Any]: 378 | dist = networks.policy_network.apply(policy_params, observation) 379 | sample = dist.sample(seed=action_key) 380 | log_prob = networks.log_prob(dist, sample) 381 | return dist.mode(), sample, log_prob, dist 382 | 383 | observation = sil_config.concatenate( 384 | demonstration_transitions.observation, online_transitions.observation 385 | ) 386 | expert_mode, expert_action, expert_log_prob, expert_dist = action_sample( 387 | demonstration_transitions.observation, key 388 | ) 389 | online_mode, online_action, online_log_prob, _ = action_sample( 390 | online_transitions.observation, key 391 | ) 392 | 393 | action = sil_config.concatenate(expert_action, online_action) 394 | action_mode = sil_config.concatenate(expert_mode, online_mode) 395 | log_prob = sil_config.concatenate(expert_log_prob, online_log_prob) 396 | prior_log_prob = networks.log_prob_prior(action) 397 | # Use min as reducer in case we use two or one critic functions. 398 | q = networks.critic_network.apply( 399 | q_params, observation, action).min(axis=-1) 400 | q_mode = networks.critic_network.apply( 401 | q_params, observation, action_mode).min(axis=-1) 402 | 403 | if use_policy_prior: 404 | dist_bc = networks.bc_policy_network.apply( 405 | bc_policy_params, observation 406 | ) 407 | prior_log_prob = networks.log_prob(dist_bc, action) 408 | kl = (log_prob - prior_log_prob).mean() 409 | constraint = (kl_bound - kl).sum() # Sum to reduce to scalar. 410 | clipped_constraint = jnp.clip(constraint, a_max=0.0) 411 | d = damping * clipped_constraint ** 2 412 | # Constraint is <= 0, so negate for loss. 413 | entropy_reg = -alpha * constraint + d 414 | else: # Vanilla maximum entropy regularization with uniform prior. 415 | d = 0.0 416 | kl = (log_prob - prior_log_prob).mean() 417 | constraint = kl 418 | clipped_constraint = 0.0 419 | entropy_reg = alpha * kl 420 | 421 | actor_loss = entropy_reg - q.mean() 422 | 423 | if actor_bc_loss: 424 | 425 | # For SAC's tanh policy, the minimizing modal MSE and maximizing 426 | # loglikelihood do not appear to be mutually guaranteed, so we optimize 427 | # for both. 428 | # Incorporate BC MSE loss from TD3+BC. 429 | # https://arxiv.org/abs/2106.06860 430 | expert_se = (expert_mode - demonstration_transitions.action) ** 2 431 | bc_loss_mean = 0.5 * expert_se.mean() * jnp.abs(q).mean() 432 | 433 | # Also incorporate a log-likelihood, which should be similar in value to 434 | # the entropy as they are constructed in similar ways, so use alpha to 435 | # weight. This is like maximum likelihood with an entropy bonus. 436 | # See https://proceedings.mlr.press/v97/jacq19a/jacq19a.pdf Section 5.2. 437 | expert_demo_log_prob = networks.log_prob( 438 | expert_dist, demonstration_transitions.action 439 | ) 440 | bc_loss_mean += -alpha * expert_demo_log_prob.mean() 441 | 442 | actor_loss += bc_loss_mean 443 | 444 | metrics = { 445 | 'actor_q': q.mean(), 446 | 'actor_q_mode': q_mode.mean(), 447 | 'actor_entropy_bonus': (alpha * log_prob).mean(), 448 | 'actor_kl': kl, 449 | 'kl_bound': kl_bound, 450 | 'constraint': constraint, 451 | 'clipped_constraint': clipped_constraint, 452 | 'entropy_reg': entropy_reg, 453 | 'prior_log_prob': prior_log_prob.mean(), 454 | 'policy_log_prob': log_prob.mean(), 455 | 'damping': d, 456 | } 457 | return actor_loss, metrics 458 | 459 | alpha_grad = jax.value_and_grad(alpha_loss) 460 | critic_grad = jax.value_and_grad(critic_loss, argnums=[0, 1], has_aux=True) 461 | actor_grad = jax.value_and_grad(actor_loss, has_aux=True) 462 | 463 | def update_step( 464 | state: TrainingState, 465 | sample: ImitationSample, 466 | ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: 467 | # Update temperature, actor and critic. 468 | key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4) 469 | alpha_grads = None 470 | alpha_loss = None 471 | if adaptive_entropy_coefficient: 472 | transition = sil_config.concatenate_transitions( 473 | sample.online_sample, sample.demonstration_sample 474 | ) 475 | alpha_loss, alpha_grads = alpha_grad( 476 | state.alpha_params, state.policy_params, transition, key_alpha 477 | ) 478 | alpha = jax.nn.softplus(state.alpha_params) 479 | else: 480 | alpha = entropy_coefficient 481 | 482 | # Update critic (and reward). 483 | 484 | q_params = state.q_params 485 | r_params = state.r_params 486 | target_q_params = state.target_q_params 487 | critic_loss = None 488 | critic_grads = None 489 | critic_loss_metrics = None 490 | q_optimizer_state = None 491 | r_optimizer_state = None 492 | for _ in range(critic_actor_update_ratio): 493 | critic_losses, grads = critic_grad( 494 | q_params, 495 | r_params, 496 | state.policy_params, 497 | target_q_params, 498 | alpha, 499 | sample.demonstration_sample, 500 | sample.online_sample, 501 | key_critic, 502 | ) 503 | critic_loss, critic_loss_metrics = critic_losses 504 | critic_grads, reward_grads = grads 505 | 506 | # Apply critic gradients. 507 | critic_update, q_optimizer_state = q_optimizer.update( 508 | critic_grads, state.q_optimizer_state, q_params 509 | ) 510 | q_params = optax.apply_updates(q_params, critic_update) 511 | 512 | reward_update, r_optimizer_state = r_optimizer.update( 513 | reward_grads, state.r_optimizer_state, r_params 514 | ) 515 | r_params = optax.apply_updates(r_params, reward_update) 516 | 517 | target_q_params = jax.tree_map( 518 | lambda x, y: x * (1 - tau) + y * tau, target_q_params, q_params 519 | ) 520 | 521 | # Update actor. 522 | actor_losses, actor_grads = actor_grad( 523 | state.policy_params, 524 | q_params, 525 | alpha, 526 | sample.demonstration_sample, 527 | sample.online_sample, 528 | key_actor, 529 | ) 530 | actor_loss, actor_loss_metrics = actor_losses 531 | actor_update, policy_optimizer_state = policy_optimizer.update( 532 | actor_grads, state.policy_optimizer_state) 533 | policy_params = optax.apply_updates(state.policy_params, actor_update) 534 | 535 | metrics = { 536 | 'critic_loss': critic_loss, 537 | 'actor_loss': actor_loss, 538 | 'critic_grad_norm': optax.global_norm(critic_grads), 539 | 'actor_grad_norm': optax.global_norm(actor_grads), 540 | } 541 | 542 | metrics.update(critic_loss_metrics) 543 | metrics.update(actor_loss_metrics) 544 | 545 | if MONITOR_BC_METRICS: 546 | # During training, expert actions should become / stay high likelihood. 547 | expert_action_dist = networks.policy_network.apply( 548 | policy_params, sample.demonstration_sample.observation 549 | ) 550 | samp = expert_action_dist.sample(seed=key) 551 | expert_ent_approx = -networks.log_prob(expert_action_dist, samp).mean() 552 | expert_llhs = networks.log_prob( 553 | expert_action_dist, sample.demonstration_sample.action 554 | ) 555 | expert_se = ( 556 | expert_action_dist.mode() - sample.demonstration_sample.action 557 | ) ** 2 558 | online_action_dist = networks.policy_network.apply( 559 | policy_params, sample.online_sample.observation 560 | ) 561 | samp = online_action_dist.sample(seed=key) 562 | online_ent_approx = -networks.log_prob(online_action_dist, samp).mean() 563 | online_llh = networks.log_prob( 564 | online_action_dist, sample.online_sample.action 565 | ).mean() 566 | online_se = (online_action_dist.mode() - sample.online_sample.action) ** 2 567 | 568 | metrics.update({ 569 | 'expert_llh_mean': expert_llhs.mean(), 570 | 'expert_llh_max': expert_llhs.max(), 571 | 'expert_llh_min': expert_llhs.min(), 572 | 'expert_mse': expert_se.mean(), 573 | 'online_llh': online_llh, 574 | 'online_mse': online_se.mean(), 575 | 'expert_ent': expert_ent_approx, 576 | 'online_ent': online_ent_approx, 577 | }) 578 | 579 | new_state = TrainingState( 580 | policy_optimizer_state=policy_optimizer_state, 581 | q_optimizer_state=q_optimizer_state, 582 | r_optimizer_state=r_optimizer_state, 583 | policy_params=policy_params, 584 | q_params=q_params, 585 | target_q_params=target_q_params, 586 | r_params=r_params, 587 | bc_policy_params=state.bc_policy_params, 588 | key=key, 589 | ) 590 | if adaptive_entropy_coefficient: 591 | # Apply alpha gradients. 592 | alpha_update, alpha_optimizer_state = alpha_optimizer.update( 593 | alpha_grads, state.alpha_optimizer_state) 594 | alpha_params = optax.apply_updates(state.alpha_params, alpha_update) 595 | metrics.update({ 596 | 'alpha_loss': alpha_loss, 597 | 'alpha': jax.nn.softplus(alpha_params), 598 | }) 599 | new_state = new_state._replace( 600 | alpha_optimizer_state=alpha_optimizer_state, 601 | alpha_params=alpha_params) 602 | 603 | metrics['rewards_mean'] = jnp.mean( 604 | jnp.abs(jnp.mean(sample.online_sample.reward, axis=0)) 605 | ) 606 | metrics['rewards_std'] = jnp.std(sample.online_sample.reward, axis=0) 607 | 608 | return new_state, metrics 609 | 610 | # General learner book-keeping and loggers. 611 | self._counter = counter or counting.Counter() 612 | self._logger = learner_logger or loggers.make_default_logger( 613 | 'learner', 614 | asynchronous=True, 615 | serialize_fn=utils.fetch_devicearray, 616 | steps_key=self._counter.get_steps_key()) 617 | self._num_sgd_steps_per_step = num_sgd_steps_per_step 618 | 619 | # Iterator on demonstration transitions. 620 | self._iterator = dataset 621 | 622 | # Use the JIT compiler. 623 | self._update_step = jax.jit(update_step) 624 | 625 | # Do not record timestamps until after the first learning step is done. 626 | # This is to avoid including the time it takes for actors to come online and 627 | # fill the replay buffer. 628 | self._timestamp = None 629 | 630 | def step(self): 631 | 632 | metrics = {} 633 | # Update temperature, actor and critic. 634 | for _ in range(self._num_sgd_steps_per_step): 635 | sample = next(self._iterator) 636 | self._state, metrics = self._update_step(self._state, sample) 637 | 638 | # Compute elapsed time. 639 | timestamp = time.time() 640 | elapsed_time = timestamp - self._timestamp if self._timestamp else 0 641 | self._timestamp = timestamp 642 | 643 | # Increment counts and record the current time. 644 | counts = self._counter.increment(steps=1, walltime=elapsed_time) 645 | 646 | # Attempts to write the logs. 647 | self._logger.write({**metrics, **counts}) 648 | 649 | def get_variables(self, names: List[str]) -> List[Any]: 650 | variables = { 651 | 'policy': self._state.policy_params, 652 | 'critic': self._state.q_params, 653 | 'reward': self._state.r_params, 654 | 'bc_policy_params': self._state.bc_policy_params, 655 | } 656 | return [variables[name] for name in names] 657 | 658 | def save(self) -> TrainingState: 659 | return self._state 660 | 661 | def restore(self, state: TrainingState): 662 | self._state = state 663 | -------------------------------------------------------------------------------- /sil/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Soft imitation networks definition. 17 | 18 | Builds heavily on acme/agents/jax/sac/networks.py 19 | (at https://github.com/google-deepmind/acme) 20 | """ 21 | 22 | import dataclasses 23 | import enum 24 | import math 25 | from typing import Any, Optional, Tuple, Callable, Sequence, Union 26 | 27 | from acme import core 28 | from acme import specs 29 | from acme.agents.jax.sac import networks as sac_networks 30 | from acme.jax import networks as networks_lib 31 | from acme.jax import types 32 | from acme.jax import utils 33 | import haiku as hk 34 | import haiku.initializers as hk_init 35 | import jax 36 | import jax.numpy as jnp 37 | import numpy as np 38 | import tensorflow_probability.substrates.jax as tfp 39 | 40 | from sil import config as sil_config 41 | 42 | tfd = tfp.distributions 43 | tfb = tfp.bijectors 44 | 45 | 46 | class PolicyArchitectures(enum.Enum): 47 | """Variations of policy architectures used in SIL methods.""" 48 | MLP = 'mlp' 49 | MIXMLP = 'mixmlp' 50 | HETSTATSIN = 'hetstatsin' 51 | HETSTATTRI = 'hetstattri' 52 | HETSTATPRELU = 'hetstatprelu' 53 | MIXHETSTATSIN = 'mixhetstatsin' 54 | MIXHETSTATTRI = 'mixhetstattri' 55 | MIXHETSTATPRELU = 'mixhetstatprelu' 56 | 57 | def __str__(self) -> str: 58 | return self.value 59 | 60 | 61 | class CriticArchitectures(enum.Enum): 62 | MLP = 'mlp' 63 | DOUBLE_MLP = 'double_mlp' # Used for SAC-based methods that use two critics. 64 | LNMLP = 'lnmlp' 65 | DOUBLE_LNMLP = 'double_lnmlp' 66 | STATSIN = 'statsin' 67 | STATTRI = 'stattri' 68 | STATPRELU = 'statprelu' 69 | 70 | def __str__(self) -> str: 71 | return self.value 72 | 73 | 74 | class RewardArchitectures(enum.Enum): 75 | MLP = 'mlp' 76 | LNMLP = 'lnmlp' 77 | PCSIL = 'pos_csil' 78 | NCSIL = 'neg_csil' 79 | PCONST = 'pos_const' 80 | NCONST = 'neg_const' 81 | 82 | def __str__(self) -> str: 83 | return self.value 84 | 85 | 86 | def observation_encoder(inputs: Any) -> jnp.ndarray: 87 | """Function that transforms observations into the correct vectors.""" 88 | if isinstance(inputs, jnp.ndarray): 89 | return inputs 90 | else: 91 | raise ValueError(f'Cannot convert type {type(inputs)}.') 92 | 93 | 94 | def update_encoder(params: networks_lib.Params, 95 | reference: networks_lib.Params) -> networks_lib.Params: 96 | predicate = lambda module_name, name, value: 'encoder' in module_name 97 | _, params_head = hk.data_structures.partition(predicate, params) 98 | ref_enc, _ = hk.data_structures.partition(predicate, reference) 99 | return hk.data_structures.merge(params_head, ref_enc) 100 | 101 | 102 | class Sequential(hk.Module): 103 | """Sequentially calls the given list of layers.""" 104 | 105 | def __init__( 106 | self, 107 | layers: Sequence[Callable[..., Any]], 108 | name: Optional[str] = None, 109 | ): 110 | super().__init__(name=name) 111 | self.layers = tuple(layers) 112 | 113 | def __call__(self, inputs, *args, **kwargs): 114 | """Calls all layers sequentially.""" 115 | out = inputs 116 | last_idx = len(self.layers) - 1 117 | for i, layer in enumerate(self.layers): 118 | if i == last_idx: 119 | out = layer(out, *args, **kwargs) 120 | else: 121 | out = layer(out) 122 | return out 123 | 124 | 125 | @dataclasses.dataclass 126 | class SILNetworks: 127 | """Network and pure functions for the soft imitation agent.""" 128 | 129 | environment_specs: specs.EnvironmentSpec 130 | policy_architecture: PolicyArchitectures 131 | bc_policy_architecture: PolicyArchitectures 132 | policy_network: networks_lib.FeedForwardNetwork 133 | critic_network: networks_lib.FeedForwardNetwork 134 | reward_network: networks_lib.FeedForwardNetwork 135 | log_prob: networks_lib.LogProbFn 136 | log_prob_prior: Callable[[jnp.ndarray], jnp.ndarray] 137 | sample: networks_lib.SampleFn 138 | bc_policy_network: networks_lib.FeedForwardNetwork 139 | reward_policy_coherence: bool = False 140 | sample_eval: Optional[networks_lib.SampleFn] = None 141 | 142 | def to_sac(self, using_bc_policy: bool = False) -> sac_networks.SACNetworks: 143 | """Cast to SAC policy to make use of the SAC helpers.""" 144 | policy_network = ( 145 | self.bc_policy_network if using_bc_policy else self.policy_network 146 | ) 147 | return sac_networks.SACNetworks( 148 | policy_network, 149 | self.critic_network, 150 | self.log_prob, 151 | self.sample, 152 | self.sample_eval, 153 | ) 154 | 155 | 156 | # From acme/agents/jax/cql/networks.py 157 | def apply_and_sample_n( 158 | key: networks_lib.PRNGKey, 159 | networks: SILNetworks, 160 | params: networks_lib.Params, 161 | obs: jnp.ndarray, 162 | num_samples: int, 163 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 164 | """Applies the policy and samples num_samples actions.""" 165 | dist_params = networks.policy_network.apply(params, obs) 166 | sampled_actions = jnp.array( 167 | [ 168 | networks.sample(dist_params, key_n) 169 | for key_n in jax.random.split(key, num_samples) 170 | ] 171 | ) 172 | sampled_log_probs = networks.log_prob(dist_params, sampled_actions) 173 | return sampled_actions, sampled_log_probs 174 | 175 | 176 | def default_models_to_snapshot( 177 | networks: SILNetworks, spec: specs.EnvironmentSpec 178 | ): 179 | """Defines default models to be snapshotted.""" 180 | dummy_obs = utils.zeros_like(spec.observations) 181 | dummy_action = utils.zeros_like(spec.actions) 182 | dummy_key = jax.random.PRNGKey(0) 183 | 184 | def critic_network(source: core.VariableSource) -> types.ModelToSnapshot: 185 | params = source.get_variables(['critic'])[0] 186 | return types.ModelToSnapshot( 187 | networks.critic_network.apply, 188 | params, 189 | {'obs': dummy_obs, 'action': dummy_action}, 190 | ) 191 | 192 | def reward_network(source: core.VariableSource) -> types.ModelToSnapshot: 193 | params = source.get_variables(['reward'])[0] 194 | return types.ModelToSnapshot( 195 | networks.critic_network.apply, 196 | params, 197 | {'obs': dummy_obs, 'action': dummy_action}, 198 | ) 199 | 200 | def default_training_actor( 201 | source: core.VariableSource) -> types.ModelToSnapshot: 202 | params = source.get_variables(['policy'])[0] 203 | return types.ModelToSnapshot( 204 | sac_networks.apply_policy_and_sample( 205 | networks.to_sac(), eval_mode=False 206 | ), 207 | params, 208 | {'key': dummy_key, 'obs': dummy_obs}, 209 | ) 210 | 211 | def default_eval_actor( 212 | source: core.VariableSource) -> types.ModelToSnapshot: 213 | params = source.get_variables(['policy'])[0] 214 | return types.ModelToSnapshot( 215 | sac_networks.apply_policy_and_sample(networks.to_sac(), eval_mode=True), 216 | params, 217 | {'key': dummy_key, 'obs': dummy_obs}, 218 | ) 219 | 220 | return { 221 | 'critic_network': critic_network, 222 | 'reward_network': reward_network, 223 | 'default_training_actor': default_training_actor, 224 | 'default_eval_actor': default_eval_actor, 225 | } 226 | 227 | 228 | default_init_normal = hk.initializers.VarianceScaling( 229 | 0.333, 'fan_out', 'normal' 230 | ) 231 | # If 1D regression plotting, 0.2 is more sensible. 232 | # It's crucial this is not changed, it doesn't work otherwise. 233 | # Acme SAC uses this 234 | default_init_uniform = hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform') 235 | 236 | 237 | class ClampedScaleNormalTanhDistribution(hk.Module): 238 | """Module that produces a variance-clampled TanhTransformedDistribution.""" 239 | 240 | def __init__( 241 | self, 242 | num_dimensions: int, 243 | max_log_scale: float = 0.0, 244 | min_log_scale: float = -5.0, 245 | w_init: hk_init.Initializer = hk_init.Orthogonal(), 246 | b_init: hk_init.Initializer = hk_init.Constant(0.0), 247 | name: str = 'ClampedScaleNormalTanhDistribution', 248 | ): 249 | """Initialization. 250 | 251 | Args: 252 | num_dimensions: Number of dimensions of a distribution. 253 | max_log_scale: Maximum log standard deviation. 254 | min_log_scale: Minimum log standard deviation. 255 | w_init: Initialization for linear layer weights. 256 | b_init: Initialization for linear layer biases. 257 | name: name of model that is passed to the parameters 258 | """ 259 | super().__init__(name=name) 260 | assert max_log_scale > min_log_scale 261 | self._min_log_scale = min_log_scale 262 | self._log_scale_range = max_log_scale - self._min_log_scale 263 | self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) 264 | self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init) 265 | 266 | def __call__( 267 | self, inputs: jnp.ndarray, faithful_distributions: bool = False 268 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]: 269 | loc = self._loc_layer(inputs) 270 | if faithful_distributions: 271 | inputs_ = jax.lax.stop_gradient(inputs) 272 | else: 273 | inputs_ = inputs 274 | 275 | log_scale_raw = self._scale_layer(inputs_) # Operating range around [-1, 1] 276 | log_scale_norm = 0.5 * (jax.nn.tanh(log_scale_raw) + 1.0) # Range [0, 1] 277 | scale = self._min_log_scale + self._log_scale_range * log_scale_norm 278 | scale = math.sqrt(sil_config.MIN_VAR) + jnp.exp(scale) 279 | distribution = tfd.Normal(loc=loc, scale=scale) 280 | transformed_dist = tfd.Independent( 281 | networks_lib.TanhTransformedDistribution(distribution), 282 | reinterpreted_batch_ndims=1, 283 | ) 284 | if faithful_distributions: 285 | cut_dist = tfd.Normal(loc=jax.lax.stop_gradient(loc), scale=scale) 286 | cut_transformed_dist = tfd.Independent( 287 | networks_lib.TanhTransformedDistribution(cut_dist), 288 | reinterpreted_batch_ndims=1, 289 | ) 290 | return transformed_dist, cut_transformed_dist 291 | else: 292 | return transformed_dist 293 | 294 | 295 | class MixtureClampedScaleNormalTanhDistribution( 296 | ClampedScaleNormalTanhDistribution): 297 | """Module that produces a variance-clampled TanhTransformedDistribution.""" 298 | 299 | def __init__( 300 | self, 301 | num_dimensions: int, 302 | n_mixture: int, 303 | max_log_scale: float = 0.0, 304 | min_log_scale: float = -5.0, 305 | w_init: hk_init.Initializer = hk_init.Orthogonal(), 306 | b_init: hk_init.Initializer = hk_init.Constant(0.0), 307 | ): 308 | """Initialization. 309 | 310 | Args: 311 | num_dimensions: Number of dimensions of a distribution. 312 | n_mixture: numver of mixture components. 313 | max_log_scale: Maximum log standard deviation. 314 | min_log_scale: Minimum log standard deviation. 315 | w_init: Initialization for linear layer weights. 316 | b_init: Initialization for linear layer biases. 317 | """ 318 | self.n_mixture = n_mixture 319 | self.n_outputs = num_dimensions 320 | super().__init__( 321 | num_dimensions=num_dimensions * n_mixture, 322 | max_log_scale=max_log_scale, 323 | min_log_scale=min_log_scale, 324 | w_init=w_init, 325 | b_init=b_init, 326 | name='MixtureClampedScaleNormalTanhDistribution', 327 | ) 328 | 329 | def __call__( 330 | self, inputs: jnp.ndarray, faithful_distributions: bool = False 331 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]: 332 | loc = self._loc_layer(inputs) 333 | if faithful_distributions: 334 | inputs_ = jax.lax.stop_gradient(inputs) 335 | else: 336 | inputs_ = inputs 337 | 338 | log_scale_raw = self._scale_layer(inputs_) # Operating range around [-1, 1] 339 | log_scale_norm = 0.5 * (jax.nn.tanh(log_scale_raw) + 1.0) # range [0, 1] 340 | scale = self._min_log_scale + self._log_scale_range * log_scale_norm 341 | scale = math.sqrt(sil_config.MIN_VAR) + jnp.exp(scale) 342 | 343 | log_mixture_weights = hk.get_parameter( 344 | 'log_mixture_weights', 345 | [self.n_mixture], 346 | init=hk.initializers.Constant(1.0), 347 | ) 348 | mixture_weights = jax.nn.softmax(log_mixture_weights) 349 | mixture_distribution = tfd.Categorical(probs=mixture_weights) 350 | 351 | def make_mixture(location, scale, weights): 352 | distribution = tfd.Normal(loc=location, scale=scale) 353 | 354 | transformed_distribution = tfd.Independent( 355 | networks_lib.TanhTransformedDistribution(distribution), 356 | reinterpreted_batch_ndims=1, 357 | ) 358 | 359 | return MixtureSameFamily( 360 | mixture_distribution=weights, 361 | components_distribution=transformed_distribution, 362 | ) 363 | 364 | mean = loc.reshape((-1, self.n_mixture, self.n_outputs)) 365 | stddev = scale.reshape((-1, self.n_mixture, self.n_outputs)) 366 | mixture = make_mixture(mean, stddev, mixture_distribution) 367 | if faithful_distributions: 368 | cut_mixture = make_mixture(jax.lax.stop_gradient(mean), 369 | stddev, mixture_distribution) 370 | return mixture, cut_mixture 371 | else: 372 | return mixture 373 | 374 | 375 | def _triangle_activation(x: jnp.ndarray) -> jnp.ndarray: 376 | z = jnp.floor(x / jnp.pi + 0.5) 377 | return (x - jnp.pi * z) * (-1) ** z 378 | 379 | 380 | @jax.jit 381 | def triangle_activation(x: jnp.ndarray) -> jnp.ndarray: 382 | pdiv2sqrt2 = 1.1107207345 383 | return pdiv2sqrt2 * _triangle_activation(x) 384 | 385 | 386 | @jax.jit 387 | def periodic_relu_activation(x: jnp.ndarray) -> jnp.ndarray: 388 | pdiv4 = 0.785398163 389 | pdiv2 = 1.570796326 390 | return (_triangle_activation(x) + _triangle_activation(x + pdiv2)) * pdiv4 391 | 392 | 393 | @jax.jit 394 | def sin_cos_activation(x: jnp.ndarray) -> jnp.ndarray: 395 | return jnp.sin(x) + jnp.cos(x) 396 | 397 | 398 | @jax.jit 399 | def hard_sin(x: jnp.ndarray) -> jnp.ndarray: 400 | pdiv4 = 0.785398163 # π/4 401 | return periodic_relu_activation(x - pdiv4) 402 | 403 | 404 | @jax.jit 405 | def hard_cos(x: jnp.ndarray) -> jnp.ndarray: 406 | pdiv4 = 0.785398163 # π/4 407 | return periodic_relu_activation(x + pdiv4) 408 | 409 | 410 | gaussian_init = hk.initializers.RandomNormal(1.0) 411 | 412 | 413 | class StationaryFeatures(hk.Module): 414 | """Stationary feature layer. 415 | 416 | Combines an MLP feature component (with bottleneck output) into a relatively 417 | wider feature layer that has periodic activation function. The from of the 418 | final weight distribution and periodic activation dictates the nature of the 419 | parametic stationary process. 420 | 421 | For more details see 422 | Periodic Activation Functions Induce Stationarity, Meronen et al. 423 | https://arxiv.org/abs/2110.13572 424 | """ 425 | 426 | def __init__( 427 | self, 428 | num_dimensions: int, 429 | layers: Sequence[int], 430 | feature_dimension: int = 512, 431 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, 432 | stationary_activation: Callable[ 433 | [jnp.ndarray], jnp.ndarray 434 | ] = sin_cos_activation, 435 | stationary_init: hk.initializers.Initializer = gaussian_init, 436 | layer_norm_mlp: bool = False, 437 | ): 438 | """Initialization. 439 | 440 | Args: 441 | num_dimensions: Number of dimensions of the output distribution. 442 | layers: feature MLP architecture up to the feature layer. 443 | feature_dimension: size of feature layer. 444 | activation: Activation of MLP network. 445 | stationary_activation: Periodic activation of feature layer. 446 | stationary_init: Random initialization of last layer 447 | layer_norm_mlp: Use layer norm in the first MLP layer. 448 | """ 449 | super().__init__(name='StationaryFeatures') 450 | self.output_dimension = num_dimensions 451 | self.feature_dimension = feature_dimension 452 | self.stationary_activation = stationary_activation 453 | self.stationary_init = stationary_init 454 | if layer_norm_mlp: 455 | self.mlp = networks_lib.LayerNormMLP( 456 | list(layers), 457 | activation=activation, 458 | w_init=hk.initializers.Orthogonal(), 459 | activate_final=True, 460 | ) 461 | else: 462 | self.mlp = hk.nets.MLP( 463 | list(layers), 464 | activation=activation, 465 | w_init=hk.initializers.Orthogonal(), 466 | activate_final=True, 467 | ) 468 | 469 | def features(self, inputs: jnp.ndarray) -> jnp.ndarray: 470 | input_dimension = inputs.shape[-1] 471 | 472 | # While the theory says that these random weights should be fixed, it's 473 | # crucial in practice to let them be trained. The distribution does not 474 | # actually change much, so they still contribute to the stationary 475 | # behaviour, and letting them be trained alleviates potential underfitting. 476 | random_weights = hk.get_parameter( 477 | 'random_weights', 478 | [input_dimension, self.feature_dimension // 2], 479 | init=self.stationary_init, 480 | ) 481 | 482 | log_lengthscales = hk.get_parameter( 483 | 'log_lengthscales', 484 | [input_dimension], 485 | init=hk.initializers.Constant(-5.0), 486 | ) 487 | 488 | ls = jnp.diag(jnp.exp(log_lengthscales)) 489 | wx = inputs @ ls @ random_weights 490 | pdiv4 = 0.785398163 # π/4 491 | f = jnp.concatenate( 492 | ( 493 | self.stationary_activation(wx + pdiv4), 494 | self.stationary_activation(wx - pdiv4), 495 | ), 496 | axis=-1, 497 | ) 498 | return f / math.sqrt(self.feature_dimension) 499 | 500 | 501 | class StationaryHeteroskedasticNormalTanhDistribution(StationaryFeatures): 502 | """Module that produces a stationary TanhTransformedDistribution.""" 503 | 504 | def __init__( 505 | self, 506 | num_dimensions: int, 507 | layers: Sequence[int], 508 | feature_dimension: int = 512, 509 | prior_variance: float = 1.0, 510 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, 511 | stationary_activation: Callable[ 512 | [jnp.ndarray], jnp.ndarray 513 | ] = sin_cos_activation, 514 | stationary_init: hk.initializers.Initializer = gaussian_init, 515 | layer_norm_mlp: bool = False, 516 | ): 517 | """Initialization. 518 | 519 | Args: 520 | num_dimensions: Number of dimensions of the output distribution. 521 | layers: feature MLP architecture up to the feature layer. 522 | feature_dimension: size of feature layer. 523 | prior_variance: initial variance of the predictive. 524 | activation: Activation of MLP network. 525 | stationary_activation: Periodic activation of feature layer. 526 | stationary_init: Random initialization of last layer 527 | layer_norm_mlp: Use layer norm in the first MLP layer. 528 | """ 529 | self.prior_var = prior_variance 530 | self.prior_stddev = np.sqrt(prior_variance) 531 | super().__init__( 532 | num_dimensions, 533 | layers, 534 | feature_dimension, 535 | activation, 536 | stationary_activation, 537 | stationary_init, 538 | layer_norm_mlp 539 | ) 540 | 541 | def __call__( 542 | self, inputs: jnp.ndarray, faithful_distributions: bool = False 543 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]: 544 | inputs = self.mlp(inputs) 545 | features = self.features(inputs) 546 | if faithful_distributions: 547 | features_ = jax.lax.stop_gradient(features) 548 | else: 549 | features_ = features 550 | 551 | loc_weights = hk.get_parameter( 552 | 'loc_weights', 553 | [self.feature_dimension, self.output_dimension], 554 | init=hk.initializers.Constant(0.0), 555 | ) 556 | 557 | # Parameterize the PSD matrix in lower triangular form as a 'raw' vector. 558 | # This minimizes the memory footprint to between 50-75% of the full matrix. 559 | n_sqrt = self.feature_dimension * (self.feature_dimension + 1) // 2 560 | scale_cross_weights_sqrt_raw = hk.get_parameter( 561 | 'scale_cross_weights_sqrt', 562 | [self.output_dimension, n_sqrt], 563 | init=hk.initializers.Constant(0.0), 564 | ) 565 | # convert vector into a lower triagular matrix with exponentiated diagonal, 566 | # so a vector of zeros becomes the identity matrix. 567 | b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None) 568 | scale_cross_weights_sqrt = jax.vmap(b.forward)(scale_cross_weights_sqrt_raw) 569 | loc = features @ loc_weights 570 | # Cholesky decompositon: A = LL^T where L is lower triangular 571 | # Variance is diagonal of x @ A @ x.T = x @ L @ L.T @ x.T 572 | # so first compute x @ L per output d 573 | var_sqrt = jnp.einsum('dij,bi->bdj', scale_cross_weights_sqrt, features_) 574 | var = jnp.einsum('bdi,bdi->bd', var_sqrt, var_sqrt) 575 | 576 | scale = self.prior_stddev * jnp.tanh(jnp.sqrt(sil_config.MIN_VAR + var)) 577 | 578 | distribution = tfd.Normal(loc=loc, scale=scale) 579 | 580 | transformed_distribution = tfd.Independent( 581 | networks_lib.TanhTransformedDistribution(distribution), 582 | reinterpreted_batch_ndims=1, 583 | ) 584 | 585 | if faithful_distributions: 586 | cut_distribution = tfd.Normal(loc=jax.lax.stop_gradient(loc), scale=scale) 587 | cut_transformed_distribution = tfd.Independent( 588 | networks_lib.TanhTransformedDistribution(cut_distribution), 589 | reinterpreted_batch_ndims=1, 590 | ) 591 | return transformed_distribution, cut_transformed_distribution 592 | else: 593 | return transformed_distribution 594 | 595 | 596 | class MixtureSameFamily(tfd.MixtureSameFamily): 597 | """MixtureSameFamily with mode computation.""" 598 | 599 | def mode(self) -> jnp.ndarray: 600 | """Return the mode of the modal mixture distribution.""" 601 | mode_model = self.mixture_distribution.mode() 602 | modes = self.components_distribution.mode() 603 | return modes[:, mode_model, :] 604 | 605 | 606 | class MixtureStationaryHeteroskedasticNormalTanhDistribution( 607 | StationaryHeteroskedasticNormalTanhDistribution 608 | ): 609 | """Module that produces a stationary TanhTransformedDistribution.""" 610 | 611 | def __init__( 612 | self, 613 | num_dimensions: int, 614 | n_mixture: int, 615 | layers: Sequence[int], 616 | feature_dimension: int = 512, 617 | prior_variance: float = 1.0, 618 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, 619 | stationary_activation: Callable[ 620 | [jnp.ndarray], jnp.ndarray 621 | ] = sin_cos_activation, 622 | stationary_init: hk.initializers.Initializer = gaussian_init, 623 | layer_norm_mlp: bool = False, 624 | ): 625 | """Initialization. 626 | 627 | Args: 628 | num_dimensions: Number of dimensions of the output distribution. 629 | n_mixture: Number of mixture components. 630 | layers: feature MLP architecture up to the feature layer. 631 | feature_dimension: size of feature layer. 632 | prior_variance: initial variance of the predictive. 633 | activation: Activation of MLP network. 634 | stationary_activation: Periodic activation of feature layer. 635 | stationary_init: Random initialization of last layer 636 | layer_norm_mlp: Use layer norm in the first MLP layer. 637 | """ 638 | self.n_mixture = n_mixture 639 | super().__init__( 640 | num_dimensions=num_dimensions, 641 | layers=layers, 642 | feature_dimension=feature_dimension, 643 | prior_variance=prior_variance, 644 | activation=activation, 645 | stationary_activation=stationary_activation, 646 | stationary_init=stationary_init, 647 | layer_norm_mlp=layer_norm_mlp, 648 | ) 649 | 650 | def __call__( 651 | self, inputs: jnp.ndarray, faithful_distributions: bool = False 652 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]: 653 | inputs = self.mlp(inputs) 654 | features = self.features(inputs) 655 | if faithful_distributions: 656 | features_ = jax.lax.stop_gradient(features) 657 | else: 658 | features_ = features 659 | 660 | loc_weights = hk.get_parameter( 661 | 'loc_weights', 662 | [self.feature_dimension, self.n_mixture, self.output_dimension], 663 | init=hk.initializers.Orthogonal(), # For mixture diversity. 664 | ) 665 | 666 | scale_cross_weights_sqrt = hk.get_parameter( 667 | 'scale_cross_weights_sqrt', 668 | [self.output_dimension, self.feature_dimension, self.feature_dimension], 669 | init=hk.initializers.Identity(gain=1.0), 670 | ) 671 | 672 | # batch x n_mixture x d_out 673 | mean = jnp.einsum('bi,ijk->bjk', features, loc_weights) 674 | scale_cross_weights_sqrt = jnp.tril(scale_cross_weights_sqrt) 675 | # Cholesky decompositon: A = LL^T where L is lower triangular 676 | # Variance is diagonal of x @ A @ x.T = x @ L @ L.T @ x.T 677 | # so first compute x @ L per output d 678 | var_sqrt = jnp.einsum('dij,bi->bdj', scale_cross_weights_sqrt, features_) 679 | var = jnp.einsum('bdi,bdi->bd', var_sqrt, var_sqrt) 680 | 681 | stddev_ = self.prior_stddev * jnp.tanh(jnp.sqrt(sil_config.MIN_VAR + var)) 682 | stddev = jnp.repeat(jnp.expand_dims(stddev_, 1), self.n_mixture, axis=1) 683 | 684 | assert mean.shape == stddev.shape, f'{mean.shape} != {stddev.shape}' 685 | 686 | log_mixture_weights = hk.get_parameter( 687 | 'log_mixture_weights', 688 | [self.n_mixture], 689 | init=hk.initializers.Constant(1.0), 690 | ) 691 | mixture_weights = jax.nn.softmax(log_mixture_weights) 692 | mixture_distribution = tfd.Categorical(probs=mixture_weights) 693 | 694 | def make_mixture(location, scale, weights): 695 | distribution = tfd.Normal(loc=location, scale=scale) 696 | 697 | transformed_distribution = tfd.Independent( 698 | networks_lib.TanhTransformedDistribution(distribution), 699 | reinterpreted_batch_ndims=1, 700 | ) 701 | 702 | return MixtureSameFamily( 703 | mixture_distribution=weights, 704 | components_distribution=transformed_distribution, 705 | ) 706 | 707 | mixture = make_mixture(mean, stddev, mixture_distribution) 708 | if faithful_distributions: 709 | cut_mixture = make_mixture( 710 | jax.lax.stop_gradient(mean), stddev, mixture_distribution) 711 | return mixture, cut_mixture 712 | else: 713 | return mixture 714 | 715 | 716 | class StationaryMLP(StationaryFeatures): 717 | """MLP that behaves like the mean function of a stationary process.""" 718 | 719 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 720 | inputs = self.mlp(inputs) 721 | features = self.features(inputs) 722 | 723 | loc_weights = hk.get_parameter( 724 | 'loc_weights', 725 | [self.feature_dimension, self.output_dimension], 726 | init=hk.initializers.Constant(0.0), 727 | ) 728 | return features @ loc_weights 729 | 730 | 731 | class LayerNormMLP(hk.Module): 732 | """Simple feedforward MLP torso with initial layer-norm.""" 733 | 734 | def __init__( 735 | self, 736 | layer_sizes: Sequence[int], 737 | w_init: hk.initializers.Initializer, 738 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu, 739 | activate_final: bool = False, 740 | name: str = 'feedforward_mlp_torso', 741 | ): 742 | """Construct the MLP.""" 743 | super().__init__(name=name) 744 | assert len(layer_sizes) > 1 745 | self._network = hk.Sequential([ 746 | hk.Linear(layer_sizes[0], w_init=w_init), 747 | hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), 748 | hk.nets.MLP( 749 | layer_sizes[1:], 750 | w_init=w_init, 751 | activation=activation, 752 | activate_final=activate_final, 753 | ), 754 | ]) 755 | 756 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 757 | """Forwards the policy network.""" 758 | return self._network(inputs) 759 | 760 | 761 | def prior_policy_log_likelihood( 762 | env_spec: specs.EnvironmentSpec, policy_architecture: PolicyArchitectures 763 | ) -> Callable[[jnp.ndarray], jnp.ndarray]: 764 | """We assume a uniform hyper prior in a [-1, 1] action space.""" 765 | del policy_architecture 766 | act_spec = env_spec.actions 767 | num_actions = np.prod(act_spec.shape, dtype=int) 768 | 769 | prior_llh = lambda x: -num_actions * jnp.log(2.0) 770 | return jax.vmap(prior_llh) 771 | 772 | 773 | def make_networks( 774 | spec: specs.EnvironmentSpec, 775 | policy_architecture: PolicyArchitectures = PolicyArchitectures.MLP, 776 | critic_architecture: CriticArchitectures = CriticArchitectures.LNMLP, 777 | reward_architecture: RewardArchitectures = RewardArchitectures.LNMLP, 778 | policy_hidden_layer_sizes: Tuple[int, ...] = (256, 256), 779 | critic_hidden_layer_sizes: Tuple[int, ...] = (256, 256), 780 | reward_hidden_layer_sizes: Tuple[int, ...] = (256, 256), 781 | reward_policy_coherence_alpha: Optional[float] = None, 782 | bc_policy_architecture: Optional[PolicyArchitectures] = None, 783 | bc_policy_hidden_layer_sizes: Optional[Tuple[int, ...]] = None, 784 | layer_norm_policy: bool = False, 785 | ) -> SILNetworks: 786 | """Creates networks used by the agent.""" 787 | 788 | num_actions = np.prod(spec.actions.shape, dtype=int) 789 | 790 | def _make_actor_fn( 791 | policy_arch: PolicyArchitectures, hidden_layer_size: Tuple[int, ...] 792 | ): 793 | assert len(hidden_layer_size) > 1 794 | 795 | def _actor_fn(obs, *args, train_encoder=False, **kwargs): 796 | if policy_arch == PolicyArchitectures.MLP: 797 | mlp = networks_lib.LayerNormMLP if layer_norm_policy else hk.nets.MLP 798 | network = Sequential([ 799 | mlp( 800 | list(hidden_layer_size), 801 | w_init=hk.initializers.Orthogonal(), 802 | activation=jax.nn.elu, 803 | activate_final=True, 804 | ), 805 | ClampedScaleNormalTanhDistribution(num_actions), 806 | ]) 807 | elif policy_arch == PolicyArchitectures.MIXMLP: 808 | mlp = networks_lib.LayerNormMLP if layer_norm_policy else hk.nets.MLP 809 | network = Sequential([ 810 | mlp( 811 | list(hidden_layer_size), 812 | w_init=hk.initializers.Orthogonal(), 813 | activation=jax.nn.elu, 814 | activate_final=True, 815 | ), 816 | MixtureClampedScaleNormalTanhDistribution(num_actions, n_mixture=5), 817 | ]) 818 | elif policy_arch == PolicyArchitectures.HETSTATSIN: 819 | network = StationaryHeteroskedasticNormalTanhDistribution( 820 | num_actions, 821 | hidden_layer_size[:-1], 822 | feature_dimension=hidden_layer_size[-1], 823 | prior_variance=0.75, 824 | stationary_activation=sin_cos_activation, 825 | layer_norm_mlp=layer_norm_policy, 826 | ) 827 | elif policy_arch == PolicyArchitectures.HETSTATTRI: 828 | network = StationaryHeteroskedasticNormalTanhDistribution( 829 | num_actions, 830 | hidden_layer_size[:-1], 831 | feature_dimension=hidden_layer_size[-1], 832 | prior_variance=0.75, 833 | stationary_activation=triangle_activation, 834 | layer_norm_mlp=layer_norm_policy, 835 | ) 836 | elif policy_arch == PolicyArchitectures.HETSTATPRELU: 837 | network = StationaryHeteroskedasticNormalTanhDistribution( 838 | num_actions, 839 | hidden_layer_size[:-1], 840 | feature_dimension=hidden_layer_size[-1], 841 | prior_variance=0.75, 842 | stationary_activation=periodic_relu_activation, 843 | layer_norm_mlp=layer_norm_policy, 844 | ) 845 | elif policy_arch == PolicyArchitectures.MIXHETSTATSIN: 846 | network = MixtureStationaryHeteroskedasticNormalTanhDistribution( 847 | num_actions, 848 | n_mixture=5, 849 | layers=hidden_layer_size[:-1], 850 | feature_dimension=hidden_layer_size[-1], 851 | prior_variance=0.75, 852 | stationary_activation=sin_cos_activation, 853 | layer_norm_mlp=layer_norm_policy, 854 | ) 855 | elif policy_arch == PolicyArchitectures.MIXHETSTATTRI: 856 | network = MixtureStationaryHeteroskedasticNormalTanhDistribution( 857 | num_actions, 858 | n_mixture=5, 859 | layers=hidden_layer_size[:-1], 860 | feature_dimension=hidden_layer_size[-1], 861 | prior_variance=0.75, 862 | stationary_activation=triangle_activation, 863 | layer_norm_mlp=layer_norm_policy, 864 | ) 865 | elif policy_arch == PolicyArchitectures.MIXHETSTATPRELU: 866 | network = MixtureStationaryHeteroskedasticNormalTanhDistribution( 867 | num_actions, 868 | n_mixture=5, 869 | layers=hidden_layer_size[:-1], 870 | feature_dimension=hidden_layer_size[-1], 871 | prior_variance=0.75, 872 | stationary_activation=periodic_relu_activation, 873 | layer_norm_mlp=layer_norm_policy, 874 | ) 875 | else: 876 | raise ValueError('Unknown policy architecture.') 877 | 878 | obs = observation_encoder(obs) 879 | if not train_encoder: 880 | obs = jax.lax.stop_gradient(obs) 881 | return network(obs, *args, **kwargs) 882 | 883 | return _actor_fn 884 | 885 | actor_fn = _make_actor_fn(policy_architecture, policy_hidden_layer_sizes) 886 | 887 | prior_policy_llh = prior_policy_log_likelihood( 888 | spec, 889 | policy_architecture=PolicyArchitectures.MLP, 890 | ) 891 | 892 | def _critic_fn(obs, action, train_encoder=False): 893 | if critic_architecture == CriticArchitectures.DOUBLE_MLP: # Needed for SAC. 894 | network1 = hk.Sequential([ 895 | hk.nets.MLP( 896 | list(critic_hidden_layer_sizes) + [1], 897 | w_init=hk.initializers.Orthogonal(), 898 | activation=jax.nn.elu), 899 | ]) 900 | network2 = hk.Sequential([ 901 | hk.nets.MLP( 902 | list(critic_hidden_layer_sizes) + [1], 903 | w_init=hk.initializers.Orthogonal(), 904 | activation=jax.nn.elu), 905 | ]) 906 | obs = observation_encoder(obs) 907 | input_ = jnp.concatenate([obs, action], axis=-1) 908 | value1 = network1(input_) 909 | value2 = network2(input_) 910 | return jnp.concatenate([value1, value2], axis=-1) 911 | elif critic_architecture == CriticArchitectures.MLP: 912 | network = hk.nets.MLP( 913 | list(critic_hidden_layer_sizes) + [1], 914 | w_init=hk.initializers.Orthogonal(), 915 | activation=jax.nn.elu, 916 | ) 917 | elif critic_architecture == CriticArchitectures.LNMLP: 918 | network = networks_lib.LayerNormMLP( 919 | list(critic_hidden_layer_sizes) + [1], 920 | w_init=hk.initializers.Orthogonal(), 921 | activation=jax.nn.elu, 922 | ) 923 | elif critic_architecture == CriticArchitectures.DOUBLE_LNMLP: 924 | network1 = hk.Sequential([ 925 | networks_lib.LayerNormMLP( 926 | list(critic_hidden_layer_sizes) + [1], 927 | w_init=hk.initializers.Orthogonal(), 928 | activation=jax.nn.elu), 929 | ]) 930 | network2 = hk.Sequential([ 931 | networks_lib.LayerNormMLP( 932 | list(critic_hidden_layer_sizes) + [1], 933 | w_init=hk.initializers.Orthogonal(), 934 | activation=jax.nn.elu), 935 | ]) 936 | 937 | obs = observation_encoder(obs) 938 | if not train_encoder: 939 | obs = jax.lax.stop_gradient(obs) 940 | input_ = jnp.concatenate([obs, action], axis=-1) 941 | value1 = network1(input_) 942 | value2 = network2(input_) 943 | return jnp.concatenate([value1, value2], axis=-1) 944 | elif critic_architecture == CriticArchitectures.STATSIN: 945 | network = StationaryMLP( 946 | 1, 947 | critic_hidden_layer_sizes[:-1], 948 | feature_dimension=critic_hidden_layer_sizes[-1], 949 | stationary_activation=sin_cos_activation, 950 | layer_norm_mlp=False, 951 | ) 952 | elif critic_architecture == CriticArchitectures.STATTRI: 953 | network = StationaryMLP( 954 | 1, 955 | critic_hidden_layer_sizes[:-1], 956 | feature_dimension=critic_hidden_layer_sizes[-1], 957 | stationary_activation=triangle_activation, 958 | layer_norm_mlp=False, 959 | ) 960 | elif critic_architecture == CriticArchitectures.STATPRELU: 961 | network = StationaryMLP( 962 | 1, 963 | critic_hidden_layer_sizes[:-1], 964 | feature_dimension=critic_hidden_layer_sizes[-1], 965 | stationary_activation=periodic_relu_activation, 966 | layer_norm_mlp=False, 967 | ) 968 | else: 969 | raise ValueError('Unknown critic architecture.') 970 | 971 | obs = observation_encoder(obs) 972 | input_ = jnp.concatenate([obs, action], axis=-1) 973 | return network(input_) 974 | 975 | reward_policy_coherence = (reward_architecture == RewardArchitectures.PCSIL or 976 | reward_architecture == RewardArchitectures.NCSIL) 977 | if reward_policy_coherence: 978 | assert reward_policy_coherence_alpha is not None 979 | assert bc_policy_architecture is not None 980 | assert bc_policy_hidden_layer_sizes is not None 981 | bc_actor_fn = _make_actor_fn( 982 | bc_policy_architecture, bc_policy_hidden_layer_sizes 983 | ) 984 | 985 | def _reward_fn(obs, action, *args, **kwargs): 986 | obs = observation_encoder(obs) 987 | alpha = reward_policy_coherence_alpha 988 | log_ratio = (bc_actor_fn(obs, *args, **kwargs).log_prob(action) # pytype: disable=attribute-error 989 | - prior_policy_llh(action)) 990 | if reward_architecture == RewardArchitectures.PCSIL: 991 | return alpha * log_ratio 992 | else: # reward_architecture == RewardArchitectures.NCSIL 993 | return alpha * (log_ratio - num_actions * sil_config.MAX_REWARD) 994 | 995 | else: 996 | bc_actor_fn = actor_fn 997 | bc_policy_architecture = policy_architecture 998 | 999 | def _reward_fn(obs, action, train_encoder=False): 1000 | if reward_architecture == RewardArchitectures.MLP: 1001 | network = hk.nets.MLP( 1002 | list(reward_hidden_layer_sizes) + [1], 1003 | w_init=hk.initializers.Orthogonal(), 1004 | ) 1005 | elif reward_architecture == RewardArchitectures.LNMLP: 1006 | network = networks_lib.LayerNormMLP( 1007 | list(reward_hidden_layer_sizes) + [1], 1008 | w_init=hk.initializers.Orthogonal(), 1009 | ) 1010 | elif reward_architecture == RewardArchitectures.PCONST: 1011 | network = jax.vmap(lambda sa: 1.0) 1012 | elif reward_architecture == RewardArchitectures.NCONST: 1013 | network = jax.vmap(lambda sa: -1.0) 1014 | else: 1015 | raise ValueError('Unknown reward architecture.') 1016 | 1017 | obs = observation_encoder(obs) 1018 | if not train_encoder: 1019 | obs = jax.lax.stop_gradient(obs) 1020 | input_ = jnp.concatenate([obs, action], axis=-1) 1021 | return network(input_) 1022 | 1023 | policy = hk.without_apply_rng(hk.transform(actor_fn)) 1024 | bc_policy = hk.without_apply_rng(hk.transform(bc_actor_fn)) 1025 | critic = hk.without_apply_rng(hk.transform(_critic_fn)) 1026 | reward = hk.without_apply_rng(hk.transform(_reward_fn)) 1027 | 1028 | # Create dummy observations and actions to create network parameters. 1029 | dummy_action = utils.zeros_like(spec.actions) 1030 | dummy_obs = utils.zeros_like(spec.observations) 1031 | dummy_action = utils.add_batch_dim(dummy_action) 1032 | dummy_obs = utils.add_batch_dim(dummy_obs) 1033 | 1034 | return SILNetworks( 1035 | policy_architecture=policy_architecture, 1036 | bc_policy_architecture=bc_policy_architecture, 1037 | policy_network=networks_lib.FeedForwardNetwork( 1038 | lambda key: policy.init(key, dummy_obs), policy.apply 1039 | ), 1040 | critic_network=networks_lib.FeedForwardNetwork( 1041 | lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply 1042 | ), 1043 | reward_network=networks_lib.FeedForwardNetwork( 1044 | lambda key: reward.init(key, dummy_obs, dummy_action), reward.apply 1045 | ), 1046 | log_prob=lambda params, actions: params.log_prob(actions), 1047 | log_prob_prior=prior_policy_llh, 1048 | sample=lambda params, key: params.sample(seed=key), 1049 | # Policy eval is distribution's 'mode' (i.e. deterministic). 1050 | sample_eval=lambda params, key: params.mode(), 1051 | environment_specs=spec, 1052 | reward_policy_coherence=reward_policy_coherence, 1053 | bc_policy_network=networks_lib.FeedForwardNetwork( 1054 | lambda key: bc_policy.init(key, dummy_obs), bc_policy.apply 1055 | ), 1056 | ) 1057 | --------------------------------------------------------------------------------