├── jaxrl2
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── wandb_config_example.py
│ ├── target_update.py
│ ├── launch_util.py
│ ├── general_utils.py
│ ├── visualization_utils.py
│ └── wandb_logger.py
├── agents
│ ├── __init__.py
│ ├── pixel_sac
│ │ ├── __init__.py
│ │ ├── temperature.py
│ │ ├── temperature_updater.py
│ │ ├── critic_updater.py
│ │ ├── actor_updater.py
│ │ └── pixel_sac_learner.py
│ ├── agent.py
│ └── common.py
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── replay_buffer.py
│ └── augmentations.py
├── networks
│ ├── values
│ │ ├── __init__.py
│ │ ├── state_action_ensemble.py
│ │ ├── state_value.py
│ │ └── state_action_value.py
│ ├── __init__.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── spatial_softmax.py
│ │ ├── networks.py
│ │ ├── resnet_encoderv2.py
│ │ ├── impala_encoder.py
│ │ ├── resnet_encoderv1.py
│ │ └── cross_norm.py
│ ├── constants.py
│ ├── normal_policy.py
│ ├── normal_tanh_policy.py
│ ├── learned_std_normal_policy.py
│ └── mlp.py
└── types.py
├── examples
├── __init__.py
├── scripts
│ ├── run_aloha.sh
│ ├── run_libero.sh
│ └── run_real.sh
├── launch_train_sim.py
├── launch_train_real.py
├── train_real.py
├── train_sim.py
├── train_utils_real.py
└── train_utils_sim.py
├── setup.py
├── .gitmodules
├── requirements.txt
├── .gitignore
└── README.md
/jaxrl2/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/jaxrl2/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/jaxrl2/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxrl2.agents.pixel_sac import PixelSACLearner
2 |
--------------------------------------------------------------------------------
/jaxrl2/data/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxrl2.data.replay_buffer import ReplayBuffer
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name="jaxrl2", packages=["jaxrl2"])
4 |
--------------------------------------------------------------------------------
/jaxrl2/agents/pixel_sac/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxrl2.agents.pixel_sac.pixel_sac_learner import PixelSACLearner
2 |
--------------------------------------------------------------------------------
/jaxrl2/networks/values/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxrl2.networks.values.state_action_ensemble import StateActionEnsemble
2 | from jaxrl2.networks.values.state_value import StateValue
3 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "LIBERO"]
2 | path = LIBERO
3 | url = git@github.com:nakamotoo/LIBERO.git
4 | [submodule "openpi"]
5 | path = openpi
6 | url = git@github.com:nakamotoo/openpi.git
7 |
--------------------------------------------------------------------------------
/jaxrl2/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxrl2.networks.mlp import MLP
2 | from jaxrl2.networks.normal_policy import NormalPolicy
3 | from jaxrl2.networks.normal_tanh_policy import NormalTanhPolicy
4 |
--------------------------------------------------------------------------------
/jaxrl2/networks/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxrl2.networks.mlp import MLP
2 | from jaxrl2.networks.normal_policy import NormalPolicy
3 | from jaxrl2.networks.normal_tanh_policy import NormalTanhPolicy
4 |
--------------------------------------------------------------------------------
/jaxrl2/types.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Union
2 |
3 | import flax
4 | import numpy as np
5 |
6 | DataType = Union[np.ndarray, Dict[str, 'DataType']]
7 | PRNGKey = Any
8 | Params = flax.core.FrozenDict[str, Any]
--------------------------------------------------------------------------------
/jaxrl2/utils/wandb_config_example.py:
--------------------------------------------------------------------------------
1 | # copy this into wandb_config.py!
2 | def get_wandb_config():
3 | return dict (
4 | WANDB_API_KEY='your api key',
5 | WANDB_EMAIL='your email',
6 | WANDB_USERNAME='user',
7 | WANDB_TEAM='team_name_if_any'
8 | )
--------------------------------------------------------------------------------
/jaxrl2/networks/constants.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import jax.numpy as jnp
3 |
4 |
5 | def default_init(scale: float = jnp.sqrt(2)):
6 | return nn.initializers.orthogonal(scale)
7 |
8 | def xavier_init():
9 | return nn.initializers.xavier_normal()
10 |
11 | def kaiming_init():
12 | return nn.initializers.kaiming_normal()
--------------------------------------------------------------------------------
/jaxrl2/agents/pixel_sac/temperature.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import jax.numpy as jnp
3 |
4 |
5 | class Temperature(nn.Module):
6 | initial_temperature: float = 1.0
7 |
8 | @nn.compact
9 | def __call__(self) -> jnp.ndarray:
10 | log_temp = self.param('log_temp',
11 | init_fn=lambda key: jnp.full(
12 | (), jnp.log(self.initial_temperature)))
13 | return jnp.exp(log_temp)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.33.1
2 | distrax==0.1.5
3 | dm-control==1.0.14
4 | easydict==1.13
5 | einops==0.8.1
6 | flax==0.10.2
7 | future==1.0.0
8 | gym==0.26.2
9 | gym-notices==0.0.8
10 | gymnasium==1.0.0
11 | h5py==3.13.0
12 | imageio==2.37.0
13 | jax==0.5.0
14 | ml_collections==1.0.0
15 | ml_dtypes==0.5.1
16 | optax==0.2.4
17 | optree==0.15.0
18 | orbax-checkpoint==0.11.1
19 | orderly-set==5.4.0
20 | packaging==24.2
21 | pandas==2.2.3
22 | pillow==10.4.0
23 | scipy==1.15.2
24 | tensorflow==2.19.0
25 | etils==1.12.2
26 | matplotlib==3.9.3
27 | wandb[media]==0.19.9
28 | opencv-python==4.11.0.86
29 | robosuite==1.4.1
30 | gym-aloha==0.1.1
31 | bddl==3.5.0
32 | tf_keras==2.19.0
33 | moviepy==1.0.3
34 |
--------------------------------------------------------------------------------
/jaxrl2/utils/target_update.py:
--------------------------------------------------------------------------------
1 | import jax
2 |
3 | from jaxrl2.types import Params
4 |
5 | from functools import partial
6 |
7 | def soft_target_update(critic_params: Params, target_critic_params: Params, tau: float) -> Params:
8 | new_target_params = jax.tree_util.tree_map(lambda p, tp: p * tau + tp * (1 - tau), critic_params, target_critic_params)
9 | return new_target_params
10 |
11 | @partial(jax.pmap, axis_name='pmap', static_broadcasted_argnums=(2))
12 | def soft_target_update_parallel(critic_params: Params, target_critic_params: Params, tau: float) -> Params:
13 | new_target_params = jax.tree_util.tree_map(lambda p, tp: p * tau + tp * (1 - tau), critic_params, target_critic_params)
14 | return new_target_params
15 |
--------------------------------------------------------------------------------
/jaxrl2/utils/launch_util.py:
--------------------------------------------------------------------------------
1 | from jaxrl2.utils.general_utils import AttrDict
2 |
3 | def parse_training_args(train_args_dict, parser):
4 | for k, v in train_args_dict.items():
5 | if type(v) == tuple:
6 | parser.add_argument('--' + k, nargs="+", default=v, type=type(v[0]))
7 | elif type(v) != bool:
8 | parser.add_argument('--' + k, default=v, type=type(v))
9 | else:
10 | parser.add_argument('--' + k, default=int(v), type=int)
11 | args = parser.parse_args()
12 | config = {}
13 | for key in train_args_dict.keys():
14 | config[key] = getattr(args, key)
15 | variant = AttrDict(vars(args))
16 | variant['train_kwargs'] = config
17 | return variant, args
18 |
--------------------------------------------------------------------------------
/jaxrl2/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | class AttrDict(dict):
2 | __setattr__ = dict.__setitem__
3 |
4 | def __getattr__(self, attr):
5 | # Take care that getattr() raises AttributeError, not KeyError.
6 | # Required e.g. for hasattr(), deepcopy and OrderedDict.
7 | try:
8 | return self.__getitem__(attr)
9 | except KeyError:
10 | raise AttributeError("Attribute %r not found" % attr)
11 |
12 | def __getstate__(self): return self
13 | def __setstate__(self, d): self = d
14 |
15 | def add_batch_dim(input):
16 | if isinstance(input, dict):
17 | for k, v in input.items():
18 | input[k] = v[None]
19 | else:
20 | input = input[None]
21 | return input
22 |
--------------------------------------------------------------------------------
/jaxrl2/agents/pixel_sac/temperature_updater.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | import jax
4 | from flax.training.train_state import TrainState
5 |
6 |
7 | def update_temperature(
8 | temp: TrainState, entropy: float,
9 | target_entropy: float) -> Tuple[TrainState, Dict[str, float]]:
10 |
11 | def temperature_loss_fn(temp_params):
12 | temperature = temp.apply_fn({'params': temp_params})
13 | temp_loss = temperature * (entropy - target_entropy).mean()
14 | return temp_loss, {
15 | 'temperature': temperature,
16 | 'temperature_loss': temp_loss
17 | }
18 |
19 | grads, info = jax.grad(temperature_loss_fn, has_aux=True)(temp.params)
20 | new_temp = temp.apply_gradients(grads=grads)
21 |
22 | return new_temp, info
--------------------------------------------------------------------------------
/examples/scripts/run_aloha.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | proj_name=DSRL_pi0_Aloha
3 | device_id=0
4 |
5 | export DISPLAY=:0
6 | export MUJOCO_GL=egl
7 | export MUJOCO_EGL_DEVICE_ID=$device_id
8 |
9 | export OPENPI_DATA_HOME=./openpi
10 | export EXP=./logs/$proj_name;
11 | export CUDA_VISIBLE_DEVICES=$device_id
12 | export XLA_PYTHON_CLIENT_PREALLOCATE=false
13 |
14 |
15 | pip install mujoco==2.3.7
16 |
17 | python3 examples/launch_train_sim.py \
18 | --algorithm pixel_sac \
19 | --env aloha_cube \
20 | --prefix dsrl_pi0_aloha \
21 | --wandb_project ${proj_name} \
22 | --batch_size 256 \
23 | --discount 0.999 \
24 | --seed 0 \
25 | --max_steps 3000000 \
26 | --eval_interval 10000 \
27 | --log_interval 500 \
28 | --eval_episodes 10 \
29 | --multi_grad_step 20 \
30 | --start_online_updates 1000 \
31 | --resize_image 64 \
32 | --action_magnitude 2.0 \
33 | --query_freq 50 \
34 | --hidden_dims 128 \
35 | --target_entropy 0.0
--------------------------------------------------------------------------------
/examples/scripts/run_libero.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | proj_name=DSRL_pi0_Libero
3 | device_id=0
4 |
5 | export DISPLAY=:0
6 | export MUJOCO_GL=egl
7 | export PYOPENGL_PLATFORM=egl
8 | export MUJOCO_EGL_DEVICE_ID=$device_id
9 |
10 | export OPENPI_DATA_HOME=./openpi
11 | export EXP=./logs/$proj_name;
12 | export CUDA_VISIBLE_DEVICES=$device_id
13 | export XLA_PYTHON_CLIENT_PREALLOCATE=false
14 |
15 | pip install mujoco==3.3.1
16 |
17 | python3 examples/launch_train_sim.py \
18 | --algorithm pixel_sac \
19 | --env libero \
20 | --prefix dsrl_pi0_libero \
21 | --wandb_project ${proj_name} \
22 | --batch_size 256 \
23 | --discount 0.999 \
24 | --seed 0 \
25 | --max_steps 500000 \
26 | --eval_interval 10000 \
27 | --log_interval 500 \
28 | --eval_episodes 10 \
29 | --multi_grad_step 20 \
30 | --start_online_updates 500 \
31 | --resize_image 64 \
32 | --action_magnitude 1.0 \
33 | --query_freq 20 \
34 | --hidden_dims 128 \
--------------------------------------------------------------------------------
/examples/scripts/run_real.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | proj_name=DSRL_pi0_FrankaDroid
3 | device_id=0
4 |
5 | export EXP=./logs/$proj_name;
6 | export CUDA_VISIBLE_DEVICES=$device_id
7 | export XLA_PYTHON_CLIENT_PREALLOCATE=false
8 |
9 | # Fill inFranka Droid camera IDs
10 | export LEFT_CAMERA_ID=""
11 | export RIGHT_CAMERA_ID=""
12 | export WRIST_CAMERA_ID=""
13 |
14 | # Fill inpi0 remote host and port
15 | export remote_host=""
16 | export remote_port=""
17 |
18 |
19 | python3 examples/launch_train_real.py \
20 | --algorithm pixel_sac \
21 | --env franka_droid \
22 | --prefix dsrl_pi0_real \
23 | --wandb_project ${proj_name} \
24 | --batch_size 256 \
25 | --discount 0.99 \
26 | --seed 0 \
27 | --max_steps 500000 \
28 | --eval_interval 2000 \
29 | --log_interval 100 \
30 | --multi_grad_step 30 \
31 | --resize_image 128 \
32 | --action_magnitude 2.5 \
33 | --query_freq 10 \
34 | --hidden_dims 1024 \
35 | --num_qs 2
--------------------------------------------------------------------------------
/jaxrl2/networks/values/state_action_ensemble.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Sequence
2 |
3 | import flax.linen as nn
4 | import jax.numpy as jnp
5 |
6 | from jaxrl2.networks.values.state_action_value import StateActionValue
7 |
8 |
9 | class StateActionEnsemble(nn.Module):
10 | hidden_dims: Sequence[int]
11 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
12 | num_qs: int = 2
13 | use_action_sep: bool = False
14 |
15 | @nn.compact
16 | def __call__(self, states, actions, training: bool = False):
17 |
18 | # print ('Use action sep in state action ensemble: ', self.use_action_sep)
19 | VmapCritic = nn.vmap(StateActionValue,
20 | variable_axes={'params': 0},
21 | split_rngs={'params': True},
22 | in_axes=None,
23 | out_axes=0,
24 | axis_size=self.num_qs)
25 | qs = VmapCritic(self.hidden_dims,
26 | activations=self.activations,
27 | use_action_sep=self.use_action_sep)(states, actions,
28 | training)
29 | return qs
30 |
--------------------------------------------------------------------------------
/jaxrl2/networks/values/state_value.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Sequence
2 |
3 | import flax.linen as nn
4 | import jax.numpy as jnp
5 |
6 | from jaxrl2.networks.mlp import MLP
7 |
8 |
9 | class StateValue(nn.Module):
10 | hidden_dims: Sequence[int]
11 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
12 |
13 | @nn.compact
14 | def __call__(self,
15 | observations: jnp.ndarray,
16 | training: bool = False) -> jnp.ndarray:
17 | critic = MLP((*self.hidden_dims, 1),
18 | activations=self.activations)(observations,
19 | training=training)
20 | return jnp.squeeze(critic, -1)
21 |
22 |
23 | class StateValueEnsemble(nn.Module):
24 | hidden_dims: Sequence[int]
25 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
26 | num_vs: int = 2
27 |
28 | @nn.compact
29 | def __call__(self, observations, training: bool = False):
30 | VmapCritic = nn.vmap(StateValue,
31 | variable_axes={'params': 0},
32 | split_rngs={'params': True},
33 | in_axes=None,
34 | out_axes=0,
35 | axis_size=self.num_vs)
36 | qs = VmapCritic(self.hidden_dims,
37 | activations=self.activations)(observations,
38 | training)
39 | return qs
40 |
--------------------------------------------------------------------------------
/jaxrl2/networks/normal_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence
2 |
3 | import distrax
4 | import flax.linen as nn
5 | import jax.numpy as jnp
6 |
7 | from jaxrl2.networks import MLP
8 | from jaxrl2.networks.constants import default_init, xavier_init
9 |
10 |
11 | class NormalPolicy(nn.Module):
12 | hidden_dims: Sequence[int]
13 | action_dim: int
14 | dropout_rate: Optional[float] = None
15 | std: Optional[float] = 1.
16 | init_scale: Optional[float] = 1.
17 | output_scale: Optional[float] = 1.
18 | init_method: str = 'xavier'
19 |
20 | @nn.compact
21 | def __call__(self,
22 | observations: jnp.ndarray,
23 | training: bool = False) -> distrax.Distribution:
24 | outputs = MLP(self.hidden_dims,
25 | activate_final=True,
26 | dropout_rate=self.dropout_rate,
27 | init_scale = self.init_scale
28 | )(observations, training=training)
29 |
30 | if self.init_method == 'xavier':
31 | # print('fc layer {}x{}'.format(outputs.shape, self.action_dim))
32 | means = nn.Dense(self.action_dim, kernel_init=xavier_init())(outputs)
33 | else:
34 | means = nn.Dense(self.action_dim, kernel_init=default_init(self.init_scale))(outputs)
35 |
36 | means *= self.output_scale
37 |
38 | return distrax.MultivariateNormalDiag(loc=means,
39 | scale_diag=jnp.ones_like(means)*self.std)
40 |
--------------------------------------------------------------------------------
/jaxrl2/networks/values/state_action_value.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Sequence
2 |
3 | import flax.linen as nn
4 | import jax.numpy as jnp
5 |
6 | import jax
7 |
8 | from jaxrl2.networks.mlp import MLP
9 | from jaxrl2.networks.mlp import MLPActionSep
10 | from jaxrl2.networks.constants import default_init
11 |
12 | from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple,
13 | Union)
14 |
15 | PRNGKey = Any
16 | Shape = Tuple[int, ...]
17 | Dtype = Any
18 | Array = Any
19 | PrecisionLike = Union[None, str, jax.lax.Precision, Tuple[str, str],
20 | Tuple[jax.lax.Precision, jax.lax.Precision]]
21 |
22 | default_kernel_init = nn.initializers.lecun_normal()
23 |
24 | class StateActionValue(nn.Module):
25 | hidden_dims: Sequence[int]
26 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
27 | use_action_sep: bool = False
28 |
29 | @nn.compact
30 | def __call__(self,
31 | observations: jnp.ndarray,
32 | actions: jnp.ndarray,
33 | training: bool = False):
34 | inputs = {'states': observations, 'actions': actions}
35 | if self.use_action_sep:
36 | critic = MLPActionSep(
37 | (*self.hidden_dims, 1),
38 | activations=self.activations,
39 | use_layer_norm=True)(inputs, training=training)
40 | else:
41 | critic = MLP((*self.hidden_dims, 1),
42 | activations=self.activations,
43 | use_layer_norm=True)(inputs, training=training)
44 | return jnp.squeeze(critic, -1)
45 |
--------------------------------------------------------------------------------
/jaxrl2/networks/encoders/spatial_softmax.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Sequence, Union
2 |
3 | import flax.linen as nn
4 | import jax
5 | import jax.numpy as jnp
6 | import numpy as np
7 | from flax.core.frozen_dict import FrozenDict
8 |
9 | from jaxrl2.networks.constants import default_init, xavier_init, kaiming_init
10 |
11 | from functools import partial
12 | from typing import Any, Callable, Sequence, Tuple
13 | import distrax
14 | import wandb
15 |
16 | ModuleDef = Any
17 |
18 | class SpatialSoftmax(nn.Module):
19 | height: int
20 | width: int
21 | channel: int
22 | pos_x: jnp.ndarray
23 | pos_y: jnp.ndarray
24 | temperature: None
25 | log_heatmap: bool = False
26 |
27 | @nn.compact
28 | def __call__(self, feature):
29 | if self.temperature == -1:
30 | from jax.nn import initializers
31 | # print("Trainable temperature parameter")
32 | temperature = self.param('softmax_temperature', initializers.ones, (1), jnp.float32)
33 | else:
34 | temperature = 1.
35 |
36 | # print(temperature)
37 | assert len(feature.shape) == 4
38 | batch_size, num_featuremaps = feature.shape[0], feature.shape[3]
39 | feature = feature.transpose(0, 3, 1, 2).reshape(batch_size, num_featuremaps, self.height * self.width)
40 |
41 | softmax_attention = nn.softmax(feature / temperature)
42 | expected_x = jnp.sum(self.pos_x * softmax_attention, axis=2, keepdims=True).reshape(batch_size, num_featuremaps)
43 | expected_y = jnp.sum(self.pos_y * softmax_attention, axis=2, keepdims=True).reshape(batch_size, num_featuremaps)
44 | expected_xy = jnp.concatenate([expected_x, expected_y], axis=1)
45 |
46 | expected_xy = jnp.reshape(expected_xy, [batch_size, 2*num_featuremaps])
47 | return expected_xy
48 |
49 |
--------------------------------------------------------------------------------
/jaxrl2/networks/encoders/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Sequence, Union
2 |
3 | import flax.linen as nn
4 | import jax
5 | import jax.numpy as jnp
6 | from flax.core.frozen_dict import FrozenDict
7 |
8 | from jaxrl2.networks.constants import default_init, xavier_init, kaiming_init
9 |
10 | from functools import partial
11 | from typing import Any, Callable, Sequence, Tuple
12 | import distrax
13 |
14 | ModuleDef = Any
15 |
16 | class Encoder(nn.Module):
17 | features: Sequence[int] = (32, 32, 32, 32)
18 | strides: Sequence[int] = (2, 1, 1, 1)
19 | padding: str = 'VALID'
20 |
21 | @nn.compact
22 | def __call__(self, observations: jnp.ndarray, training=False) -> jnp.ndarray:
23 | assert len(self.features) == len(self.strides)
24 |
25 | x = observations.astype(jnp.float32) / 255.0
26 | x = jnp.reshape(x, (*x.shape[:-2], -1))
27 |
28 | for features, stride in zip(self.features, self.strides):
29 | x = nn.Conv(features,
30 | kernel_size=(3, 3),
31 | strides=(stride, stride),
32 | kernel_init=default_init(),
33 | padding=self.padding)(x)
34 | x = nn.relu(x)
35 |
36 | return x.reshape((*x.shape[:-3], -1))
37 |
38 |
39 | class PixelMultiplexer(nn.Module):
40 | encoder: Union[nn.Module, list]
41 | network: nn.Module
42 | latent_dim: int
43 | use_bottleneck: bool=True
44 |
45 | @nn.compact
46 | def __call__(self,
47 | observations: Union[FrozenDict, Dict],
48 | actions: Optional[jnp.ndarray] = None,
49 | training: bool = False):
50 | observations = FrozenDict(observations)
51 |
52 | x = self.encoder(observations['pixels'], training)
53 | if self.use_bottleneck:
54 | x = nn.Dense(self.latent_dim, kernel_init=xavier_init())(x)
55 | x = nn.LayerNorm()(x)
56 | x = nn.tanh(x)
57 |
58 | x = observations.copy(add_or_replace={'pixels': x})
59 |
60 | # print('fully connected keys', x.keys())
61 | if actions is None:
62 | return self.network(x, training=training)
63 | else:
64 | return self.network(x, actions, training=training)
65 |
--------------------------------------------------------------------------------
/jaxrl2/agents/agent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from flax.training import checkpoints
4 | import pathlib
5 | from flax.training.train_state import TrainState
6 |
7 | from jaxrl2.agents.common import (eval_actions_jit, eval_log_prob_jit, eval_mse_jit, eval_reward_function_jit,
8 | sample_actions_jit)
9 | from jaxrl2.data.dataset import DatasetDict
10 | from jaxrl2.types import PRNGKey
11 |
12 |
13 | def get_batch_stats(actor):
14 | if hasattr(actor, 'batch_stats'):
15 | return actor.batch_stats
16 | else:
17 | return None
18 |
19 | class Agent(object):
20 | _actor: TrainState
21 | _critic: TrainState
22 | _rng: PRNGKey
23 |
24 | def eval_actions(self, observations: np.ndarray) -> np.ndarray:
25 | actions = eval_actions_jit(self._actor.apply_fn, self._actor.params,
26 | observations, get_batch_stats(self._actor))
27 | return np.asarray(actions)
28 |
29 | def eval_log_probs(self, batch: DatasetDict) -> float:
30 | return eval_log_prob_jit(self._actor.apply_fn, self._actor.params, get_batch_stats(self._actor),
31 | batch)
32 |
33 | def eval_mse(self, batch: DatasetDict) -> float:
34 | return eval_mse_jit(self._actor.apply_fn, self._actor.params, get_batch_stats(self._actor),
35 | batch)
36 |
37 | def eval_reward_function(self, batch: DatasetDict) -> float:
38 | return eval_reward_function_jit(self._actor.apply_fn, self._actor.params, self._actor.batch_stats,
39 | batch)
40 |
41 | def sample_actions(self, observations: np.ndarray) -> np.ndarray:
42 | rng, actions = sample_actions_jit(self._rng, self._actor.apply_fn,
43 | self._actor.params, observations, get_batch_stats(self._actor))
44 |
45 | self._rng = rng
46 | return np.asarray(actions)
47 |
48 | @property
49 | def _save_dict(self):
50 | return None
51 |
52 | def save_checkpoint(self, dir, step, keep_every_n_steps):
53 | checkpoints.save_checkpoint(dir, self._save_dict, step, prefix='checkpoint', overwrite=False, keep_every_n_steps=keep_every_n_steps)
54 |
55 | def restore_checkpoint(self, dir):
56 | raise NotImplementedError
57 |
58 |
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .nox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *.cover
46 | *.py,cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 | db.sqlite3-journal
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # IPython
77 | profile_default/
78 | ipython_config.py
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # pipenv
84 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
85 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
86 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
87 | # install all needed dependencies.
88 | #Pipfile.lock
89 |
90 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
91 | __pypackages__/
92 |
93 | # Celery stuff
94 | celerybeat-schedule
95 | celerybeat.pid
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
127 | .idea/
128 | tmp/
129 | wandb_config.py
130 | logs/
131 | wandb/
--------------------------------------------------------------------------------
/jaxrl2/agents/pixel_sac/critic_updater.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from flax.training.train_state import TrainState
6 |
7 | from jaxrl2.data.dataset import DatasetDict
8 | from jaxrl2.types import Params, PRNGKey
9 |
10 |
11 | def update_critic(
12 | key: PRNGKey, actor: TrainState, critic: TrainState,
13 | target_critic: TrainState, temp: TrainState, batch: DatasetDict,
14 | discount: float, backup_entropy: bool = False,
15 | critic_reduction: str = 'min') -> Tuple[TrainState, Dict[str, float]]:
16 | dist = actor.apply_fn({'params': actor.params}, batch['next_observations'])
17 | next_actions, next_log_probs = dist.sample_and_log_prob(seed=key)
18 | next_qs = target_critic.apply_fn({'params': target_critic.params},
19 | batch['next_observations'], next_actions)
20 | if critic_reduction == 'min':
21 | next_q = next_qs.min(axis=0)
22 | elif critic_reduction == 'mean':
23 | next_q = next_qs.mean(axis=0)
24 | else:
25 | raise NotImplemented()
26 |
27 | target_q = batch['rewards'] + batch["discount"] * batch['masks'] * next_q
28 |
29 | if backup_entropy:
30 | target_q -= batch["discount"] * batch['masks'] * temp.apply_fn(
31 | {'params': temp.params}) * next_log_probs
32 |
33 | def critic_loss_fn(
34 | critic_params: Params) -> Tuple[jnp.ndarray, Dict[str, float]]:
35 | qs = critic.apply_fn({'params': critic_params}, batch['observations'],
36 | batch['actions'])
37 | critic_loss = ((qs - target_q)**2).mean()
38 | return critic_loss, {
39 | 'critic_loss': critic_loss,
40 | 'q': qs.mean(),
41 | 'target_actor_entropy': -next_log_probs.mean(),
42 | 'next_actions_sampled': next_actions.mean(),
43 | 'next_log_probs': next_log_probs.mean(),
44 | 'next_q_pi': next_qs.mean(),
45 | 'target_q': target_q.mean(),
46 | 'next_actions_mean': next_actions.mean(),
47 | 'next_actions_std': next_actions.std(),
48 | 'next_actions_min': next_actions.min(),
49 | 'next_actions_max': next_actions.max(),
50 | 'next_log_probs': next_log_probs.mean(),
51 |
52 | }
53 |
54 | grads, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params)
55 | new_critic = critic.apply_gradients(grads=grads)
56 |
57 | return new_critic, info
58 |
--------------------------------------------------------------------------------
/examples/launch_train_sim.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from examples.train_sim import main
4 | from jaxrl2.utils.launch_util import parse_training_args
5 |
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser()
9 |
10 | parser.add_argument('--seed', default=42, help='Random seed.', type=int)
11 | parser.add_argument('--launch_group_id', default='', help='group id used to group runs on wandb.')
12 | parser.add_argument('--eval_episodes', default=10,help='Number of episodes used for evaluation.', type=int)
13 | parser.add_argument('--env', default='libero', help='name of environment')
14 | parser.add_argument('--log_interval', default=1000, help='Logging interval.', type=int)
15 | parser.add_argument('--eval_interval', default=5000, help='Eval interval.', type=int)
16 | parser.add_argument('--checkpoint_interval', default=-1, help='checkpoint interval.', type=int)
17 | parser.add_argument('--batch_size', default=16, help='Mini batch size.', type=int)
18 | parser.add_argument('--max_steps', default=int(1e6), help='Number of training steps.', type=int)
19 | parser.add_argument('--add_states', default=1, help='whether to add low-dim states to the obervations', type=int)
20 | parser.add_argument('--wandb_project', default='cql_sim_online', help='wandb project')
21 | parser.add_argument('--start_online_updates', default=1000, help='number of steps to collect before starting online updates', type=int)
22 | parser.add_argument('--algorithm', default='pixel_sac', help='type of algorithm')
23 | parser.add_argument('--prefix', default='', help='prefix to use for wandb')
24 | parser.add_argument('--suffix', default='', help='suffix to use for wandb')
25 | parser.add_argument('--multi_grad_step', default=1, help='Number of graident steps to take per environment step, aka UTD', type=int)
26 | parser.add_argument('--resize_image', default=-1, help='the size of image if need resizing', type=int)
27 | parser.add_argument('--query_freq', default=-1, help='query frequency', type=int)
28 |
29 | train_args_dict = dict(
30 | actor_lr=1e-4,
31 | critic_lr= 3e-4,
32 | temp_lr=3e-4,
33 | hidden_dims= (128, 128, 128),
34 | cnn_features= (32, 32, 32, 32),
35 | cnn_strides= (2, 1, 1, 1),
36 | cnn_padding= 'VALID',
37 | latent_dim= 50,
38 | discount= 0.999,
39 | tau= 0.005,
40 | critic_reduction = 'mean',
41 | dropout_rate=0.0,
42 | aug_next=1,
43 | use_bottleneck=True,
44 | encoder_type='small',
45 | encoder_norm='group',
46 | use_spatial_softmax=True,
47 | softmax_temperature=-1,
48 | target_entropy='auto',
49 | num_qs=10,
50 | action_magnitude=1.0,
51 | num_cameras=1,
52 | )
53 |
54 | variant, args = parse_training_args(train_args_dict, parser)
55 | print(variant)
56 | main(variant)
57 | sys.exit()
58 |
--------------------------------------------------------------------------------
/jaxrl2/networks/encoders/resnet_encoderv2.py:
--------------------------------------------------------------------------------
1 | # Based on:
2 | # https://github.com/google/flax/blob/main/examples/imagenet/models.py
3 | # and
4 | # https://github.com/google-research/big_transfer/blob/master/bit_jax/models.py
5 | from functools import partial
6 | from typing import Any, Callable, Sequence, Tuple
7 |
8 | import flax.linen as nn
9 | import jax.numpy as jnp
10 | from flax import linen as nn
11 |
12 | ModuleDef = Any
13 |
14 |
15 | class ResNetV2Block(nn.Module):
16 | """ResNet block."""
17 | filters: int
18 | conv: ModuleDef
19 | norm: ModuleDef
20 | act: Callable
21 | strides: Tuple[int, int] = (1, 1)
22 |
23 | @nn.compact
24 | def __call__(self, x):
25 | residual = x
26 | y = self.norm()(x)
27 | y = self.act(y)
28 | y = self.conv(self.filters, (3, 3), self.strides)(y)
29 | y = self.norm()(y)
30 | y = self.act(y)
31 | y = self.conv(self.filters, (3, 3))(y)
32 |
33 | if residual.shape != y.shape:
34 | residual = self.conv(self.filters, (1, 1), self.strides)(residual)
35 |
36 | return residual + y
37 |
38 |
39 | class MyGroupNorm(nn.GroupNorm):
40 |
41 | def __call__(self, x):
42 | if x.ndim == 3:
43 | x = x[jnp.newaxis]
44 | x = super().__call__(x)
45 | return x[0]
46 | else:
47 | return super().__call__(x)
48 |
49 |
50 | class ResNetV2Encoder(nn.Module):
51 | """ResNetV2."""
52 | stage_sizes: Sequence[int]
53 | num_filters: int = 16
54 | dtype: Any = jnp.float32
55 | act: Callable = nn.relu
56 | norm: str = 'batch'
57 |
58 | @nn.compact
59 | def __call__(self, x, train: bool = True):
60 | conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
61 | if self.norm == 'batch':
62 | norm = partial(nn.BatchNorm,
63 | use_running_average=not train,
64 | momentum=0.9,
65 | epsilon=1e-5,
66 | dtype=self.dtype)
67 | elif self.norm == 'groupnorm':
68 | norm = partial(MyGroupNorm,
69 | num_groups=4,
70 | epsilon=1e-5,
71 | dtype=self.dtype)
72 | else:
73 | raise ValueError('norm not found')
74 |
75 | x = x.astype(jnp.float32) / 255.0
76 | x = jnp.reshape(x, (*x.shape[:-2], -1))
77 |
78 | x = conv(self.num_filters, (3, 3))(x)
79 | for i, block_size in enumerate(self.stage_sizes):
80 | for j in range(block_size):
81 | strides = (2, 2) if i > 0 and j == 0 else (1, 1)
82 | x = ResNetV2Block(self.num_filters * 2**i,
83 | strides=strides,
84 | conv=conv,
85 | norm=norm,
86 | act=self.act)(x)
87 |
88 | x = norm()(x)
89 | x = self.act(x)
90 | return x.reshape((*x.shape[:-3], -1))
91 |
--------------------------------------------------------------------------------
/examples/launch_train_real.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from examples.train_real import main
4 | from jaxrl2.utils.launch_util import parse_training_args
5 |
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser()
9 |
10 | parser.add_argument('--seed', default=42, help='Random seed.', type=int)
11 | parser.add_argument('--launch_group_id', default='', help='group id used to group runs on wandb.')
12 | parser.add_argument('--eval_episodes', default=10,help='Number of episodes used for evaluation.', type=int)
13 | parser.add_argument('--env', default='libero', help='name of environment')
14 | parser.add_argument('--log_interval', default=1000, help='Logging interval.', type=int)
15 | parser.add_argument('--eval_interval', default=5000, help='Eval interval.', type=int)
16 | parser.add_argument('--checkpoint_interval', default=-1, help='checkpoint interval.', type=int)
17 | parser.add_argument('--batch_size', default=16, help='Mini batch size.', type=int)
18 | parser.add_argument('--max_steps', default=int(1e6), help='Number of training steps.', type=int)
19 | parser.add_argument('--add_states', default=1, help='whether to add low-dim states to the obervations', type=int)
20 | parser.add_argument('--wandb_project', default='cql_sim_online', help='wandb project')
21 | parser.add_argument('--num_initial_traj_collect', default=1, help='number of trajectories to collect before starting online updates', type=int)
22 | parser.add_argument('--algorithm', default='pixel_sac', help='type of algorithm')
23 | parser.add_argument('--prefix', default='', help='prefix to use for wandb')
24 | parser.add_argument('--suffix', default='', help='suffix to use for wandb')
25 | parser.add_argument('--multi_grad_step', default=1, help='Number of graident steps to take per environment step, aka UTD', type=int)
26 | parser.add_argument('--resize_image', default=-1, help='the size of image if need resizing', type=int)
27 | parser.add_argument('--query_freq', default=-1, help='query frequency', type=int)
28 | parser.add_argument('--instruction', default='put the spoon on the plate', help='language instruction for the robot')
29 |
30 | # The hyperparameters for the real robot experiments
31 | train_args_dict = dict(
32 | actor_lr=1e-4,
33 | critic_lr= 3e-4,
34 | temp_lr=3e-4,
35 | hidden_dims= (1024, 1024, 1024),
36 | cnn_features= (32, 32, 32, 32),
37 | cnn_strides= (3, 2, 2, 2),
38 | cnn_padding= 'VALID',
39 | latent_dim= 50,
40 | discount= 0.99,
41 | tau= 0.005,
42 | critic_reduction = 'min',
43 | dropout_rate=0.0,
44 | aug_next=1,
45 | use_bottleneck=True,
46 | encoder_type='small',
47 | encoder_norm='group',
48 | use_spatial_softmax=True,
49 | softmax_temperature=-1,
50 | target_entropy=0.0,
51 | num_qs=2,
52 | action_magnitude=2.5,
53 | num_cameras=3,
54 | )
55 |
56 | variant, args = parse_training_args(train_args_dict, parser)
57 | print(variant)
58 | main(variant)
59 | sys.exit()
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # DSRL for π₀: Diffusion Steering via Reinforcement Learning
4 |
5 | ## [[website](https://diffusion-steering.github.io)] [[paper](https://arxiv.org/abs/2506.15799)]
6 |
7 |
8 |
9 |
10 | ## Overview
11 | This repository provides the official implementation for our paper: [Steering Your Diffusion Policy with Latent Space Reinforcement Learning](https://arxiv.org/abs/2506.15799) (CoRL 2025).
12 |
13 | Specifically, it contains a JAX-based implementation of DSRL (Diffusion Steering via Reinforcement Learning) for steering a pre-trained generalist policy, [π₀](https://github.com/Physical-Intelligence/openpi), across various environments, including:
14 |
15 | - **Simulation:** Libero, Aloha
16 | - **Real Robot:** Franka
17 |
18 | If you find this repository useful for your research, please cite:
19 |
20 | ```
21 | @article{wagenmaker2025steering,
22 | author = {Andrew Wagenmaker and Mitsuhiko Nakamoto and Yunchu Zhang and Seohong Park and Waleed Yagoub and Anusha Nagabandi and Abhishek Gupta and Sergey Levine},
23 | title = {Steering Your Diffusion Policy with Latent Space Reinforcement Learning},
24 | journal = {Conference on Robot Learning (CoRL)},
25 | year = {2025},
26 | }
27 | ```
28 |
29 | ## Installation
30 | 1. Create a conda environment:
31 | ```
32 | conda create -n dsrl_pi0 python=3.11.11
33 | conda activate dsrl_pi0
34 | ```
35 |
36 | 2. Clone this repo with all submodules
37 | ```
38 | git clone git@github.com:nakamotoo/dsrl_pi0.git --recurse-submodules
39 | cd dsrl_pi0
40 | ```
41 |
42 | 3. Install all packages and dependencies
43 | ```
44 | pip install -e .
45 | pip install -r requirements.txt
46 | pip install "jax[cuda12]==0.5.0"
47 |
48 | # install openpi
49 | pip install -e openpi
50 | pip install -e openpi/packages/openpi-client
51 |
52 | # install Libero
53 | pip install -e LIBERO
54 | pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu # needed for libero
55 | ```
56 |
57 | ## Training (Simulation)
58 | Libero
59 | ```
60 | bash examples/scripts/run_libero.sh
61 | ```
62 | Aloha
63 | ```
64 | bash examples/scripts/run_aloha.sh
65 | ```
66 | ### Training Logs
67 | We provide sample W&B runs and logs: https://wandb.ai/mitsuhiko/DSRL_pi0_public
68 |
69 | ## Training (Real)
70 | For real-world experiments, we use the remote hosting feature from pi0 (see [here](https://github.com/Physical-Intelligence/openpi/blob/main/docs/remote_inference.md)) which enables us to host the pi0 model on a higher-spec remote server, in case the robot's client machine is not powerful enough.
71 |
72 | 0. Setup Franka robot and install DROID package [[link](https://github.com/droid-dataset/droid.git)]
73 |
74 | 1. [On the remote server] Host pi0 droid model on your remote server
75 | ```
76 | cd openpi && python scripts/serve_policy.py --env=DROID
77 | ```
78 | 2. [On your robot client machine] Run DSRL
79 | ```
80 | bash examples/scripts/run_real.sh
81 | ```
82 |
83 |
84 | ## Credits
85 | This repository is built upon [jaxrl2](https://github.com/ikostrikov/jaxrl2) and [PTR](https://github.com/Asap7772/PTR) repositories.
86 | In case of any questions, bugs, suggestions or improvements, please feel free to contact me at nakamoto\[at\]berkeley\[dot\]edu
87 |
--------------------------------------------------------------------------------
/jaxrl2/networks/normal_tanh_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence
2 |
3 | import distrax
4 | import flax.linen as nn
5 | import jax.numpy as jnp
6 | from tensorflow_probability.substrates import jax as tfp
7 |
8 | from jaxrl2.networks import MLP
9 | from jaxrl2.networks.constants import default_init, xavier_init
10 |
11 |
12 | class TanhMultivariateNormalDiag(distrax.Transformed):
13 |
14 | def __init__(self,
15 | loc: jnp.ndarray,
16 | scale_diag: jnp.ndarray,
17 | low: Optional[jnp.ndarray] = None,
18 | high: Optional[jnp.ndarray] = None):
19 | distribution = distrax.MultivariateNormalDiag(loc=loc,
20 | scale_diag=scale_diag)
21 |
22 | layers = []
23 |
24 | if not (low is None or high is None):
25 |
26 | def rescale_from_tanh(x):
27 | x = (x + 1) / 2 # (-1, 1) => (0, 1)
28 | return x * (high - low) + low
29 |
30 | def forward_log_det_jacobian(x):
31 | high_ = jnp.broadcast_to(high, x.shape)
32 | low_ = jnp.broadcast_to(low, x.shape)
33 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1)
34 |
35 | layers.append(
36 | distrax.Lambda(
37 | rescale_from_tanh,
38 | forward_log_det_jacobian=forward_log_det_jacobian,
39 | event_ndims_in=1,
40 | event_ndims_out=1))
41 |
42 | layers.append(distrax.Block(distrax.Tanh(), 1))
43 |
44 | bijector = distrax.Chain(layers)
45 |
46 | super().__init__(distribution=distribution, bijector=bijector)
47 |
48 | def mode(self) -> jnp.ndarray:
49 | return self.bijector.forward(self.distribution.mode())
50 |
51 |
52 | class NormalTanhPolicy(nn.Module):
53 | hidden_dims: Sequence[int]
54 | action_dim: int
55 | dropout_rate: Optional[float] = None
56 | log_std_min: Optional[float] = -20
57 | log_std_max: Optional[float] = 2
58 | low: Optional[jnp.ndarray] = None
59 | high: Optional[jnp.ndarray] = None
60 | mlp_init_scale: float = 1.0
61 | init_method: str = 'default'
62 |
63 | @nn.compact
64 | def __call__(self,
65 | observations: jnp.ndarray,
66 | training: bool = False) -> distrax.Distribution:
67 | outputs = MLP(self.hidden_dims,
68 | activate_final=True,
69 | dropout_rate=self.dropout_rate,
70 | init_scale=self.mlp_init_scale)(observations,
71 | training=training)
72 |
73 | if self.init_method == 'xavier':
74 | means = nn.Dense(self.action_dim, kernel_init=xavier_init())(outputs)
75 | log_stds = nn.Dense(self.action_dim, kernel_init=xavier_init())(outputs)
76 | else:
77 | means = nn.Dense(self.action_dim, kernel_init=default_init(self.mlp_init_scale))(outputs)
78 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
79 |
80 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)
81 |
82 | return TanhMultivariateNormalDiag(loc=means,
83 | scale_diag=jnp.exp(log_stds) * self.mlp_init_scale,
84 | low=self.low,
85 | high=self.high)
86 |
--------------------------------------------------------------------------------
/jaxrl2/networks/encoders/impala_encoder.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import jax.numpy as jnp
3 |
4 |
5 | class ResnetStack(nn.Module):
6 | num_ch: int
7 | num_blocks: int
8 | use_max_pooling: bool = True
9 |
10 | @nn.compact
11 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
12 | initializer = nn.initializers.xavier_uniform()
13 | conv_out = nn.Conv(
14 | features=self.num_ch,
15 | kernel_size=(3, 3),
16 | strides=1,
17 | kernel_init=initializer,
18 | padding='SAME'
19 | )(observations)
20 |
21 | if self.use_max_pooling:
22 | conv_out = nn.max_pool(
23 | conv_out,
24 | window_shape=(3, 3),
25 | padding='SAME',
26 | strides=(2, 2)
27 | )
28 |
29 | for _ in range(self.num_blocks):
30 | block_input = conv_out
31 | conv_out = nn.relu(conv_out)
32 | conv_out = nn.Conv(
33 | features=self.num_ch, kernel_size=(3, 3), strides=1,
34 | padding='SAME',
35 | kernel_init=initializer)(conv_out)
36 |
37 | conv_out = nn.relu(conv_out)
38 | conv_out = nn.Conv(
39 | features=self.num_ch, kernel_size=(3, 3), strides=1,
40 | padding='SAME', kernel_init=initializer
41 | )(conv_out)
42 | conv_out += block_input
43 |
44 | return conv_out
45 |
46 |
47 | class ImpalaEncoder(nn.Module):
48 | nn_scale: int = 1
49 |
50 | def setup(self):
51 | stack_sizes = [16, 32, 32]
52 | self.stack_blocks = [
53 | ResnetStack(
54 | num_ch=stack_sizes[0] * self.nn_scale,
55 | num_blocks=2),
56 | ResnetStack(
57 | num_ch=stack_sizes[1] * self.nn_scale,
58 | num_blocks=2),
59 | ResnetStack(
60 | num_ch=stack_sizes[2] * self.nn_scale,
61 | num_blocks=2),
62 | ]
63 |
64 | @nn.compact
65 | def __call__(self, x, train=True):
66 | x = x.astype(jnp.float32) / 255.0
67 | x = jnp.reshape(x, (*x.shape[:-2], -1))
68 |
69 | conv_out = x
70 |
71 | for idx in range(len(self.stack_blocks)):
72 | conv_out = self.stack_blocks[idx](conv_out)
73 |
74 | conv_out = nn.relu(conv_out)
75 | return conv_out.reshape((*x.shape[:-3], -1))
76 |
77 |
78 | class SmallerImpalaEncoder(nn.Module):
79 | nn_scale: int = 1
80 |
81 | def setup(self):
82 | stack_sizes = [16, 32, 32]
83 | self.stack_blocks = [
84 | ResnetStack(
85 | num_ch=stack_sizes[0] * self.nn_scale,
86 | num_blocks=2),
87 | ResnetStack(
88 | num_ch=stack_sizes[1] * self.nn_scale,
89 | num_blocks=1),
90 | ResnetStack(
91 | num_ch=stack_sizes[2] * self.nn_scale,
92 | num_blocks=1),
93 | ]
94 |
95 | @nn.compact
96 | def __call__(self, x, train=True):
97 | x = x.astype(jnp.float32) / 255.0
98 | x = jnp.reshape(x, (*x.shape[:-2], -1))
99 |
100 | conv_out = x
101 |
102 | for idx in range(len(self.stack_blocks)):
103 | conv_out = self.stack_blocks[idx](conv_out)
104 |
105 | conv_out = nn.relu(conv_out)
106 | return conv_out.reshape((*x.shape[:-3], -1))
107 |
--------------------------------------------------------------------------------
/jaxrl2/agents/pixel_sac/actor_updater.py:
--------------------------------------------------------------------------------
1 | from audioop import cross
2 | from typing import Dict, Tuple
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | from flax.training.train_state import TrainState
7 |
8 | from jaxrl2.data.dataset import DatasetDict
9 | from jaxrl2.types import Params, PRNGKey
10 |
11 |
12 | def update_actor(key: PRNGKey, actor: TrainState, critic: TrainState,
13 | temp: TrainState, batch: DatasetDict, cross_norm:bool=False, critic_reduction:str='min') -> Tuple[TrainState, Dict[str, float]]:
14 |
15 | key, key_act = jax.random.split(key, num=2)
16 |
17 | def actor_loss_fn(
18 | actor_params: Params) -> Tuple[jnp.ndarray, Dict[str, float]]:
19 | if hasattr(actor, 'batch_stats') and actor.batch_stats is not None:
20 | dist, new_model_state = actor.apply_fn({'params': actor_params, 'batch_stats': actor.batch_stats}, batch['observations'], mutable=['batch_stats'])
21 | if cross_norm:
22 | next_dist = actor.apply_fn({'params': actor_params, 'batch_stats': actor.batch_stats}, batch['next_observations'], mutable=['batch_stats'])
23 | else:
24 | next_dist = actor.apply_fn({'params': actor_params, 'batch_stats': actor.batch_stats}, batch['next_observations'])
25 | if type(next_dist) == tuple:
26 | next_dist, new_model_state = next_dist
27 | else:
28 | dist = actor.apply_fn({'params': actor_params}, batch['observations'])
29 | next_dist = actor.apply_fn({'params': actor_params}, batch['next_observations'])
30 | new_model_state = {}
31 |
32 | # For logging only
33 | mean_dist = dist.distribution._loc
34 | std_diag_dist = dist.distribution._scale_diag
35 | mean_dist_norm = jnp.linalg.norm(mean_dist, axis=-1)
36 | std_dist_norm = jnp.linalg.norm(std_diag_dist, axis=-1)
37 |
38 |
39 | actions, log_probs = dist.sample_and_log_prob(seed=key_act)
40 |
41 | if hasattr(critic, 'batch_stats') and critic.batch_stats is not None:
42 | qs, _ = critic.apply_fn({'params': critic.params, 'batch_stats': critic.batch_stats}, batch['observations'],
43 | actions, mutable=['batch_stats'])
44 | else:
45 | qs = critic.apply_fn({'params': critic.params}, batch['observations'], actions)
46 |
47 | if critic_reduction == 'min':
48 | q = qs.min(axis=0)
49 | elif critic_reduction == 'mean':
50 | q = qs.mean(axis=0)
51 | else:
52 | raise ValueError(f"Invalid critic reduction: {critic_reduction}")
53 | actor_loss = (log_probs * temp.apply_fn({'params': temp.params}) - q).mean()
54 |
55 | things_to_log = {
56 | 'actor_loss': actor_loss,
57 | 'entropy': -log_probs.mean(),
58 | 'q_pi_in_actor': q.mean(),
59 | 'mean_pi_norm': mean_dist_norm.mean(),
60 | 'std_pi_norm': std_dist_norm.mean(),
61 | 'mean_pi_avg': mean_dist.mean(),
62 | 'mean_pi_max': mean_dist.max(),
63 | 'mean_pi_min': mean_dist.min(),
64 | 'std_pi_avg': std_diag_dist.mean(),
65 | 'std_pi_max': std_diag_dist.max(),
66 | 'std_pi_min': std_diag_dist.min(),
67 | }
68 | return actor_loss, (things_to_log, new_model_state)
69 |
70 | grads, (info, new_model_state) = jax.grad(actor_loss_fn, has_aux=True)(actor.params)
71 |
72 | if 'batch_stats' in new_model_state:
73 | new_actor = actor.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
74 | else:
75 | new_actor = actor.apply_gradients(grads=grads)
76 |
77 | return new_actor, info
--------------------------------------------------------------------------------
/jaxrl2/networks/learned_std_normal_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence
2 |
3 | import distrax
4 | import flax.linen as nn
5 | import jax.numpy as jnp
6 |
7 | from jaxrl2.networks import MLP
8 | from jaxrl2.networks.constants import default_init
9 |
10 | class LearnedStdNormalPolicy(nn.Module):
11 | hidden_dims: Sequence[int]
12 | action_dim: int
13 | dropout_rate: Optional[float] = None
14 | log_std_min: Optional[float] = -20
15 | log_std_max: Optional[float] = 2
16 |
17 | @nn.compact
18 | def __call__(self,
19 | observations: jnp.ndarray,
20 | training: bool = False) -> distrax.Distribution:
21 | outputs = MLP(self.hidden_dims,
22 | activate_final=True,
23 | dropout_rate=self.dropout_rate)(observations,
24 | training=training)
25 |
26 | means = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs)
27 |
28 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs)
29 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)
30 |
31 | distribution = distrax.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds))
32 | return distribution
33 |
34 | class TanhMultivariateNormalDiag(distrax.Transformed):
35 |
36 | def __init__(self,
37 | loc: jnp.ndarray,
38 | scale_diag: jnp.ndarray,
39 | low: Optional[jnp.ndarray] = None,
40 | high: Optional[jnp.ndarray] = None):
41 | distribution = distrax.MultivariateNormalDiag(loc=loc,
42 | scale_diag=scale_diag)
43 |
44 | layers = []
45 |
46 | if not (low is None or high is None):
47 |
48 | def rescale_from_tanh(x):
49 | x = (x + 1) / 2 # (-1, 1) => (0, 1)
50 | return x * (high - low) + low
51 |
52 | def forward_log_det_jacobian(x):
53 | high_ = jnp.broadcast_to(high, x.shape)
54 | low_ = jnp.broadcast_to(low, x.shape)
55 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1)
56 |
57 | layers.append(
58 | distrax.Lambda(
59 | rescale_from_tanh,
60 | forward_log_det_jacobian=forward_log_det_jacobian,
61 | event_ndims_in=1,
62 | event_ndims_out=1))
63 |
64 | layers.append(distrax.Block(distrax.Tanh(), 1))
65 |
66 | bijector = distrax.Chain(layers)
67 |
68 | super().__init__(distribution=distribution, bijector=bijector)
69 |
70 | def mode(self) -> jnp.ndarray:
71 | return self.bijector.forward(self.distribution.mode())
72 |
73 | class LearnedStdTanhNormalPolicy(nn.Module):
74 | hidden_dims: Sequence[int]
75 | action_dim: int
76 | dropout_rate: Optional[float] = None
77 | log_std_min: Optional[float] = -20
78 | log_std_max: Optional[float] = 2
79 | low: Optional[float] = None
80 | high: Optional[float] = None
81 |
82 | @nn.compact
83 | def __call__(self,
84 | observations: jnp.ndarray,
85 | training: bool = False) -> distrax.Distribution:
86 | outputs = MLP(self.hidden_dims,
87 | activate_final=True,
88 | dropout_rate=self.dropout_rate)(observations,
89 | training=training)
90 |
91 | means = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs)
92 |
93 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs)
94 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)
95 |
96 | distribution = TanhMultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds), low=self.low, high=self.high)
97 | return distribution
--------------------------------------------------------------------------------
/jaxrl2/networks/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Sequence, Union
2 | from flax.core import frozen_dict
3 |
4 | import numpy as np
5 | import flax.linen as nn
6 | import jax.numpy as jnp
7 | from flax.core.frozen_dict import FrozenDict
8 |
9 | from jaxrl2.networks.constants import default_init
10 |
11 |
12 | def _flatten_dict(x: Union[FrozenDict, jnp.ndarray]):
13 | if hasattr(x, 'values'):
14 | obs = []
15 | for k, v in sorted(x.items()):
16 | # if k == "actions":
17 | # v = v[:, 0:1, ...]
18 | if k == 'state': # flatten action chunk to 1D
19 | obs.append(jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])]))
20 | # v = jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])])
21 | elif k == 'prev_action' or k == 'actions':
22 | if v.ndim > 2:
23 | # deal with action chunk
24 | obs.append(jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])]))
25 | else:
26 | obs.append(v)
27 | else:
28 | obs.append(_flatten_dict(v))
29 | return jnp.concatenate(obs, -1)
30 | else:
31 | return x
32 |
33 | def _flatten_dict_special(x):
34 | if hasattr(x, 'values'):
35 | obs = []
36 | action = None
37 | for k, v in sorted(x.items()):
38 | if k == 'state' or k == 'prev_action':
39 | obs.append(jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])]))
40 | elif k == 'actions':
41 | print ('action shape: ', v.shape)
42 | action = v
43 | else:
44 | obs.append(_flatten_dict(v))
45 | return jnp.concatenate(obs, -1), action
46 | else:
47 | return x
48 |
49 |
50 | class MLP(nn.Module):
51 | hidden_dims: Sequence[int]
52 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
53 | activate_final: int = False
54 | dropout_rate: Optional[float] = None
55 | init_scale: Optional[float] = 1.
56 | use_layer_norm: bool = False
57 |
58 | @nn.compact
59 | def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
60 | x = _flatten_dict(x)
61 | # print('mlp post flatten', x.shape)
62 |
63 | for i, size in enumerate(self.hidden_dims):
64 | x = nn.Dense(size, kernel_init=default_init(self.init_scale))(x)
65 | # print('post fc size', x.shape)
66 | if i + 1 < len(self.hidden_dims) or self.activate_final:
67 | if self.dropout_rate is not None:
68 | x = nn.Dropout(rate=self.dropout_rate)(
69 | x, deterministic=not training)
70 | if self.use_layer_norm:
71 | x = nn.LayerNorm()(x)
72 | x = self.activations(x)
73 | return x
74 |
75 |
76 | class MLPActionSep(nn.Module):
77 | hidden_dims: Sequence[int]
78 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
79 | activate_final: int = False
80 | dropout_rate: Optional[float] = None
81 | init_scale: Optional[float] = 1.
82 | use_layer_norm: bool = False
83 | @nn.compact
84 | def __call__(self, x: jnp.ndarray, training: bool = False):
85 | x, action = _flatten_dict_special(x)
86 | print ('mlp action sep state post flatten', x.shape)
87 | print ('mlp action sep action post flatten', action.shape)
88 |
89 | for i, size in enumerate(self.hidden_dims):
90 | x_used = jnp.concatenate([x, action], axis=-1)
91 | x = nn.Dense(size, kernel_init=default_init())(x_used)
92 | print ('FF layers: ', x_used.shape, x.shape)
93 | if i + 1 < len(self.hidden_dims) or self.activate_final:
94 | if self.dropout_rate is not None:
95 | x = nn.Dropout(rate=self.dropout_rate)(
96 | x, deterministic=not training)
97 | if self.use_layer_norm:
98 | x = nn.LayerNorm()(x)
99 | x = self.activations(x)
100 | return x
--------------------------------------------------------------------------------
/jaxrl2/utils/visualization_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import jax.numpy as jnp
4 | import cv2
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
9 |
10 |
11 | def np_unstack(array, axis):
12 | arr = np.split(array, array.shape[axis], axis)
13 | arr = [a.squeeze() for a in arr]
14 | return arr
15 |
16 | def action2img(action, res, channels, action_scale):
17 | assert action.size == 2 # can only plot 2-dimensional actions
18 | img = np.zeros((res, res, channels), dtype=np.float32).copy()
19 | start_pt = res / 2 * np.ones((2,))
20 | end_pt = start_pt + action * action_scale * (res / 2 - 1) * np.array([1, -1]) # swaps last dimension
21 | np2pt = lambda x: tuple(np.asarray(x, int))
22 | img = cv2.arrowedLine(img, np2pt(start_pt), np2pt(end_pt), (255, 255, 255), 1, cv2.LINE_AA, tipLength=0.2)
23 | return img
24 |
25 | def batch_action2img(actions, res, channels, action_scale=50):
26 | batch = actions.shape[0]
27 | im = np.empty((batch, res, res, channels), dtype=np.float32)
28 | for b in range(batch):
29 | im[b] = action2img(actions[b], res, channels, action_scale)
30 | return im
31 |
32 | def visualize_image_actions(images, gtruth_actions, pred_actions):
33 | gtruth_action_row1 = batch_action2img(gtruth_actions[:, :2], 128, 3, action_scale=3)
34 | gtruth_action_row1 = np.concatenate(np_unstack(gtruth_action_row1, axis=0), axis=1)
35 | pred_action_row1 = batch_action2img(pred_actions[:, :2], 128, 3, action_scale=3)
36 | pred_action_row1 = np.concatenate(np_unstack(pred_action_row1, axis=0), axis=1)
37 | sel_image_row = np.concatenate(np_unstack(images, axis=0), axis=1)
38 | image_rows = [sel_image_row, gtruth_action_row1, pred_action_row1]
39 | out = np.concatenate(image_rows, axis=0)
40 | return out
41 |
42 |
43 | def visualize_states_rewards(states, rewards, target_point):
44 | states = states.squeeze()
45 | rewards = rewards.squeeze()
46 |
47 | fig, axs = plt.subplots(7, 1)
48 | fig.set_size_inches(5, 15)
49 | canvas = FigureCanvas(fig)
50 | plt.xlim([0, len(states)])
51 |
52 | axs[0].plot(states[:, 0], linestyle='--', marker='o')
53 | axs[0].set_ylabel('states_x')
54 | axs[1].plot(states[:, 1], linestyle='--', marker='o')
55 | axs[1].set_ylabel('states_y')
56 | axs[2].plot(states[:, 2], linestyle='--', marker='o')
57 | axs[2].set_ylabel('states_z')
58 |
59 | axs[3].plot(np.abs(states[:, 0] - target_point[0]), linestyle='--', marker='o')
60 | axs[3].set_ylabel('norm_x')
61 | axs[4].plot(np.abs(states[:, 1] - target_point[1]), linestyle='--', marker='o')
62 | axs[4].set_ylabel('norm_y')
63 | axs[5].plot(np.abs(states[:, 2] - target_point[2]), linestyle='--', marker='o')
64 | axs[5].set_ylabel('norm_z')
65 |
66 | axs[6].plot(rewards, linestyle='--', marker='o')
67 | axs[6].set_ylabel('rewards')
68 |
69 | plt.tight_layout()
70 |
71 | canvas.draw() # draw the canvas, cache the renderer
72 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
73 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
74 |
75 | plt.close(fig)
76 | return out_image
77 |
78 | def add_text_to_images(img_list, string_list):
79 | from PIL import Image
80 | from PIL import ImageDraw
81 | out = []
82 | for im, string in zip(img_list, string_list):
83 | im = Image.fromarray(np.array(im).astype(np.uint8))
84 | draw = ImageDraw.Draw(im)
85 | draw.text((10, 10), string, fill=(255, 0, 0, 128))
86 | out.append(np.array(im))
87 | return out
88 |
89 | def sigmoid(x):
90 | return 1. / (1. + jnp.exp(-x))
91 |
92 | def visualize_image_rewards(images, gtruth_rewards, pred_rewards, obs, task_id_mapping):
93 | id_task_mapping = {v : k for (k, v) in task_id_mapping.items()}
94 | sel_images = np_unstack(images, axis=0)
95 | sel_images = add_text_to_images(sel_images, ["{:.2f} \n{:.2f} \nTask {}".format(gtruth_rewards[i], sigmoid(pred_rewards[i, 0]), np.argmax(obs['task_id'][i])) for i in range(gtruth_rewards.shape[0])])
96 | sel_image_row = np.concatenate(sel_images, axis=1)
97 | return sel_image_row
98 |
--------------------------------------------------------------------------------
/jaxrl2/data/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Iterable, Optional, Tuple, Union
2 | import collections
3 | import jax
4 | import numpy as np
5 | from gym.utils import seeding
6 | import jax.numpy as jnp
7 | from jaxrl2.types import DataType
8 |
9 | DatasetDict = Dict[str, DataType]
10 | from flax.core import frozen_dict
11 |
12 | def concat_recursive(batches):
13 | new_batch = {}
14 | for k, v in batches[0].items():
15 | if isinstance(v, frozen_dict.FrozenDict):
16 | new_batch[k] = concat_recursive([batches[0][k], batches[1][k]])
17 | else:
18 | new_batch[k] = np.concatenate([b[k] for b in batches], 0)
19 | return new_batch
20 |
21 | def _check_lengths(dataset_dict: DatasetDict,
22 | dataset_len: Optional[int] = None) -> int:
23 | for v in dataset_dict.values():
24 | if isinstance(v, dict):
25 | dataset_len = dataset_len or _check_lengths(v, dataset_len)
26 | elif isinstance(v, np.ndarray):
27 | item_len = len(v)
28 | dataset_len = dataset_len or item_len
29 | assert dataset_len == item_len, 'Inconsistent item lengths in the dataset.'
30 | else:
31 | raise TypeError('Unsupported type.')
32 | return dataset_len
33 |
34 |
35 | def _split(dataset_dict: DatasetDict,
36 | index: int) -> Tuple[DatasetDict, DatasetDict]:
37 | train_dataset_dict, test_dataset_dict = {}, {}
38 | for k, v in dataset_dict.items():
39 | if isinstance(v, dict):
40 | train_v, test_v = _split(v, index)
41 | elif isinstance(v, np.ndarray):
42 | train_v, test_v = v[:index], v[index:]
43 | else:
44 | raise TypeError('Unsupported type.')
45 | train_dataset_dict[k] = train_v
46 | test_dataset_dict[k] = test_v
47 | return train_dataset_dict, test_dataset_dict
48 |
49 |
50 | def _sample(dataset_dict: Union[np.ndarray, DatasetDict],
51 | indx: np.ndarray) -> DatasetDict:
52 | if isinstance(dataset_dict, np.ndarray):
53 | return dataset_dict[indx]
54 | elif isinstance(dataset_dict, dict):
55 | batch = {}
56 | for k, v in dataset_dict.items():
57 | batch[k] = _sample(v, indx)
58 | else:
59 | raise TypeError("Unsupported type.")
60 | return batch
61 |
62 |
63 | class Dataset(object):
64 |
65 | def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None):
66 | self.dataset_dict = dataset_dict
67 | self.dataset_len = _check_lengths(dataset_dict)
68 |
69 | # Seeding similar to OpenAI Gym:
70 | # https://github.com/openai/gym/blob/master/gym/spaces/space.py#L46
71 | self._np_random = None
72 | if seed is not None:
73 | self.seed(seed)
74 |
75 | @property
76 | def np_random(self) -> np.random.RandomState:
77 | if self._np_random is None:
78 | self.seed()
79 | return self._np_random
80 |
81 | def seed(self, seed: Optional[int] = None) -> list:
82 | self._np_random, seed = seeding.np_random(seed)
83 | return [seed]
84 |
85 | def __len__(self) -> int:
86 | return self.dataset_len
87 |
88 | def sample(self,
89 | batch_size: int,
90 | keys: Optional[Iterable[str]] = None,
91 | indx: Optional[np.ndarray] = None) -> frozen_dict.FrozenDict:
92 | if indx is None:
93 | if hasattr(self.np_random, 'integers'):
94 | indx = self.np_random.integers(len(self), size=batch_size)
95 | else:
96 | indx = self.np_random.randint(len(self), size=batch_size)
97 |
98 | batch = dict()
99 |
100 | if keys is None:
101 | keys = self.dataset_dict.keys()
102 |
103 | for k in keys:
104 | if isinstance(self.dataset_dict[k], dict):
105 | batch[k] = _sample(self.dataset_dict[k], indx)
106 | else:
107 | batch[k] = self.dataset_dict[k][indx]
108 |
109 | return frozen_dict.freeze(batch)
110 |
111 | def split(self, ratio: float) -> Tuple['Dataset', 'Dataset']:
112 | assert 0 < ratio and ratio < 1
113 | index = int(self.dataset_len * ratio)
114 | train_dataset_dict, test_dataset_dict = _split(self.dataset_dict,
115 | index)
116 | return Dataset(train_dataset_dict), Dataset(test_dataset_dict)
117 |
--------------------------------------------------------------------------------
/jaxrl2/utils/wandb_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import wandb
4 | import time
5 |
6 | import dateutil.tz
7 | from collections import OrderedDict
8 | import numpy as np
9 | from numbers import Number
10 | import matplotlib
11 | matplotlib.use('Agg')
12 | import matplotlib.pyplot as plt
13 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
14 |
15 | def create_exp_name(exp_prefix, exp_id=0, seed=0):
16 | """
17 | Create a semi-unique experiment name that has a timestamp
18 | :param exp_prefix:
19 | :param exp_id:
20 | :return:
21 | """
22 | now = datetime.datetime.now(dateutil.tz.tzlocal())
23 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
24 | return "%s_%s_%04d--s-%d" % (exp_prefix, timestamp, exp_id, seed)
25 |
26 |
27 | def create_stats_ordered_dict(
28 | name,
29 | data,
30 | stat_prefix=None,
31 | always_show_all_stats=True,
32 | exclude_max_min=False,
33 | ):
34 | if stat_prefix is not None:
35 | name = "{}{}".format(stat_prefix, name)
36 | if isinstance(data, Number):
37 | return OrderedDict({name: data})
38 |
39 | if len(data) == 0:
40 | return OrderedDict()
41 |
42 | if isinstance(data, tuple):
43 | ordered_dict = OrderedDict()
44 | for number, d in enumerate(data):
45 | sub_dict = create_stats_ordered_dict(
46 | "{0}_{1}".format(name, number),
47 | d,
48 | )
49 | ordered_dict.update(sub_dict)
50 | return ordered_dict
51 |
52 | if isinstance(data, list):
53 | try:
54 | iter(data[0])
55 | except TypeError:
56 | pass
57 | else:
58 | data = np.concatenate(data)
59 |
60 | if (isinstance(data, np.ndarray) and data.size == 1
61 | and not always_show_all_stats):
62 | return OrderedDict({name: float(data)})
63 | try:
64 | stats = OrderedDict([
65 | (name + ' Mean', np.mean(data)),
66 | (name + ' Std', np.std(data)),
67 | ])
68 | except:
69 | stats = OrderedDict([
70 | (name + ' Mean', -1),
71 | (name + ' Std', -1),
72 | ])
73 | if not exclude_max_min:
74 | try:
75 | stats[name + ' Max'] = np.max(data)
76 | stats[name + ' Min'] = np.min(data)
77 | except:
78 | stats[name + ' Max'] = -1
79 | stats[name + ' Min'] = -1
80 | return stats
81 |
82 | class WandBLogger(object):
83 | def __init__(self, wandb_logging, variant, project, experiment_id, output_dir=None, group_name='', team=None):
84 | self.wandb_logging = wandb_logging
85 | output_dir = os.path.join(output_dir, experiment_id)
86 | os.makedirs(output_dir, exist_ok=True)
87 | if wandb_logging:
88 | print('wandb using experimentid: ', experiment_id)
89 | print('wandb using project: ', project)
90 | print('wandb using group: ', group_name)
91 |
92 | try:
93 | from jaxrl2.utils.wandb_config import get_wandb_config
94 | wandb_config = get_wandb_config()
95 | os.environ['WANDB_API_KEY'] = wandb_config['WANDB_API_KEY']
96 | os.environ['WANDB_USER_EMAIL'] = wandb_config['WANDB_EMAIL']
97 | os.environ['WANDB_USERNAME'] = wandb_config['WANDB_USERNAME']
98 | team = wandb_config['WANDB_TEAM'] if wandb_config['WANDB_TEAM'] != '' else None
99 | except:
100 | print('wandb_config.py not found, using default wandb config')
101 | os.environ["WANDB_MODE"] = "run"
102 | wandb.init(
103 | config=variant,
104 | project=project,
105 | dir=output_dir,
106 | id=experiment_id,
107 | settings=wandb.Settings(start_method="thread"),
108 | group=group_name,
109 | entity=team
110 | )
111 | self.output_dir = output_dir
112 |
113 |
114 | def log(self, *args, **kwargs):
115 | if self.wandb_logging:
116 | wandb.log(*args, **kwargs)
117 |
118 | def log_histogram(self, name, values, step):
119 | fig = plt.figure()
120 | canvas = FigureCanvas(fig)
121 | plt.tight_layout()
122 |
123 | plt.hist(np.array(values.flatten()), bins=100)
124 | # plt.show()
125 |
126 | canvas.draw() # draw the canvas, cache the renderer
127 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
128 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
129 | plt.close(fig)
130 | self.log({name: wandb.Image(out_image)}, step=step)
--------------------------------------------------------------------------------
/examples/train_real.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | import os
3 | import jax
4 | from jaxrl2.agents.pixel_sac.pixel_sac_learner import PixelSACLearner
5 | from jaxrl2.utils.general_utils import add_batch_dim
6 | import numpy as np
7 | import logging
8 |
9 | import gymnasium as gym
10 | from gym.spaces import Dict, Box
11 |
12 | from jaxrl2.data import ReplayBuffer
13 | from jaxrl2.utils.wandb_logger import WandBLogger, create_exp_name
14 | import tempfile
15 | from functools import partial
16 | from examples.train_utils_real import trajwise_alternating_training_loop
17 | import tensorflow as tf
18 | from jax.experimental.compilation_cache import compilation_cache
19 | from openpi_client import websocket_client_policy as _websocket_client_policy
20 | from droid.robot_env import RobotEnv
21 |
22 | home_dir = os.environ['HOME']
23 | compilation_cache.initialize_cache(os.path.join(home_dir, 'jax_compilation_cache'))
24 |
25 | def shard_batch(batch, sharding):
26 | """Shards a batch across devices along its first dimension.
27 |
28 | Args:
29 | batch: A pytree of arrays.
30 | sharding: A jax Sharding object with shape (num_devices,).
31 | """
32 | return jax.tree_util.tree_map(
33 | lambda x: jax.device_put(
34 | x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1)))
35 | ),
36 | batch,
37 | )
38 |
39 | class DummyEnv(gym.ObservationWrapper):
40 |
41 | def __init__(self, variant):
42 | self.variant = variant
43 | self.image_shape = (variant.resize_image, variant.resize_image, 3 * variant.num_cameras, 1)
44 | obs_dict = {}
45 | obs_dict['pixels'] = Box(low=0, high=255, shape=self.image_shape, dtype=np.uint8)
46 | if variant.add_states:
47 | state_dim = 8 + 2024 # 8 is the proprioceptive state's dim, 2024 is the image representation's dim
48 | obs_dict['state'] = Box(low=-1.0, high=1.0, shape=(state_dim, 1), dtype=np.float32)
49 | self.observation_space = Dict(obs_dict)
50 | self.action_space = Box(low=-1, high=1, shape=(1, 32,), dtype=np.float32) # 32 is the noise action space of pi 0
51 |
52 | def main(variant):
53 | devices = jax.local_devices()
54 | num_devices = len(devices)
55 | assert variant.batch_size % num_devices == 0
56 | logging.info('num devices', num_devices)
57 | logging.info('batch size', variant.batch_size)
58 | # we shard the leading dimension (batch dimension) accross all devices evenly
59 | sharding = jax.sharding.PositionalSharding(devices)
60 | shard_fn = partial(shard_batch, sharding=sharding)
61 |
62 | # prevent tensorflow from using GPUs
63 | tf.config.set_visible_devices([], "GPU")
64 |
65 | kwargs = variant['train_kwargs']
66 | if kwargs.pop('cosine_decay', False):
67 | kwargs['decay_steps'] = variant.max_steps
68 |
69 | if not variant.prefix:
70 | import uuid
71 | variant.prefix = str(uuid.uuid4().fields[-1])[:5]
72 |
73 | if variant.suffix:
74 | expname = create_exp_name(variant.prefix, seed=variant.seed) + f"_{variant.suffix}"
75 | else:
76 | expname = create_exp_name(variant.prefix, seed=variant.seed)
77 |
78 | outputdir = os.path.join(os.environ['EXP'], expname)
79 | variant.outputdir = outputdir
80 | if not os.path.exists(outputdir):
81 | os.makedirs(outputdir)
82 | print('writing to output dir ', outputdir)
83 |
84 | group_name = variant.prefix + '_' + variant.launch_group_id
85 | wandb_output_dir = tempfile.mkdtemp()
86 | wandb_logger = WandBLogger(variant.prefix != '', variant, variant.wandb_project, experiment_id=expname, output_dir=wandb_output_dir, group_name=group_name)
87 |
88 |
89 | agent_dp = _websocket_client_policy.WebsocketClientPolicy(
90 | host=os.environ['remote_host'],
91 | port=os.environ['remote_port']
92 | )
93 | logging.info(f"Server metadata: {agent_dp.get_server_metadata()}")
94 |
95 | logging.info("initializing environment...")
96 | env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
97 | eval_env = env
98 | logging.info("created the droid env!")
99 |
100 | assert os.environ.get('LEFT_CAMERA_ID') is not None
101 | assert os.environ.get('RIGHT_CAMERA_ID') is not None
102 | assert os.environ.get('WRIST_CAMERA_ID') is not None
103 |
104 | robot_config = dict(
105 | camera_to_use='right',
106 | left_camera_id=os.environ['LEFT_CAMERA_ID'],
107 | right_camera_id=os.environ['RIGHT_CAMERA_ID'],
108 | wrist_camera_id=os.environ['WRIST_CAMERA_ID'],
109 | max_timesteps=200
110 | )
111 |
112 | dummy_env = DummyEnv(variant)
113 | sample_obs = add_batch_dim(dummy_env.observation_space.sample())
114 | sample_action = add_batch_dim(dummy_env.action_space.sample())
115 | logging.info('sample obs shapes', [(k, v.shape) for k, v in sample_obs.items()])
116 | logging.info('sample action shape', sample_action.shape)
117 |
118 | agent = PixelSACLearner(variant.seed, sample_obs, sample_action, **kwargs)
119 |
120 | if variant.restore_path != '':
121 | logging.info('restoring from', variant.restore_path)
122 | agent.restore_checkpoint(variant.restore_path)
123 |
124 | online_buffer_size = 2 * variant.max_steps // variant.multi_grad_step
125 | online_replay_buffer = ReplayBuffer(dummy_env.observation_space, dummy_env.action_space, int(online_buffer_size))
126 | replay_buffer = online_replay_buffer
127 | replay_buffer.seed(variant.seed)
128 | trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger, shard_fn=shard_fn, agent_dp=agent_dp, robot_config=robot_config)
129 |
--------------------------------------------------------------------------------
/jaxrl2/networks/encoders/resnet_encoderv1.py:
--------------------------------------------------------------------------------
1 | import flax.linen as nn
2 | import jax.numpy as jnp
3 |
4 | from functools import partial
5 | from typing import Any, Callable, Sequence, Tuple
6 | from jaxrl2.networks.constants import default_init, xavier_init, kaiming_init
7 | from jaxrl2.networks.encoders.spatial_softmax import SpatialSoftmax
8 | from jaxrl2.networks.encoders.cross_norm import CrossNorm
9 |
10 | ModuleDef = Any
11 |
12 |
13 | class MyGroupNorm(nn.GroupNorm):
14 |
15 | def __call__(self, x):
16 | if x.ndim == 3:
17 | x = x[jnp.newaxis]
18 | x = super().__call__(x)
19 | return x[0]
20 | else:
21 | return super().__call__(x)
22 |
23 | class ResNetBlock(nn.Module):
24 | """ResNet block."""
25 | filters: int
26 | conv: ModuleDef
27 | norm: ModuleDef
28 | act: Callable
29 | strides: Tuple[int, int] = (1, 1)
30 |
31 | @nn.compact
32 | def __call__(self, x, ):
33 | residual = x
34 | y = self.conv(self.filters, (3, 3), self.strides)(x)
35 | y = self.norm()(y)
36 | y = self.act(y)
37 | y = self.conv(self.filters, (3, 3))(y)
38 | y = self.norm()(y)
39 |
40 | if residual.shape != y.shape:
41 | residual = self.conv(self.filters, (1, 1),
42 | self.strides, name='conv_proj')(residual)
43 | residual = self.norm(name='norm_proj')(residual)
44 |
45 | return self.act(residual + y)
46 |
47 |
48 | class BottleneckResNetBlock(nn.Module):
49 | """Bottleneck ResNet block."""
50 | filters: int
51 | conv: ModuleDef
52 | norm: ModuleDef
53 | act: Callable
54 | strides: Tuple[int, int] = (1, 1)
55 |
56 | @nn.compact
57 | def __call__(self, x):
58 | residual = x
59 | y = self.conv(self.filters, (1, 1))(x)
60 | y = self.norm()(y)
61 | y = self.act(y)
62 | y = self.conv(self.filters, (3, 3), self.strides)(y)
63 | y = self.norm()(y)
64 | y = self.act(y)
65 | y = self.conv(self.filters * 4, (1, 1))(y)
66 | y = self.norm(scale_init=nn.initializers.zeros)(y)
67 |
68 | if residual.shape != y.shape:
69 | residual = self.conv(self.filters * 4, (1, 1),
70 | self.strides, name='conv_proj')(residual)
71 | residual = self.norm(name='norm_proj')(residual)
72 |
73 | return self.act(residual + y)
74 |
75 |
76 | class ResNetEncoder(nn.Module):
77 | """ResNetV1."""
78 | stage_sizes: Sequence[int]
79 | block_cls: ModuleDef
80 | num_filters: int = 64
81 | dtype: Any = jnp.float32
82 | act: Callable = nn.relu
83 | conv: ModuleDef = nn.Conv
84 | norm: str = 'batch'
85 | use_spatial_softmax: bool = True
86 | softmax_temperature: float = 1.0
87 |
88 | @nn.compact
89 | def __call__(self, observations: jnp.ndarray, train: bool = True) -> jnp.ndarray:
90 |
91 | x = observations.astype(jnp.float32) / 255.0
92 | x = jnp.reshape(x, (*x.shape[:-2], -1))
93 |
94 | conv = partial(self.conv, use_bias=False, dtype=self.dtype, kernel_init=kaiming_init())
95 | if self.norm == 'batch':
96 | norm = partial(nn.BatchNorm,
97 | use_running_average=not train,
98 | momentum=0.9,
99 | epsilon=1e-5,
100 | dtype=self.dtype)
101 | elif self.norm == 'group':
102 | norm = partial(MyGroupNorm,
103 | num_groups=4,
104 | epsilon=1e-5,
105 | dtype=self.dtype)
106 | elif self.norm == 'cross':
107 | norm = partial(CrossNorm,
108 | use_running_average=not train,
109 | momentum=0.9,
110 | epsilon=1e-5,
111 | dtype=self.dtype)
112 | elif self.norm == 'layer':
113 | norm = partial(nn.LayerNorm,
114 | epsilon=1e-5,
115 | dtype=self.dtype,
116 | )
117 | else:
118 | raise ValueError('norm not found')
119 |
120 | # print('input ', x.shape)
121 | strides = (2, 2, 2, 1, 1)
122 | x = conv(self.num_filters, (7, 7), (strides[0], strides[0]),
123 | padding=[(3, 3), (3, 3)],
124 | name='conv_init')(x)
125 | # print('post conv1', x.shape)
126 |
127 | x = norm(name='bn_init')(x)
128 | x = nn.relu(x)
129 | x = nn.max_pool(x, (3, 3), strides=(strides[1], strides[1]), padding='SAME')
130 | # print('post maxpool1', x.shape)
131 | for i, block_size in enumerate(self.stage_sizes):
132 | for j in range(block_size):
133 | stride = (strides[i + 1], strides[i + 1]) if i > 0 and j == 0 else (1, 1)
134 | x = self.block_cls(self.num_filters * 2 ** i,
135 | strides=stride,
136 | conv=conv,
137 | norm=norm,
138 | act=self.act)(x)
139 | # print('post block layer ', x.shape)
140 | # print('post block ', x.shape)
141 |
142 | if self.use_spatial_softmax:
143 | height, width, channel = x.shape[len(x.shape) - 3:]
144 | pos_x, pos_y = jnp.meshgrid(
145 | jnp.linspace(-1., 1., height),
146 | jnp.linspace(-1., 1., width)
147 | )
148 | pos_x = pos_x.reshape(height * width)
149 | pos_y = pos_y.reshape(height * width)
150 | # print('pre spatial softmax', x.shape)
151 | x = SpatialSoftmax(height, width, channel, pos_x, pos_y, self.softmax_temperature)(x)
152 | # print('post spatial softmax', x.shape)
153 | else:
154 | x = jnp.mean(x, axis=(len(x.shape) - 3,len(x.shape) - 2))
155 | # print('post flatten', x.shape)
156 | return x
157 |
158 | ResNetSmall = partial(ResNetEncoder, stage_sizes=(1, 1, 1, 1),
159 | block_cls=ResNetBlock)
160 | ResNet18 = partial(ResNetEncoder, stage_sizes=(2, 2, 2, 2),
161 | block_cls=ResNetBlock)
162 | ResNet34 = partial(ResNetEncoder, stage_sizes=(3, 4, 6, 3),
163 | block_cls=ResNetBlock)
164 |
--------------------------------------------------------------------------------
/jaxrl2/agents/common.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Callable, Tuple, Any
3 |
4 | import distrax
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 |
9 | from jaxrl2.data.dataset import DatasetDict
10 | from jaxrl2.types import Params, PRNGKey
11 | import flax.linen as nn
12 | from typing import Any, Callable, Dict, Sequence, Union
13 |
14 | # Helps to minimize CPU to GPU transfer.
15 | def _unpack(batch):
16 | # Assuming that if next_observation is missing, it's combined with observation:
17 | obs_pixels = batch['observations']['pixels'][..., :-1]
18 | next_obs_pixels = batch['observations']['pixels'][..., 1:]
19 |
20 | obs = batch['observations'].copy(add_or_replace={'pixels': obs_pixels})
21 | next_obs = batch['next_observations'].copy(
22 | add_or_replace={'pixels': next_obs_pixels})
23 |
24 | batch = batch.copy(add_or_replace={
25 | 'observations': obs,
26 | 'next_observations': next_obs
27 | })
28 |
29 | return batch
30 |
31 | @partial(jax.jit, static_argnames='actor_apply_fn')
32 | def eval_log_prob_jit(actor_apply_fn: Callable[..., distrax.Distribution],
33 | actor_params: Params, actor_batch_stats: Any, batch: DatasetDict) -> float:
34 | # batch = _unpack(batch)
35 | input_collections = {'params': actor_params}
36 | if actor_batch_stats is not None:
37 | input_collections['batch_stats'] = actor_batch_stats
38 | dist = actor_apply_fn(input_collections,
39 | batch['observations'],
40 | training=False,
41 | mutable=False)
42 | log_probs = dist.log_prob(batch['actions'])
43 | return log_probs.mean()
44 |
45 | @partial(jax.jit, static_argnames='actor_apply_fn')
46 | def eval_mse_jit(actor_apply_fn: Callable[..., distrax.Distribution],
47 | actor_params: Params, actor_batch_stats: Any, batch: DatasetDict) -> float:
48 | # batch = _unpack(batch)
49 | input_collections = {'params': actor_params}
50 | if actor_batch_stats is not None:
51 | input_collections['batch_stats'] = actor_batch_stats
52 | dist = actor_apply_fn(input_collections,
53 | batch['observations'],
54 | training=False,
55 | mutable=False)
56 | mse = (dist.loc - batch['actions']) ** 2
57 | return mse.mean()
58 |
59 | def eval_reward_function_jit(actor_apply_fn: Callable[..., distrax.Distribution],
60 | actor_params: Params, actor_batch_stats: Any, batch: DatasetDict) -> float:
61 | # batch = _unpack(batch)
62 | input_collections = {'params': actor_params}
63 | if actor_batch_stats is not None:
64 | input_collections['batch_stats'] = actor_batch_stats
65 | dist = actor_apply_fn(input_collections,
66 | batch['observations'],
67 | training=False,
68 | mutable=False)
69 | pred = dist.mode().reshape(-1)
70 | loss = - (batch['rewards'] * jnp.log(1. / (1. + jnp.exp(-pred))) + (1.0 - batch['rewards']) * jnp.log(1. - 1. / (1. + jnp.exp(-pred))))
71 | return loss.mean()
72 |
73 |
74 | @partial(jax.jit, static_argnames='actor_apply_fn')
75 | def eval_actions_jit(actor_apply_fn: Callable[..., distrax.Distribution],
76 | actor_params: Params,
77 | observations: np.ndarray,
78 | actor_batch_stats: Any) -> jnp.ndarray:
79 | input_collections = {'params': actor_params}
80 | if actor_batch_stats is not None:
81 | input_collections['batch_stats'] = actor_batch_stats
82 | dist = actor_apply_fn(input_collections, observations, training=False,
83 | mutable=False)
84 | return dist.mode()
85 |
86 |
87 | @partial(jax.jit, static_argnames='actor_apply_fn')
88 | def sample_actions_jit(
89 | rng: PRNGKey, actor_apply_fn: Callable[..., distrax.Distribution],
90 | actor_params: Params,
91 | observations: np.ndarray,
92 | actor_batch_stats: Any) -> Tuple[PRNGKey, jnp.ndarray]:
93 | input_collections = {'params': actor_params}
94 | if actor_batch_stats is not None:
95 | input_collections['batch_stats'] = actor_batch_stats
96 | dist = actor_apply_fn(input_collections, observations)
97 | rng, key = jax.random.split(rng)
98 | return rng, dist.sample(seed=key)
99 |
100 |
101 | class ModuleDict(nn.Module):
102 | """
103 | from https://github.com/rail-berkeley/jaxrl_minimal/blob/main/jaxrl_m/common/common.py#L33
104 | Utility class for wrapping a dictionary of modules. This is useful when you have multiple modules that you want to
105 | initialize all at once (creating a single `params` dictionary), but you want to be able to call them separately
106 | later. As a bonus, the modules may have sub-modules nested inside them that share parameters (e.g. an image encoder)
107 | and Flax will automatically handle this without duplicating the parameters.
108 |
109 | To initialize the modules, call `init` with no `name` kwarg, and then pass the example arguments to each module as
110 | additional kwargs. To call the modules, pass the name of the module as the `name` kwarg, and then pass the arguments
111 | to the module as additional args or kwargs.
112 |
113 | Example usage:
114 | ```
115 | shared_encoder = Encoder()
116 | actor = Actor(encoder=shared_encoder)
117 | critic = Critic(encoder=shared_encoder)
118 |
119 | model_def = ModuleDict({"actor": actor, "critic": critic})
120 | params = model_def.init(rng_key, actor=example_obs, critic=(example_obs, example_action))
121 |
122 | actor_output = model_def.apply({"params": params}, example_obs, name="actor")
123 | critic_output = model_def.apply({"params": params}, example_obs, action=example_action, name="critic")
124 | ```
125 | """
126 |
127 | modules: Dict[str, nn.Module]
128 |
129 | @nn.compact
130 | def __call__(self, *args, name=None, **kwargs):
131 | if name is None:
132 | if kwargs.keys() != self.modules.keys():
133 | raise ValueError(
134 | f"When `name` is not specified, kwargs must contain the arguments for each module. "
135 | f"Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}"
136 | )
137 | out = {}
138 | for key, value in kwargs.items():
139 | if isinstance(value, Mapping):
140 | out[key] = self.modules[key](**value)
141 | elif isinstance(value, Sequence):
142 | out[key] = self.modules[key](*value)
143 | else:
144 | out[key] = self.modules[key](value)
145 | return out
146 |
147 | return self.modules[name](*args, **kwargs)
--------------------------------------------------------------------------------
/examples/train_sim.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | import os
3 | # Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs from https://github.com/huggingface/gym-aloha/tree/main?tab=readme-ov-file#-gpu-rendering-egl
4 | xla_flags = os.environ.get('XLA_FLAGS', '')
5 | xla_flags += ' --xla_gpu_triton_gemm_any=True'
6 | os.environ['XLA_FLAGS'] = xla_flags
7 |
8 | import pathlib, copy
9 |
10 | import jax
11 | from jaxrl2.agents.pixel_sac.pixel_sac_learner import PixelSACLearner
12 | from jaxrl2.utils.general_utils import add_batch_dim
13 | import numpy as np
14 |
15 | import gymnasium as gym
16 | import gym_aloha
17 | from gym.spaces import Dict, Box
18 |
19 | from libero.libero import benchmark
20 | from libero.libero import get_libero_path
21 | from libero.libero.envs import OffScreenRenderEnv
22 |
23 | from jaxrl2.data import ReplayBuffer
24 | from jaxrl2.utils.wandb_logger import WandBLogger, create_exp_name
25 | import tempfile
26 | from functools import partial
27 | from examples.train_utils_sim import trajwise_alternating_training_loop
28 | import tensorflow as tf
29 | from jax.experimental.compilation_cache import compilation_cache
30 |
31 | from openpi.training import config as openpi_config
32 | from openpi.policies import policy_config
33 | from openpi.shared import download
34 |
35 | home_dir = os.environ['HOME']
36 | compilation_cache.initialize_cache(os.path.join(home_dir, 'jax_compilation_cache'))
37 |
38 | def _get_libero_env(task, resolution, seed):
39 | """Initializes and returns the LIBERO environment, along with the task description."""
40 | task_description = task.language
41 | task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
42 | env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
43 | env = OffScreenRenderEnv(**env_args)
44 | env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
45 | return env, task_description
46 |
47 | def shard_batch(batch, sharding):
48 | """Shards a batch across devices along its first dimension.
49 |
50 | Args:
51 | batch: A pytree of arrays.
52 | sharding: A jax Sharding object with shape (num_devices,).
53 | """
54 | return jax.tree_util.tree_map(
55 | lambda x: jax.device_put(
56 | x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1)))
57 | ),
58 | batch,
59 | )
60 |
61 |
62 | class DummyEnv(gym.ObservationWrapper):
63 |
64 | def __init__(self, variant):
65 | self.variant = variant
66 | self.image_shape = (variant.resize_image, variant.resize_image, 3 * variant.num_cameras, 1)
67 | obs_dict = {}
68 | obs_dict['pixels'] = Box(low=0, high=255, shape=self.image_shape, dtype=np.uint8)
69 | if variant.add_states:
70 | if variant.env == 'libero':
71 | state_dim = 8
72 | elif variant.env == 'aloha_cube':
73 | state_dim = 14
74 | obs_dict['state'] = Box(low=-1.0, high=1.0, shape=(state_dim, 1), dtype=np.float32)
75 | self.observation_space = Dict(obs_dict)
76 | self.action_space = Box(low=-1, high=1, shape=(1, 32,), dtype=np.float32) # 32 is the noise action space of pi 0
77 |
78 |
79 | def main(variant):
80 | devices = jax.local_devices()
81 | num_devices = len(devices)
82 | assert variant.batch_size % num_devices == 0
83 | print('num devices', num_devices)
84 | print('batch size', variant.batch_size)
85 | # we shard the leading dimension (batch dimension) accross all devices evenly
86 | sharding = jax.sharding.PositionalSharding(devices)
87 | shard_fn = partial(shard_batch, sharding=sharding)
88 |
89 | # prevent tensorflow from using GPUs
90 | tf.config.set_visible_devices([], "GPU")
91 |
92 | kwargs = variant['train_kwargs']
93 | if kwargs.pop('cosine_decay', False):
94 | kwargs['decay_steps'] = variant.max_steps
95 |
96 | if not variant.prefix:
97 | import uuid
98 | variant.prefix = str(uuid.uuid4().fields[-1])[:5]
99 |
100 | if variant.suffix:
101 | expname = create_exp_name(variant.prefix, seed=variant.seed) + f"_{variant.suffix}"
102 | else:
103 | expname = create_exp_name(variant.prefix, seed=variant.seed)
104 |
105 | outputdir = os.path.join(os.environ['EXP'], expname)
106 | variant.outputdir = outputdir
107 | if not os.path.exists(outputdir):
108 | os.makedirs(outputdir)
109 | print('writing to output dir ', outputdir)
110 |
111 | if variant.env == 'libero':
112 | benchmark_dict = benchmark.get_benchmark_dict()
113 | task_suite = benchmark_dict["libero_90"]()
114 | task_id = 57
115 | task = task_suite.get_task(task_id)
116 | env, task_description = _get_libero_env(task, 256, variant.seed)
117 | eval_env = env
118 | variant.task_description = task_description
119 | variant.env_max_reward = 1
120 | variant.max_timesteps = 400
121 | elif variant.env == 'aloha_cube':
122 | from gymnasium.envs.registration import register
123 | register(
124 | id="gym_aloha/AlohaTransferCube-v0",
125 | entry_point="gym_aloha.env:AlohaEnv",
126 | max_episode_steps=400,
127 | nondeterministic=True,
128 | kwargs={"obs_type": "pixels", "task": "transfer_cube"},
129 | )
130 | env = gym.make("gym_aloha/AlohaTransferCube-v0", obs_type="pixels_agent_pos", render_mode="rgb_array")
131 | eval_env = copy.deepcopy(env)
132 | variant.env_max_reward = 4
133 | variant.max_timesteps = 400
134 |
135 |
136 | group_name = variant.prefix + '_' + variant.launch_group_id
137 | wandb_output_dir = tempfile.mkdtemp()
138 | wandb_logger = WandBLogger(variant.prefix != '', variant, variant.wandb_project, experiment_id=expname, output_dir=wandb_output_dir, group_name=group_name)
139 |
140 | dummy_env = DummyEnv(variant)
141 | sample_obs = add_batch_dim(dummy_env.observation_space.sample())
142 | sample_action = add_batch_dim(dummy_env.action_space.sample())
143 | print('sample obs shapes', [(k, v.shape) for k, v in sample_obs.items()])
144 | print('sample action shape', sample_action.shape)
145 |
146 |
147 | if variant.env == 'libero':
148 | config = openpi_config.get_config("pi0_libero")
149 | checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_libero")
150 | elif variant.env == 'aloha_cube':
151 | config = openpi_config.get_config("pi0_aloha_sim")
152 | checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_aloha_sim")
153 | else:
154 | raise NotImplementedError()
155 | agent_dp = policy_config.create_trained_policy(config, checkpoint_dir)
156 | print("Loaded pi0 policy from %s", checkpoint_dir)
157 | agent = PixelSACLearner(variant.seed, sample_obs, sample_action, **kwargs)
158 |
159 | online_buffer_size = variant.max_steps // variant.multi_grad_step
160 | online_replay_buffer = ReplayBuffer(dummy_env.observation_space, dummy_env.action_space, int(online_buffer_size))
161 | replay_buffer = online_replay_buffer
162 | replay_buffer.seed(variant.seed)
163 | trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger, shard_fn=shard_fn, agent_dp=agent_dp)
164 |
--------------------------------------------------------------------------------
/jaxrl2/data/replay_buffer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from typing import Iterable, Optional
3 | import jax
4 | import gym
5 | import gym.spaces
6 | import numpy as np
7 | import pickle
8 |
9 | import copy
10 |
11 | from jaxrl2.data.dataset import Dataset, DatasetDict
12 | import collections
13 | from flax.core import frozen_dict
14 |
15 | def _init_replay_dict(obs_space: gym.Space,
16 | capacity: int) -> Union[np.ndarray, DatasetDict]:
17 | if isinstance(obs_space, gym.spaces.Box):
18 | return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype)
19 | elif isinstance(obs_space, gym.spaces.Dict):
20 | data_dict = {}
21 | for k, v in obs_space.spaces.items():
22 | data_dict[k] = _init_replay_dict(v, capacity)
23 | return data_dict
24 | else:
25 | raise TypeError()
26 |
27 |
28 | class ReplayBuffer(Dataset):
29 |
30 | def __init__(self, observation_space: gym.Space, action_space: gym.Space, capacity: int, ):
31 | self.observation_space = observation_space
32 | self.action_space = action_space
33 | self.capacity = capacity
34 |
35 | print("making replay buffer of capacity ", self.capacity)
36 |
37 | observations = _init_replay_dict(self.observation_space, self.capacity)
38 | next_observations = _init_replay_dict(self.observation_space, self.capacity)
39 | actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype)
40 | next_actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype)
41 | rewards = np.empty((self.capacity, ), dtype=np.float32)
42 | masks = np.empty((self.capacity, ), dtype=np.float32)
43 | discount = np.empty((self.capacity, ), dtype=np.float32)
44 |
45 | self.data = {
46 | 'observations': observations,
47 | 'next_observations': next_observations,
48 | 'actions': actions,
49 | 'next_actions': next_actions,
50 | 'rewards': rewards,
51 | 'masks': masks,
52 | 'discount': discount,
53 | }
54 |
55 | self.size = 0
56 | self._traj_counter = 0
57 | self._start = 0
58 | self.traj_bounds = dict()
59 | self.streaming_buffer_size = None # this is for streaming the online data
60 |
61 | def __len__(self) -> int:
62 | return self.size
63 |
64 | def length(self) -> int:
65 | return self.size
66 |
67 | def increment_traj_counter(self):
68 | self.traj_bounds[self._traj_counter] = (self._start, self.size) # [start, end)
69 | self._start = self.size
70 | self._traj_counter += 1
71 |
72 | def get_random_trajs(self, num_trajs: int):
73 | self.which_trajs = np.random.randint(0, self._traj_counter, num_trajs)
74 | observations_list = []
75 | next_observations_list = []
76 | actions_list = []
77 | rewards_list = []
78 | terminals_list = []
79 | masks_list = []
80 | discount_list = []
81 |
82 | for i in self.which_trajs:
83 | start, end = self.traj_bounds[i]
84 |
85 | # handle this as a dictionary
86 | obs_dict_curr_traj = dict()
87 | for k in self.data['observations']:
88 | obs_dict_curr_traj[k] = self.data['observations'][k][start:end]
89 | observations_list.append(obs_dict_curr_traj)
90 |
91 | next_obs_dict_curr_traj = dict()
92 | for k in self.data['next_observations']:
93 | next_obs_dict_curr_traj[k] = self.data['next_observations'][k][start:end]
94 | next_observations_list.append(next_obs_dict_curr_traj)
95 |
96 | actions_list.append(self.data['actions'][start:end])
97 | rewards_list.append(self.data['rewards'][start:end])
98 | terminals_list.append(1-self.data['masks'][start:end])
99 | masks_list.append(self.data['masks'][start:end])
100 |
101 |
102 |
103 | batch = {
104 | 'observations': observations_list,
105 | 'next_observations': next_observations_list,
106 | 'actions': actions_list,
107 | 'rewards': rewards_list,
108 | 'terminals': terminals_list,
109 | 'masks': masks_list,
110 |
111 |
112 | }
113 | return batch
114 |
115 | def insert(self, data_dict: DatasetDict):
116 | if self.size == self.capacity:
117 | # Double the capacity
118 | observations = _init_replay_dict(self.observation_space, self.capacity)
119 | next_observations = _init_replay_dict(self.observation_space, self.capacity)
120 | actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype)
121 | next_actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype)
122 | rewards = np.empty((self.capacity, ), dtype=np.float32)
123 | masks = np.empty((self.capacity, ), dtype=np.float32)
124 | discount = np.empty((self.capacity, ), dtype=np.float32)
125 |
126 | data_new = {
127 | 'observations': observations,
128 | 'next_observations': next_observations,
129 | 'actions': actions,
130 | 'next_actions': next_actions,
131 | 'rewards': rewards,
132 | 'masks': masks,
133 | 'discount': discount,
134 | }
135 |
136 | for x in data_new:
137 | if isinstance(self.data[x], np.ndarray):
138 | self.data[x] = np.concatenate((self.data[x], data_new[x]), axis=0)
139 | elif isinstance(self.data[x], dict):
140 | for y in self.data[x]:
141 | self.data[x][y] = np.concatenate((self.data[x][y], data_new[x][y]), axis=0)
142 | else:
143 | raise TypeError()
144 | self.capacity *= 2
145 |
146 |
147 | for x in data_dict:
148 | if x in self.data:
149 | if isinstance(data_dict[x], dict):
150 | for y in data_dict[x]:
151 | self.data[x][y][self.size] = data_dict[x][y]
152 | else:
153 | self.data[x][self.size] = data_dict[x]
154 | self.size += 1
155 |
156 | def compute_action_stats(self):
157 | actions = self.data['actions']
158 | return {'mean': actions.mean(axis=0), 'std': actions.std(axis=0)}
159 |
160 | def normalize_actions(self, action_stats):
161 | # do not normalize gripper dimension (last dimension)
162 | copy.deepcopy(action_stats)
163 | action_stats['mean'][-1] = 0
164 | action_stats['std'][-1] = 1
165 | self.data['actions'] = (self.data['actions'] - action_stats['mean']) / action_stats['std']
166 | self.data['next_actions'] = (self.data['next_actions'] - action_stats['mean']) / action_stats['std']
167 |
168 | def sample(self, batch_size: int, keys: Optional[Iterable[str]] = None, indx: Optional[np.ndarray] = None) -> frozen_dict.FrozenDict:
169 | if self.streaming_buffer_size:
170 | indices = np.random.randint(0, self.streaming_buffer_size, batch_size)
171 | else:
172 | indices = np.random.randint(0, self.size, batch_size)
173 | data_dict = {}
174 | for x in self.data:
175 | if isinstance(self.data[x], np.ndarray):
176 | data_dict[x] = self.data[x][indices]
177 | elif isinstance(self.data[x], dict):
178 | data_dict[x] = {}
179 | for y in self.data[x]:
180 | data_dict[x][y] = self.data[x][y][indices]
181 | else:
182 | raise TypeError()
183 |
184 | return frozen_dict.freeze(data_dict)
185 |
186 | def get_iterator(self, batch_size: int, keys: Optional[Iterable[str]] = None, indx: Optional[np.ndarray] = None, queue_size: int = 2):
187 | # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
188 | # queue_size = 2 should be ok for one GPU.
189 |
190 | queue = collections.deque()
191 |
192 | def enqueue(n):
193 | for _ in range(n):
194 | data = self.sample(batch_size, keys, indx)
195 | queue.append(jax.device_put(data))
196 |
197 | enqueue(queue_size)
198 | while queue:
199 | yield queue.popleft()
200 | enqueue(1)
201 |
202 |
203 | def save(self, filename):
204 | save_dict = dict(
205 | data=self.data,
206 | size = self.size,
207 | _traj_counter = self._traj_counter,
208 | _start=self._start,
209 | traj_bounds=self.traj_bounds
210 | )
211 | with open(filename, 'wb') as f:
212 | pickle.dump(save_dict, f, protocol=4)
213 |
214 |
215 | def restore(self, filename):
216 | save_dict = np.load(filename, allow_pickle=True)[0]
217 | # todo test this:
218 | self.data = save_dict['data']
219 | self.size = save_dict['size']
220 | self._traj_counter = save_dict['_traj_counter']
221 | self._start = save_dict['_start']
222 | self.traj_bounds = save_dict['traj_bounds']
223 |
--------------------------------------------------------------------------------
/jaxrl2/networks/encoders/cross_norm.py:
--------------------------------------------------------------------------------
1 | """Normalization modules for Flax."""
2 |
3 | from typing import (Any, Callable, Optional, Tuple, Iterable, Union)
4 |
5 | from jax import lax
6 | from jax.nn import initializers
7 | import jax.numpy as jnp
8 |
9 | from flax.linen.module import Module, compact, merge_param
10 |
11 |
12 | PRNGKey = Any
13 | Array = Any
14 | Shape = Tuple[int]
15 | Dtype = Any # this could be a real type?
16 |
17 | Axes = Union[int, Iterable[int]]
18 |
19 | import flax.linen as nn
20 | from flax.linen.module import Module, compact, merge_param
21 |
22 | def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]:
23 | """Returns a tuple of deduplicated, sorted, and positive axes."""
24 | if not isinstance(axes, Iterable):
25 | axes = (axes,)
26 | return tuple(set([rank + axis if axis < 0 else axis for axis in axes]))
27 |
28 |
29 | def _abs_sq(x):
30 | """Computes the elementwise square of the absolute value |x|^2."""
31 | if jnp.iscomplexobj(x):
32 | return lax.square(lax.real(x)) + lax.square(lax.imag(x))
33 | else:
34 | return lax.square(x)
35 |
36 |
37 | def _compute_stats(x: Array, axes: Axes,
38 | axis_name: Optional[str] = None,
39 | axis_index_groups: Any = None,
40 | alpha: float = 0.5):
41 | """Computes mean and variance statistics.
42 | This implementation takes care of a few important details:
43 | - Computes in float32 precision for half precision inputs
44 | - mean and variance is computable in a single XLA fusion,
45 | by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
46 | - Clips negative variances to zero which can happen due to
47 | roundoff errors. This avoids downstream NaNs.
48 | - Supports averaging across a parallel axis and subgroups of a parallel axis
49 | with a single `lax.pmean` call to avoid latency.
50 | Arguments:
51 | x: Input array.
52 | axes: The axes in ``x`` to compute mean and variance statistics for.
53 | axis_name: Optional name for the pmapped axis to compute mean over.
54 | axis_index_groups: Optional axis indices.
55 | Returns:
56 | A pair ``(mean, var)``.
57 | """
58 | # promote x to at least float32, this avoids half precision computation
59 | # but preserves double or complex floating points
60 | x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
61 | split_1, split_2 = jnp.split(x, 2, axis=0) # split x into two parts
62 | mean_s1 = jnp.mean(split_1, axes)
63 | mean_s2 = jnp.mean(split_2, axes)
64 |
65 | mean2_s1 = jnp.mean(_abs_sq(split_1), axes)
66 | mean2_s2 = jnp.mean(_abs_sq(split_2), axes)
67 |
68 | mean = alpha * mean_s1 + (1 - alpha) * mean_s2
69 |
70 | if axis_name is not None:
71 | concatenated_mean = jnp.concatenate([mean, mean2])
72 | mean, mean2 = jnp.split(
73 | lax.pmean(
74 | concatenated_mean,
75 | axis_name=axis_name,
76 | axis_index_groups=axis_index_groups), 2)
77 | # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
78 | # to floating point round-off errors.
79 | var_s1 = mean2_s1 - _abs_sq(mean_s1)
80 | var_s2 = mean2_s2 - _abs_sq(mean_s2)
81 | var = alpha * var_s1 + (1 - alpha) * var_s2
82 |
83 | var = jnp.maximum(0., var)
84 | return mean, var
85 |
86 |
87 | def _normalize(mdl: Module, x: Array, mean: Array, var: Array,
88 | reduction_axes: Axes, feature_axes: Axes,
89 | dtype: Dtype, param_dtype: Dtype,
90 | epsilon: float,
91 | use_bias: bool, use_scale: bool,
92 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array],
93 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array]):
94 | """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
95 | Arguments:
96 | mdl: Module to apply the normalization in (normalization params will reside
97 | in this module).
98 | x: The input.
99 | mean: Mean to use for normalization.
100 | var: Variance to use for normalization.
101 | reduction_axes: The axes in ``x`` to reduce.
102 | feature_axes: Axes containing features. A separate bias and scale is learned
103 | for each specified feature.
104 | dtype: Dtype of the returned result.
105 | param_dtype: Dtype of the parameters.
106 | epsilon: Normalization epsilon.
107 | use_bias: If true, add a bias term to the output.
108 | use_scale: If true, scale the output.
109 | bias_init: Initialization function for the bias term.
110 | scale_init: Initialization function for the scaling function.
111 | Returns:
112 | The normalized input.
113 | """
114 | reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
115 | feature_axes = _canonicalize_axes(x.ndim, feature_axes)
116 | stats_shape = list(x.shape)
117 | for axis in reduction_axes:
118 | stats_shape[axis] = 1
119 | mean = mean.reshape(stats_shape)
120 | var = var.reshape(stats_shape)
121 | feature_shape = [1] * x.ndim
122 | reduced_feature_shape = []
123 | for ax in feature_axes:
124 | feature_shape[ax] = x.shape[ax]
125 | reduced_feature_shape.append(x.shape[ax])
126 | y = x - mean
127 | mul = lax.rsqrt(var + epsilon)
128 | if use_scale:
129 | scale = mdl.param('scale', scale_init, reduced_feature_shape,
130 | param_dtype).reshape(feature_shape)
131 | mul *= scale
132 | y *= mul
133 | if use_bias:
134 | bias = mdl.param('bias', bias_init, reduced_feature_shape,
135 | param_dtype).reshape(feature_shape)
136 | y += bias
137 | return jnp.asarray(y, dtype)
138 |
139 | class CrossNorm(Module):
140 | """CrossNorm Module.
141 | Usage Note:
142 | If we define a model with CrossNorm, for example::
143 | BN = nn.CrossNorm(use_running_average=False, momentum=0.9, epsilon=1e-5,
144 | dtype=jnp.float32)
145 | The initialized variables dict will contain in addition to a 'params'
146 | collection a separate 'batch_stats' collection that will contain all the
147 | running statistics for all the BatchNorm layers in a model::
148 | vars_initialized = BN.init(key, x) # {'params': ..., 'batch_stats': ...}
149 | We then update the batch_stats during training by specifying that the
150 | `batch_stats` collection is mutable in the `apply` method for our module.::
151 | vars_in = {'params': params, 'batch_stats': old_batch_stats}
152 | y, mutated_vars = BN.apply(vars_in, x, mutable=['batch_stats'])
153 | new_batch_stats = mutated_vars['batch_stats']
154 | During eval we would define BN with `use_running_average=True` and use the
155 | batch_stats collection from training to set the statistics. In this case
156 | we are not mutating the batch statistics collection, and needn't mark it
157 | mutable::
158 | vars_in = {'params': params, 'batch_stats': training_batch_stats}
159 | y = BN.apply(vars_in, x)
160 | Attributes:
161 | use_running_average: if True, the statistics stored in batch_stats
162 | will be used instead of computing the batch statistics on the input.
163 | axis: the feature or non-batch axis of the input.
164 | momentum: decay rate for the exponential moving average of
165 | the batch statistics.
166 | epsilon: a small float added to variance to avoid dividing by zero.
167 | dtype: the dtype of the computation (default: float32).
168 | param_dtype: the dtype passed to parameter initializers (default: float32).
169 | use_bias: if True, bias (beta) is added.
170 | use_scale: if True, multiply by scale (gamma).
171 | When the next layer is linear (also e.g. nn.relu), this can be disabled
172 | since the scaling will be done by the next layer.
173 | bias_init: initializer for bias, by default, zero.
174 | scale_init: initializer for scale, by default, one.
175 | axis_name: the axis name used to combine batch statistics from multiple
176 | devices. See `jax.pmap` for a description of axis names (default: None).
177 | axis_index_groups: groups of axis indices within that named axis
178 | representing subsets of devices to reduce over (default: None). For
179 | example, `[[0, 1], [2, 3]]` would independently batch-normalize over
180 | the examples on the first two and last two devices. See `jax.lax.psum`
181 | for more details.
182 |
183 | Note modified original BatchNorm module
184 | """
185 | use_running_average: Optional[bool] = None
186 | axis: int = -1
187 | momentum: float = 0.99
188 | epsilon: float = 1e-5
189 | dtype: Dtype = jnp.float32
190 | param_dtype: Dtype = jnp.float32
191 | use_bias: bool = True
192 | use_scale: bool = True
193 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
194 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
195 | axis_name: Optional[str] = None
196 | axis_index_groups: Any = None
197 | alpha: float = 0.5
198 |
199 | @compact
200 | def __call__(self, x, use_running_average: Optional[bool] = None):
201 | """Normalizes the input using batch statistics.
202 | NOTE:
203 | During initialization (when parameters are mutable) the running average
204 | of the batch statistics will not be updated. Therefore, the inputs
205 | fed during initialization don't need to match that of the actual input
206 | distribution and the reduction axis (set with `axis_name`) does not have
207 | to exist.
208 | Args:
209 | x: the input to be normalized.
210 | use_running_average: if true, the statistics stored in batch_stats
211 | will be used instead of computing the batch statistics on the input.
212 | Returns:
213 | Normalized inputs (the same shape as inputs).
214 | """
215 |
216 | use_running_average = merge_param(
217 | 'use_running_average', self.use_running_average, use_running_average)
218 | feature_axes = _canonicalize_axes(x.ndim, self.axis)
219 | reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
220 | feature_shape = [x.shape[ax] for ax in feature_axes]
221 |
222 | # see NOTE above on initialization behavior
223 | initializing = self.is_mutable_collection('params')
224 |
225 | ra_mean = self.variable('batch_stats', 'mean',
226 | lambda s: jnp.zeros(s, jnp.float32),
227 | feature_shape)
228 | ra_var = self.variable('batch_stats', 'var',
229 | lambda s: jnp.ones(s, jnp.float32),
230 | feature_shape)
231 |
232 | if use_running_average:
233 | mean, var = ra_mean.value, ra_var.value
234 | else:
235 | mean, var = _compute_stats(
236 | x, reduction_axes,
237 | axis_name=self.axis_name if not initializing else None,
238 | axis_index_groups=self.axis_index_groups, alpha=self.alpha)
239 |
240 | if not initializing:
241 | ra_mean.value = self.momentum * ra_mean.value + (1 -
242 | self.momentum) * mean
243 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
244 |
245 | return _normalize(
246 | self, x, mean, var, reduction_axes, feature_axes,
247 | self.dtype, self.param_dtype, self.epsilon,
248 | self.use_bias, self.use_scale,
249 | self.bias_init, self.scale_init)
--------------------------------------------------------------------------------
/examples/train_utils_real.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from tqdm import tqdm
4 | import time
5 | import numpy as np
6 | import jax
7 | import sys
8 | import select
9 | import tty
10 | import termios
11 | from openpi_client import image_tools
12 | from moviepy.editor import ImageSequenceClip
13 |
14 |
15 | def trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger,
16 | shard_fn=None, agent_dp=None, robot_config=None):
17 | replay_buffer_iterator = replay_buffer.get_iterator(variant.batch_size)
18 | if shard_fn is not None:
19 | replay_buffer_iterator = map(shard_fn, replay_buffer_iterator)
20 |
21 | i = 0
22 | total_env_steps = 0
23 | total_num_traj = 0
24 | wandb_logger.log({'num_online_samples': 0}, step=i)
25 | wandb_logger.log({'num_online_trajs': 0}, step=i)
26 | wandb_logger.log({'env_steps': 0}, step=i)
27 |
28 | with tqdm(total=variant.max_steps, initial=0) as pbar:
29 | while i <= variant.max_steps:
30 | traj = collect_traj(variant, agent, env, i, agent_dp, wandb_logger, total_num_traj, robot_config)
31 | total_num_traj += 1
32 | add_online_data_to_buffer(variant, traj, online_replay_buffer)
33 | total_env_steps += traj['env_steps']
34 | print('online buffer timesteps length:', len(online_replay_buffer))
35 | print('online buffer num traj:', total_num_traj)
36 | print('total env steps:', total_env_steps)
37 |
38 | if i == 0:
39 | num_gradsteps = 5000
40 | else:
41 | num_gradsteps = len(traj["rewards"]) * variant.multi_grad_step
42 | print(f'num_gradsteps: {num_gradsteps}')
43 | if total_num_traj >= variant.num_initial_traj_collect:
44 | for _ in range(num_gradsteps):
45 |
46 | batch = next(replay_buffer_iterator)
47 | update_info = agent.update(batch)
48 |
49 | pbar.update()
50 | i += 1
51 |
52 | if i % variant.log_interval == 0:
53 | update_info = {k: jax.device_get(v) for k, v in update_info.items()}
54 | for k, v in update_info.items():
55 | if v.ndim == 0:
56 | wandb_logger.log({f'training/{k}': v}, step=i)
57 | elif v.ndim <= 2:
58 | wandb_logger.log_histogram(f'training/{k}', v, i)
59 | wandb_logger.log({
60 | 'replay_buffer_size': len(online_replay_buffer),
61 | 'is_success (exploration)': int(traj['is_success']),
62 | }, i)
63 |
64 | if i % variant.eval_interval == 0:
65 | wandb_logger.log({'num_online_samples': len(online_replay_buffer)}, step=i)
66 | wandb_logger.log({'num_online_trajs': total_num_traj}, step=i)
67 | wandb_logger.log({'env_steps': total_env_steps}, step=i)
68 | if hasattr(agent, 'perform_eval'):
69 | agent.perform_eval(variant, i, wandb_logger, replay_buffer, replay_buffer_iterator, eval_env)
70 |
71 | if variant.checkpoint_interval != -1:
72 | if i % variant.checkpoint_interval == 0:
73 | agent.save_checkpoint(variant.outputdir, i, variant.checkpoint_interval)
74 |
75 | def add_online_data_to_buffer(variant, traj, online_replay_buffer):
76 |
77 | discount_horizon = variant.query_freq
78 | actions = np.array(traj['actions']) # (T, chunk_size, 14)
79 | episode_len = len(actions)
80 | rewards = np.array(traj['rewards'])
81 | masks = np.array(traj['masks'])
82 |
83 | for t in range(episode_len):
84 | obs = traj['observations'][t]
85 | next_obs = traj['observations'][t + 1]
86 | # remove batch dimension
87 | obs = {k: v[0] for k, v in obs.items()}
88 | next_obs = {k: v[0] for k, v in next_obs.items()}
89 | if not variant.add_states:
90 | obs.pop('state', None)
91 | next_obs.pop('state', None)
92 |
93 | insert_dict = dict(
94 | observations=obs,
95 | next_observations=next_obs,
96 | actions=actions[t],
97 | next_actions=actions[t + 1] if t < episode_len - 1 else actions[t],
98 | rewards=rewards[t],
99 | masks=masks[t],
100 | discount=variant.discount ** discount_horizon
101 | )
102 | online_replay_buffer.insert(insert_dict)
103 | online_replay_buffer.increment_traj_counter()
104 |
105 | def collect_traj(variant, agent, env, i, agent_dp=None, wandb_logger=None, traj_id=None, robot_config=None):
106 | query_frequency = variant.query_freq
107 | instruction = variant.instruction
108 | max_timesteps = robot_config['max_timesteps']
109 | agent._rng, rng = jax.random.split(agent._rng)
110 | try:
111 | env.reset()
112 | except Exception as e:
113 | print(f"Environment reset failed")
114 | import traceback
115 | traceback.print_exc()
116 | import pdb; pdb.set_trace()
117 | step_time = 1 / 15 # 15 Hz
118 | last_step_time = time.time()
119 | old_settings = termios.tcgetattr(sys.stdin)
120 |
121 | rewards = []
122 | action_list = []
123 | obs_list = []
124 | image_list = []
125 |
126 | old_settings = termios.tcgetattr(sys.stdin)
127 | try:
128 | tty.setcbreak(sys.stdin.fileno())
129 | for t in tqdm(range(max_timesteps)):
130 | # Check for keyboard input
131 | if select.select([sys.stdin], [], [], 0) == ([sys.stdin], [], []):
132 | char_input = sys.stdin.read(1)
133 | if char_input.lower() == 'q':
134 | print("'q' pressed, stopping loop.")
135 | break
136 |
137 | try:
138 | _env_obs = env.get_observation()
139 | except Exception as e:
140 | print(f"Environment get obs failed")
141 | import traceback
142 | traceback.print_exc()
143 | import pdb; pdb.set_trace()
144 | curr_obs = _extract_observation(
145 | robot_config,
146 | _env_obs,
147 | )
148 | image_list.append(curr_obs[robot_config['camera_to_use'] + "_image"])
149 |
150 | request_data = get_pi0_input(curr_obs, robot_config, instruction)
151 |
152 | if t % query_frequency == 0:
153 |
154 | rng, key = jax.random.split(rng)
155 |
156 | img_all = process_images(variant, curr_obs)
157 |
158 | # extract the feature from the pi0 VLM backbone and concat with the qpos as states
159 | img_rep_pi0, _ = agent_dp.get_prefix_rep(request_data)
160 | img_rep_pi0 = img_rep_pi0[:, -1, :] # (1, 2048)
161 | qpos = np.concatenate([curr_obs["joint_position"], curr_obs["gripper_position"], img_rep_pi0.flatten()])
162 |
163 | obs_dict = {
164 | 'pixels': img_all,
165 | 'state': qpos[np.newaxis, ..., np.newaxis],
166 | }
167 | if i == 0:
168 | noise = jax.random.normal(key, (1, *agent.action_chunk_shape))
169 | noise_repeat = jax.numpy.repeat(noise[:, -1:, :], 10 - noise.shape[1], axis=1)
170 | noise = jax.numpy.concatenate([noise, noise_repeat], axis=1)
171 | actions_noise = noise[0, :agent.action_chunk_shape[0], :]
172 | else:
173 | # sac agent predicts the noise for diffusion model
174 | actions_noise = agent.sample_actions(obs_dict)
175 | actions_noise = np.reshape(actions_noise, agent.action_chunk_shape)
176 | noise = np.repeat(actions_noise[-1:, :], 10 - actions_noise.shape[0], axis=0)
177 | noise = jax.numpy.concatenate([actions_noise, noise], axis=0)[None]
178 | action_list.append(actions_noise)
179 | obs_list.append(obs_dict)
180 | action = agent_dp.infer(request_data, noise=np.asarray(noise))["actions"]
181 |
182 | action_t = action[t % query_frequency]
183 |
184 | # binarize gripper action.
185 | if action_t[-1].item() > 0.5:
186 | action_t = np.concatenate([action_t[:-1], np.ones((1,))])
187 | else:
188 | action_t = np.concatenate([action_t[:-1], np.zeros((1,))])
189 | action_t = np.clip(action_t, -1, 1)
190 |
191 | try:
192 | env.step(action_t)
193 | except Exception as e:
194 | print(f"Environment step failed")
195 | import traceback
196 | traceback.print_exc() # This prints the full traceback
197 | import pdb; pdb.set_trace()
198 |
199 | now = time.time()
200 | dt = now - last_step_time
201 | if dt < step_time:
202 | time.sleep(step_time - dt)
203 | last_step_time = time.time()
204 | else:
205 | last_step_time = now
206 |
207 | print("Trial finished. Mark as (1) Success or (0) Failure:")
208 | while True:
209 | if select.select([sys.stdin], [], [], 0) == ([sys.stdin], [], []):
210 | char_input = sys.stdin.read(1)
211 | if char_input == '1':
212 | print("Trial marked as SUCCESS.")
213 | is_success = True
214 | break
215 | elif char_input == '0':
216 | print("Trial marked as FAILURE.")
217 | is_success = False
218 | break
219 | else:
220 | print("Invalid input. Please enter '1' for Success or '0' for Failure:")
221 | time.sleep(0.01) # Small sleep to prevent busy-waiting if no input
222 |
223 | try:
224 | _env_obs = env.get_observation()
225 | except Exception as e:
226 | print(f"Environment get obs failed")
227 | import traceback
228 | traceback.print_exc()
229 | import pdb; pdb.set_trace()
230 |
231 | # add last observation
232 | curr_obs = _extract_observation(
233 | robot_config,
234 | _env_obs,
235 | )
236 | image_list.append(curr_obs[robot_config['camera_to_use'] + "_image"])
237 | request_data = get_pi0_input(curr_obs, robot_config, instruction)
238 | img_all = process_images(variant, curr_obs)
239 | img_rep_pi0, _ = agent_dp.get_prefix_rep(request_data)
240 | img_rep_pi0 = img_rep_pi0[:, -1, :] # (1, 2048)
241 | qpos = np.concatenate([curr_obs["joint_position"], curr_obs["gripper_position"], img_rep_pi0.flatten()])
242 | obs_dict = {
243 | 'pixels': img_all,
244 | 'state': qpos[np.newaxis, ..., np.newaxis],
245 | }
246 | obs_list.append(obs_dict)
247 | print(f'Rollout Done')
248 |
249 | finally:
250 | if is_success:
251 | query_steps = len(action_list)
252 | rewards = np.concatenate([-np.ones(query_steps - 1), [0]])
253 | masks = np.concatenate([np.ones(query_steps - 1), [0]])
254 | else:
255 | query_steps = len(action_list)
256 | rewards = -np.ones(query_steps)
257 | masks = np.ones(query_steps)
258 |
259 | if wandb_logger is not None:
260 | wandb_logger.log({f'is_success': int(is_success)}, step=i)
261 | wandb_logger.log({f'total_num_traj': traj_id}, step=i)
262 |
263 | video_path = os.path.join(variant.outputdir, f'video_high_{traj_id}.mp4')
264 | video = np.stack(image_list)
265 | ImageSequenceClip(list(video), fps=15).write_videofile(video_path, codec="libx264")
266 |
267 | print("Episide Done! Press c after resetting the environment")
268 | try:
269 | env.reset()
270 | except Exception as e:
271 | print(f"Environment reset failed")
272 | import traceback
273 | traceback.print_exc() # This prints the full traceback
274 | import pdb; pdb.set_trace()
275 | import pdb; pdb.set_trace()
276 | termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
277 |
278 | traj = {
279 | 'observations': obs_list,
280 | 'actions': action_list,
281 | 'rewards': rewards,
282 | 'masks': masks,
283 | 'is_success': is_success,
284 | 'env_steps': t + 1,
285 | }
286 |
287 | return traj
288 |
289 |
290 | def _extract_observation(robot_config, obs_dict):
291 | '''
292 | from https://github.com/Physical-Intelligence/openpi/blob/main/examples/droid/main.py
293 | '''
294 | image_observations = obs_dict["image"]
295 | left_image, right_image, wrist_image = None, None, None
296 | for key in image_observations.keys():
297 | if robot_config['left_camera_id'] in key and "left" in key:
298 | left_image = image_observations[key]
299 | elif robot_config['right_camera_id'] in key and "left" in key:
300 | right_image = image_observations[key]
301 | elif robot_config['wrist_camera_id'] in key and "left" in key:
302 | wrist_image = image_observations[key]
303 |
304 | # Drop the alpha dimension
305 | left_image = left_image[..., :3]
306 | right_image = right_image[..., :3]
307 | wrist_image = wrist_image[..., :3]
308 |
309 | # Convert to RGB
310 | left_image = left_image[..., ::-1]
311 | right_image = right_image[..., ::-1]
312 | wrist_image = wrist_image[..., ::-1]
313 |
314 | # In addition to image observations, also capture the proprioceptive state
315 | robot_state = obs_dict["robot_state"]
316 | cartesian_position = np.array(robot_state["cartesian_position"])
317 | joint_position = np.array(robot_state["joint_positions"])
318 | gripper_position = np.array([robot_state["gripper_position"]])
319 |
320 | return {
321 | "left_image": left_image,
322 | "right_image": right_image,
323 | "wrist_image": wrist_image,
324 | "cartesian_position": cartesian_position,
325 | "joint_position": joint_position,
326 | "gripper_position": gripper_position,
327 | }
328 |
329 | def get_pi0_input(obs, robot_config, instruction):
330 | external_image = obs[robot_config['camera_to_use'] + "_image"]
331 | request_data = {
332 | "observation/exterior_image_1_left": image_tools.resize_with_pad(
333 | external_image, 224, 224
334 | ),
335 | "observation/wrist_image_left": image_tools.resize_with_pad(obs["wrist_image"], 224, 224),
336 | "observation/joint_position": obs["joint_position"],
337 | "observation/gripper_position": obs["gripper_position"],
338 | "prompt": instruction,
339 | }
340 | return request_data
341 |
342 |
343 | def process_images(variant, obs):
344 | '''
345 | concat the images from all cameras
346 | '''
347 | im1 = image_tools.resize_with_pad(obs["left_image"], variant.resize_image, variant.resize_image)
348 | im2 = image_tools.resize_with_pad(obs["right_image"], variant.resize_image, variant.resize_image)
349 | im3 = image_tools.resize_with_pad(obs["wrist_image"], variant.resize_image, variant.resize_image)
350 | img_all = np.concatenate([im1, im2, im3], axis=2)[np.newaxis, ..., np.newaxis]
351 | return img_all
--------------------------------------------------------------------------------
/jaxrl2/agents/pixel_sac/pixel_sac_learner.py:
--------------------------------------------------------------------------------
1 | """Implementations of algorithms for continuous control."""
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | from flax.training import checkpoints
5 | import pathlib
6 | import matplotlib.pyplot as plt
7 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8 |
9 | import numpy as np
10 | import copy
11 | import functools
12 | from typing import Dict, Optional, Sequence, Tuple, Union
13 |
14 | import jax
15 | import jax.numpy as jnp
16 | import optax
17 | from flax.core.frozen_dict import FrozenDict
18 | from flax.training import train_state
19 | from typing import Any
20 |
21 | from jaxrl2.agents.agent import Agent
22 | from jaxrl2.data.augmentations import batched_random_crop, color_transform
23 | from jaxrl2.networks.encoders.networks import Encoder, PixelMultiplexer
24 | from jaxrl2.networks.encoders.impala_encoder import ImpalaEncoder, SmallerImpalaEncoder
25 | from jaxrl2.networks.encoders.resnet_encoderv1 import ResNet18, ResNet34, ResNetSmall
26 | from jaxrl2.networks.encoders.resnet_encoderv2 import ResNetV2Encoder
27 | from jaxrl2.agents.pixel_sac.actor_updater import update_actor
28 | from jaxrl2.agents.pixel_sac.critic_updater import update_critic
29 | from jaxrl2.agents.pixel_sac.temperature_updater import update_temperature
30 | from jaxrl2.agents.pixel_sac.temperature import Temperature
31 | from jaxrl2.data.dataset import DatasetDict
32 | from jaxrl2.networks.learned_std_normal_policy import LearnedStdTanhNormalPolicy
33 | from jaxrl2.networks.values import StateActionEnsemble
34 | from jaxrl2.types import Params, PRNGKey
35 | from jaxrl2.utils.target_update import soft_target_update
36 |
37 |
38 | class TrainState(train_state.TrainState):
39 | batch_stats: Any
40 |
41 | @functools.partial(jax.jit, static_argnames=('critic_reduction', 'color_jitter', 'aug_next', 'num_cameras'))
42 | def _update_jit(
43 | rng: PRNGKey, actor: TrainState, critic: TrainState,
44 | target_critic_params: Params, temp: TrainState, batch: TrainState,
45 | discount: float, tau: float, target_entropy: float,
46 | critic_reduction: str, color_jitter: bool, aug_next: bool, num_cameras: int,
47 | ) -> Tuple[PRNGKey, TrainState, TrainState, Params, TrainState, Dict[str,float]]:
48 | aug_pixels = batch['observations']['pixels']
49 | aug_next_pixels = batch['next_observations']['pixels']
50 | if batch['observations']['pixels'].squeeze().ndim != 2:
51 | rng, key = jax.random.split(rng)
52 | aug_pixels = batched_random_crop(key, batch['observations']['pixels'])
53 |
54 | if color_jitter:
55 | rng, key = jax.random.split(rng)
56 | if num_cameras > 1:
57 | for i in range(num_cameras):
58 | aug_pixels = aug_pixels.at[:,:,:,i*3:(i+1)*3].set((color_transform(key, aug_pixels[:,:,:,i*3:(i+1)*3].astype(jnp.float32)/255.)*255).astype(jnp.uint8))
59 | else:
60 | aug_pixels = (color_transform(key, aug_pixels.astype(jnp.float32)/255.)*255).astype(jnp.uint8)
61 |
62 | observations = batch['observations'].copy(add_or_replace={'pixels': aug_pixels})
63 | batch = batch.copy(add_or_replace={'observations': observations})
64 |
65 | key, rng = jax.random.split(rng)
66 | if aug_next:
67 | rng, key = jax.random.split(rng)
68 | aug_next_pixels = batched_random_crop(key, batch['next_observations']['pixels'])
69 | if color_jitter:
70 | rng, key = jax.random.split(rng)
71 | if num_cameras > 1:
72 | for i in range(num_cameras):
73 | aug_next_pixels = aug_next_pixels.at[:,:,:,i*3:(i+1)*3].set((color_transform(key, aug_next_pixels[:,:,:,i*3:(i+1)*3].astype(jnp.float32)/255.)*255).astype(jnp.uint8))
74 | else:
75 | aug_next_pixels = (color_transform(key, aug_next_pixels.astype(jnp.float32)/255.)*255).astype(jnp.uint8)
76 | next_observations = batch['next_observations'].copy(
77 | add_or_replace={'pixels': aug_next_pixels})
78 | batch = batch.copy(add_or_replace={'next_observations': next_observations})
79 |
80 | key, rng = jax.random.split(rng)
81 | target_critic = critic.replace(params=target_critic_params)
82 | new_critic, critic_info = update_critic(key, actor, critic, target_critic, temp, batch, discount, critic_reduction=critic_reduction)
83 | new_target_critic_params = soft_target_update(new_critic.params, target_critic_params, tau)
84 |
85 | key, rng = jax.random.split(rng)
86 | new_actor, actor_info = update_actor(key, actor, new_critic, temp, batch, critic_reduction=critic_reduction)
87 | new_temp, alpha_info = update_temperature(temp, actor_info['entropy'], target_entropy)
88 |
89 | return rng, new_actor, new_critic, new_target_critic_params, new_temp, {
90 | **critic_info,
91 | **actor_info,
92 | **alpha_info
93 | }
94 |
95 |
96 | class PixelSACLearner(Agent):
97 |
98 | def __init__(self,
99 | seed: int,
100 | observations: Union[jnp.ndarray, DatasetDict],
101 | actions: jnp.ndarray,
102 | actor_lr: float = 3e-4,
103 | critic_lr: float = 3e-4,
104 | temp_lr: float = 3e-4,
105 | decay_steps: Optional[int] = None,
106 | hidden_dims: Sequence[int] = (256, 256),
107 | cnn_features: Sequence[int] = (32, 32, 32, 32),
108 | cnn_strides: Sequence[int] = (2, 1, 1, 1),
109 | cnn_padding: str = 'VALID',
110 | latent_dim: int = 50,
111 | discount: float = 0.99,
112 | tau: float = 0.005,
113 | critic_reduction: str = 'mean',
114 | dropout_rate: Optional[float] = None,
115 | encoder_type='resnet_34_v1',
116 | encoder_norm='group',
117 | color_jitter = True,
118 | use_spatial_softmax=True,
119 | softmax_temperature=1,
120 | aug_next=True,
121 | use_bottleneck=True,
122 | init_temperature: float = 1.0,
123 | num_qs: int = 2,
124 | target_entropy: float = None,
125 | action_magnitude: float = 1.0,
126 | num_cameras: int = 1
127 | ):
128 | """
129 | An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1812.05905
130 | """
131 |
132 | self.aug_next=aug_next
133 | self.color_jitter = color_jitter
134 | self.num_cameras = num_cameras
135 |
136 | self.action_dim = np.prod(actions.shape[-2:])
137 | self.action_chunk_shape = actions.shape[-2:]
138 |
139 | self.tau = tau
140 | self.discount = discount
141 | self.critic_reduction = critic_reduction
142 |
143 | rng = jax.random.PRNGKey(seed)
144 | rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4)
145 |
146 | if encoder_type == 'small':
147 | encoder_def = Encoder(cnn_features, cnn_strides, cnn_padding)
148 | elif encoder_type == 'impala':
149 | print('using impala')
150 | encoder_def = ImpalaEncoder()
151 | elif encoder_type == 'impala_small':
152 | print('using impala small')
153 | encoder_def = SmallerImpalaEncoder()
154 | elif encoder_type == 'resnet_small':
155 | encoder_def = ResNetSmall(norm=encoder_norm, use_spatial_softmax=use_spatial_softmax, softmax_temperature=softmax_temperature)
156 | elif encoder_type == 'resnet_18_v1':
157 | encoder_def = ResNet18(norm=encoder_norm, use_spatial_softmax=use_spatial_softmax, softmax_temperature=softmax_temperature)
158 | elif encoder_type == 'resnet_34_v1':
159 | encoder_def = ResNet34(norm=encoder_norm, use_spatial_softmax=use_spatial_softmax, softmax_temperature=softmax_temperature)
160 | elif encoder_type == 'resnet_small_v2':
161 | encoder_def = ResNetV2Encoder(stage_sizes=(1, 1, 1, 1), norm=encoder_norm)
162 | elif encoder_type == 'resnet_18_v2':
163 | encoder_def = ResNetV2Encoder(stage_sizes=(2, 2, 2, 2), norm=encoder_norm)
164 | elif encoder_type == 'resnet_34_v2':
165 | encoder_def = ResNetV2Encoder(stage_sizes=(3, 4, 6, 3), norm=encoder_norm)
166 | else:
167 | raise ValueError('encoder type not found!')
168 |
169 | if decay_steps is not None:
170 | actor_lr = optax.cosine_decay_schedule(actor_lr, decay_steps)
171 |
172 | if len(hidden_dims) == 1:
173 | hidden_dims = (hidden_dims[0], hidden_dims[0], hidden_dims[0])
174 |
175 | policy_def = LearnedStdTanhNormalPolicy(hidden_dims, self.action_dim, dropout_rate=dropout_rate, low=-action_magnitude, high=action_magnitude)
176 |
177 | actor_def = PixelMultiplexer(encoder=encoder_def,
178 | network=policy_def,
179 | latent_dim=latent_dim,
180 | use_bottleneck=use_bottleneck
181 | )
182 | print(actor_def)
183 | actor_def_init = actor_def.init(actor_key, observations)
184 | actor_params = actor_def_init['params']
185 | actor_batch_stats = actor_def_init['batch_stats'] if 'batch_stats' in actor_def_init else None
186 |
187 | actor = TrainState.create(apply_fn=actor_def.apply,
188 | params=actor_params,
189 | tx=optax.adam(learning_rate=actor_lr),
190 | batch_stats=actor_batch_stats)
191 |
192 | critic_def = StateActionEnsemble(hidden_dims, num_qs=num_qs)
193 | critic_def = PixelMultiplexer(encoder=encoder_def,
194 | network=critic_def,
195 | latent_dim=latent_dim,
196 | use_bottleneck=use_bottleneck
197 | )
198 | print(critic_def)
199 | critic_def_init = critic_def.init(critic_key, observations, actions)
200 | self._critic_init_params = critic_def_init['params']
201 |
202 | critic_params = critic_def_init['params']
203 | critic_batch_stats = critic_def_init['batch_stats'] if 'batch_stats' in critic_def_init else None
204 | critic = TrainState.create(apply_fn=critic_def.apply,
205 | params=critic_params,
206 | tx=optax.adam(learning_rate=critic_lr),
207 | batch_stats=critic_batch_stats
208 | )
209 | target_critic_params = copy.deepcopy(critic_params)
210 |
211 | temp_def = Temperature(init_temperature)
212 | temp_params = temp_def.init(temp_key)['params']
213 | temp = TrainState.create(apply_fn=temp_def.apply,
214 | params=temp_params,
215 | tx=optax.adam(learning_rate=temp_lr),
216 | batch_stats=None)
217 |
218 |
219 | self._rng = rng
220 | self._actor = actor
221 | self._critic = critic
222 | self._target_critic_params = target_critic_params
223 | self._temp = temp
224 | if target_entropy is None or target_entropy == 'auto':
225 | self.target_entropy = -self.action_dim / 2
226 | else:
227 | self.target_entropy = float(target_entropy)
228 | print(f'target_entropy: {self.target_entropy}')
229 | print(self.critic_reduction)
230 |
231 |
232 | def update(self, batch: FrozenDict) -> Dict[str, float]:
233 | new_rng, new_actor, new_critic, new_target_critic, new_temp, info = _update_jit(
234 | self._rng, self._actor, self._critic, self._target_critic_params, self._temp, batch, self.discount, self.tau, self.target_entropy, self.critic_reduction, self.color_jitter, self.aug_next, self.num_cameras
235 | )
236 |
237 | self._rng = new_rng
238 | self._actor = new_actor
239 | self._critic = new_critic
240 | self._target_critic_params = new_target_critic
241 | self._temp = new_temp
242 | return info
243 |
244 | def perform_eval(self, variant, i, wandb_logger, eval_buffer, eval_buffer_iterator, eval_env):
245 | from examples.train_utils_sim import make_multiple_value_reward_visulizations
246 | make_multiple_value_reward_visulizations(self, variant, i, eval_buffer, wandb_logger)
247 |
248 | def make_value_reward_visulization(self, variant, trajs):
249 | num_traj = len(trajs['rewards'])
250 | traj_images = []
251 |
252 | for itraj in range(num_traj):
253 | observations = trajs['observations'][itraj]
254 | next_observations = trajs['next_observations'][itraj]
255 | actions = trajs['actions'][itraj]
256 | rewards = trajs['rewards'][itraj]
257 | masks = trajs['masks'][itraj]
258 |
259 | q_pred = []
260 |
261 | for t in range(0, len(actions)):
262 | action = actions[t][None]
263 | obs_pixels = observations['pixels'][t]
264 | next_obs_pixels = next_observations['pixels'][t]
265 |
266 | obs_dict = {'pixels': obs_pixels[None]}
267 | for k, v in observations.items():
268 | if 'pixels' not in k:
269 | obs_dict[k] = v[t][None]
270 | next_obs_dict = {'pixels': next_obs_pixels[None]}
271 | for k, v in next_observations.items():
272 | if 'pixels' not in k:
273 | next_obs_dict[k] = v[t][None]
274 |
275 | q_value = get_value(action, obs_dict, self._critic)
276 | q_pred.append(q_value)
277 |
278 | traj_images.append(make_visual(q_pred, rewards, masks, observations['pixels']))
279 | print('finished reward value visuals.')
280 | return np.concatenate(traj_images, 0)
281 |
282 | @property
283 | def _save_dict(self):
284 | save_dict = {
285 | 'critic': self._critic,
286 | 'target_critic_params': self._target_critic_params,
287 | 'actor': self._actor,
288 | 'temp': self._temp
289 | }
290 | return save_dict
291 |
292 | def restore_checkpoint(self, dir):
293 | assert pathlib.Path(dir).exists(), f"Checkpoint {dir} does not exist."
294 | output_dict = checkpoints.restore_checkpoint(dir, self._save_dict)
295 | self._actor = output_dict['actor']
296 | self._critic = output_dict['critic']
297 | self._target_critic_params = output_dict['target_critic_params']
298 | self._temp = output_dict['temp']
299 | print('restored from ', dir)
300 |
301 |
302 | @functools.partial(jax.jit)
303 | def get_value(action, observation, critic):
304 | input_collections = {'params': critic.params}
305 | q_pred = critic.apply_fn(input_collections, observation, action)
306 | return q_pred
307 |
308 |
309 | def np_unstack(array, axis):
310 | arr = np.split(array, array.shape[axis], axis)
311 | arr = [a.squeeze() for a in arr]
312 | return arr
313 |
314 | def make_visual(q_estimates, rewards, masks, images):
315 |
316 | q_estimates_np = np.stack(q_estimates, 0).squeeze()
317 | fig, axs = plt.subplots(4, 1, figsize=(8, 12))
318 | canvas = FigureCanvas(fig)
319 | plt.xlim([0, len(q_estimates_np)])
320 |
321 | assert len(images.shape) == 5
322 | images = images[..., -1] # only taking the most recent image of the stack
323 | assert images.shape[-1] == 3
324 |
325 | interval = max(1, images.shape[0] // 4)
326 | sel_images = images[::interval]
327 | sel_images = np.concatenate(np_unstack(sel_images, 0), 1)
328 |
329 | axs[0].imshow(sel_images)
330 | if len(q_estimates_np.shape) == 2:
331 | for i in range(q_estimates_np.shape[1]):
332 | axs[1].plot(q_estimates_np[:, i], linestyle='--', marker='o')
333 | else:
334 | axs[1].plot(q_estimates_np, linestyle='--', marker='o')
335 | axs[1].set_ylabel('q values')
336 | axs[2].plot(rewards, linestyle='--', marker='o')
337 | axs[2].set_ylabel('rewards')
338 | axs[2].set_xlim([0, len(rewards)])
339 |
340 | axs[3].plot(masks, linestyle='--', marker='d')
341 | axs[3].set_ylabel('masks')
342 | axs[3].set_xlim([0, len(masks)])
343 |
344 | plt.tight_layout()
345 |
346 | canvas.draw() # draw the canvas, cache the renderer
347 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
348 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
349 |
350 | plt.close(fig)
351 | return out_image
--------------------------------------------------------------------------------
/jaxrl2/data/augmentations.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import functools
4 | from functools import partial
5 |
6 |
7 | def random_crop(key, img, padding):
8 | crop_from = jax.random.randint(key, (2, ), 0, 2 * padding + 1)
9 | crop_from = jnp.concatenate([crop_from, jnp.zeros((2, ), dtype=jnp.int32)])
10 | padded_img = jnp.pad(img, ((padding, padding), (padding, padding), (0, 0),
11 | (0, 0)),
12 | mode='edge')
13 | return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)
14 |
15 |
16 | def batched_random_crop(key, imgs, padding=4):
17 | keys = jax.random.split(key, imgs.shape[0])
18 | return jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding)
19 |
20 | @partial(jax.pmap, axis_name='pmap', static_broadcasted_argnums=(1))
21 | def batched_random_crop_parallel(key, imgs, padding):
22 | keys = jax.random.split(key, imgs.shape[0])
23 | return jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding)
24 |
25 | # typing
26 |
27 | def _maybe_apply(apply_fn, inputs, rng, apply_prob):
28 | should_apply = jax.random.uniform(rng, shape=()) <= apply_prob
29 | return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x)
30 |
31 |
32 | def _depthwise_conv2d(inputs, kernel, strides, padding):
33 | """Computes a depthwise conv2d in Jax.
34 | Args:
35 | inputs: an NHWC tensor with N=1.
36 | kernel: a [H", W", 1, C] tensor.
37 | strides: a 2d tensor.
38 | padding: "SAME" or "VALID".
39 | Returns:
40 | The depthwise convolution of inputs with kernel, as [H, W, C].
41 | """
42 | return jax.lax.conv_general_dilated(
43 | inputs,
44 | kernel,
45 | strides,
46 | padding,
47 | feature_group_count=inputs.shape[-1],
48 | dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
49 |
50 |
51 | def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
52 | """Applies gaussian blur to a single image, given as NHWC with N=1."""
53 | radius = int(kernel_size / 2)
54 | kernel_size_ = 2 * radius + 1
55 | x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
56 | blur_filter = jnp.exp(-x**2 / (2. * sigma**2))
57 | blur_filter = blur_filter / jnp.sum(blur_filter)
58 | blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
59 | blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
60 | num_channels = image.shape[-1]
61 | blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels])
62 | blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
63 | expand_batch_dim = len(image.shape) == 3
64 | if expand_batch_dim:
65 | image = image[jnp.newaxis, ...]
66 | blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
67 | blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
68 | blurred = jnp.squeeze(blurred, axis=0)
69 | return blurred
70 |
71 |
72 | def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min,
73 | sigma_max, apply_prob):
74 | """Applies a random gaussian blur."""
75 | apply_rng, transform_rng = jax.random.split(rng)
76 |
77 | def _apply(image):
78 | sigma_rng, = jax.random.split(transform_rng, 1)
79 | sigma = jax.random.uniform(
80 | sigma_rng,
81 | shape=(),
82 | minval=sigma_min,
83 | maxval=sigma_max,
84 | dtype=jnp.float32)
85 | return _gaussian_blur_single_image(image, kernel_size, padding, sigma)
86 |
87 | return _maybe_apply(_apply, image, apply_rng, apply_prob)
88 |
89 |
90 | def rgb_to_hsv(r, g, b):
91 | """Converts R, G, B values to H, S, V values.
92 | Reference TF implementation:
93 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
94 | Only input values between 0 and 1 are guaranteed to work properly, but this
95 | function complies with the TF implementation outside of this range.
96 | Args:
97 | r: A tensor representing the red color component as floats.
98 | g: A tensor representing the green color component as floats.
99 | b: A tensor representing the blue color component as floats.
100 | Returns:
101 | H, S, V values, each as tensors of shape [...] (same as the input without
102 | the last dimension).
103 | """
104 | vv = jnp.maximum(jnp.maximum(r, g), b)
105 | range_ = vv - jnp.minimum(jnp.minimum(r, g), b)
106 | sat = jnp.where(vv > 0, range_ / vv, 0.)
107 | norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9)
108 |
109 | hr = norm * (g - b)
110 | hg = norm * (b - r) + 2. / 6.
111 | hb = norm * (r - g) + 4. / 6.
112 |
113 | hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb))
114 | hue = hue * (range_ > 0)
115 | hue = hue + (hue < 0)
116 |
117 | return hue, sat, vv
118 |
119 |
120 | def hsv_to_rgb(h, s, v):
121 | """Converts H, S, V values to an R, G, B tuple.
122 | Reference TF implementation:
123 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
124 | Only input values between 0 and 1 are guaranteed to work properly, but this
125 | function complies with the TF implementation outside of this range.
126 | Args:
127 | h: A float tensor of arbitrary shape for the hue (0-1 values).
128 | s: A float tensor of the same shape for the saturation (0-1 values).
129 | v: A float tensor of the same shape for the value channel (0-1 values).
130 | Returns:
131 | An (r, g, b) tuple, each with the same dimension as the inputs.
132 | """
133 | c = s * v
134 | m = v - c
135 | dh = (h % 1.) * 6.
136 | fmodu = dh % 2.
137 | x = c * (1 - jnp.abs(fmodu - 1))
138 | hcat = jnp.floor(dh).astype(jnp.int32)
139 | rr = jnp.where(
140 | (hcat == 0) | (hcat == 5), c, jnp.where(
141 | (hcat == 1) | (hcat == 4), x, 0)) + m
142 | gg = jnp.where(
143 | (hcat == 1) | (hcat == 2), c, jnp.where(
144 | (hcat == 0) | (hcat == 3), x, 0)) + m
145 | bb = jnp.where(
146 | (hcat == 3) | (hcat == 4), c, jnp.where(
147 | (hcat == 2) | (hcat == 5), x, 0)) + m
148 | return rr, gg, bb
149 |
150 |
151 | def adjust_brightness(rgb_tuple, delta):
152 | return jax.tree_util.tree_map(lambda x: x + delta, rgb_tuple)
153 |
154 |
155 | def adjust_contrast(image, factor):
156 | def _adjust_contrast_channel(channel):
157 | mean = jnp.mean(channel, axis=(-2, -1), keepdims=True)
158 | return factor * (channel - mean) + mean
159 | return jax.tree_util.tree_map(_adjust_contrast_channel, image)
160 |
161 |
162 | def adjust_saturation(h, s, v, factor):
163 | return h, jnp.clip(s * factor, 0., 1.), v
164 |
165 |
166 | def adjust_hue(h, s, v, delta):
167 | # Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb
168 | # conversions) when running on GPU. When running on CPU, the results will be
169 | # different if all RGB values for a pixel are outside of the [0, 1] range.
170 | return (h + delta) % 1.0, s, v
171 |
172 |
173 | def _random_brightness(rgb_tuple, rng, max_delta):
174 | delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
175 | return adjust_brightness(rgb_tuple, delta)
176 |
177 |
178 | def _random_contrast(rgb_tuple, rng, max_delta):
179 | factor = jax.random.uniform(
180 | rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
181 | return adjust_contrast(rgb_tuple, factor)
182 |
183 |
184 | def _random_saturation(rgb_tuple, rng, max_delta):
185 | h, s, v = rgb_to_hsv(*rgb_tuple)
186 | factor = jax.random.uniform(
187 | rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
188 | return hsv_to_rgb(*adjust_saturation(h, s, v, factor))
189 |
190 |
191 | def _random_hue(rgb_tuple, rng, max_delta):
192 | h, s, v = rgb_to_hsv(*rgb_tuple)
193 | delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
194 | return hsv_to_rgb(*adjust_hue(h, s, v, delta))
195 |
196 |
197 | def _to_grayscale(image):
198 | rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
199 | grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
200 | return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
201 |
202 |
203 | def _color_transform_single_image(image, rng, brightness, contrast, saturation,
204 | hue, to_grayscale_prob, color_jitter_prob,
205 | apply_prob, shuffle):
206 | """Applies color jittering to a single image."""
207 | apply_rng, transform_rng = jax.random.split(rng)
208 | perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
209 | transform_rng, 7)
210 |
211 | # Whether the transform should be applied at all.
212 | should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob
213 | # Whether to apply grayscale transform.
214 | should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob
215 | # Whether to apply color jittering.
216 | should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob
217 |
218 | # Decorator to conditionally apply fn based on an index.
219 | def _make_cond(fn, idx):
220 |
221 | def identity_fn(x, unused_rng, unused_param):
222 | return x
223 |
224 | def cond_fn(args, i):
225 | def clip(args):
226 | return jax.tree_util.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
227 | out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
228 | lambda a: clip(fn(*a)), args,
229 | lambda a: identity_fn(*a))
230 | return jax.lax.stop_gradient(out)
231 |
232 | return cond_fn
233 |
234 | random_brightness_cond = _make_cond(_random_brightness, idx=0)
235 | random_contrast_cond = _make_cond(_random_contrast, idx=1)
236 | random_saturation_cond = _make_cond(_random_saturation, idx=2)
237 | random_hue_cond = _make_cond(_random_hue, idx=3)
238 |
239 | def _color_jitter(x):
240 | rgb_tuple = tuple(jax.tree_util.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1)))
241 | if shuffle:
242 | order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
243 | else:
244 | order = range(4)
245 | for idx in order:
246 | if brightness > 0:
247 | rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx)
248 | if contrast > 0:
249 | rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx)
250 | if saturation > 0:
251 | rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx)
252 | if hue > 0:
253 | rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx)
254 | return jnp.stack(rgb_tuple, axis=-1)
255 |
256 | out_apply = _color_jitter(image)
257 | out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
258 | _to_grayscale, out_apply, lambda x: x)
259 | return jnp.clip(out_apply, 0., 1.)
260 |
261 |
262 | def _random_flip_single_image(image, rng):
263 | _, flip_rng = jax.random.split(rng)
264 | should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
265 | image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
266 | return image
267 |
268 |
269 | def random_flip(images, rng):
270 | rngs = jax.random.split(rng, images.shape[0])
271 | return jax.vmap(_random_flip_single_image)(images, rngs)
272 |
273 |
274 | def color_transform(rng,
275 | images,
276 | brightness=0.2,
277 | contrast=0.1,
278 | saturation=0.1,
279 | hue=0.03,
280 | color_jitter_prob=0.8,
281 | to_grayscale_prob=0.0,
282 | apply_prob=1.0,
283 | shuffle=True):
284 | """Applies color jittering and/or grayscaling to a batch of images.
285 | Args:
286 | images: an NHWC tensor, with C=3.
287 | rng: a single PRNGKey.
288 | brightness: the range of jitter on brightness.
289 | contrast: the range of jitter on contrast.
290 | saturation: the range of jitter on saturation.
291 | hue: the range of jitter on hue.
292 | color_jitter_prob: the probability of applying color jittering.
293 | to_grayscale_prob: the probability of converting the image to grayscale.
294 | apply_prob: the probability of applying the transform to a batch element.
295 | shuffle: whether to apply the transforms in a random order.
296 | Returns:
297 | A NHWC tensor of the transformed images.
298 | """
299 | images = images[:, :, :, :, 0]
300 | rngs = jax.random.split(rng, images.shape[0])
301 | jitter_fn = functools.partial(
302 | _color_transform_single_image,
303 | brightness=brightness,
304 | contrast=contrast,
305 | saturation=saturation,
306 | hue=hue,
307 | color_jitter_prob=color_jitter_prob,
308 | to_grayscale_prob=to_grayscale_prob,
309 | apply_prob=apply_prob,
310 | shuffle=shuffle)
311 | augmented_images = jax.vmap(jitter_fn)(images, rngs)
312 | return augmented_images[..., jnp.newaxis]
313 |
314 | @partial(jax.pmap, axis_name='pmap')
315 | def color_transform_parallel(rng, images):
316 | """Applies color jittering and/or grayscaling to a batch of images.
317 | Args:
318 | images: an NHWC tensor, with C=3.
319 | rng: a single PRNGKey.
320 | brightness: the range of jitter on brightness.
321 | contrast: the range of jitter on contrast.
322 | saturation: the range of jitter on saturation.
323 | hue: the range of jitter on hue.
324 | color_jitter_prob: the probability of applying color jittering.
325 | to_grayscale_prob: the probability of converting the image to grayscale.
326 | apply_prob: the probability of applying the transform to a batch element.
327 | shuffle: whether to apply the transforms in a random order.
328 | Returns:
329 | A NHWC tensor of the transformed images.
330 | """
331 | brightness=0.2,
332 | contrast=0.1,
333 | saturation=0.1,
334 | hue=0.03,
335 | color_jitter_prob=0.8,
336 | to_grayscale_prob=0.0,
337 | apply_prob=1.0,
338 | shuffle=True
339 | images = images[:, :, :, :, 0]
340 | rngs = jax.random.split(rng, images.shape[0])
341 | jitter_fn = functools.partial(
342 | _color_transform_single_image,
343 | brightness=brightness,
344 | contrast=contrast,
345 | saturation=saturation,
346 | hue=hue,
347 | color_jitter_prob=color_jitter_prob,
348 | to_grayscale_prob=to_grayscale_prob,
349 | apply_prob=apply_prob,
350 | shuffle=shuffle)
351 | augmented_images = jax.vmap(jitter_fn)(images, rngs)
352 | return augmented_images[..., jnp.newaxis]
353 |
354 |
355 | def gaussian_blur(images,
356 | rng,
357 | blur_divider=10.,
358 | sigma_min=0.1,
359 | sigma_max=2.0,
360 | apply_prob=1.0):
361 | """Applies gaussian blur to a batch of images.
362 | Args:
363 | images: an NHWC tensor, with C=3.
364 | rng: a single PRNGKey.
365 | blur_divider: the blurring kernel will have size H / blur_divider.
366 | sigma_min: the minimum value for sigma in the blurring kernel.
367 | sigma_max: the maximum value for sigma in the blurring kernel.
368 | apply_prob: the probability of applying the transform to a batch element.
369 | Returns:
370 | A NHWC tensor of the blurred images.
371 | """
372 | rngs = jax.random.split(rng, images.shape[0])
373 | kernel_size = images.shape[1] / blur_divider
374 | blur_fn = functools.partial(
375 | _random_gaussian_blur,
376 | kernel_size=kernel_size,
377 | padding='SAME',
378 | sigma_min=sigma_min,
379 | sigma_max=sigma_max,
380 | apply_prob=apply_prob)
381 | return jax.vmap(blur_fn)(images, rngs)
382 |
383 |
384 | def _solarize_single_image(image, rng, threshold, apply_prob):
385 |
386 | def _apply(image):
387 | return jnp.where(image < threshold, image, 1. - image)
388 |
389 | return _maybe_apply(_apply, image, rng, apply_prob)
390 |
391 |
392 | def solarize(images, rng, threshold=0.5, apply_prob=1.0):
393 | """Applies solarization.
394 | Args:
395 | images: an NHWC tensor (with C=3).
396 | rng: a single PRNGKey.
397 | threshold: the solarization threshold.
398 | apply_prob: the probability of applying the transform to a batch element.
399 | Returns:
400 | A NHWC tensor of the transformed images.
401 | """
402 | rngs = jax.random.split(rng, images.shape[0])
403 | solarize_fn = functools.partial(
404 | _solarize_single_image, threshold=threshold, apply_prob=apply_prob)
405 | return jax.vmap(solarize_fn)(images, rngs)
406 |
--------------------------------------------------------------------------------
/examples/train_utils_sim.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 | import wandb
4 | import jax
5 | from openpi_client import image_tools
6 | import math
7 | import PIL
8 |
9 | def _quat2axisangle(quat):
10 | """
11 | Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
12 | """
13 | # clip quaternion
14 | if quat[3] > 1.0:
15 | quat[3] = 1.0
16 | elif quat[3] < -1.0:
17 | quat[3] = -1.0
18 |
19 | den = np.sqrt(1.0 - quat[3] * quat[3])
20 | if math.isclose(den, 0.0):
21 | # This is (close to) a zero degree rotation, immediately return
22 | return np.zeros(3)
23 |
24 | return (quat[:3] * 2.0 * math.acos(quat[3])) / den
25 |
26 | def obs_to_img(obs, variant):
27 | '''
28 | Convert raw observation to resized image for DSRL actor/critic
29 | '''
30 | if variant.env == 'libero':
31 | curr_image = obs["agentview_image"][::-1, ::-1]
32 | elif variant.env == 'aloha_cube':
33 | curr_image = obs["pixels"]["top"]
34 | else:
35 | raise NotImplementedError()
36 | if variant.resize_image > 0:
37 | curr_image = np.array(PIL.Image.fromarray(curr_image).resize((variant.resize_image, variant.resize_image)))
38 | return curr_image
39 |
40 | def obs_to_pi_zero_input(obs, variant):
41 | if variant.env == 'libero':
42 | img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
43 | wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
44 | img = image_tools.convert_to_uint8(
45 | image_tools.resize_with_pad(img, 224, 224)
46 | )
47 | wrist_img = image_tools.convert_to_uint8(
48 | image_tools.resize_with_pad(wrist_img, 224, 224)
49 | )
50 |
51 | obs_pi_zero = {
52 | "observation/image": img,
53 | "observation/wrist_image": wrist_img,
54 | "observation/state": np.concatenate(
55 | (
56 | obs["robot0_eef_pos"],
57 | _quat2axisangle(obs["robot0_eef_quat"]),
58 | obs["robot0_gripper_qpos"],
59 | )
60 | ),
61 | "prompt": str(variant.task_description),
62 | }
63 | elif variant.env == 'aloha_cube':
64 | img = np.ascontiguousarray(obs["pixels"]["top"])
65 | img = image_tools.convert_to_uint8(
66 | image_tools.resize_with_pad(img, 224, 224)
67 | )
68 | obs_pi_zero = {
69 | "state": obs["agent_pos"],
70 | "images": {"cam_high": np.transpose(img, (2,0,1))}
71 | }
72 | else:
73 | raise NotImplementedError()
74 | return obs_pi_zero
75 |
76 | def obs_to_qpos(obs, variant):
77 | if variant.env == 'libero':
78 | qpos = np.concatenate(
79 | (
80 | obs["robot0_eef_pos"],
81 | _quat2axisangle(obs["robot0_eef_quat"]),
82 | obs["robot0_gripper_qpos"],
83 | )
84 | )
85 | elif variant.env == 'aloha_cube':
86 | qpos = obs["agent_pos"]
87 | else:
88 | raise NotImplementedError()
89 | return qpos
90 |
91 | def trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger,
92 | perform_control_evals=True, shard_fn=None, agent_dp=None):
93 | replay_buffer_iterator = replay_buffer.get_iterator(variant.batch_size)
94 | if shard_fn is not None:
95 | replay_buffer_iterator = map(shard_fn, replay_buffer_iterator)
96 |
97 | total_env_steps = 0
98 | i = 0
99 | wandb_logger.log({'num_online_samples': 0}, step=i)
100 | wandb_logger.log({'num_online_trajs': 0}, step=i)
101 | wandb_logger.log({'env_steps': 0}, step=i)
102 |
103 | with tqdm(total=variant.max_steps, initial=0) as pbar:
104 | while i <= variant.max_steps:
105 | traj = collect_traj(variant, agent, env, i, agent_dp)
106 | traj_id = online_replay_buffer._traj_counter
107 | add_online_data_to_buffer(variant, traj, online_replay_buffer)
108 | total_env_steps += traj['env_steps']
109 | print('online buffer timesteps length:', len(online_replay_buffer))
110 | print('online buffer num traj:', traj_id + 1)
111 | print('total env steps:', total_env_steps)
112 |
113 | if variant.get("num_online_gradsteps_batch", -1) > 0:
114 | num_gradsteps = variant.num_online_gradsteps_batch
115 | else:
116 | num_gradsteps = len(traj["rewards"])*variant.multi_grad_step
117 |
118 | if len(online_replay_buffer) > variant.start_online_updates:
119 | for _ in range(num_gradsteps):
120 | # perform first visualization before updating
121 | if i == 0:
122 | print('performing evaluation for initial checkpoint')
123 | if perform_control_evals:
124 | perform_control_eval(agent, eval_env, i, variant, wandb_logger, agent_dp)
125 | if hasattr(agent, 'perform_eval'):
126 | agent.perform_eval(variant, i, wandb_logger, replay_buffer, replay_buffer_iterator, eval_env)
127 |
128 | # online perform update once we have some amount of online trajs
129 | batch = next(replay_buffer_iterator)
130 | update_info = agent.update(batch)
131 |
132 | pbar.update()
133 | i += 1
134 |
135 |
136 | if i % variant.log_interval == 0:
137 | update_info = {k: jax.device_get(v) for k, v in update_info.items()}
138 | for k, v in update_info.items():
139 | if v.ndim == 0:
140 | wandb_logger.log({f'training/{k}': v}, step=i)
141 | elif v.ndim <= 2:
142 | wandb_logger.log_histogram(f'training/{k}', v, i)
143 | # wandb_logger.log({'replay_buffer_size': len(online_replay_buffer)}, i)
144 | wandb_logger.log({
145 | 'replay_buffer_size': len(online_replay_buffer),
146 | 'episode_return (exploration)': traj['episode_return'],
147 | 'is_success (exploration)': int(traj['is_success']),
148 | }, i)
149 |
150 | if i % variant.eval_interval == 0:
151 | wandb_logger.log({'num_online_samples': len(online_replay_buffer)}, step=i)
152 | wandb_logger.log({'num_online_trajs': traj_id + 1}, step=i)
153 | wandb_logger.log({'env_steps': total_env_steps}, step=i)
154 | if perform_control_evals:
155 | perform_control_eval(agent, eval_env, i, variant, wandb_logger, agent_dp)
156 | if hasattr(agent, 'perform_eval'):
157 | agent.perform_eval(variant, i, wandb_logger, replay_buffer, replay_buffer_iterator, eval_env)
158 |
159 | if variant.checkpoint_interval != -1 and i % variant.checkpoint_interval == 0:
160 | agent.save_checkpoint(variant.outputdir, i, variant.checkpoint_interval)
161 |
162 |
163 | def add_online_data_to_buffer(variant, traj, online_replay_buffer):
164 |
165 | discount_horizon = variant.query_freq
166 | actions = np.array(traj['actions']) # (T, chunk_size, action_dim )
167 | episode_len = len(actions)
168 | rewards = np.array(traj['rewards'])
169 | masks = np.array(traj['masks'])
170 |
171 | for t in range(episode_len):
172 | obs = traj['observations'][t]
173 | next_obs = traj['observations'][t + 1]
174 | # remove batch dimension
175 | obs = {k: v[0] for k, v in obs.items()}
176 | next_obs = {k: v[0] for k, v in next_obs.items()}
177 | if not variant.add_states:
178 | obs.pop('state', None)
179 | next_obs.pop('state', None)
180 |
181 | insert_dict = dict(
182 | observations=obs,
183 | next_observations=next_obs,
184 | actions=actions[t],
185 | next_actions=actions[t + 1] if t < episode_len - 1 else actions[t],
186 | rewards=rewards[t],
187 | masks=masks[t],
188 | discount=variant.discount ** discount_horizon
189 | )
190 | online_replay_buffer.insert(insert_dict)
191 | online_replay_buffer.increment_traj_counter()
192 |
193 | def collect_traj(variant, agent, env, i, agent_dp=None):
194 | query_frequency = variant.query_freq
195 | max_timesteps = variant.max_timesteps
196 | env_max_reward = variant.env_max_reward
197 |
198 | agent._rng, rng = jax.random.split(agent._rng)
199 |
200 | if 'libero' in variant.env:
201 | obs = env.reset()
202 | elif 'aloha' in variant.env:
203 | obs, _ = env.reset()
204 |
205 | image_list = [] # for visualization
206 | rewards = []
207 | action_list = []
208 | obs_list = []
209 |
210 | for t in tqdm(range(max_timesteps)):
211 | curr_image = obs_to_img(obs, variant)
212 |
213 | qpos = obs_to_qpos(obs, variant)
214 |
215 | if variant.add_states:
216 | obs_dict = {
217 | 'pixels': curr_image[np.newaxis, ..., np.newaxis],
218 | 'state': qpos[np.newaxis, ..., np.newaxis],
219 | }
220 | else:
221 | obs_dict = {
222 | 'pixels': curr_image[np.newaxis, ..., np.newaxis],
223 | }
224 |
225 | if t % query_frequency == 0:
226 |
227 | assert agent_dp is not None
228 | # we then use the noise to sample the action from diffusion model
229 | rng, key = jax.random.split(rng)
230 | obs_pi_zero = obs_to_pi_zero_input(obs, variant)
231 | if i == 0:
232 | # for initial round of data collection, we sample from standard gaussian noise
233 | noise = jax.random.normal(key, (1, *agent.action_chunk_shape))
234 | noise_repeat = jax.numpy.repeat(noise[:, -1:, :], 50 - noise.shape[1], axis=1)
235 | noise = jax.numpy.concatenate([noise, noise_repeat], axis=1)
236 | actions_noise = noise[0, :agent.action_chunk_shape[0], :]
237 | else:
238 | # sac agent predicts the noise for diffusion model
239 | actions_noise = agent.sample_actions(obs_dict)
240 | actions_noise = np.reshape(actions_noise, agent.action_chunk_shape)
241 | noise = np.repeat(actions_noise[-1:, :], 50 - actions_noise.shape[0], axis=0)
242 | noise = jax.numpy.concatenate([actions_noise, noise], axis=0)[None]
243 |
244 | actions = agent_dp.infer(obs_pi_zero, noise=noise)["actions"]
245 | action_list.append(actions_noise)
246 | obs_list.append(obs_dict)
247 |
248 | action_t = actions[t % query_frequency]
249 | if 'libero' in variant.env:
250 | obs, reward, done, _ = env.step(action_t)
251 | elif 'aloha' in variant.env:
252 | obs, reward, terminated, truncated, _ = env.step(action_t)
253 | done = terminated or truncated
254 |
255 | rewards.append(reward)
256 | image_list.append(curr_image)
257 | if done:
258 | break
259 |
260 | # add last observation
261 | curr_image = obs_to_img(obs, variant)
262 | qpos = obs_to_qpos(obs, variant)
263 | obs_dict = {
264 | 'pixels': curr_image[np.newaxis, ..., np.newaxis],
265 | 'state': qpos[np.newaxis, ..., np.newaxis],
266 | }
267 | obs_list.append(obs_dict)
268 | image_list.append(curr_image)
269 |
270 | # per episode
271 | rewards = np.array(rewards)
272 | episode_return = np.sum(rewards[rewards!=None])
273 | is_success = (reward == env_max_reward)
274 | print(f'Rollout Done: {episode_return=}, Success: {is_success}')
275 |
276 |
277 | '''
278 | We use sparse -1/0 reward to train the SAC agent.
279 | '''
280 | if is_success:
281 | query_steps = len(action_list)
282 | rewards = np.concatenate([-np.ones(query_steps - 1), [0]])
283 | masks = np.concatenate([np.ones(query_steps - 1), [0]])
284 | else:
285 | query_steps = len(action_list)
286 | rewards = -np.ones(query_steps)
287 | masks = np.ones(query_steps)
288 |
289 | return {
290 | 'observations': obs_list,
291 | 'actions': action_list,
292 | 'rewards': rewards,
293 | 'masks': masks,
294 | 'is_success': is_success,
295 | 'episode_return': episode_return,
296 | 'images': image_list,
297 | 'env_steps': t + 1
298 | }
299 |
300 | def perform_control_eval(agent, env, i, variant, wandb_logger, agent_dp=None):
301 | query_frequency = variant.query_freq
302 | print('query frequency', query_frequency)
303 | max_timesteps = variant.max_timesteps
304 | env_max_reward = variant.env_max_reward
305 | episode_returns = []
306 | highest_rewards = []
307 | success_rates = []
308 | episode_lens = []
309 |
310 | rng = jax.random.PRNGKey(variant.seed+456)
311 |
312 | for rollout_id in range(variant.eval_episodes):
313 | if 'libero' in variant.env:
314 | obs = env.reset()
315 | elif 'aloha' in variant.env:
316 | obs, _ = env.reset()
317 |
318 | image_list = [] # for visualization
319 | rewards = []
320 |
321 |
322 | for t in tqdm(range(max_timesteps)):
323 | curr_image = obs_to_img(obs, variant)
324 |
325 | if t % query_frequency == 0:
326 | qpos = obs_to_qpos(obs, variant)
327 | if variant.add_states:
328 | obs_dict = {
329 | 'pixels': curr_image[np.newaxis, ..., np.newaxis],
330 | 'state': qpos[np.newaxis, ..., np.newaxis],
331 | }
332 | else:
333 | obs_dict = {
334 | 'pixels': curr_image[np.newaxis, ..., np.newaxis],
335 | }
336 |
337 | rng, key = jax.random.split(rng)
338 | assert agent_dp is not None
339 |
340 | obs_pi_zero = obs_to_pi_zero_input(obs, variant)
341 |
342 |
343 | if i == 0:
344 | # for initial evaluation, we sample from standard gaussian noise to evaluate the base policy's performance
345 | noise = jax.random.normal(rng, (1, 50, 32))
346 | else:
347 | actions_noise = agent.sample_actions(obs_dict)
348 | actions_noise = np.reshape(actions_noise, agent.action_chunk_shape)
349 | noise = np.repeat(actions_noise[-1:, :], 50 - actions_noise.shape[0], axis=0)
350 | noise = jax.numpy.concatenate([actions_noise, noise], axis=0)[None]
351 |
352 | actions = agent_dp.infer(obs_pi_zero, noise=noise)["actions"]
353 |
354 | action_t = actions[t % query_frequency]
355 |
356 | if 'libero' in variant.env:
357 | obs, reward, done, _ = env.step(action_t)
358 | elif 'aloha' in variant.env:
359 | obs, reward, terminated, truncated, _ = env.step(action_t)
360 | done = terminated or truncated
361 |
362 | rewards.append(reward)
363 | image_list.append(curr_image)
364 | if done:
365 | break
366 |
367 | # per episode
368 | episode_lens.append(t + 1)
369 | rewards = np.array(rewards)
370 | episode_return = np.sum(rewards)
371 | episode_returns.append(episode_return)
372 | episode_highest_reward = np.max(rewards)
373 | highest_rewards.append(episode_highest_reward)
374 | is_success = (reward == env_max_reward)
375 | success_rates.append(is_success)
376 |
377 | print(f'Rollout {rollout_id} : {episode_return=}, Success: {is_success}')
378 | video = np.stack(image_list).transpose(0, 3, 1, 2)
379 | wandb_logger.log({f'eval_video/{rollout_id}': wandb.Video(video, fps=50)}, step=i)
380 |
381 |
382 | success_rate = np.mean(np.array(success_rates))
383 | avg_return = np.mean(episode_returns)
384 | avg_episode_len = np.mean(episode_lens)
385 | summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n'
386 | wandb_logger.log({'evaluation/avg_return': avg_return}, step=i)
387 | wandb_logger.log({'evaluation/success_rate': success_rate}, step=i)
388 | wandb_logger.log({'evaluation/avg_episode_len': avg_episode_len}, step=i)
389 | for r in range(env_max_reward+1):
390 | more_or_equal_r = (np.array(highest_rewards) >= r).sum()
391 | more_or_equal_r_rate = more_or_equal_r / variant.eval_episodes
392 | wandb_logger.log({f'evaluation/Reward >= {r}': more_or_equal_r_rate}, step=i)
393 | summary_str += f'Reward >= {r}: {more_or_equal_r}/{variant.eval_episodes} = {more_or_equal_r_rate*100}%\n'
394 |
395 | print(summary_str)
396 |
397 | def make_multiple_value_reward_visulizations(agent, variant, i, replay_buffer, wandb_logger):
398 | trajs = replay_buffer.get_random_trajs(3)
399 | images = agent.make_value_reward_visulization(variant, trajs)
400 | wandb_logger.log({'reward_value_images': wandb.Image(images)}, step=i)
401 |
402 |
--------------------------------------------------------------------------------