├── jaxrl2 ├── __init__.py ├── utils │ ├── __init__.py │ ├── wandb_config_example.py │ ├── target_update.py │ ├── launch_util.py │ ├── general_utils.py │ ├── visualization_utils.py │ └── wandb_logger.py ├── agents │ ├── __init__.py │ ├── pixel_sac │ │ ├── __init__.py │ │ ├── temperature.py │ │ ├── temperature_updater.py │ │ ├── critic_updater.py │ │ ├── actor_updater.py │ │ └── pixel_sac_learner.py │ ├── agent.py │ └── common.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── replay_buffer.py │ └── augmentations.py ├── networks │ ├── values │ │ ├── __init__.py │ │ ├── state_action_ensemble.py │ │ ├── state_value.py │ │ └── state_action_value.py │ ├── __init__.py │ ├── encoders │ │ ├── __init__.py │ │ ├── spatial_softmax.py │ │ ├── networks.py │ │ ├── resnet_encoderv2.py │ │ ├── impala_encoder.py │ │ ├── resnet_encoderv1.py │ │ └── cross_norm.py │ ├── constants.py │ ├── normal_policy.py │ ├── normal_tanh_policy.py │ ├── learned_std_normal_policy.py │ └── mlp.py └── types.py ├── examples ├── __init__.py ├── scripts │ ├── run_aloha.sh │ ├── run_libero.sh │ └── run_real.sh ├── launch_train_sim.py ├── launch_train_real.py ├── train_real.py ├── train_sim.py ├── train_utils_real.py └── train_utils_sim.py ├── setup.py ├── .gitmodules ├── requirements.txt ├── .gitignore └── README.md /jaxrl2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jaxrl2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /jaxrl2/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl2.agents.pixel_sac import PixelSACLearner 2 | -------------------------------------------------------------------------------- /jaxrl2/data/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl2.data.replay_buffer import ReplayBuffer 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="jaxrl2", packages=["jaxrl2"]) 4 | -------------------------------------------------------------------------------- /jaxrl2/agents/pixel_sac/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl2.agents.pixel_sac.pixel_sac_learner import PixelSACLearner 2 | -------------------------------------------------------------------------------- /jaxrl2/networks/values/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl2.networks.values.state_action_ensemble import StateActionEnsemble 2 | from jaxrl2.networks.values.state_value import StateValue 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "LIBERO"] 2 | path = LIBERO 3 | url = git@github.com:nakamotoo/LIBERO.git 4 | [submodule "openpi"] 5 | path = openpi 6 | url = git@github.com:nakamotoo/openpi.git 7 | -------------------------------------------------------------------------------- /jaxrl2/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl2.networks.mlp import MLP 2 | from jaxrl2.networks.normal_policy import NormalPolicy 3 | from jaxrl2.networks.normal_tanh_policy import NormalTanhPolicy 4 | -------------------------------------------------------------------------------- /jaxrl2/networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl2.networks.mlp import MLP 2 | from jaxrl2.networks.normal_policy import NormalPolicy 3 | from jaxrl2.networks.normal_tanh_policy import NormalTanhPolicy 4 | -------------------------------------------------------------------------------- /jaxrl2/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Union 2 | 3 | import flax 4 | import numpy as np 5 | 6 | DataType = Union[np.ndarray, Dict[str, 'DataType']] 7 | PRNGKey = Any 8 | Params = flax.core.FrozenDict[str, Any] -------------------------------------------------------------------------------- /jaxrl2/utils/wandb_config_example.py: -------------------------------------------------------------------------------- 1 | # copy this into wandb_config.py! 2 | def get_wandb_config(): 3 | return dict ( 4 | WANDB_API_KEY='your api key', 5 | WANDB_EMAIL='your email', 6 | WANDB_USERNAME='user', 7 | WANDB_TEAM='team_name_if_any' 8 | ) -------------------------------------------------------------------------------- /jaxrl2/networks/constants.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | 5 | def default_init(scale: float = jnp.sqrt(2)): 6 | return nn.initializers.orthogonal(scale) 7 | 8 | def xavier_init(): 9 | return nn.initializers.xavier_normal() 10 | 11 | def kaiming_init(): 12 | return nn.initializers.kaiming_normal() -------------------------------------------------------------------------------- /jaxrl2/agents/pixel_sac/temperature.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | 5 | class Temperature(nn.Module): 6 | initial_temperature: float = 1.0 7 | 8 | @nn.compact 9 | def __call__(self) -> jnp.ndarray: 10 | log_temp = self.param('log_temp', 11 | init_fn=lambda key: jnp.full( 12 | (), jnp.log(self.initial_temperature))) 13 | return jnp.exp(log_temp) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.33.1 2 | distrax==0.1.5 3 | dm-control==1.0.14 4 | easydict==1.13 5 | einops==0.8.1 6 | flax==0.10.2 7 | future==1.0.0 8 | gym==0.26.2 9 | gym-notices==0.0.8 10 | gymnasium==1.0.0 11 | h5py==3.13.0 12 | imageio==2.37.0 13 | jax==0.5.0 14 | ml_collections==1.0.0 15 | ml_dtypes==0.5.1 16 | optax==0.2.4 17 | optree==0.15.0 18 | orbax-checkpoint==0.11.1 19 | orderly-set==5.4.0 20 | packaging==24.2 21 | pandas==2.2.3 22 | pillow==10.4.0 23 | scipy==1.15.2 24 | tensorflow==2.19.0 25 | etils==1.12.2 26 | matplotlib==3.9.3 27 | wandb[media]==0.19.9 28 | opencv-python==4.11.0.86 29 | robosuite==1.4.1 30 | gym-aloha==0.1.1 31 | bddl==3.5.0 32 | tf_keras==2.19.0 33 | moviepy==1.0.3 34 | -------------------------------------------------------------------------------- /jaxrl2/utils/target_update.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from jaxrl2.types import Params 4 | 5 | from functools import partial 6 | 7 | def soft_target_update(critic_params: Params, target_critic_params: Params, tau: float) -> Params: 8 | new_target_params = jax.tree_util.tree_map(lambda p, tp: p * tau + tp * (1 - tau), critic_params, target_critic_params) 9 | return new_target_params 10 | 11 | @partial(jax.pmap, axis_name='pmap', static_broadcasted_argnums=(2)) 12 | def soft_target_update_parallel(critic_params: Params, target_critic_params: Params, tau: float) -> Params: 13 | new_target_params = jax.tree_util.tree_map(lambda p, tp: p * tau + tp * (1 - tau), critic_params, target_critic_params) 14 | return new_target_params 15 | -------------------------------------------------------------------------------- /jaxrl2/utils/launch_util.py: -------------------------------------------------------------------------------- 1 | from jaxrl2.utils.general_utils import AttrDict 2 | 3 | def parse_training_args(train_args_dict, parser): 4 | for k, v in train_args_dict.items(): 5 | if type(v) == tuple: 6 | parser.add_argument('--' + k, nargs="+", default=v, type=type(v[0])) 7 | elif type(v) != bool: 8 | parser.add_argument('--' + k, default=v, type=type(v)) 9 | else: 10 | parser.add_argument('--' + k, default=int(v), type=int) 11 | args = parser.parse_args() 12 | config = {} 13 | for key in train_args_dict.keys(): 14 | config[key] = getattr(args, key) 15 | variant = AttrDict(vars(args)) 16 | variant['train_kwargs'] = config 17 | return variant, args 18 | -------------------------------------------------------------------------------- /jaxrl2/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | class AttrDict(dict): 2 | __setattr__ = dict.__setitem__ 3 | 4 | def __getattr__(self, attr): 5 | # Take care that getattr() raises AttributeError, not KeyError. 6 | # Required e.g. for hasattr(), deepcopy and OrderedDict. 7 | try: 8 | return self.__getitem__(attr) 9 | except KeyError: 10 | raise AttributeError("Attribute %r not found" % attr) 11 | 12 | def __getstate__(self): return self 13 | def __setstate__(self, d): self = d 14 | 15 | def add_batch_dim(input): 16 | if isinstance(input, dict): 17 | for k, v in input.items(): 18 | input[k] = v[None] 19 | else: 20 | input = input[None] 21 | return input 22 | -------------------------------------------------------------------------------- /jaxrl2/agents/pixel_sac/temperature_updater.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import jax 4 | from flax.training.train_state import TrainState 5 | 6 | 7 | def update_temperature( 8 | temp: TrainState, entropy: float, 9 | target_entropy: float) -> Tuple[TrainState, Dict[str, float]]: 10 | 11 | def temperature_loss_fn(temp_params): 12 | temperature = temp.apply_fn({'params': temp_params}) 13 | temp_loss = temperature * (entropy - target_entropy).mean() 14 | return temp_loss, { 15 | 'temperature': temperature, 16 | 'temperature_loss': temp_loss 17 | } 18 | 19 | grads, info = jax.grad(temperature_loss_fn, has_aux=True)(temp.params) 20 | new_temp = temp.apply_gradients(grads=grads) 21 | 22 | return new_temp, info -------------------------------------------------------------------------------- /examples/scripts/run_aloha.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | proj_name=DSRL_pi0_Aloha 3 | device_id=0 4 | 5 | export DISPLAY=:0 6 | export MUJOCO_GL=egl 7 | export MUJOCO_EGL_DEVICE_ID=$device_id 8 | 9 | export OPENPI_DATA_HOME=./openpi 10 | export EXP=./logs/$proj_name; 11 | export CUDA_VISIBLE_DEVICES=$device_id 12 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 13 | 14 | 15 | pip install mujoco==2.3.7 16 | 17 | python3 examples/launch_train_sim.py \ 18 | --algorithm pixel_sac \ 19 | --env aloha_cube \ 20 | --prefix dsrl_pi0_aloha \ 21 | --wandb_project ${proj_name} \ 22 | --batch_size 256 \ 23 | --discount 0.999 \ 24 | --seed 0 \ 25 | --max_steps 3000000 \ 26 | --eval_interval 10000 \ 27 | --log_interval 500 \ 28 | --eval_episodes 10 \ 29 | --multi_grad_step 20 \ 30 | --start_online_updates 1000 \ 31 | --resize_image 64 \ 32 | --action_magnitude 2.0 \ 33 | --query_freq 50 \ 34 | --hidden_dims 128 \ 35 | --target_entropy 0.0 -------------------------------------------------------------------------------- /examples/scripts/run_libero.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | proj_name=DSRL_pi0_Libero 3 | device_id=0 4 | 5 | export DISPLAY=:0 6 | export MUJOCO_GL=egl 7 | export PYOPENGL_PLATFORM=egl 8 | export MUJOCO_EGL_DEVICE_ID=$device_id 9 | 10 | export OPENPI_DATA_HOME=./openpi 11 | export EXP=./logs/$proj_name; 12 | export CUDA_VISIBLE_DEVICES=$device_id 13 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 14 | 15 | pip install mujoco==3.3.1 16 | 17 | python3 examples/launch_train_sim.py \ 18 | --algorithm pixel_sac \ 19 | --env libero \ 20 | --prefix dsrl_pi0_libero \ 21 | --wandb_project ${proj_name} \ 22 | --batch_size 256 \ 23 | --discount 0.999 \ 24 | --seed 0 \ 25 | --max_steps 500000 \ 26 | --eval_interval 10000 \ 27 | --log_interval 500 \ 28 | --eval_episodes 10 \ 29 | --multi_grad_step 20 \ 30 | --start_online_updates 500 \ 31 | --resize_image 64 \ 32 | --action_magnitude 1.0 \ 33 | --query_freq 20 \ 34 | --hidden_dims 128 \ -------------------------------------------------------------------------------- /examples/scripts/run_real.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | proj_name=DSRL_pi0_FrankaDroid 3 | device_id=0 4 | 5 | export EXP=./logs/$proj_name; 6 | export CUDA_VISIBLE_DEVICES=$device_id 7 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 8 | 9 | # Fill inFranka Droid camera IDs 10 | export LEFT_CAMERA_ID="" 11 | export RIGHT_CAMERA_ID="" 12 | export WRIST_CAMERA_ID="" 13 | 14 | # Fill inpi0 remote host and port 15 | export remote_host="" 16 | export remote_port="" 17 | 18 | 19 | python3 examples/launch_train_real.py \ 20 | --algorithm pixel_sac \ 21 | --env franka_droid \ 22 | --prefix dsrl_pi0_real \ 23 | --wandb_project ${proj_name} \ 24 | --batch_size 256 \ 25 | --discount 0.99 \ 26 | --seed 0 \ 27 | --max_steps 500000 \ 28 | --eval_interval 2000 \ 29 | --log_interval 100 \ 30 | --multi_grad_step 30 \ 31 | --resize_image 128 \ 32 | --action_magnitude 2.5 \ 33 | --query_freq 10 \ 34 | --hidden_dims 1024 \ 35 | --num_qs 2 -------------------------------------------------------------------------------- /jaxrl2/networks/values/state_action_ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Sequence 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | from jaxrl2.networks.values.state_action_value import StateActionValue 7 | 8 | 9 | class StateActionEnsemble(nn.Module): 10 | hidden_dims: Sequence[int] 11 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 12 | num_qs: int = 2 13 | use_action_sep: bool = False 14 | 15 | @nn.compact 16 | def __call__(self, states, actions, training: bool = False): 17 | 18 | # print ('Use action sep in state action ensemble: ', self.use_action_sep) 19 | VmapCritic = nn.vmap(StateActionValue, 20 | variable_axes={'params': 0}, 21 | split_rngs={'params': True}, 22 | in_axes=None, 23 | out_axes=0, 24 | axis_size=self.num_qs) 25 | qs = VmapCritic(self.hidden_dims, 26 | activations=self.activations, 27 | use_action_sep=self.use_action_sep)(states, actions, 28 | training) 29 | return qs 30 | -------------------------------------------------------------------------------- /jaxrl2/networks/values/state_value.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Sequence 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | from jaxrl2.networks.mlp import MLP 7 | 8 | 9 | class StateValue(nn.Module): 10 | hidden_dims: Sequence[int] 11 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 12 | 13 | @nn.compact 14 | def __call__(self, 15 | observations: jnp.ndarray, 16 | training: bool = False) -> jnp.ndarray: 17 | critic = MLP((*self.hidden_dims, 1), 18 | activations=self.activations)(observations, 19 | training=training) 20 | return jnp.squeeze(critic, -1) 21 | 22 | 23 | class StateValueEnsemble(nn.Module): 24 | hidden_dims: Sequence[int] 25 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 26 | num_vs: int = 2 27 | 28 | @nn.compact 29 | def __call__(self, observations, training: bool = False): 30 | VmapCritic = nn.vmap(StateValue, 31 | variable_axes={'params': 0}, 32 | split_rngs={'params': True}, 33 | in_axes=None, 34 | out_axes=0, 35 | axis_size=self.num_vs) 36 | qs = VmapCritic(self.hidden_dims, 37 | activations=self.activations)(observations, 38 | training) 39 | return qs 40 | -------------------------------------------------------------------------------- /jaxrl2/networks/normal_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import distrax 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | 7 | from jaxrl2.networks import MLP 8 | from jaxrl2.networks.constants import default_init, xavier_init 9 | 10 | 11 | class NormalPolicy(nn.Module): 12 | hidden_dims: Sequence[int] 13 | action_dim: int 14 | dropout_rate: Optional[float] = None 15 | std: Optional[float] = 1. 16 | init_scale: Optional[float] = 1. 17 | output_scale: Optional[float] = 1. 18 | init_method: str = 'xavier' 19 | 20 | @nn.compact 21 | def __call__(self, 22 | observations: jnp.ndarray, 23 | training: bool = False) -> distrax.Distribution: 24 | outputs = MLP(self.hidden_dims, 25 | activate_final=True, 26 | dropout_rate=self.dropout_rate, 27 | init_scale = self.init_scale 28 | )(observations, training=training) 29 | 30 | if self.init_method == 'xavier': 31 | # print('fc layer {}x{}'.format(outputs.shape, self.action_dim)) 32 | means = nn.Dense(self.action_dim, kernel_init=xavier_init())(outputs) 33 | else: 34 | means = nn.Dense(self.action_dim, kernel_init=default_init(self.init_scale))(outputs) 35 | 36 | means *= self.output_scale 37 | 38 | return distrax.MultivariateNormalDiag(loc=means, 39 | scale_diag=jnp.ones_like(means)*self.std) 40 | -------------------------------------------------------------------------------- /jaxrl2/networks/values/state_action_value.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Sequence 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | import jax 7 | 8 | from jaxrl2.networks.mlp import MLP 9 | from jaxrl2.networks.mlp import MLPActionSep 10 | from jaxrl2.networks.constants import default_init 11 | 12 | from typing import (Any, Callable, Iterable, List, Optional, Sequence, Tuple, 13 | Union) 14 | 15 | PRNGKey = Any 16 | Shape = Tuple[int, ...] 17 | Dtype = Any 18 | Array = Any 19 | PrecisionLike = Union[None, str, jax.lax.Precision, Tuple[str, str], 20 | Tuple[jax.lax.Precision, jax.lax.Precision]] 21 | 22 | default_kernel_init = nn.initializers.lecun_normal() 23 | 24 | class StateActionValue(nn.Module): 25 | hidden_dims: Sequence[int] 26 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 27 | use_action_sep: bool = False 28 | 29 | @nn.compact 30 | def __call__(self, 31 | observations: jnp.ndarray, 32 | actions: jnp.ndarray, 33 | training: bool = False): 34 | inputs = {'states': observations, 'actions': actions} 35 | if self.use_action_sep: 36 | critic = MLPActionSep( 37 | (*self.hidden_dims, 1), 38 | activations=self.activations, 39 | use_layer_norm=True)(inputs, training=training) 40 | else: 41 | critic = MLP((*self.hidden_dims, 1), 42 | activations=self.activations, 43 | use_layer_norm=True)(inputs, training=training) 44 | return jnp.squeeze(critic, -1) 45 | -------------------------------------------------------------------------------- /jaxrl2/networks/encoders/spatial_softmax.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Sequence, Union 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from flax.core.frozen_dict import FrozenDict 8 | 9 | from jaxrl2.networks.constants import default_init, xavier_init, kaiming_init 10 | 11 | from functools import partial 12 | from typing import Any, Callable, Sequence, Tuple 13 | import distrax 14 | import wandb 15 | 16 | ModuleDef = Any 17 | 18 | class SpatialSoftmax(nn.Module): 19 | height: int 20 | width: int 21 | channel: int 22 | pos_x: jnp.ndarray 23 | pos_y: jnp.ndarray 24 | temperature: None 25 | log_heatmap: bool = False 26 | 27 | @nn.compact 28 | def __call__(self, feature): 29 | if self.temperature == -1: 30 | from jax.nn import initializers 31 | # print("Trainable temperature parameter") 32 | temperature = self.param('softmax_temperature', initializers.ones, (1), jnp.float32) 33 | else: 34 | temperature = 1. 35 | 36 | # print(temperature) 37 | assert len(feature.shape) == 4 38 | batch_size, num_featuremaps = feature.shape[0], feature.shape[3] 39 | feature = feature.transpose(0, 3, 1, 2).reshape(batch_size, num_featuremaps, self.height * self.width) 40 | 41 | softmax_attention = nn.softmax(feature / temperature) 42 | expected_x = jnp.sum(self.pos_x * softmax_attention, axis=2, keepdims=True).reshape(batch_size, num_featuremaps) 43 | expected_y = jnp.sum(self.pos_y * softmax_attention, axis=2, keepdims=True).reshape(batch_size, num_featuremaps) 44 | expected_xy = jnp.concatenate([expected_x, expected_y], axis=1) 45 | 46 | expected_xy = jnp.reshape(expected_xy, [batch_size, 2*num_featuremaps]) 47 | return expected_xy 48 | 49 | -------------------------------------------------------------------------------- /jaxrl2/networks/encoders/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Sequence, Union 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | from flax.core.frozen_dict import FrozenDict 7 | 8 | from jaxrl2.networks.constants import default_init, xavier_init, kaiming_init 9 | 10 | from functools import partial 11 | from typing import Any, Callable, Sequence, Tuple 12 | import distrax 13 | 14 | ModuleDef = Any 15 | 16 | class Encoder(nn.Module): 17 | features: Sequence[int] = (32, 32, 32, 32) 18 | strides: Sequence[int] = (2, 1, 1, 1) 19 | padding: str = 'VALID' 20 | 21 | @nn.compact 22 | def __call__(self, observations: jnp.ndarray, training=False) -> jnp.ndarray: 23 | assert len(self.features) == len(self.strides) 24 | 25 | x = observations.astype(jnp.float32) / 255.0 26 | x = jnp.reshape(x, (*x.shape[:-2], -1)) 27 | 28 | for features, stride in zip(self.features, self.strides): 29 | x = nn.Conv(features, 30 | kernel_size=(3, 3), 31 | strides=(stride, stride), 32 | kernel_init=default_init(), 33 | padding=self.padding)(x) 34 | x = nn.relu(x) 35 | 36 | return x.reshape((*x.shape[:-3], -1)) 37 | 38 | 39 | class PixelMultiplexer(nn.Module): 40 | encoder: Union[nn.Module, list] 41 | network: nn.Module 42 | latent_dim: int 43 | use_bottleneck: bool=True 44 | 45 | @nn.compact 46 | def __call__(self, 47 | observations: Union[FrozenDict, Dict], 48 | actions: Optional[jnp.ndarray] = None, 49 | training: bool = False): 50 | observations = FrozenDict(observations) 51 | 52 | x = self.encoder(observations['pixels'], training) 53 | if self.use_bottleneck: 54 | x = nn.Dense(self.latent_dim, kernel_init=xavier_init())(x) 55 | x = nn.LayerNorm()(x) 56 | x = nn.tanh(x) 57 | 58 | x = observations.copy(add_or_replace={'pixels': x}) 59 | 60 | # print('fully connected keys', x.keys()) 61 | if actions is None: 62 | return self.network(x, training=training) 63 | else: 64 | return self.network(x, actions, training=training) 65 | -------------------------------------------------------------------------------- /jaxrl2/agents/agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from flax.training import checkpoints 4 | import pathlib 5 | from flax.training.train_state import TrainState 6 | 7 | from jaxrl2.agents.common import (eval_actions_jit, eval_log_prob_jit, eval_mse_jit, eval_reward_function_jit, 8 | sample_actions_jit) 9 | from jaxrl2.data.dataset import DatasetDict 10 | from jaxrl2.types import PRNGKey 11 | 12 | 13 | def get_batch_stats(actor): 14 | if hasattr(actor, 'batch_stats'): 15 | return actor.batch_stats 16 | else: 17 | return None 18 | 19 | class Agent(object): 20 | _actor: TrainState 21 | _critic: TrainState 22 | _rng: PRNGKey 23 | 24 | def eval_actions(self, observations: np.ndarray) -> np.ndarray: 25 | actions = eval_actions_jit(self._actor.apply_fn, self._actor.params, 26 | observations, get_batch_stats(self._actor)) 27 | return np.asarray(actions) 28 | 29 | def eval_log_probs(self, batch: DatasetDict) -> float: 30 | return eval_log_prob_jit(self._actor.apply_fn, self._actor.params, get_batch_stats(self._actor), 31 | batch) 32 | 33 | def eval_mse(self, batch: DatasetDict) -> float: 34 | return eval_mse_jit(self._actor.apply_fn, self._actor.params, get_batch_stats(self._actor), 35 | batch) 36 | 37 | def eval_reward_function(self, batch: DatasetDict) -> float: 38 | return eval_reward_function_jit(self._actor.apply_fn, self._actor.params, self._actor.batch_stats, 39 | batch) 40 | 41 | def sample_actions(self, observations: np.ndarray) -> np.ndarray: 42 | rng, actions = sample_actions_jit(self._rng, self._actor.apply_fn, 43 | self._actor.params, observations, get_batch_stats(self._actor)) 44 | 45 | self._rng = rng 46 | return np.asarray(actions) 47 | 48 | @property 49 | def _save_dict(self): 50 | return None 51 | 52 | def save_checkpoint(self, dir, step, keep_every_n_steps): 53 | checkpoints.save_checkpoint(dir, self._save_dict, step, prefix='checkpoint', overwrite=False, keep_every_n_steps=keep_every_n_steps) 54 | 55 | def restore_checkpoint(self, dir): 56 | raise NotImplementedError 57 | 58 | 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .nox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | *.py,cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | db.sqlite3-journal 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # pipenv 84 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 85 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 86 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 87 | # install all needed dependencies. 88 | #Pipfile.lock 89 | 90 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 91 | __pypackages__/ 92 | 93 | # Celery stuff 94 | celerybeat-schedule 95 | celerybeat.pid 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | .idea/ 128 | tmp/ 129 | wandb_config.py 130 | logs/ 131 | wandb/ -------------------------------------------------------------------------------- /jaxrl2/agents/pixel_sac/critic_updater.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from flax.training.train_state import TrainState 6 | 7 | from jaxrl2.data.dataset import DatasetDict 8 | from jaxrl2.types import Params, PRNGKey 9 | 10 | 11 | def update_critic( 12 | key: PRNGKey, actor: TrainState, critic: TrainState, 13 | target_critic: TrainState, temp: TrainState, batch: DatasetDict, 14 | discount: float, backup_entropy: bool = False, 15 | critic_reduction: str = 'min') -> Tuple[TrainState, Dict[str, float]]: 16 | dist = actor.apply_fn({'params': actor.params}, batch['next_observations']) 17 | next_actions, next_log_probs = dist.sample_and_log_prob(seed=key) 18 | next_qs = target_critic.apply_fn({'params': target_critic.params}, 19 | batch['next_observations'], next_actions) 20 | if critic_reduction == 'min': 21 | next_q = next_qs.min(axis=0) 22 | elif critic_reduction == 'mean': 23 | next_q = next_qs.mean(axis=0) 24 | else: 25 | raise NotImplemented() 26 | 27 | target_q = batch['rewards'] + batch["discount"] * batch['masks'] * next_q 28 | 29 | if backup_entropy: 30 | target_q -= batch["discount"] * batch['masks'] * temp.apply_fn( 31 | {'params': temp.params}) * next_log_probs 32 | 33 | def critic_loss_fn( 34 | critic_params: Params) -> Tuple[jnp.ndarray, Dict[str, float]]: 35 | qs = critic.apply_fn({'params': critic_params}, batch['observations'], 36 | batch['actions']) 37 | critic_loss = ((qs - target_q)**2).mean() 38 | return critic_loss, { 39 | 'critic_loss': critic_loss, 40 | 'q': qs.mean(), 41 | 'target_actor_entropy': -next_log_probs.mean(), 42 | 'next_actions_sampled': next_actions.mean(), 43 | 'next_log_probs': next_log_probs.mean(), 44 | 'next_q_pi': next_qs.mean(), 45 | 'target_q': target_q.mean(), 46 | 'next_actions_mean': next_actions.mean(), 47 | 'next_actions_std': next_actions.std(), 48 | 'next_actions_min': next_actions.min(), 49 | 'next_actions_max': next_actions.max(), 50 | 'next_log_probs': next_log_probs.mean(), 51 | 52 | } 53 | 54 | grads, info = jax.grad(critic_loss_fn, has_aux=True)(critic.params) 55 | new_critic = critic.apply_gradients(grads=grads) 56 | 57 | return new_critic, info 58 | -------------------------------------------------------------------------------- /examples/launch_train_sim.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from examples.train_sim import main 4 | from jaxrl2.utils.launch_util import parse_training_args 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--seed', default=42, help='Random seed.', type=int) 11 | parser.add_argument('--launch_group_id', default='', help='group id used to group runs on wandb.') 12 | parser.add_argument('--eval_episodes', default=10,help='Number of episodes used for evaluation.', type=int) 13 | parser.add_argument('--env', default='libero', help='name of environment') 14 | parser.add_argument('--log_interval', default=1000, help='Logging interval.', type=int) 15 | parser.add_argument('--eval_interval', default=5000, help='Eval interval.', type=int) 16 | parser.add_argument('--checkpoint_interval', default=-1, help='checkpoint interval.', type=int) 17 | parser.add_argument('--batch_size', default=16, help='Mini batch size.', type=int) 18 | parser.add_argument('--max_steps', default=int(1e6), help='Number of training steps.', type=int) 19 | parser.add_argument('--add_states', default=1, help='whether to add low-dim states to the obervations', type=int) 20 | parser.add_argument('--wandb_project', default='cql_sim_online', help='wandb project') 21 | parser.add_argument('--start_online_updates', default=1000, help='number of steps to collect before starting online updates', type=int) 22 | parser.add_argument('--algorithm', default='pixel_sac', help='type of algorithm') 23 | parser.add_argument('--prefix', default='', help='prefix to use for wandb') 24 | parser.add_argument('--suffix', default='', help='suffix to use for wandb') 25 | parser.add_argument('--multi_grad_step', default=1, help='Number of graident steps to take per environment step, aka UTD', type=int) 26 | parser.add_argument('--resize_image', default=-1, help='the size of image if need resizing', type=int) 27 | parser.add_argument('--query_freq', default=-1, help='query frequency', type=int) 28 | 29 | train_args_dict = dict( 30 | actor_lr=1e-4, 31 | critic_lr= 3e-4, 32 | temp_lr=3e-4, 33 | hidden_dims= (128, 128, 128), 34 | cnn_features= (32, 32, 32, 32), 35 | cnn_strides= (2, 1, 1, 1), 36 | cnn_padding= 'VALID', 37 | latent_dim= 50, 38 | discount= 0.999, 39 | tau= 0.005, 40 | critic_reduction = 'mean', 41 | dropout_rate=0.0, 42 | aug_next=1, 43 | use_bottleneck=True, 44 | encoder_type='small', 45 | encoder_norm='group', 46 | use_spatial_softmax=True, 47 | softmax_temperature=-1, 48 | target_entropy='auto', 49 | num_qs=10, 50 | action_magnitude=1.0, 51 | num_cameras=1, 52 | ) 53 | 54 | variant, args = parse_training_args(train_args_dict, parser) 55 | print(variant) 56 | main(variant) 57 | sys.exit() 58 | -------------------------------------------------------------------------------- /jaxrl2/networks/encoders/resnet_encoderv2.py: -------------------------------------------------------------------------------- 1 | # Based on: 2 | # https://github.com/google/flax/blob/main/examples/imagenet/models.py 3 | # and 4 | # https://github.com/google-research/big_transfer/blob/master/bit_jax/models.py 5 | from functools import partial 6 | from typing import Any, Callable, Sequence, Tuple 7 | 8 | import flax.linen as nn 9 | import jax.numpy as jnp 10 | from flax import linen as nn 11 | 12 | ModuleDef = Any 13 | 14 | 15 | class ResNetV2Block(nn.Module): 16 | """ResNet block.""" 17 | filters: int 18 | conv: ModuleDef 19 | norm: ModuleDef 20 | act: Callable 21 | strides: Tuple[int, int] = (1, 1) 22 | 23 | @nn.compact 24 | def __call__(self, x): 25 | residual = x 26 | y = self.norm()(x) 27 | y = self.act(y) 28 | y = self.conv(self.filters, (3, 3), self.strides)(y) 29 | y = self.norm()(y) 30 | y = self.act(y) 31 | y = self.conv(self.filters, (3, 3))(y) 32 | 33 | if residual.shape != y.shape: 34 | residual = self.conv(self.filters, (1, 1), self.strides)(residual) 35 | 36 | return residual + y 37 | 38 | 39 | class MyGroupNorm(nn.GroupNorm): 40 | 41 | def __call__(self, x): 42 | if x.ndim == 3: 43 | x = x[jnp.newaxis] 44 | x = super().__call__(x) 45 | return x[0] 46 | else: 47 | return super().__call__(x) 48 | 49 | 50 | class ResNetV2Encoder(nn.Module): 51 | """ResNetV2.""" 52 | stage_sizes: Sequence[int] 53 | num_filters: int = 16 54 | dtype: Any = jnp.float32 55 | act: Callable = nn.relu 56 | norm: str = 'batch' 57 | 58 | @nn.compact 59 | def __call__(self, x, train: bool = True): 60 | conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) 61 | if self.norm == 'batch': 62 | norm = partial(nn.BatchNorm, 63 | use_running_average=not train, 64 | momentum=0.9, 65 | epsilon=1e-5, 66 | dtype=self.dtype) 67 | elif self.norm == 'groupnorm': 68 | norm = partial(MyGroupNorm, 69 | num_groups=4, 70 | epsilon=1e-5, 71 | dtype=self.dtype) 72 | else: 73 | raise ValueError('norm not found') 74 | 75 | x = x.astype(jnp.float32) / 255.0 76 | x = jnp.reshape(x, (*x.shape[:-2], -1)) 77 | 78 | x = conv(self.num_filters, (3, 3))(x) 79 | for i, block_size in enumerate(self.stage_sizes): 80 | for j in range(block_size): 81 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 82 | x = ResNetV2Block(self.num_filters * 2**i, 83 | strides=strides, 84 | conv=conv, 85 | norm=norm, 86 | act=self.act)(x) 87 | 88 | x = norm()(x) 89 | x = self.act(x) 90 | return x.reshape((*x.shape[:-3], -1)) 91 | -------------------------------------------------------------------------------- /examples/launch_train_real.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from examples.train_real import main 4 | from jaxrl2.utils.launch_util import parse_training_args 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--seed', default=42, help='Random seed.', type=int) 11 | parser.add_argument('--launch_group_id', default='', help='group id used to group runs on wandb.') 12 | parser.add_argument('--eval_episodes', default=10,help='Number of episodes used for evaluation.', type=int) 13 | parser.add_argument('--env', default='libero', help='name of environment') 14 | parser.add_argument('--log_interval', default=1000, help='Logging interval.', type=int) 15 | parser.add_argument('--eval_interval', default=5000, help='Eval interval.', type=int) 16 | parser.add_argument('--checkpoint_interval', default=-1, help='checkpoint interval.', type=int) 17 | parser.add_argument('--batch_size', default=16, help='Mini batch size.', type=int) 18 | parser.add_argument('--max_steps', default=int(1e6), help='Number of training steps.', type=int) 19 | parser.add_argument('--add_states', default=1, help='whether to add low-dim states to the obervations', type=int) 20 | parser.add_argument('--wandb_project', default='cql_sim_online', help='wandb project') 21 | parser.add_argument('--num_initial_traj_collect', default=1, help='number of trajectories to collect before starting online updates', type=int) 22 | parser.add_argument('--algorithm', default='pixel_sac', help='type of algorithm') 23 | parser.add_argument('--prefix', default='', help='prefix to use for wandb') 24 | parser.add_argument('--suffix', default='', help='suffix to use for wandb') 25 | parser.add_argument('--multi_grad_step', default=1, help='Number of graident steps to take per environment step, aka UTD', type=int) 26 | parser.add_argument('--resize_image', default=-1, help='the size of image if need resizing', type=int) 27 | parser.add_argument('--query_freq', default=-1, help='query frequency', type=int) 28 | parser.add_argument('--instruction', default='put the spoon on the plate', help='language instruction for the robot') 29 | 30 | # The hyperparameters for the real robot experiments 31 | train_args_dict = dict( 32 | actor_lr=1e-4, 33 | critic_lr= 3e-4, 34 | temp_lr=3e-4, 35 | hidden_dims= (1024, 1024, 1024), 36 | cnn_features= (32, 32, 32, 32), 37 | cnn_strides= (3, 2, 2, 2), 38 | cnn_padding= 'VALID', 39 | latent_dim= 50, 40 | discount= 0.99, 41 | tau= 0.005, 42 | critic_reduction = 'min', 43 | dropout_rate=0.0, 44 | aug_next=1, 45 | use_bottleneck=True, 46 | encoder_type='small', 47 | encoder_norm='group', 48 | use_spatial_softmax=True, 49 | softmax_temperature=-1, 50 | target_entropy=0.0, 51 | num_qs=2, 52 | action_magnitude=2.5, 53 | num_cameras=3, 54 | ) 55 | 56 | variant, args = parse_training_args(train_args_dict, parser) 57 | print(variant) 58 | main(variant) 59 | sys.exit() 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # DSRL for π₀: Diffusion Steering via Reinforcement Learning 4 | 5 | ## [[website](https://diffusion-steering.github.io)] [[paper](https://arxiv.org/abs/2506.15799)] 6 | 7 |
8 | 9 | 10 | ## Overview 11 | This repository provides the official implementation for our paper: [Steering Your Diffusion Policy with Latent Space Reinforcement Learning](https://arxiv.org/abs/2506.15799) (CoRL 2025). 12 | 13 | Specifically, it contains a JAX-based implementation of DSRL (Diffusion Steering via Reinforcement Learning) for steering a pre-trained generalist policy, [π₀](https://github.com/Physical-Intelligence/openpi), across various environments, including: 14 | 15 | - **Simulation:** Libero, Aloha 16 | - **Real Robot:** Franka 17 | 18 | If you find this repository useful for your research, please cite: 19 | 20 | ``` 21 | @article{wagenmaker2025steering, 22 | author = {Andrew Wagenmaker and Mitsuhiko Nakamoto and Yunchu Zhang and Seohong Park and Waleed Yagoub and Anusha Nagabandi and Abhishek Gupta and Sergey Levine}, 23 | title = {Steering Your Diffusion Policy with Latent Space Reinforcement Learning}, 24 | journal = {Conference on Robot Learning (CoRL)}, 25 | year = {2025}, 26 | } 27 | ``` 28 | 29 | ## Installation 30 | 1. Create a conda environment: 31 | ``` 32 | conda create -n dsrl_pi0 python=3.11.11 33 | conda activate dsrl_pi0 34 | ``` 35 | 36 | 2. Clone this repo with all submodules 37 | ``` 38 | git clone git@github.com:nakamotoo/dsrl_pi0.git --recurse-submodules 39 | cd dsrl_pi0 40 | ``` 41 | 42 | 3. Install all packages and dependencies 43 | ``` 44 | pip install -e . 45 | pip install -r requirements.txt 46 | pip install "jax[cuda12]==0.5.0" 47 | 48 | # install openpi 49 | pip install -e openpi 50 | pip install -e openpi/packages/openpi-client 51 | 52 | # install Libero 53 | pip install -e LIBERO 54 | pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu # needed for libero 55 | ``` 56 | 57 | ## Training (Simulation) 58 | Libero 59 | ``` 60 | bash examples/scripts/run_libero.sh 61 | ``` 62 | Aloha 63 | ``` 64 | bash examples/scripts/run_aloha.sh 65 | ``` 66 | ### Training Logs 67 | We provide sample W&B runs and logs: https://wandb.ai/mitsuhiko/DSRL_pi0_public 68 | 69 | ## Training (Real) 70 | For real-world experiments, we use the remote hosting feature from pi0 (see [here](https://github.com/Physical-Intelligence/openpi/blob/main/docs/remote_inference.md)) which enables us to host the pi0 model on a higher-spec remote server, in case the robot's client machine is not powerful enough. 71 | 72 | 0. Setup Franka robot and install DROID package [[link](https://github.com/droid-dataset/droid.git)] 73 | 74 | 1. [On the remote server] Host pi0 droid model on your remote server 75 | ``` 76 | cd openpi && python scripts/serve_policy.py --env=DROID 77 | ``` 78 | 2. [On your robot client machine] Run DSRL 79 | ``` 80 | bash examples/scripts/run_real.sh 81 | ``` 82 | 83 | 84 | ## Credits 85 | This repository is built upon [jaxrl2](https://github.com/ikostrikov/jaxrl2) and [PTR](https://github.com/Asap7772/PTR) repositories. 86 | In case of any questions, bugs, suggestions or improvements, please feel free to contact me at nakamoto\[at\]berkeley\[dot\]edu 87 | -------------------------------------------------------------------------------- /jaxrl2/networks/normal_tanh_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import distrax 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | from tensorflow_probability.substrates import jax as tfp 7 | 8 | from jaxrl2.networks import MLP 9 | from jaxrl2.networks.constants import default_init, xavier_init 10 | 11 | 12 | class TanhMultivariateNormalDiag(distrax.Transformed): 13 | 14 | def __init__(self, 15 | loc: jnp.ndarray, 16 | scale_diag: jnp.ndarray, 17 | low: Optional[jnp.ndarray] = None, 18 | high: Optional[jnp.ndarray] = None): 19 | distribution = distrax.MultivariateNormalDiag(loc=loc, 20 | scale_diag=scale_diag) 21 | 22 | layers = [] 23 | 24 | if not (low is None or high is None): 25 | 26 | def rescale_from_tanh(x): 27 | x = (x + 1) / 2 # (-1, 1) => (0, 1) 28 | return x * (high - low) + low 29 | 30 | def forward_log_det_jacobian(x): 31 | high_ = jnp.broadcast_to(high, x.shape) 32 | low_ = jnp.broadcast_to(low, x.shape) 33 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1) 34 | 35 | layers.append( 36 | distrax.Lambda( 37 | rescale_from_tanh, 38 | forward_log_det_jacobian=forward_log_det_jacobian, 39 | event_ndims_in=1, 40 | event_ndims_out=1)) 41 | 42 | layers.append(distrax.Block(distrax.Tanh(), 1)) 43 | 44 | bijector = distrax.Chain(layers) 45 | 46 | super().__init__(distribution=distribution, bijector=bijector) 47 | 48 | def mode(self) -> jnp.ndarray: 49 | return self.bijector.forward(self.distribution.mode()) 50 | 51 | 52 | class NormalTanhPolicy(nn.Module): 53 | hidden_dims: Sequence[int] 54 | action_dim: int 55 | dropout_rate: Optional[float] = None 56 | log_std_min: Optional[float] = -20 57 | log_std_max: Optional[float] = 2 58 | low: Optional[jnp.ndarray] = None 59 | high: Optional[jnp.ndarray] = None 60 | mlp_init_scale: float = 1.0 61 | init_method: str = 'default' 62 | 63 | @nn.compact 64 | def __call__(self, 65 | observations: jnp.ndarray, 66 | training: bool = False) -> distrax.Distribution: 67 | outputs = MLP(self.hidden_dims, 68 | activate_final=True, 69 | dropout_rate=self.dropout_rate, 70 | init_scale=self.mlp_init_scale)(observations, 71 | training=training) 72 | 73 | if self.init_method == 'xavier': 74 | means = nn.Dense(self.action_dim, kernel_init=xavier_init())(outputs) 75 | log_stds = nn.Dense(self.action_dim, kernel_init=xavier_init())(outputs) 76 | else: 77 | means = nn.Dense(self.action_dim, kernel_init=default_init(self.mlp_init_scale))(outputs) 78 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init())(outputs) 79 | 80 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 81 | 82 | return TanhMultivariateNormalDiag(loc=means, 83 | scale_diag=jnp.exp(log_stds) * self.mlp_init_scale, 84 | low=self.low, 85 | high=self.high) 86 | -------------------------------------------------------------------------------- /jaxrl2/networks/encoders/impala_encoder.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | 5 | class ResnetStack(nn.Module): 6 | num_ch: int 7 | num_blocks: int 8 | use_max_pooling: bool = True 9 | 10 | @nn.compact 11 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 12 | initializer = nn.initializers.xavier_uniform() 13 | conv_out = nn.Conv( 14 | features=self.num_ch, 15 | kernel_size=(3, 3), 16 | strides=1, 17 | kernel_init=initializer, 18 | padding='SAME' 19 | )(observations) 20 | 21 | if self.use_max_pooling: 22 | conv_out = nn.max_pool( 23 | conv_out, 24 | window_shape=(3, 3), 25 | padding='SAME', 26 | strides=(2, 2) 27 | ) 28 | 29 | for _ in range(self.num_blocks): 30 | block_input = conv_out 31 | conv_out = nn.relu(conv_out) 32 | conv_out = nn.Conv( 33 | features=self.num_ch, kernel_size=(3, 3), strides=1, 34 | padding='SAME', 35 | kernel_init=initializer)(conv_out) 36 | 37 | conv_out = nn.relu(conv_out) 38 | conv_out = nn.Conv( 39 | features=self.num_ch, kernel_size=(3, 3), strides=1, 40 | padding='SAME', kernel_init=initializer 41 | )(conv_out) 42 | conv_out += block_input 43 | 44 | return conv_out 45 | 46 | 47 | class ImpalaEncoder(nn.Module): 48 | nn_scale: int = 1 49 | 50 | def setup(self): 51 | stack_sizes = [16, 32, 32] 52 | self.stack_blocks = [ 53 | ResnetStack( 54 | num_ch=stack_sizes[0] * self.nn_scale, 55 | num_blocks=2), 56 | ResnetStack( 57 | num_ch=stack_sizes[1] * self.nn_scale, 58 | num_blocks=2), 59 | ResnetStack( 60 | num_ch=stack_sizes[2] * self.nn_scale, 61 | num_blocks=2), 62 | ] 63 | 64 | @nn.compact 65 | def __call__(self, x, train=True): 66 | x = x.astype(jnp.float32) / 255.0 67 | x = jnp.reshape(x, (*x.shape[:-2], -1)) 68 | 69 | conv_out = x 70 | 71 | for idx in range(len(self.stack_blocks)): 72 | conv_out = self.stack_blocks[idx](conv_out) 73 | 74 | conv_out = nn.relu(conv_out) 75 | return conv_out.reshape((*x.shape[:-3], -1)) 76 | 77 | 78 | class SmallerImpalaEncoder(nn.Module): 79 | nn_scale: int = 1 80 | 81 | def setup(self): 82 | stack_sizes = [16, 32, 32] 83 | self.stack_blocks = [ 84 | ResnetStack( 85 | num_ch=stack_sizes[0] * self.nn_scale, 86 | num_blocks=2), 87 | ResnetStack( 88 | num_ch=stack_sizes[1] * self.nn_scale, 89 | num_blocks=1), 90 | ResnetStack( 91 | num_ch=stack_sizes[2] * self.nn_scale, 92 | num_blocks=1), 93 | ] 94 | 95 | @nn.compact 96 | def __call__(self, x, train=True): 97 | x = x.astype(jnp.float32) / 255.0 98 | x = jnp.reshape(x, (*x.shape[:-2], -1)) 99 | 100 | conv_out = x 101 | 102 | for idx in range(len(self.stack_blocks)): 103 | conv_out = self.stack_blocks[idx](conv_out) 104 | 105 | conv_out = nn.relu(conv_out) 106 | return conv_out.reshape((*x.shape[:-3], -1)) 107 | -------------------------------------------------------------------------------- /jaxrl2/agents/pixel_sac/actor_updater.py: -------------------------------------------------------------------------------- 1 | from audioop import cross 2 | from typing import Dict, Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from flax.training.train_state import TrainState 7 | 8 | from jaxrl2.data.dataset import DatasetDict 9 | from jaxrl2.types import Params, PRNGKey 10 | 11 | 12 | def update_actor(key: PRNGKey, actor: TrainState, critic: TrainState, 13 | temp: TrainState, batch: DatasetDict, cross_norm:bool=False, critic_reduction:str='min') -> Tuple[TrainState, Dict[str, float]]: 14 | 15 | key, key_act = jax.random.split(key, num=2) 16 | 17 | def actor_loss_fn( 18 | actor_params: Params) -> Tuple[jnp.ndarray, Dict[str, float]]: 19 | if hasattr(actor, 'batch_stats') and actor.batch_stats is not None: 20 | dist, new_model_state = actor.apply_fn({'params': actor_params, 'batch_stats': actor.batch_stats}, batch['observations'], mutable=['batch_stats']) 21 | if cross_norm: 22 | next_dist = actor.apply_fn({'params': actor_params, 'batch_stats': actor.batch_stats}, batch['next_observations'], mutable=['batch_stats']) 23 | else: 24 | next_dist = actor.apply_fn({'params': actor_params, 'batch_stats': actor.batch_stats}, batch['next_observations']) 25 | if type(next_dist) == tuple: 26 | next_dist, new_model_state = next_dist 27 | else: 28 | dist = actor.apply_fn({'params': actor_params}, batch['observations']) 29 | next_dist = actor.apply_fn({'params': actor_params}, batch['next_observations']) 30 | new_model_state = {} 31 | 32 | # For logging only 33 | mean_dist = dist.distribution._loc 34 | std_diag_dist = dist.distribution._scale_diag 35 | mean_dist_norm = jnp.linalg.norm(mean_dist, axis=-1) 36 | std_dist_norm = jnp.linalg.norm(std_diag_dist, axis=-1) 37 | 38 | 39 | actions, log_probs = dist.sample_and_log_prob(seed=key_act) 40 | 41 | if hasattr(critic, 'batch_stats') and critic.batch_stats is not None: 42 | qs, _ = critic.apply_fn({'params': critic.params, 'batch_stats': critic.batch_stats}, batch['observations'], 43 | actions, mutable=['batch_stats']) 44 | else: 45 | qs = critic.apply_fn({'params': critic.params}, batch['observations'], actions) 46 | 47 | if critic_reduction == 'min': 48 | q = qs.min(axis=0) 49 | elif critic_reduction == 'mean': 50 | q = qs.mean(axis=0) 51 | else: 52 | raise ValueError(f"Invalid critic reduction: {critic_reduction}") 53 | actor_loss = (log_probs * temp.apply_fn({'params': temp.params}) - q).mean() 54 | 55 | things_to_log = { 56 | 'actor_loss': actor_loss, 57 | 'entropy': -log_probs.mean(), 58 | 'q_pi_in_actor': q.mean(), 59 | 'mean_pi_norm': mean_dist_norm.mean(), 60 | 'std_pi_norm': std_dist_norm.mean(), 61 | 'mean_pi_avg': mean_dist.mean(), 62 | 'mean_pi_max': mean_dist.max(), 63 | 'mean_pi_min': mean_dist.min(), 64 | 'std_pi_avg': std_diag_dist.mean(), 65 | 'std_pi_max': std_diag_dist.max(), 66 | 'std_pi_min': std_diag_dist.min(), 67 | } 68 | return actor_loss, (things_to_log, new_model_state) 69 | 70 | grads, (info, new_model_state) = jax.grad(actor_loss_fn, has_aux=True)(actor.params) 71 | 72 | if 'batch_stats' in new_model_state: 73 | new_actor = actor.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats']) 74 | else: 75 | new_actor = actor.apply_gradients(grads=grads) 76 | 77 | return new_actor, info -------------------------------------------------------------------------------- /jaxrl2/networks/learned_std_normal_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import distrax 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | 7 | from jaxrl2.networks import MLP 8 | from jaxrl2.networks.constants import default_init 9 | 10 | class LearnedStdNormalPolicy(nn.Module): 11 | hidden_dims: Sequence[int] 12 | action_dim: int 13 | dropout_rate: Optional[float] = None 14 | log_std_min: Optional[float] = -20 15 | log_std_max: Optional[float] = 2 16 | 17 | @nn.compact 18 | def __call__(self, 19 | observations: jnp.ndarray, 20 | training: bool = False) -> distrax.Distribution: 21 | outputs = MLP(self.hidden_dims, 22 | activate_final=True, 23 | dropout_rate=self.dropout_rate)(observations, 24 | training=training) 25 | 26 | means = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs) 27 | 28 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs) 29 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 30 | 31 | distribution = distrax.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds)) 32 | return distribution 33 | 34 | class TanhMultivariateNormalDiag(distrax.Transformed): 35 | 36 | def __init__(self, 37 | loc: jnp.ndarray, 38 | scale_diag: jnp.ndarray, 39 | low: Optional[jnp.ndarray] = None, 40 | high: Optional[jnp.ndarray] = None): 41 | distribution = distrax.MultivariateNormalDiag(loc=loc, 42 | scale_diag=scale_diag) 43 | 44 | layers = [] 45 | 46 | if not (low is None or high is None): 47 | 48 | def rescale_from_tanh(x): 49 | x = (x + 1) / 2 # (-1, 1) => (0, 1) 50 | return x * (high - low) + low 51 | 52 | def forward_log_det_jacobian(x): 53 | high_ = jnp.broadcast_to(high, x.shape) 54 | low_ = jnp.broadcast_to(low, x.shape) 55 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1) 56 | 57 | layers.append( 58 | distrax.Lambda( 59 | rescale_from_tanh, 60 | forward_log_det_jacobian=forward_log_det_jacobian, 61 | event_ndims_in=1, 62 | event_ndims_out=1)) 63 | 64 | layers.append(distrax.Block(distrax.Tanh(), 1)) 65 | 66 | bijector = distrax.Chain(layers) 67 | 68 | super().__init__(distribution=distribution, bijector=bijector) 69 | 70 | def mode(self) -> jnp.ndarray: 71 | return self.bijector.forward(self.distribution.mode()) 72 | 73 | class LearnedStdTanhNormalPolicy(nn.Module): 74 | hidden_dims: Sequence[int] 75 | action_dim: int 76 | dropout_rate: Optional[float] = None 77 | log_std_min: Optional[float] = -20 78 | log_std_max: Optional[float] = 2 79 | low: Optional[float] = None 80 | high: Optional[float] = None 81 | 82 | @nn.compact 83 | def __call__(self, 84 | observations: jnp.ndarray, 85 | training: bool = False) -> distrax.Distribution: 86 | outputs = MLP(self.hidden_dims, 87 | activate_final=True, 88 | dropout_rate=self.dropout_rate)(observations, 89 | training=training) 90 | 91 | means = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs) 92 | 93 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs) 94 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 95 | 96 | distribution = TanhMultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds), low=self.low, high=self.high) 97 | return distribution -------------------------------------------------------------------------------- /jaxrl2/networks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence, Union 2 | from flax.core import frozen_dict 3 | 4 | import numpy as np 5 | import flax.linen as nn 6 | import jax.numpy as jnp 7 | from flax.core.frozen_dict import FrozenDict 8 | 9 | from jaxrl2.networks.constants import default_init 10 | 11 | 12 | def _flatten_dict(x: Union[FrozenDict, jnp.ndarray]): 13 | if hasattr(x, 'values'): 14 | obs = [] 15 | for k, v in sorted(x.items()): 16 | # if k == "actions": 17 | # v = v[:, 0:1, ...] 18 | if k == 'state': # flatten action chunk to 1D 19 | obs.append(jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])])) 20 | # v = jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])]) 21 | elif k == 'prev_action' or k == 'actions': 22 | if v.ndim > 2: 23 | # deal with action chunk 24 | obs.append(jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])])) 25 | else: 26 | obs.append(v) 27 | else: 28 | obs.append(_flatten_dict(v)) 29 | return jnp.concatenate(obs, -1) 30 | else: 31 | return x 32 | 33 | def _flatten_dict_special(x): 34 | if hasattr(x, 'values'): 35 | obs = [] 36 | action = None 37 | for k, v in sorted(x.items()): 38 | if k == 'state' or k == 'prev_action': 39 | obs.append(jnp.reshape(v, [*v.shape[:-2], np.prod(v.shape[-2:])])) 40 | elif k == 'actions': 41 | print ('action shape: ', v.shape) 42 | action = v 43 | else: 44 | obs.append(_flatten_dict(v)) 45 | return jnp.concatenate(obs, -1), action 46 | else: 47 | return x 48 | 49 | 50 | class MLP(nn.Module): 51 | hidden_dims: Sequence[int] 52 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 53 | activate_final: int = False 54 | dropout_rate: Optional[float] = None 55 | init_scale: Optional[float] = 1. 56 | use_layer_norm: bool = False 57 | 58 | @nn.compact 59 | def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: 60 | x = _flatten_dict(x) 61 | # print('mlp post flatten', x.shape) 62 | 63 | for i, size in enumerate(self.hidden_dims): 64 | x = nn.Dense(size, kernel_init=default_init(self.init_scale))(x) 65 | # print('post fc size', x.shape) 66 | if i + 1 < len(self.hidden_dims) or self.activate_final: 67 | if self.dropout_rate is not None: 68 | x = nn.Dropout(rate=self.dropout_rate)( 69 | x, deterministic=not training) 70 | if self.use_layer_norm: 71 | x = nn.LayerNorm()(x) 72 | x = self.activations(x) 73 | return x 74 | 75 | 76 | class MLPActionSep(nn.Module): 77 | hidden_dims: Sequence[int] 78 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 79 | activate_final: int = False 80 | dropout_rate: Optional[float] = None 81 | init_scale: Optional[float] = 1. 82 | use_layer_norm: bool = False 83 | @nn.compact 84 | def __call__(self, x: jnp.ndarray, training: bool = False): 85 | x, action = _flatten_dict_special(x) 86 | print ('mlp action sep state post flatten', x.shape) 87 | print ('mlp action sep action post flatten', action.shape) 88 | 89 | for i, size in enumerate(self.hidden_dims): 90 | x_used = jnp.concatenate([x, action], axis=-1) 91 | x = nn.Dense(size, kernel_init=default_init())(x_used) 92 | print ('FF layers: ', x_used.shape, x.shape) 93 | if i + 1 < len(self.hidden_dims) or self.activate_final: 94 | if self.dropout_rate is not None: 95 | x = nn.Dropout(rate=self.dropout_rate)( 96 | x, deterministic=not training) 97 | if self.use_layer_norm: 98 | x = nn.LayerNorm()(x) 99 | x = self.activations(x) 100 | return x -------------------------------------------------------------------------------- /jaxrl2/utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import cv2 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 9 | 10 | 11 | def np_unstack(array, axis): 12 | arr = np.split(array, array.shape[axis], axis) 13 | arr = [a.squeeze() for a in arr] 14 | return arr 15 | 16 | def action2img(action, res, channels, action_scale): 17 | assert action.size == 2 # can only plot 2-dimensional actions 18 | img = np.zeros((res, res, channels), dtype=np.float32).copy() 19 | start_pt = res / 2 * np.ones((2,)) 20 | end_pt = start_pt + action * action_scale * (res / 2 - 1) * np.array([1, -1]) # swaps last dimension 21 | np2pt = lambda x: tuple(np.asarray(x, int)) 22 | img = cv2.arrowedLine(img, np2pt(start_pt), np2pt(end_pt), (255, 255, 255), 1, cv2.LINE_AA, tipLength=0.2) 23 | return img 24 | 25 | def batch_action2img(actions, res, channels, action_scale=50): 26 | batch = actions.shape[0] 27 | im = np.empty((batch, res, res, channels), dtype=np.float32) 28 | for b in range(batch): 29 | im[b] = action2img(actions[b], res, channels, action_scale) 30 | return im 31 | 32 | def visualize_image_actions(images, gtruth_actions, pred_actions): 33 | gtruth_action_row1 = batch_action2img(gtruth_actions[:, :2], 128, 3, action_scale=3) 34 | gtruth_action_row1 = np.concatenate(np_unstack(gtruth_action_row1, axis=0), axis=1) 35 | pred_action_row1 = batch_action2img(pred_actions[:, :2], 128, 3, action_scale=3) 36 | pred_action_row1 = np.concatenate(np_unstack(pred_action_row1, axis=0), axis=1) 37 | sel_image_row = np.concatenate(np_unstack(images, axis=0), axis=1) 38 | image_rows = [sel_image_row, gtruth_action_row1, pred_action_row1] 39 | out = np.concatenate(image_rows, axis=0) 40 | return out 41 | 42 | 43 | def visualize_states_rewards(states, rewards, target_point): 44 | states = states.squeeze() 45 | rewards = rewards.squeeze() 46 | 47 | fig, axs = plt.subplots(7, 1) 48 | fig.set_size_inches(5, 15) 49 | canvas = FigureCanvas(fig) 50 | plt.xlim([0, len(states)]) 51 | 52 | axs[0].plot(states[:, 0], linestyle='--', marker='o') 53 | axs[0].set_ylabel('states_x') 54 | axs[1].plot(states[:, 1], linestyle='--', marker='o') 55 | axs[1].set_ylabel('states_y') 56 | axs[2].plot(states[:, 2], linestyle='--', marker='o') 57 | axs[2].set_ylabel('states_z') 58 | 59 | axs[3].plot(np.abs(states[:, 0] - target_point[0]), linestyle='--', marker='o') 60 | axs[3].set_ylabel('norm_x') 61 | axs[4].plot(np.abs(states[:, 1] - target_point[1]), linestyle='--', marker='o') 62 | axs[4].set_ylabel('norm_y') 63 | axs[5].plot(np.abs(states[:, 2] - target_point[2]), linestyle='--', marker='o') 64 | axs[5].set_ylabel('norm_z') 65 | 66 | axs[6].plot(rewards, linestyle='--', marker='o') 67 | axs[6].set_ylabel('rewards') 68 | 69 | plt.tight_layout() 70 | 71 | canvas.draw() # draw the canvas, cache the renderer 72 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') 73 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 74 | 75 | plt.close(fig) 76 | return out_image 77 | 78 | def add_text_to_images(img_list, string_list): 79 | from PIL import Image 80 | from PIL import ImageDraw 81 | out = [] 82 | for im, string in zip(img_list, string_list): 83 | im = Image.fromarray(np.array(im).astype(np.uint8)) 84 | draw = ImageDraw.Draw(im) 85 | draw.text((10, 10), string, fill=(255, 0, 0, 128)) 86 | out.append(np.array(im)) 87 | return out 88 | 89 | def sigmoid(x): 90 | return 1. / (1. + jnp.exp(-x)) 91 | 92 | def visualize_image_rewards(images, gtruth_rewards, pred_rewards, obs, task_id_mapping): 93 | id_task_mapping = {v : k for (k, v) in task_id_mapping.items()} 94 | sel_images = np_unstack(images, axis=0) 95 | sel_images = add_text_to_images(sel_images, ["{:.2f} \n{:.2f} \nTask {}".format(gtruth_rewards[i], sigmoid(pred_rewards[i, 0]), np.argmax(obs['task_id'][i])) for i in range(gtruth_rewards.shape[0])]) 96 | sel_image_row = np.concatenate(sel_images, axis=1) 97 | return sel_image_row 98 | -------------------------------------------------------------------------------- /jaxrl2/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, Optional, Tuple, Union 2 | import collections 3 | import jax 4 | import numpy as np 5 | from gym.utils import seeding 6 | import jax.numpy as jnp 7 | from jaxrl2.types import DataType 8 | 9 | DatasetDict = Dict[str, DataType] 10 | from flax.core import frozen_dict 11 | 12 | def concat_recursive(batches): 13 | new_batch = {} 14 | for k, v in batches[0].items(): 15 | if isinstance(v, frozen_dict.FrozenDict): 16 | new_batch[k] = concat_recursive([batches[0][k], batches[1][k]]) 17 | else: 18 | new_batch[k] = np.concatenate([b[k] for b in batches], 0) 19 | return new_batch 20 | 21 | def _check_lengths(dataset_dict: DatasetDict, 22 | dataset_len: Optional[int] = None) -> int: 23 | for v in dataset_dict.values(): 24 | if isinstance(v, dict): 25 | dataset_len = dataset_len or _check_lengths(v, dataset_len) 26 | elif isinstance(v, np.ndarray): 27 | item_len = len(v) 28 | dataset_len = dataset_len or item_len 29 | assert dataset_len == item_len, 'Inconsistent item lengths in the dataset.' 30 | else: 31 | raise TypeError('Unsupported type.') 32 | return dataset_len 33 | 34 | 35 | def _split(dataset_dict: DatasetDict, 36 | index: int) -> Tuple[DatasetDict, DatasetDict]: 37 | train_dataset_dict, test_dataset_dict = {}, {} 38 | for k, v in dataset_dict.items(): 39 | if isinstance(v, dict): 40 | train_v, test_v = _split(v, index) 41 | elif isinstance(v, np.ndarray): 42 | train_v, test_v = v[:index], v[index:] 43 | else: 44 | raise TypeError('Unsupported type.') 45 | train_dataset_dict[k] = train_v 46 | test_dataset_dict[k] = test_v 47 | return train_dataset_dict, test_dataset_dict 48 | 49 | 50 | def _sample(dataset_dict: Union[np.ndarray, DatasetDict], 51 | indx: np.ndarray) -> DatasetDict: 52 | if isinstance(dataset_dict, np.ndarray): 53 | return dataset_dict[indx] 54 | elif isinstance(dataset_dict, dict): 55 | batch = {} 56 | for k, v in dataset_dict.items(): 57 | batch[k] = _sample(v, indx) 58 | else: 59 | raise TypeError("Unsupported type.") 60 | return batch 61 | 62 | 63 | class Dataset(object): 64 | 65 | def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): 66 | self.dataset_dict = dataset_dict 67 | self.dataset_len = _check_lengths(dataset_dict) 68 | 69 | # Seeding similar to OpenAI Gym: 70 | # https://github.com/openai/gym/blob/master/gym/spaces/space.py#L46 71 | self._np_random = None 72 | if seed is not None: 73 | self.seed(seed) 74 | 75 | @property 76 | def np_random(self) -> np.random.RandomState: 77 | if self._np_random is None: 78 | self.seed() 79 | return self._np_random 80 | 81 | def seed(self, seed: Optional[int] = None) -> list: 82 | self._np_random, seed = seeding.np_random(seed) 83 | return [seed] 84 | 85 | def __len__(self) -> int: 86 | return self.dataset_len 87 | 88 | def sample(self, 89 | batch_size: int, 90 | keys: Optional[Iterable[str]] = None, 91 | indx: Optional[np.ndarray] = None) -> frozen_dict.FrozenDict: 92 | if indx is None: 93 | if hasattr(self.np_random, 'integers'): 94 | indx = self.np_random.integers(len(self), size=batch_size) 95 | else: 96 | indx = self.np_random.randint(len(self), size=batch_size) 97 | 98 | batch = dict() 99 | 100 | if keys is None: 101 | keys = self.dataset_dict.keys() 102 | 103 | for k in keys: 104 | if isinstance(self.dataset_dict[k], dict): 105 | batch[k] = _sample(self.dataset_dict[k], indx) 106 | else: 107 | batch[k] = self.dataset_dict[k][indx] 108 | 109 | return frozen_dict.freeze(batch) 110 | 111 | def split(self, ratio: float) -> Tuple['Dataset', 'Dataset']: 112 | assert 0 < ratio and ratio < 1 113 | index = int(self.dataset_len * ratio) 114 | train_dataset_dict, test_dataset_dict = _split(self.dataset_dict, 115 | index) 116 | return Dataset(train_dataset_dict), Dataset(test_dataset_dict) 117 | -------------------------------------------------------------------------------- /jaxrl2/utils/wandb_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import wandb 4 | import time 5 | 6 | import dateutil.tz 7 | from collections import OrderedDict 8 | import numpy as np 9 | from numbers import Number 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 14 | 15 | def create_exp_name(exp_prefix, exp_id=0, seed=0): 16 | """ 17 | Create a semi-unique experiment name that has a timestamp 18 | :param exp_prefix: 19 | :param exp_id: 20 | :return: 21 | """ 22 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 23 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 24 | return "%s_%s_%04d--s-%d" % (exp_prefix, timestamp, exp_id, seed) 25 | 26 | 27 | def create_stats_ordered_dict( 28 | name, 29 | data, 30 | stat_prefix=None, 31 | always_show_all_stats=True, 32 | exclude_max_min=False, 33 | ): 34 | if stat_prefix is not None: 35 | name = "{}{}".format(stat_prefix, name) 36 | if isinstance(data, Number): 37 | return OrderedDict({name: data}) 38 | 39 | if len(data) == 0: 40 | return OrderedDict() 41 | 42 | if isinstance(data, tuple): 43 | ordered_dict = OrderedDict() 44 | for number, d in enumerate(data): 45 | sub_dict = create_stats_ordered_dict( 46 | "{0}_{1}".format(name, number), 47 | d, 48 | ) 49 | ordered_dict.update(sub_dict) 50 | return ordered_dict 51 | 52 | if isinstance(data, list): 53 | try: 54 | iter(data[0]) 55 | except TypeError: 56 | pass 57 | else: 58 | data = np.concatenate(data) 59 | 60 | if (isinstance(data, np.ndarray) and data.size == 1 61 | and not always_show_all_stats): 62 | return OrderedDict({name: float(data)}) 63 | try: 64 | stats = OrderedDict([ 65 | (name + ' Mean', np.mean(data)), 66 | (name + ' Std', np.std(data)), 67 | ]) 68 | except: 69 | stats = OrderedDict([ 70 | (name + ' Mean', -1), 71 | (name + ' Std', -1), 72 | ]) 73 | if not exclude_max_min: 74 | try: 75 | stats[name + ' Max'] = np.max(data) 76 | stats[name + ' Min'] = np.min(data) 77 | except: 78 | stats[name + ' Max'] = -1 79 | stats[name + ' Min'] = -1 80 | return stats 81 | 82 | class WandBLogger(object): 83 | def __init__(self, wandb_logging, variant, project, experiment_id, output_dir=None, group_name='', team=None): 84 | self.wandb_logging = wandb_logging 85 | output_dir = os.path.join(output_dir, experiment_id) 86 | os.makedirs(output_dir, exist_ok=True) 87 | if wandb_logging: 88 | print('wandb using experimentid: ', experiment_id) 89 | print('wandb using project: ', project) 90 | print('wandb using group: ', group_name) 91 | 92 | try: 93 | from jaxrl2.utils.wandb_config import get_wandb_config 94 | wandb_config = get_wandb_config() 95 | os.environ['WANDB_API_KEY'] = wandb_config['WANDB_API_KEY'] 96 | os.environ['WANDB_USER_EMAIL'] = wandb_config['WANDB_EMAIL'] 97 | os.environ['WANDB_USERNAME'] = wandb_config['WANDB_USERNAME'] 98 | team = wandb_config['WANDB_TEAM'] if wandb_config['WANDB_TEAM'] != '' else None 99 | except: 100 | print('wandb_config.py not found, using default wandb config') 101 | os.environ["WANDB_MODE"] = "run" 102 | wandb.init( 103 | config=variant, 104 | project=project, 105 | dir=output_dir, 106 | id=experiment_id, 107 | settings=wandb.Settings(start_method="thread"), 108 | group=group_name, 109 | entity=team 110 | ) 111 | self.output_dir = output_dir 112 | 113 | 114 | def log(self, *args, **kwargs): 115 | if self.wandb_logging: 116 | wandb.log(*args, **kwargs) 117 | 118 | def log_histogram(self, name, values, step): 119 | fig = plt.figure() 120 | canvas = FigureCanvas(fig) 121 | plt.tight_layout() 122 | 123 | plt.hist(np.array(values.flatten()), bins=100) 124 | # plt.show() 125 | 126 | canvas.draw() # draw the canvas, cache the renderer 127 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') 128 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 129 | plt.close(fig) 130 | self.log({name: wandb.Image(out_image)}, step=step) -------------------------------------------------------------------------------- /examples/train_real.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import os 3 | import jax 4 | from jaxrl2.agents.pixel_sac.pixel_sac_learner import PixelSACLearner 5 | from jaxrl2.utils.general_utils import add_batch_dim 6 | import numpy as np 7 | import logging 8 | 9 | import gymnasium as gym 10 | from gym.spaces import Dict, Box 11 | 12 | from jaxrl2.data import ReplayBuffer 13 | from jaxrl2.utils.wandb_logger import WandBLogger, create_exp_name 14 | import tempfile 15 | from functools import partial 16 | from examples.train_utils_real import trajwise_alternating_training_loop 17 | import tensorflow as tf 18 | from jax.experimental.compilation_cache import compilation_cache 19 | from openpi_client import websocket_client_policy as _websocket_client_policy 20 | from droid.robot_env import RobotEnv 21 | 22 | home_dir = os.environ['HOME'] 23 | compilation_cache.initialize_cache(os.path.join(home_dir, 'jax_compilation_cache')) 24 | 25 | def shard_batch(batch, sharding): 26 | """Shards a batch across devices along its first dimension. 27 | 28 | Args: 29 | batch: A pytree of arrays. 30 | sharding: A jax Sharding object with shape (num_devices,). 31 | """ 32 | return jax.tree_util.tree_map( 33 | lambda x: jax.device_put( 34 | x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) 35 | ), 36 | batch, 37 | ) 38 | 39 | class DummyEnv(gym.ObservationWrapper): 40 | 41 | def __init__(self, variant): 42 | self.variant = variant 43 | self.image_shape = (variant.resize_image, variant.resize_image, 3 * variant.num_cameras, 1) 44 | obs_dict = {} 45 | obs_dict['pixels'] = Box(low=0, high=255, shape=self.image_shape, dtype=np.uint8) 46 | if variant.add_states: 47 | state_dim = 8 + 2024 # 8 is the proprioceptive state's dim, 2024 is the image representation's dim 48 | obs_dict['state'] = Box(low=-1.0, high=1.0, shape=(state_dim, 1), dtype=np.float32) 49 | self.observation_space = Dict(obs_dict) 50 | self.action_space = Box(low=-1, high=1, shape=(1, 32,), dtype=np.float32) # 32 is the noise action space of pi 0 51 | 52 | def main(variant): 53 | devices = jax.local_devices() 54 | num_devices = len(devices) 55 | assert variant.batch_size % num_devices == 0 56 | logging.info('num devices', num_devices) 57 | logging.info('batch size', variant.batch_size) 58 | # we shard the leading dimension (batch dimension) accross all devices evenly 59 | sharding = jax.sharding.PositionalSharding(devices) 60 | shard_fn = partial(shard_batch, sharding=sharding) 61 | 62 | # prevent tensorflow from using GPUs 63 | tf.config.set_visible_devices([], "GPU") 64 | 65 | kwargs = variant['train_kwargs'] 66 | if kwargs.pop('cosine_decay', False): 67 | kwargs['decay_steps'] = variant.max_steps 68 | 69 | if not variant.prefix: 70 | import uuid 71 | variant.prefix = str(uuid.uuid4().fields[-1])[:5] 72 | 73 | if variant.suffix: 74 | expname = create_exp_name(variant.prefix, seed=variant.seed) + f"_{variant.suffix}" 75 | else: 76 | expname = create_exp_name(variant.prefix, seed=variant.seed) 77 | 78 | outputdir = os.path.join(os.environ['EXP'], expname) 79 | variant.outputdir = outputdir 80 | if not os.path.exists(outputdir): 81 | os.makedirs(outputdir) 82 | print('writing to output dir ', outputdir) 83 | 84 | group_name = variant.prefix + '_' + variant.launch_group_id 85 | wandb_output_dir = tempfile.mkdtemp() 86 | wandb_logger = WandBLogger(variant.prefix != '', variant, variant.wandb_project, experiment_id=expname, output_dir=wandb_output_dir, group_name=group_name) 87 | 88 | 89 | agent_dp = _websocket_client_policy.WebsocketClientPolicy( 90 | host=os.environ['remote_host'], 91 | port=os.environ['remote_port'] 92 | ) 93 | logging.info(f"Server metadata: {agent_dp.get_server_metadata()}") 94 | 95 | logging.info("initializing environment...") 96 | env = RobotEnv(action_space="joint_velocity", gripper_action_space="position") 97 | eval_env = env 98 | logging.info("created the droid env!") 99 | 100 | assert os.environ.get('LEFT_CAMERA_ID') is not None 101 | assert os.environ.get('RIGHT_CAMERA_ID') is not None 102 | assert os.environ.get('WRIST_CAMERA_ID') is not None 103 | 104 | robot_config = dict( 105 | camera_to_use='right', 106 | left_camera_id=os.environ['LEFT_CAMERA_ID'], 107 | right_camera_id=os.environ['RIGHT_CAMERA_ID'], 108 | wrist_camera_id=os.environ['WRIST_CAMERA_ID'], 109 | max_timesteps=200 110 | ) 111 | 112 | dummy_env = DummyEnv(variant) 113 | sample_obs = add_batch_dim(dummy_env.observation_space.sample()) 114 | sample_action = add_batch_dim(dummy_env.action_space.sample()) 115 | logging.info('sample obs shapes', [(k, v.shape) for k, v in sample_obs.items()]) 116 | logging.info('sample action shape', sample_action.shape) 117 | 118 | agent = PixelSACLearner(variant.seed, sample_obs, sample_action, **kwargs) 119 | 120 | if variant.restore_path != '': 121 | logging.info('restoring from', variant.restore_path) 122 | agent.restore_checkpoint(variant.restore_path) 123 | 124 | online_buffer_size = 2 * variant.max_steps // variant.multi_grad_step 125 | online_replay_buffer = ReplayBuffer(dummy_env.observation_space, dummy_env.action_space, int(online_buffer_size)) 126 | replay_buffer = online_replay_buffer 127 | replay_buffer.seed(variant.seed) 128 | trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger, shard_fn=shard_fn, agent_dp=agent_dp, robot_config=robot_config) 129 | -------------------------------------------------------------------------------- /jaxrl2/networks/encoders/resnet_encoderv1.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from functools import partial 5 | from typing import Any, Callable, Sequence, Tuple 6 | from jaxrl2.networks.constants import default_init, xavier_init, kaiming_init 7 | from jaxrl2.networks.encoders.spatial_softmax import SpatialSoftmax 8 | from jaxrl2.networks.encoders.cross_norm import CrossNorm 9 | 10 | ModuleDef = Any 11 | 12 | 13 | class MyGroupNorm(nn.GroupNorm): 14 | 15 | def __call__(self, x): 16 | if x.ndim == 3: 17 | x = x[jnp.newaxis] 18 | x = super().__call__(x) 19 | return x[0] 20 | else: 21 | return super().__call__(x) 22 | 23 | class ResNetBlock(nn.Module): 24 | """ResNet block.""" 25 | filters: int 26 | conv: ModuleDef 27 | norm: ModuleDef 28 | act: Callable 29 | strides: Tuple[int, int] = (1, 1) 30 | 31 | @nn.compact 32 | def __call__(self, x, ): 33 | residual = x 34 | y = self.conv(self.filters, (3, 3), self.strides)(x) 35 | y = self.norm()(y) 36 | y = self.act(y) 37 | y = self.conv(self.filters, (3, 3))(y) 38 | y = self.norm()(y) 39 | 40 | if residual.shape != y.shape: 41 | residual = self.conv(self.filters, (1, 1), 42 | self.strides, name='conv_proj')(residual) 43 | residual = self.norm(name='norm_proj')(residual) 44 | 45 | return self.act(residual + y) 46 | 47 | 48 | class BottleneckResNetBlock(nn.Module): 49 | """Bottleneck ResNet block.""" 50 | filters: int 51 | conv: ModuleDef 52 | norm: ModuleDef 53 | act: Callable 54 | strides: Tuple[int, int] = (1, 1) 55 | 56 | @nn.compact 57 | def __call__(self, x): 58 | residual = x 59 | y = self.conv(self.filters, (1, 1))(x) 60 | y = self.norm()(y) 61 | y = self.act(y) 62 | y = self.conv(self.filters, (3, 3), self.strides)(y) 63 | y = self.norm()(y) 64 | y = self.act(y) 65 | y = self.conv(self.filters * 4, (1, 1))(y) 66 | y = self.norm(scale_init=nn.initializers.zeros)(y) 67 | 68 | if residual.shape != y.shape: 69 | residual = self.conv(self.filters * 4, (1, 1), 70 | self.strides, name='conv_proj')(residual) 71 | residual = self.norm(name='norm_proj')(residual) 72 | 73 | return self.act(residual + y) 74 | 75 | 76 | class ResNetEncoder(nn.Module): 77 | """ResNetV1.""" 78 | stage_sizes: Sequence[int] 79 | block_cls: ModuleDef 80 | num_filters: int = 64 81 | dtype: Any = jnp.float32 82 | act: Callable = nn.relu 83 | conv: ModuleDef = nn.Conv 84 | norm: str = 'batch' 85 | use_spatial_softmax: bool = True 86 | softmax_temperature: float = 1.0 87 | 88 | @nn.compact 89 | def __call__(self, observations: jnp.ndarray, train: bool = True) -> jnp.ndarray: 90 | 91 | x = observations.astype(jnp.float32) / 255.0 92 | x = jnp.reshape(x, (*x.shape[:-2], -1)) 93 | 94 | conv = partial(self.conv, use_bias=False, dtype=self.dtype, kernel_init=kaiming_init()) 95 | if self.norm == 'batch': 96 | norm = partial(nn.BatchNorm, 97 | use_running_average=not train, 98 | momentum=0.9, 99 | epsilon=1e-5, 100 | dtype=self.dtype) 101 | elif self.norm == 'group': 102 | norm = partial(MyGroupNorm, 103 | num_groups=4, 104 | epsilon=1e-5, 105 | dtype=self.dtype) 106 | elif self.norm == 'cross': 107 | norm = partial(CrossNorm, 108 | use_running_average=not train, 109 | momentum=0.9, 110 | epsilon=1e-5, 111 | dtype=self.dtype) 112 | elif self.norm == 'layer': 113 | norm = partial(nn.LayerNorm, 114 | epsilon=1e-5, 115 | dtype=self.dtype, 116 | ) 117 | else: 118 | raise ValueError('norm not found') 119 | 120 | # print('input ', x.shape) 121 | strides = (2, 2, 2, 1, 1) 122 | x = conv(self.num_filters, (7, 7), (strides[0], strides[0]), 123 | padding=[(3, 3), (3, 3)], 124 | name='conv_init')(x) 125 | # print('post conv1', x.shape) 126 | 127 | x = norm(name='bn_init')(x) 128 | x = nn.relu(x) 129 | x = nn.max_pool(x, (3, 3), strides=(strides[1], strides[1]), padding='SAME') 130 | # print('post maxpool1', x.shape) 131 | for i, block_size in enumerate(self.stage_sizes): 132 | for j in range(block_size): 133 | stride = (strides[i + 1], strides[i + 1]) if i > 0 and j == 0 else (1, 1) 134 | x = self.block_cls(self.num_filters * 2 ** i, 135 | strides=stride, 136 | conv=conv, 137 | norm=norm, 138 | act=self.act)(x) 139 | # print('post block layer ', x.shape) 140 | # print('post block ', x.shape) 141 | 142 | if self.use_spatial_softmax: 143 | height, width, channel = x.shape[len(x.shape) - 3:] 144 | pos_x, pos_y = jnp.meshgrid( 145 | jnp.linspace(-1., 1., height), 146 | jnp.linspace(-1., 1., width) 147 | ) 148 | pos_x = pos_x.reshape(height * width) 149 | pos_y = pos_y.reshape(height * width) 150 | # print('pre spatial softmax', x.shape) 151 | x = SpatialSoftmax(height, width, channel, pos_x, pos_y, self.softmax_temperature)(x) 152 | # print('post spatial softmax', x.shape) 153 | else: 154 | x = jnp.mean(x, axis=(len(x.shape) - 3,len(x.shape) - 2)) 155 | # print('post flatten', x.shape) 156 | return x 157 | 158 | ResNetSmall = partial(ResNetEncoder, stage_sizes=(1, 1, 1, 1), 159 | block_cls=ResNetBlock) 160 | ResNet18 = partial(ResNetEncoder, stage_sizes=(2, 2, 2, 2), 161 | block_cls=ResNetBlock) 162 | ResNet34 = partial(ResNetEncoder, stage_sizes=(3, 4, 6, 3), 163 | block_cls=ResNetBlock) 164 | -------------------------------------------------------------------------------- /jaxrl2/agents/common.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Tuple, Any 3 | 4 | import distrax 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from jaxrl2.data.dataset import DatasetDict 10 | from jaxrl2.types import Params, PRNGKey 11 | import flax.linen as nn 12 | from typing import Any, Callable, Dict, Sequence, Union 13 | 14 | # Helps to minimize CPU to GPU transfer. 15 | def _unpack(batch): 16 | # Assuming that if next_observation is missing, it's combined with observation: 17 | obs_pixels = batch['observations']['pixels'][..., :-1] 18 | next_obs_pixels = batch['observations']['pixels'][..., 1:] 19 | 20 | obs = batch['observations'].copy(add_or_replace={'pixels': obs_pixels}) 21 | next_obs = batch['next_observations'].copy( 22 | add_or_replace={'pixels': next_obs_pixels}) 23 | 24 | batch = batch.copy(add_or_replace={ 25 | 'observations': obs, 26 | 'next_observations': next_obs 27 | }) 28 | 29 | return batch 30 | 31 | @partial(jax.jit, static_argnames='actor_apply_fn') 32 | def eval_log_prob_jit(actor_apply_fn: Callable[..., distrax.Distribution], 33 | actor_params: Params, actor_batch_stats: Any, batch: DatasetDict) -> float: 34 | # batch = _unpack(batch) 35 | input_collections = {'params': actor_params} 36 | if actor_batch_stats is not None: 37 | input_collections['batch_stats'] = actor_batch_stats 38 | dist = actor_apply_fn(input_collections, 39 | batch['observations'], 40 | training=False, 41 | mutable=False) 42 | log_probs = dist.log_prob(batch['actions']) 43 | return log_probs.mean() 44 | 45 | @partial(jax.jit, static_argnames='actor_apply_fn') 46 | def eval_mse_jit(actor_apply_fn: Callable[..., distrax.Distribution], 47 | actor_params: Params, actor_batch_stats: Any, batch: DatasetDict) -> float: 48 | # batch = _unpack(batch) 49 | input_collections = {'params': actor_params} 50 | if actor_batch_stats is not None: 51 | input_collections['batch_stats'] = actor_batch_stats 52 | dist = actor_apply_fn(input_collections, 53 | batch['observations'], 54 | training=False, 55 | mutable=False) 56 | mse = (dist.loc - batch['actions']) ** 2 57 | return mse.mean() 58 | 59 | def eval_reward_function_jit(actor_apply_fn: Callable[..., distrax.Distribution], 60 | actor_params: Params, actor_batch_stats: Any, batch: DatasetDict) -> float: 61 | # batch = _unpack(batch) 62 | input_collections = {'params': actor_params} 63 | if actor_batch_stats is not None: 64 | input_collections['batch_stats'] = actor_batch_stats 65 | dist = actor_apply_fn(input_collections, 66 | batch['observations'], 67 | training=False, 68 | mutable=False) 69 | pred = dist.mode().reshape(-1) 70 | loss = - (batch['rewards'] * jnp.log(1. / (1. + jnp.exp(-pred))) + (1.0 - batch['rewards']) * jnp.log(1. - 1. / (1. + jnp.exp(-pred)))) 71 | return loss.mean() 72 | 73 | 74 | @partial(jax.jit, static_argnames='actor_apply_fn') 75 | def eval_actions_jit(actor_apply_fn: Callable[..., distrax.Distribution], 76 | actor_params: Params, 77 | observations: np.ndarray, 78 | actor_batch_stats: Any) -> jnp.ndarray: 79 | input_collections = {'params': actor_params} 80 | if actor_batch_stats is not None: 81 | input_collections['batch_stats'] = actor_batch_stats 82 | dist = actor_apply_fn(input_collections, observations, training=False, 83 | mutable=False) 84 | return dist.mode() 85 | 86 | 87 | @partial(jax.jit, static_argnames='actor_apply_fn') 88 | def sample_actions_jit( 89 | rng: PRNGKey, actor_apply_fn: Callable[..., distrax.Distribution], 90 | actor_params: Params, 91 | observations: np.ndarray, 92 | actor_batch_stats: Any) -> Tuple[PRNGKey, jnp.ndarray]: 93 | input_collections = {'params': actor_params} 94 | if actor_batch_stats is not None: 95 | input_collections['batch_stats'] = actor_batch_stats 96 | dist = actor_apply_fn(input_collections, observations) 97 | rng, key = jax.random.split(rng) 98 | return rng, dist.sample(seed=key) 99 | 100 | 101 | class ModuleDict(nn.Module): 102 | """ 103 | from https://github.com/rail-berkeley/jaxrl_minimal/blob/main/jaxrl_m/common/common.py#L33 104 | Utility class for wrapping a dictionary of modules. This is useful when you have multiple modules that you want to 105 | initialize all at once (creating a single `params` dictionary), but you want to be able to call them separately 106 | later. As a bonus, the modules may have sub-modules nested inside them that share parameters (e.g. an image encoder) 107 | and Flax will automatically handle this without duplicating the parameters. 108 | 109 | To initialize the modules, call `init` with no `name` kwarg, and then pass the example arguments to each module as 110 | additional kwargs. To call the modules, pass the name of the module as the `name` kwarg, and then pass the arguments 111 | to the module as additional args or kwargs. 112 | 113 | Example usage: 114 | ``` 115 | shared_encoder = Encoder() 116 | actor = Actor(encoder=shared_encoder) 117 | critic = Critic(encoder=shared_encoder) 118 | 119 | model_def = ModuleDict({"actor": actor, "critic": critic}) 120 | params = model_def.init(rng_key, actor=example_obs, critic=(example_obs, example_action)) 121 | 122 | actor_output = model_def.apply({"params": params}, example_obs, name="actor") 123 | critic_output = model_def.apply({"params": params}, example_obs, action=example_action, name="critic") 124 | ``` 125 | """ 126 | 127 | modules: Dict[str, nn.Module] 128 | 129 | @nn.compact 130 | def __call__(self, *args, name=None, **kwargs): 131 | if name is None: 132 | if kwargs.keys() != self.modules.keys(): 133 | raise ValueError( 134 | f"When `name` is not specified, kwargs must contain the arguments for each module. " 135 | f"Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}" 136 | ) 137 | out = {} 138 | for key, value in kwargs.items(): 139 | if isinstance(value, Mapping): 140 | out[key] = self.modules[key](**value) 141 | elif isinstance(value, Sequence): 142 | out[key] = self.modules[key](*value) 143 | else: 144 | out[key] = self.modules[key](value) 145 | return out 146 | 147 | return self.modules[name](*args, **kwargs) -------------------------------------------------------------------------------- /examples/train_sim.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import os 3 | # Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs from https://github.com/huggingface/gym-aloha/tree/main?tab=readme-ov-file#-gpu-rendering-egl 4 | xla_flags = os.environ.get('XLA_FLAGS', '') 5 | xla_flags += ' --xla_gpu_triton_gemm_any=True' 6 | os.environ['XLA_FLAGS'] = xla_flags 7 | 8 | import pathlib, copy 9 | 10 | import jax 11 | from jaxrl2.agents.pixel_sac.pixel_sac_learner import PixelSACLearner 12 | from jaxrl2.utils.general_utils import add_batch_dim 13 | import numpy as np 14 | 15 | import gymnasium as gym 16 | import gym_aloha 17 | from gym.spaces import Dict, Box 18 | 19 | from libero.libero import benchmark 20 | from libero.libero import get_libero_path 21 | from libero.libero.envs import OffScreenRenderEnv 22 | 23 | from jaxrl2.data import ReplayBuffer 24 | from jaxrl2.utils.wandb_logger import WandBLogger, create_exp_name 25 | import tempfile 26 | from functools import partial 27 | from examples.train_utils_sim import trajwise_alternating_training_loop 28 | import tensorflow as tf 29 | from jax.experimental.compilation_cache import compilation_cache 30 | 31 | from openpi.training import config as openpi_config 32 | from openpi.policies import policy_config 33 | from openpi.shared import download 34 | 35 | home_dir = os.environ['HOME'] 36 | compilation_cache.initialize_cache(os.path.join(home_dir, 'jax_compilation_cache')) 37 | 38 | def _get_libero_env(task, resolution, seed): 39 | """Initializes and returns the LIBERO environment, along with the task description.""" 40 | task_description = task.language 41 | task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file 42 | env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} 43 | env = OffScreenRenderEnv(**env_args) 44 | env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state 45 | return env, task_description 46 | 47 | def shard_batch(batch, sharding): 48 | """Shards a batch across devices along its first dimension. 49 | 50 | Args: 51 | batch: A pytree of arrays. 52 | sharding: A jax Sharding object with shape (num_devices,). 53 | """ 54 | return jax.tree_util.tree_map( 55 | lambda x: jax.device_put( 56 | x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) 57 | ), 58 | batch, 59 | ) 60 | 61 | 62 | class DummyEnv(gym.ObservationWrapper): 63 | 64 | def __init__(self, variant): 65 | self.variant = variant 66 | self.image_shape = (variant.resize_image, variant.resize_image, 3 * variant.num_cameras, 1) 67 | obs_dict = {} 68 | obs_dict['pixels'] = Box(low=0, high=255, shape=self.image_shape, dtype=np.uint8) 69 | if variant.add_states: 70 | if variant.env == 'libero': 71 | state_dim = 8 72 | elif variant.env == 'aloha_cube': 73 | state_dim = 14 74 | obs_dict['state'] = Box(low=-1.0, high=1.0, shape=(state_dim, 1), dtype=np.float32) 75 | self.observation_space = Dict(obs_dict) 76 | self.action_space = Box(low=-1, high=1, shape=(1, 32,), dtype=np.float32) # 32 is the noise action space of pi 0 77 | 78 | 79 | def main(variant): 80 | devices = jax.local_devices() 81 | num_devices = len(devices) 82 | assert variant.batch_size % num_devices == 0 83 | print('num devices', num_devices) 84 | print('batch size', variant.batch_size) 85 | # we shard the leading dimension (batch dimension) accross all devices evenly 86 | sharding = jax.sharding.PositionalSharding(devices) 87 | shard_fn = partial(shard_batch, sharding=sharding) 88 | 89 | # prevent tensorflow from using GPUs 90 | tf.config.set_visible_devices([], "GPU") 91 | 92 | kwargs = variant['train_kwargs'] 93 | if kwargs.pop('cosine_decay', False): 94 | kwargs['decay_steps'] = variant.max_steps 95 | 96 | if not variant.prefix: 97 | import uuid 98 | variant.prefix = str(uuid.uuid4().fields[-1])[:5] 99 | 100 | if variant.suffix: 101 | expname = create_exp_name(variant.prefix, seed=variant.seed) + f"_{variant.suffix}" 102 | else: 103 | expname = create_exp_name(variant.prefix, seed=variant.seed) 104 | 105 | outputdir = os.path.join(os.environ['EXP'], expname) 106 | variant.outputdir = outputdir 107 | if not os.path.exists(outputdir): 108 | os.makedirs(outputdir) 109 | print('writing to output dir ', outputdir) 110 | 111 | if variant.env == 'libero': 112 | benchmark_dict = benchmark.get_benchmark_dict() 113 | task_suite = benchmark_dict["libero_90"]() 114 | task_id = 57 115 | task = task_suite.get_task(task_id) 116 | env, task_description = _get_libero_env(task, 256, variant.seed) 117 | eval_env = env 118 | variant.task_description = task_description 119 | variant.env_max_reward = 1 120 | variant.max_timesteps = 400 121 | elif variant.env == 'aloha_cube': 122 | from gymnasium.envs.registration import register 123 | register( 124 | id="gym_aloha/AlohaTransferCube-v0", 125 | entry_point="gym_aloha.env:AlohaEnv", 126 | max_episode_steps=400, 127 | nondeterministic=True, 128 | kwargs={"obs_type": "pixels", "task": "transfer_cube"}, 129 | ) 130 | env = gym.make("gym_aloha/AlohaTransferCube-v0", obs_type="pixels_agent_pos", render_mode="rgb_array") 131 | eval_env = copy.deepcopy(env) 132 | variant.env_max_reward = 4 133 | variant.max_timesteps = 400 134 | 135 | 136 | group_name = variant.prefix + '_' + variant.launch_group_id 137 | wandb_output_dir = tempfile.mkdtemp() 138 | wandb_logger = WandBLogger(variant.prefix != '', variant, variant.wandb_project, experiment_id=expname, output_dir=wandb_output_dir, group_name=group_name) 139 | 140 | dummy_env = DummyEnv(variant) 141 | sample_obs = add_batch_dim(dummy_env.observation_space.sample()) 142 | sample_action = add_batch_dim(dummy_env.action_space.sample()) 143 | print('sample obs shapes', [(k, v.shape) for k, v in sample_obs.items()]) 144 | print('sample action shape', sample_action.shape) 145 | 146 | 147 | if variant.env == 'libero': 148 | config = openpi_config.get_config("pi0_libero") 149 | checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_libero") 150 | elif variant.env == 'aloha_cube': 151 | config = openpi_config.get_config("pi0_aloha_sim") 152 | checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_aloha_sim") 153 | else: 154 | raise NotImplementedError() 155 | agent_dp = policy_config.create_trained_policy(config, checkpoint_dir) 156 | print("Loaded pi0 policy from %s", checkpoint_dir) 157 | agent = PixelSACLearner(variant.seed, sample_obs, sample_action, **kwargs) 158 | 159 | online_buffer_size = variant.max_steps // variant.multi_grad_step 160 | online_replay_buffer = ReplayBuffer(dummy_env.observation_space, dummy_env.action_space, int(online_buffer_size)) 161 | replay_buffer = online_replay_buffer 162 | replay_buffer.seed(variant.seed) 163 | trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger, shard_fn=shard_fn, agent_dp=agent_dp) 164 | -------------------------------------------------------------------------------- /jaxrl2/data/replay_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from typing import Iterable, Optional 3 | import jax 4 | import gym 5 | import gym.spaces 6 | import numpy as np 7 | import pickle 8 | 9 | import copy 10 | 11 | from jaxrl2.data.dataset import Dataset, DatasetDict 12 | import collections 13 | from flax.core import frozen_dict 14 | 15 | def _init_replay_dict(obs_space: gym.Space, 16 | capacity: int) -> Union[np.ndarray, DatasetDict]: 17 | if isinstance(obs_space, gym.spaces.Box): 18 | return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype) 19 | elif isinstance(obs_space, gym.spaces.Dict): 20 | data_dict = {} 21 | for k, v in obs_space.spaces.items(): 22 | data_dict[k] = _init_replay_dict(v, capacity) 23 | return data_dict 24 | else: 25 | raise TypeError() 26 | 27 | 28 | class ReplayBuffer(Dataset): 29 | 30 | def __init__(self, observation_space: gym.Space, action_space: gym.Space, capacity: int, ): 31 | self.observation_space = observation_space 32 | self.action_space = action_space 33 | self.capacity = capacity 34 | 35 | print("making replay buffer of capacity ", self.capacity) 36 | 37 | observations = _init_replay_dict(self.observation_space, self.capacity) 38 | next_observations = _init_replay_dict(self.observation_space, self.capacity) 39 | actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype) 40 | next_actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype) 41 | rewards = np.empty((self.capacity, ), dtype=np.float32) 42 | masks = np.empty((self.capacity, ), dtype=np.float32) 43 | discount = np.empty((self.capacity, ), dtype=np.float32) 44 | 45 | self.data = { 46 | 'observations': observations, 47 | 'next_observations': next_observations, 48 | 'actions': actions, 49 | 'next_actions': next_actions, 50 | 'rewards': rewards, 51 | 'masks': masks, 52 | 'discount': discount, 53 | } 54 | 55 | self.size = 0 56 | self._traj_counter = 0 57 | self._start = 0 58 | self.traj_bounds = dict() 59 | self.streaming_buffer_size = None # this is for streaming the online data 60 | 61 | def __len__(self) -> int: 62 | return self.size 63 | 64 | def length(self) -> int: 65 | return self.size 66 | 67 | def increment_traj_counter(self): 68 | self.traj_bounds[self._traj_counter] = (self._start, self.size) # [start, end) 69 | self._start = self.size 70 | self._traj_counter += 1 71 | 72 | def get_random_trajs(self, num_trajs: int): 73 | self.which_trajs = np.random.randint(0, self._traj_counter, num_trajs) 74 | observations_list = [] 75 | next_observations_list = [] 76 | actions_list = [] 77 | rewards_list = [] 78 | terminals_list = [] 79 | masks_list = [] 80 | discount_list = [] 81 | 82 | for i in self.which_trajs: 83 | start, end = self.traj_bounds[i] 84 | 85 | # handle this as a dictionary 86 | obs_dict_curr_traj = dict() 87 | for k in self.data['observations']: 88 | obs_dict_curr_traj[k] = self.data['observations'][k][start:end] 89 | observations_list.append(obs_dict_curr_traj) 90 | 91 | next_obs_dict_curr_traj = dict() 92 | for k in self.data['next_observations']: 93 | next_obs_dict_curr_traj[k] = self.data['next_observations'][k][start:end] 94 | next_observations_list.append(next_obs_dict_curr_traj) 95 | 96 | actions_list.append(self.data['actions'][start:end]) 97 | rewards_list.append(self.data['rewards'][start:end]) 98 | terminals_list.append(1-self.data['masks'][start:end]) 99 | masks_list.append(self.data['masks'][start:end]) 100 | 101 | 102 | 103 | batch = { 104 | 'observations': observations_list, 105 | 'next_observations': next_observations_list, 106 | 'actions': actions_list, 107 | 'rewards': rewards_list, 108 | 'terminals': terminals_list, 109 | 'masks': masks_list, 110 | 111 | 112 | } 113 | return batch 114 | 115 | def insert(self, data_dict: DatasetDict): 116 | if self.size == self.capacity: 117 | # Double the capacity 118 | observations = _init_replay_dict(self.observation_space, self.capacity) 119 | next_observations = _init_replay_dict(self.observation_space, self.capacity) 120 | actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype) 121 | next_actions = np.empty((self.capacity, *self.action_space.shape), dtype=self.action_space.dtype) 122 | rewards = np.empty((self.capacity, ), dtype=np.float32) 123 | masks = np.empty((self.capacity, ), dtype=np.float32) 124 | discount = np.empty((self.capacity, ), dtype=np.float32) 125 | 126 | data_new = { 127 | 'observations': observations, 128 | 'next_observations': next_observations, 129 | 'actions': actions, 130 | 'next_actions': next_actions, 131 | 'rewards': rewards, 132 | 'masks': masks, 133 | 'discount': discount, 134 | } 135 | 136 | for x in data_new: 137 | if isinstance(self.data[x], np.ndarray): 138 | self.data[x] = np.concatenate((self.data[x], data_new[x]), axis=0) 139 | elif isinstance(self.data[x], dict): 140 | for y in self.data[x]: 141 | self.data[x][y] = np.concatenate((self.data[x][y], data_new[x][y]), axis=0) 142 | else: 143 | raise TypeError() 144 | self.capacity *= 2 145 | 146 | 147 | for x in data_dict: 148 | if x in self.data: 149 | if isinstance(data_dict[x], dict): 150 | for y in data_dict[x]: 151 | self.data[x][y][self.size] = data_dict[x][y] 152 | else: 153 | self.data[x][self.size] = data_dict[x] 154 | self.size += 1 155 | 156 | def compute_action_stats(self): 157 | actions = self.data['actions'] 158 | return {'mean': actions.mean(axis=0), 'std': actions.std(axis=0)} 159 | 160 | def normalize_actions(self, action_stats): 161 | # do not normalize gripper dimension (last dimension) 162 | copy.deepcopy(action_stats) 163 | action_stats['mean'][-1] = 0 164 | action_stats['std'][-1] = 1 165 | self.data['actions'] = (self.data['actions'] - action_stats['mean']) / action_stats['std'] 166 | self.data['next_actions'] = (self.data['next_actions'] - action_stats['mean']) / action_stats['std'] 167 | 168 | def sample(self, batch_size: int, keys: Optional[Iterable[str]] = None, indx: Optional[np.ndarray] = None) -> frozen_dict.FrozenDict: 169 | if self.streaming_buffer_size: 170 | indices = np.random.randint(0, self.streaming_buffer_size, batch_size) 171 | else: 172 | indices = np.random.randint(0, self.size, batch_size) 173 | data_dict = {} 174 | for x in self.data: 175 | if isinstance(self.data[x], np.ndarray): 176 | data_dict[x] = self.data[x][indices] 177 | elif isinstance(self.data[x], dict): 178 | data_dict[x] = {} 179 | for y in self.data[x]: 180 | data_dict[x][y] = self.data[x][y][indices] 181 | else: 182 | raise TypeError() 183 | 184 | return frozen_dict.freeze(data_dict) 185 | 186 | def get_iterator(self, batch_size: int, keys: Optional[Iterable[str]] = None, indx: Optional[np.ndarray] = None, queue_size: int = 2): 187 | # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device 188 | # queue_size = 2 should be ok for one GPU. 189 | 190 | queue = collections.deque() 191 | 192 | def enqueue(n): 193 | for _ in range(n): 194 | data = self.sample(batch_size, keys, indx) 195 | queue.append(jax.device_put(data)) 196 | 197 | enqueue(queue_size) 198 | while queue: 199 | yield queue.popleft() 200 | enqueue(1) 201 | 202 | 203 | def save(self, filename): 204 | save_dict = dict( 205 | data=self.data, 206 | size = self.size, 207 | _traj_counter = self._traj_counter, 208 | _start=self._start, 209 | traj_bounds=self.traj_bounds 210 | ) 211 | with open(filename, 'wb') as f: 212 | pickle.dump(save_dict, f, protocol=4) 213 | 214 | 215 | def restore(self, filename): 216 | save_dict = np.load(filename, allow_pickle=True)[0] 217 | # todo test this: 218 | self.data = save_dict['data'] 219 | self.size = save_dict['size'] 220 | self._traj_counter = save_dict['_traj_counter'] 221 | self._start = save_dict['_start'] 222 | self.traj_bounds = save_dict['traj_bounds'] 223 | -------------------------------------------------------------------------------- /jaxrl2/networks/encoders/cross_norm.py: -------------------------------------------------------------------------------- 1 | """Normalization modules for Flax.""" 2 | 3 | from typing import (Any, Callable, Optional, Tuple, Iterable, Union) 4 | 5 | from jax import lax 6 | from jax.nn import initializers 7 | import jax.numpy as jnp 8 | 9 | from flax.linen.module import Module, compact, merge_param 10 | 11 | 12 | PRNGKey = Any 13 | Array = Any 14 | Shape = Tuple[int] 15 | Dtype = Any # this could be a real type? 16 | 17 | Axes = Union[int, Iterable[int]] 18 | 19 | import flax.linen as nn 20 | from flax.linen.module import Module, compact, merge_param 21 | 22 | def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: 23 | """Returns a tuple of deduplicated, sorted, and positive axes.""" 24 | if not isinstance(axes, Iterable): 25 | axes = (axes,) 26 | return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) 27 | 28 | 29 | def _abs_sq(x): 30 | """Computes the elementwise square of the absolute value |x|^2.""" 31 | if jnp.iscomplexobj(x): 32 | return lax.square(lax.real(x)) + lax.square(lax.imag(x)) 33 | else: 34 | return lax.square(x) 35 | 36 | 37 | def _compute_stats(x: Array, axes: Axes, 38 | axis_name: Optional[str] = None, 39 | axis_index_groups: Any = None, 40 | alpha: float = 0.5): 41 | """Computes mean and variance statistics. 42 | This implementation takes care of a few important details: 43 | - Computes in float32 precision for half precision inputs 44 | - mean and variance is computable in a single XLA fusion, 45 | by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]). 46 | - Clips negative variances to zero which can happen due to 47 | roundoff errors. This avoids downstream NaNs. 48 | - Supports averaging across a parallel axis and subgroups of a parallel axis 49 | with a single `lax.pmean` call to avoid latency. 50 | Arguments: 51 | x: Input array. 52 | axes: The axes in ``x`` to compute mean and variance statistics for. 53 | axis_name: Optional name for the pmapped axis to compute mean over. 54 | axis_index_groups: Optional axis indices. 55 | Returns: 56 | A pair ``(mean, var)``. 57 | """ 58 | # promote x to at least float32, this avoids half precision computation 59 | # but preserves double or complex floating points 60 | x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) 61 | split_1, split_2 = jnp.split(x, 2, axis=0) # split x into two parts 62 | mean_s1 = jnp.mean(split_1, axes) 63 | mean_s2 = jnp.mean(split_2, axes) 64 | 65 | mean2_s1 = jnp.mean(_abs_sq(split_1), axes) 66 | mean2_s2 = jnp.mean(_abs_sq(split_2), axes) 67 | 68 | mean = alpha * mean_s1 + (1 - alpha) * mean_s2 69 | 70 | if axis_name is not None: 71 | concatenated_mean = jnp.concatenate([mean, mean2]) 72 | mean, mean2 = jnp.split( 73 | lax.pmean( 74 | concatenated_mean, 75 | axis_name=axis_name, 76 | axis_index_groups=axis_index_groups), 2) 77 | # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due 78 | # to floating point round-off errors. 79 | var_s1 = mean2_s1 - _abs_sq(mean_s1) 80 | var_s2 = mean2_s2 - _abs_sq(mean_s2) 81 | var = alpha * var_s1 + (1 - alpha) * var_s2 82 | 83 | var = jnp.maximum(0., var) 84 | return mean, var 85 | 86 | 87 | def _normalize(mdl: Module, x: Array, mean: Array, var: Array, 88 | reduction_axes: Axes, feature_axes: Axes, 89 | dtype: Dtype, param_dtype: Dtype, 90 | epsilon: float, 91 | use_bias: bool, use_scale: bool, 92 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array], 93 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array]): 94 | """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. 95 | Arguments: 96 | mdl: Module to apply the normalization in (normalization params will reside 97 | in this module). 98 | x: The input. 99 | mean: Mean to use for normalization. 100 | var: Variance to use for normalization. 101 | reduction_axes: The axes in ``x`` to reduce. 102 | feature_axes: Axes containing features. A separate bias and scale is learned 103 | for each specified feature. 104 | dtype: Dtype of the returned result. 105 | param_dtype: Dtype of the parameters. 106 | epsilon: Normalization epsilon. 107 | use_bias: If true, add a bias term to the output. 108 | use_scale: If true, scale the output. 109 | bias_init: Initialization function for the bias term. 110 | scale_init: Initialization function for the scaling function. 111 | Returns: 112 | The normalized input. 113 | """ 114 | reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) 115 | feature_axes = _canonicalize_axes(x.ndim, feature_axes) 116 | stats_shape = list(x.shape) 117 | for axis in reduction_axes: 118 | stats_shape[axis] = 1 119 | mean = mean.reshape(stats_shape) 120 | var = var.reshape(stats_shape) 121 | feature_shape = [1] * x.ndim 122 | reduced_feature_shape = [] 123 | for ax in feature_axes: 124 | feature_shape[ax] = x.shape[ax] 125 | reduced_feature_shape.append(x.shape[ax]) 126 | y = x - mean 127 | mul = lax.rsqrt(var + epsilon) 128 | if use_scale: 129 | scale = mdl.param('scale', scale_init, reduced_feature_shape, 130 | param_dtype).reshape(feature_shape) 131 | mul *= scale 132 | y *= mul 133 | if use_bias: 134 | bias = mdl.param('bias', bias_init, reduced_feature_shape, 135 | param_dtype).reshape(feature_shape) 136 | y += bias 137 | return jnp.asarray(y, dtype) 138 | 139 | class CrossNorm(Module): 140 | """CrossNorm Module. 141 | Usage Note: 142 | If we define a model with CrossNorm, for example:: 143 | BN = nn.CrossNorm(use_running_average=False, momentum=0.9, epsilon=1e-5, 144 | dtype=jnp.float32) 145 | The initialized variables dict will contain in addition to a 'params' 146 | collection a separate 'batch_stats' collection that will contain all the 147 | running statistics for all the BatchNorm layers in a model:: 148 | vars_initialized = BN.init(key, x) # {'params': ..., 'batch_stats': ...} 149 | We then update the batch_stats during training by specifying that the 150 | `batch_stats` collection is mutable in the `apply` method for our module.:: 151 | vars_in = {'params': params, 'batch_stats': old_batch_stats} 152 | y, mutated_vars = BN.apply(vars_in, x, mutable=['batch_stats']) 153 | new_batch_stats = mutated_vars['batch_stats'] 154 | During eval we would define BN with `use_running_average=True` and use the 155 | batch_stats collection from training to set the statistics. In this case 156 | we are not mutating the batch statistics collection, and needn't mark it 157 | mutable:: 158 | vars_in = {'params': params, 'batch_stats': training_batch_stats} 159 | y = BN.apply(vars_in, x) 160 | Attributes: 161 | use_running_average: if True, the statistics stored in batch_stats 162 | will be used instead of computing the batch statistics on the input. 163 | axis: the feature or non-batch axis of the input. 164 | momentum: decay rate for the exponential moving average of 165 | the batch statistics. 166 | epsilon: a small float added to variance to avoid dividing by zero. 167 | dtype: the dtype of the computation (default: float32). 168 | param_dtype: the dtype passed to parameter initializers (default: float32). 169 | use_bias: if True, bias (beta) is added. 170 | use_scale: if True, multiply by scale (gamma). 171 | When the next layer is linear (also e.g. nn.relu), this can be disabled 172 | since the scaling will be done by the next layer. 173 | bias_init: initializer for bias, by default, zero. 174 | scale_init: initializer for scale, by default, one. 175 | axis_name: the axis name used to combine batch statistics from multiple 176 | devices. See `jax.pmap` for a description of axis names (default: None). 177 | axis_index_groups: groups of axis indices within that named axis 178 | representing subsets of devices to reduce over (default: None). For 179 | example, `[[0, 1], [2, 3]]` would independently batch-normalize over 180 | the examples on the first two and last two devices. See `jax.lax.psum` 181 | for more details. 182 | 183 | Note modified original BatchNorm module 184 | """ 185 | use_running_average: Optional[bool] = None 186 | axis: int = -1 187 | momentum: float = 0.99 188 | epsilon: float = 1e-5 189 | dtype: Dtype = jnp.float32 190 | param_dtype: Dtype = jnp.float32 191 | use_bias: bool = True 192 | use_scale: bool = True 193 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 194 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 195 | axis_name: Optional[str] = None 196 | axis_index_groups: Any = None 197 | alpha: float = 0.5 198 | 199 | @compact 200 | def __call__(self, x, use_running_average: Optional[bool] = None): 201 | """Normalizes the input using batch statistics. 202 | NOTE: 203 | During initialization (when parameters are mutable) the running average 204 | of the batch statistics will not be updated. Therefore, the inputs 205 | fed during initialization don't need to match that of the actual input 206 | distribution and the reduction axis (set with `axis_name`) does not have 207 | to exist. 208 | Args: 209 | x: the input to be normalized. 210 | use_running_average: if true, the statistics stored in batch_stats 211 | will be used instead of computing the batch statistics on the input. 212 | Returns: 213 | Normalized inputs (the same shape as inputs). 214 | """ 215 | 216 | use_running_average = merge_param( 217 | 'use_running_average', self.use_running_average, use_running_average) 218 | feature_axes = _canonicalize_axes(x.ndim, self.axis) 219 | reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) 220 | feature_shape = [x.shape[ax] for ax in feature_axes] 221 | 222 | # see NOTE above on initialization behavior 223 | initializing = self.is_mutable_collection('params') 224 | 225 | ra_mean = self.variable('batch_stats', 'mean', 226 | lambda s: jnp.zeros(s, jnp.float32), 227 | feature_shape) 228 | ra_var = self.variable('batch_stats', 'var', 229 | lambda s: jnp.ones(s, jnp.float32), 230 | feature_shape) 231 | 232 | if use_running_average: 233 | mean, var = ra_mean.value, ra_var.value 234 | else: 235 | mean, var = _compute_stats( 236 | x, reduction_axes, 237 | axis_name=self.axis_name if not initializing else None, 238 | axis_index_groups=self.axis_index_groups, alpha=self.alpha) 239 | 240 | if not initializing: 241 | ra_mean.value = self.momentum * ra_mean.value + (1 - 242 | self.momentum) * mean 243 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var 244 | 245 | return _normalize( 246 | self, x, mean, var, reduction_axes, feature_axes, 247 | self.dtype, self.param_dtype, self.epsilon, 248 | self.use_bias, self.use_scale, 249 | self.bias_init, self.scale_init) -------------------------------------------------------------------------------- /examples/train_utils_real.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | import jax 7 | import sys 8 | import select 9 | import tty 10 | import termios 11 | from openpi_client import image_tools 12 | from moviepy.editor import ImageSequenceClip 13 | 14 | 15 | def trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger, 16 | shard_fn=None, agent_dp=None, robot_config=None): 17 | replay_buffer_iterator = replay_buffer.get_iterator(variant.batch_size) 18 | if shard_fn is not None: 19 | replay_buffer_iterator = map(shard_fn, replay_buffer_iterator) 20 | 21 | i = 0 22 | total_env_steps = 0 23 | total_num_traj = 0 24 | wandb_logger.log({'num_online_samples': 0}, step=i) 25 | wandb_logger.log({'num_online_trajs': 0}, step=i) 26 | wandb_logger.log({'env_steps': 0}, step=i) 27 | 28 | with tqdm(total=variant.max_steps, initial=0) as pbar: 29 | while i <= variant.max_steps: 30 | traj = collect_traj(variant, agent, env, i, agent_dp, wandb_logger, total_num_traj, robot_config) 31 | total_num_traj += 1 32 | add_online_data_to_buffer(variant, traj, online_replay_buffer) 33 | total_env_steps += traj['env_steps'] 34 | print('online buffer timesteps length:', len(online_replay_buffer)) 35 | print('online buffer num traj:', total_num_traj) 36 | print('total env steps:', total_env_steps) 37 | 38 | if i == 0: 39 | num_gradsteps = 5000 40 | else: 41 | num_gradsteps = len(traj["rewards"]) * variant.multi_grad_step 42 | print(f'num_gradsteps: {num_gradsteps}') 43 | if total_num_traj >= variant.num_initial_traj_collect: 44 | for _ in range(num_gradsteps): 45 | 46 | batch = next(replay_buffer_iterator) 47 | update_info = agent.update(batch) 48 | 49 | pbar.update() 50 | i += 1 51 | 52 | if i % variant.log_interval == 0: 53 | update_info = {k: jax.device_get(v) for k, v in update_info.items()} 54 | for k, v in update_info.items(): 55 | if v.ndim == 0: 56 | wandb_logger.log({f'training/{k}': v}, step=i) 57 | elif v.ndim <= 2: 58 | wandb_logger.log_histogram(f'training/{k}', v, i) 59 | wandb_logger.log({ 60 | 'replay_buffer_size': len(online_replay_buffer), 61 | 'is_success (exploration)': int(traj['is_success']), 62 | }, i) 63 | 64 | if i % variant.eval_interval == 0: 65 | wandb_logger.log({'num_online_samples': len(online_replay_buffer)}, step=i) 66 | wandb_logger.log({'num_online_trajs': total_num_traj}, step=i) 67 | wandb_logger.log({'env_steps': total_env_steps}, step=i) 68 | if hasattr(agent, 'perform_eval'): 69 | agent.perform_eval(variant, i, wandb_logger, replay_buffer, replay_buffer_iterator, eval_env) 70 | 71 | if variant.checkpoint_interval != -1: 72 | if i % variant.checkpoint_interval == 0: 73 | agent.save_checkpoint(variant.outputdir, i, variant.checkpoint_interval) 74 | 75 | def add_online_data_to_buffer(variant, traj, online_replay_buffer): 76 | 77 | discount_horizon = variant.query_freq 78 | actions = np.array(traj['actions']) # (T, chunk_size, 14) 79 | episode_len = len(actions) 80 | rewards = np.array(traj['rewards']) 81 | masks = np.array(traj['masks']) 82 | 83 | for t in range(episode_len): 84 | obs = traj['observations'][t] 85 | next_obs = traj['observations'][t + 1] 86 | # remove batch dimension 87 | obs = {k: v[0] for k, v in obs.items()} 88 | next_obs = {k: v[0] for k, v in next_obs.items()} 89 | if not variant.add_states: 90 | obs.pop('state', None) 91 | next_obs.pop('state', None) 92 | 93 | insert_dict = dict( 94 | observations=obs, 95 | next_observations=next_obs, 96 | actions=actions[t], 97 | next_actions=actions[t + 1] if t < episode_len - 1 else actions[t], 98 | rewards=rewards[t], 99 | masks=masks[t], 100 | discount=variant.discount ** discount_horizon 101 | ) 102 | online_replay_buffer.insert(insert_dict) 103 | online_replay_buffer.increment_traj_counter() 104 | 105 | def collect_traj(variant, agent, env, i, agent_dp=None, wandb_logger=None, traj_id=None, robot_config=None): 106 | query_frequency = variant.query_freq 107 | instruction = variant.instruction 108 | max_timesteps = robot_config['max_timesteps'] 109 | agent._rng, rng = jax.random.split(agent._rng) 110 | try: 111 | env.reset() 112 | except Exception as e: 113 | print(f"Environment reset failed") 114 | import traceback 115 | traceback.print_exc() 116 | import pdb; pdb.set_trace() 117 | step_time = 1 / 15 # 15 Hz 118 | last_step_time = time.time() 119 | old_settings = termios.tcgetattr(sys.stdin) 120 | 121 | rewards = [] 122 | action_list = [] 123 | obs_list = [] 124 | image_list = [] 125 | 126 | old_settings = termios.tcgetattr(sys.stdin) 127 | try: 128 | tty.setcbreak(sys.stdin.fileno()) 129 | for t in tqdm(range(max_timesteps)): 130 | # Check for keyboard input 131 | if select.select([sys.stdin], [], [], 0) == ([sys.stdin], [], []): 132 | char_input = sys.stdin.read(1) 133 | if char_input.lower() == 'q': 134 | print("'q' pressed, stopping loop.") 135 | break 136 | 137 | try: 138 | _env_obs = env.get_observation() 139 | except Exception as e: 140 | print(f"Environment get obs failed") 141 | import traceback 142 | traceback.print_exc() 143 | import pdb; pdb.set_trace() 144 | curr_obs = _extract_observation( 145 | robot_config, 146 | _env_obs, 147 | ) 148 | image_list.append(curr_obs[robot_config['camera_to_use'] + "_image"]) 149 | 150 | request_data = get_pi0_input(curr_obs, robot_config, instruction) 151 | 152 | if t % query_frequency == 0: 153 | 154 | rng, key = jax.random.split(rng) 155 | 156 | img_all = process_images(variant, curr_obs) 157 | 158 | # extract the feature from the pi0 VLM backbone and concat with the qpos as states 159 | img_rep_pi0, _ = agent_dp.get_prefix_rep(request_data) 160 | img_rep_pi0 = img_rep_pi0[:, -1, :] # (1, 2048) 161 | qpos = np.concatenate([curr_obs["joint_position"], curr_obs["gripper_position"], img_rep_pi0.flatten()]) 162 | 163 | obs_dict = { 164 | 'pixels': img_all, 165 | 'state': qpos[np.newaxis, ..., np.newaxis], 166 | } 167 | if i == 0: 168 | noise = jax.random.normal(key, (1, *agent.action_chunk_shape)) 169 | noise_repeat = jax.numpy.repeat(noise[:, -1:, :], 10 - noise.shape[1], axis=1) 170 | noise = jax.numpy.concatenate([noise, noise_repeat], axis=1) 171 | actions_noise = noise[0, :agent.action_chunk_shape[0], :] 172 | else: 173 | # sac agent predicts the noise for diffusion model 174 | actions_noise = agent.sample_actions(obs_dict) 175 | actions_noise = np.reshape(actions_noise, agent.action_chunk_shape) 176 | noise = np.repeat(actions_noise[-1:, :], 10 - actions_noise.shape[0], axis=0) 177 | noise = jax.numpy.concatenate([actions_noise, noise], axis=0)[None] 178 | action_list.append(actions_noise) 179 | obs_list.append(obs_dict) 180 | action = agent_dp.infer(request_data, noise=np.asarray(noise))["actions"] 181 | 182 | action_t = action[t % query_frequency] 183 | 184 | # binarize gripper action. 185 | if action_t[-1].item() > 0.5: 186 | action_t = np.concatenate([action_t[:-1], np.ones((1,))]) 187 | else: 188 | action_t = np.concatenate([action_t[:-1], np.zeros((1,))]) 189 | action_t = np.clip(action_t, -1, 1) 190 | 191 | try: 192 | env.step(action_t) 193 | except Exception as e: 194 | print(f"Environment step failed") 195 | import traceback 196 | traceback.print_exc() # This prints the full traceback 197 | import pdb; pdb.set_trace() 198 | 199 | now = time.time() 200 | dt = now - last_step_time 201 | if dt < step_time: 202 | time.sleep(step_time - dt) 203 | last_step_time = time.time() 204 | else: 205 | last_step_time = now 206 | 207 | print("Trial finished. Mark as (1) Success or (0) Failure:") 208 | while True: 209 | if select.select([sys.stdin], [], [], 0) == ([sys.stdin], [], []): 210 | char_input = sys.stdin.read(1) 211 | if char_input == '1': 212 | print("Trial marked as SUCCESS.") 213 | is_success = True 214 | break 215 | elif char_input == '0': 216 | print("Trial marked as FAILURE.") 217 | is_success = False 218 | break 219 | else: 220 | print("Invalid input. Please enter '1' for Success or '0' for Failure:") 221 | time.sleep(0.01) # Small sleep to prevent busy-waiting if no input 222 | 223 | try: 224 | _env_obs = env.get_observation() 225 | except Exception as e: 226 | print(f"Environment get obs failed") 227 | import traceback 228 | traceback.print_exc() 229 | import pdb; pdb.set_trace() 230 | 231 | # add last observation 232 | curr_obs = _extract_observation( 233 | robot_config, 234 | _env_obs, 235 | ) 236 | image_list.append(curr_obs[robot_config['camera_to_use'] + "_image"]) 237 | request_data = get_pi0_input(curr_obs, robot_config, instruction) 238 | img_all = process_images(variant, curr_obs) 239 | img_rep_pi0, _ = agent_dp.get_prefix_rep(request_data) 240 | img_rep_pi0 = img_rep_pi0[:, -1, :] # (1, 2048) 241 | qpos = np.concatenate([curr_obs["joint_position"], curr_obs["gripper_position"], img_rep_pi0.flatten()]) 242 | obs_dict = { 243 | 'pixels': img_all, 244 | 'state': qpos[np.newaxis, ..., np.newaxis], 245 | } 246 | obs_list.append(obs_dict) 247 | print(f'Rollout Done') 248 | 249 | finally: 250 | if is_success: 251 | query_steps = len(action_list) 252 | rewards = np.concatenate([-np.ones(query_steps - 1), [0]]) 253 | masks = np.concatenate([np.ones(query_steps - 1), [0]]) 254 | else: 255 | query_steps = len(action_list) 256 | rewards = -np.ones(query_steps) 257 | masks = np.ones(query_steps) 258 | 259 | if wandb_logger is not None: 260 | wandb_logger.log({f'is_success': int(is_success)}, step=i) 261 | wandb_logger.log({f'total_num_traj': traj_id}, step=i) 262 | 263 | video_path = os.path.join(variant.outputdir, f'video_high_{traj_id}.mp4') 264 | video = np.stack(image_list) 265 | ImageSequenceClip(list(video), fps=15).write_videofile(video_path, codec="libx264") 266 | 267 | print("Episide Done! Press c after resetting the environment") 268 | try: 269 | env.reset() 270 | except Exception as e: 271 | print(f"Environment reset failed") 272 | import traceback 273 | traceback.print_exc() # This prints the full traceback 274 | import pdb; pdb.set_trace() 275 | import pdb; pdb.set_trace() 276 | termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) 277 | 278 | traj = { 279 | 'observations': obs_list, 280 | 'actions': action_list, 281 | 'rewards': rewards, 282 | 'masks': masks, 283 | 'is_success': is_success, 284 | 'env_steps': t + 1, 285 | } 286 | 287 | return traj 288 | 289 | 290 | def _extract_observation(robot_config, obs_dict): 291 | ''' 292 | from https://github.com/Physical-Intelligence/openpi/blob/main/examples/droid/main.py 293 | ''' 294 | image_observations = obs_dict["image"] 295 | left_image, right_image, wrist_image = None, None, None 296 | for key in image_observations.keys(): 297 | if robot_config['left_camera_id'] in key and "left" in key: 298 | left_image = image_observations[key] 299 | elif robot_config['right_camera_id'] in key and "left" in key: 300 | right_image = image_observations[key] 301 | elif robot_config['wrist_camera_id'] in key and "left" in key: 302 | wrist_image = image_observations[key] 303 | 304 | # Drop the alpha dimension 305 | left_image = left_image[..., :3] 306 | right_image = right_image[..., :3] 307 | wrist_image = wrist_image[..., :3] 308 | 309 | # Convert to RGB 310 | left_image = left_image[..., ::-1] 311 | right_image = right_image[..., ::-1] 312 | wrist_image = wrist_image[..., ::-1] 313 | 314 | # In addition to image observations, also capture the proprioceptive state 315 | robot_state = obs_dict["robot_state"] 316 | cartesian_position = np.array(robot_state["cartesian_position"]) 317 | joint_position = np.array(robot_state["joint_positions"]) 318 | gripper_position = np.array([robot_state["gripper_position"]]) 319 | 320 | return { 321 | "left_image": left_image, 322 | "right_image": right_image, 323 | "wrist_image": wrist_image, 324 | "cartesian_position": cartesian_position, 325 | "joint_position": joint_position, 326 | "gripper_position": gripper_position, 327 | } 328 | 329 | def get_pi0_input(obs, robot_config, instruction): 330 | external_image = obs[robot_config['camera_to_use'] + "_image"] 331 | request_data = { 332 | "observation/exterior_image_1_left": image_tools.resize_with_pad( 333 | external_image, 224, 224 334 | ), 335 | "observation/wrist_image_left": image_tools.resize_with_pad(obs["wrist_image"], 224, 224), 336 | "observation/joint_position": obs["joint_position"], 337 | "observation/gripper_position": obs["gripper_position"], 338 | "prompt": instruction, 339 | } 340 | return request_data 341 | 342 | 343 | def process_images(variant, obs): 344 | ''' 345 | concat the images from all cameras 346 | ''' 347 | im1 = image_tools.resize_with_pad(obs["left_image"], variant.resize_image, variant.resize_image) 348 | im2 = image_tools.resize_with_pad(obs["right_image"], variant.resize_image, variant.resize_image) 349 | im3 = image_tools.resize_with_pad(obs["wrist_image"], variant.resize_image, variant.resize_image) 350 | img_all = np.concatenate([im1, im2, im3], axis=2)[np.newaxis, ..., np.newaxis] 351 | return img_all -------------------------------------------------------------------------------- /jaxrl2/agents/pixel_sac/pixel_sac_learner.py: -------------------------------------------------------------------------------- 1 | """Implementations of algorithms for continuous control.""" 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | from flax.training import checkpoints 5 | import pathlib 6 | import matplotlib.pyplot as plt 7 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 8 | 9 | import numpy as np 10 | import copy 11 | import functools 12 | from typing import Dict, Optional, Sequence, Tuple, Union 13 | 14 | import jax 15 | import jax.numpy as jnp 16 | import optax 17 | from flax.core.frozen_dict import FrozenDict 18 | from flax.training import train_state 19 | from typing import Any 20 | 21 | from jaxrl2.agents.agent import Agent 22 | from jaxrl2.data.augmentations import batched_random_crop, color_transform 23 | from jaxrl2.networks.encoders.networks import Encoder, PixelMultiplexer 24 | from jaxrl2.networks.encoders.impala_encoder import ImpalaEncoder, SmallerImpalaEncoder 25 | from jaxrl2.networks.encoders.resnet_encoderv1 import ResNet18, ResNet34, ResNetSmall 26 | from jaxrl2.networks.encoders.resnet_encoderv2 import ResNetV2Encoder 27 | from jaxrl2.agents.pixel_sac.actor_updater import update_actor 28 | from jaxrl2.agents.pixel_sac.critic_updater import update_critic 29 | from jaxrl2.agents.pixel_sac.temperature_updater import update_temperature 30 | from jaxrl2.agents.pixel_sac.temperature import Temperature 31 | from jaxrl2.data.dataset import DatasetDict 32 | from jaxrl2.networks.learned_std_normal_policy import LearnedStdTanhNormalPolicy 33 | from jaxrl2.networks.values import StateActionEnsemble 34 | from jaxrl2.types import Params, PRNGKey 35 | from jaxrl2.utils.target_update import soft_target_update 36 | 37 | 38 | class TrainState(train_state.TrainState): 39 | batch_stats: Any 40 | 41 | @functools.partial(jax.jit, static_argnames=('critic_reduction', 'color_jitter', 'aug_next', 'num_cameras')) 42 | def _update_jit( 43 | rng: PRNGKey, actor: TrainState, critic: TrainState, 44 | target_critic_params: Params, temp: TrainState, batch: TrainState, 45 | discount: float, tau: float, target_entropy: float, 46 | critic_reduction: str, color_jitter: bool, aug_next: bool, num_cameras: int, 47 | ) -> Tuple[PRNGKey, TrainState, TrainState, Params, TrainState, Dict[str,float]]: 48 | aug_pixels = batch['observations']['pixels'] 49 | aug_next_pixels = batch['next_observations']['pixels'] 50 | if batch['observations']['pixels'].squeeze().ndim != 2: 51 | rng, key = jax.random.split(rng) 52 | aug_pixels = batched_random_crop(key, batch['observations']['pixels']) 53 | 54 | if color_jitter: 55 | rng, key = jax.random.split(rng) 56 | if num_cameras > 1: 57 | for i in range(num_cameras): 58 | aug_pixels = aug_pixels.at[:,:,:,i*3:(i+1)*3].set((color_transform(key, aug_pixels[:,:,:,i*3:(i+1)*3].astype(jnp.float32)/255.)*255).astype(jnp.uint8)) 59 | else: 60 | aug_pixels = (color_transform(key, aug_pixels.astype(jnp.float32)/255.)*255).astype(jnp.uint8) 61 | 62 | observations = batch['observations'].copy(add_or_replace={'pixels': aug_pixels}) 63 | batch = batch.copy(add_or_replace={'observations': observations}) 64 | 65 | key, rng = jax.random.split(rng) 66 | if aug_next: 67 | rng, key = jax.random.split(rng) 68 | aug_next_pixels = batched_random_crop(key, batch['next_observations']['pixels']) 69 | if color_jitter: 70 | rng, key = jax.random.split(rng) 71 | if num_cameras > 1: 72 | for i in range(num_cameras): 73 | aug_next_pixels = aug_next_pixels.at[:,:,:,i*3:(i+1)*3].set((color_transform(key, aug_next_pixels[:,:,:,i*3:(i+1)*3].astype(jnp.float32)/255.)*255).astype(jnp.uint8)) 74 | else: 75 | aug_next_pixels = (color_transform(key, aug_next_pixels.astype(jnp.float32)/255.)*255).astype(jnp.uint8) 76 | next_observations = batch['next_observations'].copy( 77 | add_or_replace={'pixels': aug_next_pixels}) 78 | batch = batch.copy(add_or_replace={'next_observations': next_observations}) 79 | 80 | key, rng = jax.random.split(rng) 81 | target_critic = critic.replace(params=target_critic_params) 82 | new_critic, critic_info = update_critic(key, actor, critic, target_critic, temp, batch, discount, critic_reduction=critic_reduction) 83 | new_target_critic_params = soft_target_update(new_critic.params, target_critic_params, tau) 84 | 85 | key, rng = jax.random.split(rng) 86 | new_actor, actor_info = update_actor(key, actor, new_critic, temp, batch, critic_reduction=critic_reduction) 87 | new_temp, alpha_info = update_temperature(temp, actor_info['entropy'], target_entropy) 88 | 89 | return rng, new_actor, new_critic, new_target_critic_params, new_temp, { 90 | **critic_info, 91 | **actor_info, 92 | **alpha_info 93 | } 94 | 95 | 96 | class PixelSACLearner(Agent): 97 | 98 | def __init__(self, 99 | seed: int, 100 | observations: Union[jnp.ndarray, DatasetDict], 101 | actions: jnp.ndarray, 102 | actor_lr: float = 3e-4, 103 | critic_lr: float = 3e-4, 104 | temp_lr: float = 3e-4, 105 | decay_steps: Optional[int] = None, 106 | hidden_dims: Sequence[int] = (256, 256), 107 | cnn_features: Sequence[int] = (32, 32, 32, 32), 108 | cnn_strides: Sequence[int] = (2, 1, 1, 1), 109 | cnn_padding: str = 'VALID', 110 | latent_dim: int = 50, 111 | discount: float = 0.99, 112 | tau: float = 0.005, 113 | critic_reduction: str = 'mean', 114 | dropout_rate: Optional[float] = None, 115 | encoder_type='resnet_34_v1', 116 | encoder_norm='group', 117 | color_jitter = True, 118 | use_spatial_softmax=True, 119 | softmax_temperature=1, 120 | aug_next=True, 121 | use_bottleneck=True, 122 | init_temperature: float = 1.0, 123 | num_qs: int = 2, 124 | target_entropy: float = None, 125 | action_magnitude: float = 1.0, 126 | num_cameras: int = 1 127 | ): 128 | """ 129 | An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1812.05905 130 | """ 131 | 132 | self.aug_next=aug_next 133 | self.color_jitter = color_jitter 134 | self.num_cameras = num_cameras 135 | 136 | self.action_dim = np.prod(actions.shape[-2:]) 137 | self.action_chunk_shape = actions.shape[-2:] 138 | 139 | self.tau = tau 140 | self.discount = discount 141 | self.critic_reduction = critic_reduction 142 | 143 | rng = jax.random.PRNGKey(seed) 144 | rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4) 145 | 146 | if encoder_type == 'small': 147 | encoder_def = Encoder(cnn_features, cnn_strides, cnn_padding) 148 | elif encoder_type == 'impala': 149 | print('using impala') 150 | encoder_def = ImpalaEncoder() 151 | elif encoder_type == 'impala_small': 152 | print('using impala small') 153 | encoder_def = SmallerImpalaEncoder() 154 | elif encoder_type == 'resnet_small': 155 | encoder_def = ResNetSmall(norm=encoder_norm, use_spatial_softmax=use_spatial_softmax, softmax_temperature=softmax_temperature) 156 | elif encoder_type == 'resnet_18_v1': 157 | encoder_def = ResNet18(norm=encoder_norm, use_spatial_softmax=use_spatial_softmax, softmax_temperature=softmax_temperature) 158 | elif encoder_type == 'resnet_34_v1': 159 | encoder_def = ResNet34(norm=encoder_norm, use_spatial_softmax=use_spatial_softmax, softmax_temperature=softmax_temperature) 160 | elif encoder_type == 'resnet_small_v2': 161 | encoder_def = ResNetV2Encoder(stage_sizes=(1, 1, 1, 1), norm=encoder_norm) 162 | elif encoder_type == 'resnet_18_v2': 163 | encoder_def = ResNetV2Encoder(stage_sizes=(2, 2, 2, 2), norm=encoder_norm) 164 | elif encoder_type == 'resnet_34_v2': 165 | encoder_def = ResNetV2Encoder(stage_sizes=(3, 4, 6, 3), norm=encoder_norm) 166 | else: 167 | raise ValueError('encoder type not found!') 168 | 169 | if decay_steps is not None: 170 | actor_lr = optax.cosine_decay_schedule(actor_lr, decay_steps) 171 | 172 | if len(hidden_dims) == 1: 173 | hidden_dims = (hidden_dims[0], hidden_dims[0], hidden_dims[0]) 174 | 175 | policy_def = LearnedStdTanhNormalPolicy(hidden_dims, self.action_dim, dropout_rate=dropout_rate, low=-action_magnitude, high=action_magnitude) 176 | 177 | actor_def = PixelMultiplexer(encoder=encoder_def, 178 | network=policy_def, 179 | latent_dim=latent_dim, 180 | use_bottleneck=use_bottleneck 181 | ) 182 | print(actor_def) 183 | actor_def_init = actor_def.init(actor_key, observations) 184 | actor_params = actor_def_init['params'] 185 | actor_batch_stats = actor_def_init['batch_stats'] if 'batch_stats' in actor_def_init else None 186 | 187 | actor = TrainState.create(apply_fn=actor_def.apply, 188 | params=actor_params, 189 | tx=optax.adam(learning_rate=actor_lr), 190 | batch_stats=actor_batch_stats) 191 | 192 | critic_def = StateActionEnsemble(hidden_dims, num_qs=num_qs) 193 | critic_def = PixelMultiplexer(encoder=encoder_def, 194 | network=critic_def, 195 | latent_dim=latent_dim, 196 | use_bottleneck=use_bottleneck 197 | ) 198 | print(critic_def) 199 | critic_def_init = critic_def.init(critic_key, observations, actions) 200 | self._critic_init_params = critic_def_init['params'] 201 | 202 | critic_params = critic_def_init['params'] 203 | critic_batch_stats = critic_def_init['batch_stats'] if 'batch_stats' in critic_def_init else None 204 | critic = TrainState.create(apply_fn=critic_def.apply, 205 | params=critic_params, 206 | tx=optax.adam(learning_rate=critic_lr), 207 | batch_stats=critic_batch_stats 208 | ) 209 | target_critic_params = copy.deepcopy(critic_params) 210 | 211 | temp_def = Temperature(init_temperature) 212 | temp_params = temp_def.init(temp_key)['params'] 213 | temp = TrainState.create(apply_fn=temp_def.apply, 214 | params=temp_params, 215 | tx=optax.adam(learning_rate=temp_lr), 216 | batch_stats=None) 217 | 218 | 219 | self._rng = rng 220 | self._actor = actor 221 | self._critic = critic 222 | self._target_critic_params = target_critic_params 223 | self._temp = temp 224 | if target_entropy is None or target_entropy == 'auto': 225 | self.target_entropy = -self.action_dim / 2 226 | else: 227 | self.target_entropy = float(target_entropy) 228 | print(f'target_entropy: {self.target_entropy}') 229 | print(self.critic_reduction) 230 | 231 | 232 | def update(self, batch: FrozenDict) -> Dict[str, float]: 233 | new_rng, new_actor, new_critic, new_target_critic, new_temp, info = _update_jit( 234 | self._rng, self._actor, self._critic, self._target_critic_params, self._temp, batch, self.discount, self.tau, self.target_entropy, self.critic_reduction, self.color_jitter, self.aug_next, self.num_cameras 235 | ) 236 | 237 | self._rng = new_rng 238 | self._actor = new_actor 239 | self._critic = new_critic 240 | self._target_critic_params = new_target_critic 241 | self._temp = new_temp 242 | return info 243 | 244 | def perform_eval(self, variant, i, wandb_logger, eval_buffer, eval_buffer_iterator, eval_env): 245 | from examples.train_utils_sim import make_multiple_value_reward_visulizations 246 | make_multiple_value_reward_visulizations(self, variant, i, eval_buffer, wandb_logger) 247 | 248 | def make_value_reward_visulization(self, variant, trajs): 249 | num_traj = len(trajs['rewards']) 250 | traj_images = [] 251 | 252 | for itraj in range(num_traj): 253 | observations = trajs['observations'][itraj] 254 | next_observations = trajs['next_observations'][itraj] 255 | actions = trajs['actions'][itraj] 256 | rewards = trajs['rewards'][itraj] 257 | masks = trajs['masks'][itraj] 258 | 259 | q_pred = [] 260 | 261 | for t in range(0, len(actions)): 262 | action = actions[t][None] 263 | obs_pixels = observations['pixels'][t] 264 | next_obs_pixels = next_observations['pixels'][t] 265 | 266 | obs_dict = {'pixels': obs_pixels[None]} 267 | for k, v in observations.items(): 268 | if 'pixels' not in k: 269 | obs_dict[k] = v[t][None] 270 | next_obs_dict = {'pixels': next_obs_pixels[None]} 271 | for k, v in next_observations.items(): 272 | if 'pixels' not in k: 273 | next_obs_dict[k] = v[t][None] 274 | 275 | q_value = get_value(action, obs_dict, self._critic) 276 | q_pred.append(q_value) 277 | 278 | traj_images.append(make_visual(q_pred, rewards, masks, observations['pixels'])) 279 | print('finished reward value visuals.') 280 | return np.concatenate(traj_images, 0) 281 | 282 | @property 283 | def _save_dict(self): 284 | save_dict = { 285 | 'critic': self._critic, 286 | 'target_critic_params': self._target_critic_params, 287 | 'actor': self._actor, 288 | 'temp': self._temp 289 | } 290 | return save_dict 291 | 292 | def restore_checkpoint(self, dir): 293 | assert pathlib.Path(dir).exists(), f"Checkpoint {dir} does not exist." 294 | output_dict = checkpoints.restore_checkpoint(dir, self._save_dict) 295 | self._actor = output_dict['actor'] 296 | self._critic = output_dict['critic'] 297 | self._target_critic_params = output_dict['target_critic_params'] 298 | self._temp = output_dict['temp'] 299 | print('restored from ', dir) 300 | 301 | 302 | @functools.partial(jax.jit) 303 | def get_value(action, observation, critic): 304 | input_collections = {'params': critic.params} 305 | q_pred = critic.apply_fn(input_collections, observation, action) 306 | return q_pred 307 | 308 | 309 | def np_unstack(array, axis): 310 | arr = np.split(array, array.shape[axis], axis) 311 | arr = [a.squeeze() for a in arr] 312 | return arr 313 | 314 | def make_visual(q_estimates, rewards, masks, images): 315 | 316 | q_estimates_np = np.stack(q_estimates, 0).squeeze() 317 | fig, axs = plt.subplots(4, 1, figsize=(8, 12)) 318 | canvas = FigureCanvas(fig) 319 | plt.xlim([0, len(q_estimates_np)]) 320 | 321 | assert len(images.shape) == 5 322 | images = images[..., -1] # only taking the most recent image of the stack 323 | assert images.shape[-1] == 3 324 | 325 | interval = max(1, images.shape[0] // 4) 326 | sel_images = images[::interval] 327 | sel_images = np.concatenate(np_unstack(sel_images, 0), 1) 328 | 329 | axs[0].imshow(sel_images) 330 | if len(q_estimates_np.shape) == 2: 331 | for i in range(q_estimates_np.shape[1]): 332 | axs[1].plot(q_estimates_np[:, i], linestyle='--', marker='o') 333 | else: 334 | axs[1].plot(q_estimates_np, linestyle='--', marker='o') 335 | axs[1].set_ylabel('q values') 336 | axs[2].plot(rewards, linestyle='--', marker='o') 337 | axs[2].set_ylabel('rewards') 338 | axs[2].set_xlim([0, len(rewards)]) 339 | 340 | axs[3].plot(masks, linestyle='--', marker='d') 341 | axs[3].set_ylabel('masks') 342 | axs[3].set_xlim([0, len(masks)]) 343 | 344 | plt.tight_layout() 345 | 346 | canvas.draw() # draw the canvas, cache the renderer 347 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') 348 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 349 | 350 | plt.close(fig) 351 | return out_image -------------------------------------------------------------------------------- /jaxrl2/data/augmentations.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import functools 4 | from functools import partial 5 | 6 | 7 | def random_crop(key, img, padding): 8 | crop_from = jax.random.randint(key, (2, ), 0, 2 * padding + 1) 9 | crop_from = jnp.concatenate([crop_from, jnp.zeros((2, ), dtype=jnp.int32)]) 10 | padded_img = jnp.pad(img, ((padding, padding), (padding, padding), (0, 0), 11 | (0, 0)), 12 | mode='edge') 13 | return jax.lax.dynamic_slice(padded_img, crop_from, img.shape) 14 | 15 | 16 | def batched_random_crop(key, imgs, padding=4): 17 | keys = jax.random.split(key, imgs.shape[0]) 18 | return jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding) 19 | 20 | @partial(jax.pmap, axis_name='pmap', static_broadcasted_argnums=(1)) 21 | def batched_random_crop_parallel(key, imgs, padding): 22 | keys = jax.random.split(key, imgs.shape[0]) 23 | return jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding) 24 | 25 | # typing 26 | 27 | def _maybe_apply(apply_fn, inputs, rng, apply_prob): 28 | should_apply = jax.random.uniform(rng, shape=()) <= apply_prob 29 | return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x) 30 | 31 | 32 | def _depthwise_conv2d(inputs, kernel, strides, padding): 33 | """Computes a depthwise conv2d in Jax. 34 | Args: 35 | inputs: an NHWC tensor with N=1. 36 | kernel: a [H", W", 1, C] tensor. 37 | strides: a 2d tensor. 38 | padding: "SAME" or "VALID". 39 | Returns: 40 | The depthwise convolution of inputs with kernel, as [H, W, C]. 41 | """ 42 | return jax.lax.conv_general_dilated( 43 | inputs, 44 | kernel, 45 | strides, 46 | padding, 47 | feature_group_count=inputs.shape[-1], 48 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')) 49 | 50 | 51 | def _gaussian_blur_single_image(image, kernel_size, padding, sigma): 52 | """Applies gaussian blur to a single image, given as NHWC with N=1.""" 53 | radius = int(kernel_size / 2) 54 | kernel_size_ = 2 * radius + 1 55 | x = jnp.arange(-radius, radius + 1).astype(jnp.float32) 56 | blur_filter = jnp.exp(-x**2 / (2. * sigma**2)) 57 | blur_filter = blur_filter / jnp.sum(blur_filter) 58 | blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1]) 59 | blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1]) 60 | num_channels = image.shape[-1] 61 | blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels]) 62 | blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels]) 63 | expand_batch_dim = len(image.shape) == 3 64 | if expand_batch_dim: 65 | image = image[jnp.newaxis, ...] 66 | blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding) 67 | blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding) 68 | blurred = jnp.squeeze(blurred, axis=0) 69 | return blurred 70 | 71 | 72 | def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min, 73 | sigma_max, apply_prob): 74 | """Applies a random gaussian blur.""" 75 | apply_rng, transform_rng = jax.random.split(rng) 76 | 77 | def _apply(image): 78 | sigma_rng, = jax.random.split(transform_rng, 1) 79 | sigma = jax.random.uniform( 80 | sigma_rng, 81 | shape=(), 82 | minval=sigma_min, 83 | maxval=sigma_max, 84 | dtype=jnp.float32) 85 | return _gaussian_blur_single_image(image, kernel_size, padding, sigma) 86 | 87 | return _maybe_apply(_apply, image, apply_rng, apply_prob) 88 | 89 | 90 | def rgb_to_hsv(r, g, b): 91 | """Converts R, G, B values to H, S, V values. 92 | Reference TF implementation: 93 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc 94 | Only input values between 0 and 1 are guaranteed to work properly, but this 95 | function complies with the TF implementation outside of this range. 96 | Args: 97 | r: A tensor representing the red color component as floats. 98 | g: A tensor representing the green color component as floats. 99 | b: A tensor representing the blue color component as floats. 100 | Returns: 101 | H, S, V values, each as tensors of shape [...] (same as the input without 102 | the last dimension). 103 | """ 104 | vv = jnp.maximum(jnp.maximum(r, g), b) 105 | range_ = vv - jnp.minimum(jnp.minimum(r, g), b) 106 | sat = jnp.where(vv > 0, range_ / vv, 0.) 107 | norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9) 108 | 109 | hr = norm * (g - b) 110 | hg = norm * (b - r) + 2. / 6. 111 | hb = norm * (r - g) + 4. / 6. 112 | 113 | hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb)) 114 | hue = hue * (range_ > 0) 115 | hue = hue + (hue < 0) 116 | 117 | return hue, sat, vv 118 | 119 | 120 | def hsv_to_rgb(h, s, v): 121 | """Converts H, S, V values to an R, G, B tuple. 122 | Reference TF implementation: 123 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc 124 | Only input values between 0 and 1 are guaranteed to work properly, but this 125 | function complies with the TF implementation outside of this range. 126 | Args: 127 | h: A float tensor of arbitrary shape for the hue (0-1 values). 128 | s: A float tensor of the same shape for the saturation (0-1 values). 129 | v: A float tensor of the same shape for the value channel (0-1 values). 130 | Returns: 131 | An (r, g, b) tuple, each with the same dimension as the inputs. 132 | """ 133 | c = s * v 134 | m = v - c 135 | dh = (h % 1.) * 6. 136 | fmodu = dh % 2. 137 | x = c * (1 - jnp.abs(fmodu - 1)) 138 | hcat = jnp.floor(dh).astype(jnp.int32) 139 | rr = jnp.where( 140 | (hcat == 0) | (hcat == 5), c, jnp.where( 141 | (hcat == 1) | (hcat == 4), x, 0)) + m 142 | gg = jnp.where( 143 | (hcat == 1) | (hcat == 2), c, jnp.where( 144 | (hcat == 0) | (hcat == 3), x, 0)) + m 145 | bb = jnp.where( 146 | (hcat == 3) | (hcat == 4), c, jnp.where( 147 | (hcat == 2) | (hcat == 5), x, 0)) + m 148 | return rr, gg, bb 149 | 150 | 151 | def adjust_brightness(rgb_tuple, delta): 152 | return jax.tree_util.tree_map(lambda x: x + delta, rgb_tuple) 153 | 154 | 155 | def adjust_contrast(image, factor): 156 | def _adjust_contrast_channel(channel): 157 | mean = jnp.mean(channel, axis=(-2, -1), keepdims=True) 158 | return factor * (channel - mean) + mean 159 | return jax.tree_util.tree_map(_adjust_contrast_channel, image) 160 | 161 | 162 | def adjust_saturation(h, s, v, factor): 163 | return h, jnp.clip(s * factor, 0., 1.), v 164 | 165 | 166 | def adjust_hue(h, s, v, delta): 167 | # Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb 168 | # conversions) when running on GPU. When running on CPU, the results will be 169 | # different if all RGB values for a pixel are outside of the [0, 1] range. 170 | return (h + delta) % 1.0, s, v 171 | 172 | 173 | def _random_brightness(rgb_tuple, rng, max_delta): 174 | delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta) 175 | return adjust_brightness(rgb_tuple, delta) 176 | 177 | 178 | def _random_contrast(rgb_tuple, rng, max_delta): 179 | factor = jax.random.uniform( 180 | rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta) 181 | return adjust_contrast(rgb_tuple, factor) 182 | 183 | 184 | def _random_saturation(rgb_tuple, rng, max_delta): 185 | h, s, v = rgb_to_hsv(*rgb_tuple) 186 | factor = jax.random.uniform( 187 | rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta) 188 | return hsv_to_rgb(*adjust_saturation(h, s, v, factor)) 189 | 190 | 191 | def _random_hue(rgb_tuple, rng, max_delta): 192 | h, s, v = rgb_to_hsv(*rgb_tuple) 193 | delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta) 194 | return hsv_to_rgb(*adjust_hue(h, s, v, delta)) 195 | 196 | 197 | def _to_grayscale(image): 198 | rgb_weights = jnp.array([0.2989, 0.5870, 0.1140]) 199 | grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis] 200 | return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels. 201 | 202 | 203 | def _color_transform_single_image(image, rng, brightness, contrast, saturation, 204 | hue, to_grayscale_prob, color_jitter_prob, 205 | apply_prob, shuffle): 206 | """Applies color jittering to a single image.""" 207 | apply_rng, transform_rng = jax.random.split(rng) 208 | perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split( 209 | transform_rng, 7) 210 | 211 | # Whether the transform should be applied at all. 212 | should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob 213 | # Whether to apply grayscale transform. 214 | should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob 215 | # Whether to apply color jittering. 216 | should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob 217 | 218 | # Decorator to conditionally apply fn based on an index. 219 | def _make_cond(fn, idx): 220 | 221 | def identity_fn(x, unused_rng, unused_param): 222 | return x 223 | 224 | def cond_fn(args, i): 225 | def clip(args): 226 | return jax.tree_util.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args) 227 | out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args, 228 | lambda a: clip(fn(*a)), args, 229 | lambda a: identity_fn(*a)) 230 | return jax.lax.stop_gradient(out) 231 | 232 | return cond_fn 233 | 234 | random_brightness_cond = _make_cond(_random_brightness, idx=0) 235 | random_contrast_cond = _make_cond(_random_contrast, idx=1) 236 | random_saturation_cond = _make_cond(_random_saturation, idx=2) 237 | random_hue_cond = _make_cond(_random_hue, idx=3) 238 | 239 | def _color_jitter(x): 240 | rgb_tuple = tuple(jax.tree_util.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1))) 241 | if shuffle: 242 | order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32)) 243 | else: 244 | order = range(4) 245 | for idx in order: 246 | if brightness > 0: 247 | rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx) 248 | if contrast > 0: 249 | rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx) 250 | if saturation > 0: 251 | rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx) 252 | if hue > 0: 253 | rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx) 254 | return jnp.stack(rgb_tuple, axis=-1) 255 | 256 | out_apply = _color_jitter(image) 257 | out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply, 258 | _to_grayscale, out_apply, lambda x: x) 259 | return jnp.clip(out_apply, 0., 1.) 260 | 261 | 262 | def _random_flip_single_image(image, rng): 263 | _, flip_rng = jax.random.split(rng) 264 | should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5 265 | image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x) 266 | return image 267 | 268 | 269 | def random_flip(images, rng): 270 | rngs = jax.random.split(rng, images.shape[0]) 271 | return jax.vmap(_random_flip_single_image)(images, rngs) 272 | 273 | 274 | def color_transform(rng, 275 | images, 276 | brightness=0.2, 277 | contrast=0.1, 278 | saturation=0.1, 279 | hue=0.03, 280 | color_jitter_prob=0.8, 281 | to_grayscale_prob=0.0, 282 | apply_prob=1.0, 283 | shuffle=True): 284 | """Applies color jittering and/or grayscaling to a batch of images. 285 | Args: 286 | images: an NHWC tensor, with C=3. 287 | rng: a single PRNGKey. 288 | brightness: the range of jitter on brightness. 289 | contrast: the range of jitter on contrast. 290 | saturation: the range of jitter on saturation. 291 | hue: the range of jitter on hue. 292 | color_jitter_prob: the probability of applying color jittering. 293 | to_grayscale_prob: the probability of converting the image to grayscale. 294 | apply_prob: the probability of applying the transform to a batch element. 295 | shuffle: whether to apply the transforms in a random order. 296 | Returns: 297 | A NHWC tensor of the transformed images. 298 | """ 299 | images = images[:, :, :, :, 0] 300 | rngs = jax.random.split(rng, images.shape[0]) 301 | jitter_fn = functools.partial( 302 | _color_transform_single_image, 303 | brightness=brightness, 304 | contrast=contrast, 305 | saturation=saturation, 306 | hue=hue, 307 | color_jitter_prob=color_jitter_prob, 308 | to_grayscale_prob=to_grayscale_prob, 309 | apply_prob=apply_prob, 310 | shuffle=shuffle) 311 | augmented_images = jax.vmap(jitter_fn)(images, rngs) 312 | return augmented_images[..., jnp.newaxis] 313 | 314 | @partial(jax.pmap, axis_name='pmap') 315 | def color_transform_parallel(rng, images): 316 | """Applies color jittering and/or grayscaling to a batch of images. 317 | Args: 318 | images: an NHWC tensor, with C=3. 319 | rng: a single PRNGKey. 320 | brightness: the range of jitter on brightness. 321 | contrast: the range of jitter on contrast. 322 | saturation: the range of jitter on saturation. 323 | hue: the range of jitter on hue. 324 | color_jitter_prob: the probability of applying color jittering. 325 | to_grayscale_prob: the probability of converting the image to grayscale. 326 | apply_prob: the probability of applying the transform to a batch element. 327 | shuffle: whether to apply the transforms in a random order. 328 | Returns: 329 | A NHWC tensor of the transformed images. 330 | """ 331 | brightness=0.2, 332 | contrast=0.1, 333 | saturation=0.1, 334 | hue=0.03, 335 | color_jitter_prob=0.8, 336 | to_grayscale_prob=0.0, 337 | apply_prob=1.0, 338 | shuffle=True 339 | images = images[:, :, :, :, 0] 340 | rngs = jax.random.split(rng, images.shape[0]) 341 | jitter_fn = functools.partial( 342 | _color_transform_single_image, 343 | brightness=brightness, 344 | contrast=contrast, 345 | saturation=saturation, 346 | hue=hue, 347 | color_jitter_prob=color_jitter_prob, 348 | to_grayscale_prob=to_grayscale_prob, 349 | apply_prob=apply_prob, 350 | shuffle=shuffle) 351 | augmented_images = jax.vmap(jitter_fn)(images, rngs) 352 | return augmented_images[..., jnp.newaxis] 353 | 354 | 355 | def gaussian_blur(images, 356 | rng, 357 | blur_divider=10., 358 | sigma_min=0.1, 359 | sigma_max=2.0, 360 | apply_prob=1.0): 361 | """Applies gaussian blur to a batch of images. 362 | Args: 363 | images: an NHWC tensor, with C=3. 364 | rng: a single PRNGKey. 365 | blur_divider: the blurring kernel will have size H / blur_divider. 366 | sigma_min: the minimum value for sigma in the blurring kernel. 367 | sigma_max: the maximum value for sigma in the blurring kernel. 368 | apply_prob: the probability of applying the transform to a batch element. 369 | Returns: 370 | A NHWC tensor of the blurred images. 371 | """ 372 | rngs = jax.random.split(rng, images.shape[0]) 373 | kernel_size = images.shape[1] / blur_divider 374 | blur_fn = functools.partial( 375 | _random_gaussian_blur, 376 | kernel_size=kernel_size, 377 | padding='SAME', 378 | sigma_min=sigma_min, 379 | sigma_max=sigma_max, 380 | apply_prob=apply_prob) 381 | return jax.vmap(blur_fn)(images, rngs) 382 | 383 | 384 | def _solarize_single_image(image, rng, threshold, apply_prob): 385 | 386 | def _apply(image): 387 | return jnp.where(image < threshold, image, 1. - image) 388 | 389 | return _maybe_apply(_apply, image, rng, apply_prob) 390 | 391 | 392 | def solarize(images, rng, threshold=0.5, apply_prob=1.0): 393 | """Applies solarization. 394 | Args: 395 | images: an NHWC tensor (with C=3). 396 | rng: a single PRNGKey. 397 | threshold: the solarization threshold. 398 | apply_prob: the probability of applying the transform to a batch element. 399 | Returns: 400 | A NHWC tensor of the transformed images. 401 | """ 402 | rngs = jax.random.split(rng, images.shape[0]) 403 | solarize_fn = functools.partial( 404 | _solarize_single_image, threshold=threshold, apply_prob=apply_prob) 405 | return jax.vmap(solarize_fn)(images, rngs) 406 | -------------------------------------------------------------------------------- /examples/train_utils_sim.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import wandb 4 | import jax 5 | from openpi_client import image_tools 6 | import math 7 | import PIL 8 | 9 | def _quat2axisangle(quat): 10 | """ 11 | Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 12 | """ 13 | # clip quaternion 14 | if quat[3] > 1.0: 15 | quat[3] = 1.0 16 | elif quat[3] < -1.0: 17 | quat[3] = -1.0 18 | 19 | den = np.sqrt(1.0 - quat[3] * quat[3]) 20 | if math.isclose(den, 0.0): 21 | # This is (close to) a zero degree rotation, immediately return 22 | return np.zeros(3) 23 | 24 | return (quat[:3] * 2.0 * math.acos(quat[3])) / den 25 | 26 | def obs_to_img(obs, variant): 27 | ''' 28 | Convert raw observation to resized image for DSRL actor/critic 29 | ''' 30 | if variant.env == 'libero': 31 | curr_image = obs["agentview_image"][::-1, ::-1] 32 | elif variant.env == 'aloha_cube': 33 | curr_image = obs["pixels"]["top"] 34 | else: 35 | raise NotImplementedError() 36 | if variant.resize_image > 0: 37 | curr_image = np.array(PIL.Image.fromarray(curr_image).resize((variant.resize_image, variant.resize_image))) 38 | return curr_image 39 | 40 | def obs_to_pi_zero_input(obs, variant): 41 | if variant.env == 'libero': 42 | img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) 43 | wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) 44 | img = image_tools.convert_to_uint8( 45 | image_tools.resize_with_pad(img, 224, 224) 46 | ) 47 | wrist_img = image_tools.convert_to_uint8( 48 | image_tools.resize_with_pad(wrist_img, 224, 224) 49 | ) 50 | 51 | obs_pi_zero = { 52 | "observation/image": img, 53 | "observation/wrist_image": wrist_img, 54 | "observation/state": np.concatenate( 55 | ( 56 | obs["robot0_eef_pos"], 57 | _quat2axisangle(obs["robot0_eef_quat"]), 58 | obs["robot0_gripper_qpos"], 59 | ) 60 | ), 61 | "prompt": str(variant.task_description), 62 | } 63 | elif variant.env == 'aloha_cube': 64 | img = np.ascontiguousarray(obs["pixels"]["top"]) 65 | img = image_tools.convert_to_uint8( 66 | image_tools.resize_with_pad(img, 224, 224) 67 | ) 68 | obs_pi_zero = { 69 | "state": obs["agent_pos"], 70 | "images": {"cam_high": np.transpose(img, (2,0,1))} 71 | } 72 | else: 73 | raise NotImplementedError() 74 | return obs_pi_zero 75 | 76 | def obs_to_qpos(obs, variant): 77 | if variant.env == 'libero': 78 | qpos = np.concatenate( 79 | ( 80 | obs["robot0_eef_pos"], 81 | _quat2axisangle(obs["robot0_eef_quat"]), 82 | obs["robot0_gripper_qpos"], 83 | ) 84 | ) 85 | elif variant.env == 'aloha_cube': 86 | qpos = obs["agent_pos"] 87 | else: 88 | raise NotImplementedError() 89 | return qpos 90 | 91 | def trajwise_alternating_training_loop(variant, agent, env, eval_env, online_replay_buffer, replay_buffer, wandb_logger, 92 | perform_control_evals=True, shard_fn=None, agent_dp=None): 93 | replay_buffer_iterator = replay_buffer.get_iterator(variant.batch_size) 94 | if shard_fn is not None: 95 | replay_buffer_iterator = map(shard_fn, replay_buffer_iterator) 96 | 97 | total_env_steps = 0 98 | i = 0 99 | wandb_logger.log({'num_online_samples': 0}, step=i) 100 | wandb_logger.log({'num_online_trajs': 0}, step=i) 101 | wandb_logger.log({'env_steps': 0}, step=i) 102 | 103 | with tqdm(total=variant.max_steps, initial=0) as pbar: 104 | while i <= variant.max_steps: 105 | traj = collect_traj(variant, agent, env, i, agent_dp) 106 | traj_id = online_replay_buffer._traj_counter 107 | add_online_data_to_buffer(variant, traj, online_replay_buffer) 108 | total_env_steps += traj['env_steps'] 109 | print('online buffer timesteps length:', len(online_replay_buffer)) 110 | print('online buffer num traj:', traj_id + 1) 111 | print('total env steps:', total_env_steps) 112 | 113 | if variant.get("num_online_gradsteps_batch", -1) > 0: 114 | num_gradsteps = variant.num_online_gradsteps_batch 115 | else: 116 | num_gradsteps = len(traj["rewards"])*variant.multi_grad_step 117 | 118 | if len(online_replay_buffer) > variant.start_online_updates: 119 | for _ in range(num_gradsteps): 120 | # perform first visualization before updating 121 | if i == 0: 122 | print('performing evaluation for initial checkpoint') 123 | if perform_control_evals: 124 | perform_control_eval(agent, eval_env, i, variant, wandb_logger, agent_dp) 125 | if hasattr(agent, 'perform_eval'): 126 | agent.perform_eval(variant, i, wandb_logger, replay_buffer, replay_buffer_iterator, eval_env) 127 | 128 | # online perform update once we have some amount of online trajs 129 | batch = next(replay_buffer_iterator) 130 | update_info = agent.update(batch) 131 | 132 | pbar.update() 133 | i += 1 134 | 135 | 136 | if i % variant.log_interval == 0: 137 | update_info = {k: jax.device_get(v) for k, v in update_info.items()} 138 | for k, v in update_info.items(): 139 | if v.ndim == 0: 140 | wandb_logger.log({f'training/{k}': v}, step=i) 141 | elif v.ndim <= 2: 142 | wandb_logger.log_histogram(f'training/{k}', v, i) 143 | # wandb_logger.log({'replay_buffer_size': len(online_replay_buffer)}, i) 144 | wandb_logger.log({ 145 | 'replay_buffer_size': len(online_replay_buffer), 146 | 'episode_return (exploration)': traj['episode_return'], 147 | 'is_success (exploration)': int(traj['is_success']), 148 | }, i) 149 | 150 | if i % variant.eval_interval == 0: 151 | wandb_logger.log({'num_online_samples': len(online_replay_buffer)}, step=i) 152 | wandb_logger.log({'num_online_trajs': traj_id + 1}, step=i) 153 | wandb_logger.log({'env_steps': total_env_steps}, step=i) 154 | if perform_control_evals: 155 | perform_control_eval(agent, eval_env, i, variant, wandb_logger, agent_dp) 156 | if hasattr(agent, 'perform_eval'): 157 | agent.perform_eval(variant, i, wandb_logger, replay_buffer, replay_buffer_iterator, eval_env) 158 | 159 | if variant.checkpoint_interval != -1 and i % variant.checkpoint_interval == 0: 160 | agent.save_checkpoint(variant.outputdir, i, variant.checkpoint_interval) 161 | 162 | 163 | def add_online_data_to_buffer(variant, traj, online_replay_buffer): 164 | 165 | discount_horizon = variant.query_freq 166 | actions = np.array(traj['actions']) # (T, chunk_size, action_dim ) 167 | episode_len = len(actions) 168 | rewards = np.array(traj['rewards']) 169 | masks = np.array(traj['masks']) 170 | 171 | for t in range(episode_len): 172 | obs = traj['observations'][t] 173 | next_obs = traj['observations'][t + 1] 174 | # remove batch dimension 175 | obs = {k: v[0] for k, v in obs.items()} 176 | next_obs = {k: v[0] for k, v in next_obs.items()} 177 | if not variant.add_states: 178 | obs.pop('state', None) 179 | next_obs.pop('state', None) 180 | 181 | insert_dict = dict( 182 | observations=obs, 183 | next_observations=next_obs, 184 | actions=actions[t], 185 | next_actions=actions[t + 1] if t < episode_len - 1 else actions[t], 186 | rewards=rewards[t], 187 | masks=masks[t], 188 | discount=variant.discount ** discount_horizon 189 | ) 190 | online_replay_buffer.insert(insert_dict) 191 | online_replay_buffer.increment_traj_counter() 192 | 193 | def collect_traj(variant, agent, env, i, agent_dp=None): 194 | query_frequency = variant.query_freq 195 | max_timesteps = variant.max_timesteps 196 | env_max_reward = variant.env_max_reward 197 | 198 | agent._rng, rng = jax.random.split(agent._rng) 199 | 200 | if 'libero' in variant.env: 201 | obs = env.reset() 202 | elif 'aloha' in variant.env: 203 | obs, _ = env.reset() 204 | 205 | image_list = [] # for visualization 206 | rewards = [] 207 | action_list = [] 208 | obs_list = [] 209 | 210 | for t in tqdm(range(max_timesteps)): 211 | curr_image = obs_to_img(obs, variant) 212 | 213 | qpos = obs_to_qpos(obs, variant) 214 | 215 | if variant.add_states: 216 | obs_dict = { 217 | 'pixels': curr_image[np.newaxis, ..., np.newaxis], 218 | 'state': qpos[np.newaxis, ..., np.newaxis], 219 | } 220 | else: 221 | obs_dict = { 222 | 'pixels': curr_image[np.newaxis, ..., np.newaxis], 223 | } 224 | 225 | if t % query_frequency == 0: 226 | 227 | assert agent_dp is not None 228 | # we then use the noise to sample the action from diffusion model 229 | rng, key = jax.random.split(rng) 230 | obs_pi_zero = obs_to_pi_zero_input(obs, variant) 231 | if i == 0: 232 | # for initial round of data collection, we sample from standard gaussian noise 233 | noise = jax.random.normal(key, (1, *agent.action_chunk_shape)) 234 | noise_repeat = jax.numpy.repeat(noise[:, -1:, :], 50 - noise.shape[1], axis=1) 235 | noise = jax.numpy.concatenate([noise, noise_repeat], axis=1) 236 | actions_noise = noise[0, :agent.action_chunk_shape[0], :] 237 | else: 238 | # sac agent predicts the noise for diffusion model 239 | actions_noise = agent.sample_actions(obs_dict) 240 | actions_noise = np.reshape(actions_noise, agent.action_chunk_shape) 241 | noise = np.repeat(actions_noise[-1:, :], 50 - actions_noise.shape[0], axis=0) 242 | noise = jax.numpy.concatenate([actions_noise, noise], axis=0)[None] 243 | 244 | actions = agent_dp.infer(obs_pi_zero, noise=noise)["actions"] 245 | action_list.append(actions_noise) 246 | obs_list.append(obs_dict) 247 | 248 | action_t = actions[t % query_frequency] 249 | if 'libero' in variant.env: 250 | obs, reward, done, _ = env.step(action_t) 251 | elif 'aloha' in variant.env: 252 | obs, reward, terminated, truncated, _ = env.step(action_t) 253 | done = terminated or truncated 254 | 255 | rewards.append(reward) 256 | image_list.append(curr_image) 257 | if done: 258 | break 259 | 260 | # add last observation 261 | curr_image = obs_to_img(obs, variant) 262 | qpos = obs_to_qpos(obs, variant) 263 | obs_dict = { 264 | 'pixels': curr_image[np.newaxis, ..., np.newaxis], 265 | 'state': qpos[np.newaxis, ..., np.newaxis], 266 | } 267 | obs_list.append(obs_dict) 268 | image_list.append(curr_image) 269 | 270 | # per episode 271 | rewards = np.array(rewards) 272 | episode_return = np.sum(rewards[rewards!=None]) 273 | is_success = (reward == env_max_reward) 274 | print(f'Rollout Done: {episode_return=}, Success: {is_success}') 275 | 276 | 277 | ''' 278 | We use sparse -1/0 reward to train the SAC agent. 279 | ''' 280 | if is_success: 281 | query_steps = len(action_list) 282 | rewards = np.concatenate([-np.ones(query_steps - 1), [0]]) 283 | masks = np.concatenate([np.ones(query_steps - 1), [0]]) 284 | else: 285 | query_steps = len(action_list) 286 | rewards = -np.ones(query_steps) 287 | masks = np.ones(query_steps) 288 | 289 | return { 290 | 'observations': obs_list, 291 | 'actions': action_list, 292 | 'rewards': rewards, 293 | 'masks': masks, 294 | 'is_success': is_success, 295 | 'episode_return': episode_return, 296 | 'images': image_list, 297 | 'env_steps': t + 1 298 | } 299 | 300 | def perform_control_eval(agent, env, i, variant, wandb_logger, agent_dp=None): 301 | query_frequency = variant.query_freq 302 | print('query frequency', query_frequency) 303 | max_timesteps = variant.max_timesteps 304 | env_max_reward = variant.env_max_reward 305 | episode_returns = [] 306 | highest_rewards = [] 307 | success_rates = [] 308 | episode_lens = [] 309 | 310 | rng = jax.random.PRNGKey(variant.seed+456) 311 | 312 | for rollout_id in range(variant.eval_episodes): 313 | if 'libero' in variant.env: 314 | obs = env.reset() 315 | elif 'aloha' in variant.env: 316 | obs, _ = env.reset() 317 | 318 | image_list = [] # for visualization 319 | rewards = [] 320 | 321 | 322 | for t in tqdm(range(max_timesteps)): 323 | curr_image = obs_to_img(obs, variant) 324 | 325 | if t % query_frequency == 0: 326 | qpos = obs_to_qpos(obs, variant) 327 | if variant.add_states: 328 | obs_dict = { 329 | 'pixels': curr_image[np.newaxis, ..., np.newaxis], 330 | 'state': qpos[np.newaxis, ..., np.newaxis], 331 | } 332 | else: 333 | obs_dict = { 334 | 'pixels': curr_image[np.newaxis, ..., np.newaxis], 335 | } 336 | 337 | rng, key = jax.random.split(rng) 338 | assert agent_dp is not None 339 | 340 | obs_pi_zero = obs_to_pi_zero_input(obs, variant) 341 | 342 | 343 | if i == 0: 344 | # for initial evaluation, we sample from standard gaussian noise to evaluate the base policy's performance 345 | noise = jax.random.normal(rng, (1, 50, 32)) 346 | else: 347 | actions_noise = agent.sample_actions(obs_dict) 348 | actions_noise = np.reshape(actions_noise, agent.action_chunk_shape) 349 | noise = np.repeat(actions_noise[-1:, :], 50 - actions_noise.shape[0], axis=0) 350 | noise = jax.numpy.concatenate([actions_noise, noise], axis=0)[None] 351 | 352 | actions = agent_dp.infer(obs_pi_zero, noise=noise)["actions"] 353 | 354 | action_t = actions[t % query_frequency] 355 | 356 | if 'libero' in variant.env: 357 | obs, reward, done, _ = env.step(action_t) 358 | elif 'aloha' in variant.env: 359 | obs, reward, terminated, truncated, _ = env.step(action_t) 360 | done = terminated or truncated 361 | 362 | rewards.append(reward) 363 | image_list.append(curr_image) 364 | if done: 365 | break 366 | 367 | # per episode 368 | episode_lens.append(t + 1) 369 | rewards = np.array(rewards) 370 | episode_return = np.sum(rewards) 371 | episode_returns.append(episode_return) 372 | episode_highest_reward = np.max(rewards) 373 | highest_rewards.append(episode_highest_reward) 374 | is_success = (reward == env_max_reward) 375 | success_rates.append(is_success) 376 | 377 | print(f'Rollout {rollout_id} : {episode_return=}, Success: {is_success}') 378 | video = np.stack(image_list).transpose(0, 3, 1, 2) 379 | wandb_logger.log({f'eval_video/{rollout_id}': wandb.Video(video, fps=50)}, step=i) 380 | 381 | 382 | success_rate = np.mean(np.array(success_rates)) 383 | avg_return = np.mean(episode_returns) 384 | avg_episode_len = np.mean(episode_lens) 385 | summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n' 386 | wandb_logger.log({'evaluation/avg_return': avg_return}, step=i) 387 | wandb_logger.log({'evaluation/success_rate': success_rate}, step=i) 388 | wandb_logger.log({'evaluation/avg_episode_len': avg_episode_len}, step=i) 389 | for r in range(env_max_reward+1): 390 | more_or_equal_r = (np.array(highest_rewards) >= r).sum() 391 | more_or_equal_r_rate = more_or_equal_r / variant.eval_episodes 392 | wandb_logger.log({f'evaluation/Reward >= {r}': more_or_equal_r_rate}, step=i) 393 | summary_str += f'Reward >= {r}: {more_or_equal_r}/{variant.eval_episodes} = {more_or_equal_r_rate*100}%\n' 394 | 395 | print(summary_str) 396 | 397 | def make_multiple_value_reward_visulizations(agent, variant, i, replay_buffer, wandb_logger): 398 | trajs = replay_buffer.get_random_trajs(3) 399 | images = agent.make_value_reward_visulization(variant, trajs) 400 | wandb_logger.log({'reward_value_images': wandb.Image(images)}, step=i) 401 | 402 | --------------------------------------------------------------------------------