├── assets
└── csil.gif
├── .gitignore
├── requirements.txt
├── sil
├── __init__.py
├── builder.py
├── evaluator.py
├── pretraining.py
├── config.py
├── learning.py
└── networks.py
├── CONTRIBUTING.md
├── experiment_logger.py
├── README.md
├── helpers.py
├── LICENSE
├── run_ppil.py
├── run_iqlearn.py
└── run_csil.py
/assets/csil.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/csil/HEAD/assets/csil.gif
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | chex==0.1.6
3 | cython<3
4 | d4rl @ git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
5 | dm-acme[jax] @ git+https://github.com/google-deepmind/acme@d92e23b
6 | dm-env
7 | dm-haiku==0.0.9
8 | dm-launchpad[reverb]
9 | dm-sonnet
10 | dm-tree
11 | envlogger[tfds]
12 | flax==0.6.4
13 | gym<0.24.0
14 | ipython
15 | matplotlib
16 | notebook
17 | numpy
18 | optax==0.1.4
19 | orbax<=0.1.7
20 | patchelf
21 | rlds
22 | tabulate
23 | tensorflow==2.8.0
24 | tensorflow-datasets==4.6.0
25 | tensorflow_probability==0.15.0
26 | tqdm
27 | wandb
28 |
--------------------------------------------------------------------------------
/sil/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Soft imitation learning agent."""
17 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/experiment_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Loggers for experiments."""
17 |
18 | from typing import Any, Callable, Dict, Mapping, Optional
19 |
20 | import logging
21 | import time
22 | import wandb
23 |
24 | from acme.utils.loggers import aggregators
25 | from acme.utils.loggers import asynchronous as async_logger
26 | from acme.utils.loggers import base
27 | from acme.utils.loggers import csv
28 | from acme.utils.loggers import filters
29 | from acme.utils.loggers import terminal
30 |
31 |
32 | class WeightsAndBiasesLogger(base.Logger):
33 |
34 | def __init__(
35 | self,
36 | logger: wandb.sdk.wandb_run.Run,
37 | label: str = '',
38 | time_delta: float = 0.0,
39 | ):
40 | """Initializes the Weights And Biases wrapper for Acme.
41 |
42 | Args:
43 | logger: Weights & Biases logger instances
44 | label: label string to use when logging.
45 | serialize_fn: function to call which transforms values into a str.
46 | time_delta: How often (in seconds) to write values. This can be used to
47 | minimize terminal spam, but is 0 by default---ie everything is written.
48 | """
49 | self._label = label
50 | self._time = time.time()
51 | self._time_delta = time_delta
52 | self._logger = logger
53 |
54 | def write(self, data: base.LoggingData):
55 | """Write to weights and biases."""
56 | now = time.time()
57 | if (now - self._time) > self._time_delta:
58 | data = base.to_numpy(data) # type: ignore
59 | if self._label:
60 | stats = {f"{self._label}/{k}": v for k, v in data.items()}
61 | else:
62 | stats = data
63 | self._logger.log(stats) # type: ignore
64 | self._time = now
65 |
66 | def close(self):
67 | pass
68 |
69 | def make_logger(
70 | label: str,
71 | wandb_logger: wandb.sdk.wandb_run.Run,
72 | steps_key: str = 'steps',
73 | save_data: bool = False,
74 | time_delta: float = 1.0,
75 | asynchronous: bool = False,
76 | print_fn: Optional[Callable[[str], None]] = None,
77 | serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy,
78 | ) -> base.Logger:
79 | """Makes a default Acme logger.
80 |
81 | Args:
82 | label: Name to give to the logger.
83 | wandb_logger: Weights and Biases logger instance.
84 | save_data: Whether to persist data.
85 | time_delta: Time (in seconds) between logging events.
86 | asynchronous: Whether the write function should block or not.
87 | print_fn: How to print to terminal (defaults to print).
88 | serialize_fn: An optional function to apply to the write inputs before
89 | passing them to the various loggers.
90 | steps_key: Ignored.
91 |
92 | Returns:
93 | A logger object that responds to logger.write(some_dict).
94 | """
95 | del steps_key
96 | if not print_fn:
97 | print_fn = logging.info
98 |
99 | terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn)
100 | wandb_logger = WeightsAndBiasesLogger(
101 | logger=wandb_logger,
102 | label=label)
103 |
104 | loggers = [terminal_logger, wandb_logger]
105 |
106 | if save_data:
107 | loggers.append(csv.CSVLogger(label=label))
108 |
109 | # Dispatch to all writers and filter Nones and by time.
110 | logger = aggregators.Dispatcher(loggers, serialize_fn)
111 | logger = filters.NoneFilter(logger)
112 | if asynchronous:
113 | logger = async_logger.AsyncLogger(logger)
114 | logger = filters.TimeFilter(logger, time_delta)
115 |
116 | return logger
117 |
118 |
119 | def make_experiment_logger_factory(
120 | wandb_kwargs = Dict[str, Any]
121 | ) -> Callable[[str, Optional[str], int], base.Logger]:
122 | """Makes an Acme logger factory.
123 |
124 | Args:
125 | wandb_kwargs: Dictionary of keywork arguments for wandb.init().
126 |
127 | Returns:
128 | A logger factory function.
129 | """
130 |
131 | # In the distributed setting, it is better to initialize the logger once and pickle,
132 | # than to initialize the W&B logging in each process.
133 | wandb_logger = wandb.init(
134 | **wandb_kwargs,
135 | )
136 |
137 | def make_experiment_logger(label: str,
138 | steps_key: Optional[str] = None,
139 | task_instance: int = 0) -> base.Logger:
140 | del task_instance
141 | if steps_key is None:
142 | steps_key = f'{label}_steps'
143 | return make_logger(label=label, steps_key=steps_key,
144 | wandb_logger=wandb_logger,
145 | )
146 | return make_experiment_logger
147 |
--------------------------------------------------------------------------------
/sil/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Soft imitation learning builder."""
17 | from typing import Iterator, List, Optional
18 |
19 | import acme
20 | from acme import adders
21 | from acme import core
22 | from acme import specs
23 | from acme import types
24 | from acme.adders import reverb as adders_reverb
25 | from acme.agents.jax import actor_core as actor_core_lib
26 | from acme.agents.jax import actors
27 | from acme.agents.jax import builders
28 | from acme.agents.jax.sac import networks as sac_networks
29 | from acme.datasets import reverb as datasets
30 | from acme.jax import networks as networks_lib
31 | from acme.jax import utils
32 | from acme.jax import variable_utils
33 | from acme.utils import counting
34 | from acme.utils import loggers
35 | import jax
36 | import optax
37 | import reverb
38 | from reverb import rate_limiters
39 |
40 | from sil import config as sil_config
41 | from sil import learning
42 | from sil import networks as sil_networks
43 |
44 |
45 | class SILBuilder(
46 | builders.ActorLearnerBuilder[
47 | sil_networks.SILNetworks,
48 | actor_core_lib.FeedForwardPolicy,
49 | learning.ImitationSample,
50 | ]
51 | ):
52 | """Soft Imitation Learning Builder."""
53 |
54 | def __init__(self, config: sil_config.SILConfig):
55 | """Creates a soft imitation learner, a behavior policy and an eval actor.
56 |
57 | Args:
58 | config: a config with hyperparameters
59 | """
60 | self._config = config
61 | self._make_demonstrations = config.expert_demonstration_factory
62 |
63 | def make_learner(
64 | self,
65 | random_key: networks_lib.PRNGKey,
66 | networks: sil_networks.SILNetworks,
67 | dataset: Iterator[learning.ImitationSample],
68 | logger_fn: loggers.LoggerFactory,
69 | environment_spec: specs.EnvironmentSpec,
70 | replay_client: Optional[reverb.Client] = None,
71 | counter: Optional[counting.Counter] = None,
72 | ) -> core.Learner:
73 | del environment_spec, replay_client
74 |
75 | # Create optimizers.
76 | policy_optimizer = optax.adam(
77 | learning_rate=self._config.actor_learning_rate
78 | )
79 | q_optimizer = optax.adam(self._config.critic_learning_rate)
80 | r_optimizer = optax.sgd(learning_rate=self._config.reward_learning_rate)
81 |
82 | critic_loss = self._config.imitation.critic_loss_factory()
83 | reward_factory = self._config.imitation.reward_factory()
84 |
85 | n_policy_pretrainers = (len(self._config.policy_pretraining)
86 | if self._config.policy_pretraining else 0)
87 | policy_pretraining_loggers = [
88 | logger_fn(f'pretrainer_policy{i}')
89 | for i in range(n_policy_pretrainers)
90 | ]
91 |
92 | return learning.SILLearner(
93 | networks=networks,
94 | critic_loss_def=critic_loss,
95 | reward_factory=reward_factory,
96 | tau=self._config.tau,
97 | discount=self._config.discount,
98 | critic_actor_update_ratio=self._config.critic_actor_update_ratio,
99 | entropy_coefficient=self._config.entropy_coefficient,
100 | target_entropy=self._config.target_entropy,
101 | alpha_init=self._config.alpha_init,
102 | alpha_learning_rate=self._config.alpha_learning_rate,
103 | damping=self._config.damping,
104 | rng=random_key,
105 | num_sgd_steps_per_step=self._config.num_sgd_steps_per_step,
106 | policy_optimizer=policy_optimizer,
107 | q_optimizer=q_optimizer,
108 | r_optimizer=r_optimizer,
109 | dataset=dataset,
110 | actor_bc_loss=self._config.actor_bc_loss,
111 | policy_pretraining=self._config.policy_pretraining,
112 | critic_pretraining=self._config.critic_pretraining,
113 | learner_logger=logger_fn('learner'),
114 | policy_pretraining_loggers=policy_pretraining_loggers,
115 | critic_pretraining_logger=logger_fn('pretrainer_critic'),
116 | counter=counter,
117 | )
118 |
119 | def make_actor(
120 | self,
121 | random_key: networks_lib.PRNGKey,
122 | policy: actor_core_lib.FeedForwardPolicy,
123 | environment_spec: specs.EnvironmentSpec,
124 | variable_source: Optional[core.VariableSource] = None,
125 | adder: Optional[adders.Adder] = None,
126 | ) -> acme.Actor:
127 | del environment_spec
128 | assert variable_source is not None
129 | actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
130 | variable_client = variable_utils.VariableClient(
131 | variable_source, 'policy', device='cpu')
132 | return actors.GenericActor(
133 | actor_core, random_key, variable_client, adder, backend='cpu')
134 |
135 | def make_replay_tables(
136 | self,
137 | environment_spec: specs.EnvironmentSpec,
138 | policy: actor_core_lib.FeedForwardPolicy,
139 | ) -> List[reverb.Table]:
140 | """Create tables to insert data into."""
141 | del policy
142 | samples_per_insert_tolerance = (
143 | self._config.samples_per_insert_tolerance_rate *
144 | self._config.samples_per_insert)
145 | error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
146 | limiter = rate_limiters.SampleToInsertRatio(
147 | min_size_to_sample=self._config.min_replay_size,
148 | samples_per_insert=self._config.samples_per_insert,
149 | error_buffer=error_buffer)
150 | return [
151 | reverb.Table(
152 | name=self._config.replay_table_name,
153 | sampler=reverb.selectors.Uniform(),
154 | remover=reverb.selectors.Fifo(),
155 | max_size=self._config.max_replay_size,
156 | rate_limiter=limiter,
157 | signature=adders_reverb.NStepTransitionAdder.signature(
158 | environment_spec))
159 | ]
160 |
161 | def make_dataset_iterator(
162 | self, replay_client: reverb.Client
163 | ) -> Iterator[learning.ImitationSample]:
164 | """Create a dataset iterator to use for learning/updating the agent."""
165 | # Replay buffer for demonstration data.
166 | iterator_demos = self._make_demonstrations(self._config.batch_size)
167 |
168 | # Replay buffer for online experience.
169 | iterator_online = datasets.make_reverb_dataset(
170 | table=self._config.replay_table_name,
171 | server_address=replay_client.server_address,
172 | batch_size=self._config.batch_size,
173 | prefetch_size=self._config.prefetch_size,
174 | ).as_numpy_iterator()
175 |
176 | return utils.device_put(
177 | (
178 | learning.ImitationSample(types.Transition(*online.data), demo)
179 | for online, demo in zip(iterator_online, iterator_demos)
180 | ),
181 | jax.devices()[0],
182 | )
183 |
184 | def make_adder(
185 | self, replay_client: reverb.Client,
186 | environment_spec: Optional[specs.EnvironmentSpec],
187 | policy: Optional[actor_core_lib.FeedForwardPolicy]
188 | ) -> Optional[adders.Adder]:
189 | """Create an adder which records data generated by the actor/environment."""
190 | del environment_spec, policy
191 | return adders_reverb.NStepTransitionAdder(
192 | priority_fns={self._config.replay_table_name: None},
193 | client=replay_client,
194 | n_step=self._config.n_step,
195 | discount=self._config.discount,
196 | )
197 |
198 | def make_policy(
199 | self,
200 | networks: sil_networks.SILNetworks,
201 | environment_spec: specs.EnvironmentSpec,
202 | evaluation: bool = False,
203 | ) -> actor_core_lib.FeedForwardPolicy:
204 | """Construct the policy, which is the same as soft actor critic's policy."""
205 | del environment_spec
206 | return sac_networks.apply_policy_and_sample(
207 | networks.to_sac(using_bc_policy=False), eval_mode=evaluation
208 | )
209 |
210 | def make_bc_policy(
211 | self,
212 | networks: sil_networks.SILNetworks,
213 | environment_spec: specs.EnvironmentSpec,
214 | evaluation: bool = False,
215 | ) -> actor_core_lib.FeedForwardPolicy:
216 | """Construct the policy, which is the same as soft actor critic's policy."""
217 | del environment_spec
218 | return sac_networks.apply_policy_and_sample(
219 | networks.to_sac(using_bc_policy=True), eval_mode=evaluation
220 | )
221 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Coherent Soft Imitation Learning
2 |
3 | [](https://arxiv.org/abs/2305.16498)
4 | [](https://www.python.org/downloads/release/python-376/)
5 | [](https://opensource.org/licenses/Apache-2.0)
6 |
7 |
8 |
9 |
10 |
11 |
12 | This repository contains an implementation of [coherent soft imitation learning (CSIL)](https://arxiv.org/abs/2305.16498),
13 | published at [NeurIPS 2023](https://openreview.net/forum?id=kCCD8d2aEu).
14 |
15 | We also provide implementations of other 'soft'
16 | imitation learning (SIL) algorithms: [Inverse soft Q-learning (IQ-Learn)](https://arxiv.org/abs/2106.12142) and [proximal
17 | point imitation learning (PPIL)](https://arxiv.org/abs/2209.10968).
18 |
19 | ## Content
20 | The implementation is built on top of [Acme](https://github.com/google-deepmind/acme) and follows their agent structure.
21 | ```
22 | .
23 | ├── run_csil.py - Example of running CSIL on continuous control tasks.
24 | ├── run_iqlearn.py - Example of running IQ-Learn on continuous control tasks.
25 | ├── run_ppil.py - Example of running PPIL on continuous control tasks.
26 | ├── soft_policy_iteration.ipynb - Evaluation of SIL algorithms in a discrete tabular setting.
27 | ├── helpers.py - Utilities such as dataset iterators and environment creation.
28 | ├── experiment_logger.py - Implements a Weights & Biases logger within the Acme framework.
29 | |
30 | ├── sil
31 | | ├── config.py - Algorithm-specific configurations for soft imitation learning (SIL).
32 | | ├── builder.py - Creates the learner, actor, and policy.
33 | | ├── evaluator.py - Creates the evaluators and video recorders.
34 | | ├── learning.py - Implements the imitation learners.
35 | | ├── networks.py - Defines the policy, reward and critic networks.
36 | | └── pretraining.py - Implements pre-training for policy and critic.
37 | ```
38 |
39 | ## Usage
40 |
41 | Before running any code, first activate the conda environment and set the
42 | `PYTHONPATH`:
43 | ```bash
44 | conda activate csil
45 | export PYTHONPATH=$(pwd)/..
46 | ```
47 |
48 | To run CSIL with default settings:
49 | ```bash
50 | python scripts/run_csil.py
51 | ```
52 | This runs the online version of CSIL on HalfCheetah-v2.
53 |
54 | The experiment configurations for each algorithm (CSIL, IQ-Learn, and PPIL), can
55 | be adjusted via the flags defined at the start of `run_*.py`.
56 |
57 | The available tasks (specified with the `--env_name` flag) are:
58 | ```
59 | HalfCheetah-v2
60 | Ant-v2
61 | Walker2d-v2
62 | Hopper-v2
63 | Humanoid-v2
64 | door-v0 # Adroit hand
65 | hammer-v0 # Adroit hand
66 | pen-v0 # Adroit hand
67 | ```
68 |
69 | The default setting is online soft imitation learning. To run the offline
70 | version on the Adroit door task, for example:
71 | ```bash
72 | python scripts/run_{algo_name}.py --offline=True --env_name=door-v0
73 | ```
74 | replacing `{algo_name}` with either csil, iqlearn, or ppil.
75 |
76 | We have also included a Colab [here](https://colab.research.google.com/github/google-deepmind/csil/blob/main/soft_policy_iteration.ipynb) that reproduces
77 | the discrete grid world experiments shown in the paper, for a range of imitation learning algorithms.
78 |
79 | We highly encourage the use of accelerators (i.e. GPUs, TPUs) for CSIL. As CSIL requires a larger policy architecture, it has a slow wallclock time if run only on CPUs.
80 |
81 | For a reproduction of the paper's experiment, [see this Weights & Biases project](https://wandb.ai/jmw125/csil/workspace).
82 |
83 | The additional imitiation learning baselines shown in the paper [are available in Acme](https://github.com/google-deepmind/acme/tree/master/examples/baselines/imitation).
84 |
85 | ### Open issues
86 |
87 | [Distribued Acme experiments currently do not finish cleanly, so they appear as 'Crashed' on W&B when they finish successfully.](https://github.com/google-deepmind/acme/issues/312#issue-1990249288)
88 |
89 | The robomimic experiments are currently not open-sourced.
90 |
91 | ## Citing this work
92 |
93 | ```bibtex
94 | @inproceedings{watson2023csil,
95 | author = {Joe Watson and
96 | Sandy H. Huang and
97 | Nicolas Heess},
98 | title = {Coherent Soft Imitation Learning},
99 | booktitle = {Advances in Neural Information Processing Systems},
100 | year = {2023}
101 | }
102 | ```
103 |
104 | ## Installation
105 |
106 | First clone this code repository into a local directory:
107 | ```bash
108 | git clone https://github.com/google-deepmind/csil.git
109 | cd csil
110 | ```
111 |
112 | We recommend installing required dependencies inside a
113 | [conda environment](https://www.anaconda.com/). To do this, first install
114 | [Anaconda](https://www.anaconda.com/download#downloads) and then create and
115 | activate the conda environment:
116 | ```bash
117 | conda create --name csil python=3.9
118 | conda activate csil
119 | ```
120 | CSIL is written in JAX, so first install the correct version of JAX for your system by [following the installation instructions](https://jax.readthedocs.io/en/latest/installation.html).
121 | Acme requires `jax 0.4.3` and will install that version. This may need to be uninstalled for a CUDA-based JAX installation, e.g.
122 | ```bash
123 | pip install jax==0.4.7 https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.7+cuda12.cudnn88-cp39-cp39-manylinux2014_x86_64.whl
124 | ```
125 |
126 | MuJoCo must also be installed, in order to load the environments. Please follow
127 | the instructions [here](https://github.com/openai/mujoco-py#install-mujoco) to
128 | install the MuJoCo binary and place it in a directory where `mujoco-py` can find
129 | it.
130 | This installation uses `mujoco200`, `gym < 0.24.0` and `mujoco-py 2.0.2.5` for compatibility reasons.
131 |
132 | Then install `pip` and use it to install all the dependencies:
133 | ```bash
134 | pip install -r requirements.txt
135 | ```
136 | To verify the installation, run
137 | ```bash
138 | python -c "import jax.numpy as jnp; print(jnp.ones((1,)).device); import acme; import mujoco_py; import gym; print(gym.make('HalfCheetah-v2').reset())"
139 | ```
140 | If this fails, follow the guidance below.
141 |
142 | ## Troubleshooting
143 |
144 | If you get the error
145 | ```
146 | Command conda not found
147 | ```
148 | then you need to add the folder where Anaconda is installed to your `PATH`
149 | variable:
150 | ```bash
151 | export PATH=/path/to/anaconda/bin:$PATH
152 | ```
153 |
154 | If you get the error
155 | ```
156 | ImportError: libpython3.9.so.1.0: cannot open shared object file: No such file or directory
157 | ```
158 | first activate the conda environment and then add it to the `LD_LIBRARY_PATH`:
159 | ```bash
160 | conda activate csil
161 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:$CONDA_PREFIX/lib"
162 | ```
163 |
164 | If you get the error
165 | ```
166 | cannot find -lGL: No such file or directory
167 | ```
168 | then install libGL with:
169 | ```
170 | sudo apt install libgl-dev
171 | ```
172 |
173 |
174 | If you get the error
175 | ```
176 | fatal error: GL/glew.h: No such file or directory
177 | ```
178 | then you need to install the following in your conda environment and update the
179 | `CPATH`:
180 | ```bash
181 | conda install -c conda-forge glew
182 | conda install -c conda-forge mesalib
183 | conda install -c menpo glfw3
184 | export CPATH=$CONDA_PREFIX/include
185 | ```
186 |
187 | If you get the error
188 | ```
189 | ImportError: libgmpxx.so.4: cannot open shared object file: No such file or directory
190 | ```
191 | then you need to install the following in your conda environment and update the
192 | `CPATH`:
193 | ```bash
194 | conda install -c conda-forge gmp
195 | export CPATH=$CONDA_PREFIX/include
196 | ```
197 | If you get the error
198 | ```commandline
199 | ImportError: ../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /lib/x86_64-linux-gnu/libLLVM-15.so.1)
200 | ```
201 | try
202 | ```commandline
203 | mv libstdc++.so.6 libstdc++.so.6.old
204 | ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6 libstdc++.so.6
205 | ```
206 | according to [this advice](https://stackoverflow.com/a/73708979).
207 |
208 | ## License and disclaimer
209 |
210 | Copyright 2023 DeepMind Technologies Limited
211 |
212 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
213 | you may not use this file except in compliance with the Apache 2.0 license.
214 | You may obtain a copy of the Apache 2.0 license at:
215 | https://www.apache.org/licenses/LICENSE-2.0
216 |
217 | All other materials are licensed under the Creative Commons Attribution 4.0
218 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
219 | https://creativecommons.org/licenses/by/4.0/legalcode
220 |
221 | Unless required by applicable law or agreed to in writing, all software and
222 | materials distributed here under the Apache 2.0 or CC-BY licenses are
223 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
224 | either express or implied. See the licenses for the specific language governing
225 | permissions and limitations under those licenses.
226 |
227 | This is not an official Google product.
228 |
--------------------------------------------------------------------------------
/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Helpers for experiments."""
17 |
18 | import itertools
19 | from typing import Any, Callable, Iterator
20 |
21 | from acme import specs
22 | from acme import types
23 | from acme import wrappers
24 | from acme.datasets import tfds
25 | import d4rl # pylint: disable=unused-import
26 | import dm_env
27 | import gym
28 | import jax
29 | import jax.numpy as jnp
30 | import rlds
31 | import tensorflow as tf
32 | import tree
33 |
34 | from sil import learning
35 |
36 |
37 | ImitationIterator = Callable[
38 | [int, jax.Array], Iterator[learning.ImitationSample]
39 | ]
40 | TransitionIterator = Callable[
41 | [int, jax.Array], Iterator[types.Transition]
42 | ]
43 |
44 | # RLDS TFDS file names, see www.tensorflow.org/datasets/catalog/
45 | EXPERT_DATASET_NAMES = {
46 | 'HalfCheetah-v2': 'locomotion/halfcheetah_sac_1M_single_policy_stochastic',
47 | 'Ant-v2': 'locomotion/ant_sac_1M_single_policy_stochastic',
48 | 'Walker2d-v2': 'locomotion/walker2d_sac_1M_single_policy_stochastic',
49 | 'Hopper-v2': 'locomotion/hopper_sac_1M_single_policy_stochastic',
50 | 'Humanoid-v2': 'locomotion/humanoid_sac_15M_single_policy_stochastic',
51 | 'door-v0': 'd4rl_adroit_door/v0-expert',
52 | 'hammer-v0': 'd4rl_adroit_hammer/v0-expert',
53 | 'pen-v0': 'd4rl_adroit_pen/v0-expert',
54 | }
55 |
56 | OFFLINE_DATASET_NAMES = {
57 | 'HalfCheetah-v2': 'd4rl_mujoco_halfcheetah/v2-full-replay',
58 | 'Ant-v2': 'd4rl_mujoco_ant/v2-full-replay',
59 | 'Walker2d-v2': 'd4rl_mujoco_walker2d/v2-full-replay',
60 | 'Hopper-v2': 'd4rl_mujoco_hopper/v2-full-replay',
61 | # No alternative dataset for Humanoid-v2.
62 | 'Humanoid-v2': 'locomotion/humanoid_sac_15M_single_policy_stochastic',
63 | # Adroit doesn't have suboptimal datasets, so use human demos as alternative
64 | # the other options are -human or -cloned (50/50 human and agent).
65 | # Human data is hard for BC to learn from so we compromise with expert.
66 | 'door-v0': 'd4rl_adroit_door/v0-cloned',
67 | 'hammer-v0': 'd4rl_adroit_hammer/v0-cloned',
68 | 'pen-v0': 'd4rl_adroit_pen/v0-cloned',
69 | }
70 |
71 |
72 | class RandomIterator(Iterator[Any]):
73 |
74 | def __init__(self, dataset: tf.data.Dataset, batch_size: int, seed: int):
75 | dataset = dataset.shuffle(buffer_size=batch_size * 2, seed=seed)
76 | dataset = dataset.batch(batch_size, drop_remainder=True)
77 | self.iterator = itertools.cycle(dataset.as_numpy_iterator())
78 |
79 | def __next__(self) -> Any:
80 | return self.iterator.__next__()
81 |
82 |
83 | class OfflineImitationIterator(Iterator[learning.ImitationSample]):
84 | """ImitationSample iterator for offline IL."""
85 |
86 | def __init__(
87 | self,
88 | expert_iterator: Iterator[types.Transition],
89 | offline_iterator: Iterator[types.Transition],
90 | ):
91 | self._expert = expert_iterator
92 | self._offline = offline_iterator
93 |
94 | def __next__(self) -> learning.ImitationSample:
95 | """Combine data streams into an ImitationSample iterator."""
96 | return learning.ImitationSample(
97 | online_sample=self._offline.__next__(),
98 | demonstration_sample=self._expert.__next__(),
99 | )
100 |
101 |
102 | class MixedIterator(Iterator[types.Transition]):
103 | """Combine two streams of transitions 50/50."""
104 |
105 | def __init__(
106 | self,
107 | first_iterator: Iterator[types.Transition],
108 | second_iterator: Iterator[types.Transition],
109 | with_extras: bool = False,
110 | ):
111 | self._first = first_iterator
112 | self._second = second_iterator
113 | self.with_extras = with_extras
114 |
115 | def __next__(self) -> types.Transition:
116 | """Combine data streams 50/50 into one, with the equal batch size."""
117 | combined = tree.map_structure(
118 | lambda x, y: jnp.concatenate((x, y), axis=0),
119 | self._first.__next__(),
120 | self._second.__next__(),
121 | )
122 | return tree.map_structure(lambda x: x[::2, ...], combined)
123 |
124 |
125 | def get_dataset_name(env_name: str, expert: bool = True) -> str:
126 | """Map environment to an expert or non-expert dataset name."""
127 | if expert:
128 | assert (
129 | env_name in EXPERT_DATASET_NAMES
130 | ), f"Choose from {', '.join(EXPERT_DATASET_NAMES)}"
131 | return EXPERT_DATASET_NAMES[env_name]
132 | else: # Arbitrary offline data.
133 | assert (
134 | env_name in OFFLINE_DATASET_NAMES
135 | ), f"Choose from {', '.join(OFFLINE_DATASET_NAMES)}"
136 | return OFFLINE_DATASET_NAMES[env_name]
137 |
138 |
139 | def add_next_action_extras(
140 | transitions_iterator: tf.data.Dataset,
141 | ) -> tf.data.Dataset:
142 | """Creates transitions with next-action as extras information."""
143 |
144 | def _add_next_action_extras(
145 | double_transitions: types.Transition,
146 | ) -> types.Transition:
147 | """Creates a new transition containing the next action in extras."""
148 | # Observations may be dictionary or ndarray.
149 | get_obs = lambda x: tree.map_structure(lambda y: y[0], x)
150 | obs = get_obs(double_transitions.observation)
151 | next_obs = get_obs(double_transitions.next_observation)
152 | return types.Transition(
153 | observation=obs,
154 | action=double_transitions.action[0],
155 | reward=double_transitions.reward[0],
156 | discount=double_transitions.discount[0],
157 | next_observation=next_obs,
158 | extras={'next_action': double_transitions.action[1]},
159 | )
160 |
161 | double_transitions = rlds.transformations.batch(
162 | transitions_iterator, size=2, shift=1, drop_remainder=True
163 | )
164 | return double_transitions.map(_add_next_action_extras)
165 |
166 |
167 | def get_offline_dataset(
168 | task: str,
169 | environment_spec: specs.EnvironmentSpec,
170 | expert_num_demonstration: int,
171 | offline_num_demonstrations: int,
172 | expert_offline_data: bool = False,
173 | use_sarsa: bool = False,
174 | in_memory: bool = True,
175 | ) -> tuple[ImitationIterator, TransitionIterator]:
176 | """Get the offline dataset for a given task."""
177 | expert_dataset_name = get_dataset_name(task, expert=True)
178 | offline_dataset_name = get_dataset_name(task, expert=expert_offline_data)
179 |
180 | # Note: For offline learning we take the key, not a seed.
181 | def make_offline_dataset(
182 | batch_size: int, key: jax.Array
183 | ) -> Iterator[types.Transition]:
184 | offline_transitions_iterator = tfds.get_tfds_dataset(
185 | offline_dataset_name,
186 | offline_num_demonstrations,
187 | env_spec=environment_spec,
188 | )
189 | if use_sarsa:
190 | offline_transitions_iterator = add_next_action_extras(
191 | offline_transitions_iterator
192 | )
193 | if in_memory:
194 | return tfds.JaxInMemoryRandomSampleIterator(
195 | dataset=offline_transitions_iterator, key=key, batch_size=batch_size
196 | )
197 | else:
198 | return RandomIterator(
199 | offline_transitions_iterator, batch_size, seed=int(key))
200 |
201 | def make_imitation_dataset(
202 | batch_size: int,
203 | key: jax.Array,
204 | ) -> Iterator[learning.ImitationSample]:
205 | expert_transitions_iterator = tfds.get_tfds_dataset(
206 | expert_dataset_name, expert_num_demonstration, env_spec=environment_spec
207 | )
208 | if use_sarsa:
209 | expert_transitions_iterator = add_next_action_extras(
210 | expert_transitions_iterator
211 | )
212 | if in_memory:
213 | expert_iterator = tfds.JaxInMemoryRandomSampleIterator(
214 | dataset=expert_transitions_iterator, key=key, batch_size=batch_size
215 | )
216 | else:
217 | expert_iterator = RandomIterator(
218 | expert_transitions_iterator, batch_size, seed=int(key))
219 | offline_iterator = make_offline_dataset(batch_size, key)
220 | return OfflineImitationIterator(expert_iterator, offline_iterator)
221 |
222 | return make_imitation_dataset, make_offline_dataset
223 |
224 |
225 | def get_env_and_demonstrations(
226 | task: str,
227 | num_demonstrations: int,
228 | expert: bool = True,
229 | use_sarsa: bool = False,
230 | in_memory: bool = True,
231 | ) -> tuple[
232 | Callable[[], dm_env.Environment],
233 | specs.EnvironmentSpec,
234 | Callable[[int, int], Iterator[types.Transition]],
235 | ]:
236 | """Returns environment, spec and expert demonstration iterator."""
237 | make_env = make_environment(task)
238 |
239 | environment_spec = specs.make_environment_spec(make_env())
240 |
241 | # Create demonstrations function.
242 | dataset_name = get_dataset_name(task, expert=expert)
243 |
244 | def make_demonstrations(
245 | batch_size: int, seed: int = 0
246 | ) -> Iterator[types.Transition]:
247 | transitions_iterator = tfds.get_tfds_dataset(
248 | dataset_name, num_demonstrations, env_spec=environment_spec
249 | )
250 | if use_sarsa:
251 | transitions_iterator = add_next_action_extras(transitions_iterator)
252 | if in_memory:
253 | return tfds.JaxInMemoryRandomSampleIterator(
254 | dataset=transitions_iterator,
255 | key=jax.random.PRNGKey(seed),
256 | batch_size=batch_size,
257 | )
258 | else:
259 | return RandomIterator(
260 | transitions_iterator, batch_size, seed=seed)
261 | return make_env, environment_spec, make_demonstrations
262 |
263 |
264 | def make_environment(task: str) -> Callable[[], dm_env.Environment]:
265 | """Makes the requested continuous control environment.
266 |
267 | Args:
268 | task: Task to load.
269 |
270 | Returns:
271 | An environment satisfying the dm_env interface expected by Acme agents.
272 | """
273 |
274 | def make_env():
275 | env = gym.make(task)
276 | env = wrappers.GymWrapper(env)
277 | # Make sure the environment obeys the dm_env.Environment interface.
278 |
279 | # Wrap the environment so the expected continuous action spec is [-1, 1].
280 | # Note: This is a no-op on 'control' tasks.
281 | env = wrappers.CanonicalSpecWrapper(env, clip=True)
282 | env = wrappers.SinglePrecisionWrapper(env)
283 | return env
284 |
285 | env = make_env()
286 | env.render(mode='rgb_array')
287 |
288 | return make_env
289 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/run_ppil.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Example running proximal point imitation learning on continuous control tasks."""
17 |
18 | from absl import flags
19 | from acme import specs
20 | from acme.agents.jax import sac
21 | from absl import app
22 | from acme.jax import experiments
23 | from acme.utils import lp_utils
24 | import jax
25 | import launchpad as lp
26 |
27 | import experiment_logger
28 | import helpers
29 | from sil import builder
30 | from sil import config as sil_config
31 | from sil import evaluator
32 | from sil import networks
33 |
34 |
35 | _DIST_FLAG = flags.DEFINE_bool(
36 | 'run_distributed',
37 | False,
38 | (
39 | 'Should an agent be executed in a distributed '
40 | 'way. If False, will run single-threaded.'
41 | ),
42 | )
43 | _ENV_NAME = flags.DEFINE_string(
44 | 'env_name', 'HalfCheetah-v2', 'Which environment to run'
45 | )
46 | _SEED = flags.DEFINE_integer('seed', 0, 'Random seed.')
47 | _N_STEPS = flags.DEFINE_integer(
48 | 'num_steps', 250_000, 'Number of env steps to run.'
49 | )
50 | _EVAL_RATIO = flags.DEFINE_integer(
51 | 'eval_every', 5_000, 'How often to evaluate for local runs.'
52 | )
53 | _N_EVAL_EPS = flags.DEFINE_integer(
54 | 'evaluation_episodes', 1, 'Evaluation episodes for local runs.'
55 | )
56 | _N_DEMONSTRATIONS = flags.DEFINE_integer(
57 | 'num_demonstrations', 25, 'No. of demonstration trajectories.'
58 | )
59 | _N_OFFLINE_DATASET = flags.DEFINE_integer(
60 | 'num_offline_demonstrations', 1_000, 'Offline dataset size.'
61 | )
62 | _BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.')
63 | _DISCOUNT = flags.DEFINE_float('discount', 0.99, 'Discount factor')
64 | _ACTOR_LR = flags.DEFINE_float('actor_lr', 3e-5, 'Actor learning rate.')
65 | _CRITIC_LR = flags.DEFINE_float('critic_lr', 3e-4, 'Critic learning rate.')
66 | _REWARD_LR = flags.DEFINE_float('reward_lr', 3e-4, 'Reward learning rate.')
67 | _TAU = flags.DEFINE_float(
68 | 'tau',
69 | 0.005,
70 | (
71 | 'Target network exponential smoothing weight.'
72 | '1. = no update, 0, = no smoothing.'
73 | ),
74 | )
75 | _ENT_COEF = flags.DEFINE_float(
76 | 'entropy_coefficient',
77 | None,
78 | 'Entropy coefficient. Becomes adaptive if None.',
79 | )
80 | _CRITIC_ACTOR_RATIO = flags.DEFINE_integer(
81 | 'critic_actor_update_ratio', 20, 'Critic updates per actor update.'
82 | )
83 | _SGD_STEPS = flags.DEFINE_integer(
84 | 'sgd_steps', 1, 'SGD steps for online sample.'
85 | )
86 | _CRITIC_NETWORK = flags.DEFINE_multi_integer(
87 | 'critic_network', [256, 256], 'Define critic architecture.'
88 | )
89 | _REWARD_NETWORK = flags.DEFINE_multi_integer(
90 | 'reward_network', [256, 256], 'Define reward architecture.'
91 | )
92 | _POLICY_NETWORK = flags.DEFINE_multi_integer(
93 | 'policy_network', [256, 256], 'Define policy architecture.'
94 | )
95 | _POLICY_MODEL = flags.DEFINE_enum(
96 | 'policy_model',
97 | networks.PolicyArchitectures.MLP.value,
98 | [e.value for e in networks.PolicyArchitectures],
99 | 'Define policy model type.',
100 | )
101 | _CRITIC_MODEL = flags.DEFINE_enum(
102 | 'critic_model',
103 | networks.CriticArchitectures.LNMLP.value,
104 | [e.value for e in networks.CriticArchitectures],
105 | 'Define critic model type.',
106 | )
107 | _REWARD_MODEL = flags.DEFINE_enum(
108 | 'reward_model',
109 | networks.RewardArchitectures.LNMLP.value,
110 | [e.value for e in networks.RewardArchitectures],
111 | 'Define reward model type.',
112 | )
113 | _BET = flags.DEFINE_float(
114 | 'bellman_error_temp', 0.08, 'Temperature of logistic Bellman error term.'
115 | )
116 | _CSIL_ALPHA = flags.DEFINE_float('csil_alpha', None, 'CSIL reward temperature.')
117 | _BC_ACTOR_LOSS = flags.DEFINE_bool(
118 | 'bc_actor_loss', False, 'Have expert BC term in actor loss.'
119 | )
120 | _PRETRAIN_BC = flags.DEFINE_bool(
121 | 'pretrain_bc', False, 'Pretrain policy from demonstrations.'
122 | )
123 | _BC_PRIOR = flags.DEFINE_bool(
124 | 'bc_prior', False, 'Used pretrained BC policy as prior.'
125 | )
126 | _POLICY_PRETRAIN_STEPS = flags.DEFINE_integer(
127 | 'policy_pretrain_steps', 25_000, 'Policy pretraining steps.'
128 | )
129 | _POLICY_PRETRAIN_LR = flags.DEFINE_float(
130 | 'policy_pretrain_lr', 1e-3, 'Policy pretraining learning rate.'
131 | )
132 | _LOSS_TYPE = flags.DEFINE_enum(
133 | 'loss_type',
134 | sil_config.Losses.FAITHFUL.value,
135 | [e.value for e in sil_config.Losses],
136 | 'Define regression loss type.',
137 | )
138 | _OFFLINE_FLAG = flags.DEFINE_bool('offline', False, 'Run an offline agent.')
139 | _EVAL_BC = flags.DEFINE_bool(
140 | 'eval_bc', False, 'Run evaluator of BC policy for comparison')
141 | _CHECKPOINTING = flags.DEFINE_bool(
142 | 'checkpoint', False, 'Save models during training.'
143 | )
144 | _WANDB = flags.DEFINE_bool(
145 | 'wandb', True, 'Use weights and biases logging.'
146 | )
147 | _NAME = flags.DEFINE_string('name', 'camera-ready', 'Experiment name')
148 |
149 | def build_experiment_config():
150 | """Builds a P2IL experiment config which can be executed in different ways."""
151 | # Create an environment, grab the spec, and use it to create networks.
152 |
153 | task = _ENV_NAME.value
154 |
155 | mode = f'{"off" if _OFFLINE_FLAG.value else "on"}line'
156 | name = f'ppil_{task}_{mode}'
157 | group = (f'{name}, {_NAME.value}, '
158 | f'ndemos={_N_DEMONSTRATIONS.value}, '
159 | f'alpha={_ENT_COEF.value}')
160 | wandb_kwargs = {
161 | 'project': 'csil',
162 | 'name': name,
163 | 'group': group,
164 | 'tags': ['ppil', task, mode, jax.default_backend()],
165 | 'mode': 'online' if _WANDB.value else 'disabled',
166 | }
167 |
168 | logger_fact = experiment_logger.make_experiment_logger_factory(wandb_kwargs)
169 |
170 | make_env, env_spec, make_demonstrations = helpers.get_env_and_demonstrations(
171 | task, _N_DEMONSTRATIONS.value
172 | )
173 |
174 | def environment_factory(seed: int):
175 | del seed
176 | return make_env()
177 |
178 | batch_size = _BATCH_SIZE.value
179 | seed = _SEED.value
180 | actor_lr = _ACTOR_LR.value
181 |
182 | make_demonstrations_ = lambda batchsize: make_demonstrations(batchsize, seed)
183 |
184 | if _PRETRAIN_BC.value:
185 | dataset_factory = lambda seed_: make_demonstrations(batch_size, seed_)
186 | policy_pretraining = [
187 | sil_config.PretrainingConfig(
188 | loss=sil_config.Losses(_LOSS_TYPE.value),
189 | seed=seed,
190 | dataset_factory=dataset_factory,
191 | steps=_POLICY_PRETRAIN_STEPS.value,
192 | learning_rate=_POLICY_PRETRAIN_LR.value,
193 | use_as_reference=_BC_PRIOR.value,
194 | ),
195 | ]
196 | else:
197 | policy_pretraining = None
198 |
199 | critic_layers, reward_layers, policy_layers = (
200 | _CRITIC_NETWORK.value,
201 | _REWARD_NETWORK.value,
202 | _POLICY_NETWORK.value,
203 | )
204 |
205 | policy_architecture = networks.PolicyArchitectures(_POLICY_MODEL.value)
206 | critic_architecture = networks.CriticArchitectures(_CRITIC_MODEL.value)
207 | reward_architecture = networks.RewardArchitectures(_REWARD_MODEL.value)
208 | csil_alpha = _CSIL_ALPHA.value
209 |
210 | def network_factory(spec: specs.EnvironmentSpec):
211 | return networks.make_networks(
212 | spec,
213 | policy_architecture=policy_architecture,
214 | critic_architecture=critic_architecture,
215 | reward_architecture=reward_architecture,
216 | critic_hidden_layer_sizes=tuple(critic_layers),
217 | reward_hidden_layer_sizes=tuple(reward_layers),
218 | policy_hidden_layer_sizes=tuple(policy_layers),
219 | reward_policy_coherence_alpha=csil_alpha,
220 | )
221 |
222 | if _ENT_COEF.value is not None and _ENT_COEF.value > 0.0:
223 | kwargs = {'entropy_coefficient': _ENT_COEF.value}
224 | else:
225 | kwargs = {'target_entropy': sac.target_entropy_from_env_spec(env_spec)}
226 |
227 | # Construct the agent.
228 | config = sil_config.SILConfig(
229 | imitation=sil_config.ProximalPointConfig(
230 | bellman_error_temperature=_BET.value),
231 | actor_bc_loss=_BC_ACTOR_LOSS.value,
232 | policy_pretraining=policy_pretraining,
233 | expert_demonstration_factory=make_demonstrations_,
234 | discount=_DISCOUNT.value,
235 | critic_learning_rate=_CRITIC_LR.value,
236 | reward_learning_rate=_REWARD_LR.value,
237 | actor_learning_rate=actor_lr,
238 | num_sgd_steps_per_step=_SGD_STEPS.value,
239 | critic_actor_update_ratio=_CRITIC_ACTOR_RATIO.value,
240 | n_step=1,
241 | tau=_TAU.value,
242 | batch_size=batch_size,
243 | **kwargs,
244 | )
245 |
246 | sil_builder = builder.SILBuilder(config)
247 |
248 | imitation_evaluator_factory = evaluator.imitation_evaluator_factory(
249 | agent_config=config,
250 | environment_factory=environment_factory,
251 | network_factory=network_factory,
252 | policy_factory=sil_builder.make_policy,
253 | logger_factory=logger_fact,
254 | )
255 |
256 | evaluators = [imitation_evaluator_factory,]
257 |
258 | if _PRETRAIN_BC.value and _EVAL_BC.value:
259 | bc_evaluator_factory = evaluator.bc_evaluator_factory(
260 | environment_factory=environment_factory,
261 | network_factory=network_factory,
262 | policy_factory=sil_builder.make_policy,
263 | logger_factory=logger_fact,
264 | )
265 | evaluators += [bc_evaluator_factory,]
266 |
267 | checkpoint_config = (experiments.CheckointingConfig()
268 | if _CHECKPOINTING.value else None)
269 | if _OFFLINE_FLAG.value:
270 | make_dataset, _ = helpers.get_offline_dataset(
271 | task, env_spec, _N_DEMONSTRATIONS.value, _N_OFFLINE_DATASET.value
272 | )
273 | # The offline runner needs a random seed for the dataset.
274 | make_dataset_ = lambda k: make_dataset(batch_size, k)
275 | return experiments.OfflineExperimentConfig(
276 | builder=sil_builder,
277 | environment_factory=environment_factory,
278 | network_factory=network_factory,
279 | demonstration_dataset_factory=make_dataset_,
280 | evaluator_factories=evaluators,
281 | max_num_learner_steps=_N_STEPS.value,
282 | environment_spec=env_spec,
283 | seed=_SEED.value,
284 | logger_factory=logger_fact,
285 | checkpointing=checkpoint_config,
286 | )
287 | else:
288 | return experiments.ExperimentConfig(
289 | builder=sil_builder,
290 | environment_factory=environment_factory,
291 | network_factory=network_factory,
292 | evaluator_factories=evaluators,
293 | seed=_SEED.value,
294 | max_num_actor_steps=_N_STEPS.value,
295 | logger_factory=logger_fact,
296 | checkpointing=checkpoint_config,
297 | )
298 |
299 |
300 | def main(_):
301 | config = build_experiment_config()
302 | if _DIST_FLAG.value:
303 | if _OFFLINE_FLAG.value:
304 | program = experiments.make_distributed_offline_experiment(
305 | experiment=config
306 | )
307 | else:
308 | program = experiments.make_distributed_experiment(
309 | experiment=config, num_actors=4
310 | )
311 | lp.launch(
312 | program,
313 | xm_resources=lp_utils.make_xm_docker_resources(program),
314 | )
315 | else:
316 | if _OFFLINE_FLAG.value:
317 | experiments.run_offline_experiment(
318 | experiment=config,
319 | eval_every=_EVAL_RATIO.value,
320 | num_eval_episodes=_N_EVAL_EPS.value,
321 | )
322 | else:
323 | experiments.run_experiment(
324 | experiment=config,
325 | eval_every=_EVAL_RATIO.value,
326 | num_eval_episodes=_N_EVAL_EPS.value,
327 | )
328 |
329 |
330 | if __name__ == '__main__':
331 | app.run(main)
332 |
--------------------------------------------------------------------------------
/run_iqlearn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Example running IQ-Learn on continuous control tasks."""
17 |
18 | from absl import flags
19 | from acme import specs
20 | from acme.agents.jax import sac
21 | from absl import app
22 | from acme.jax import experiments
23 | from acme.utils import lp_utils
24 | import jax
25 | import launchpad as lp
26 |
27 | import experiment_logger
28 | import helpers
29 | from sil import builder
30 | from sil import config as sil_config
31 | from sil import evaluator
32 | from sil import networks
33 |
34 |
35 | _DIST_FLAG = flags.DEFINE_bool(
36 | 'run_distributed',
37 | False,
38 | (
39 | 'Should an agent be executed in a distributed '
40 | 'way. If False, will run single-threaded.'
41 | ),
42 | )
43 | _ENV_NAME = flags.DEFINE_string(
44 | 'env_name', 'HalfCheetah-v2', 'Which environment to run'
45 | )
46 | _SEED = flags.DEFINE_integer('seed', 0, 'Random seed.')
47 | _N_STEPS = flags.DEFINE_integer(
48 | 'num_steps', 250_000, 'Number of env steps to run.'
49 | )
50 | _EVAL_RATIO = flags.DEFINE_integer(
51 | 'eval_every', 5_000, 'How often to run evaluation.'
52 | )
53 | _N_EVAL_EPS = flags.DEFINE_integer(
54 | 'evaluation_episodes', 1, 'Evaluation episodes.'
55 | )
56 | _N_DEMOS = flags.DEFINE_integer(
57 | 'num_demonstrations', 25, 'Number of demonstration trajectories.'
58 | )
59 | _BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.')
60 | _N_OFFLINE_DATASET = flags.DEFINE_integer(
61 | 'num_offline_demonstrations', 25, 'Offline dataset size.'
62 | )
63 | _DISCOUNT = flags.DEFINE_float('discount', 0.99, 'Discount factor')
64 | _ACTOR_LR = flags.DEFINE_float('actor_lr', 3e-5, 'Actor learning rate.')
65 | _CRITIC_LR = flags.DEFINE_float('critic_lr', 3e-4, 'Critic learning rate.')
66 | _REWARD_LR = flags.DEFINE_float(
67 | 'reward_lr', 3e-4, 'Reward learning rate (unused).'
68 | )
69 | _ENT_COEF = flags.DEFINE_float(
70 | 'entropy_coefficient',
71 | 0.01,
72 | 'Entropy coefficient. Becomes adaptive if None.',
73 | )
74 | _TAU = flags.DEFINE_float(
75 | 'tau',
76 | 0.005,
77 | (
78 | 'Target network exponential smoothing weight.'
79 | '1. = no update, 0, = no smoothing.'
80 | ),
81 | )
82 | _OFFLINE_FLAG = flags.DEFINE_bool('offline', False, 'Run an offline agent.')
83 | _CRITIC_ACTOR_RATIO = flags.DEFINE_integer(
84 | 'critic_actor_update_ratio', 1, 'Critic updates per actor update.'
85 | )
86 | _SGD_STEPS = flags.DEFINE_integer(
87 | 'sgd_steps', 1, 'SGD steps for online sample.'
88 | )
89 | _CRITIC_NETWORK = flags.DEFINE_multi_integer(
90 | 'critic_network', [256, 256], 'Define critic architecture.'
91 | )
92 | _POLICY_NETWORK = flags.DEFINE_multi_integer(
93 | 'policy_network', [256, 256, 12, 256], 'Define policy architecture.'
94 | )
95 | _POLICY_MODEL = flags.DEFINE_enum(
96 | 'policy_model',
97 | networks.PolicyArchitectures.MLP.value,
98 | [e.value for e in networks.PolicyArchitectures],
99 | 'Define policy model type.',
100 | )
101 | _CRITIC_MODEL = flags.DEFINE_enum(
102 | 'critic_model',
103 | networks.CriticArchitectures.LNMLP.value,
104 | [e.value for e in networks.CriticArchitectures],
105 | 'Define policy model type.',
106 | )
107 | _BC_ACTOR_LOSS = flags.DEFINE_bool(
108 | 'bc_actor_loss', False, 'Have expert BC term in actor loss.'
109 | )
110 | _PRETRAIN_BC = flags.DEFINE_bool(
111 | 'pretrain_bc', False, 'Pretrain policy from demonstrations.'
112 | )
113 | _BC_PRIOR = flags.DEFINE_bool(
114 | 'bc_prior', False, 'Used pretrained BC policy as prior.'
115 | )
116 | _POLICY_PRETRAIN_STEPS = flags.DEFINE_integer(
117 | 'policy_pretrain_steps', 25_000, 'Policy pretraining steps.'
118 | )
119 | _POLICY_PRETRAIN_LR = flags.DEFINE_float(
120 | 'policy_pretrain_lr', 1e-3, 'Policy pretraining learning rate.'
121 | )
122 | _LOSS_TYPE = flags.DEFINE_enum(
123 | 'loss_type',
124 | sil_config.Losses.FAITHFUL.value,
125 | [e.value for e in sil_config.Losses],
126 | 'Define regression loss type.',
127 | )
128 | _LAYERNORM = flags.DEFINE_bool(
129 | 'policy_layer_norm', False, 'Use layer norm for first layer of the policy.'
130 | )
131 | _EVAL_BC = flags.DEFINE_bool('eval_bc', False,
132 | 'Run evaluator of BC policy for comparison')
133 | _EVAL_PER_VIDEO = flags.DEFINE_integer(
134 | 'evals_per_video', 0, 'Video frequency. Disable using 0.'
135 | )
136 | _CHECKPOINTING = flags.DEFINE_bool(
137 | 'checkpoint', False, 'Save models during training.'
138 | )
139 | _WANDB = flags.DEFINE_bool(
140 | 'wandb', True, 'Use weights and biases logging.')
141 |
142 | _NAME = flags.DEFINE_string('name', 'camera-ready', 'Experiment name')
143 |
144 | def _build_experiment_config():
145 | """Builds an IQ-Learn experiment config which can be executed in different ways."""
146 | # Create an environment, grab the spec, and use it to create networks.
147 |
148 | task = _ENV_NAME.value
149 |
150 | mode = f'{"off" if _OFFLINE_FLAG.value else "on"}line'
151 | name = f'iqlearn_{task}_{mode}'
152 | group = (f'{name}, {_NAME.value}, '
153 | f'ndemos={_N_DEMOS.value}, '
154 | f'alpha={_ENT_COEF.value}')
155 | wandb_kwargs = {
156 | 'project': 'csil',
157 | 'name': name,
158 | 'group': group,
159 | 'tags': ['iqlearn', task, mode, jax.default_backend()],
160 | 'config': flags.FLAGS._flags(),
161 | 'mode': 'online' if _WANDB.value else 'disabled',
162 | }
163 |
164 | logger_fact = experiment_logger.make_experiment_logger_factory(wandb_kwargs)
165 |
166 | make_env, env_spec, make_demonstrations = helpers.get_env_and_demonstrations(
167 | task, _N_DEMOS.value, use_sarsa=False
168 | )
169 |
170 | def environment_factory(seed: int):
171 | del seed
172 | return make_env()
173 |
174 | batch_size = _BATCH_SIZE.value
175 | seed = _SEED.value
176 | actor_lr = _ACTOR_LR.value
177 |
178 | make_demonstrations_ = lambda batchsize: make_demonstrations(batchsize, seed)
179 |
180 | if _PRETRAIN_BC.value:
181 | dataset_factory = lambda seed_: make_demonstrations(batch_size, seed_)
182 | policy_pretraining = [
183 | sil_config.PretrainingConfig(
184 | loss=sil_config.Losses(_LOSS_TYPE.value),
185 | seed=seed,
186 | dataset_factory=dataset_factory,
187 | steps=_POLICY_PRETRAIN_STEPS.value,
188 | learning_rate=_POLICY_PRETRAIN_LR.value,
189 | use_as_reference=_BC_PRIOR.value,
190 | ),
191 | ]
192 | else:
193 | policy_pretraining = None
194 |
195 | critic_layers = _CRITIC_NETWORK.value
196 | policy_layers = _POLICY_NETWORK.value
197 | policy_architecture = networks.PolicyArchitectures(_POLICY_MODEL.value)
198 | critic_architecture = networks.CriticArchitectures(_CRITIC_MODEL.value)
199 | use_layer_norm = _LAYERNORM.value
200 |
201 | def network_factory(spec: specs.EnvironmentSpec):
202 | return networks.make_networks(
203 | spec,
204 | policy_architecture=policy_architecture,
205 | critic_architecture=critic_architecture,
206 | critic_hidden_layer_sizes=tuple(critic_layers),
207 | policy_hidden_layer_sizes=tuple(policy_layers),
208 | layer_norm_policy=use_layer_norm,
209 | )
210 |
211 | if _ENT_COEF.value is not None and _ENT_COEF.value > 0.0:
212 | kwargs = {'entropy_coefficient': _ENT_COEF.value}
213 | else:
214 | kwargs = {'target_entropy': sac.target_entropy_from_env_spec(env_spec)}
215 |
216 | # Construct the agent.
217 | config = sil_config.SILConfig(
218 | expert_demonstration_factory=make_demonstrations_,
219 | imitation=sil_config.InverseSoftQConfig(
220 | divergence=sil_config.Divergence.CHI),
221 | actor_bc_loss=_BC_ACTOR_LOSS.value,
222 | policy_pretraining=policy_pretraining,
223 | discount=_DISCOUNT.value,
224 | critic_learning_rate=_CRITIC_LR.value,
225 | reward_learning_rate=_REWARD_LR.value,
226 | actor_learning_rate=actor_lr,
227 | num_sgd_steps_per_step=_SGD_STEPS.value,
228 | tau=_TAU.value,
229 | critic_actor_update_ratio=_CRITIC_ACTOR_RATIO.value,
230 | n_step=1,
231 | batch_size=batch_size,
232 | **kwargs,
233 | )
234 |
235 | sil_builder = builder.SILBuilder(config)
236 |
237 | imitation_evaluator_factory = evaluator.imitation_evaluator_factory(
238 | agent_config=config,
239 | environment_factory=environment_factory,
240 | network_factory=network_factory,
241 | policy_factory=sil_builder.make_policy,
242 | logger_factory=logger_fact,
243 | )
244 |
245 | evaluators = [imitation_evaluator_factory,]
246 |
247 | if _PRETRAIN_BC.value and _EVAL_BC.value:
248 | bc_evaluator_factory = evaluator.bc_evaluator_factory(
249 | environment_factory=environment_factory,
250 | network_factory=network_factory,
251 | policy_factory=sil_builder.make_policy,
252 | logger_factory=logger_fact,
253 | )
254 | evaluators += [bc_evaluator_factory,]
255 |
256 | if _EVAL_PER_VIDEO.value > 0:
257 | video_evaluator_factory = evaluator.video_evaluator_factory(
258 | environment_factory=environment_factory,
259 | network_factory=network_factory,
260 | policy_factory=sil_builder.make_policy,
261 | videos_per_eval=_EVAL_PER_VIDEO.value,
262 | logger_factory=logger_fact,
263 | )
264 | evaluators += [video_evaluator_factory]
265 |
266 | checkpoint_config = (experiments.CheckointingConfig()
267 | if _CHECKPOINTING.value else None)
268 | if _OFFLINE_FLAG.value:
269 | # Note: For offline learning, the dataset needs to contain the offline and
270 | # expert data, so make_demonstrations isn't used and make_dataset combines
271 | # the two, due to how the OfflineBuilder is constructed.
272 | make_dataset, _ = helpers.get_offline_dataset(
273 | task,
274 | env_spec,
275 | _N_DEMOS.value,
276 | _N_OFFLINE_DATASET.value,
277 | use_sarsa=False,
278 | )
279 | # Offline iterator takes RNG key.
280 | make_dataset_ = lambda k: make_dataset(batch_size, k)
281 | return experiments.OfflineExperimentConfig(
282 | builder=sil_builder,
283 | environment_factory=environment_factory,
284 | network_factory=network_factory,
285 | demonstration_dataset_factory=make_dataset_,
286 | evaluator_factories=evaluators,
287 | max_num_learner_steps=_N_STEPS.value,
288 | environment_spec=env_spec,
289 | seed=_SEED.value,
290 | logger_factory=logger_fact,
291 | checkpointing=checkpoint_config,
292 | )
293 | else: # Online.
294 | return experiments.ExperimentConfig(
295 | builder=sil_builder,
296 | environment_factory=environment_factory,
297 | network_factory=network_factory,
298 | evaluator_factories=evaluators,
299 | seed=_SEED.value,
300 | max_num_actor_steps=_N_STEPS.value,
301 | logger_factory=logger_fact,
302 | checkpointing=checkpoint_config,
303 | )
304 |
305 | def main(_):
306 | config = _build_experiment_config()
307 | if _DIST_FLAG.value:
308 | if _OFFLINE_FLAG.value:
309 | program = experiments.make_distributed_offline_experiment(
310 | experiment=config
311 | )
312 | else:
313 | program = experiments.make_distributed_experiment(
314 | experiment=config, num_actors=4
315 | )
316 | lp.launch(
317 | program,
318 | xm_resources=lp_utils.make_xm_docker_resources(program),
319 | )
320 | else:
321 | if _OFFLINE_FLAG.value:
322 | experiments.run_offline_experiment(
323 | experiment=config,
324 | eval_every=_EVAL_RATIO.value,
325 | num_eval_episodes=_N_EVAL_EPS.value,
326 | )
327 | else:
328 | experiments.run_experiment(
329 | experiment=config,
330 | eval_every=_EVAL_RATIO.value,
331 | num_eval_episodes=_N_EVAL_EPS.value,
332 | )
333 |
334 |
335 | if __name__ == '__main__':
336 | app.run(main)
337 |
--------------------------------------------------------------------------------
/run_csil.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Example running coherent soft imitation learning on continuous control tasks."""
17 |
18 | import math
19 | from typing import Iterator
20 | import wandb
21 |
22 | from absl import flags
23 | from acme import specs
24 | from acme import types
25 | from acme.agents.jax import sac
26 | from absl import app
27 | from acme.jax import experiments
28 | from acme.utils import lp_utils
29 | import jax
30 | import jax.random as rand
31 | import launchpad as lp
32 |
33 | import experiment_logger
34 | import helpers
35 | from sil import builder
36 | from sil import config as sil_config
37 | from sil import evaluator
38 | from sil import networks
39 |
40 | USE_SARSA = True
41 |
42 | _DIST_FLAG = flags.DEFINE_bool(
43 | 'run_distributed',
44 | False,
45 | (
46 | 'Should an agent be executed in a distributed '
47 | 'way. If False, will run single-threaded.'
48 | ),
49 | )
50 | _ENV_NAME = flags.DEFINE_string(
51 | 'env_name', 'HalfCheetah-v2', 'Which environment to run'
52 | )
53 | _SEED = flags.DEFINE_integer('seed', 0, 'Random seed.')
54 | _N_STEPS = flags.DEFINE_integer(
55 | 'num_steps', 250_000, 'Number of env steps to run.'
56 | )
57 | _EVAL_RATIO = flags.DEFINE_integer(
58 | 'eval_every', 1_000, 'How often to evaluate for local runs.'
59 | )
60 | _N_EVAL_EPS = flags.DEFINE_integer(
61 | 'evaluation_episodes', 1, 'Evaluation episodes for local runs.'
62 | )
63 | _N_DEMONSTRATIONS = flags.DEFINE_integer(
64 | 'num_demonstrations', 25, 'No. of demonstration trajectories.'
65 | )
66 | _N_OFFLINE_DATASET = flags.DEFINE_integer(
67 | 'num_offline_demonstrations', 1_000, 'Offline dataset size.'
68 | )
69 | _BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.')
70 | _ENT_COEF = flags.DEFINE_float('entropy_coefficient', 0.01, 'Temperature')
71 | _ENT_SF = flags.DEFINE_float('ent_sf', 1.0, 'Scale entropy target.')
72 | _DAMP = flags.DEFINE_float('damping', 0.0, 'Constraint damping.')
73 | _SF = flags.DEFINE_float('scale_factor', 1.0, 'Reward loss scale factor.')
74 | _GNSF = flags.DEFINE_float(
75 | 'grad_norm_scale_factor', 1.0, 'Critic grad scale factor.'
76 | )
77 | _DISCOUNT = flags.DEFINE_float('discount', 0.99, 'Discount factor')
78 | _ACTOR_LR = flags.DEFINE_float('actor_lr', 3e-4, 'Actor learning rate.')
79 | _CRITIC_LR = flags.DEFINE_float('critic_lr', 3e-4, 'Critic learning rate.')
80 | _REWARD_LR = flags.DEFINE_float('reward_lr', 1e-3, 'Reward learning rate.')
81 | _TAU = flags.DEFINE_float(
82 | 'tau',
83 | 0.005,
84 | (
85 | 'Target network exponential smoothing weight.'
86 | '1. = no update, 0, = no smoothing.'
87 | ),
88 | )
89 | _CRITIC_ACTOR_RATIO = flags.DEFINE_integer(
90 | 'critic_actor_update_ratio', 1, 'Critic updates per actor update.'
91 | )
92 | _SGD_STEPS = flags.DEFINE_integer(
93 | 'sgd_steps', 1, 'SGD steps for online sample.'
94 | )
95 | _CRITIC_NETWORK = flags.DEFINE_multi_integer(
96 | 'critic_network', [256, 256], 'Define critic architecture.'
97 | )
98 | _REWARD_NETWORK = flags.DEFINE_multi_integer(
99 | 'reward_network', [256, 256], 'Define reward architecture. (Unused)'
100 | )
101 | _POLICY_NETWORK = flags.DEFINE_multi_integer(
102 | 'policy_network', [256, 256, 12, 256], 'Define policy architecture.'
103 | )
104 | _POLICY_MODEL = flags.DEFINE_enum(
105 | 'policy_model',
106 | networks.PolicyArchitectures.HETSTATTRI.value,
107 | [e.value for e in networks.PolicyArchitectures],
108 | 'Define policy model type.',
109 | )
110 | _LAYERNORM = flags.DEFINE_bool(
111 | 'policy_layer_norm', False, 'Use layer norm for first layer of the policy.'
112 | )
113 | _CRITIC_MODEL = flags.DEFINE_enum(
114 | 'critic_model',
115 | networks.CriticArchitectures.LNMLP.value,
116 | [e.value for e in networks.CriticArchitectures],
117 | 'Define critic model type.',
118 | )
119 | _REWARD_MODEL = flags.DEFINE_enum(
120 | 'reward_model',
121 | networks.RewardArchitectures.PCSIL.value,
122 | [e.value for e in networks.RewardArchitectures],
123 | 'Define reward model type.',
124 | )
125 | _RCALE = flags.DEFINE_float('reward_scaling', 1.0, 'Scale learned reward.')
126 | _FINETUNE_R = flags.DEFINE_bool('finetune_reward', True, 'Finetune reward.')
127 | _LOSS_TYPE = flags.DEFINE_enum(
128 | 'loss_type',
129 | sil_config.Losses.FAITHFUL.value,
130 | [e.value for e in sil_config.Losses],
131 | 'Define regression loss type.',
132 | )
133 | _POLICY_PRETRAIN_STEPS = flags.DEFINE_integer(
134 | 'policy_pretrain_steps', 25_000, 'Policy pretraining steps.'
135 | )
136 | _POLICY_PRETRAIN_LR = flags.DEFINE_float(
137 | 'policy_pretrain_lr', 1e-3, 'Policy pretraining learning rate.'
138 | )
139 | _CRITIC_PRETRAIN_STEPS = flags.DEFINE_integer(
140 | 'critic_pretrain_steps', 5_000, 'Critic pretraining steps.'
141 | )
142 | _CRITIC_PRETRAIN_LR = flags.DEFINE_float(
143 | 'critic_pretrain_lr', 1e-4, 'Critic pretraining learning rate.'
144 | )
145 | _EVAL_BC = flags.DEFINE_bool('eval_bc', False,
146 | 'Run evaluator of BC policy for comparison')
147 | _OFFLINE_FLAG = flags.DEFINE_bool('offline', False, 'Run an offline agent.')
148 | _EVAL_PER_VIDEO = flags.DEFINE_integer(
149 | 'evals_per_video', 0, 'Video frequency. Disable using 0.'
150 | )
151 | _NUM_ACTORS = flags.DEFINE_integer(
152 | 'num_actors', 4, 'Number of distributed actors.'
153 | )
154 | _CHECKPOINTING = flags.DEFINE_bool(
155 | 'checkpoint', False, 'Save models during training.'
156 | )
157 | _WANDB = flags.DEFINE_bool(
158 | 'wandb', True, 'Use weights and biases logging.'
159 | )
160 | _NAME = flags.DEFINE_string('name', 'camera-ready', 'Experiment name')
161 |
162 | def _build_experiment_config():
163 | """Builds a CSIL experiment config which can be executed in different ways."""
164 |
165 | # Create an environment, grab the spec, and use it to create networks.
166 | task = _ENV_NAME.value
167 |
168 | mode = f'{"off" if _OFFLINE_FLAG.value else "on"}line'
169 | name = f'csil_{task}_{mode}'
170 | group = (f'{name}, {_NAME.value}, '
171 | f'ndemos={_N_DEMONSTRATIONS.value}, '
172 | f'alpha={_ENT_COEF.value}')
173 | wandb_kwargs = {
174 | 'project': 'csil',
175 | 'name': name,
176 | 'group': group,
177 | 'tags': ['csil', task, mode, jax.default_backend()],
178 | 'config': flags.FLAGS._flags(),
179 | 'mode': 'online' if _WANDB.value else 'disabled',
180 | }
181 |
182 | logger_fact = experiment_logger.make_experiment_logger_factory(wandb_kwargs)
183 |
184 | make_env, env_spec, make_demonstrations = helpers.get_env_and_demonstrations(
185 | task, _N_DEMONSTRATIONS.value, use_sarsa=USE_SARSA,
186 | in_memory='image' not in task
187 | )
188 |
189 | def environment_factory(seed: int):
190 | del seed
191 | return make_env()
192 |
193 | batch_size = _BATCH_SIZE.value
194 | seed = _SEED.value
195 | actor_lr = _ACTOR_LR.value
196 |
197 | make_demonstrations_ = lambda batchsize: make_demonstrations(batchsize, seed)
198 |
199 | if _ENT_COEF.value > 0.0:
200 | kwargs = {'entropy_coefficient': _ENT_COEF.value}
201 | else:
202 | target_entropy = _ENT_SF.value * sac.target_entropy_from_env_spec(
203 | env_spec, target_entropy_per_dimension=abs(_ENT_SF.value))
204 | kwargs = {'target_entropy': target_entropy}
205 |
206 | # Important step that normalizes reward values -- do not change!
207 | csil_alpha = _RCALE.value / math.prod(env_spec.actions.shape)
208 |
209 | policy_architecture = networks.PolicyArchitectures(_POLICY_MODEL.value)
210 | bc_policy_architecture = policy_architecture
211 | critic_architecture = networks.CriticArchitectures(_CRITIC_MODEL.value)
212 | reward_architecture = networks.RewardArchitectures(_REWARD_MODEL.value)
213 | policy_layers = _POLICY_NETWORK.value
214 | reward_layers = _REWARD_NETWORK.value
215 | critic_layers = _CRITIC_NETWORK.value
216 | use_layer_norm = _LAYERNORM.value
217 |
218 | def network_factory(spec: specs.EnvironmentSpec):
219 | return networks.make_networks(
220 | spec=spec,
221 | reward_policy_coherence_alpha=csil_alpha,
222 | policy_architecture=policy_architecture,
223 | critic_architecture=critic_architecture,
224 | reward_architecture=reward_architecture,
225 | bc_policy_architecture=bc_policy_architecture,
226 | policy_hidden_layer_sizes=tuple(policy_layers),
227 | reward_hidden_layer_sizes=tuple(reward_layers),
228 | critic_hidden_layer_sizes=tuple(critic_layers),
229 | bc_policy_hidden_layer_sizes=tuple(policy_layers),
230 | layer_norm_policy=use_layer_norm,
231 | )
232 |
233 | demo_factory = lambda seed_: make_demonstrations(batch_size, seed_)
234 | policy_pretraining = sil_config.PretrainingConfig(
235 | loss=sil_config.Losses(_LOSS_TYPE.value),
236 | seed=seed,
237 | dataset_factory=demo_factory,
238 | steps=_POLICY_PRETRAIN_STEPS.value,
239 | learning_rate=_POLICY_PRETRAIN_LR.value,
240 | use_as_reference=True,
241 | )
242 | if _OFFLINE_FLAG.value:
243 | _, offline_dataset = helpers.get_offline_dataset(
244 | task,
245 | env_spec,
246 | _N_DEMONSTRATIONS.value,
247 | _N_OFFLINE_DATASET.value,
248 | use_sarsa=USE_SARSA,
249 | )
250 |
251 | def offline_pretraining_dataset(rseed: int) -> Iterator[types.Transition]:
252 | rkey = rand.PRNGKey(rseed)
253 | return helpers.MixedIterator(
254 | offline_dataset(batch_size, rkey),
255 | make_demonstrations(batch_size, rseed),
256 | )
257 |
258 | offline_policy_pretraining = sil_config.PretrainingConfig(
259 | loss=sil_config.Losses(_LOSS_TYPE.value),
260 | seed=seed,
261 | dataset_factory=offline_pretraining_dataset,
262 | steps=_POLICY_PRETRAIN_STEPS.value,
263 | learning_rate=_POLICY_PRETRAIN_LR.value,
264 | )
265 | policy_pretrainers = [offline_policy_pretraining, policy_pretraining]
266 | critic_dataset = demo_factory
267 | else:
268 | policy_pretrainers = [policy_pretraining,]
269 | critic_dataset = demo_factory
270 |
271 | critic_pretraining = sil_config.PretrainingConfig(
272 | seed=seed,
273 | dataset_factory=critic_dataset,
274 | steps=_CRITIC_PRETRAIN_STEPS.value,
275 | learning_rate=_CRITIC_PRETRAIN_LR.value,
276 | )
277 | # Construct the agent.
278 | config_ = sil_config.SILConfig(
279 | imitation=sil_config.CoherentConfig(
280 | alpha=csil_alpha,
281 | reward_scaling=_RCALE.value,
282 | refine_reward=_FINETUNE_R.value,
283 | negative_reward=(
284 | reward_architecture == networks.RewardArchitectures.NCSIL),
285 | grad_norm_sf=_GNSF.value,
286 | scale_factor=_SF.value,
287 | ),
288 | actor_bc_loss=False,
289 | policy_pretraining=policy_pretrainers,
290 | critic_pretraining=critic_pretraining,
291 | expert_demonstration_factory=make_demonstrations_,
292 | discount=_DISCOUNT.value,
293 | critic_learning_rate=_CRITIC_LR.value,
294 | reward_learning_rate=_REWARD_LR.value,
295 | actor_learning_rate=actor_lr,
296 | num_sgd_steps_per_step=_SGD_STEPS.value,
297 | critic_actor_update_ratio=_CRITIC_ACTOR_RATIO.value,
298 | n_step=1,
299 | damping=_DAMP.value,
300 | tau=_TAU.value,
301 | batch_size=batch_size,
302 | samples_per_insert=batch_size,
303 | alpha_learning_rate=1e-2,
304 | alpha_init=0.01,
305 | **kwargs,
306 | )
307 |
308 | sil_builder = builder.SILBuilder(config_)
309 |
310 | imitation_evaluator_factory = evaluator.imitation_evaluator_factory(
311 | agent_config=config_,
312 | environment_factory=environment_factory,
313 | network_factory=network_factory,
314 | policy_factory=sil_builder.make_policy,
315 | logger_factory=logger_fact,
316 | )
317 |
318 | evaluators = [imitation_evaluator_factory,]
319 | if _EVAL_PER_VIDEO.value > 0:
320 | video_evaluator_factory = evaluator.video_evaluator_factory(
321 | environment_factory=environment_factory,
322 | network_factory=network_factory,
323 | policy_factory=sil_builder.make_policy,
324 | videos_per_eval=_EVAL_PER_VIDEO.value,
325 | logger_factory=logger_fact,
326 | )
327 | evaluators += [imitation_evaluator_factory,]
328 |
329 | if _EVAL_BC.value:
330 | bc_evaluator_factory = evaluator.bc_evaluator_factory(
331 | environment_factory=environment_factory,
332 | network_factory=network_factory,
333 | policy_factory=sil_builder.make_bc_policy,
334 | logger_factory=logger_fact,
335 | )
336 | evaluators += [bc_evaluator_factory,]
337 |
338 | checkpoint_config = (experiments.CheckointingConfig()
339 | if _CHECKPOINTING.value else None)
340 | if _OFFLINE_FLAG.value:
341 | make_offline_dataset, _ = helpers.get_offline_dataset(
342 | task,
343 | env_spec,
344 | _N_DEMONSTRATIONS.value,
345 | _N_OFFLINE_DATASET.value,
346 | use_sarsa=USE_SARSA,
347 | )
348 | make_offline_dataset_ = lambda rk: make_offline_dataset(batch_size, rk)
349 | return experiments.OfflineExperimentConfig(
350 | builder=sil_builder,
351 | environment_factory=environment_factory,
352 | network_factory=network_factory,
353 | demonstration_dataset_factory=make_offline_dataset_,
354 | evaluator_factories=evaluators,
355 | max_num_learner_steps=_N_STEPS.value,
356 | environment_spec=env_spec,
357 | seed=_SEED.value,
358 | logger_factory=logger_fact,
359 | checkpointing=checkpoint_config,
360 | )
361 | else:
362 | return experiments.ExperimentConfig(
363 | builder=sil_builder,
364 | environment_factory=environment_factory,
365 | network_factory=network_factory,
366 | evaluator_factories=evaluators,
367 | seed=_SEED.value,
368 | max_num_actor_steps=_N_STEPS.value,
369 | logger_factory=logger_fact,
370 | checkpointing=checkpoint_config,
371 | )
372 |
373 |
374 | def main(_):
375 | config = _build_experiment_config()
376 | if _DIST_FLAG.value:
377 | if _OFFLINE_FLAG.value:
378 | program = experiments.make_distributed_offline_experiment(
379 | experiment=config
380 | )
381 | else:
382 | program = experiments.make_distributed_experiment(
383 | experiment=config, num_actors=_NUM_ACTORS.value
384 | )
385 | lp.launch(
386 | program,
387 | xm_resources=lp_utils.make_xm_docker_resources(program),
388 | )
389 | else:
390 | if _OFFLINE_FLAG.value:
391 | experiments.run_offline_experiment(
392 | experiment=config,
393 | eval_every=_EVAL_RATIO.value,
394 | num_eval_episodes=_N_EVAL_EPS.value,
395 | )
396 | else:
397 | experiments.run_experiment(
398 | experiment=config,
399 | eval_every=_EVAL_RATIO.value,
400 | num_eval_episodes=_N_EVAL_EPS.value,
401 | )
402 |
403 |
404 | if __name__ == '__main__':
405 | wandb.setup()
406 | app.run(main)
407 | wandb.finish()
408 |
409 |
--------------------------------------------------------------------------------
/sil/evaluator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Custom evaluator for soft imitation learning setting.
17 |
18 | Includes the estimated return from the learned reward function.
19 |
20 | Also captures videos.
21 | """
22 | from typing import Dict, Optional, Sequence, Union
23 |
24 | from absl import logging
25 | import acme
26 | from acme import core
27 | from acme import environment_loop
28 | from acme import specs
29 | from acme.agents.jax import actor_core as actor_core_lib
30 | from acme.agents.jax import actors
31 | from acme.agents.jax import builders
32 | from acme.jax import networks as networks_lib
33 | from acme.jax import types
34 | from acme.jax import utils
35 | from acme.jax import variable_utils
36 | from acme.jax.experiments import config as experiment_config
37 | from acme.utils import counting
38 | from acme.utils import experiment_utils
39 | from acme.utils import loggers
40 | from acme.utils import observers as observers_lib
41 | from acme.wrappers import video as video_wrappers
42 | import dm_env
43 | import jax
44 | from jax import random
45 | import jax.numpy as jnp
46 | import numpy as np
47 | import tree
48 |
49 | from sil import config as sil_config
50 | from sil import networks as sil_networks
51 |
52 |
53 | N_EVALS_PER_VIDEO = 1
54 |
55 |
56 | class GymVideoWrapper(video_wrappers.VideoWrapper):
57 |
58 | def _render_frame(self, observation):
59 | """Renders a frame from the given environment observation."""
60 | del observation
61 | return self.environment.render(mode='rgb_array') # pytype: disable=attribute-error
62 |
63 | def _write_frames(self):
64 | logging.info(
65 | 'Saving video: %s/%s/%d', self._path, self._filename, self._counter
66 | )
67 | super()._write_frames()
68 |
69 |
70 | class RewardCore:
71 | """A learned implicit or explicit reward function."""
72 |
73 | def __init__(
74 | self,
75 | networks: sil_networks.SILNetworks,
76 | reward_factory: sil_config.RewardFact,
77 | discount: float,
78 | variable_source: core.VariableSource,
79 | ):
80 | self._reward_variable_client = variable_utils.VariableClient(
81 | variable_source, 'reward', device='cpu'
82 | )
83 | self._critic_variable_client = variable_utils.VariableClient(
84 | variable_source, 'critic', device='cpu'
85 | )
86 | self._policy_variable_client = variable_utils.VariableClient(
87 | variable_source, 'policy', device='cpu'
88 | )
89 | self._reward_fn = None
90 |
91 | def _reward(
92 | state: jnp.ndarray,
93 | action: jnp.ndarray,
94 | next_state: jnp.ndarray,
95 | transition_discount: jnp.ndarray,
96 | key: jnp.ndarray,
97 | reward_params: networks_lib.Params,
98 | critic_params: networks_lib.Params,
99 | policy_params: networks_lib.Params,
100 | ) -> jnp.ndarray:
101 | def state_action_reward_fn(
102 | state: jnp.ndarray, action: jnp.ndarray
103 | ) -> jnp.ndarray:
104 | return jnp.ravel(
105 | networks.reward_network.apply(reward_params, state, action)
106 | )
107 |
108 | def state_action_value_fn(
109 | state: jnp.ndarray, action: jnp.ndarray
110 | ) -> jnp.ndarray:
111 | return networks.critic_network.apply(
112 | critic_params, state, action).min(axis=-1)
113 |
114 | def state_value_fn(
115 | state: jnp.ndarray, policy_key: jnp.ndarray
116 | ) -> jnp.ndarray:
117 | action_dist = networks.policy_network.apply(policy_params, state)
118 | action = action_dist.sample(seed=policy_key)
119 | v = networks.critic_network.apply(
120 | critic_params, state, action).min(axis=-1)
121 | return v
122 |
123 | reward = reward_factory(
124 | state_action_reward_fn,
125 | state_action_value_fn,
126 | state_value_fn,
127 | discount,
128 | )
129 | return reward(state, action, next_state, transition_discount, key)
130 |
131 | self._reward_fn = jax.jit(_reward)
132 |
133 | @property
134 | def _reward_params(self) -> Sequence[networks_lib.Params]:
135 | return self._reward_variable_client.params
136 |
137 | @property
138 | def _critic_params(self) -> Sequence[networks_lib.Params]:
139 | return self._critic_variable_client.params
140 |
141 | @property
142 | def _policy_params(self) -> Sequence[networks_lib.Params]:
143 | params = self._policy_variable_client.params
144 | return params
145 |
146 | def __call__(
147 | self,
148 | state: jnp.ndarray,
149 | action: jnp.ndarray,
150 | next_state: jnp.ndarray,
151 | discount: jnp.ndarray,
152 | key: jnp.ndarray,
153 | ) -> jnp.ndarray:
154 | assert self._reward_fn is not None
155 | return self._reward_fn(
156 | state,
157 | action,
158 | next_state,
159 | discount,
160 | key,
161 | self._reward_params,
162 | self._critic_params,
163 | self._policy_params,
164 | )
165 |
166 | def update(self, wait: bool = False):
167 | """Get updated parameters and update reward function."""
168 |
169 | self._reward_variable_client.update(wait)
170 | self._critic_variable_client.update(wait)
171 | self._policy_variable_client.update(wait)
172 |
173 |
174 | class ImitationObserver(observers_lib.EnvLoopObserver):
175 | """Observer that evaluated using the learned reward function."""
176 |
177 | def __init__(self, reward_fn: RewardCore):
178 | self._reward_fn = reward_fn
179 | self._imitation_return = 0.0
180 | self._current_observation = None
181 | self._episode_rewards = []
182 | self._epsiode_imitation_rewards = []
183 | self._key = random.PRNGKey(0)
184 |
185 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep):
186 | self._reward_fn.update()
187 | self._imitation_return = 0.0
188 | self._current_observation = timestep.observation
189 | self._episode_rewards = []
190 | self._epsiode_imitation_rewards = []
191 |
192 | def observe(
193 | self,
194 | env: dm_env.Environment,
195 | timestep: dm_env.TimeStep,
196 | action: np.ndarray,
197 | ):
198 | """Records one environment step."""
199 | self._key, subkey = random.split(self._key)
200 | obs = tree.map_structure(
201 | lambda obs: jnp.expand_dims(obs, 0), self._current_observation
202 | )
203 | next_obs = tree.map_structure(
204 | lambda obs: jnp.expand_dims(obs, 0), timestep.observation
205 | )
206 | imitation_reward = self._reward_fn(
207 | obs,
208 | jnp.expand_dims(action, 0),
209 | next_obs,
210 | jnp.expand_dims(timestep.discount, 0),
211 | subkey,
212 | ).squeeze()
213 | self._current_observation = timestep.observation
214 | self._imitation_return += imitation_reward
215 | self._episode_rewards += [timestep.reward]
216 | self._epsiode_imitation_rewards += [imitation_reward]
217 |
218 | def compute_correlation(self) -> jnp.ndarray:
219 | r = jnp.asarray(self._episode_rewards)
220 | ir = jnp.asarray(self._epsiode_imitation_rewards)
221 | return jnp.corrcoef(r, ir)[0, 1]
222 |
223 | def get_metrics(self) -> Dict[str, observers_lib.Number]:
224 | """Returns metrics collected for the current episode."""
225 | corr = self.compute_correlation()
226 | metrics = {
227 | 'imitation_return': self._imitation_return,
228 | 'episode_reward_corr': corr,
229 | }
230 | return metrics # pytype: disable=bad-return-type # jnp-array
231 |
232 |
233 | def adroit_success(env: dm_env.Environment) -> bool:
234 | # Adroit environments have get_info methods.
235 | if not hasattr(env, 'get_info'):
236 | return False
237 | info = getattr(env, 'get_info')()
238 | if 'goal_achieved' in info:
239 | return info['goal_achieved']
240 | else:
241 | return False
242 |
243 |
244 | def _get_success_from_env(env: dm_env.Environment) -> bool:
245 | """Obtain the success flag for Adroit environments."""
246 | return adroit_success(env)
247 |
248 |
249 | class SuccessObserver(observers_lib.EnvLoopObserver):
250 | """Observer that extracts the goal_achieved flag from Adroit mj_env tasks."""
251 |
252 | _success: bool = False
253 |
254 | def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep):
255 | del env, timestep
256 | self._success = False
257 |
258 | def observe(
259 | self,
260 | env: dm_env.Environment,
261 | timestep: dm_env.TimeStep,
262 | action: np.ndarray,
263 | ):
264 | """Records one environment step."""
265 | del timestep, action
266 | success = _get_success_from_env(env)
267 | self._success = self._success or success
268 |
269 | def get_metrics(self) -> Dict[str, observers_lib.Number]:
270 | """Returns metrics collected for the current episode."""
271 | metrics = {
272 | 'success': 1.0 if self._success else 0.0,
273 | }
274 | return metrics
275 |
276 |
277 | def imitation_evaluator_factory(
278 | environment_factory: types.EnvironmentFactory,
279 | network_factory: experiment_config.NetworkFactory[builders.Networks],
280 | policy_factory: experiment_config.PolicyFactory[
281 | builders.Networks, builders.Policy
282 | ],
283 | logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger,
284 | agent_config: Union[sil_config.SILConfig, None] = None,
285 | ) -> experiment_config.EvaluatorFactory[builders.Policy]:
286 | """Returns an imitation learning evaluator process."""
287 |
288 | def evaluator(
289 | random_key: types.PRNGKey,
290 | variable_source: core.VariableSource,
291 | counter: counting.Counter,
292 | make_actor: experiment_config.MakeActorFn[builders.Policy],
293 | ) -> environment_loop.EnvironmentLoop:
294 | """The evaluation process for imitation learning."""
295 | # Create environment and evaluator networks.
296 | environment_key, actor_key = jax.random.split(random_key)
297 | # Environments normally require uint32 as a seed.
298 | environment = environment_factory(utils.sample_uint32(environment_key))
299 |
300 | eval_environment = environment
301 |
302 | environment_spec = specs.make_environment_spec(environment)
303 | networks = network_factory(environment_spec)
304 | policy = policy_factory(networks, environment_spec, True)
305 | actor = make_actor(actor_key, policy, environment_spec, variable_source)
306 |
307 | success_observer = SuccessObserver()
308 | if agent_config is not None:
309 | reward_factory = agent_config.imitation.reward_factory()
310 | reward_fn = RewardCore(
311 | networks, reward_factory, agent_config.discount, variable_source
312 | )
313 |
314 | imitation_observer = ImitationObserver(reward_fn)
315 | observers = (success_observer, imitation_observer)
316 | else:
317 | observers = (success_observer,)
318 |
319 | # Create logger and counter.
320 | counter = counting.Counter(counter, 'imitation_evaluator')
321 | logger = logger_factory('imitation_evaluator', 'actor_steps', 0)
322 |
323 | # Create the run loop and return it.
324 | return environment_loop.EnvironmentLoop(
325 | eval_environment, actor, counter, logger, observers=observers
326 | )
327 |
328 | return evaluator
329 |
330 |
331 | def video_evaluator_factory(
332 | environment_factory: types.EnvironmentFactory,
333 | network_factory: experiment_config.NetworkFactory[builders.Networks],
334 | policy_factory: experiment_config.PolicyFactory[
335 | builders.Networks, builders.Policy
336 | ],
337 | logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger,
338 | videos_per_eval: int = 0,
339 | ) -> experiment_config.EvaluatorFactory[builders.Policy]:
340 | """Returns an evaluator process that records videos."""
341 |
342 | def evaluator(
343 | random_key: types.PRNGKey,
344 | variable_source: core.VariableSource,
345 | counter: counting.Counter,
346 | make_actor: experiment_config.MakeActorFn[builders.Policy],
347 | ) -> environment_loop.EnvironmentLoop:
348 | """The evaluation process for recording videos."""
349 | # Create environment and evaluator networks.
350 | environment_key, actor_key = jax.random.split(random_key)
351 | # Environments normally require uint32 as a seed.
352 | environment = environment_factory(utils.sample_uint32(environment_key))
353 |
354 | if videos_per_eval > 0:
355 | eval_environment = GymVideoWrapper(
356 | environment,
357 | record_every=videos_per_eval,
358 | frame_rate=40,
359 | filename='eval_episode',
360 | )
361 | else:
362 | eval_environment = environment
363 |
364 | environment_spec = specs.make_environment_spec(environment)
365 | networks = network_factory(environment_spec)
366 | policy = policy_factory(networks, environment_spec, True)
367 | actor = make_actor(actor_key, policy, environment_spec, variable_source)
368 |
369 | observers = (SuccessObserver(),)
370 |
371 | # Create logger and counter.
372 | counter = counting.Counter(counter, 'video_evaluator')
373 | logger = logger_factory(
374 | 'video_evaluator', 'actor_steps', 0
375 | )
376 |
377 | # Create the run loop and return it.
378 | return environment_loop.EnvironmentLoop(
379 | eval_environment, actor, counter, logger, observers=observers
380 | )
381 |
382 | return evaluator
383 |
384 |
385 | def bc_evaluator_factory(
386 | environment_factory: types.EnvironmentFactory,
387 | network_factory: experiment_config.NetworkFactory[builders.Networks],
388 | policy_factory: experiment_config.PolicyFactory[
389 | builders.Networks, builders.Policy
390 | ],
391 | logger_factory: loggers.LoggerFactory = experiment_utils.make_experiment_logger,
392 | ) -> experiment_config.EvaluatorFactory[builders.Policy]:
393 | """Returns an imitation learning evaluator process."""
394 |
395 | def evaluator(
396 | random_key: types.PRNGKey,
397 | variable_source: core.VariableSource,
398 | counter: counting.Counter,
399 | make_actor: experiment_config.MakeActorFn[builders.Policy],
400 | ) -> environment_loop.EnvironmentLoop:
401 | del make_actor
402 | # Create environment and evaluator networks.
403 | environment_key, actor_key = jax.random.split(random_key)
404 | # Environments normally require uint32 as a seed.
405 | environment = environment_factory(utils.sample_uint32(environment_key))
406 |
407 | eval_environment = environment
408 |
409 | environment_spec = specs.make_environment_spec(environment)
410 | networks = network_factory(environment_spec)
411 | policy = policy_factory(networks, environment_spec, True)
412 |
413 | def make_bc_actor(
414 | random_key: networks_lib.PRNGKey,
415 | policy: actor_core_lib.FeedForwardPolicy,
416 | environment_spec: specs.EnvironmentSpec,
417 | variable_source: Optional[core.VariableSource] = None,
418 | ) -> acme.Actor:
419 | del environment_spec
420 | assert variable_source is not None
421 | actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
422 | variable_client = variable_utils.VariableClient(
423 | variable_source, 'bc_policy_params', device='cpu'
424 | )
425 | return actors.GenericActor(
426 | actor_core, random_key, variable_client, None, backend='cpu'
427 | )
428 |
429 | actor = make_bc_actor(actor_key, policy, environment_spec, variable_source)
430 |
431 | observers = (SuccessObserver(),)
432 |
433 | # Create logger and counter.
434 | counter = counting.Counter(counter, 'bc_evaluator')
435 | logger = logger_factory('bc_evaluator', 'actor_steps', 0)
436 |
437 | # Create the run loop and return it.
438 | return environment_loop.EnvironmentLoop(
439 | eval_environment, actor, counter, logger, observers=observers
440 | )
441 |
442 | return evaluator
443 |
--------------------------------------------------------------------------------
/sil/pretraining.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Methods to pretrain networks, e.g. policies, with regression.
17 | """
18 | import time
19 | from typing import Callable, Dict, Iterator, NamedTuple, Optional, Sequence, Tuple
20 |
21 | from acme import specs
22 | from acme import types
23 | from acme.agents.jax import bc
24 | from acme.jax import networks as networks_lib
25 | from acme.jax import types as jax_types
26 | from acme.jax import utils
27 | from acme.utils import counting
28 | from acme.utils import experiment_utils
29 | from acme.utils import loggers
30 | import jax
31 | from jax import random
32 | import jax.numpy as jnp
33 | import numpy as np
34 | import optax
35 | import tensorflow_probability.substrates.jax as tfp
36 | import tree
37 |
38 | from sil import config
39 |
40 | tfd = tfp.distributions
41 |
42 | LoggerFactory = Callable[[], loggers.Logger]
43 |
44 |
45 | BCLossWithoutAux = bc.losses.BCLossWithoutAux
46 | BCLossWithAux = bc.losses.BCLossWithAux
47 | LossArgs = [
48 | bc.networks.BCNetworks,
49 | networks_lib.Params,
50 | networks_lib.PRNGKey,
51 | types.Transition,
52 | ]
53 | Metrics = Dict[str, jnp.ndarray]
54 |
55 | ExtendedBCLossWithAux = Tuple[
56 | Callable[LossArgs, Tuple[jnp.ndarray, Metrics]],
57 | Callable[[networks_lib.Params], networks_lib.Params],
58 | Callable[[networks_lib.Params], networks_lib.Params],
59 | ]
60 |
61 |
62 | def no_param_change(params: networks_lib.Params) -> networks_lib.Params:
63 | return params
64 |
65 |
66 | def weight_decay(params: networks_lib.Params) -> jnp.ndarray:
67 | """Used for weight decay loss terms."""
68 | return 0.5 * sum(
69 | jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
70 |
71 |
72 | def mse(action_dimension: int) -> ExtendedBCLossWithAux:
73 | """Log probability loss."""
74 | del action_dimension
75 | def loss(
76 | networks: bc.networks.BCNetworks,
77 | params: networks_lib.Params,
78 | key: jax_types.PRNGKey,
79 | transitions: types.Transition,
80 | ) -> Tuple[jnp.ndarray, Metrics]:
81 | dist = networks.policy_network.apply(
82 | params, transitions.observation, is_training=True, key=key,
83 | train_encoder=True,
84 | )
85 | sample = networks.sample_fn(dist, key)
86 | entropy = -networks.log_prob(dist, sample).mean()
87 | mean_sq_error = ((dist.mode() - transitions.action) ** 2).mean()
88 | metrics = {
89 | "mse": mean_sq_error,
90 | "nllh": -networks.log_prob(dist, transitions.action).mean(),
91 | "ent": entropy,
92 | }
93 | return mean_sq_error, metrics
94 |
95 | return loss, no_param_change, no_param_change
96 |
97 |
98 | def faithful_loss(action_dimension: int) -> ExtendedBCLossWithAux:
99 | """Combines mean-squared error and negative log-likeihood.
100 |
101 | Uses stop-gradients to ensure 'faithful' MSE fit and uncertainty
102 | quantification.'
103 |
104 | Args:
105 | action_dimension: (Unused for this loss)
106 |
107 | Returns:
108 | loss function
109 | parameter extender: for adding loss terms to pytree params
110 | parameter retracter: for removing loss terms from pytree params
111 | """
112 | del action_dimension
113 | def loss(
114 | networks: bc.networks.BCNetworks,
115 | params: networks_lib.Params,
116 | key: jax_types.PRNGKey,
117 | transitions: types.Transition,
118 | ) -> Tuple[jnp.ndarray, Metrics]:
119 | dist, cut_dist = networks.policy_network.apply(
120 | params,
121 | transitions.observation,
122 | is_training=True,
123 | key=key,
124 | faithful_distributions=True,
125 | train_encoder=True,
126 | )
127 | sample = networks.sample_fn(dist, key)
128 | entropy = -networks.log_prob(cut_dist, sample).mean()
129 | mean_sq_error = ((dist.mode() - transitions.action) ** 2).mean()
130 | nllh = -networks.log_prob(cut_dist, transitions.action).mean()
131 |
132 | loss_ = mean_sq_error + nllh
133 |
134 | metrics = {
135 | "mse": mean_sq_error,
136 | "nllh": nllh,
137 | "ent": entropy,
138 | }
139 |
140 | return loss_, metrics
141 |
142 | return loss, no_param_change, no_param_change
143 |
144 |
145 | def negative_loglikelihood(action_dimension: int) -> ExtendedBCLossWithAux:
146 | """Negative log likelihood loss."""
147 |
148 | del action_dimension
149 | def loss(
150 | networks: bc.networks.BCNetworks,
151 | params: networks_lib.Params,
152 | key: jax_types.PRNGKey,
153 | transitions: types.Transition,
154 | ) -> Tuple[jnp.ndarray, Metrics]:
155 | dist = networks.policy_network.apply(
156 | params, transitions.observation, is_training=True, key=key,
157 | train_encoder=True,
158 | )
159 | nllh = -networks.log_prob(dist, transitions.action).mean()
160 | sample = networks.sample_fn(dist, key)
161 | entropy = -networks.log_prob(dist, sample).mean()
162 | metrics = {
163 | "mse": ((dist.mode() - transitions.action) ** 2).mean(),
164 | "nllh": nllh,
165 | "ent": entropy,
166 | }
167 | return nllh, metrics
168 |
169 | return loss, no_param_change, no_param_change
170 |
171 |
172 | def zero_debiased_faithful_loss(
173 | action_dimension: int,
174 | ) -> ExtendedBCLossWithAux:
175 | """Combines mean-squared error and negative log-likeihood.
176 |
177 | Uses stop-gradients to ensure 'faithful' MSE fit and uncertainty
178 | quantification.'
179 |
180 | Args:
181 | action_dimension: (Unused for this loss)
182 |
183 | Returns:
184 | loss function
185 | parameter extender: for adding loss terms to pytree params
186 | parameter retracter: for removing loss terms from pytree params
187 | """
188 |
189 | def loss(
190 | networks: bc.networks.BCNetworks,
191 | params: networks_lib.Params,
192 | key: jax_types.PRNGKey,
193 | transitions: types.Transition,
194 | ) -> Tuple[jnp.ndarray, Metrics]:
195 | policy_params = params["model"]
196 | dist, cut_dist = networks.policy_network.apply(
197 | policy_params,
198 | transitions.observation,
199 | is_training=True,
200 | key=key,
201 | faithful_distributions=True,
202 | train_encoder=True,
203 | )
204 | sample = networks.sample_fn(dist, key)
205 | entropy = -networks.log_prob(cut_dist, sample).mean()
206 | mean_sq_error = ((dist.mode() - transitions.action) ** 2).mean(axis=1)
207 | nllh = -networks.log_prob(cut_dist, transitions.action)
208 |
209 | mse1 = jnp.expand_dims(mean_sq_error, axis=1)
210 | nllh1 = jnp.expand_dims(nllh, axis=1)
211 |
212 | loss_params = params["loss"]
213 | w_virtual = loss_params["w_virtual"]
214 | w = jax.nn.softmax(w_virtual)
215 | log_w = jnp.expand_dims(jnp.log(w), axis=0)
216 | iso_sigma = jax.nn.softplus(loss_params["sigma_virtual"])
217 | sigma = iso_sigma * jnp.ones((action_dimension,))
218 |
219 | dist0 = tfd.MultivariateNormalDiag(
220 | loc=jnp.zeros((action_dimension,)), scale_diag=sigma
221 | )
222 |
223 | mse0 = jnp.expand_dims((transitions.action**2).mean(axis=1), axis=1)
224 | mse0 = mse0 / iso_sigma
225 | nllh0 = -jnp.expand_dims(dist0.log_prob(transitions.action), axis=1)
226 |
227 | mse_ = jnp.concatenate([mse0, mse1], axis=1)
228 | nllh_ = jnp.concatenate([nllh0, nllh1], axis=1)
229 |
230 | # Do a mixture likelihood.
231 | # log sum_i w_i p(x|i) = logsumexp(log_p(x|i) + log_w) along a new dimension
232 | # MSE needs temp for sharp minimum estimate.
233 | mse_lse_temp = 1000.0
234 | # Don't do logsumexp do the other form.
235 | mmse = (
236 | -jax.scipy.special.logsumexp((mse_lse_temp * -mse_ + log_w), axis=1)
237 | / mse_lse_temp
238 | )
239 | mnllh = -jax.scipy.special.logsumexp((-nllh_ + log_w), axis=1)
240 | mmse, mnllh = mmse.mean(), mnllh.mean()
241 |
242 | loss_ = mmse + mnllh
243 | metrics = {
244 | "mse": mmse,
245 | "nllh": mnllh,
246 | "ent": entropy,
247 | }
248 |
249 | return loss_, metrics
250 |
251 | def extend_params(params: networks_lib.Params) -> networks_lib.Params:
252 | new_params = {
253 | "loss": {
254 | "w_virtual": jnp.array([0.0, 1.0]),
255 | "sigma_virtual": -1.0 * jnp.ones((1,)),
256 | },
257 | "model": params,
258 | }
259 | return new_params
260 |
261 | def retract_params(params: networks_lib.Params) -> networks_lib.Params:
262 | return params["model"] if "model" in params else params
263 |
264 | return loss, extend_params, retract_params
265 |
266 |
267 | # Map enum to loss function.
268 | _LOOKUP = {
269 | config.Losses.FAITHFUL: faithful_loss,
270 | config.Losses.DBFAITHFUL: zero_debiased_faithful_loss,
271 | config.Losses.MSE: mse,
272 | config.Losses.NLLH: negative_loglikelihood,
273 | }
274 |
275 |
276 | def get_loss_function(
277 | loss_type: config.Losses,
278 | ) -> Callable[[int], ExtendedBCLossWithAux]:
279 | assert loss_type in _LOOKUP
280 | return _LOOKUP[loss_type]
281 |
282 |
283 | TerminateCondition = Callable[[list[dict[str, jnp.ndarray]]], bool]
284 |
285 |
286 | class EarlyStoppingBCLearner(bc.BCLearner):
287 | """Behavioural cloning learner that stops based on metrics."""
288 |
289 | def __init__(self, terminate_condition: TerminateCondition, *args, **kwargs):
290 | self.metrics = []
291 | self.terminate_condition = terminate_condition
292 | self.terminate = False
293 | super().__init__(*args, **kwargs)
294 |
295 | def step(self):
296 | # Get a batch of Transitions.
297 | transitions = next(self._prefetching_iterator)
298 | self._state, metrics = self._sgd_step(self._state, transitions)
299 | metrics = utils.get_from_first_device(metrics)
300 | # Compute elapsed time.
301 | timestamp = time.time()
302 | elapsed_time = timestamp - self._timestamp if self._timestamp else 0
303 | self._timestamp = timestamp
304 | # Increment counts and record the current time.
305 | counts = self._counter.increment(steps=1, walltime=elapsed_time)
306 | # Attempts to write the logs.
307 | self._logger.write({**metrics, **counts})
308 | self.metrics += [metrics,]
309 | self.terminate = self.terminate_condition(self.metrics)
310 |
311 |
312 | def behavioural_cloning_pretraining(
313 | seed: int,
314 | env_spec: specs.EnvironmentSpec,
315 | dataset_factory: Callable[[int], Iterator[types.Transition]],
316 | policy: networks_lib.FeedForwardNetwork,
317 | loss: config.Losses = config.Losses.FAITHFUL,
318 | num_steps: int = 40_000,
319 | learning_rate: float = 1e-4,
320 | logger: Optional[loggers.Logger] = None,
321 | name: str = "",
322 | ) -> networks_lib.Params:
323 | """Trains the policy and returns the params single-threaded training loop.
324 |
325 | Args:
326 | seed: Random seed for training.
327 | env_spec: Environment specification.
328 | dataset_factory: A function that returns an iterator with demonstrations to
329 | be imitated.
330 | policy: Policy network model.
331 | loss: loss type for pretraining (e.g. MSE, log likelihood, ...)
332 | num_steps: Number of training steps.
333 | learning_rate: Used for regression.
334 | logger: Optional external object for logging.
335 | name: Name used for logger.
336 |
337 | Returns:
338 | The trained network params.
339 | """
340 | key = random.PRNGKey(seed)
341 |
342 | logger = logger or experiment_utils.make_experiment_logger(f"pretrainer_policy{name}")
343 |
344 | # Train using log likelihood.
345 | n_actions = np.prod(env_spec.actions.shape)
346 | loss_fn = _LOOKUP[loss]
347 | bc_loss, extend_params, retract_params = loss_fn(n_actions)
348 |
349 | bc_policy_network = bc.convert_to_bc_network(policy)
350 | # Add loss terms to params here.
351 | policy_network = bc.BCPolicyNetwork(
352 | lambda key: extend_params(policy.init(key)), bc_policy_network.apply
353 | )
354 |
355 | bc_network = bc.BCNetworks(
356 | policy_network=policy_network,
357 | log_prob=lambda params, acts: params.log_prob(acts),
358 | # For BC agent, the sample_fn is used for evaluation.
359 | sample_fn=lambda params, key: params.sample(seed=key),
360 | )
361 |
362 | dataset = dataset_factory(seed)
363 |
364 | counter = counting.Counter(prefix="policy_pretrainer", time_delta=0.0)
365 |
366 | history = 50
367 | ent_threshold = -2 * n_actions
368 |
369 | def terminate_condition(metrics: Sequence[Dict[str, jnp.ndarray]]) -> bool:
370 | if len(metrics) < history:
371 | return False
372 | else:
373 | return all(m["ent"] < ent_threshold for m in metrics[-history:])
374 |
375 | learner = EarlyStoppingBCLearner(
376 | terminate_condition=terminate_condition,
377 | loss_fn=bc_loss,
378 | optimizer=optax.adam(learning_rate=learning_rate),
379 | random_key=key,
380 | networks=bc_network,
381 | prefetching_iterator=utils.sharded_prefetch(dataset),
382 | loss_has_aux=True,
383 | num_sgd_steps_per_step=1,
384 | logger=logger,
385 | counter=counter,)
386 |
387 | # Train the agent.
388 | for _ in range(num_steps):
389 | learner.step()
390 | # learner.terminate is available
391 |
392 | policy_and_loss_params = learner.get_variables(["policy"])[0]
393 | del learner # Ensure logger is closed.
394 | # Remove loss terms from params here.
395 | return retract_params(policy_and_loss_params)
396 |
397 |
398 | class TrainingState(NamedTuple):
399 | params: networks_lib.Params
400 | target_params: networks_lib.Params
401 | opt_state: optax.OptState
402 |
403 |
404 | def critic_pretraining(
405 | seed: int,
406 | dataset_factory: Callable[[int], Iterator[types.Transition]],
407 | critic: networks_lib.FeedForwardNetwork,
408 | critic_params: networks_lib.Params,
409 | reward: networks_lib.FeedForwardNetwork,
410 | reward_params: networks_lib.Params,
411 | discount_factor: float,
412 | num_steps: int = 10_000,
413 | learning_rate: float = 5e-3,
414 | counter: Optional[counting.Counter] = None,
415 | logger: Optional[loggers.Logger] = None,
416 | ) -> networks_lib.Params:
417 | """Pretrain the critic using a SARSA loss.
418 |
419 | Args:
420 | seed: for randomized training
421 | dataset_factory: SARSA data iterator for pretraining
422 | critic: critic function
423 | critic_params: initial critic params
424 | reward: reward function
425 | reward_params: known reward params
426 | discount_factor: discount used for Bellman equation
427 | num_steps: number of update steps
428 | learning_rate: learning rate of optimizer
429 | counter: used for logging
430 | logger: Optional external object for logging.
431 |
432 | Returns:
433 | Trained critic params.
434 | """
435 | key = jax.random.PRNGKey(seed)
436 | optimiser = optax.adam(learning_rate)
437 |
438 | initial_opt_state = optimiser.init(critic_params)
439 |
440 | state = TrainingState(critic_params, critic_params, initial_opt_state)
441 |
442 | dataset_iterator = dataset_factory(seed)
443 |
444 | tau = 0.005
445 |
446 | sample = next(dataset_iterator)
447 | assert "next_action" in sample.extras, "Require SARSA dataset."
448 |
449 | @jax.jit
450 | def loss(
451 | params: networks_lib.Params,
452 | target_params: networks_lib.Params,
453 | observation: jnp.ndarray,
454 | action: jnp.ndarray,
455 | next_observation: jnp.ndarray,
456 | next_action: jnp.ndarray,
457 | discount: jnp.ndarray,
458 | key: jax_types.PRNGKey,
459 | ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
460 | """."""
461 | del key
462 |
463 | r = jnp.ravel(reward.apply(reward_params, observation, action))
464 |
465 | next_v_target = critic.apply(
466 | target_params, next_observation, next_action).min(axis=-1)
467 | next_v_target = jax.lax.stop_gradient(next_v_target)
468 | discount_ = discount_factor * discount
469 | q_sarsa_target = jnp.expand_dims(r + discount_ * next_v_target, -1)
470 | q = critic.apply(params, observation, action)
471 | sarsa_loss = ((q_sarsa_target - q) ** 2).mean()
472 |
473 | def nonbatch_critic(s, a):
474 | batch_s = tree.map_structure(lambda f: f[None, ...], s)
475 | return critic.apply(params, batch_s, a[None, ...])[0]
476 |
477 | dqda = jax.vmap(jax.jacfwd(nonbatch_critic, argnums=1), in_axes=(0))
478 | grads = dqda(observation, action)
479 | # Sum over actions, average over the rest.
480 | grad_norm = jnp.sqrt((grads**2).sum(axis=-1).mean())
481 |
482 | loss = sarsa_loss + grad_norm
483 |
484 | metrics = {
485 | "loss": loss,
486 | "sarsa_loss": sarsa_loss,
487 | "grad_norm": grad_norm,
488 | }
489 |
490 | return loss, metrics
491 |
492 | @jax.jit
493 | def step(transition, state, key):
494 | if "next_action" in transition.extras:
495 | next_action = transition.extras["next_action"]
496 | else:
497 | next_action = transition.action
498 | values, grads = jax.value_and_grad(loss, has_aux=True)(
499 | state.params,
500 | state.target_params,
501 | transition.observation,
502 | transition.action,
503 | transition.next_observation,
504 | next_action,
505 | transition.discount,
506 | key,
507 | )
508 | _, metrics = values
509 | updates, opt_state = optimiser.update(grads, state.opt_state)
510 | params = optax.apply_updates(state.params, updates)
511 | target_params = jax.tree_map(
512 | lambda x, y: x * (1 - tau) + y * tau, state.target_params, params
513 | )
514 | return TrainingState(params, target_params, opt_state), metrics
515 |
516 | timestamp = time.time()
517 | counter = counter or counting.Counter(
518 | prefix="pretrainer_critic", time_delta=0.0
519 | )
520 | logger = logger or loggers.make_default_logger(
521 | "pretrainer_critic",
522 | asynchronous=False,
523 | serialize_fn=utils.fetch_devicearray,
524 | steps_key=counter.get_steps_key(),
525 | )
526 | for i in range(num_steps):
527 | _, key = jax.random.split(key)
528 | transitions = next(dataset_iterator)
529 | state, metrics = step(transitions, state, key)
530 | metrics["step"] = i
531 | timestamp_ = time.time()
532 | elapsed_time = timestamp_ - timestamp
533 | timestamp = timestamp_
534 | counts = counter.increment(steps=1, walltime=elapsed_time)
535 | logger.write({**metrics, **counts})
536 |
537 | logger.close()
538 |
539 | return state.params
540 |
--------------------------------------------------------------------------------
/sil/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Soft imitation learning types, configurations and hyperparameters.
17 |
18 | Useful resources for these implementations:
19 | IQ-Learn: https://arxiv.org/abs/2106.12142
20 | https://github.com/Div99/IQ-Learn
21 | P^2IL: https://arxiv.org/abs/2209.10968
22 | https://github.com/lviano/P2IL
23 | """
24 | import abc
25 | import dataclasses
26 | import enum
27 | from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
28 |
29 | from acme import types
30 | from acme.agents.jax.sac import config as sac_config
31 | import jax
32 | from jax import lax
33 | from jax import random
34 | import jax.nn as jnn
35 | import jax.numpy as jnp
36 | import numpy as np
37 | import tree
38 |
39 | # Custom function and return typing for signature brevity.
40 | # Generalized reward function, including discounting.
41 | StateActionNextStateFunc = Callable[
42 | [jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
43 | jnp.ndarray,
44 | ]
45 | # State-action critics.
46 | StateActionFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
47 | # I.e. function regularizers.
48 | StateFunc = Callable[[jnp.ndarray], jnp.ndarray]
49 | # Soft (i.e. stochastic) value function.
50 | StochStateFunc = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
51 | # Losses and factories for defined here for brevity.
52 | AuxLoss = Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]
53 | CriticLossFact = Callable[
54 | [
55 | StateActionNextStateFunc,
56 | StateActionFunc,
57 | StochStateFunc,
58 | StochStateFunc,
59 | float,
60 | types.Transition,
61 | types.Transition,
62 | jax.Array,
63 | ],
64 | AuxLoss,
65 | ]
66 | RewardFact = Callable[
67 | [StateActionFunc, StateActionFunc, StochStateFunc, float],
68 | StateActionNextStateFunc,
69 | ]
70 | DemonstrationFact = Callable[..., Iterator[types.Transition]]
71 | ConcatenateFact = Callable[
72 | [types.Transition, types.Transition], types.Transition
73 | ]
74 |
75 |
76 | # Used to bound policy log-likelihoods via the variance.
77 | MIN_VAR = 1e-5
78 |
79 | # Maximum reward for CSIL.
80 | # By having a lower bound on the variance, and setting alpha to be
81 | # (1 - \gamma) / dim_a, we can bound the maximum discounted return.
82 | MAX_REWARD = -0.5 * np.log(MIN_VAR) + np.log(2.)
83 |
84 |
85 | class Losses(enum.Enum):
86 | FAITHFUL = "faithful"
87 | DBFAITHFUL = "dbfaithful"
88 | MSE = "mse"
89 | NLLH = "nllh"
90 |
91 | def __str__(self) -> str:
92 | return self.value
93 |
94 |
95 | # Soft inverse Q learning has different objectives based on divergence choice
96 | # between expert and policy occupancy measure.
97 | class Divergence(enum.Enum):
98 | FORWARD_KL = 0
99 | REVERSE_KL = 1
100 | REVERSE_KL_DUAL = 2
101 | REVERSE_KL_UNBIASED = 3
102 | HELLINGER = 4
103 | JS = 5
104 | CHI = 6
105 | TOTAL_VARIATION = 7
106 |
107 |
108 | # Each f divergence can be written as
109 | # D_f[P || Q] = \sup_g E_P[g(x)] - E_Q[f^*(g(x))]
110 | # where f^* is the convex conjugate of f.
111 | # For the objective, we require a function \phi(x) = -f^*(-x).
112 | # These functions implement the \phi(x) for a given divergence.
113 | # See Table 4 of Garg et al. (2021) and Ho et al. (2016) for more details.
114 | _forward_kl: StateFunc = lambda r: 1.0 + jnp.log(r)
115 | _reverse_kl: StateFunc = lambda r: jnp.exp(-r - 1.0)
116 | _reverse_kl_dual: StateFunc = lambda r: jnn.softmax(-r) * r.shape[0]
117 | _reverse_kl_unbiased: StateFunc = lambda r: jnp.exp(-r)
118 | _hellinger: StateFunc = lambda r: 1.0 / (1.0 + r) ** 2
119 | _chi: StateFunc = lambda r: r - r**2 / 2.0
120 | _total_variation: StateFunc = lambda r: r
121 | _js: StateFunc = lambda r: jnp.log(2.0 - jnp.exp(-r))
122 |
123 | DIVERGENCE_REGULARIZERS = {
124 | Divergence.FORWARD_KL: _forward_kl,
125 | Divergence.REVERSE_KL: _reverse_kl,
126 | Divergence.REVERSE_KL_DUAL: _reverse_kl_dual,
127 | Divergence.REVERSE_KL_UNBIASED: _reverse_kl_unbiased,
128 | Divergence.HELLINGER: _hellinger,
129 | Divergence.CHI: _chi,
130 | Divergence.TOTAL_VARIATION: _total_variation,
131 | Divergence.JS: _js,
132 | }
133 |
134 |
135 | def concatenate(x: Any, y: Any) -> Any:
136 | return tree.map_structure(lambda x, y: jnp.concatenate((x, y), axis=0), x, y)
137 |
138 |
139 | def concatenate_transitions(
140 | x: types.Transition, y: types.Transition
141 | ) -> types.Transition:
142 | fields = ["observation", "action", "reward", "discount", "next_observation"]
143 | return types.Transition(
144 | *[concatenate(getattr(x, field), getattr(y, field)) for field in fields]
145 | )
146 |
147 |
148 | @dataclasses.dataclass
149 | class SoftImitationConfig(abc.ABC):
150 | """Abstact base class for soft imitation learning."""
151 |
152 | @abc.abstractmethod
153 | def critic_loss_factory(self) -> CriticLossFact:
154 | """Define the critic loss based on the algorithm."""
155 |
156 | @abc.abstractmethod
157 | def reward_factory(self) -> RewardFact:
158 | """Define the reward based on the algorithm, i.e. implicit vs explicit."""
159 |
160 |
161 | @dataclasses.dataclass
162 | class InverseSoftQConfig(SoftImitationConfig):
163 | """Configuration for the IQ-Learn algorithm."""
164 |
165 | # Defines divergence-derived function regularization.
166 | divergence: Divergence = Divergence.CHI
167 |
168 | def critic_loss_factory(self) -> CriticLossFact:
169 | """Generate critic objective from hyperparameters."""
170 | regularizer = DIVERGENCE_REGULARIZERS[self.divergence]
171 |
172 | def objective(
173 | reward_fn: StateActionNextStateFunc,
174 | state_action_value_fn: StateActionFunc,
175 | value_fn: StochStateFunc,
176 | target_value_fn: StochStateFunc,
177 | discount: float,
178 | demonstration_transitions: types.Transition,
179 | online_transitions: types.Transition,
180 | key: jax.Array,
181 | ) -> AuxLoss:
182 | """See Equation 10 of Garg et al. (2021) for reference."""
183 | del target_value_fn
184 | key_er, key_v, key_or = random.split(key, 3)
185 | expert_reward = reward_fn(
186 | demonstration_transitions.observation,
187 | demonstration_transitions.action,
188 | demonstration_transitions.next_observation,
189 | demonstration_transitions.discount,
190 | key_er,
191 | )
192 | phi_grad = regularizer(expert_reward).mean()
193 | expert_loss = -phi_grad.mean()
194 |
195 | # This is denoted the value of the initial state distribution in the paper
196 | # and codebase, but in practice the distribution is replaced with the
197 | # demonstration distribution. In practice this term ensures the Q
198 | # function is maximized at the demonstration data since here we are
199 | # minimizing the value for action sampled around the optimal expert.
200 | value_reg = (1 - discount) * value_fn(
201 | demonstration_transitions.observation, key_v
202 | ).mean()
203 |
204 | online_reward = reward_fn(
205 | online_transitions.observation,
206 | online_transitions.action,
207 | online_transitions.next_observation,
208 | online_transitions.discount,
209 | key_or,
210 | )
211 | # See code implementation IQ-Learn/iq_learn/iq.py
212 | agent_loss = 0.5 * (online_reward**2).mean()
213 |
214 | metrics = {
215 | "expert_reward": expert_reward.mean(),
216 | "online_reward": online_reward.mean(),
217 | "value_reg": value_reg,
218 | "expert_loss": expert_loss,
219 | "online_reg": agent_loss,
220 | }
221 | return expert_loss + value_reg + agent_loss, metrics
222 |
223 | return objective
224 |
225 | def reward_factory(self) -> RewardFact:
226 | def reward_factory_(
227 | state_action_reward: StateActionFunc,
228 | state_action_value_function: StateActionFunc,
229 | state_value_function: StochStateFunc,
230 | discount_factor: float,
231 | ) -> StateActionNextStateFunc:
232 | del state_action_reward
233 |
234 | def reward_fn(
235 | state: jnp.ndarray,
236 | action: jnp.ndarray,
237 | next_state: jnp.ndarray,
238 | discount: jnp.ndarray,
239 | value_key: jnp.ndarray,
240 | ) -> jnp.ndarray:
241 | q = state_action_value_function(state, action)
242 | future_v = discount * state_value_function(next_state, value_key)
243 | return q - discount_factor * future_v
244 |
245 | return reward_fn
246 |
247 | return reward_factory_
248 |
249 |
250 | @dataclasses.dataclass
251 | class ProximalPointConfig(SoftImitationConfig):
252 | """Configuration for the proximal point imitation learning algorithm."""
253 |
254 | bellman_error_temperature: float = 1.0
255 |
256 | def critic_loss_factory(self) -> CriticLossFact:
257 | """Generate critic objective from hyperparameters.
258 |
259 | P^2IL's objective consists of four terms.
260 | Optimizing the experts rewards, minimizing the logistic Bellman error,
261 | optimizing the initial state value using a proxy form to improve the value
262 | function, and regularizing the reward function with a squared penality.
263 |
264 | Returns:
265 | Function that return the critic loss function.
266 | """
267 | alpha = self.bellman_error_temperature
268 |
269 | def objective(
270 | reward_fn: StateActionNextStateFunc,
271 | state_action_value_fn: StateActionFunc,
272 | value_fn: StochStateFunc,
273 | target_value_fn: StochStateFunc,
274 | discount: float,
275 | demonstration_transitions: types.Transition,
276 | online_transitions: types.Transition,
277 | key: jax.Array,
278 | ) -> AuxLoss:
279 | """See Equation 67 and Theorem 6 of Viano et al. (2022) for reference."""
280 | key_cr, key_er, key_v, key_fv, key_or = random.split(key, 5)
281 | expert_reward = reward_fn(
282 | demonstration_transitions.observation,
283 | demonstration_transitions.action,
284 | demonstration_transitions.next_observation,
285 | demonstration_transitions.discount,
286 | key_er,
287 | )
288 | expert_reward_mean = expert_reward.mean()
289 | expert_q = state_action_value_fn(
290 | demonstration_transitions.observation,
291 | demonstration_transitions.action,
292 | )
293 | expert_v = value_fn(demonstration_transitions.observation, key_v)
294 | expert_next_v = value_fn(
295 | demonstration_transitions.next_observation, key_fv
296 | )
297 | expert_d = discount * demonstration_transitions.discount
298 | expert_imp_r_mean = (expert_q - expert_d * expert_next_v).mean()
299 |
300 | online_r = reward_fn(
301 | online_transitions.observation,
302 | online_transitions.action,
303 | online_transitions.next_observation,
304 | online_transitions.discount,
305 | key_or,
306 | )
307 | online_q = state_action_value_fn(
308 | online_transitions.observation, online_transitions.action
309 | )
310 | online_v = value_fn(online_transitions.observation, key_v)
311 | online_d = online_transitions.discount
312 | combined_transition = concatenate_transitions(
313 | demonstration_transitions, online_transitions
314 | )
315 | combined_reward = reward_fn(
316 | combined_transition.observation,
317 | combined_transition.action,
318 | combined_transition.next_observation,
319 | online_transitions.discount,
320 | key_cr,
321 | )
322 | combined_q = state_action_value_fn(
323 | combined_transition.observation, combined_transition.action
324 | )
325 | # The Bellman equation needs the target value function for stability,
326 | # while the regularization shapes the current value function.
327 | combined_next_target_v = target_value_fn(
328 | combined_transition.next_observation, key_fv
329 | )
330 |
331 | # In theory, the reward function should be jointly optimized in the
332 | # Bellman equation to minimize the on-policy rewards, however, empirically
333 | # this produced worse results as the Bellman equation is harder to
334 | # minimize.
335 | discount_ = discount * combined_transition.discount
336 | combined_q_target = combined_reward + discount_ * combined_next_target_v
337 | bellman_error = lax.stop_gradient(combined_q_target) - combined_q
338 | # In theory, the Bellman error should not be negative as the Q and V
339 | # function should roughly track in magnitude, but in practice this was not
340 | # always the case.
341 | bellman_error = jnp.abs(bellman_error)
342 | # Self-normalized importance sampling weights z in Theorem 6 of
343 | # Vivano et al. (2022), used in logsumexp objective.
344 | log_weights = lax.stop_gradient(alpha * bellman_error)
345 | norm_weights = jnn.softmax(log_weights)
346 | ess = 1.0 / (norm_weights**2).sum() # effective sample size
347 | is_bellman_error = jnp.einsum("b,b->", norm_weights, bellman_error)
348 | sq_bellman_error = (bellman_error**2).mean()
349 | rms_bellman_error = jnp.sqrt(sq_bellman_error)
350 |
351 | apprenticeship_loss = -expert_reward.mean() + jnp.einsum(
352 | "b,b->", norm_weights, combined_reward
353 | )
354 |
355 | # This is denoted the value of the initial state distribution in the paper
356 | # and codebase, but in practice the distribution is replaced with the
357 | # demonstration distribution. In practice this term ensures the Q
358 | # function is maximized at the demonstration data since here we are
359 | # minimizing the value for action sampled around the optimal expert.
360 | value_reg = (1.0 - discount) * expert_v.mean()
361 |
362 | # P^2IL uses IQ-Learn Chi^2-based regularization in practice.
363 | expert_r_mean = expert_reward_mean
364 | function_reg = 0.5 * (combined_reward**2).mean()
365 | metrics = {
366 | "apprenticeship_loss": apprenticeship_loss,
367 | "expert_reward": expert_reward.mean(),
368 | "expert_reward_implicit": expert_imp_r_mean,
369 | "expert_reward_combined": expert_r_mean,
370 | "online_reward": online_r.mean(),
371 | "value_reg": value_reg,
372 | "function_reg": function_reg,
373 | "sq_bellman_error": sq_bellman_error,
374 | "is_bellman_error": is_bellman_error,
375 | "ess": ess,
376 | }
377 | return (
378 | -expert_r_mean + is_bellman_error + value_reg + function_reg,
379 | metrics,
380 | )
381 |
382 | return objective
383 |
384 | def reward_factory(self) -> RewardFact:
385 | """P^2IL's reward function is a straightforwad MLP."""
386 |
387 | def reward_factory_(
388 | state_action_reward: StateActionFunc,
389 | state_action_value_function: StateActionFunc,
390 | state_value_function: StochStateFunc,
391 | discount_factor: float,
392 | ) -> StateActionNextStateFunc:
393 | del state_action_value_function, state_value_function, discount_factor
394 |
395 | def reward_fn(
396 | state: jnp.ndarray,
397 | action: jnp.ndarray,
398 | next_state: jnp.ndarray,
399 | discount: jnp.ndarray,
400 | value_key: jnp.ndarray,
401 | ) -> jnp.ndarray:
402 | del next_state, discount, value_key
403 | return state_action_reward(state, action)
404 |
405 | return reward_fn
406 |
407 | return reward_factory_
408 |
409 |
410 | @dataclasses.dataclass
411 | class CoherentConfig(SoftImitationConfig):
412 | """Coherent soft imitation learning."""
413 |
414 | alpha: float # temperature used in the coherent reward
415 | reward_scaling: float = 1.0
416 | scale_factor: float = 1.0 # Scaling of online reward regularization.
417 | grad_norm_sf: float = 1.0 # Critic action Jacobian regularization.
418 | refine_reward: bool = True
419 | negative_reward: bool = False
420 |
421 | def critic_loss_factory(self) -> CriticLossFact:
422 | def objective(
423 | reward_fn: StateActionNextStateFunc,
424 | state_action_value_fn: StateActionFunc,
425 | value_fn: StochStateFunc,
426 | target_value_fn: StochStateFunc,
427 | discount: float,
428 | demonstration_transitions: types.Transition,
429 | online_transitions: types.Transition,
430 | key: jax.Array,
431 | ) -> AuxLoss:
432 | key_er, key_fv, key_or = random.split(key, 3)
433 | combined_transition = concatenate_transitions(
434 | demonstration_transitions, online_transitions
435 | )
436 |
437 | online_reward = reward_fn(
438 | online_transitions.observation,
439 | online_transitions.action,
440 | online_transitions.next_observation,
441 | online_transitions.discount,
442 | key_or,
443 | )
444 |
445 | expert_reward = reward_fn(
446 | demonstration_transitions.observation,
447 | demonstration_transitions.action,
448 | demonstration_transitions.next_observation,
449 | online_transitions.discount,
450 | key_er,
451 | )
452 |
453 | reward = lax.stop_gradient(jnp.concatenate(
454 | (expert_reward, online_reward), axis=0))
455 |
456 | state_action_value = state_action_value_fn(
457 | combined_transition.observation, combined_transition.action
458 | )
459 | future_value = target_value_fn(
460 | combined_transition.next_observation, key_fv
461 | )
462 | # Use knowledge to bound rogue Q values.
463 | max_value = 0. if self.negative_reward else MAX_REWARD / (1. - discount)
464 | future_value = jnp.clip(future_value, a_max=max_value)
465 |
466 | discount_ = discount * combined_transition.discount
467 |
468 | target_state_action_value = (reward + discount_ * future_value)
469 | bellman_error = state_action_value - target_state_action_value
470 |
471 | sbe = (bellman_error**2).mean()
472 | be = bellman_error.mean()
473 |
474 | # The mean expert reward corresponse to maximizing BC likelihood + const.,
475 | # minimizing the online reward corresponds to minimizing KL to prior.
476 | # Use an unbiased KL estimator that can never be negative.al
477 | expert_reward_mean = expert_reward.mean() # + imp_expert_reward.mean()
478 |
479 | # In some cases the online reward can go as low as -50
480 | # even when the mean is positive. To be robust to these outliers, we just
481 | # clip the negative online rewards, as the role of this term is to
482 | # regularize the large positive values.
483 | # This KL estimator is motivated in http://joschu.net/blog/kl-approx.html
484 | if self.negative_reward:
485 | online_log_ratio = online_reward + self.reward_scaling * MAX_REWARD
486 | else:
487 | online_log_ratio = online_reward
488 | safe_online_log_ratio = jnp.maximum(online_log_ratio, -5.0)
489 | # the estimator is log r + 1/r - 1, and the reward is alpha log r
490 | policy_kl_est = (
491 | jnp.exp(-safe_online_log_ratio) - 1. + online_log_ratio
492 | ).mean()
493 |
494 | def non_batch_state_action_value_fn(s, a):
495 | batch_s = tree.map_structure(lambda f: f[None, ...], s)
496 | return state_action_value_fn(batch_s, a[None, ...])[0]
497 | dqda = jax.vmap(
498 | jax.jacrev(non_batch_state_action_value_fn, argnums=1), in_axes=(0))
499 | grads = dqda(demonstration_transitions.observation,
500 | demonstration_transitions.action)
501 | grad_norm = jnp.sqrt((grads**2).sum(axis=1).mean())
502 |
503 | loss = sbe
504 | if self.refine_reward:
505 | loss -= expert_reward_mean
506 | loss += self.scale_factor * policy_kl_est
507 | loss += self.grad_norm_sf * grad_norm
508 |
509 | metrics = {
510 | "critic_action_grad_norm": grad_norm,
511 | "sq_bellman_error": sbe,
512 | "expert_reward": expert_reward.mean(),
513 | "online_reward": online_reward.mean(),
514 | "kl_est": policy_kl_est,
515 | }
516 | return loss, metrics
517 |
518 | return objective
519 |
520 | def reward_factory(self) -> RewardFact:
521 | """CSIL's reward function is policy-derived."""
522 |
523 | def reward_factory_(
524 | state_action_reward: StateActionFunc,
525 | state_action_value_function: StateActionFunc,
526 | state_value_function: StochStateFunc,
527 | discount_factor: float,
528 | ) -> StateActionNextStateFunc:
529 | del state_action_value_function, state_value_function, discount_factor
530 |
531 | def reward_fn(
532 | state: jnp.ndarray,
533 | action: jnp.ndarray,
534 | next_state: jnp.ndarray,
535 | discount: jnp.ndarray,
536 | value_key: jnp.ndarray,
537 | ) -> jnp.ndarray:
538 | del next_state, discount, value_key
539 | return state_action_reward(state, action)
540 |
541 | return reward_fn
542 |
543 | return reward_factory_
544 |
545 |
546 | @dataclasses.dataclass
547 | class PretrainingConfig:
548 | """Parameters for pretraining a model."""
549 |
550 | dataset_factory: DemonstrationFact
551 | learning_rate: float
552 | steps: int
553 | seed: int
554 | loss: Optional[Losses] = None # Used for policy pretraining.
555 | use_as_reference: bool = False # Use pretrained policy as prior.
556 |
557 |
558 | def null_data_factory(
559 | n_demonstrations: int, seed: int
560 | ) -> Iterator[types.Transition]:
561 | del n_demonstrations, seed
562 | raise NotImplementedError()
563 |
564 |
565 | SilConfigTypes = Union[InverseSoftQConfig, ProximalPointConfig, CoherentConfig]
566 |
567 |
568 | @dataclasses.dataclass
569 | class SILConfig(sac_config.SACConfig):
570 | """Configuration options for soft imitation learning."""
571 |
572 | # Imitation learning hyperparameters.
573 | expert_demonstration_factory: DemonstrationFact = null_data_factory
574 | imitation: SilConfigTypes = (
575 | dataclasses.field(default_factory=InverseSoftQConfig))
576 | actor_bc_loss: bool = False
577 | policy_pretraining: Optional[List[PretrainingConfig]] = None
578 | critic_pretraining: Optional[PretrainingConfig] = None
579 | actor_learning_rate: float = 3e-4
580 | reward_learning_rate: float = 3e-4
581 | critic_learning_rate: float = 3e-4
582 | critic_actor_update_ratio: int = 1
583 | alpha_learning_rate: float = 3e-4
584 | alpha_init: float = 1.0
585 | damping: float = 0.0 # Entropy constraint damping.
586 |
--------------------------------------------------------------------------------
/sil/learning.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Soft imitation learning learner implementation."""
17 |
18 | from __future__ import annotations
19 | import time
20 | from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple
21 |
22 | import acme
23 | from acme import types
24 | from acme.jax import networks as networks_lib
25 | from acme.jax import utils
26 | from acme.utils import counting
27 | from acme.utils import loggers
28 | import jax
29 | from jax import lax
30 | import jax.numpy as jnp
31 | import optax
32 |
33 | from sil import config as sil_config
34 | from sil import networks as sil_networks
35 | from sil import pretraining
36 |
37 | # useful for analysis and comparing algorithms
38 | MONITOR_BC_METRICS = False
39 |
40 |
41 | class ImitationSample(NamedTuple):
42 | """For imitation learning, we require agent and demonstration experience."""
43 |
44 | online_sample: types.Transition
45 | demonstration_sample: types.Transition
46 |
47 |
48 | class TrainingState(NamedTuple):
49 | """Contains training state for the learner."""
50 | policy_optimizer_state: optax.OptState
51 | q_optimizer_state: optax.OptState
52 | r_optimizer_state: optax.OptState
53 | policy_params: networks_lib.Params
54 | q_params: networks_lib.Params
55 | target_q_params: networks_lib.Params
56 | r_params: networks_lib.Params
57 | key: networks_lib.PRNGKey
58 | bc_policy_params: Optional[networks_lib.Params] = None
59 | alpha_optimizer_state: Optional[optax.OptState] = None
60 | alpha_params: Optional[networks_lib.Params] = None
61 |
62 |
63 | class SILLearner(acme.Learner):
64 | """Soft imitation learning learner."""
65 |
66 | _state: TrainingState
67 |
68 | def __init__(
69 | self,
70 | networks: sil_networks.SILNetworks,
71 | critic_loss_def: sil_config.CriticLossFact,
72 | reward_factory: sil_config.RewardFact,
73 | rng: jnp.ndarray,
74 | dataset: Iterator[ImitationSample],
75 | policy_optimizer: optax.GradientTransformation,
76 | q_optimizer: optax.GradientTransformation,
77 | r_optimizer: optax.GradientTransformation,
78 | tau: float = 0.005,
79 | discount: float = 0.99,
80 | critic_actor_update_ratio: int = 1,
81 | alpha_init: float = 1.0,
82 | alpha_learning_rate: float = 1e-3,
83 | entropy_coefficient: Optional[float] = None,
84 | target_entropy: float = 0.0,
85 | actor_bc_loss: bool = False,
86 | damping: float = 0.0,
87 | policy_pretraining: Optional[List[sil_config.PretrainingConfig]] = None,
88 | critic_pretraining: Optional[sil_config.PretrainingConfig] = None,
89 | counter: Optional[counting.Counter] = None,
90 | learner_logger: Optional[loggers.Logger] = None,
91 | policy_pretraining_loggers: Optional[List[loggers.Logger]] = None,
92 | critic_pretraining_logger: Optional[loggers.Logger] = None,
93 | num_sgd_steps_per_step: int = 1,
94 | ):
95 | """Initialize the soft imitation learning learner.
96 |
97 | Args:
98 | networks: SIL networks
99 | critic_loss_def: loss function definition for critic
100 | reward_factory: create implicit or explicit reward functions
101 | rng: a key for random number generation.
102 | dataset: an iterator over demonstrations and online data.
103 | policy_optimizer: the policy optimizer.
104 | q_optimizer: the Q-function optimizer.
105 | r_optimizer: the reward function optimizer.
106 | tau: target smoothing coefficient.
107 | discount: discount to use for TD updates.
108 | critic_actor_update_ratio: critic updates per single actor update.
109 | alpha_init:
110 | alpha_learning_rate:
111 | entropy_coefficient: coefficient applied to the entropy bonus. If None, an
112 | adaptative coefficient will be used.
113 | target_entropy: Used to normalize entropy. Only used when
114 | entropy_coefficient is None.
115 | actor_bc_loss: add auxiliary BC term to actor objective (unused)
116 | damping: damping of KL constraint
117 | policy_pretraining: Optional config for pretraining policy
118 | critic_pretraining: Optional config for pretraining critic
119 | counter: counter object used to keep track of steps.
120 | learner_logger: logger object to be used by learner.
121 | policy_pretraining_loggers: logger objects to be used by the policy pretraining.
122 | critic_pretraining_logger: logger object to be used by critic pretraining.
123 | num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'.
124 | """
125 |
126 | adaptive_entropy_coefficient = entropy_coefficient is None
127 | kl_bound = jnp.abs(target_entropy)
128 |
129 | if adaptive_entropy_coefficient:
130 | # Alpha is the temperature parameter that determines the relative
131 | # importance of the entropy term versus the reward.
132 | # Invert softplus to initial virtual value.
133 | if alpha_init > 4.:
134 | virtual_alpha = jnp.asarray(alpha_init, dtype=jnp.float32)
135 | else: # safely invert softplus
136 | virtual_alpha = jnp.log(
137 | jnp.exp(jnp.asarray(alpha_init, dtype=jnp.float32)) - 1.0
138 | )
139 | alpha_optimizer = optax.sgd(learning_rate=alpha_learning_rate)
140 | alpha_optimizer_state = alpha_optimizer.init(virtual_alpha)
141 | else:
142 | if target_entropy:
143 | raise ValueError(
144 | 'target_entropy should not be set when '
145 | 'entropy_coefficient is provided'
146 | )
147 |
148 | def make_initial_state(
149 | key: networks_lib.PRNGKey,
150 | ) -> Tuple[TrainingState, Optional[networks_lib.Params], bool]:
151 | """Initialises the training state (parameters and optimiser state)."""
152 | key_policy, key_q, key = jax.random.split(key, 3)
153 |
154 | # In the online setting we pretrain the policy against the demonstrations.
155 | # In the offfline setting we pretrain the policy against the dataset
156 | # and demonstrations.
157 | # In both cases, we use the last trained policy for the CSIL reward,
158 | # and in the offline case, we use the first policy as the 'BC' policy
159 | # to stay close to the data.
160 | bc_policy_params = []
161 | use_pretrained_prior = False
162 | if policy_pretraining:
163 | for i, pt in enumerate(policy_pretraining):
164 | use_pretrained_prior = use_pretrained_prior or pt.use_as_reference
165 | if not bc_policy_params:
166 | policy_ = networks_lib.FeedForwardNetwork(
167 | networks.bc_policy_network.init,
168 | networks.bc_policy_network.apply,
169 | )
170 | else:
171 | policy_ = networks_lib.FeedForwardNetwork(
172 | lambda key: bc_policy_params[0],
173 | networks.bc_policy_network.apply,
174 | )
175 | params = pretraining.behavioural_cloning_pretraining(
176 | loss=pt.loss,
177 | seed=pt.seed,
178 | env_spec=networks.environment_specs,
179 | dataset_factory=pt.dataset_factory,
180 | policy=policy_,
181 | learning_rate=pt.learning_rate,
182 | num_steps=pt.steps,
183 | logger=policy_pretraining_loggers[i],
184 | name=f'{i}',
185 | )
186 | bc_policy_params += [params,]
187 | else:
188 | bc_policy_params = [None]
189 | # While IQ-Learn and P2IL use policy pretraining for the policy, CSIL can
190 | # use it only for the reward initialization.
191 | policy_match = (
192 | networks.policy_architecture == networks.bc_policy_architecture
193 | )
194 |
195 | if policy_match and policy_pretraining:
196 | policy_params = bc_policy_params[0].copy()
197 | else:
198 | policy_params = networks.policy_network.init(key_policy)
199 |
200 | policy_optimizer_state = policy_optimizer.init(policy_params)
201 |
202 | if networks.reward_policy_coherence and bc_policy_params[-1]:
203 | r_params = bc_policy_params[-1].copy()
204 | else:
205 | r_params = networks.reward_network.init(key_q)
206 | # Share encoder with policy if present.
207 | r_params = sil_networks.update_encoder(r_params, policy_params)
208 |
209 | r_optimizer_state = r_optimizer.init(r_params)
210 |
211 | if critic_pretraining is not None:
212 | critic_ = networks_lib.FeedForwardNetwork(
213 | networks.critic_network.init, networks.critic_network.apply
214 | )
215 | reward_ = networks_lib.FeedForwardNetwork(
216 | networks.reward_network.init, networks.reward_network.apply
217 | )
218 | policy_ = networks_lib.FeedForwardNetwork(
219 | networks.policy_network.init, networks.policy_network.apply
220 | )
221 | critic_params = critic_.init(key_q)
222 | critic_params = sil_networks.update_encoder(
223 | critic_params, policy_params)
224 | critic_params = pretraining.critic_pretraining(
225 | seed=critic_pretraining.seed,
226 | dataset_factory=critic_pretraining.dataset_factory,
227 | critic=critic_,
228 | critic_params=critic_params,
229 | reward=reward_,
230 | reward_params=r_params,
231 | discount_factor=discount,
232 | num_steps=critic_pretraining.steps,
233 | learning_rate=critic_pretraining.learning_rate,
234 | logger=critic_pretraining_logger,
235 | )
236 | else:
237 | critic_params = networks.critic_network.init(key_q)
238 | # Share encoder with policy if present.
239 | critic_params = sil_networks.update_encoder(
240 | critic_params, policy_params)
241 |
242 | q_optimizer_state = q_optimizer.init(critic_params)
243 |
244 | state = TrainingState(
245 | policy_optimizer_state=policy_optimizer_state,
246 | q_optimizer_state=q_optimizer_state,
247 | r_optimizer_state=r_optimizer_state,
248 | policy_params=policy_params,
249 | q_params=critic_params,
250 | target_q_params=critic_params,
251 | r_params=r_params,
252 | bc_policy_params=bc_policy_params[-1],
253 | key=key,
254 | )
255 |
256 | if adaptive_entropy_coefficient:
257 | state = state._replace(
258 | alpha_optimizer_state=alpha_optimizer_state,
259 | alpha_params=virtual_alpha,
260 | )
261 | return state, bc_policy_params[-1], use_pretrained_prior
262 |
263 | # Create initial state.
264 | self._state, bc_policy_params, use_policy_prior = make_initial_state(rng)
265 |
266 | if use_policy_prior:
267 | assert bc_policy_params is not None
268 |
269 | def alpha_loss(
270 | virtual_alpha: jnp.ndarray,
271 | policy_params: networks_lib.Params,
272 | transitions: types.Transition,
273 | key: networks_lib.PRNGKey,
274 | ) -> jnp.ndarray:
275 | """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
276 | dist_params = networks.policy_network.apply(
277 | policy_params, transitions.observation
278 | )
279 | action = dist_params.sample(seed=key)
280 | log_prob = networks.log_prob(dist_params, action)
281 | alpha = jax.nn.softplus(virtual_alpha)
282 | if use_policy_prior: # bc_policy_params are not None
283 | dist_bc = networks.bc_policy_network.apply(
284 | bc_policy_params, transitions.observation
285 | )
286 | prior_log_prob = networks.log_prob(dist_bc, action)
287 | kl = (log_prob - prior_log_prob).mean()
288 | # KL constraint.
289 | constraint = jax.lax.stop_gradient(kl_bound - kl)
290 | # Zero if kl < kl_bound, negative if violated.
291 | # We want temp to go zero if kl not violated, so don't clip here.
292 | loss = constraint
293 | # Do gradient ascent so invert sign w.r.t. actor loss term.
294 | alpha_loss = alpha * loss
295 | else:
296 | alpha_loss = (
297 | alpha * jax.lax.stop_gradient(-log_prob - target_entropy).mean()
298 | )
299 | return alpha_loss
300 |
301 | def critic_loss(
302 | q_params: networks_lib.Params,
303 | r_params: networks_lib.Params,
304 | policy_params: networks_lib.Params,
305 | target_q_params: networks_lib.Params,
306 | alpha: jnp.ndarray,
307 | demonstration_transitions: types.Transition,
308 | online_transitions: types.Transition,
309 | key: networks_lib.PRNGKey,
310 | ):
311 | # The key aspect of soft imitation learning is the critic objective and
312 | # reward. We obtain these from factories defined in the config.
313 | def state_action_reward_fn(state, action):
314 | return jnp.ravel(networks.reward_network.apply(r_params, state, action))
315 |
316 | def state_action_value_fn(state, action):
317 | return networks.critic_network.apply(
318 | q_params, state, action).min(axis=-1) # reduce via min even for 1D
319 |
320 | def _state_value_fn(state, critic_params, policy_key):
321 | # SAC's soft value function, see Equation 3 of
322 | # https://arxiv.org/pdf/1812.05905.pdf.
323 | action_dist = networks.policy_network.apply(policy_params, state)
324 | action = action_dist.sample(seed=policy_key)
325 | policy_log_prob = networks.log_prob(action_dist, action)
326 | if use_policy_prior: # bc_policy_params have been trained
327 | prior_log_prob = networks.bc_policy_network.apply(
328 | bc_policy_params, state
329 | ).log_prob(action)
330 | else:
331 | prior_log_prob = networks.log_prob_prior(action)
332 | q = networks.critic_network.apply(
333 | critic_params, state, action).min(axis=-1)
334 | return q - alpha * (policy_log_prob - prior_log_prob)
335 |
336 | def state_value_fn(
337 | state: jnp.ndarray, key: jax.Array
338 | ) -> jnp.ndarray:
339 | return _state_value_fn(state, q_params, key)
340 |
341 | def target_state_value_fn(
342 | state: jnp.ndarray, key: jax.Array
343 | ) -> jnp.ndarray:
344 | return lax.stop_gradient(_state_value_fn(state, target_q_params, key))
345 |
346 | reward_fn = reward_factory(
347 | state_action_reward_fn,
348 | state_action_value_fn,
349 | target_state_value_fn,
350 | discount,
351 | )
352 |
353 | critic_loss, metrics = critic_loss_def(
354 | reward_fn,
355 | state_action_value_fn,
356 | state_value_fn,
357 | target_state_value_fn,
358 | discount,
359 | demonstration_transitions,
360 | online_transitions,
361 | key,
362 | )
363 | return critic_loss, metrics
364 |
365 | def actor_loss(
366 | policy_params: networks_lib.Params,
367 | q_params: networks_lib.Params,
368 | alpha: jnp.ndarray,
369 | demonstration_transitions: types.Transition,
370 | online_transitions: types.Transition,
371 | key: networks_lib.PRNGKey,
372 | ) -> Tuple[jnp.ndarray, Dict[str, float | jnp.Array]]:
373 |
374 | def action_sample(
375 | observation: jnp.ndarray,
376 | action_key: jax.Array,
377 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Any]:
378 | dist = networks.policy_network.apply(policy_params, observation)
379 | sample = dist.sample(seed=action_key)
380 | log_prob = networks.log_prob(dist, sample)
381 | return dist.mode(), sample, log_prob, dist
382 |
383 | observation = sil_config.concatenate(
384 | demonstration_transitions.observation, online_transitions.observation
385 | )
386 | expert_mode, expert_action, expert_log_prob, expert_dist = action_sample(
387 | demonstration_transitions.observation, key
388 | )
389 | online_mode, online_action, online_log_prob, _ = action_sample(
390 | online_transitions.observation, key
391 | )
392 |
393 | action = sil_config.concatenate(expert_action, online_action)
394 | action_mode = sil_config.concatenate(expert_mode, online_mode)
395 | log_prob = sil_config.concatenate(expert_log_prob, online_log_prob)
396 | prior_log_prob = networks.log_prob_prior(action)
397 | # Use min as reducer in case we use two or one critic functions.
398 | q = networks.critic_network.apply(
399 | q_params, observation, action).min(axis=-1)
400 | q_mode = networks.critic_network.apply(
401 | q_params, observation, action_mode).min(axis=-1)
402 |
403 | if use_policy_prior:
404 | dist_bc = networks.bc_policy_network.apply(
405 | bc_policy_params, observation
406 | )
407 | prior_log_prob = networks.log_prob(dist_bc, action)
408 | kl = (log_prob - prior_log_prob).mean()
409 | constraint = (kl_bound - kl).sum() # Sum to reduce to scalar.
410 | clipped_constraint = jnp.clip(constraint, a_max=0.0)
411 | d = damping * clipped_constraint ** 2
412 | # Constraint is <= 0, so negate for loss.
413 | entropy_reg = -alpha * constraint + d
414 | else: # Vanilla maximum entropy regularization with uniform prior.
415 | d = 0.0
416 | kl = (log_prob - prior_log_prob).mean()
417 | constraint = kl
418 | clipped_constraint = 0.0
419 | entropy_reg = alpha * kl
420 |
421 | actor_loss = entropy_reg - q.mean()
422 |
423 | if actor_bc_loss:
424 |
425 | # For SAC's tanh policy, the minimizing modal MSE and maximizing
426 | # loglikelihood do not appear to be mutually guaranteed, so we optimize
427 | # for both.
428 | # Incorporate BC MSE loss from TD3+BC.
429 | # https://arxiv.org/abs/2106.06860
430 | expert_se = (expert_mode - demonstration_transitions.action) ** 2
431 | bc_loss_mean = 0.5 * expert_se.mean() * jnp.abs(q).mean()
432 |
433 | # Also incorporate a log-likelihood, which should be similar in value to
434 | # the entropy as they are constructed in similar ways, so use alpha to
435 | # weight. This is like maximum likelihood with an entropy bonus.
436 | # See https://proceedings.mlr.press/v97/jacq19a/jacq19a.pdf Section 5.2.
437 | expert_demo_log_prob = networks.log_prob(
438 | expert_dist, demonstration_transitions.action
439 | )
440 | bc_loss_mean += -alpha * expert_demo_log_prob.mean()
441 |
442 | actor_loss += bc_loss_mean
443 |
444 | metrics = {
445 | 'actor_q': q.mean(),
446 | 'actor_q_mode': q_mode.mean(),
447 | 'actor_entropy_bonus': (alpha * log_prob).mean(),
448 | 'actor_kl': kl,
449 | 'kl_bound': kl_bound,
450 | 'constraint': constraint,
451 | 'clipped_constraint': clipped_constraint,
452 | 'entropy_reg': entropy_reg,
453 | 'prior_log_prob': prior_log_prob.mean(),
454 | 'policy_log_prob': log_prob.mean(),
455 | 'damping': d,
456 | }
457 | return actor_loss, metrics
458 |
459 | alpha_grad = jax.value_and_grad(alpha_loss)
460 | critic_grad = jax.value_and_grad(critic_loss, argnums=[0, 1], has_aux=True)
461 | actor_grad = jax.value_and_grad(actor_loss, has_aux=True)
462 |
463 | def update_step(
464 | state: TrainingState,
465 | sample: ImitationSample,
466 | ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
467 | # Update temperature, actor and critic.
468 | key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4)
469 | alpha_grads = None
470 | alpha_loss = None
471 | if adaptive_entropy_coefficient:
472 | transition = sil_config.concatenate_transitions(
473 | sample.online_sample, sample.demonstration_sample
474 | )
475 | alpha_loss, alpha_grads = alpha_grad(
476 | state.alpha_params, state.policy_params, transition, key_alpha
477 | )
478 | alpha = jax.nn.softplus(state.alpha_params)
479 | else:
480 | alpha = entropy_coefficient
481 |
482 | # Update critic (and reward).
483 |
484 | q_params = state.q_params
485 | r_params = state.r_params
486 | target_q_params = state.target_q_params
487 | critic_loss = None
488 | critic_grads = None
489 | critic_loss_metrics = None
490 | q_optimizer_state = None
491 | r_optimizer_state = None
492 | for _ in range(critic_actor_update_ratio):
493 | critic_losses, grads = critic_grad(
494 | q_params,
495 | r_params,
496 | state.policy_params,
497 | target_q_params,
498 | alpha,
499 | sample.demonstration_sample,
500 | sample.online_sample,
501 | key_critic,
502 | )
503 | critic_loss, critic_loss_metrics = critic_losses
504 | critic_grads, reward_grads = grads
505 |
506 | # Apply critic gradients.
507 | critic_update, q_optimizer_state = q_optimizer.update(
508 | critic_grads, state.q_optimizer_state, q_params
509 | )
510 | q_params = optax.apply_updates(q_params, critic_update)
511 |
512 | reward_update, r_optimizer_state = r_optimizer.update(
513 | reward_grads, state.r_optimizer_state, r_params
514 | )
515 | r_params = optax.apply_updates(r_params, reward_update)
516 |
517 | target_q_params = jax.tree_map(
518 | lambda x, y: x * (1 - tau) + y * tau, target_q_params, q_params
519 | )
520 |
521 | # Update actor.
522 | actor_losses, actor_grads = actor_grad(
523 | state.policy_params,
524 | q_params,
525 | alpha,
526 | sample.demonstration_sample,
527 | sample.online_sample,
528 | key_actor,
529 | )
530 | actor_loss, actor_loss_metrics = actor_losses
531 | actor_update, policy_optimizer_state = policy_optimizer.update(
532 | actor_grads, state.policy_optimizer_state)
533 | policy_params = optax.apply_updates(state.policy_params, actor_update)
534 |
535 | metrics = {
536 | 'critic_loss': critic_loss,
537 | 'actor_loss': actor_loss,
538 | 'critic_grad_norm': optax.global_norm(critic_grads),
539 | 'actor_grad_norm': optax.global_norm(actor_grads),
540 | }
541 |
542 | metrics.update(critic_loss_metrics)
543 | metrics.update(actor_loss_metrics)
544 |
545 | if MONITOR_BC_METRICS:
546 | # During training, expert actions should become / stay high likelihood.
547 | expert_action_dist = networks.policy_network.apply(
548 | policy_params, sample.demonstration_sample.observation
549 | )
550 | samp = expert_action_dist.sample(seed=key)
551 | expert_ent_approx = -networks.log_prob(expert_action_dist, samp).mean()
552 | expert_llhs = networks.log_prob(
553 | expert_action_dist, sample.demonstration_sample.action
554 | )
555 | expert_se = (
556 | expert_action_dist.mode() - sample.demonstration_sample.action
557 | ) ** 2
558 | online_action_dist = networks.policy_network.apply(
559 | policy_params, sample.online_sample.observation
560 | )
561 | samp = online_action_dist.sample(seed=key)
562 | online_ent_approx = -networks.log_prob(online_action_dist, samp).mean()
563 | online_llh = networks.log_prob(
564 | online_action_dist, sample.online_sample.action
565 | ).mean()
566 | online_se = (online_action_dist.mode() - sample.online_sample.action) ** 2
567 |
568 | metrics.update({
569 | 'expert_llh_mean': expert_llhs.mean(),
570 | 'expert_llh_max': expert_llhs.max(),
571 | 'expert_llh_min': expert_llhs.min(),
572 | 'expert_mse': expert_se.mean(),
573 | 'online_llh': online_llh,
574 | 'online_mse': online_se.mean(),
575 | 'expert_ent': expert_ent_approx,
576 | 'online_ent': online_ent_approx,
577 | })
578 |
579 | new_state = TrainingState(
580 | policy_optimizer_state=policy_optimizer_state,
581 | q_optimizer_state=q_optimizer_state,
582 | r_optimizer_state=r_optimizer_state,
583 | policy_params=policy_params,
584 | q_params=q_params,
585 | target_q_params=target_q_params,
586 | r_params=r_params,
587 | bc_policy_params=state.bc_policy_params,
588 | key=key,
589 | )
590 | if adaptive_entropy_coefficient:
591 | # Apply alpha gradients.
592 | alpha_update, alpha_optimizer_state = alpha_optimizer.update(
593 | alpha_grads, state.alpha_optimizer_state)
594 | alpha_params = optax.apply_updates(state.alpha_params, alpha_update)
595 | metrics.update({
596 | 'alpha_loss': alpha_loss,
597 | 'alpha': jax.nn.softplus(alpha_params),
598 | })
599 | new_state = new_state._replace(
600 | alpha_optimizer_state=alpha_optimizer_state,
601 | alpha_params=alpha_params)
602 |
603 | metrics['rewards_mean'] = jnp.mean(
604 | jnp.abs(jnp.mean(sample.online_sample.reward, axis=0))
605 | )
606 | metrics['rewards_std'] = jnp.std(sample.online_sample.reward, axis=0)
607 |
608 | return new_state, metrics
609 |
610 | # General learner book-keeping and loggers.
611 | self._counter = counter or counting.Counter()
612 | self._logger = learner_logger or loggers.make_default_logger(
613 | 'learner',
614 | asynchronous=True,
615 | serialize_fn=utils.fetch_devicearray,
616 | steps_key=self._counter.get_steps_key())
617 | self._num_sgd_steps_per_step = num_sgd_steps_per_step
618 |
619 | # Iterator on demonstration transitions.
620 | self._iterator = dataset
621 |
622 | # Use the JIT compiler.
623 | self._update_step = jax.jit(update_step)
624 |
625 | # Do not record timestamps until after the first learning step is done.
626 | # This is to avoid including the time it takes for actors to come online and
627 | # fill the replay buffer.
628 | self._timestamp = None
629 |
630 | def step(self):
631 |
632 | metrics = {}
633 | # Update temperature, actor and critic.
634 | for _ in range(self._num_sgd_steps_per_step):
635 | sample = next(self._iterator)
636 | self._state, metrics = self._update_step(self._state, sample)
637 |
638 | # Compute elapsed time.
639 | timestamp = time.time()
640 | elapsed_time = timestamp - self._timestamp if self._timestamp else 0
641 | self._timestamp = timestamp
642 |
643 | # Increment counts and record the current time.
644 | counts = self._counter.increment(steps=1, walltime=elapsed_time)
645 |
646 | # Attempts to write the logs.
647 | self._logger.write({**metrics, **counts})
648 |
649 | def get_variables(self, names: List[str]) -> List[Any]:
650 | variables = {
651 | 'policy': self._state.policy_params,
652 | 'critic': self._state.q_params,
653 | 'reward': self._state.r_params,
654 | 'bc_policy_params': self._state.bc_policy_params,
655 | }
656 | return [variables[name] for name in names]
657 |
658 | def save(self) -> TrainingState:
659 | return self._state
660 |
661 | def restore(self, state: TrainingState):
662 | self._state = state
663 |
--------------------------------------------------------------------------------
/sil/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Soft imitation networks definition.
17 |
18 | Builds heavily on acme/agents/jax/sac/networks.py
19 | (at https://github.com/google-deepmind/acme)
20 | """
21 |
22 | import dataclasses
23 | import enum
24 | import math
25 | from typing import Any, Optional, Tuple, Callable, Sequence, Union
26 |
27 | from acme import core
28 | from acme import specs
29 | from acme.agents.jax.sac import networks as sac_networks
30 | from acme.jax import networks as networks_lib
31 | from acme.jax import types
32 | from acme.jax import utils
33 | import haiku as hk
34 | import haiku.initializers as hk_init
35 | import jax
36 | import jax.numpy as jnp
37 | import numpy as np
38 | import tensorflow_probability.substrates.jax as tfp
39 |
40 | from sil import config as sil_config
41 |
42 | tfd = tfp.distributions
43 | tfb = tfp.bijectors
44 |
45 |
46 | class PolicyArchitectures(enum.Enum):
47 | """Variations of policy architectures used in SIL methods."""
48 | MLP = 'mlp'
49 | MIXMLP = 'mixmlp'
50 | HETSTATSIN = 'hetstatsin'
51 | HETSTATTRI = 'hetstattri'
52 | HETSTATPRELU = 'hetstatprelu'
53 | MIXHETSTATSIN = 'mixhetstatsin'
54 | MIXHETSTATTRI = 'mixhetstattri'
55 | MIXHETSTATPRELU = 'mixhetstatprelu'
56 |
57 | def __str__(self) -> str:
58 | return self.value
59 |
60 |
61 | class CriticArchitectures(enum.Enum):
62 | MLP = 'mlp'
63 | DOUBLE_MLP = 'double_mlp' # Used for SAC-based methods that use two critics.
64 | LNMLP = 'lnmlp'
65 | DOUBLE_LNMLP = 'double_lnmlp'
66 | STATSIN = 'statsin'
67 | STATTRI = 'stattri'
68 | STATPRELU = 'statprelu'
69 |
70 | def __str__(self) -> str:
71 | return self.value
72 |
73 |
74 | class RewardArchitectures(enum.Enum):
75 | MLP = 'mlp'
76 | LNMLP = 'lnmlp'
77 | PCSIL = 'pos_csil'
78 | NCSIL = 'neg_csil'
79 | PCONST = 'pos_const'
80 | NCONST = 'neg_const'
81 |
82 | def __str__(self) -> str:
83 | return self.value
84 |
85 |
86 | def observation_encoder(inputs: Any) -> jnp.ndarray:
87 | """Function that transforms observations into the correct vectors."""
88 | if isinstance(inputs, jnp.ndarray):
89 | return inputs
90 | else:
91 | raise ValueError(f'Cannot convert type {type(inputs)}.')
92 |
93 |
94 | def update_encoder(params: networks_lib.Params,
95 | reference: networks_lib.Params) -> networks_lib.Params:
96 | predicate = lambda module_name, name, value: 'encoder' in module_name
97 | _, params_head = hk.data_structures.partition(predicate, params)
98 | ref_enc, _ = hk.data_structures.partition(predicate, reference)
99 | return hk.data_structures.merge(params_head, ref_enc)
100 |
101 |
102 | class Sequential(hk.Module):
103 | """Sequentially calls the given list of layers."""
104 |
105 | def __init__(
106 | self,
107 | layers: Sequence[Callable[..., Any]],
108 | name: Optional[str] = None,
109 | ):
110 | super().__init__(name=name)
111 | self.layers = tuple(layers)
112 |
113 | def __call__(self, inputs, *args, **kwargs):
114 | """Calls all layers sequentially."""
115 | out = inputs
116 | last_idx = len(self.layers) - 1
117 | for i, layer in enumerate(self.layers):
118 | if i == last_idx:
119 | out = layer(out, *args, **kwargs)
120 | else:
121 | out = layer(out)
122 | return out
123 |
124 |
125 | @dataclasses.dataclass
126 | class SILNetworks:
127 | """Network and pure functions for the soft imitation agent."""
128 |
129 | environment_specs: specs.EnvironmentSpec
130 | policy_architecture: PolicyArchitectures
131 | bc_policy_architecture: PolicyArchitectures
132 | policy_network: networks_lib.FeedForwardNetwork
133 | critic_network: networks_lib.FeedForwardNetwork
134 | reward_network: networks_lib.FeedForwardNetwork
135 | log_prob: networks_lib.LogProbFn
136 | log_prob_prior: Callable[[jnp.ndarray], jnp.ndarray]
137 | sample: networks_lib.SampleFn
138 | bc_policy_network: networks_lib.FeedForwardNetwork
139 | reward_policy_coherence: bool = False
140 | sample_eval: Optional[networks_lib.SampleFn] = None
141 |
142 | def to_sac(self, using_bc_policy: bool = False) -> sac_networks.SACNetworks:
143 | """Cast to SAC policy to make use of the SAC helpers."""
144 | policy_network = (
145 | self.bc_policy_network if using_bc_policy else self.policy_network
146 | )
147 | return sac_networks.SACNetworks(
148 | policy_network,
149 | self.critic_network,
150 | self.log_prob,
151 | self.sample,
152 | self.sample_eval,
153 | )
154 |
155 |
156 | # From acme/agents/jax/cql/networks.py
157 | def apply_and_sample_n(
158 | key: networks_lib.PRNGKey,
159 | networks: SILNetworks,
160 | params: networks_lib.Params,
161 | obs: jnp.ndarray,
162 | num_samples: int,
163 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
164 | """Applies the policy and samples num_samples actions."""
165 | dist_params = networks.policy_network.apply(params, obs)
166 | sampled_actions = jnp.array(
167 | [
168 | networks.sample(dist_params, key_n)
169 | for key_n in jax.random.split(key, num_samples)
170 | ]
171 | )
172 | sampled_log_probs = networks.log_prob(dist_params, sampled_actions)
173 | return sampled_actions, sampled_log_probs
174 |
175 |
176 | def default_models_to_snapshot(
177 | networks: SILNetworks, spec: specs.EnvironmentSpec
178 | ):
179 | """Defines default models to be snapshotted."""
180 | dummy_obs = utils.zeros_like(spec.observations)
181 | dummy_action = utils.zeros_like(spec.actions)
182 | dummy_key = jax.random.PRNGKey(0)
183 |
184 | def critic_network(source: core.VariableSource) -> types.ModelToSnapshot:
185 | params = source.get_variables(['critic'])[0]
186 | return types.ModelToSnapshot(
187 | networks.critic_network.apply,
188 | params,
189 | {'obs': dummy_obs, 'action': dummy_action},
190 | )
191 |
192 | def reward_network(source: core.VariableSource) -> types.ModelToSnapshot:
193 | params = source.get_variables(['reward'])[0]
194 | return types.ModelToSnapshot(
195 | networks.critic_network.apply,
196 | params,
197 | {'obs': dummy_obs, 'action': dummy_action},
198 | )
199 |
200 | def default_training_actor(
201 | source: core.VariableSource) -> types.ModelToSnapshot:
202 | params = source.get_variables(['policy'])[0]
203 | return types.ModelToSnapshot(
204 | sac_networks.apply_policy_and_sample(
205 | networks.to_sac(), eval_mode=False
206 | ),
207 | params,
208 | {'key': dummy_key, 'obs': dummy_obs},
209 | )
210 |
211 | def default_eval_actor(
212 | source: core.VariableSource) -> types.ModelToSnapshot:
213 | params = source.get_variables(['policy'])[0]
214 | return types.ModelToSnapshot(
215 | sac_networks.apply_policy_and_sample(networks.to_sac(), eval_mode=True),
216 | params,
217 | {'key': dummy_key, 'obs': dummy_obs},
218 | )
219 |
220 | return {
221 | 'critic_network': critic_network,
222 | 'reward_network': reward_network,
223 | 'default_training_actor': default_training_actor,
224 | 'default_eval_actor': default_eval_actor,
225 | }
226 |
227 |
228 | default_init_normal = hk.initializers.VarianceScaling(
229 | 0.333, 'fan_out', 'normal'
230 | )
231 | # If 1D regression plotting, 0.2 is more sensible.
232 | # It's crucial this is not changed, it doesn't work otherwise.
233 | # Acme SAC uses this
234 | default_init_uniform = hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform')
235 |
236 |
237 | class ClampedScaleNormalTanhDistribution(hk.Module):
238 | """Module that produces a variance-clampled TanhTransformedDistribution."""
239 |
240 | def __init__(
241 | self,
242 | num_dimensions: int,
243 | max_log_scale: float = 0.0,
244 | min_log_scale: float = -5.0,
245 | w_init: hk_init.Initializer = hk_init.Orthogonal(),
246 | b_init: hk_init.Initializer = hk_init.Constant(0.0),
247 | name: str = 'ClampedScaleNormalTanhDistribution',
248 | ):
249 | """Initialization.
250 |
251 | Args:
252 | num_dimensions: Number of dimensions of a distribution.
253 | max_log_scale: Maximum log standard deviation.
254 | min_log_scale: Minimum log standard deviation.
255 | w_init: Initialization for linear layer weights.
256 | b_init: Initialization for linear layer biases.
257 | name: name of model that is passed to the parameters
258 | """
259 | super().__init__(name=name)
260 | assert max_log_scale > min_log_scale
261 | self._min_log_scale = min_log_scale
262 | self._log_scale_range = max_log_scale - self._min_log_scale
263 | self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init)
264 | self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init)
265 |
266 | def __call__(
267 | self, inputs: jnp.ndarray, faithful_distributions: bool = False
268 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]:
269 | loc = self._loc_layer(inputs)
270 | if faithful_distributions:
271 | inputs_ = jax.lax.stop_gradient(inputs)
272 | else:
273 | inputs_ = inputs
274 |
275 | log_scale_raw = self._scale_layer(inputs_) # Operating range around [-1, 1]
276 | log_scale_norm = 0.5 * (jax.nn.tanh(log_scale_raw) + 1.0) # Range [0, 1]
277 | scale = self._min_log_scale + self._log_scale_range * log_scale_norm
278 | scale = math.sqrt(sil_config.MIN_VAR) + jnp.exp(scale)
279 | distribution = tfd.Normal(loc=loc, scale=scale)
280 | transformed_dist = tfd.Independent(
281 | networks_lib.TanhTransformedDistribution(distribution),
282 | reinterpreted_batch_ndims=1,
283 | )
284 | if faithful_distributions:
285 | cut_dist = tfd.Normal(loc=jax.lax.stop_gradient(loc), scale=scale)
286 | cut_transformed_dist = tfd.Independent(
287 | networks_lib.TanhTransformedDistribution(cut_dist),
288 | reinterpreted_batch_ndims=1,
289 | )
290 | return transformed_dist, cut_transformed_dist
291 | else:
292 | return transformed_dist
293 |
294 |
295 | class MixtureClampedScaleNormalTanhDistribution(
296 | ClampedScaleNormalTanhDistribution):
297 | """Module that produces a variance-clampled TanhTransformedDistribution."""
298 |
299 | def __init__(
300 | self,
301 | num_dimensions: int,
302 | n_mixture: int,
303 | max_log_scale: float = 0.0,
304 | min_log_scale: float = -5.0,
305 | w_init: hk_init.Initializer = hk_init.Orthogonal(),
306 | b_init: hk_init.Initializer = hk_init.Constant(0.0),
307 | ):
308 | """Initialization.
309 |
310 | Args:
311 | num_dimensions: Number of dimensions of a distribution.
312 | n_mixture: numver of mixture components.
313 | max_log_scale: Maximum log standard deviation.
314 | min_log_scale: Minimum log standard deviation.
315 | w_init: Initialization for linear layer weights.
316 | b_init: Initialization for linear layer biases.
317 | """
318 | self.n_mixture = n_mixture
319 | self.n_outputs = num_dimensions
320 | super().__init__(
321 | num_dimensions=num_dimensions * n_mixture,
322 | max_log_scale=max_log_scale,
323 | min_log_scale=min_log_scale,
324 | w_init=w_init,
325 | b_init=b_init,
326 | name='MixtureClampedScaleNormalTanhDistribution',
327 | )
328 |
329 | def __call__(
330 | self, inputs: jnp.ndarray, faithful_distributions: bool = False
331 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]:
332 | loc = self._loc_layer(inputs)
333 | if faithful_distributions:
334 | inputs_ = jax.lax.stop_gradient(inputs)
335 | else:
336 | inputs_ = inputs
337 |
338 | log_scale_raw = self._scale_layer(inputs_) # Operating range around [-1, 1]
339 | log_scale_norm = 0.5 * (jax.nn.tanh(log_scale_raw) + 1.0) # range [0, 1]
340 | scale = self._min_log_scale + self._log_scale_range * log_scale_norm
341 | scale = math.sqrt(sil_config.MIN_VAR) + jnp.exp(scale)
342 |
343 | log_mixture_weights = hk.get_parameter(
344 | 'log_mixture_weights',
345 | [self.n_mixture],
346 | init=hk.initializers.Constant(1.0),
347 | )
348 | mixture_weights = jax.nn.softmax(log_mixture_weights)
349 | mixture_distribution = tfd.Categorical(probs=mixture_weights)
350 |
351 | def make_mixture(location, scale, weights):
352 | distribution = tfd.Normal(loc=location, scale=scale)
353 |
354 | transformed_distribution = tfd.Independent(
355 | networks_lib.TanhTransformedDistribution(distribution),
356 | reinterpreted_batch_ndims=1,
357 | )
358 |
359 | return MixtureSameFamily(
360 | mixture_distribution=weights,
361 | components_distribution=transformed_distribution,
362 | )
363 |
364 | mean = loc.reshape((-1, self.n_mixture, self.n_outputs))
365 | stddev = scale.reshape((-1, self.n_mixture, self.n_outputs))
366 | mixture = make_mixture(mean, stddev, mixture_distribution)
367 | if faithful_distributions:
368 | cut_mixture = make_mixture(jax.lax.stop_gradient(mean),
369 | stddev, mixture_distribution)
370 | return mixture, cut_mixture
371 | else:
372 | return mixture
373 |
374 |
375 | def _triangle_activation(x: jnp.ndarray) -> jnp.ndarray:
376 | z = jnp.floor(x / jnp.pi + 0.5)
377 | return (x - jnp.pi * z) * (-1) ** z
378 |
379 |
380 | @jax.jit
381 | def triangle_activation(x: jnp.ndarray) -> jnp.ndarray:
382 | pdiv2sqrt2 = 1.1107207345
383 | return pdiv2sqrt2 * _triangle_activation(x)
384 |
385 |
386 | @jax.jit
387 | def periodic_relu_activation(x: jnp.ndarray) -> jnp.ndarray:
388 | pdiv4 = 0.785398163
389 | pdiv2 = 1.570796326
390 | return (_triangle_activation(x) + _triangle_activation(x + pdiv2)) * pdiv4
391 |
392 |
393 | @jax.jit
394 | def sin_cos_activation(x: jnp.ndarray) -> jnp.ndarray:
395 | return jnp.sin(x) + jnp.cos(x)
396 |
397 |
398 | @jax.jit
399 | def hard_sin(x: jnp.ndarray) -> jnp.ndarray:
400 | pdiv4 = 0.785398163 # π/4
401 | return periodic_relu_activation(x - pdiv4)
402 |
403 |
404 | @jax.jit
405 | def hard_cos(x: jnp.ndarray) -> jnp.ndarray:
406 | pdiv4 = 0.785398163 # π/4
407 | return periodic_relu_activation(x + pdiv4)
408 |
409 |
410 | gaussian_init = hk.initializers.RandomNormal(1.0)
411 |
412 |
413 | class StationaryFeatures(hk.Module):
414 | """Stationary feature layer.
415 |
416 | Combines an MLP feature component (with bottleneck output) into a relatively
417 | wider feature layer that has periodic activation function. The from of the
418 | final weight distribution and periodic activation dictates the nature of the
419 | parametic stationary process.
420 |
421 | For more details see
422 | Periodic Activation Functions Induce Stationarity, Meronen et al.
423 | https://arxiv.org/abs/2110.13572
424 | """
425 |
426 | def __init__(
427 | self,
428 | num_dimensions: int,
429 | layers: Sequence[int],
430 | feature_dimension: int = 512,
431 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu,
432 | stationary_activation: Callable[
433 | [jnp.ndarray], jnp.ndarray
434 | ] = sin_cos_activation,
435 | stationary_init: hk.initializers.Initializer = gaussian_init,
436 | layer_norm_mlp: bool = False,
437 | ):
438 | """Initialization.
439 |
440 | Args:
441 | num_dimensions: Number of dimensions of the output distribution.
442 | layers: feature MLP architecture up to the feature layer.
443 | feature_dimension: size of feature layer.
444 | activation: Activation of MLP network.
445 | stationary_activation: Periodic activation of feature layer.
446 | stationary_init: Random initialization of last layer
447 | layer_norm_mlp: Use layer norm in the first MLP layer.
448 | """
449 | super().__init__(name='StationaryFeatures')
450 | self.output_dimension = num_dimensions
451 | self.feature_dimension = feature_dimension
452 | self.stationary_activation = stationary_activation
453 | self.stationary_init = stationary_init
454 | if layer_norm_mlp:
455 | self.mlp = networks_lib.LayerNormMLP(
456 | list(layers),
457 | activation=activation,
458 | w_init=hk.initializers.Orthogonal(),
459 | activate_final=True,
460 | )
461 | else:
462 | self.mlp = hk.nets.MLP(
463 | list(layers),
464 | activation=activation,
465 | w_init=hk.initializers.Orthogonal(),
466 | activate_final=True,
467 | )
468 |
469 | def features(self, inputs: jnp.ndarray) -> jnp.ndarray:
470 | input_dimension = inputs.shape[-1]
471 |
472 | # While the theory says that these random weights should be fixed, it's
473 | # crucial in practice to let them be trained. The distribution does not
474 | # actually change much, so they still contribute to the stationary
475 | # behaviour, and letting them be trained alleviates potential underfitting.
476 | random_weights = hk.get_parameter(
477 | 'random_weights',
478 | [input_dimension, self.feature_dimension // 2],
479 | init=self.stationary_init,
480 | )
481 |
482 | log_lengthscales = hk.get_parameter(
483 | 'log_lengthscales',
484 | [input_dimension],
485 | init=hk.initializers.Constant(-5.0),
486 | )
487 |
488 | ls = jnp.diag(jnp.exp(log_lengthscales))
489 | wx = inputs @ ls @ random_weights
490 | pdiv4 = 0.785398163 # π/4
491 | f = jnp.concatenate(
492 | (
493 | self.stationary_activation(wx + pdiv4),
494 | self.stationary_activation(wx - pdiv4),
495 | ),
496 | axis=-1,
497 | )
498 | return f / math.sqrt(self.feature_dimension)
499 |
500 |
501 | class StationaryHeteroskedasticNormalTanhDistribution(StationaryFeatures):
502 | """Module that produces a stationary TanhTransformedDistribution."""
503 |
504 | def __init__(
505 | self,
506 | num_dimensions: int,
507 | layers: Sequence[int],
508 | feature_dimension: int = 512,
509 | prior_variance: float = 1.0,
510 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu,
511 | stationary_activation: Callable[
512 | [jnp.ndarray], jnp.ndarray
513 | ] = sin_cos_activation,
514 | stationary_init: hk.initializers.Initializer = gaussian_init,
515 | layer_norm_mlp: bool = False,
516 | ):
517 | """Initialization.
518 |
519 | Args:
520 | num_dimensions: Number of dimensions of the output distribution.
521 | layers: feature MLP architecture up to the feature layer.
522 | feature_dimension: size of feature layer.
523 | prior_variance: initial variance of the predictive.
524 | activation: Activation of MLP network.
525 | stationary_activation: Periodic activation of feature layer.
526 | stationary_init: Random initialization of last layer
527 | layer_norm_mlp: Use layer norm in the first MLP layer.
528 | """
529 | self.prior_var = prior_variance
530 | self.prior_stddev = np.sqrt(prior_variance)
531 | super().__init__(
532 | num_dimensions,
533 | layers,
534 | feature_dimension,
535 | activation,
536 | stationary_activation,
537 | stationary_init,
538 | layer_norm_mlp
539 | )
540 |
541 | def __call__(
542 | self, inputs: jnp.ndarray, faithful_distributions: bool = False
543 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]:
544 | inputs = self.mlp(inputs)
545 | features = self.features(inputs)
546 | if faithful_distributions:
547 | features_ = jax.lax.stop_gradient(features)
548 | else:
549 | features_ = features
550 |
551 | loc_weights = hk.get_parameter(
552 | 'loc_weights',
553 | [self.feature_dimension, self.output_dimension],
554 | init=hk.initializers.Constant(0.0),
555 | )
556 |
557 | # Parameterize the PSD matrix in lower triangular form as a 'raw' vector.
558 | # This minimizes the memory footprint to between 50-75% of the full matrix.
559 | n_sqrt = self.feature_dimension * (self.feature_dimension + 1) // 2
560 | scale_cross_weights_sqrt_raw = hk.get_parameter(
561 | 'scale_cross_weights_sqrt',
562 | [self.output_dimension, n_sqrt],
563 | init=hk.initializers.Constant(0.0),
564 | )
565 | # convert vector into a lower triagular matrix with exponentiated diagonal,
566 | # so a vector of zeros becomes the identity matrix.
567 | b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
568 | scale_cross_weights_sqrt = jax.vmap(b.forward)(scale_cross_weights_sqrt_raw)
569 | loc = features @ loc_weights
570 | # Cholesky decompositon: A = LL^T where L is lower triangular
571 | # Variance is diagonal of x @ A @ x.T = x @ L @ L.T @ x.T
572 | # so first compute x @ L per output d
573 | var_sqrt = jnp.einsum('dij,bi->bdj', scale_cross_weights_sqrt, features_)
574 | var = jnp.einsum('bdi,bdi->bd', var_sqrt, var_sqrt)
575 |
576 | scale = self.prior_stddev * jnp.tanh(jnp.sqrt(sil_config.MIN_VAR + var))
577 |
578 | distribution = tfd.Normal(loc=loc, scale=scale)
579 |
580 | transformed_distribution = tfd.Independent(
581 | networks_lib.TanhTransformedDistribution(distribution),
582 | reinterpreted_batch_ndims=1,
583 | )
584 |
585 | if faithful_distributions:
586 | cut_distribution = tfd.Normal(loc=jax.lax.stop_gradient(loc), scale=scale)
587 | cut_transformed_distribution = tfd.Independent(
588 | networks_lib.TanhTransformedDistribution(cut_distribution),
589 | reinterpreted_batch_ndims=1,
590 | )
591 | return transformed_distribution, cut_transformed_distribution
592 | else:
593 | return transformed_distribution
594 |
595 |
596 | class MixtureSameFamily(tfd.MixtureSameFamily):
597 | """MixtureSameFamily with mode computation."""
598 |
599 | def mode(self) -> jnp.ndarray:
600 | """Return the mode of the modal mixture distribution."""
601 | mode_model = self.mixture_distribution.mode()
602 | modes = self.components_distribution.mode()
603 | return modes[:, mode_model, :]
604 |
605 |
606 | class MixtureStationaryHeteroskedasticNormalTanhDistribution(
607 | StationaryHeteroskedasticNormalTanhDistribution
608 | ):
609 | """Module that produces a stationary TanhTransformedDistribution."""
610 |
611 | def __init__(
612 | self,
613 | num_dimensions: int,
614 | n_mixture: int,
615 | layers: Sequence[int],
616 | feature_dimension: int = 512,
617 | prior_variance: float = 1.0,
618 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu,
619 | stationary_activation: Callable[
620 | [jnp.ndarray], jnp.ndarray
621 | ] = sin_cos_activation,
622 | stationary_init: hk.initializers.Initializer = gaussian_init,
623 | layer_norm_mlp: bool = False,
624 | ):
625 | """Initialization.
626 |
627 | Args:
628 | num_dimensions: Number of dimensions of the output distribution.
629 | n_mixture: Number of mixture components.
630 | layers: feature MLP architecture up to the feature layer.
631 | feature_dimension: size of feature layer.
632 | prior_variance: initial variance of the predictive.
633 | activation: Activation of MLP network.
634 | stationary_activation: Periodic activation of feature layer.
635 | stationary_init: Random initialization of last layer
636 | layer_norm_mlp: Use layer norm in the first MLP layer.
637 | """
638 | self.n_mixture = n_mixture
639 | super().__init__(
640 | num_dimensions=num_dimensions,
641 | layers=layers,
642 | feature_dimension=feature_dimension,
643 | prior_variance=prior_variance,
644 | activation=activation,
645 | stationary_activation=stationary_activation,
646 | stationary_init=stationary_init,
647 | layer_norm_mlp=layer_norm_mlp,
648 | )
649 |
650 | def __call__(
651 | self, inputs: jnp.ndarray, faithful_distributions: bool = False
652 | ) -> Union[tfd.Distribution, Tuple[tfd.Distribution, tfd.Distribution]]:
653 | inputs = self.mlp(inputs)
654 | features = self.features(inputs)
655 | if faithful_distributions:
656 | features_ = jax.lax.stop_gradient(features)
657 | else:
658 | features_ = features
659 |
660 | loc_weights = hk.get_parameter(
661 | 'loc_weights',
662 | [self.feature_dimension, self.n_mixture, self.output_dimension],
663 | init=hk.initializers.Orthogonal(), # For mixture diversity.
664 | )
665 |
666 | scale_cross_weights_sqrt = hk.get_parameter(
667 | 'scale_cross_weights_sqrt',
668 | [self.output_dimension, self.feature_dimension, self.feature_dimension],
669 | init=hk.initializers.Identity(gain=1.0),
670 | )
671 |
672 | # batch x n_mixture x d_out
673 | mean = jnp.einsum('bi,ijk->bjk', features, loc_weights)
674 | scale_cross_weights_sqrt = jnp.tril(scale_cross_weights_sqrt)
675 | # Cholesky decompositon: A = LL^T where L is lower triangular
676 | # Variance is diagonal of x @ A @ x.T = x @ L @ L.T @ x.T
677 | # so first compute x @ L per output d
678 | var_sqrt = jnp.einsum('dij,bi->bdj', scale_cross_weights_sqrt, features_)
679 | var = jnp.einsum('bdi,bdi->bd', var_sqrt, var_sqrt)
680 |
681 | stddev_ = self.prior_stddev * jnp.tanh(jnp.sqrt(sil_config.MIN_VAR + var))
682 | stddev = jnp.repeat(jnp.expand_dims(stddev_, 1), self.n_mixture, axis=1)
683 |
684 | assert mean.shape == stddev.shape, f'{mean.shape} != {stddev.shape}'
685 |
686 | log_mixture_weights = hk.get_parameter(
687 | 'log_mixture_weights',
688 | [self.n_mixture],
689 | init=hk.initializers.Constant(1.0),
690 | )
691 | mixture_weights = jax.nn.softmax(log_mixture_weights)
692 | mixture_distribution = tfd.Categorical(probs=mixture_weights)
693 |
694 | def make_mixture(location, scale, weights):
695 | distribution = tfd.Normal(loc=location, scale=scale)
696 |
697 | transformed_distribution = tfd.Independent(
698 | networks_lib.TanhTransformedDistribution(distribution),
699 | reinterpreted_batch_ndims=1,
700 | )
701 |
702 | return MixtureSameFamily(
703 | mixture_distribution=weights,
704 | components_distribution=transformed_distribution,
705 | )
706 |
707 | mixture = make_mixture(mean, stddev, mixture_distribution)
708 | if faithful_distributions:
709 | cut_mixture = make_mixture(
710 | jax.lax.stop_gradient(mean), stddev, mixture_distribution)
711 | return mixture, cut_mixture
712 | else:
713 | return mixture
714 |
715 |
716 | class StationaryMLP(StationaryFeatures):
717 | """MLP that behaves like the mean function of a stationary process."""
718 |
719 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
720 | inputs = self.mlp(inputs)
721 | features = self.features(inputs)
722 |
723 | loc_weights = hk.get_parameter(
724 | 'loc_weights',
725 | [self.feature_dimension, self.output_dimension],
726 | init=hk.initializers.Constant(0.0),
727 | )
728 | return features @ loc_weights
729 |
730 |
731 | class LayerNormMLP(hk.Module):
732 | """Simple feedforward MLP torso with initial layer-norm."""
733 |
734 | def __init__(
735 | self,
736 | layer_sizes: Sequence[int],
737 | w_init: hk.initializers.Initializer,
738 | activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu,
739 | activate_final: bool = False,
740 | name: str = 'feedforward_mlp_torso',
741 | ):
742 | """Construct the MLP."""
743 | super().__init__(name=name)
744 | assert len(layer_sizes) > 1
745 | self._network = hk.Sequential([
746 | hk.Linear(layer_sizes[0], w_init=w_init),
747 | hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
748 | hk.nets.MLP(
749 | layer_sizes[1:],
750 | w_init=w_init,
751 | activation=activation,
752 | activate_final=activate_final,
753 | ),
754 | ])
755 |
756 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
757 | """Forwards the policy network."""
758 | return self._network(inputs)
759 |
760 |
761 | def prior_policy_log_likelihood(
762 | env_spec: specs.EnvironmentSpec, policy_architecture: PolicyArchitectures
763 | ) -> Callable[[jnp.ndarray], jnp.ndarray]:
764 | """We assume a uniform hyper prior in a [-1, 1] action space."""
765 | del policy_architecture
766 | act_spec = env_spec.actions
767 | num_actions = np.prod(act_spec.shape, dtype=int)
768 |
769 | prior_llh = lambda x: -num_actions * jnp.log(2.0)
770 | return jax.vmap(prior_llh)
771 |
772 |
773 | def make_networks(
774 | spec: specs.EnvironmentSpec,
775 | policy_architecture: PolicyArchitectures = PolicyArchitectures.MLP,
776 | critic_architecture: CriticArchitectures = CriticArchitectures.LNMLP,
777 | reward_architecture: RewardArchitectures = RewardArchitectures.LNMLP,
778 | policy_hidden_layer_sizes: Tuple[int, ...] = (256, 256),
779 | critic_hidden_layer_sizes: Tuple[int, ...] = (256, 256),
780 | reward_hidden_layer_sizes: Tuple[int, ...] = (256, 256),
781 | reward_policy_coherence_alpha: Optional[float] = None,
782 | bc_policy_architecture: Optional[PolicyArchitectures] = None,
783 | bc_policy_hidden_layer_sizes: Optional[Tuple[int, ...]] = None,
784 | layer_norm_policy: bool = False,
785 | ) -> SILNetworks:
786 | """Creates networks used by the agent."""
787 |
788 | num_actions = np.prod(spec.actions.shape, dtype=int)
789 |
790 | def _make_actor_fn(
791 | policy_arch: PolicyArchitectures, hidden_layer_size: Tuple[int, ...]
792 | ):
793 | assert len(hidden_layer_size) > 1
794 |
795 | def _actor_fn(obs, *args, train_encoder=False, **kwargs):
796 | if policy_arch == PolicyArchitectures.MLP:
797 | mlp = networks_lib.LayerNormMLP if layer_norm_policy else hk.nets.MLP
798 | network = Sequential([
799 | mlp(
800 | list(hidden_layer_size),
801 | w_init=hk.initializers.Orthogonal(),
802 | activation=jax.nn.elu,
803 | activate_final=True,
804 | ),
805 | ClampedScaleNormalTanhDistribution(num_actions),
806 | ])
807 | elif policy_arch == PolicyArchitectures.MIXMLP:
808 | mlp = networks_lib.LayerNormMLP if layer_norm_policy else hk.nets.MLP
809 | network = Sequential([
810 | mlp(
811 | list(hidden_layer_size),
812 | w_init=hk.initializers.Orthogonal(),
813 | activation=jax.nn.elu,
814 | activate_final=True,
815 | ),
816 | MixtureClampedScaleNormalTanhDistribution(num_actions, n_mixture=5),
817 | ])
818 | elif policy_arch == PolicyArchitectures.HETSTATSIN:
819 | network = StationaryHeteroskedasticNormalTanhDistribution(
820 | num_actions,
821 | hidden_layer_size[:-1],
822 | feature_dimension=hidden_layer_size[-1],
823 | prior_variance=0.75,
824 | stationary_activation=sin_cos_activation,
825 | layer_norm_mlp=layer_norm_policy,
826 | )
827 | elif policy_arch == PolicyArchitectures.HETSTATTRI:
828 | network = StationaryHeteroskedasticNormalTanhDistribution(
829 | num_actions,
830 | hidden_layer_size[:-1],
831 | feature_dimension=hidden_layer_size[-1],
832 | prior_variance=0.75,
833 | stationary_activation=triangle_activation,
834 | layer_norm_mlp=layer_norm_policy,
835 | )
836 | elif policy_arch == PolicyArchitectures.HETSTATPRELU:
837 | network = StationaryHeteroskedasticNormalTanhDistribution(
838 | num_actions,
839 | hidden_layer_size[:-1],
840 | feature_dimension=hidden_layer_size[-1],
841 | prior_variance=0.75,
842 | stationary_activation=periodic_relu_activation,
843 | layer_norm_mlp=layer_norm_policy,
844 | )
845 | elif policy_arch == PolicyArchitectures.MIXHETSTATSIN:
846 | network = MixtureStationaryHeteroskedasticNormalTanhDistribution(
847 | num_actions,
848 | n_mixture=5,
849 | layers=hidden_layer_size[:-1],
850 | feature_dimension=hidden_layer_size[-1],
851 | prior_variance=0.75,
852 | stationary_activation=sin_cos_activation,
853 | layer_norm_mlp=layer_norm_policy,
854 | )
855 | elif policy_arch == PolicyArchitectures.MIXHETSTATTRI:
856 | network = MixtureStationaryHeteroskedasticNormalTanhDistribution(
857 | num_actions,
858 | n_mixture=5,
859 | layers=hidden_layer_size[:-1],
860 | feature_dimension=hidden_layer_size[-1],
861 | prior_variance=0.75,
862 | stationary_activation=triangle_activation,
863 | layer_norm_mlp=layer_norm_policy,
864 | )
865 | elif policy_arch == PolicyArchitectures.MIXHETSTATPRELU:
866 | network = MixtureStationaryHeteroskedasticNormalTanhDistribution(
867 | num_actions,
868 | n_mixture=5,
869 | layers=hidden_layer_size[:-1],
870 | feature_dimension=hidden_layer_size[-1],
871 | prior_variance=0.75,
872 | stationary_activation=periodic_relu_activation,
873 | layer_norm_mlp=layer_norm_policy,
874 | )
875 | else:
876 | raise ValueError('Unknown policy architecture.')
877 |
878 | obs = observation_encoder(obs)
879 | if not train_encoder:
880 | obs = jax.lax.stop_gradient(obs)
881 | return network(obs, *args, **kwargs)
882 |
883 | return _actor_fn
884 |
885 | actor_fn = _make_actor_fn(policy_architecture, policy_hidden_layer_sizes)
886 |
887 | prior_policy_llh = prior_policy_log_likelihood(
888 | spec,
889 | policy_architecture=PolicyArchitectures.MLP,
890 | )
891 |
892 | def _critic_fn(obs, action, train_encoder=False):
893 | if critic_architecture == CriticArchitectures.DOUBLE_MLP: # Needed for SAC.
894 | network1 = hk.Sequential([
895 | hk.nets.MLP(
896 | list(critic_hidden_layer_sizes) + [1],
897 | w_init=hk.initializers.Orthogonal(),
898 | activation=jax.nn.elu),
899 | ])
900 | network2 = hk.Sequential([
901 | hk.nets.MLP(
902 | list(critic_hidden_layer_sizes) + [1],
903 | w_init=hk.initializers.Orthogonal(),
904 | activation=jax.nn.elu),
905 | ])
906 | obs = observation_encoder(obs)
907 | input_ = jnp.concatenate([obs, action], axis=-1)
908 | value1 = network1(input_)
909 | value2 = network2(input_)
910 | return jnp.concatenate([value1, value2], axis=-1)
911 | elif critic_architecture == CriticArchitectures.MLP:
912 | network = hk.nets.MLP(
913 | list(critic_hidden_layer_sizes) + [1],
914 | w_init=hk.initializers.Orthogonal(),
915 | activation=jax.nn.elu,
916 | )
917 | elif critic_architecture == CriticArchitectures.LNMLP:
918 | network = networks_lib.LayerNormMLP(
919 | list(critic_hidden_layer_sizes) + [1],
920 | w_init=hk.initializers.Orthogonal(),
921 | activation=jax.nn.elu,
922 | )
923 | elif critic_architecture == CriticArchitectures.DOUBLE_LNMLP:
924 | network1 = hk.Sequential([
925 | networks_lib.LayerNormMLP(
926 | list(critic_hidden_layer_sizes) + [1],
927 | w_init=hk.initializers.Orthogonal(),
928 | activation=jax.nn.elu),
929 | ])
930 | network2 = hk.Sequential([
931 | networks_lib.LayerNormMLP(
932 | list(critic_hidden_layer_sizes) + [1],
933 | w_init=hk.initializers.Orthogonal(),
934 | activation=jax.nn.elu),
935 | ])
936 |
937 | obs = observation_encoder(obs)
938 | if not train_encoder:
939 | obs = jax.lax.stop_gradient(obs)
940 | input_ = jnp.concatenate([obs, action], axis=-1)
941 | value1 = network1(input_)
942 | value2 = network2(input_)
943 | return jnp.concatenate([value1, value2], axis=-1)
944 | elif critic_architecture == CriticArchitectures.STATSIN:
945 | network = StationaryMLP(
946 | 1,
947 | critic_hidden_layer_sizes[:-1],
948 | feature_dimension=critic_hidden_layer_sizes[-1],
949 | stationary_activation=sin_cos_activation,
950 | layer_norm_mlp=False,
951 | )
952 | elif critic_architecture == CriticArchitectures.STATTRI:
953 | network = StationaryMLP(
954 | 1,
955 | critic_hidden_layer_sizes[:-1],
956 | feature_dimension=critic_hidden_layer_sizes[-1],
957 | stationary_activation=triangle_activation,
958 | layer_norm_mlp=False,
959 | )
960 | elif critic_architecture == CriticArchitectures.STATPRELU:
961 | network = StationaryMLP(
962 | 1,
963 | critic_hidden_layer_sizes[:-1],
964 | feature_dimension=critic_hidden_layer_sizes[-1],
965 | stationary_activation=periodic_relu_activation,
966 | layer_norm_mlp=False,
967 | )
968 | else:
969 | raise ValueError('Unknown critic architecture.')
970 |
971 | obs = observation_encoder(obs)
972 | input_ = jnp.concatenate([obs, action], axis=-1)
973 | return network(input_)
974 |
975 | reward_policy_coherence = (reward_architecture == RewardArchitectures.PCSIL or
976 | reward_architecture == RewardArchitectures.NCSIL)
977 | if reward_policy_coherence:
978 | assert reward_policy_coherence_alpha is not None
979 | assert bc_policy_architecture is not None
980 | assert bc_policy_hidden_layer_sizes is not None
981 | bc_actor_fn = _make_actor_fn(
982 | bc_policy_architecture, bc_policy_hidden_layer_sizes
983 | )
984 |
985 | def _reward_fn(obs, action, *args, **kwargs):
986 | obs = observation_encoder(obs)
987 | alpha = reward_policy_coherence_alpha
988 | log_ratio = (bc_actor_fn(obs, *args, **kwargs).log_prob(action) # pytype: disable=attribute-error
989 | - prior_policy_llh(action))
990 | if reward_architecture == RewardArchitectures.PCSIL:
991 | return alpha * log_ratio
992 | else: # reward_architecture == RewardArchitectures.NCSIL
993 | return alpha * (log_ratio - num_actions * sil_config.MAX_REWARD)
994 |
995 | else:
996 | bc_actor_fn = actor_fn
997 | bc_policy_architecture = policy_architecture
998 |
999 | def _reward_fn(obs, action, train_encoder=False):
1000 | if reward_architecture == RewardArchitectures.MLP:
1001 | network = hk.nets.MLP(
1002 | list(reward_hidden_layer_sizes) + [1],
1003 | w_init=hk.initializers.Orthogonal(),
1004 | )
1005 | elif reward_architecture == RewardArchitectures.LNMLP:
1006 | network = networks_lib.LayerNormMLP(
1007 | list(reward_hidden_layer_sizes) + [1],
1008 | w_init=hk.initializers.Orthogonal(),
1009 | )
1010 | elif reward_architecture == RewardArchitectures.PCONST:
1011 | network = jax.vmap(lambda sa: 1.0)
1012 | elif reward_architecture == RewardArchitectures.NCONST:
1013 | network = jax.vmap(lambda sa: -1.0)
1014 | else:
1015 | raise ValueError('Unknown reward architecture.')
1016 |
1017 | obs = observation_encoder(obs)
1018 | if not train_encoder:
1019 | obs = jax.lax.stop_gradient(obs)
1020 | input_ = jnp.concatenate([obs, action], axis=-1)
1021 | return network(input_)
1022 |
1023 | policy = hk.without_apply_rng(hk.transform(actor_fn))
1024 | bc_policy = hk.without_apply_rng(hk.transform(bc_actor_fn))
1025 | critic = hk.without_apply_rng(hk.transform(_critic_fn))
1026 | reward = hk.without_apply_rng(hk.transform(_reward_fn))
1027 |
1028 | # Create dummy observations and actions to create network parameters.
1029 | dummy_action = utils.zeros_like(spec.actions)
1030 | dummy_obs = utils.zeros_like(spec.observations)
1031 | dummy_action = utils.add_batch_dim(dummy_action)
1032 | dummy_obs = utils.add_batch_dim(dummy_obs)
1033 |
1034 | return SILNetworks(
1035 | policy_architecture=policy_architecture,
1036 | bc_policy_architecture=bc_policy_architecture,
1037 | policy_network=networks_lib.FeedForwardNetwork(
1038 | lambda key: policy.init(key, dummy_obs), policy.apply
1039 | ),
1040 | critic_network=networks_lib.FeedForwardNetwork(
1041 | lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply
1042 | ),
1043 | reward_network=networks_lib.FeedForwardNetwork(
1044 | lambda key: reward.init(key, dummy_obs, dummy_action), reward.apply
1045 | ),
1046 | log_prob=lambda params, actions: params.log_prob(actions),
1047 | log_prob_prior=prior_policy_llh,
1048 | sample=lambda params, key: params.sample(seed=key),
1049 | # Policy eval is distribution's 'mode' (i.e. deterministic).
1050 | sample_eval=lambda params, key: params.mode(),
1051 | environment_specs=spec,
1052 | reward_policy_coherence=reward_policy_coherence,
1053 | bc_policy_network=networks_lib.FeedForwardNetwork(
1054 | lambda key: bc_policy.init(key, dummy_obs), bc_policy.apply
1055 | ),
1056 | )
1057 |
--------------------------------------------------------------------------------