├── setup.py ├── experiments ├── scripts │ ├── launch_contrastive.sh │ ├── launch_contrastive_gpu.sh │ ├── eval_policy.sh │ ├── eval_policy_clip.sh │ ├── eval_all.sh │ ├── launch_sim.sh │ ├── .nfs0000000005fc462300000001 │ ├── launch_bridge.sh │ └── eval.sh ├── normalize_actions.py ├── create_eval_goals.py ├── eval_visualize_roboverse.py ├── eval_checkpoints.py ├── configs │ └── offline_contrastive_config.py └── sim_offline_gc.py ├── jaxrl_m ├── agents │ ├── __init__.py │ ├── discrete │ │ ├── bc.py │ │ ├── cql.py │ │ ├── gc_cql.py │ │ ├── gc_iql.py │ │ └── iql.py │ └── continuous │ │ ├── affordance.py │ │ ├── bc.py │ │ └── gc_bc.py ├── envs │ └── wrappers │ │ ├── action_norm.py │ │ ├── mujoco.py │ │ ├── dmcgym.py │ │ ├── roboverse.py │ │ └── video_recorder.py ├── common │ ├── typing.py │ ├── wandb.py │ ├── evaluation.py │ ├── encoding.py │ └── common.py ├── vision │ ├── __init__.py │ ├── cvae.py │ ├── small_encoders.py │ ├── bigvision_common.py │ └── impala.py ├── utils │ ├── timer_utils.py │ ├── train_utils.py │ └── sim_utils.py ├── data │ ├── ss2_language.py │ ├── sgl_dataset.py │ ├── replay_buffer.py │ ├── tf_augmentations.py │ ├── language.py │ ├── ego4d.py │ ├── dataset.py │ └── ss2.py └── networks │ ├── discrete_nets.py │ └── actor_critic_nets.py ├── README.md ├── environment_tpu.yml ├── environment_cuda.yml ├── Dockerfile ├── scripts ├── ss2_frames_to_tfrecord.py └── bridgedata_numpy_to_tfrecord.py └── contributing.md /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="jaxrl_m", packages=["jaxrl_m"]) 4 | -------------------------------------------------------------------------------- /experiments/scripts/launch_contrastive.sh: -------------------------------------------------------------------------------- 1 | CONFIG=${1-contrastive_tpu} 2 | NAME="all_v1_task_$CONFIG" 3 | 4 | CMD="python experiments/andre/contrastive.py \ 5 | --config experiments/andre/configs/offline_contrastive_config.py:$CONFIG \ 6 | --bridgedata_config experiments/andre/configs/bridgedata_config.py:all \ 7 | --name $NAME \ 8 | $LAUNCH_FLAGS" 9 | 10 | $CMD 11 | -------------------------------------------------------------------------------- /experiments/scripts/launch_contrastive_gpu.sh: -------------------------------------------------------------------------------- 1 | CONFIG=${1-contrastive_gpu} 2 | NAME="all_v1_task_$CONFIG" 3 | 4 | CMD="python experiments/andre/ego4d_contrastive.py \ 5 | --config experiments/andre/configs/offline_contrastive_config.py:$CONFIG \ 6 | --bridgedata_config experiments/andre/configs/bridgedata_config.py:all \ 7 | --name $NAME $LAUNCH_FLAGS" 8 | 9 | $CMD 10 | -------------------------------------------------------------------------------- /jaxrl_m/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous.bc import BCAgent 2 | from .continuous.gc_bc import GCBCAgent 3 | from .continuous.gc_iql import GCIQLAgent 4 | from .continuous.iql import IQLAgent 5 | from .continuous.multimodal import MultimodalAgent 6 | 7 | agents = { 8 | "gc_bc": GCBCAgent, 9 | "gc_iql": GCIQLAgent, 10 | "bc": BCAgent, 11 | "iql": IQLAgent, 12 | "multimodal": MultimodalAgent, 13 | } 14 | -------------------------------------------------------------------------------- /experiments/scripts/eval_policy.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 2 | NAME="$1" 3 | CHECK="$2" 4 | 5 | CMD="python experiments/vivek/eval_policy.py \ 6 | --num_timesteps 60 \ 7 | --video_save_path /home/robonet \ 8 | --checkpoint_path gs://rail-tpus-$3/jaxrl_m_bridgedata/$NAME/checkpoint_$CHECK \ 9 | --wandb_run_name widowx-gcrl/jaxrl_m_bridgedata/$NAME" 10 | 11 | $CMD --goal_eep "0.3 0.0 0.1" --initial_eep "0.3 0.0 0.1" 12 | 13 | -------------------------------------------------------------------------------- /experiments/scripts/eval_policy_clip.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 2 | NAME="$1" 3 | CHECK="$2" 4 | 5 | CMD="python experiments/vivek/eval_policy_clip.py \ 6 | --num_timesteps 60 \ 7 | --video_save_path /home/robonet \ 8 | --checkpoint_path gs://rail-tpus-$3/jaxrl_m_bridgedata/$NAME/checkpoint_$CHECK \ 9 | --wandb_run_name widowx-gcrl/jaxrl_m_bridgedata/$NAME" 10 | 11 | $CMD --goal_eep "0.3 0.0 0.1" --initial_eep "0.3 0.0 0.1" 12 | 13 | -------------------------------------------------------------------------------- /jaxrl_m/envs/wrappers/action_norm.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class UnnormalizeAction(gym.ActionWrapper): 5 | def __init__(self, env: gym.Env, action_metadata: dict): 6 | self.action_mean = action_metadata["mean"] 7 | self.action_std = action_metadata["std"] 8 | super().__init__(env) 9 | 10 | def action(self, action): 11 | """ 12 | Un-normalizes the action 13 | """ 14 | action = (action * self.action_std) + self.action_mean 15 | return action 16 | -------------------------------------------------------------------------------- /jaxrl_m/common/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union 2 | 3 | import numpy as np 4 | import jax.numpy as jnp 5 | import flax 6 | import tensorflow as tf 7 | 8 | 9 | PRNGKey = Any 10 | Params = flax.core.FrozenDict[str, Any] 11 | Shape = Sequence[int] 12 | Dtype = Any # this could be a real type? 13 | InfoDict = Dict[str, float] 14 | Array = Union[np.ndarray, jnp.ndarray, tf.Tensor] 15 | Data = Union[Array, Dict[str, "Data"]] 16 | Batch = Dict[str, Data] 17 | # A method to be passed into TrainState.__call__ 18 | ModuleMethod = Union[str, Callable, None] 19 | -------------------------------------------------------------------------------- /jaxrl_m/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl_m.vision.impala import impala_configs 2 | from jaxrl_m.vision.bigvision_resnetv2 import resnetv2_configs 3 | from jaxrl_m.vision.small_encoders import small_configs 4 | from jaxrl_m.vision.resnet_v1 import resnetv1_configs 5 | from jaxrl_m.vision.clip import * 6 | 7 | encoders = dict() 8 | encoders.update(impala_configs) 9 | encoders.update(resnetv2_configs) 10 | encoders.update(resnetv1_configs) 11 | encoders.update(small_configs) 12 | encoders["clip_text_with_projection"] = CLIPTextEncoderWithProjection 13 | encoders["clip_vision_with_projection"] = CLIPVisionEncoderWithProjection 14 | encoders["clip_text_with_ftmap"] = CLIPTextEncoderWithFtMap 15 | encoders["clip_vision_with_ftmap"] = CLIPVisionEncoderWithFtMap 16 | encoders["muse"] = MUSEPlaceHolder 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Goal Representations for Instruction Following (GRIF) 2 | 3 | This is the code repository for the paper *Goal Representations for Instruction Following: A Semi-Supervised Language Interface to Control* [[arXiv](https://arxiv.org/abs/2307.00117), [website](https://rail-berkeley.github.io/grif/)]. 4 | 5 | Based on [dibyaghosh/jaxrl_minimal](https://github.com/dibyaghosh/jaxrl_minimal). 6 | 7 | 8 | ## Environment 9 | For GPU: 10 | ``` 11 | conda env create -f environment_cuda.yml 12 | ``` 13 | 14 | For TPU: 15 | ``` 16 | conda env create -f environment_tpu.yml 17 | ``` 18 | 19 | See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax. 20 | 21 | ## Running 22 | 23 | To train GRIF, run 24 | ``` 25 | bash experiments/scripts/launch_bridge.sh GRIF 26 | ``` 27 | -------------------------------------------------------------------------------- /jaxrl_m/utils/timer_utils.py: -------------------------------------------------------------------------------- 1 | """Timer utility.""" 2 | 3 | from collections import defaultdict 4 | import time 5 | 6 | 7 | class Timer: 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.counts = defaultdict(int) 13 | self.times = defaultdict(float) 14 | self.start_times = {} 15 | 16 | def tick(self, key): 17 | if key in self.start_times: 18 | raise ValueError(f"Timer is already ticking for key: {key}") 19 | self.start_times[key] = time.time() 20 | 21 | def tock(self, key): 22 | if key not in self.start_times: 23 | raise ValueError(f"Timer is not ticking for key: {key}") 24 | self.counts[key] += 1 25 | self.times[key] += time.time() - self.start_times[key] 26 | del self.start_times[key] 27 | 28 | def get_average_times(self, reset=True): 29 | ret = {key: self.times[key] / self.counts[key] for key in self.counts} 30 | if reset: 31 | self.reset() 32 | return ret 33 | -------------------------------------------------------------------------------- /jaxrl_m/data/ss2_language.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import json 3 | 4 | lang_to_code = {} 5 | code_to_lang = {} 6 | 7 | 8 | def load_mapping(path): 9 | global lang_to_code, code_to_lang 10 | 11 | for split in ["train", "validation"]: 12 | labels_path = tf.io.gfile.join(path, f"{split}.json") 13 | labels = json.loads(tf.io.gfile.GFile(labels_path, "r").read()) 14 | 15 | for label in labels: 16 | code = int(label["id"]) 17 | caption = label["label"] 18 | lang_to_code[caption] = code 19 | code_to_lang[code] = caption 20 | 21 | 22 | def lang_encode(lang): 23 | if not lang: 24 | return -1 25 | elif lang in lang_to_code: 26 | return lang_to_code[lang] 27 | else: 28 | raise ValueError(f"Language {lang} not found in mapping") 29 | 30 | 31 | def lang_decode(code): 32 | if code in code_to_lang: 33 | return code_to_lang[code] 34 | else: 35 | raise ValueError(f"Code {code} not found in mapping") 36 | 37 | 38 | def get_encodings(): 39 | return lang_to_code, code_to_lang 40 | -------------------------------------------------------------------------------- /experiments/normalize_actions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a metadata.npy file with the mean and standard deviation of the actions 3 | in a dataset. Saves the file to the same folder as the data. This file is read 4 | during training and used to normalize the actions in the data loader. 5 | """ 6 | 7 | import numpy as np 8 | from absl import flags, app 9 | from jaxrl_m.utils.sim_utils import load_tf_dataset 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | 13 | FLAGS = flags.FLAGS 14 | flags.DEFINE_string("data_path", None, "Location of dataset", required=True) 15 | 16 | 17 | def main(_): 18 | dataset = load_tf_dataset(FLAGS.data_path) 19 | actions = [] 20 | for f in tqdm(iter(dataset)): 21 | actions.append(f["actions"][:]) 22 | actions = np.concatenate(actions) 23 | metadata = {} 24 | metadata["mean"] = np.mean(actions, axis=0) 25 | metadata["std"] = np.std(actions, axis=0) 26 | # don't normalize gripper 27 | metadata["mean"][6] = 0 28 | metadata["std"][6] = 1 29 | np.save(tf.io.gfile.join(FLAGS.data_path, "metadata.npy"), metadata) 30 | 31 | 32 | if __name__ == "__main__": 33 | app.run(main) 34 | -------------------------------------------------------------------------------- /environment_tpu.yml: -------------------------------------------------------------------------------- 1 | name: grif 2 | channels: 3 | - conda-forge 4 | - defaults 5 | - nvidia 6 | dependencies: 7 | - python=3.9 8 | - pip=21.0 9 | - setuptools==65.5.0 10 | - numpy<=1.23 11 | - scipy>=1.6.0 12 | - matplotlib=3.8.2 13 | - tqdm>=4.60.0 14 | - absl-py>=0.12.0 15 | - wandb>=0.12.14 16 | - moviepy>=1.0.3 17 | - google-auth==2.7.0 18 | - lockfile=0.12 19 | - imageio=2.19 20 | - ml-collections=0.1 21 | - distrax<2 22 | - imageio[ffmpeg] 23 | - mesalib 24 | - glew 25 | - cudatoolkit 26 | - cudnn 27 | - cuda-nvcc 28 | - patchelf 29 | - optax==0.1.7 30 | - pip: 31 | - jax[tpu]==0.4.14 32 | - flax==0.6.11 33 | - chex==0.1.7 34 | - tensorflow-cpu==2.11 35 | - tensorflow-hub==0.12 36 | - tensorflow-text==2.11 37 | - tensorflow-probability==0.19 38 | - tensorflow-datasets==4.9 39 | - tensorflow-estimator==2.11 40 | - tensorboard==2.11 41 | - transformers==4.25 42 | - gym==0.23 43 | - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 44 | - -e . 45 | variables: 46 | TF_FORCE_GPU_ALLOW_GROWTH: "true" 47 | XLA_PYTHON_CLIENT_PREALLOCATE: "false" 48 | 49 | -------------------------------------------------------------------------------- /experiments/scripts/eval_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Steps array 4 | steps=() 5 | 6 | # Generate the steps array 7 | for i in {1..5}; do 8 | steps+=($(($i * 5000))) 9 | done 10 | 11 | # Loop over the array 12 | for step in "${steps[@]}" 13 | do 14 | # Substitute the step into the command 15 | LAUNCH_FLAGS="--dataset_ bridgedata --name test_${step} --split_strategy_ task \ 16 | --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/bridge_muse_do_20230502_011720 --lang_encoder_ muse \ 17 | --resume_step_ ${step} --eval_checkpoint --dropout_rate_ 0.3" sh experiments/andre/scripts/launch_contrastive.sh contrastive_gpu 18 | done 19 | 20 | # # Generate the steps array 21 | # for i in {1..19}; do 22 | # step=$(($i * 5000)) 23 | 24 | # # Substitute the step into the command 25 | # LAUNCH_FLAGS="--dataset_ bridgedata --name fs_${step} --split_strategy_ task \ 26 | # --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/new_labels_task_20230429_072500 \ 27 | # --resume_step_ ${step} --eval_checkpoint" bash experiments/andre/scripts/launch_contrastive.sh contrastive_gpu 28 | # done 29 | 30 | -------------------------------------------------------------------------------- /environment_cuda.yml: -------------------------------------------------------------------------------- 1 | name: grif 2 | channels: 3 | - conda-forge 4 | - defaults 5 | - nvidia 6 | dependencies: 7 | - python=3.9 8 | - pip=21.0 9 | - setuptools==65.5.0 10 | - numpy<=1.23 11 | - scipy>=1.6.0 12 | - matplotlib=3.8.2 13 | - tqdm>=4.60.0 14 | - absl-py>=0.12.0 15 | - wandb>=0.12.14 16 | - moviepy>=1.0.3 17 | - google-auth==2.7.0 18 | - lockfile=0.12 19 | - imageio=2.19 20 | - ml-collections=0.1 21 | - distrax<2 22 | - imageio[ffmpeg] 23 | - mesalib 24 | - glew 25 | - cudatoolkit 26 | - cudnn 27 | - cuda-nvcc 28 | - patchelf 29 | - jax=0.4.14 30 | - optax==0.1.7 31 | - pip: 32 | - jaxlib==0.4.14+cuda11.cudnn86 33 | - flax==0.6.11 34 | - chex==0.1.7 35 | - tensorflow-cpu==2.11 36 | - tensorflow-hub==0.12 37 | - tensorflow-text==2.11 38 | - tensorflow-probability==0.19 39 | - tensorflow-datasets==4.9 40 | - tensorflow-estimator==2.11 41 | - tensorboard==2.11 42 | - transformers==4.25 43 | - gym==0.23 44 | - -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 45 | - -e . 46 | variables: 47 | TF_FORCE_GPU_ALLOW_GROWTH: "true" 48 | XLA_PYTHON_CLIENT_PREALLOCATE: "false" 49 | 50 | -------------------------------------------------------------------------------- /jaxrl_m/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import imageio 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import wandb 7 | from flax.core import frozen_dict 8 | import tensorflow as tf 9 | from jaxrl_m.data.bridge_dataset import BridgeDataset 10 | 11 | 12 | def concat_batches(offline_batch, online_batch, axis=1): 13 | batch = defaultdict(list) 14 | 15 | if type(offline_batch) != dict: 16 | offline_batch = offline_batch.unfreeze() 17 | 18 | if type(online_batch) != dict: 19 | online_batch = online_batch.unfreeze() 20 | 21 | for k, v in offline_batch.items(): 22 | if type(v) is dict: 23 | batch[k] = concat_batches(offline_batch[k], online_batch[k], axis=axis) 24 | else: 25 | batch[k] = jnp.concatenate((offline_batch[k], online_batch[k]), axis=axis) 26 | 27 | return frozen_dict.freeze(batch) 28 | 29 | 30 | def load_recorded_video( 31 | video_path: str, 32 | ): 33 | with tf.io.gfile.GFile(video_path, "rb") as f: 34 | video = np.array(imageio.mimread(f, "MP4")).transpose((0, 3, 1, 2)) 35 | assert video.shape[1] == 3, "Numpy array should be (T, C, H, W)" 36 | 37 | return wandb.Video(video, fps=20) 38 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM robonet-base:latest 2 | 3 | RUN ~/myenv/bin/pip install tensorflow jax[cuda11_cudnn82] flax distrax ml_collections h5py wandb tensorflow-text tensorflow-hub\ 4 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 5 | 6 | RUN ~/myenv/bin/pip install einops 7 | ENV PYTHONPATH=${PYTHONPATH}:/home/robonet/code/jaxrl_minimal:/home/robonet/code/denoising-diffusion-flax 8 | 9 | # Downloading gcloud package 10 | RUN curl https://dl.google.com/dl/cloudsdk/release/google-cloud-sdk.tar.gz > /tmp/google-cloud-sdk.tar.gz 11 | 12 | # Installing the package 13 | RUN sudo mkdir -p /usr/local/gcloud \ 14 | && sudo tar -C /usr/local/gcloud -xvf /tmp/google-cloud-sdk.tar.gz \ 15 | && sudo /usr/local/gcloud/google-cloud-sdk/install.sh --quiet 16 | 17 | # Adding the package path to local 18 | ENV PATH $PATH:/usr/local/gcloud/google-cloud-sdk/bin 19 | 20 | # avoid git safe directory errors 21 | RUN git config --global --add safe.directory /home/robonet/code/jaxrl_minimal 22 | 23 | # activate gcloud credentials (requires them to be mounted at /tmp/gcloud_key.json through docker-compose) 24 | RUN echo "gcloud auth activate-service-account --key-file /tmp/gcloud_key.json" >> ~/.bashrc 25 | ENV GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcloud_key.json 26 | 27 | WORKDIR /home/robonet/code/jaxrl_minimal 28 | -------------------------------------------------------------------------------- /experiments/scripts/launch_sim.sh: -------------------------------------------------------------------------------- 1 | # 2 cores per process 2 | TPU0="export TPU_VISIBLE_DEVICES=0 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 3 | TPU1="export TPU_VISIBLE_DEVICES=1 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8477 TPU_MESH_CONTROLLER_PORT=8477" 4 | TPU2="export TPU_VISIBLE_DEVICES=2 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 5 | TPU3="export TPU_VISIBLE_DEVICES=3 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8479 TPU_MESH_CONTROLLER_PORT=8479" 6 | 7 | # 4 cores per process 8 | TPU01="export TPU_VISIBLE_DEVICES=0,1 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 9 | TPU23="export TPU_VISIBLE_DEVICES=2,3 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 10 | 11 | export MUJOCO_GL=osmesa 12 | 13 | NAME="test" 14 | 15 | $TPU23 16 | CMD="python experiments/vivek/sim_offline_gc.py \ 17 | --config experiments/vivek/configs/offline_pixels_config.py:sim_gc_iql \ 18 | --name $NAME" 19 | 20 | $CMD 21 | -------------------------------------------------------------------------------- /experiments/scripts/.nfs0000000005fc462300000001: -------------------------------------------------------------------------------- 1 | # 2 cores per process 2 | TPU0="export TPU_VISIBLE_DEVICES=0 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 3 | TPU1="export TPU_VISIBLE_DEVICES=1 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8477 TPU_MESH_CONTROLLER_PORT=8477" 4 | TPU2="export TPU_VISIBLE_DEVICES=2 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 5 | TPU3="export TPU_VISIBLE_DEVICES=3 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8479 TPU_MESH_CONTROLLER_PORT=8479" 6 | 7 | # 4 cores per process 8 | TPU01="export TPU_VISIBLE_DEVICES=0,1 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 9 | TPU23="export TPU_VISIBLE_DEVICES=2,3 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 10 | 11 | NAME="all_bc_finetuning_85k" 12 | 13 | CMD="python experiments/homer/bridgedata_offline_gc.py \ 14 | --config experiments/homer/configs/offline_pixels_config.py:gc_bc \ 15 | --bridgedata_config experiments/homer/configs/bridgedata_config.py:all_finetune \ 16 | --name $NAME" 17 | 18 | $CMD -------------------------------------------------------------------------------- /jaxrl_m/networks/discrete_nets.py: -------------------------------------------------------------------------------- 1 | from jaxrl_m.common.typing import * 2 | import jax.numpy as jnp 3 | 4 | import flax.linen as nn 5 | 6 | from jaxrl_m.common.common import MLP, default_init 7 | from jaxrl_m.networks.actor_critic_nets import get_encoding 8 | import distrax 9 | 10 | 11 | class DiscreteQ(nn.Module): 12 | encoder: nn.Module 13 | network: nn.Module 14 | 15 | def __call__(self, observations): 16 | latents = get_encoding(self.encoder, observations) 17 | return self.network(latents) 18 | 19 | 20 | class DiscreteCriticHead(nn.Module): 21 | hidden_dims: Sequence[int] 22 | n_actions: int 23 | activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish 24 | 25 | def setup(self): 26 | self.q = MLP((*self.hidden_dims, self.n_actions), activation=self.activation) 27 | 28 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 29 | return self.q(observations) 30 | 31 | 32 | class DiscretePolicy(nn.Module): 33 | hidden_dims: Sequence[int] 34 | n_actions: int 35 | activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish 36 | 37 | @nn.compact 38 | def __call__(self, observations: jnp.ndarray, temperature=1.0) -> jnp.ndarray: 39 | logits = MLP((*self.hidden_dims, self.n_actions), activation=self.activation)( 40 | observations 41 | ) 42 | return distrax.Categorical(logits=logits / temperature) 43 | -------------------------------------------------------------------------------- /jaxrl_m/vision/cvae.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | import jax.random as random 4 | 5 | from typing import Any 6 | from typing import Sequence 7 | 8 | from jaxrl_m.common.common import MLP 9 | from jaxlib.xla_extension import DeviceArray 10 | 11 | ModuleDef = Any 12 | 13 | 14 | class Encoder(nn.Module): 15 | hidden_dims: Sequence[int] 16 | latent_dim: int 17 | 18 | @nn.compact 19 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 20 | h = MLP(self.hidden_dims)(observations) 21 | mean = MLP([self.latent_dim])(h) 22 | logvar = MLP([self.latent_dim])(h) 23 | return mean, logvar 24 | 25 | 26 | class Decoder(nn.Module): 27 | hidden_dims: Sequence[int] 28 | output_dim: int 29 | 30 | @nn.compact 31 | def __call__(self, latent: jnp.ndarray) -> jnp.ndarray: 32 | output = MLP((*self.hidden_dims, self.output_dim))(latent) 33 | return output 34 | 35 | 36 | class CVAE(nn.Module): 37 | encoder_hidden_dims: Sequence[int] 38 | latent_dim: int 39 | decoder_hidden_dims: Sequence[int] 40 | output_dim: int 41 | 42 | def setup(self): 43 | self.encoder = Encoder(self.encoder_hidden_dims, self.latent_dim) 44 | self.decoder = Decoder(self.decoder_hidden_dims, self.output_dim) 45 | 46 | def __call__( 47 | self, observations: jnp.ndarray, goals: jnp.ndarray, seed: DeviceArray 48 | ): 49 | rets = dict() 50 | 51 | combined = jnp.concatenate([observations, goals], axis=-1) 52 | rets["mean"], rets["logvar"] = self.encoder(combined) 53 | stds = jnp.exp(0.5 * rets["logvar"]) 54 | z = rets["mean"] + stds * random.normal(seed, rets["mean"].shape) 55 | rets["reconstruction"] = self.decoder(z) 56 | 57 | return rets 58 | -------------------------------------------------------------------------------- /experiments/create_eval_goals.py: -------------------------------------------------------------------------------- 1 | """ 2 | Samples final states from succesful trajectories in a validation dataset to use 3 | as goals for evaluation. Logs these goals to an eval_goals.npy file in the same 4 | folder as the dataset. Takes an argument to specify which info key to use as 5 | the success condition. 6 | """ 7 | 8 | import string 9 | import numpy as np 10 | from absl import flags, app, logging 11 | from jaxrl_m.utils.sim_utils import load_tf_dataset 12 | import tensorflow as tf 13 | import jax 14 | 15 | FLAGS = flags.FLAGS 16 | flags.DEFINE_string("data_path", None, "Location of dataset", required=True) 17 | flags.DEFINE_integer("num_goals", None, "Number of goals", required=True) 18 | flags.DEFINE_string("accept_trajectory_key", None, "Success key", required=True) 19 | 20 | 21 | def main(_): 22 | dataset = load_tf_dataset(FLAGS.data_path) 23 | data = [] 24 | for traj in iter(dataset): 25 | if traj["infos"][FLAGS.accept_trajectory_key][-1]: 26 | data.append(traj) 27 | 28 | logging.info(f"Number of successful trajectories: {len(data)}") 29 | data = np.random.choice(data, size=FLAGS.num_goals, replace=False) 30 | 31 | # turn list of dicts into dict of lists, selecting first element of each trajectory 32 | data_first = jax.tree_map(lambda *xs: np.array(xs)[:, 0], *data) 33 | 34 | # turn list of dicts into dict of lists, selecting last element of each trajectory 35 | data = jax.tree_map(lambda *xs: np.array(xs)[:, -1], *data) 36 | 37 | data["infos"]["initial_positions"] = data_first["infos"]["object_positions"] 38 | 39 | # decode strings 40 | data["infos"]["object_names"] = [ 41 | [s.decode("UTF-8") for s in names] for names in data["infos"]["object_names"] 42 | ] 43 | 44 | with tf.io.gfile.GFile( 45 | tf.io.gfile.join(FLAGS.data_path, "eval_goals.npy"), "wb" 46 | ) as f: 47 | np.save(f, data) 48 | 49 | 50 | if __name__ == "__main__": 51 | app.run(main) 52 | -------------------------------------------------------------------------------- /experiments/scripts/launch_bridge.sh: -------------------------------------------------------------------------------- 1 | # 2 cores per process 2 | TPU0="export TPU_VISIBLE_DEVICES=0 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 3 | TPU1="export TPU_VISIBLE_DEVICES=1 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8477 TPU_MESH_CONTROLLER_PORT=8477" 4 | TPU2="export TPU_VISIBLE_DEVICES=2 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 5 | TPU3="export TPU_VISIBLE_DEVICES=3 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8479 TPU_MESH_CONTROLLER_PORT=8479" 6 | 7 | # 4 cores per process 8 | TPU01="export TPU_VISIBLE_DEVICES=0,1 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 9 | TPU23="export TPU_VISIBLE_DEVICES=2,3 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 10 | 11 | TPU0123="export TPU_VISIBLE_DEVICES=0,1,2,3 TPU_CHIPS_PER_HOST_BOUNDS=1,4,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 12 | 13 | if [[ $# -ne 2 ]] && [[ $# -ne 1 ]]; then 14 | echo "Usage: bash experiments/scripts/launch_bridge.sh [config] TPU[0123]+" && 15 | exit 1 16 | fi 17 | 18 | CONFIG="$1" 19 | NAME="all_multimodal_$CONFIG" 20 | 21 | case "$2" in 22 | "TPU0") eval $TPU0 ;; 23 | "TPU1") eval $TPU1 ;; 24 | "TPU2") eval $TPU2 ;; 25 | "TPU3") eval $TPU3 ;; 26 | "TPU01") eval $TPU01 ;; 27 | "TPU23") eval $TPU23 ;; 28 | "TPU0123") eval $TPU0123 ;; 29 | "") ;; 30 | *) echo "Invalid TPU argument" && exit 1 ;; 31 | esac 32 | 33 | CMD="python experiments/bridgedata_offline_gc.py \ 34 | --config experiments/configs/offline_multimodal_config.py:$CONFIG \ 35 | --bridgedata_config experiments/configs/bridgedata_config.py:all \ 36 | --name $NAME \ 37 | $LAUNCH_FLAGS" 38 | 39 | $CMD 40 | -------------------------------------------------------------------------------- /experiments/scripts/eval.sh: -------------------------------------------------------------------------------- 1 | LAUNCH_FLAGS='--dataset_ bridgedata --name pt_5k --split_strategy_ task \ 2 | --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/new_labels_task_pt_20230429_072651 \ 3 | --resume_step_ 5000 --eval_checkpoint' \ 4 | sh experiments/andre/scripts/launch_contrastive.sh contrastive_tpu 5 | 6 | LAUNCH_FLAGS='--dataset_ bridgedata --name pt_50k --split_strategy_ task \ 7 | --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/new_labels_task_pt_20230429_072651 \ 8 | --resume_step_ 50000 --eval_checkpoint' \ 9 | sh experiments/andre/scripts/launch_contrastive.sh contrastive_tpu 10 | 11 | LAUNCH_FLAGS='--dataset_ bridgedata --name pt_95k --split_strategy_ task \ 12 | --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/new_labels_task_pt_20230429_072651 \ 13 | --resume_step_ 95000 --eval_checkpoint' \ 14 | sh experiments/andre/scripts/launch_contrastive.sh contrastive_tpu 15 | 16 | LAUNCH_FLAGS='--dataset_ bridgedata --name fs_5k --split_strategy_ task \ 17 | --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/new_labels_task_20230429_072500 \ 18 | --resume_step_ 5000 --eval_checkpoint' \ 19 | sh experiments/andre/scripts/launch_contrastive.sh contrastive_tpu 20 | 21 | LAUNCH_FLAGS='--dataset_ bridgedata --name fs_50k --split_strategy_ task \ 22 | --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/new_labels_task_20230429_072500 \ 23 | --resume_step_ 50000 --eval_checkpoint' \ 24 | sh experiments/andre/scripts/launch_contrastive.sh contrastive_tpu 25 | 26 | LAUNCH_FLAGS='--dataset_ bridgedata --name fs_95k --split_strategy_ task \ 27 | --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/new_labels_task_20230429_072500 \ 28 | --resume_step_ 950000 --eval_checkpoint' \ 29 | sh experiments/andre/scripts/launch_contrastive.sh contrastive_tpu 30 | 31 | LAUNCH_FLAGS='--dataset_ bridgedata --name zs --split_strategy_ task \ 32 | ---resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/all_v1_task_contrastive_tpu_20230416_084330 \ 33 | --eval_checkpoint' \ 34 | sh experiments/andre/scripts/launch_contrastive.sh contrastive_tpu 35 | 36 | 37 | # --resume_path_ gs://rail-tpus-andre/logs/jaxrl_m_bridgedata/all_v1_task_contrastive_tpu_20230416_084330 38 | -------------------------------------------------------------------------------- /jaxrl_m/data/sgl_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Union, Optional, Iterable 2 | import tensorflow as tf 3 | import os 4 | from jaxrl_m.data.bridge_dataset import BridgeDataset 5 | 6 | 7 | # this is for the manually collected validation sets 8 | class SGLDataset(BridgeDataset): 9 | def __init__( 10 | self, 11 | root_data_path: str, 12 | ): 13 | tf_paths = tf.io.gfile.glob(os.path.join(root_data_path, "data.tfrecord")) 14 | caption_path = tf_paths[0].replace("data.tfrecord", "captions.txt") 15 | 16 | # read captions 17 | with tf.io.gfile.GFile(caption_path, "r") as f: 18 | content = f.read() 19 | lines = content.split("\n") 20 | captions = [line.split(",")[2].strip() for line in lines] 21 | self.captions = captions 22 | dataset = self._construct_tf_dataset(tf_paths) 23 | dataset = dataset.batch(100, drop_remainder=False) 24 | self.tf_dataset = dataset 25 | 26 | def _construct_tf_dataset(self, paths): 27 | dataset = tf.data.Dataset.from_tensor_slices(paths) 28 | dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) 29 | dataset = dataset.map( 30 | self._decode_example_sgl, num_parallel_calls=tf.data.AUTOTUNE 31 | ) 32 | return dataset 33 | 34 | PROTO_TYPE_SPEC = { 35 | "start_frame": tf.uint8, 36 | "stop_frame": tf.uint8, 37 | "label": tf.uint32, 38 | } 39 | 40 | def _decode_example_sgl(self, example_proto): 41 | features = { 42 | key: tf.io.FixedLenFeature([], tf.string) 43 | for key in self.PROTO_TYPE_SPEC.keys() 44 | } 45 | parsed_features = tf.io.parse_single_example(example_proto, features) 46 | parsed_tensors = { 47 | key: tf.io.parse_tensor(parsed_features[key], dtype) 48 | for key, dtype in self.PROTO_TYPE_SPEC.items() 49 | } 50 | 51 | return { 52 | "observations": { 53 | "image": parsed_tensors["start_frame"], 54 | }, 55 | "next_observations": { 56 | "image": parsed_tensors["stop_frame"], 57 | }, 58 | "goals": { 59 | "image": parsed_tensors["stop_frame"], 60 | "language": parsed_tensors["label"], 61 | }, 62 | } 63 | 64 | def decode_lang(self, i): 65 | return self.captions[i] 66 | 67 | 68 | if __name__ == "__main__": 69 | dataset = SGLDataset("gs://rail-tpus-andre/bridge_validation/scene1") 70 | data_iter = dataset.get_iterator() 71 | batch = next(data_iter) 72 | print(batch) 73 | -------------------------------------------------------------------------------- /jaxrl_m/data/replay_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import gym 4 | import gym.spaces 5 | import numpy as np 6 | 7 | from jaxrl_m.data.dataset import Dataset, DatasetDict 8 | 9 | 10 | def _init_replay_dict( 11 | obs_space: gym.Space, capacity: int 12 | ) -> Union[np.ndarray, DatasetDict]: 13 | if isinstance(obs_space, gym.spaces.Box): 14 | return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype) 15 | elif isinstance(obs_space, gym.spaces.Dict): 16 | data_dict = {} 17 | for k, v in obs_space.spaces.items(): 18 | data_dict[k] = _init_replay_dict(v, capacity) 19 | return data_dict 20 | else: 21 | raise TypeError() 22 | 23 | 24 | def _insert_recursively( 25 | dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int 26 | ): 27 | if isinstance(dataset_dict, np.ndarray): 28 | dataset_dict[insert_index] = data_dict 29 | elif isinstance(dataset_dict, dict): 30 | assert dataset_dict.keys() == data_dict.keys() 31 | for k in dataset_dict.keys(): 32 | _insert_recursively(dataset_dict[k], data_dict[k], insert_index) 33 | else: 34 | raise TypeError() 35 | 36 | 37 | class ReplayBuffer(Dataset): 38 | def __init__( 39 | self, 40 | observation_space: gym.Space, 41 | action_space: gym.Space, 42 | capacity: int, 43 | next_observation_space: Optional[gym.Space] = None, 44 | goal_space: Optional[gym.Space] = None, 45 | seed: Optional[int] = None, 46 | ): 47 | if next_observation_space is None: 48 | next_observation_space = observation_space 49 | 50 | observation_data = _init_replay_dict(observation_space, capacity) 51 | next_observation_data = _init_replay_dict(next_observation_space, capacity) 52 | dataset_dict = dict( 53 | observations=observation_data, 54 | next_observations=next_observation_data, 55 | actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype), 56 | rewards=np.empty((capacity,), dtype=np.float32), 57 | masks=np.empty((capacity,), dtype=bool), 58 | ) 59 | 60 | if goal_space is not None: 61 | goal_space = _init_replay_dict(goal_space, capacity) 62 | dataset_dict["goals"] = goal_space 63 | 64 | super().__init__(dataset_dict, seed) 65 | 66 | self._size = 0 67 | self._capacity = capacity 68 | self._insert_index = 0 69 | 70 | def __len__(self) -> int: 71 | return self._size 72 | 73 | def insert(self, data_dict: DatasetDict): 74 | _insert_recursively(self.dataset_dict, data_dict, self._insert_index) 75 | 76 | self._insert_index = (self._insert_index + 1) % self._capacity 77 | self._size = min(self._size + 1, self._capacity) 78 | -------------------------------------------------------------------------------- /jaxrl_m/common/wandb.py: -------------------------------------------------------------------------------- 1 | import absl.flags as flags 2 | import datetime 3 | import tempfile 4 | from copy import copy 5 | from socket import gethostname 6 | 7 | import ml_collections 8 | import wandb 9 | 10 | 11 | def _recursive_flatten_dict(d: dict): 12 | keys, values = [], [] 13 | for key, value in d.items(): 14 | if isinstance(value, dict): 15 | sub_keys, sub_values = _recursive_flatten_dict(value) 16 | keys += [f"{key}/{k}" for k in sub_keys] 17 | values += sub_values 18 | else: 19 | keys.append(key) 20 | values.append(value) 21 | return keys, values 22 | 23 | 24 | class WandBLogger(object): 25 | @staticmethod 26 | def get_default_config(): 27 | config = ml_collections.ConfigDict() 28 | config.project = "jaxrl_m" # WandB Project Name 29 | config.entity = ml_collections.config_dict.FieldReference(None, field_type=str) 30 | # Which entity to log as (default: your own user) 31 | config.exp_descriptor = "" # Run name (doesn't have to be unique) 32 | # Unique identifier for run (will be automatically generated unless 33 | # provided) 34 | config.unique_identifier = "" 35 | return config 36 | 37 | def __init__(self, wandb_config, variant, wandb_output_dir=None, debug=False): 38 | self.config = wandb_config 39 | if self.config.unique_identifier == "": 40 | self.config.unique_identifier = datetime.datetime.now().strftime( 41 | "%Y%m%d_%H%M%S" 42 | ) 43 | 44 | self.config.experiment_id = ( 45 | self.experiment_id 46 | ) = f"{self.config.exp_descriptor}_{self.config.unique_identifier}" # NOQA 47 | 48 | print(self.config) 49 | 50 | if wandb_output_dir is None: 51 | wandb_output_dir = tempfile.mkdtemp() 52 | 53 | self._variant = copy(variant) 54 | 55 | if "hostname" not in self._variant: 56 | self._variant["hostname"] = gethostname() 57 | 58 | if debug: 59 | mode = "disabled" 60 | else: 61 | mode = "online" 62 | 63 | self.run = wandb.init( 64 | config=self._variant, 65 | project=self.config.project, 66 | entity=self.config.entity, 67 | dir=wandb_output_dir, 68 | id=self.config.experiment_id, 69 | save_code=True, 70 | mode=mode, 71 | ) 72 | 73 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS} 74 | for k in flag_dict: 75 | if isinstance(flag_dict[k], ml_collections.ConfigDict): 76 | flag_dict[k] = flag_dict[k].to_dict() 77 | wandb.config.update(flag_dict) 78 | 79 | def log(self, data: dict, step: int = None): 80 | data_flat = _recursive_flatten_dict(data) 81 | data = {k: v for k, v in zip(*data_flat)} 82 | wandb.log(data, step=step) 83 | -------------------------------------------------------------------------------- /jaxrl_m/agents/discrete/bc.py: -------------------------------------------------------------------------------- 1 | """Implementations of behavioral cloning in discrete action spaces.""" 2 | import functools 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import flax 7 | import flax.linen as nn 8 | import optax 9 | import distrax 10 | 11 | from jaxrl_m.common.typing import * 12 | from jaxrl_m.common.common import TrainState, nonpytree_field 13 | from jaxrl_m.networks.discrete_nets import DiscreteQ, DiscreteCriticHead 14 | 15 | import ml_collections 16 | 17 | 18 | class BCAgent(flax.struct.PyTreeNode): 19 | model: TrainState 20 | method: ModuleMethod = nonpytree_field(default=None) 21 | 22 | @functools.partial(jax.pmap, axis_name="pmap") 23 | def update(agent, batch: Batch): 24 | def loss_fn(params): 25 | logits = agent.model( 26 | batch["observations"], params=params, method=agent.method 27 | ) 28 | dist = distrax.Categorical(logits=logits) 29 | probs = jax.nn.softmax(logits) 30 | accuracy = jnp.mean(jnp.equal(jnp.argmax(probs, axis=1), batch["actions"])) 31 | actor_loss = -1 * dist.log_prob(batch["actions"]).mean() 32 | 33 | return actor_loss, { 34 | "actor_loss": actor_loss, 35 | "accuracy": accuracy, 36 | "entropy": dist.entropy().mean(), 37 | } 38 | 39 | new_model, info = agent.model.apply_loss_fn( 40 | loss_fn=loss_fn, pmap_axis="pmap", has_aux=True 41 | ) 42 | return agent.replace(model=new_model), info 43 | 44 | @functools.partial(jax.jit, static_argnames=("argmax")) 45 | def sample_actions(agent, observations, *, seed, temperature=1.0, argmax=False): 46 | logits = agent.model(observations, method=agent.method) 47 | dist = distrax.Categorical(logits=logits / temperature) 48 | 49 | if argmax: 50 | return dist.mode() 51 | else: 52 | return dist.sample(seed=seed) 53 | 54 | 55 | def create_bc_learner( 56 | seed: int, 57 | observations: jnp.ndarray, 58 | n_actions: int, 59 | encoder_def: nn.Module, 60 | network_kwargs: dict = { 61 | "hidden_dims": [256, 256], 62 | }, 63 | optim_kwargs: dict = { 64 | "learning_rate": 6e-5, 65 | }, 66 | **kwargs 67 | ): 68 | 69 | print("Extra kwargs:", kwargs) 70 | 71 | rng = jax.random.PRNGKey(seed) 72 | 73 | network_def = DiscreteCriticHead(n_actions=n_actions, **network_kwargs) 74 | model_def = DiscreteQ( 75 | encoder=encoder_def, 76 | network=network_def, 77 | ) 78 | tx = optax.adam(**optim_kwargs) 79 | 80 | params = model_def.init(rng, observations)["params"] 81 | model = TrainState.create(model_def, params, tx=tx) 82 | 83 | return BCAgent(model) 84 | 85 | 86 | def get_default_config(): 87 | config = ml_collections.ConfigDict( 88 | { 89 | "algo": "bc", 90 | "optim_kwargs": { 91 | "learning_rate": 6e-5, 92 | }, 93 | "network_kwargs": { 94 | "hidden_dims": (256, 256), 95 | }, 96 | } 97 | ) 98 | 99 | return config 100 | -------------------------------------------------------------------------------- /jaxrl_m/envs/wrappers/mujoco.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | import gym 3 | import numpy as np 4 | 5 | 6 | def convert_obs(obs): 7 | return { 8 | "image": obs["pixels"].astype(np.uint8), 9 | "proprio": np.concatenate( 10 | [ 11 | obs["end_effector_pos"], 12 | obs["end_effector_quat"], 13 | obs["right_finger_qpos"], 14 | obs["left_finger_qpos"], 15 | ] 16 | ), 17 | } 18 | 19 | 20 | def filter_info_keys(info): 21 | keep_keys = [ 22 | "place_success", 23 | ] 24 | return {k: v for k, v in info.items() if k in keep_keys} 25 | 26 | 27 | class GCMujocoWrapper(gym.Wrapper): 28 | """ 29 | Goal-conditioned wrapper for Mujoco sim environments. 30 | 31 | When reset is called, a new goal is sampled from the goal_sampler. The 32 | goal_sampler can either be a set of evaluation goals or a function that 33 | returns a goal (e.g an affordance model). 34 | """ 35 | 36 | def __init__( 37 | self, 38 | env: gym.Env, 39 | goal_sampler: Union[np.ndarray, Callable], 40 | ): 41 | super().__init__(env) 42 | self.env = env 43 | self.observation_space = gym.spaces.Dict( 44 | { 45 | "image": gym.spaces.Box( 46 | low=np.zeros((128, 128, 3)), 47 | high=255 * np.ones((128, 128, 3)), 48 | dtype=np.uint8, 49 | ), 50 | "proprio": gym.spaces.Box( 51 | low=np.zeros((8,)), 52 | high=np.ones((8,)), 53 | dtype=np.uint8, 54 | ), 55 | } 56 | ) 57 | self.current_goal = None 58 | self.goal_sampler = goal_sampler 59 | 60 | def step(self, *args): 61 | obs, reward, done, trunc, info = self.env.step(*args) 62 | info = filter_info_keys(info) 63 | return ( 64 | convert_obs(obs), 65 | reward, 66 | done, 67 | trunc, 68 | {"goal": self.current_goal, **info}, 69 | ) 70 | 71 | def reset(self, **kwargs): 72 | if not callable(self.goal_sampler): 73 | idx = np.random.randint(len(self.goal_sampler["observations"]["image"])) 74 | goal_image = self.goal_sampler["observations"]["image"][idx] 75 | original_object_positions = self.goal_sampler["infos"]["initial_positions"][ 76 | idx 77 | ] 78 | # original_object_quats = self.goal_sampler["infos"]["initial_quats"][idx] 79 | target_position = self.goal_sampler["infos"]["target_position"][idx] 80 | object_names = self.goal_sampler["infos"]["object_names"][idx] 81 | target_object = self.goal_sampler["infos"]["target_object"][idx] 82 | self.env.task.change_props(object_names) 83 | self.env.task.init_prop_poses = original_object_positions 84 | self.env.task.target_pos = target_position 85 | self.env.target_obj = target_object 86 | obs, info = self.env.reset() 87 | obs = convert_obs(obs) 88 | else: 89 | obs, info = self.env.reset() 90 | obs = convert_obs(obs) 91 | goal_image = self.goal_sampler(obs) 92 | 93 | goal = {"image": goal_image} 94 | 95 | self.current_goal = goal 96 | 97 | info = filter_info_keys(info) 98 | return obs, {"goal": goal, **info} 99 | -------------------------------------------------------------------------------- /jaxrl_m/vision/small_encoders.py: -------------------------------------------------------------------------------- 1 | """From https://raw.githubusercontent.com/google/flax/main/examples/ppo/models.py""" 2 | 3 | from multiprocessing import pool 4 | from flax import linen as nn 5 | import jax.numpy as jnp 6 | from typing import Sequence, Union, Tuple, Optional 7 | from jaxrl_m.common.common import orthogonal_init 8 | 9 | 10 | class AtariEncoder(nn.Module): 11 | """Class defining the actor-critic model.""" 12 | 13 | @nn.compact 14 | def __call__(self, x): 15 | """Define the convolutional network architecture. 16 | 17 | Architecture originates from "Human-level control through deep reinforcement 18 | learning.", Nature 518, no. 7540 (2015): 529-533. 19 | Note that this is different than the one from "Playing atari with deep 20 | reinforcement learning." arxiv.org/abs/1312.5602 (2013) 21 | 22 | Network is used to both estimate policy (logits) and expected state value; 23 | in other words, hidden layers' params are shared between policy and value 24 | networks, see e.g.: 25 | github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py 26 | """ 27 | dtype = jnp.float32 28 | x = x.astype(dtype) / 255.0 29 | x = nn.Conv( 30 | features=32, kernel_size=(8, 8), strides=(4, 4), name="conv1", dtype=dtype 31 | )(x) 32 | x = nn.relu(x) 33 | x = nn.Conv( 34 | features=64, kernel_size=(4, 4), strides=(2, 2), name="conv2", dtype=dtype 35 | )(x) 36 | x = nn.relu(x) 37 | x = nn.Conv( 38 | features=64, kernel_size=(3, 3), strides=(1, 1), name="conv3", dtype=dtype 39 | )(x) 40 | x = nn.relu(x) 41 | x = x.reshape((*x.shape[:-3], -1)) # flatten 42 | return x 43 | 44 | 45 | class SmallEncoder(nn.Module): 46 | features: Sequence[int] = (16, 16, 16) 47 | kernel_sizes: Sequence[int] = (3, 3, 3) 48 | strides: Sequence[int] = (1, 1, 1) 49 | padding: Union[Sequence[int], str] = (1, 1, 1) 50 | pool_method: Optional[str] = "max" 51 | pool_sizes: Sequence[int] = (2, 2, 1) 52 | pool_strides: Sequence[int] = (2, 2, 1) 53 | pool_padding: Sequence[int] = (0, 0, 0) 54 | bottleneck_dim: Optional[int] = None 55 | 56 | @nn.compact 57 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 58 | assert len(self.features) == len(self.strides) 59 | 60 | x = observations.astype(jnp.float32) / 255.0 61 | 62 | for i in range(len(self.features)): 63 | 64 | if isinstance(self.padding, str): 65 | padding = self.padding 66 | else: 67 | padding = self.padding[i] 68 | 69 | x = nn.Conv( 70 | self.features[i], 71 | kernel_size=(self.kernel_sizes[i], self.kernel_sizes[i]), 72 | strides=(self.strides[i], self.strides[i]), 73 | kernel_init=orthogonal_init(), 74 | padding=padding, 75 | )(x) 76 | if self.pool_method is not None: 77 | if self.pool_method == "avg": 78 | pool_func = nn.avg_pool 79 | elif self.pool_method == "max": 80 | pool_func = nn.max_pool 81 | x = pool_func( 82 | x, 83 | window_shape=(self.pool_sizes[i], self.pool_sizes[i]), 84 | strides=(self.pool_strides[i], self.pool_strides[i]), 85 | padding=((self.pool_padding[i], self.pool_padding[i]),) * 2, 86 | ) 87 | x = nn.relu(x) 88 | 89 | if self.bottleneck_dim is not None: 90 | x = nn.Dense(self.bottleneck_dim, kernel_init=orthogonal_init())(x) 91 | x = nn.LayerNorm()(x) 92 | x = nn.tanh(x) 93 | 94 | return x.reshape((*x.shape[:-3], -1)) 95 | 96 | 97 | small_configs = {"atari": AtariEncoder, "small": SmallEncoder} 98 | -------------------------------------------------------------------------------- /jaxrl_m/vision/bigvision_common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Big Vision Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities shared across models.""" 16 | 17 | from absl import logging 18 | import jaxrl_m.vision.bigvision_utils as u 19 | import flax.linen as nn 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | def merge_params(loaded, inited, dont_load=()): 25 | """Makes `loaded` pytree match `init`, warning or failing on mismatch. 26 | 27 | Args: 28 | loaded: pytree of parameters, typically loaded from a checkpoint. 29 | inited: pytree of parameter, typically coming from model init. 30 | dont_load: List of regexes for parameters which shall not be taken 31 | from `loaded`, either because they should remain at their init value, 32 | or because they are missing on either side. 33 | 34 | Returns: 35 | If successful, a new pytree which matches the structure of `init` 36 | but contains values from `loaded`, except for `dont_load`. 37 | 38 | If structures don't match and mismatches are not covered by regexes in 39 | `dont_load` argument, then raises an exception with more information. 40 | """ 41 | dont_load = u.check_and_compile_patterns(dont_load) 42 | 43 | def should_merge(name): 44 | return not any(pattern.fullmatch(name) for pattern in dont_load) 45 | 46 | loaded_flat, _ = u.tree_flatten_with_names(loaded) 47 | inited_flat, _ = u.tree_flatten_with_names(inited) 48 | loaded_flat = {k: v for k, v in loaded_flat} 49 | inited_flat = {k: v for k, v in inited_flat} 50 | 51 | # Let's first build the pytree from all common keys. 52 | merged = {} 53 | for name, init_val in inited_flat.items(): 54 | # param is present in both. Load or ignore it! 55 | if name in loaded_flat and should_merge(name): 56 | merged[name] = loaded_flat[name] 57 | else: 58 | logging.info("Ignoring checkpoint and using init value for %s", name) 59 | merged[name] = init_val 60 | 61 | def pp(title, names, indent=" "): # Just pretty-printing 62 | if names: 63 | return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names)) 64 | else: 65 | return "" 66 | 67 | # Now, if there are keys that only exist in inited or loaded, be helpful: 68 | not_in_loaded = inited_flat.keys() - loaded_flat.keys() 69 | not_in_inited = loaded_flat.keys() - inited_flat.keys() 70 | logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded)) 71 | logging.info(pp("Parameters in checkpoint but not in model", not_in_inited)) 72 | 73 | # And now see if any of them are not explicitly ignored => an error 74 | not_in_loaded = {k for k in not_in_loaded if should_merge(k)} 75 | not_in_inited = {k for k in not_in_inited if should_merge(k)} 76 | 77 | if not_in_loaded or not_in_inited: 78 | raise ValueError( 79 | pp("Params in checkpoint", loaded_flat.keys()) 80 | + "\n" 81 | + pp("Params in model (code)", inited_flat.keys()) 82 | + "\n" 83 | + pp( 84 | "Params in model (code) but not in checkpoint and not `dont_load`ed", 85 | not_in_loaded, 86 | indent=" - ", 87 | ) 88 | + "\n" 89 | + pp( # Special indent for tests. 90 | "Params in checkpoint but not in model (code) and not `dont_load`ed", 91 | not_in_inited, 92 | indent=" + ", 93 | ) 94 | ) # Special indent for tests. 95 | 96 | return u.recover_tree(merged.keys(), merged.values()) 97 | -------------------------------------------------------------------------------- /jaxrl_m/envs/wrappers/dmcgym.py: -------------------------------------------------------------------------------- 1 | # Taken from 2 | # https://github.com/denisyarats/dmc2gym 3 | # and modified to exclude duplicated code. 4 | 5 | import copy 6 | from typing import OrderedDict 7 | 8 | import dm_env 9 | import gym 10 | import numpy as np 11 | from gym import spaces 12 | 13 | 14 | def filter_info(obs): 15 | new_obs = {} 16 | info = {} 17 | for k, v in obs.items(): 18 | if "info" in k: 19 | info_key = k.split("/")[-1] 20 | info[info_key] = v 21 | else: 22 | new_obs[k] = v 23 | return new_obs, info 24 | 25 | 26 | def dmc_spec2gym_space(spec): 27 | if isinstance(spec, OrderedDict) or isinstance(spec, dict): 28 | new_spec = OrderedDict() 29 | for k, v in spec.items(): 30 | if "info" in k: 31 | continue 32 | new_spec[k] = dmc_spec2gym_space(v) 33 | return spaces.Dict(new_spec) 34 | elif isinstance(spec, dm_env.specs.BoundedArray): 35 | low = np.broadcast_to(spec.minimum, spec.shape) 36 | high = np.broadcast_to(spec.maximum, spec.shape) 37 | return spaces.Box(low=low, high=high, shape=spec.shape, dtype=spec.dtype) 38 | elif isinstance(spec, dm_env.specs.Array): 39 | if np.issubdtype(spec.dtype, np.integer): 40 | low = np.iinfo(spec.dtype).min 41 | high = np.iinfo(spec.dtype).max 42 | elif np.issubdtype(spec.dtype, np.inexact): 43 | low = float("-inf") 44 | high = float("inf") 45 | else: 46 | raise ValueError() 47 | 48 | return spaces.Box(low=low, high=high, shape=spec.shape, dtype=spec.dtype) 49 | else: 50 | raise NotImplementedError 51 | 52 | 53 | def dmc_obs2gym_obs(obs): 54 | if isinstance(obs, OrderedDict) or isinstance(obs, dict): 55 | obs = copy.copy(obs) 56 | for k, v in obs.items(): 57 | obs[k] = dmc_obs2gym_obs(v) 58 | return obs 59 | else: 60 | return np.asarray(obs) 61 | 62 | 63 | class DMCGYM(gym.core.Env): 64 | metadata = {"render_modes": ["rgb_array"]} 65 | 66 | def __init__(self, env: dm_env.Environment): 67 | self._env = env 68 | 69 | self.action_space = dmc_spec2gym_space(self._env.action_spec()) 70 | 71 | self.observation_space = dmc_spec2gym_space(self._env.observation_spec()) 72 | 73 | self.viewer = None 74 | 75 | def _get_viewer(self): 76 | if self.viewer is None: 77 | from gym.envs.mujoco.mujoco_rendering import Viewer 78 | 79 | self.viewer = Viewer( 80 | self._env.physics.model.ptr, self._env.physics.data.ptr 81 | ) 82 | return self.viewer 83 | 84 | def __getattr__(self, name): 85 | return getattr(self._env, name) 86 | 87 | def seed(self, seed: int): 88 | if hasattr(self._env, "random_state"): 89 | self._env.random_state.seed(seed) 90 | else: 91 | self._env.task.random.seed(seed) 92 | 93 | def step(self, action: np.ndarray): 94 | action = np.clip(action, self.action_space.low, self.action_space.high) 95 | 96 | time_step = self._env.step(action) 97 | reward = time_step.reward 98 | done = time_step.last() 99 | obs, info = filter_info(time_step.observation) 100 | 101 | trunc = done and time_step.discount == 1.0 102 | 103 | return dmc_obs2gym_obs(obs), reward, done, trunc, info 104 | 105 | def reset(self): 106 | time_step = self._env.reset() 107 | obs, info = filter_info(time_step.observation) 108 | return dmc_obs2gym_obs(obs), info 109 | 110 | def render( 111 | self, mode="rgb_array", height: int = 128, width: int = 128, camera_id: int = 0 112 | ): 113 | assert mode in ["human", "rgb_array"], ( 114 | "only support rgb_array and human mode, given %s" % mode 115 | ) 116 | if mode == "rgb_array": 117 | return self._env.physics.render( 118 | height=height, width=width, camera_id=camera_id 119 | ) 120 | elif mode == "human": 121 | self._get_viewer().render() 122 | -------------------------------------------------------------------------------- /jaxrl_m/vision/impala.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | 5 | def default_init(scale: float = jnp.sqrt(2)): 6 | return nn.initializers.orthogonal(scale) 7 | 8 | 9 | def xavier_init(): 10 | return nn.initializers.xavier_normal() 11 | 12 | 13 | def kaiming_init(): 14 | return nn.initializers.kaiming_normal() 15 | 16 | 17 | class ResnetStack(nn.Module): 18 | num_ch: int 19 | num_blocks: int 20 | use_max_pooling: bool = True 21 | 22 | @nn.compact 23 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 24 | initializer = nn.initializers.xavier_uniform() 25 | conv_out = nn.Conv( 26 | features=self.num_ch, 27 | kernel_size=(3, 3), 28 | strides=1, 29 | kernel_init=initializer, 30 | padding="SAME", 31 | )(observations) 32 | 33 | if self.use_max_pooling: 34 | conv_out = nn.max_pool( 35 | conv_out, window_shape=(3, 3), padding="SAME", strides=(2, 2) 36 | ) 37 | 38 | for _ in range(self.num_blocks): 39 | block_input = conv_out 40 | conv_out = nn.relu(conv_out) 41 | conv_out = nn.Conv( 42 | features=self.num_ch, 43 | kernel_size=(3, 3), 44 | strides=1, 45 | padding="SAME", 46 | kernel_init=initializer, 47 | )(conv_out) 48 | 49 | conv_out = nn.relu(conv_out) 50 | conv_out = nn.Conv( 51 | features=self.num_ch, 52 | kernel_size=(3, 3), 53 | strides=1, 54 | padding="SAME", 55 | kernel_init=initializer, 56 | )(conv_out) 57 | conv_out += block_input 58 | 59 | return conv_out 60 | 61 | 62 | class ImpalaEncoder(nn.Module): 63 | width: int = 1 64 | use_multiplicative_cond: bool = False 65 | stack_sizes: tuple = (16, 32, 32) 66 | num_blocks: int = 2 67 | dropout_rate: float = None 68 | 69 | def setup(self): 70 | stack_sizes = self.stack_sizes 71 | self.stack_blocks = [ 72 | ResnetStack(num_ch=stack_sizes[i] * self.width, num_blocks=self.num_blocks) 73 | for i in range(len(stack_sizes)) 74 | ] 75 | if self.dropout_rate is not None: 76 | self.dropout = nn.Dropout(rate=self.dropout_rate) 77 | 78 | @nn.compact 79 | def __call__(self, x, train=True, cond_var=None): 80 | x = x.astype(jnp.float32) / 255.0 81 | # x = jnp.reshape(x, (*x.shape[:-2], -1)) 82 | 83 | conv_out = x 84 | 85 | for idx in range(len(self.stack_blocks)): 86 | conv_out = self.stack_blocks[idx](conv_out) 87 | if self.dropout_rate is not None: 88 | conv_out = self.dropout(conv_out, deterministic=not train) 89 | if self.use_multiplicative_cond: 90 | assert cond_var is not None, "Cond var shouldn't be done when using it" 91 | print("Using Multiplicative Cond!") 92 | temp_out = nn.Dense(conv_out.shape[-1], kernel_init=xavier_init())( 93 | cond_var 94 | ) 95 | x_mult = jnp.expand_dims(jnp.expand_dims(temp_out, 1), 1) 96 | print("x_mult shape in IMPALA:", x_mult.shape, conv_out.shape) 97 | conv_out = conv_out * x_mult 98 | 99 | conv_out = nn.relu(conv_out) 100 | # print(conv_out.shape, conv_out.reshape((*x.shape[:-3], -1)).shape) 101 | return conv_out.reshape((*x.shape[:-3], -1)) 102 | 103 | 104 | import functools as ft 105 | 106 | impala_configs = { 107 | "impala": ImpalaEncoder, 108 | "impala_large": ft.partial(ImpalaEncoder, stack_sizes=(16, 32, 32, 32)), 109 | "impala_larger": ft.partial(ImpalaEncoder, stack_sizes=(16, 32, 32, 32, 32)), 110 | "impala_largest": ft.partial(ImpalaEncoder, stack_sizes=(16, 32, 32, 32, 32, 32)), 111 | "impala_wider": ft.partial(ImpalaEncoder, width=2), 112 | "impala_widest": ft.partial(ImpalaEncoder, width=4), 113 | "impala_deeper": ft.partial(ImpalaEncoder, num_blocks=4), 114 | "impala_deepest": ft.partial(ImpalaEncoder, num_blocks=8), 115 | } 116 | -------------------------------------------------------------------------------- /experiments/eval_visualize_roboverse.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from absl import app, flags 3 | from jaxrl_m.agents.continuous.gc_iql import create_iql_learner 4 | from jaxrl_m.vision import encoders as vision_encoders 5 | from jaxrl_m.envs.bridge import visualization 6 | from flax.training import checkpoints 7 | import numpy as np 8 | import wandb 9 | import matplotlib.pyplot as plt 10 | import os 11 | import gym 12 | import roboverse 13 | import glob 14 | import h5py 15 | from collections import defaultdict 16 | 17 | FLAGS = flags.FLAGS 18 | flags.DEFINE_string( 19 | "checkpoint_path", None, "Path to checkpoint to resume from.", required=True 20 | ) 21 | flags.DEFINE_string( 22 | "wandb_run_name", None, "Name of wandb run to resume from.", required=True 23 | ) 24 | flags.DEFINE_string("data_path", None, "Location of dataset", required=True) 25 | flags.DEFINE_integer("demo_id", None, "ID of demo to visualize.", required=True) 26 | 27 | 28 | def get_demo(): 29 | paths = glob.glob(f"{FLAGS.data_path}/train/*.hdf5") 30 | path = paths[0] 31 | specs = defaultdict(list) 32 | with h5py.File(path, "r") as f: 33 | file_len = len(f["actions"]) 34 | start = FLAGS.demo_id - (49 - f["steps_remaining"][FLAGS.demo_id]) 35 | end = FLAGS.demo_id + f["steps_remaining"][FLAGS.demo_id] + 1 36 | action_metadata = np.load( 37 | os.path.join(FLAGS.data_path, "train/metadata.npy"), allow_pickle=True 38 | ).item() 39 | actions = (f["actions"][start:end] - action_metadata["mean"]) / action_metadata[ 40 | "std" 41 | ] 42 | demo_batched = { 43 | "observations": { 44 | "image": f["observations/images0"][start:end], 45 | "proprio": f["observations/state"][start:end].astype(np.float32), 46 | }, 47 | "next_observations": { 48 | "image": f["next_observations/images0"][start:end], 49 | "proprio": f["next_observations/state"][start:end].astype(np.float32), 50 | }, 51 | "actions": actions, 52 | "goals": { 53 | "image": np.array( 54 | [f["observations/images0"][end - 1] for _ in range(len(actions))] 55 | ), 56 | }, 57 | "rewards": np.array( 58 | [0 if i == len(actions) - 1 else -1 for i in range(len(actions))] 59 | ), 60 | "masks": np.ones(len(actions)), 61 | } 62 | 63 | return demo_batched 64 | 65 | 66 | def main(_): 67 | # restore agent 68 | api = wandb.Api() 69 | run = api.run(FLAGS.wandb_run_name) 70 | 71 | assert run.config["model_constructor"] == "create_iql_learner" 72 | model_config = run.config["model_config"] 73 | encoder_def = vision_encoders[model_config["encoder"]]( 74 | **model_config["encoder_kwargs"] 75 | ) 76 | agent = create_iql_learner( 77 | seed=0, 78 | encoder_def=encoder_def, 79 | observations=np.zeros((128, 128, 3), dtype=np.uint8), 80 | goals=np.zeros((128, 128, 3), dtype=np.uint8), 81 | actions=np.zeros(7, dtype=np.float32), 82 | **model_config["agent_kwargs"], 83 | ) 84 | params = checkpoints.restore_checkpoint(FLAGS.checkpoint_path, target=None)[ 85 | "model" 86 | ]["params"] 87 | agent = agent.replace(model=agent.model.replace(params=params)) 88 | 89 | demo_batched = get_demo() 90 | 91 | # run inference 92 | metrics = agent.get_debug_metrics(demo_batched) 93 | 94 | # create visualization 95 | what_to_visualize = [ 96 | partial(visualization.visualize_metric, metric_name="mse"), 97 | partial(visualization.visualize_metric, metric_name="v"), 98 | partial(visualization.visualize_metric, metric_name="target_q"), 99 | partial(visualization.visualize_metric, metric_name="advantage"), 100 | partial(visualization.visualize_metric, metric_name="value_err"), 101 | partial(visualization.visualize_metric, metric_name="td_err"), 102 | ] 103 | image = visualization.make_visual( 104 | demo_batched["observations"]["image"], 105 | metrics, 106 | what_to_visualize=what_to_visualize, 107 | ) 108 | plt.imsave("visualization.png", image) 109 | 110 | 111 | if __name__ == "__main__": 112 | app.run(main) 113 | -------------------------------------------------------------------------------- /jaxrl_m/data/tf_augmentations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from collections.abc import Mapping 3 | from ml_collections import ConfigDict 4 | 5 | 6 | def random_resized_crop(image, scale, ratio, seed): 7 | if len(tf.shape(image)) == 3: 8 | image = tf.expand_dims(image, axis=0) 9 | batch_size = tf.shape(image)[0] 10 | # taken from https://keras.io/examples/vision/nnclr/#random-resized-crops 11 | log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1])) 12 | height = tf.shape(image)[1] 13 | width = tf.shape(image)[2] 14 | 15 | random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1]) 16 | random_ratios = tf.exp( 17 | tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1]) 18 | ) 19 | 20 | new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1) 21 | new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1) 22 | height_offsets = tf.random.stateless_uniform( 23 | (batch_size,), seed, 0, 1 - new_heights 24 | ) 25 | width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths) 26 | 27 | bounding_boxes = tf.stack( 28 | [ 29 | height_offsets, 30 | width_offsets, 31 | height_offsets + new_heights, 32 | width_offsets + new_widths, 33 | ], 34 | axis=1, 35 | ) 36 | 37 | image = tf.image.crop_and_resize( 38 | image, bounding_boxes, tf.range(batch_size), (height, width) 39 | ) 40 | 41 | return tf.squeeze(image) 42 | 43 | 44 | AUGMENT_OPS = { 45 | "random_resized_crop": random_resized_crop, 46 | "random_brightness": tf.image.stateless_random_brightness, 47 | "random_contrast": tf.image.stateless_random_contrast, 48 | "random_saturation": tf.image.stateless_random_saturation, 49 | "random_hue": tf.image.stateless_random_hue, 50 | "random_flip_left_right": tf.image.stateless_random_flip_left_right, 51 | } 52 | 53 | 54 | def augment(image, seed, **augment_kwargs): 55 | image = tf.cast(image, tf.float32) / 255 # convert images to [0, 1] 56 | for op in augment_kwargs["augment_order"]: 57 | if op in augment_kwargs: 58 | if isinstance(augment_kwargs[op], Mapping) or isinstance( 59 | augment_kwargs[op], ConfigDict 60 | ): 61 | image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op]) 62 | else: 63 | image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op]) 64 | else: 65 | image = AUGMENT_OPS[op](image, seed=seed) 66 | image = tf.clip_by_value(image, 0, 1) 67 | image = tf.cast(image * 255, tf.uint8) 68 | return image 69 | 70 | 71 | def augment_batch(images, seed, **augment_kwargs): 72 | batch_size = tf.shape(images)[0] 73 | sub_seeds = [seed] 74 | for _ in range(batch_size): 75 | sub_seeds.append( 76 | tf.random.stateless_uniform( 77 | [2], 78 | seed=sub_seeds[-1], 79 | minval=None, 80 | maxval=None, 81 | dtype=tf.int32, 82 | ) 83 | ) 84 | images = tf.cast(images, tf.float32) / 255 # convert images to [0, 1] 85 | for op in augment_kwargs["augment_order"]: 86 | if op in augment_kwargs: 87 | if isinstance(augment_kwargs[op], Mapping) or isinstance( 88 | augment_kwargs[op], ConfigDict 89 | ): 90 | # this is random_resized_crop which can handle batches 91 | assert op == "random_resized_crop" 92 | images = AUGMENT_OPS[op](images, seed=seed, **augment_kwargs[op]) 93 | else: 94 | images_list = [] 95 | for i in range(batch_size): 96 | images_list.append( 97 | AUGMENT_OPS[op]( 98 | images[i], seed=sub_seeds[i], *augment_kwargs[op] 99 | ) 100 | ) 101 | images = tf.stack(images_list) 102 | else: 103 | images_list = [] 104 | for i in range(batch_size): 105 | images_list.append(AUGMENT_OPS[op](images[i], seed=sub_seeds[i])) 106 | images = tf.stack(images_list) 107 | images = tf.clip_by_value(images, 0, 1) 108 | images = tf.cast(images * 255, tf.uint8) 109 | return images 110 | -------------------------------------------------------------------------------- /jaxrl_m/data/language.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import json 3 | 4 | lang_to_code = {} 5 | code_to_lang = {} 6 | NONE = -1 7 | 8 | 9 | def load_mapping(path, constructor=dict, augmented=False): 10 | global lang_to_code, code_to_lang 11 | if not augmented: 12 | encode_path = tf.io.gfile.join(path, "language_encodings.json") 13 | decode_path = tf.io.gfile.join(path, "language_decodings.json") 14 | else: 15 | encode_path = tf.io.gfile.join(path, "language_encodings_aug.json") 16 | decode_path = tf.io.gfile.join(path, "language_decodings_aug.json") 17 | lang_to_code = constructor(json.loads(tf.io.gfile.GFile(encode_path, "r").read())) 18 | code_to_lang = constructor( 19 | { 20 | int(k): v 21 | for k, v in json.loads(tf.io.gfile.GFile(decode_path, "r").read()).items() 22 | } 23 | ) 24 | 25 | 26 | def flush_mapping(path): 27 | encode_path = tf.io.gfile.join(path, "language_encodings.json") 28 | decode_path = tf.io.gfile.join(path, "language_decodings.json") 29 | with tf.io.gfile.GFile(encode_path, "w") as f: 30 | f.write(json.dumps(dict(lang_to_code))) 31 | with tf.io.gfile.GFile(decode_path, "w") as f: 32 | f.write(json.dumps(dict(code_to_lang))) 33 | 34 | 35 | def lang_encode(lang): 36 | if not lang: 37 | return NONE 38 | elif lang in lang_to_code: 39 | return lang_to_code[lang] 40 | else: 41 | code = len(lang_to_code) 42 | lang_to_code[lang] = code 43 | code_to_lang[code] = lang 44 | return code 45 | 46 | 47 | import numpy as np 48 | 49 | rng = np.random.RandomState(0) 50 | 51 | 52 | def lang_decode(code): 53 | global rng 54 | if code == NONE: 55 | return None 56 | text = code_to_lang[code] 57 | choices = text.split("\n") 58 | return rng.choice(choices) if rng else choices[0] 59 | 60 | 61 | def lang_encodings(): 62 | return code_to_lang 63 | 64 | 65 | # This will query gpt to generate more paraphrases 66 | if __name__ == "__main__": 67 | # parse args 68 | import argparse 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--num_paraphrases", type=int, default=5) 72 | parser.add_argument("--path", type=str, default="gs://rail-tpus-andre/new_tf") 73 | args = parser.parse_args() 74 | print(args) 75 | 76 | # load mapping 77 | path = args.path 78 | load_mapping(path) 79 | has_variants = [lang for lang in lang_to_code if "\n" in lang] 80 | print( 81 | f"Loaded {len(lang_encodings())} languages, {len(has_variants)} have variants" 82 | ) 83 | 84 | PROMPT = "Generate %d variations of the following command: %s\nNumber them like 1. 2. 3.\nBe concise and use synonyms.\n" 85 | import openai 86 | import tqdm 87 | 88 | new_code_to_lang = {} 89 | new_lang_to_code = {} 90 | for code, lang in tqdm.tqdm(code_to_lang.items()): 91 | langs = lang.split("\n") 92 | langs = [l if l.endswith(".") else l + "." for l in langs] 93 | prompt = PROMPT % (5, langs[0]) 94 | if len(langs) > 1: 95 | prompt += "Start with:\n" 96 | for i, l in enumerate(langs): 97 | prompt += f"{i+1}. {l}\n" 98 | # print('\n\n') 99 | # print(prompt) 100 | 101 | response = openai.ChatCompletion.create( 102 | model="gpt-3.5-turbo", 103 | messages=[ 104 | {"role": "system", "content": prompt}, 105 | ], 106 | ) 107 | # print('\n') 108 | response = response["choices"][0]["message"]["content"] 109 | # print(response) 110 | # break 111 | try: 112 | new_langs = response.split("\n") 113 | new_langs = [l[3:] for l in new_langs] 114 | except: 115 | print("Error parsing response") 116 | new_langs = [] 117 | 118 | new_lang = "\n".join(langs + new_langs) 119 | print("all variations:") 120 | print(new_lang) 121 | 122 | new_code_to_lang[code] = new_lang 123 | new_lang_to_code[new_lang] = code 124 | 125 | encode_path = tf.io.gfile.join(path, "language_encodings_aug.json") 126 | decode_path = tf.io.gfile.join(path, "language_decodings_aug.json") 127 | with tf.io.gfile.GFile(encode_path, "w") as f: 128 | f.write(json.dumps(dict(new_lang_to_code))) 129 | with tf.io.gfile.GFile(decode_path, "w") as f: 130 | f.write(json.dumps(dict(new_code_to_lang))) 131 | -------------------------------------------------------------------------------- /scripts/ss2_frames_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | from absl import app, flags, logging 2 | from PIL import Image 3 | from multiprocessing import Pool, Manager 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import glob 8 | import os 9 | import json 10 | import tqdm 11 | 12 | FLAGS = flags.FLAGS 13 | flags.DEFINE_string( 14 | "input_path", "/nfs/kun2/users/andrehe/SS2", "Path to input directory" 15 | ) 16 | flags.DEFINE_string( 17 | "labels_path", "/nfs/kun2/users/andrehe/SS2_labels", "Path to labels json files" 18 | ) 19 | flags.DEFINE_string( 20 | "output_path", 21 | "gs://rail-tpus-andre/something-something/tf_fixed", 22 | "Path to output directory", 23 | ) 24 | flags.DEFINE_bool("overwrite", True, "Overwrite existing files") 25 | flags.DEFINE_integer("num_workers", 8, "Number of threads to use") 26 | flags.DEFINE_integer("chunk_size", 100, "Number of videos per tfrecord") 27 | 28 | 29 | def tensor_feature(value): 30 | return tf.train.Feature( 31 | bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()]) 32 | ) 33 | 34 | 35 | def process(args): 36 | video_infos, split, action_class = args 37 | 38 | outpath = os.path.join(FLAGS.output_path, split, action_class.replace(" ", "_")) 39 | # outpath = f"{outpath}/" 40 | 41 | if tf.io.gfile.exists(outpath): 42 | if FLAGS.overwrite: 43 | logging.info(f"Deleting {outpath}") 44 | try: 45 | tf.io.gfile.rmtree(outpath) 46 | except: 47 | pass 48 | else: 49 | logging.info(f"Skipping {outpath}") 50 | return 51 | 52 | tf.io.gfile.makedirs(os.path.dirname(outpath)) 53 | 54 | for i in range(0, len(video_infos), FLAGS.chunk_size): 55 | chunk = video_infos[i : min(i + FLAGS.chunk_size, len(video_infos))] 56 | tf_path = os.path.join(outpath, f"{i // FLAGS.chunk_size}.tfrecord") 57 | with tf.io.TFRecordWriter(tf_path) as writer: 58 | for video_info in chunk: 59 | video_path = os.path.join(FLAGS.input_path, video_info["id"]) 60 | 61 | frame_paths = tf.io.gfile.glob(os.path.join(video_path, "*.png")) 62 | if len(frame_paths) == 0: 63 | logging.info(f"Skipping {video_path}, empty") 64 | continue 65 | 66 | try: 67 | # read frames 68 | frames = [ 69 | np.array(Image.open(frame_path).resize((128, 128))) 70 | for frame_path in frame_paths 71 | ] 72 | except: 73 | logging.info(f"Skipping {video_path}, error reading frames") 74 | continue 75 | 76 | # use this to find caption 77 | video_id = int(video_info["id"]) 78 | 79 | example = tf.train.Example( 80 | features=tf.train.Features( 81 | feature={ 82 | "frames": tensor_feature(np.array(frames, dtype=np.uint8)), 83 | "task_id": tensor_feature( 84 | (np.ones(len(frames)) * video_id).astype(np.uint32) 85 | ), 86 | } 87 | ) 88 | ) 89 | writer.write(example.SerializeToString()) 90 | logging.info(f"Processed {video_path}") 91 | 92 | 93 | def main(_): 94 | with open(os.path.join(FLAGS.labels_path, "labels.json")) as f: 95 | labels = json.load(f) 96 | classes = labels.keys() 97 | 98 | map_inputs = [] 99 | for split in ["train", "validation"]: 100 | # for split in ["validation"]: 101 | with open(os.path.join(FLAGS.labels_path, f"{split}.json")) as f: 102 | metadata = json.load(f) 103 | 104 | for action_class in classes: 105 | video_infos = [ 106 | vid_info 107 | for vid_info in metadata 108 | if vid_info["template"].replace("[", "").replace("]", "") 109 | == action_class 110 | ] 111 | map_inputs.append((video_infos, split, action_class)) 112 | 113 | with Pool(FLAGS.num_workers) as p: 114 | list( 115 | tqdm.tqdm( 116 | p.imap(process, map_inputs), 117 | total=len(map_inputs), 118 | ) 119 | ) 120 | 121 | 122 | if __name__ == "__main__": 123 | app.run(main) 124 | -------------------------------------------------------------------------------- /jaxrl_m/envs/wrappers/roboverse.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | import gym 3 | import numpy as np 4 | 5 | 6 | def convert_obs(obs, img_dim): 7 | return { 8 | "image": (obs["image"].reshape(img_dim, img_dim, 3) * 255).astype(np.uint8), 9 | "proprio": obs["state"].astype(np.float32), 10 | } 11 | 12 | 13 | class RoboverseWrapper(gym.Wrapper): 14 | def __init__(self, env: gym.Env): 15 | super().__init__(env) 16 | self.env = env 17 | self.observation_space = gym.spaces.Dict( 18 | { 19 | "image": gym.spaces.Box( 20 | low=np.zeros((env.observation_img_dim, env.observation_img_dim, 3)), 21 | high=255 22 | * np.ones((env.observation_img_dim, env.observation_img_dim, 3)), 23 | dtype=np.uint8, 24 | ), 25 | "proprio": gym.spaces.Box( 26 | low=np.zeros((10,)), 27 | high=np.ones((10,)), 28 | dtype=np.uint8, 29 | ), 30 | } 31 | ) 32 | 33 | def step(self, *args): 34 | obs, reward, done, info = self.env.step(*args) 35 | return ( 36 | convert_obs(obs, self.env.observation_img_dim), 37 | reward, 38 | False, 39 | done, 40 | info, 41 | ) 42 | 43 | def seed(self, seed): 44 | pass 45 | 46 | def render(self, *args, **kwargs): 47 | return self.env.render_obs() 48 | 49 | def reset(self, **kwargs): 50 | obs = convert_obs(self.env.reset(), self.env.observation_img_dim) 51 | return obs, self.env.get_info() 52 | 53 | 54 | class GCRoboverseWrapper(gym.Wrapper): 55 | def __init__( 56 | self, 57 | env: gym.Env, 58 | goal_sampler: Union[np.ndarray, Callable], 59 | ): 60 | 61 | super().__init__(env) 62 | self.env = env 63 | self.observation_space = gym.spaces.Dict( 64 | { 65 | "image": gym.spaces.Box( 66 | low=np.zeros((env.observation_img_dim, env.observation_img_dim, 3)), 67 | high=255 68 | * np.ones((env.observation_img_dim, env.observation_img_dim, 3)), 69 | dtype=np.uint8, 70 | ), 71 | "proprio": gym.spaces.Box( 72 | low=np.zeros((10,)), 73 | high=np.ones((10,)), 74 | dtype=np.uint8, 75 | ), 76 | } 77 | ) 78 | self.current_goal = None 79 | self.goal_sampler = goal_sampler 80 | 81 | def step(self, *args): 82 | obs, reward, done, info = self.env.step(*args) 83 | return ( 84 | convert_obs(obs, self.env.observation_img_dim), 85 | reward, 86 | False, 87 | done, 88 | {"goal": self.current_goal, **info}, 89 | ) 90 | 91 | def seed(self, seed): 92 | pass 93 | 94 | def render(self, *args, **kwargs): 95 | return self.env.render_obs() 96 | 97 | def reset(self, **kwargs): 98 | 99 | if not callable(self.goal_sampler): 100 | idx = np.random.randint(len(self.goal_sampler["observations"]["image"])) 101 | goal_image = self.goal_sampler["observations"]["image"][idx] 102 | original_object_positions = self.goal_sampler["infos"]["initial_positions"][ 103 | idx 104 | ] 105 | original_object_quats = self.goal_sampler["infos"]["initial_quats"][idx] 106 | target_position = self.goal_sampler["infos"]["target_position"][idx] 107 | object_names = self.goal_sampler["infos"]["object_names"][idx] 108 | target_object = self.goal_sampler["infos"]["target_object"][idx] 109 | obs = self.env.reset( 110 | original_object_positions=original_object_positions, 111 | original_object_quats=original_object_quats, 112 | target_position=target_position, 113 | object_names=object_names, 114 | target_object=target_object, 115 | ) 116 | obs = convert_obs(obs, self.env.observation_img_dim) 117 | else: 118 | obs = self.env.reset() 119 | obs = convert_obs(obs, self.env.observation_img_dim) 120 | goal_image = self.goal_sampler(obs) 121 | 122 | goal = {"image": goal_image} 123 | 124 | self.current_goal = goal 125 | 126 | return obs, {"goal": goal, **self.env.get_info()} 127 | -------------------------------------------------------------------------------- /experiments/eval_checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jaxrl_m.common.evaluation import supply_rng, evaluate_gc 3 | from jaxrl_m.common.wandb import WandBLogger 4 | from jaxrl_m.envs.wrappers.roboverse import GCRoboverseWrapper 5 | from jaxrl_m.envs.wrappers.action_norm import UnnormalizeAction 6 | from jaxrl_m.envs.wrappers.video_recorder import VideoRecorder 7 | from jaxrl_m.vision import encoders as vision_encoders 8 | from jaxrl_m.agents.continuous.gc_bc import GCBCAgent, GCActor, GCAdaptor, Policy 9 | from jaxrl_m.common.common import TrainState 10 | 11 | import gym 12 | import tqdm 13 | import glob 14 | from absl import app, flags 15 | import numpy as np 16 | from flax.training import checkpoints 17 | import roboverse 18 | import wandb 19 | from functools import partial 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | flags.DEFINE_string("checkpoint_dir", "./log/", "Dir. with checkpoints") 24 | flags.DEFINE_string("wandb_run", "run", "Name of wandb run to get config") 25 | flags.DEFINE_string("data_path", None, "Location of dataset", required=True) 26 | flags.DEFINE_integer("seed", 42, "Random seed.") 27 | flags.DEFINE_integer("eval_episodes", 10, "Number of episodes used for evaluation.") 28 | flags.DEFINE_integer("eval_interval", 1000, "Eval interval.") 29 | flags.DEFINE_integer("max_steps", int(5e5), "Number of steps.") 30 | flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.") 31 | flags.DEFINE_boolean("deterministic_eval", True, "Take mode of action dist. for eval") 32 | flags.DEFINE_boolean("save_video", False, "Save videos during evaluation.") 33 | flags.DEFINE_string("save_dir", "./log/", "Video/buffer logging dir.") 34 | 35 | 36 | def wrap(env: gym.Env, action_metadata: dict): 37 | eval_goals = np.load( 38 | os.path.join(FLAGS.data_path, "val/eval_goals.npy"), allow_pickle=True 39 | ).item() 40 | env = GCRoboverseWrapper(env, eval_goals) 41 | env = UnnormalizeAction(env, action_metadata) 42 | env = gym.wrappers.TimeLimit(env, max_episode_steps=45) 43 | return env 44 | 45 | 46 | def main(_): 47 | api = wandb.Api() 48 | run = api.run(FLAGS.wandb_run) 49 | wandb_config = WandBLogger.get_default_config() 50 | wandb_config.update( 51 | { 52 | "project": "jaxrl_m_roboverse", 53 | "exp_prefix": "gc_roboverse_offline", 54 | "exp_descriptor": f"{run.config['env_name']}", 55 | } 56 | ) 57 | wandb_logger = WandBLogger( 58 | wandb_config=wandb_config, 59 | variant=run.config["model_config"], 60 | ) 61 | 62 | FLAGS.save_dir = os.path.join( 63 | FLAGS.save_dir, 64 | wandb_logger.config.project, 65 | wandb_logger.config.exp_prefix, 66 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}", 67 | ) 68 | 69 | action_metadata = np.load( 70 | os.path.join(FLAGS.data_path, "train/metadata.npy"), allow_pickle=True 71 | ).item() 72 | 73 | eval_env = roboverse.make(run.config["env_name"], transpose_image=False) 74 | eval_env = wrap(eval_env, action_metadata) 75 | if FLAGS.save_video: 76 | eval_env = VideoRecorder( 77 | eval_env, os.path.join(FLAGS.save_dir, "videos"), goal_conditioned=True 78 | ) 79 | eval_env.reset(seed=FLAGS.seed + 42) 80 | 81 | encoder_def = vision_encoders[run.config["encoder"]]( 82 | **run.config["model_config"]["encoder_kwargs"] 83 | ) 84 | 85 | encoders = {"actor": encoder_def} 86 | goal_encoders = encoders 87 | 88 | model_def = GCActor( 89 | encoders=encoders, 90 | goal_encoders=goal_encoders, 91 | networks={ 92 | "actor": GCAdaptor( 93 | Policy( 94 | action_dim=7, 95 | **run.config["model_config"]["agent_kwargs"]["actor_kwargs"], 96 | ) 97 | ) 98 | }, 99 | ) 100 | 101 | for i in tqdm.tqdm( 102 | range(0, FLAGS.max_steps, FLAGS.eval_interval), 103 | smoothing=0.1, 104 | disable=not FLAGS.tqdm, 105 | ): 106 | agent = GCBCAgent(TrainState.create(model_def, None, None)) 107 | agent = checkpoints.restore_checkpoint( 108 | os.path.join(FLAGS.checkpoint_dir, f"checkpoint_{i}"), target=agent 109 | ) 110 | 111 | policy_fn = supply_rng( 112 | partial(agent.sample_actions, argmax=FLAGS.deterministic_eval) 113 | ) 114 | eval_info = evaluate_gc(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes) 115 | for k, v in eval_info.items(): 116 | wandb_logger.log({f"evaluation/{k}": v}, step=i) 117 | 118 | 119 | if __name__ == "__main__": 120 | app.run(main) 121 | -------------------------------------------------------------------------------- /jaxrl_m/common/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import jax 3 | import gym 4 | import numpy as np 5 | from collections import defaultdict 6 | 7 | 8 | def supply_rng(f, rng=jax.random.PRNGKey(0)): 9 | def wrapped(*args, **kwargs): 10 | nonlocal rng 11 | rng, key = jax.random.split(rng) 12 | return f(*args, seed=key, **kwargs) 13 | 14 | return wrapped 15 | 16 | 17 | def flatten(d, parent_key="", sep="."): 18 | items = [] 19 | for k, v in d.items(): 20 | new_key = parent_key + sep + k if parent_key else k 21 | if hasattr(v, "items"): 22 | items.extend(flatten(v, new_key, sep=sep).items()) 23 | else: 24 | items.append((new_key, v)) 25 | return dict(items) 26 | 27 | 28 | def filter_info(info): 29 | filter_keys = [ 30 | "object_names", 31 | "target_object", 32 | "initial_positions", 33 | "target_position", 34 | "goal", 35 | ] 36 | for k in filter_keys: 37 | if k in info: 38 | del info[k] 39 | return info 40 | 41 | 42 | def add_to(dict_of_lists, single_dict): 43 | for k, v in single_dict.items(): 44 | dict_of_lists[k].append(v) 45 | 46 | 47 | def evaluate(policy_fn, env: gym.Env, num_episodes: int) -> Dict[str, float]: 48 | stats = defaultdict(list) 49 | for _ in range(num_episodes): 50 | observation, info = env.reset() 51 | add_to(stats, flatten(info)) 52 | done = False 53 | while not done: 54 | action = policy_fn(observation) 55 | observation, _, terminated, truncated, info = env.step(action) 56 | done = terminated or truncated 57 | add_to(stats, flatten(info)) 58 | add_to(stats, flatten(info, parent_key="final")) 59 | 60 | for k, v in stats.items(): 61 | stats[k] = np.mean(v) 62 | return stats 63 | 64 | 65 | def evaluate_with_trajectories( 66 | policy_fn, env: gym.Env, num_episodes: int 67 | ) -> Dict[str, float]: 68 | 69 | trajectories = [] 70 | stats = defaultdict(list) 71 | 72 | for _ in range(num_episodes): 73 | trajectory = defaultdict(list) 74 | observation, info = env.reset() 75 | add_to(stats, flatten(info)) 76 | done = False 77 | while not done: 78 | action = policy_fn(observation) 79 | next_observation, r, terminated, truncated, info = env.step(action) 80 | done = terminated or truncated 81 | transition = dict( 82 | observation=observation, 83 | next_observation=next_observation, 84 | action=action, 85 | reward=r, 86 | done=done, 87 | info=info, 88 | ) 89 | add_to(trajectory, transition) 90 | add_to(stats, flatten(info)) 91 | observation = next_observation 92 | add_to(stats, flatten(info, parent_key="final")) 93 | trajectories.append(trajectory) 94 | 95 | for k, v in stats.items(): 96 | stats[k] = np.mean(v) 97 | return stats, trajectories 98 | 99 | 100 | def evaluate_gc( 101 | policy_fn, 102 | env: gym.Env, 103 | num_episodes: int, 104 | return_trajectories: bool = False, 105 | ) -> Dict[str, float]: 106 | 107 | stats = defaultdict(list) 108 | 109 | if return_trajectories: 110 | trajectories = [] 111 | 112 | for _ in range(num_episodes): 113 | if return_trajectories: 114 | trajectory = defaultdict(list) 115 | 116 | observation, info = env.reset() 117 | goal = info["goal"] 118 | add_to(stats, flatten(filter_info(info))) 119 | done = False 120 | 121 | while not done: 122 | action = policy_fn(observation, goal) 123 | next_observation, r, terminated, truncated, info = env.step(action) 124 | goal = info["goal"] 125 | done = terminated or truncated 126 | transition = dict( 127 | observation=observation, 128 | next_observation=next_observation, 129 | goal=goal, 130 | action=action, 131 | reward=r, 132 | done=done, 133 | info=info, 134 | ) 135 | 136 | add_to(stats, flatten(filter_info(info))) 137 | 138 | if return_trajectories: 139 | add_to(trajectory, transition) 140 | 141 | observation = next_observation 142 | 143 | add_to(stats, flatten(filter_info(info), parent_key="final")) 144 | if return_trajectories: 145 | trajectory["steps_remaining"] = list( 146 | np.arange(len(trajectory["action"]))[::-1] 147 | ) 148 | trajectories.append(trajectory) 149 | 150 | stats = {k: np.mean(v) for k, v in stats.items() if not isinstance(v[0], str)} 151 | 152 | if return_trajectories: 153 | return stats, trajectories 154 | else: 155 | return stats 156 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We discuss two key abstractions used heavily in this codebase: the use of `TrainState` and the expression of agents as `PytreeNodes` 4 | 5 | ## Agents 6 | 7 | In this codebase, we represent agents as PytreeNodes (first-class Jax citizens), making them really easy to handle. The simplest working example we have in the codebase is probably `jaxrl_m/agents/discrete/bc.py`, so check that out for a concrete implementation. 8 | 9 | The general structure of an Agent is as follows: it contains some number of neural networks, some set of configuration values, and has an update function that takes in a batch and returns a agent with updated parameters after performing some gradient update. Usually there's a `sample_actions` to sample from the resulting policy too. 10 | 11 | ```python 12 | class Agent(flax.struct.PyTreeNode): 13 | value_function: TrainState 14 | policy: TrainState 15 | config: dict = nonpytree_field() # tells Jax to not look at this (usually contains discount factor / target update speed / other hyperparams) 16 | 17 | @jax.jit 18 | def update(self, batch: Batch): 19 | ... 20 | new_value_function = ... 21 | new_policy = ... 22 | info = {'loss': 100} 23 | new_agent = self.replace(value_function=value_function, policy=new_policy) 24 | return new_agent, info 25 | 26 | @jax.jit 27 | def sample_actions(self, observations, *, seed): 28 | actions = ... 29 | return actions 30 | ``` 31 | 32 | ### Multiple Devices 33 | 34 | Operating on multiple GPUs / TPUs is really easy! Check out the section at the bottom of the page as to how to accumulate gradients across all the GPUs. 35 | 36 | 37 | - `flax.jax_utils.replicate()`: replicates an object on all GPUs 38 | - `jaxrl_m.common.common.shard_batch`: splits an batch evenly across all the GPUs 39 | - `flax.jax_utils.unreplicate()` brings back to single GPU 40 | 41 | ```python 42 | agent = ... 43 | batch = ... 44 | 45 | replicated_agent = replicate(agent) 46 | replicated_agent, info = replicated_agent.update(shard_batch(batch)) 47 | info = unreplicate(info) # bring info back to single device 48 | 49 | 50 | ``` 51 | ## TrainState 52 | 53 | 54 | The TrainState class (located at `jaxrl_m.common.common.TrainState`) is a fork of Flax's TrainState class with some additional syntactic features for ease of use. 55 | 56 | The TrainState class combines a neural network module (`flax.linen.Module`) with a set of parameters for this network (alongside with potentially an optimizer) 57 | 58 | ### Creating a TrainState 59 | 60 | ```python 61 | model_def = nn.Dense(10) # nn.Module 62 | params = model_def.init(rng, x)['params'] # parameters for nn.Module 63 | tx = optax.adam(1e-3) 64 | model = TrainState.create(model_def, params, tx=tx) 65 | ``` 66 | 67 | ### Running the Model 68 | 69 | ```python 70 | model = TrainState.create(...) 71 | y_pred = model(x) 72 | ``` 73 | 74 | In some cases, the neural network module may have several functions; for example, a VAE might have an `.encode(x)` function and a `.decode(z)` function. By default, the `__call__()` method is used, but this can be specified via an argument: 75 | 76 | ```python 77 | z = model(x, method='encode') 78 | x_pred = model(z, method='decode') 79 | ``` 80 | 81 | You can also run the model with a different set of parameters than that bound to the TrainState. This is most commonly done when taking the gradient with respect to model parameters. 82 | 83 | ```python 84 | y_pred = model(x, params=other_params) 85 | ``` 86 | 87 | ```python 88 | def loss(params): 89 | y_pred = model(x, params=params) 90 | return jnp.mean((y - y_pred) ** 2) 91 | 92 | grads = jax.grad(loss)(model.params) 93 | ``` 94 | 95 | ### Optimizing a TrainState 96 | 97 | To update a model (that has a `tx`), we provide two convenience functions: `.apply_gradients` and `.apply_loss_fn` 98 | 99 | `model.apply_gradients` takes in a set of gradients (same shape as parameters) and computes the new set of parameters using optax. 100 | 101 | ```python 102 | def loss(params): 103 | y_pred = model(x, params=params) 104 | return jnp.mean((y - y_pred) ** 2) 105 | 106 | grads = jax.grad(loss)(model.params) 107 | new_model = model.apply_gradients(grads=grads) 108 | ``` 109 | 110 | `model.apply_loss_fn()` is a convenience method that both computes the gradients and runs `.apply_gradients()`. 111 | 112 | ```python 113 | def loss(params): 114 | y_pred = model(x, params=params) 115 | return jnp.mean((y - y_pred) ** 2) 116 | 117 | new_model = model.apply_loss_fn(loss_fn=loss) 118 | ``` 119 | 120 | If the model is being run across multiple GPUs / TPUs and we wish to aggregate gradients, this can be specified with the `pmap_axis` argument (you can always use jax.lax.pmean as an alternative): 121 | 122 | ```python 123 | @functools.partial(jax.pmap, axis_name='pmap') 124 | def update(model, x, y): 125 | def loss(params): 126 | y_pred = model(x, params=params) 127 | return jnp.mean((y - y_pred) ** 2) 128 | 129 | new_model = model.apply_loss_fn(loss_fn=loss, pmap_axis='pmap') 130 | return new_model 131 | ``` 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /jaxrl_m/utils/sim_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | import gym 3 | import numpy as np 4 | import roboverse 5 | from mujoco_manipulation import tasks 6 | import tensorflow as tf 7 | from jaxrl_m.envs.wrappers.roboverse import GCRoboverseWrapper 8 | from jaxrl_m.envs.wrappers.mujoco import GCMujocoWrapper 9 | from jaxrl_m.envs.wrappers.dmcgym import DMCGYM 10 | from jaxrl_m.envs.wrappers.action_norm import UnnormalizeAction 11 | from jaxrl_m.envs.wrappers.video_recorder import VideoRecorder 12 | 13 | 14 | def wrap_mujoco_gc_env( 15 | env, 16 | max_episode_steps: int, 17 | action_metadata: dict, 18 | goal_sampler: Union[np.ndarray, Callable], 19 | ): 20 | env = DMCGYM(env) 21 | env = GCMujocoWrapper(env, goal_sampler) 22 | env = UnnormalizeAction(env, action_metadata) 23 | env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps) 24 | return env 25 | 26 | 27 | def make_mujoco_gc_env( 28 | env_name: str, 29 | max_episode_steps: int, 30 | action_metadata: dict, 31 | save_video: bool, 32 | save_video_dir: str, 33 | save_video_prefix: str, 34 | goals: Union[np.ndarray, Callable], 35 | ): 36 | env = tasks.load(env_name) 37 | env = wrap_mujoco_gc_env(env, max_episode_steps, action_metadata, goals) 38 | 39 | if save_video: 40 | env = VideoRecorder( 41 | env, 42 | save_folder=save_video_dir, 43 | save_prefix=save_video_prefix, 44 | goal_conditioned=True, 45 | ) 46 | 47 | env.reset() 48 | 49 | return env 50 | 51 | 52 | def wrap_roboverse_gc_env( 53 | env: gym.Env, 54 | max_episode_steps: int, 55 | action_metadata: dict, 56 | goal_sampler: Union[np.ndarray, Callable], 57 | ): 58 | env = GCRoboverseWrapper(env, goal_sampler) 59 | env = UnnormalizeAction(env, action_metadata) 60 | env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps) 61 | return env 62 | 63 | 64 | def make_roboverse_gc_env( 65 | env_name: str, 66 | max_episode_steps: int, 67 | action_metadata: dict, 68 | save_video: bool, 69 | save_video_dir: str, 70 | save_video_prefix: str, 71 | goals: Union[np.ndarray, Callable], 72 | ): 73 | env = roboverse.make(env_name, transpose_image=False) 74 | env = wrap_roboverse_gc_env(env, max_episode_steps, action_metadata, goals) 75 | 76 | if save_video: 77 | env = VideoRecorder( 78 | env, 79 | save_folder=save_video_dir, 80 | save_prefix=save_video_prefix, 81 | goal_conditioned=True, 82 | ) 83 | 84 | env.reset() 85 | 86 | return env 87 | 88 | 89 | PROTO_TYPE_SPEC = { 90 | "observations/images0": tf.uint8, 91 | "observations/state": tf.float32, 92 | "next_observations/images0": tf.uint8, 93 | "next_observations/state": tf.float32, 94 | "actions": tf.float32, 95 | "terminals": tf.bool, 96 | "truncates": tf.bool, 97 | "info/place_success": tf.bool, 98 | "info/target_object": tf.uint8, 99 | "info/object_positions": tf.float32, 100 | "info/target_position": tf.float32, 101 | "info/object_names": tf.string, 102 | } 103 | 104 | 105 | def _decode_example(example_proto): 106 | # decode the example proto according to PROTO_TYPE_SPEC 107 | features = { 108 | key: tf.io.FixedLenFeature([], tf.string) for key in PROTO_TYPE_SPEC.keys() 109 | } 110 | parsed_features = tf.io.parse_single_example(example_proto, features) 111 | parsed_tensors = { 112 | key: tf.io.parse_tensor(parsed_features[key], dtype) 113 | for key, dtype in PROTO_TYPE_SPEC.items() 114 | } 115 | 116 | return { 117 | "observations": { 118 | "image": parsed_tensors["observations/images0"], 119 | "proprio": parsed_tensors["observations/state"], 120 | }, 121 | "next_observations": { 122 | "image": parsed_tensors["next_observations/images0"], 123 | "proprio": parsed_tensors["next_observations/state"], 124 | }, 125 | "actions": parsed_tensors["actions"], 126 | "terminals": parsed_tensors["terminals"], 127 | "truncates": parsed_tensors["truncates"], 128 | "infos": { 129 | "place_success": parsed_tensors["info/place_success"], 130 | "object_positions": parsed_tensors["info/object_positions"], 131 | "target_position": parsed_tensors["info/target_position"], 132 | "target_object": parsed_tensors["info/target_object"], 133 | "object_names": parsed_tensors["info/object_names"], 134 | }, 135 | } 136 | 137 | 138 | def load_tf_dataset(data_path): 139 | """Load a sim dataset in TFRecord format.""" 140 | data_paths = tf.io.gfile.glob(tf.io.gfile.join(data_path, "*.tfrecord")) 141 | 142 | # shuffle again using the dataset API so the files are read in a 143 | # different order every epoch 144 | dataset = tf.data.Dataset.from_tensor_slices(data_paths) 145 | 146 | # yields raw serialized examples 147 | dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) 148 | 149 | # yields trajectories 150 | dataset = dataset.map(_decode_example, num_parallel_calls=tf.data.AUTOTUNE) 151 | 152 | return dataset 153 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/affordance.py: -------------------------------------------------------------------------------- 1 | """Affordance model.""" 2 | import functools 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import flax 7 | import flax.linen as nn 8 | import optax 9 | 10 | from jaxrl_m.common.typing import Batch 11 | from jaxrl_m.common.typing import Dict 12 | from flax.core import FrozenDict 13 | from jaxrl_m.common.common import TrainState 14 | from jaxrl_m.common.common import nonpytree_field 15 | from jaxrl_m.networks.actor_critic_nets import get_encoding 16 | from jaxrl_m.vision.cvae import CVAE 17 | from jaxlib.xla_extension import DeviceArray 18 | 19 | 20 | def elbo_loss(reconstruction, goal, mean, logvar, affordance_beta=0.02): 21 | pred_loss = jnp.mean(jnp.square(reconstruction - goal), axis=-1) 22 | kld = -0.5 * jnp.sum(1.0 + logvar - mean**2.0 - jnp.exp(logvar), axis=-1) 23 | assert pred_loss.shape == kld.shape 24 | elbo_loss = pred_loss + affordance_beta * kld 25 | return elbo_loss.mean(), { 26 | "pred_loss": pred_loss.mean(), 27 | "kld": kld.mean(), 28 | "elbo_loss": elbo_loss.mean(), 29 | } 30 | 31 | 32 | class Affordance(nn.Module): 33 | networks: Dict[str, nn.Module] 34 | 35 | def affordance(self, observations, goals, seed, return_latents=False): 36 | latents = get_encoding( 37 | self.networks["encoder"], 38 | observations, 39 | use_proprio=False, 40 | early_goal_concat=False, 41 | goals=None, 42 | stop_gradient=True, 43 | ) 44 | goal_latents = get_encoding( 45 | self.networks["encoder"], 46 | goals, 47 | use_proprio=False, 48 | early_goal_concat=False, 49 | goals=None, 50 | stop_gradient=True, 51 | ) 52 | reconstruction = self.networks["affordance"](latents, goal_latents, seed) 53 | 54 | if return_latents: 55 | info = {"latents": latents, "goal_latents": goal_latents} 56 | return reconstruction, info 57 | else: 58 | return reconstruction 59 | 60 | def __call__(self, observations, goals, seed): 61 | rets = dict() 62 | rets["affordance"] = self.affordance(observations, goals, seed=seed) 63 | return rets 64 | 65 | 66 | class AffordanceAgent(flax.struct.PyTreeNode): 67 | model: TrainState 68 | config: dict = nonpytree_field() 69 | 70 | @functools.partial(jax.pmap, axis_name="pmap") 71 | def update(agent, batch: Batch, seed: DeviceArray): 72 | def loss_fn(params): 73 | rets, latent_dict = agent.model( 74 | batch["observations"], 75 | batch["goals"], 76 | seed=seed, 77 | return_latents=True, 78 | params=params, 79 | method="affordance", 80 | ) 81 | loss, info = elbo_loss( 82 | rets["reconstruction"], 83 | latent_dict["goal_latents"], 84 | rets["mean"], 85 | rets["logvar"], 86 | affordance_beta=agent.config["affordance_beta"], 87 | ) 88 | return loss, info 89 | 90 | new_model, info = agent.model.apply_loss_fn( 91 | loss_fn=loss_fn, has_aux=True, pmap_axis="pmap" 92 | ) 93 | 94 | return agent.replace(model=new_model), info 95 | 96 | @jax.jit 97 | def get_debug_metrics(agent, batch, seed): 98 | rets, latent_dict = agent.model( 99 | batch["observations"], 100 | batch["goals"], 101 | seed=seed, 102 | return_latents=True, 103 | # params=params, 104 | method="affordance", 105 | ) 106 | loss, info = elbo_loss( 107 | rets["reconstruction"], 108 | latent_dict["goal_latents"], 109 | rets["mean"], 110 | rets["logvar"], 111 | affordance_beta=agent.config["affordance_beta"], 112 | ) 113 | 114 | return info 115 | 116 | 117 | def create_affordance_learner( 118 | seed: int, 119 | observations: FrozenDict, 120 | goals: FrozenDict, 121 | encoder_def: nn.Module, 122 | # Model architecture 123 | affordance_kwargs: dict = { 124 | "encoder_hidden_dims": [256, 256, 256], 125 | "latent_dim": 8, 126 | "decoder_hidden_dims": [256, 256, 256], 127 | }, 128 | # Optimizer 129 | optim_kwargs: dict = { 130 | # 'learning_rate': 6e-5, # TODO(kuanfang): Was this tuned? 131 | "learning_rate": 3e-4, 132 | }, 133 | # Algorithm config 134 | affordance_beta=0.02, 135 | **kwargs 136 | ): 137 | print("Extra kwargs:", kwargs) 138 | 139 | rng = jax.random.PRNGKey(seed) 140 | affordance_def = CVAE(**affordance_kwargs) 141 | 142 | networks = { 143 | "encoder": encoder_def, 144 | "affordance": affordance_def, 145 | } 146 | 147 | model_def = Affordance( 148 | networks=networks, 149 | ) 150 | 151 | tx = optax.adam(**optim_kwargs) 152 | 153 | params = model_def.init(rng, observations, goals, rng)["params"] 154 | 155 | model = TrainState.create(model_def, params, tx=tx) 156 | 157 | config = flax.core.FrozenDict(dict(affordance_beta=affordance_beta)) 158 | 159 | return AffordanceAgent(model, config) 160 | -------------------------------------------------------------------------------- /jaxrl_m/agents/discrete/cql.py: -------------------------------------------------------------------------------- 1 | """Implementations of CQL in discrete action spaces.""" 2 | import functools 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import flax 7 | import flax.linen as nn 8 | import optax 9 | import distrax 10 | 11 | from jaxrl_m.common.typing import * 12 | from jaxrl_m.common.common import TrainState, nonpytree_field, target_update 13 | from jaxrl_m.networks.discrete_nets import DiscreteQ, DiscreteCriticHead 14 | 15 | import ml_collections 16 | 17 | 18 | def cql_loss_fn(q, q_pred, q_target, cql_temperature=1.0, cql_alpha=1.0): 19 | td_loss = jnp.square(q_pred - q_target) 20 | cql_loss = ( 21 | jax.scipy.special.logsumexp(q / cql_temperature, axis=-1) 22 | - q_pred / cql_temperature 23 | ) 24 | critic_loss = td_loss + cql_alpha * cql_loss 25 | 26 | dist = distrax.Categorical(logits=q / cql_temperature) 27 | q_sorted = jnp.sort(q, axis=-1) 28 | 29 | return critic_loss.mean(), { 30 | "critic_loss": critic_loss.mean(), 31 | "td_loss": td_loss.mean(), 32 | "cql_loss": cql_loss.mean(), 33 | "td_loss max": td_loss.max(), 34 | "td_loss min": td_loss.min(), 35 | "entropy": dist.entropy().mean(), 36 | "q": q_pred.mean(), 37 | "q_pi": jnp.max(q, axis=-1).mean(), 38 | "target_q": q_target.mean(), 39 | "q_gap": jnp.mean(q_sorted[:, -1] - q_sorted[:, -2]), 40 | "q_gap max": jnp.max(q_sorted[:, -1] - q_sorted[:, -2]), 41 | "q_gap min": jnp.min(q_sorted[:, -1] - q_sorted[:, -2]), 42 | } 43 | 44 | 45 | class CQLAgent(flax.struct.PyTreeNode): 46 | model: TrainState 47 | target_model: TrainState 48 | config: dict = nonpytree_field() 49 | # What method to call to get Q-values (default is model_def.__call__) 50 | method: ModuleMethod = nonpytree_field(default=None) 51 | 52 | @functools.partial(jax.pmap, axis_name="pmap") 53 | def update(agent, batch: Batch): 54 | def loss_fn(params): 55 | # Target Q 56 | nq = agent.target_model(batch["next_observations"], method=agent.method) 57 | nv = jnp.max(nq, axis=-1) 58 | q_target = batch["rewards"] + agent.config["discount"] * nv * batch["masks"] 59 | 60 | # Current Q 61 | q = agent.model(batch["observations"], params=params, method=agent.method) 62 | q_pred = q[jnp.arange(len(batch["actions"])), batch["actions"]] 63 | 64 | # CQL Loss 65 | critic_loss, info = cql_loss_fn( 66 | q, 67 | q_pred, 68 | q_target, 69 | cql_temperature=agent.config["temperature"], 70 | cql_alpha=agent.config["cql_alpha"], 71 | ) 72 | return critic_loss, info 73 | 74 | new_model, info = agent.model.apply_loss_fn( 75 | loss_fn=loss_fn, pmap_axis="pmap", has_aux=True 76 | ) 77 | 78 | new_target_model = target_update( 79 | agent.model, agent.target_model, agent.config["target_update_rate"] 80 | ) 81 | 82 | return agent.replace(model=new_model, target_model=new_target_model), info 83 | 84 | @functools.partial(jax.jit, static_argnames=("argmax")) 85 | def sample_actions(agent, observations, *, seed, temperature=1.0, argmax=False): 86 | logits = agent.model(observations, method=agent.method) 87 | dist = distrax.Categorical(logits=logits / temperature) 88 | 89 | if argmax: 90 | return dist.mode() 91 | else: 92 | return dist.sample(seed=seed) 93 | 94 | 95 | def create_cql_learner( 96 | seed: int, 97 | observations: jnp.ndarray, 98 | n_actions: int, 99 | # Model architecture 100 | encoder_def: nn.Module, 101 | network_kwargs: dict = { 102 | "hidden_dims": [256, 256], 103 | }, 104 | optim_kwargs: dict = { 105 | "learning_rate": 6e-5, 106 | }, 107 | # Algorithm config 108 | discount=0.95, 109 | cql_alpha=1.0, 110 | temperature=1.0, 111 | target_update_rate=0.002, 112 | **kwargs 113 | ): 114 | 115 | print("Extra kwargs:", kwargs) 116 | 117 | rng = jax.random.PRNGKey(seed) 118 | 119 | network_def = DiscreteCriticHead(n_actions=n_actions, **network_kwargs) 120 | model_def = DiscreteQ(encoder=encoder_def, network=network_def) 121 | 122 | tx = optax.adam(**optim_kwargs) 123 | params = model_def.init(rng, observations)["params"] 124 | 125 | model = TrainState.create(model_def, params, tx=tx) 126 | target_model = TrainState.create(model_def, params) 127 | 128 | config = flax.core.FrozenDict( 129 | dict( 130 | discount=discount, 131 | cql_alpha=cql_alpha, 132 | temperature=temperature, 133 | target_update_rate=target_update_rate, 134 | ) 135 | ) 136 | return CQLAgent(model, target_model, config) 137 | 138 | 139 | def get_default_config(): 140 | config = ml_collections.ConfigDict( 141 | { 142 | "algo": "cql", 143 | "optim_kwargs": {"learning_rate": 6e-5, "eps": 0.00015}, 144 | "network_kwargs": { 145 | "hidden_dims": (256, 256), 146 | }, 147 | "discount": 0.95, 148 | "cql_alpha": 0.5, 149 | "temperature": 1.0, 150 | "target_update_rate": 0.002, 151 | } 152 | ) 153 | return config 154 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/bc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Any 4 | import jax 5 | import jax.numpy as jnp 6 | from jaxrl_m.common.encoding import EncodingWrapper 7 | import numpy as np 8 | import flax 9 | import flax.linen as nn 10 | import optax 11 | 12 | from flax.core import FrozenDict 13 | from jaxrl_m.common.typing import Batch 14 | from jaxrl_m.common.typing import PRNGKey 15 | from jaxrl_m.common.common import JaxRLTrainState, nonpytree_field 16 | from jaxrl_m.networks.actor_critic_nets import Policy 17 | from jaxrl_m.networks.actor_critic_nets import ActorCriticWrapper 18 | 19 | 20 | class BCAgent(flax.struct.PyTreeNode): 21 | state: JaxRLTrainState 22 | lr_schedule: Any = nonpytree_field() 23 | 24 | @partial(jax.jit, static_argnames="pmap_axis") 25 | def update(self, batch: Batch, pmap_axis: str = None): 26 | new_rng, dropout_rng = jax.random.split(self.state.rng) 27 | 28 | def loss_fn(params): 29 | dist = self.state.apply_fn( 30 | {"params": params}, 31 | batch["observations"], 32 | temperature=1.0, 33 | train=True, 34 | rngs={"dropout": dropout_rng}, 35 | method="actor", 36 | ) 37 | pi_actions = dist.mode() 38 | log_probs = dist.log_prob(batch["actions"]) 39 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 40 | actor_loss = -(log_probs).mean() 41 | actor_std = dist.stddev().mean(axis=1) 42 | 43 | return actor_loss, { 44 | "actor_loss": actor_loss, 45 | "mse": mse.mean(), 46 | "log_probs": log_probs, 47 | "pi_actions": pi_actions, 48 | "mean_std": actor_std.mean(), 49 | "max_std": actor_std.max(), 50 | } 51 | 52 | # compute gradients and update params 53 | new_state, info = self.state.apply_loss_fns( 54 | loss_fn, pmap_axis=pmap_axis, has_aux=True 55 | ) 56 | # update rng 57 | new_state = new_state.replace(rng=new_rng) 58 | # log learning rates 59 | info["lr"] = self.lr_schedule(self.state.step) 60 | 61 | return self.replace(state=new_state), info 62 | 63 | @partial(jax.jit, static_argnames="argmax") 64 | def sample_actions( 65 | self, 66 | observations: np.ndarray, 67 | *, 68 | seed: PRNGKey, 69 | temperature: float = 1.0, 70 | argmax=False 71 | ) -> jnp.ndarray: 72 | dist = self.state.apply_fn( 73 | {"params": self.state.params}, 74 | observations, 75 | temperature=temperature, 76 | method="actor", 77 | ) 78 | if argmax: 79 | actions = dist.mode() 80 | else: 81 | actions = dist.sample(seed=seed) 82 | return actions 83 | 84 | @jax.jit 85 | def get_debug_metrics(self, batch): 86 | dist = self.state.apply_fn( 87 | {"params": self.state.params}, 88 | batch["observations"], 89 | temperature=1.0, 90 | method="actor", 91 | ) 92 | pi_actions = dist.mode() 93 | log_probs = dist.log_prob(batch["actions"]) 94 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 95 | 96 | return { 97 | "mse": mse, 98 | "log_probs": log_probs, 99 | "pi_actions": pi_actions, 100 | } 101 | 102 | @classmethod 103 | def create( 104 | cls, 105 | rng: PRNGKey, 106 | observations: FrozenDict, 107 | actions: jnp.ndarray, 108 | # Model architecture 109 | encoder_def: nn.Module, 110 | use_proprio: bool = False, 111 | network_kwargs: dict = { 112 | "hidden_dims": [256, 256], 113 | }, 114 | policy_kwargs: dict = { 115 | "tanh_squash_distribution": False, 116 | "state_dependent_std": False, 117 | "dropout": 0.0, 118 | }, 119 | # Optimizer 120 | learning_rate: float = 3e-4, 121 | warmup_steps: int = 1000, 122 | decay_steps: int = 1000000, 123 | ): 124 | encoder_def = EncodingWrapper( 125 | encoder=encoder_def, 126 | use_proprio=use_proprio, 127 | stop_gradient=False, 128 | ) 129 | 130 | encoders = {"actor": encoder_def} 131 | networks = { 132 | "actor": Policy( 133 | action_dim=actions.shape[-1], **network_kwargs, **policy_kwargs 134 | ) 135 | } 136 | 137 | model_def = ActorCriticWrapper( 138 | encoders=encoders, 139 | networks=networks, 140 | ) 141 | 142 | lr_schedule = optax.warmup_cosine_decay_schedule( 143 | init_value=0.0, 144 | peak_value=learning_rate, 145 | warmup_steps=warmup_steps, 146 | decay_steps=decay_steps, 147 | end_value=0.0, 148 | ) 149 | tx = optax.adam(lr_schedule) 150 | 151 | rng, init_rng = jax.random.split(rng) 152 | params = model_def.init(init_rng, observations, actions)["params"] 153 | 154 | rng, create_rng = jax.random.split(rng) 155 | state = JaxRLTrainState.create( 156 | apply_fn=model_def.apply, 157 | params=params, 158 | txs=tx, 159 | target_params=params, 160 | rng=create_rng, 161 | ) 162 | 163 | return cls(state, lr_schedule) 164 | -------------------------------------------------------------------------------- /jaxrl_m/agents/discrete/gc_cql.py: -------------------------------------------------------------------------------- 1 | """Implementations of goal-conditioned CQL in discrete action spaces.""" 2 | import functools 3 | import copy 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import flax 9 | import flax.linen as nn 10 | import optax 11 | import distrax 12 | 13 | from jaxrl_m.common.typing import * 14 | from jaxrl_m.common.common import TrainState, nonpytree_field, target_update 15 | from jaxrl_m.networks.discrete_nets import DiscreteCriticHead 16 | from jaxrl_m.networks.actor_critic_nets import get_encoding 17 | from jaxrl_m.agents.discrete.cql import target_update, cql_loss_fn 18 | 19 | import ml_collections 20 | 21 | 22 | class DiscreteGCQ(nn.Module): 23 | encoder: nn.Module 24 | goal_encoder: nn.Module 25 | network: nn.Module 26 | 27 | def __call__(self, observations, goals): 28 | latents = get_encoding(self.encoder, observations) 29 | goal_latents = get_encoding(self.goal_encoder, goals) 30 | return self.network(latents, goal_latents) 31 | 32 | 33 | class GCAdaptor(nn.Module): 34 | network: nn.Module 35 | 36 | def __call__(self, observations, goals): 37 | combined = jnp.concatenate([observations, goals], axis=-1) 38 | return self.network(combined) 39 | 40 | 41 | class GoalConditionedCQLAgent(flax.struct.PyTreeNode): 42 | model: TrainState 43 | target_model: TrainState 44 | config: dict = nonpytree_field() 45 | 46 | # What method to call to get goal-conditioned Q-values (default is model_def.__call__) 47 | method: ModuleMethod = nonpytree_field(default=None) 48 | 49 | @functools.partial(jax.pmap, axis_name="pmap") 50 | def update(agent, batch: Batch): 51 | def loss_fn(params): 52 | # Target Q 53 | nq = agent.target_model( 54 | batch["next_observations"], batch["goals"], method=agent.method 55 | ) 56 | nv = jnp.max(nq, axis=-1) 57 | q_target = batch["rewards"] + agent.config["discount"] * nv * batch["masks"] 58 | 59 | # Current Q 60 | q = agent.model( 61 | batch["observations"], 62 | batch["goals"], 63 | params=params, 64 | method=agent.method, 65 | ) 66 | q_pred = q[jnp.arange(len(batch["actions"])), batch["actions"]] 67 | 68 | # CQL Loss 69 | critic_loss, info = cql_loss_fn( 70 | q, 71 | q_pred, 72 | q_target, 73 | cql_temperature=agent.config["temperature"], 74 | cql_alpha=agent.config["cql_alpha"], 75 | ) 76 | return critic_loss, info 77 | 78 | new_model, info = agent.model.apply_loss_fn( 79 | loss_fn=loss_fn, pmap_axis="pmap", has_aux=True 80 | ) 81 | new_target_model = target_update( 82 | agent.model, agent.target_model, agent.config["target_update_rate"] 83 | ) 84 | 85 | return agent.replace(model=new_model, target_model=new_target_model), info 86 | 87 | @functools.partial(jax.jit, static_argnames=("argmax")) 88 | def sample_actions( 89 | agent, observations, goals, *, seed, temperature=1.0, argmax=False 90 | ): 91 | logits = agent.model(observations, goals, method=agent.method) 92 | if argmax: 93 | return jnp.argmax(logits, axis=-1) 94 | else: 95 | dist = distrax.Categorical(logits=logits / temperature) 96 | return dist.sample(seed=seed) 97 | 98 | 99 | def create_cql_learner( 100 | seed: int, 101 | observations: jnp.ndarray, 102 | goals: jnp.ndarray, 103 | n_actions: int, 104 | # Model architecture 105 | encoder_def: nn.Module, 106 | shared_goal_encoder: bool = False, 107 | network_kwargs: dict = { 108 | "hidden_dims": [256, 256], 109 | }, 110 | optim_kwargs: dict = { 111 | "learning_rate": 6e-5, 112 | }, 113 | # Algorithm config 114 | discount=0.95, 115 | cql_alpha=1.0, 116 | temperature=1.0, 117 | target_update_rate=0.002, 118 | **kwargs 119 | ): 120 | 121 | print("Extra kwargs:", kwargs) 122 | 123 | rng = jax.random.PRNGKey(seed) 124 | rng, model_key = jax.random.split(rng) 125 | 126 | if network_def is None: 127 | network_def = GCAdaptor( 128 | DiscreteCriticHead(n_actions=n_actions, **network_kwargs) 129 | ) 130 | 131 | if shared_goal_encoder: 132 | goal_encoder_def = encoder_def 133 | else: 134 | goal_encoder_def = copy.deepcopy(encoder_def) 135 | 136 | model_def = DiscreteGCQ( 137 | encoder=encoder_def, 138 | goal_encoder=goal_encoder_def, 139 | network=network_def, 140 | ) 141 | 142 | print(model_def) 143 | 144 | if tx is None: 145 | tx = optax.adam(**optim_kwargs) 146 | 147 | params = model_def.init(model_key, observations, goals)["params"] 148 | 149 | model = TrainState.create(model_def, params, tx=tx) 150 | target_model = TrainState.create(model_def, params=params) 151 | 152 | config = flax.core.FrozenDict( 153 | dict( 154 | discount=discount, 155 | cql_alpha=cql_alpha, 156 | temperature=temperature, 157 | target_update_rate=target_update_rate, 158 | ) 159 | ) 160 | return GoalConditionedCQLAgent(rng, model, target_model, config) 161 | 162 | 163 | def get_default_config(): 164 | config = ml_collections.ConfigDict( 165 | { 166 | "algo": "gccql", 167 | "shared_goal_encoder": False, 168 | "optim_kwargs": {"learning_rate": 6e-5, "eps": 0.00015}, 169 | "network_kwargs": { 170 | "hidden_dims": (256, 256), 171 | }, 172 | "discount": 0.95, 173 | "cql_alpha": 0.5, 174 | "temperature": 1.0, 175 | "target_update_rate": 0.002, 176 | } 177 | ) 178 | return config 179 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/gc_bc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Any 4 | import jax 5 | import jax.numpy as jnp 6 | from jaxrl_m.common.encoding import GCEncodingWrapper 7 | import numpy as np 8 | import flax 9 | import flax.linen as nn 10 | import optax 11 | 12 | from flax.core import FrozenDict 13 | from jaxrl_m.common.typing import Batch 14 | from jaxrl_m.common.typing import PRNGKey 15 | from jaxrl_m.common.common import JaxRLTrainState, nonpytree_field 16 | from jaxrl_m.networks.actor_critic_nets import Policy 17 | from jaxrl_m.networks.actor_critic_nets import ActorCriticWrapper 18 | 19 | 20 | class GCBCAgent(flax.struct.PyTreeNode): 21 | state: JaxRLTrainState 22 | lr_schedule: Any = nonpytree_field() 23 | 24 | @partial(jax.jit, static_argnames="pmap_axis") 25 | def update(self, batch: Batch, pmap_axis: str = None): 26 | def loss_fn(params, rng): 27 | rng, key = jax.random.split(rng) 28 | dist = self.state.apply_fn( 29 | {"params": params}, 30 | (batch["observations"], batch["goals"]), 31 | temperature=1.0, 32 | train=True, 33 | rngs={"dropout": key}, 34 | method="actor", 35 | ) 36 | pi_actions = dist.mode() 37 | log_probs = dist.log_prob(batch["actions"]) 38 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 39 | actor_loss = -(log_probs).mean() 40 | actor_std = dist.stddev().mean(axis=1) 41 | 42 | return actor_loss, { 43 | "actor_loss": actor_loss, 44 | "mse": mse.mean(), 45 | "log_probs": log_probs.mean(), 46 | "pi_actions": pi_actions.mean(), 47 | "mean_std": actor_std.mean(), 48 | "max_std": actor_std.max(), 49 | } 50 | 51 | # compute gradients and update params 52 | new_state, info = self.state.apply_loss_fns( 53 | loss_fn, pmap_axis=pmap_axis, has_aux=True 54 | ) 55 | 56 | # log learning rates 57 | info["lr"] = self.lr_schedule(self.state.step) 58 | 59 | return self.replace(state=new_state), info 60 | 61 | @partial(jax.jit, static_argnames="argmax") 62 | def sample_actions( 63 | self, 64 | observations: np.ndarray, 65 | goals: np.ndarray, 66 | *, 67 | seed: PRNGKey, 68 | temperature: float = 1.0, 69 | argmax=False 70 | ) -> jnp.ndarray: 71 | dist = self.state.apply_fn( 72 | {"params": self.state.params}, 73 | (observations, goals), 74 | temperature=temperature, 75 | method="actor", 76 | ) 77 | if argmax: 78 | actions = dist.mode() 79 | else: 80 | actions = dist.sample(seed=seed) 81 | return actions 82 | 83 | @jax.jit 84 | def get_debug_metrics(self, batch): 85 | dist = self.state.apply_fn( 86 | {"params": self.state.params}, 87 | (batch["observations"], batch["goals"]), 88 | temperature=1.0, 89 | method="actor", 90 | ) 91 | pi_actions = dist.mode() 92 | log_probs = dist.log_prob(batch["actions"]) 93 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 94 | 95 | return { 96 | "mse": mse, 97 | "log_probs": log_probs, 98 | "pi_actions": pi_actions, 99 | } 100 | 101 | @classmethod 102 | def create( 103 | cls, 104 | rng: PRNGKey, 105 | observations: FrozenDict, 106 | actions: jnp.ndarray, 107 | goals: FrozenDict, 108 | # Model architecture 109 | encoder_def: nn.Module, 110 | shared_goal_encoder: bool = True, 111 | early_goal_concat: bool = False, 112 | use_proprio: bool = False, 113 | network_kwargs: dict = { 114 | "hidden_dims": [256, 256], 115 | }, 116 | policy_kwargs: dict = { 117 | "tanh_squash_distribution": False, 118 | "state_dependent_std": False, 119 | "dropout": 0.0, 120 | }, 121 | # Optimizer 122 | learning_rate: float = 3e-4, 123 | warmup_steps: int = 1000, 124 | decay_steps: int = 1000000, 125 | ): 126 | if early_goal_concat: 127 | # passing None as the goal encoder causes early goal concat 128 | goal_encoder_def = None 129 | else: 130 | if shared_goal_encoder: 131 | goal_encoder_def = encoder_def 132 | else: 133 | goal_encoder_def = copy.deepcopy(encoder_def) 134 | 135 | encoder_def = GCEncodingWrapper( 136 | encoder=encoder_def, 137 | goal_encoder=goal_encoder_def, 138 | use_proprio=use_proprio, 139 | stop_gradient=False, 140 | ) 141 | 142 | encoders = {"actor": encoder_def} 143 | networks = { 144 | "actor": Policy( 145 | action_dim=actions.shape[-1], **network_kwargs, **policy_kwargs 146 | ) 147 | } 148 | 149 | model_def = ActorCriticWrapper( 150 | encoders=encoders, 151 | networks=networks, 152 | ) 153 | 154 | lr_schedule = optax.warmup_cosine_decay_schedule( 155 | init_value=0.0, 156 | peak_value=learning_rate, 157 | warmup_steps=warmup_steps, 158 | decay_steps=decay_steps, 159 | end_value=0.0, 160 | ) 161 | tx = optax.adam(lr_schedule) 162 | 163 | rng, init_rng = jax.random.split(rng) 164 | params = model_def.init(init_rng, (observations, goals), actions)["params"] 165 | 166 | rng, create_rng = jax.random.split(rng) 167 | state = JaxRLTrainState.create( 168 | apply_fn=model_def.apply, 169 | params=params, 170 | txs=tx, 171 | target_params=params, 172 | rng=create_rng, 173 | ) 174 | 175 | return cls(state, lr_schedule) 176 | -------------------------------------------------------------------------------- /jaxrl_m/data/ego4d.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | PROTO_TYPE_SPEC = {"images": tf.string, "text": tf.string} 5 | 6 | 7 | def get_ego4d_dataloader( 8 | path, 9 | batch_size, 10 | shuffle_buffer_size=25000, 11 | cache=False, 12 | take_every_n=1, 13 | val_split=0.1, 14 | seed=42, 15 | train_file_idxs=None, 16 | val_file_idxs=None, 17 | ): 18 | 19 | # get the tfrecord files 20 | dataset = tf.data.Dataset.list_files(tf.io.gfile.join(path, "*.tfrecord")) 21 | 22 | # at this point the cardinality is still known, so we split it into train and val 23 | num_files = dataset.cardinality().numpy() 24 | num_val_files = int(num_files * val_split) 25 | num_train_files = num_files - num_val_files 26 | 27 | # shuffle the dataset 28 | dataset = dataset.shuffle(num_files, seed=seed) 29 | datasets = {} 30 | if train_file_idxs is None: 31 | # split into train and val at the file level 32 | datasets["train"] = dataset.take(num_train_files) 33 | datasets["val"] = dataset.skip(num_train_files) 34 | else: 35 | train_filter_func = lambda i, data: tf.reduce_any(tf.math.equal(tf.math.mod(i, 64), train_file_idxs)) 36 | datasets["train"] = dataset.enumerate().filter(train_filter_func).map(lambda i, data: data) 37 | val_filter_func = lambda i, data: tf.reduce_any(tf.math.equal(tf.math.mod(i, 64), val_file_idxs)) 38 | datasets["val"] = dataset.enumerate().filter(val_filter_func).map(lambda i, data: data) 39 | 40 | for split, dataset in datasets.items(): 41 | # read them 42 | dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) 43 | 44 | # decode the examples (yields videos) 45 | dataset = dataset.map(_decode_example, num_parallel_calls=tf.data.AUTOTUNE) 46 | 47 | # cache all the dataloading 48 | if cache: 49 | dataset = dataset.cache() 50 | 51 | # add goals (yields videos) 52 | dataset = dataset.map(_add_goals, num_parallel_calls=tf.data.AUTOTUNE) 53 | 54 | # unbatch to get individual frames 55 | dataset = dataset.unbatch() 56 | 57 | if take_every_n > 1: 58 | dataset = dataset.shard(take_every_n, index=0) 59 | 60 | 61 | # process each frame 62 | dataset = dataset.map(_process_frame, num_parallel_calls=tf.data.AUTOTUNE) 63 | 64 | # shuffle the dataset 65 | dataset = dataset.shuffle(shuffle_buffer_size, seed=seed) 66 | 67 | # batch the dataset 68 | dataset = dataset.batch( 69 | batch_size, num_parallel_calls=tf.data.AUTOTUNE, drop_remainder=True 70 | ) 71 | 72 | # restructure the batches 73 | dataset = dataset.map(_restructure_batch, num_parallel_calls=tf.data.AUTOTUNE) 74 | 75 | # always prefetch last 76 | dataset = dataset.prefetch(tf.data.AUTOTUNE) 77 | 78 | # repeat the dataset 79 | if split == "train": 80 | dataset = dataset.repeat() 81 | 82 | datasets[split] = dataset 83 | 84 | return datasets 85 | 86 | def _restructure_batch(batch): 87 | # batch is a dict with keys "images", "text", "goals", and "frame_indices" 88 | # "images" is a tensor of shape [batch_size, 224, 224, 3] 89 | # "text" is a tensor of shape [batch_size] 90 | # "goals" is a tensor of shape [batch_size] 91 | 92 | return { 93 | "observations": { 94 | "image": batch["images"], 95 | }, 96 | "goals": { 97 | "image": batch["goals"], 98 | "language": batch["text"], 99 | }, 100 | "actions": tf.zeros([tf.shape(batch["images"])[0], 10], dtype=tf.float32), 101 | } 102 | 103 | 104 | def _decode_example(example_proto): 105 | # decode the example proto according to PROTO_TYPE_SPEC 106 | features = { 107 | key: tf.io.FixedLenFeature([], tf.string) for key in PROTO_TYPE_SPEC.keys() 108 | } 109 | parsed_features = tf.io.parse_single_example(example_proto, features) 110 | parsed_tensors = { 111 | key: tf.io.parse_tensor(parsed_features[key], dtype) 112 | for key, dtype in PROTO_TYPE_SPEC.items() 113 | } 114 | 115 | return parsed_tensors 116 | 117 | 118 | def _add_goals(video): 119 | # video is a dict with keys "images" and "text" 120 | # "images" is a tensor of shape [n_frames, 224, 224, 3] 121 | # "text" is a tensor of shape [n_frames] 122 | 123 | # for now: for frame i, select a goal uniformly from the range [i, n_frames) 124 | num_frames = tf.shape(video["images"])[0] 125 | rand = tf.random.uniform(shape=[num_frames], minval=0, maxval=1, dtype=tf.float32) 126 | offsets = tf.cast( 127 | tf.floor(rand * tf.cast(tf.range(num_frames)[::-1], tf.float32)), tf.int32 128 | ) 129 | indices = tf.range(num_frames) + offsets 130 | video["goals"] = tf.gather(video["images"], indices) 131 | 132 | # for now: just get rid of text 133 | video["text"] = tf.tile(tf.expand_dims(video["text"], 0), [num_frames]) 134 | 135 | return video 136 | 137 | 138 | def _process_frame(frame): 139 | for key in ["images", "goals"]: 140 | frame[key] = tf.io.decode_jpeg(frame[key]) 141 | # this will throw an error if any images aren't 224x224x3 142 | frame[key] = tf.ensure_shape(frame[key], [224, 224, 3]) 143 | # may want to think more carefully about the resize method 144 | # frame[key] = tf.image.resize(frame[key], [128, 128], method="lanczos3") 145 | # normalize to [-1, 1] 146 | # frame[key] = (frame[key] / 127.5) - 1 147 | # convert to float32 148 | # frame[key] = tf.cast(frame[key], tf.float32) 149 | 150 | return frame 151 | 152 | 153 | if __name__ == "__main__": 154 | import tqdm 155 | 156 | datasets = get_ego4d_dataloader( 157 | "gs://rail-tpus-kevin/ego4d-tfrecord", 256, 100, cache=False, take_every_n=50 158 | ) 159 | 160 | 161 | with tqdm.tqdm() as pbar: 162 | for batch in datasets['train']: 163 | pbar.update(1) 164 | print(batch["goals"]["language"]) 165 | -------------------------------------------------------------------------------- /jaxrl_m/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from flax.core import frozen_dict 5 | from gym.utils import seeding 6 | 7 | from jaxrl_m.common.typing import Data 8 | 9 | DatasetDict = Dict[str, Data] 10 | 11 | 12 | def _check_lengths(dataset_dict: DatasetDict, dataset_len: Optional[int] = None) -> int: 13 | for v in dataset_dict.values(): 14 | if isinstance(v, dict): 15 | dataset_len = dataset_len or _check_lengths(v, dataset_len) 16 | elif isinstance(v, np.ndarray): 17 | item_len = len(v) 18 | dataset_len = dataset_len or item_len 19 | assert dataset_len == item_len, "Inconsistent item lengths in the dataset." 20 | else: 21 | raise TypeError("Unsupported type.") 22 | return dataset_len 23 | 24 | 25 | def _subselect(dataset_dict: DatasetDict, index: np.ndarray) -> DatasetDict: 26 | new_dataset_dict = {} 27 | for k, v in dataset_dict.items(): 28 | if isinstance(v, dict): 29 | new_v = _subselect(v, index) 30 | elif isinstance(v, np.ndarray): 31 | new_v = v[index] 32 | else: 33 | raise TypeError("Unsupported type.") 34 | new_dataset_dict[k] = new_v 35 | return new_dataset_dict 36 | 37 | 38 | def _sample( 39 | dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray 40 | ) -> DatasetDict: 41 | if isinstance(dataset_dict, np.ndarray): 42 | return dataset_dict[indx] 43 | elif isinstance(dataset_dict, dict): 44 | batch = {} 45 | for k, v in dataset_dict.items(): 46 | batch[k] = _sample(v, indx) 47 | else: 48 | raise TypeError("Unsupported type.") 49 | return batch 50 | 51 | 52 | class Dataset(object): 53 | def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): 54 | self.dataset_dict = dataset_dict 55 | self.dataset_len = _check_lengths(dataset_dict) 56 | 57 | # Seeding similar to OpenAI Gym 58 | self._np_random = None 59 | if seed is not None: 60 | self.seed(seed) 61 | 62 | @property 63 | def np_random(self) -> np.random.RandomState: 64 | if self._np_random is None: 65 | self.seed() 66 | return self._np_random 67 | 68 | def seed(self, seed: Optional[int] = None) -> list: 69 | self._np_random, seed = seeding.np_random(seed) 70 | return [seed] 71 | 72 | def __len__(self) -> int: 73 | return self.dataset_len 74 | 75 | def sample( 76 | self, 77 | batch_size: int, 78 | keys: Optional[Iterable[str]] = None, 79 | indx: Optional[np.ndarray] = None, 80 | ) -> frozen_dict.FrozenDict: 81 | if indx is None: 82 | if hasattr(self.np_random, "integers"): 83 | indx = self.np_random.integers(len(self), size=batch_size) 84 | else: 85 | indx = self.np_random.randint(len(self), size=batch_size) 86 | 87 | batch = dict() 88 | 89 | if keys is None: 90 | keys = self.dataset_dict.keys() 91 | 92 | for k in keys: 93 | if isinstance(self.dataset_dict[k], dict): 94 | batch[k] = _sample(self.dataset_dict[k], indx) 95 | else: 96 | batch[k] = self.dataset_dict[k][indx] 97 | 98 | return frozen_dict.freeze(batch) 99 | 100 | def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: 101 | assert 0 < ratio and ratio < 1 102 | train_index = np.index_exp[: int(self.dataset_len * ratio)] 103 | test_index = np.index_exp[int(self.dataset_len * ratio) :] 104 | 105 | index = np.arange(len(self), dtype=np.int32) 106 | self.np_random.shuffle(index) 107 | train_index = index[: int(self.dataset_len * ratio)] 108 | test_index = index[int(self.dataset_len * ratio) :] 109 | 110 | train_dataset_dict = _subselect(self.dataset_dict, train_index) 111 | test_dataset_dict = _subselect(self.dataset_dict, test_index) 112 | return Dataset(train_dataset_dict), Dataset(test_dataset_dict) 113 | 114 | def _trajectory_boundaries_and_returns(self) -> Tuple[list, list, list]: 115 | episode_starts = [0] 116 | episode_ends = [] 117 | 118 | episode_return = 0 119 | episode_returns = [] 120 | 121 | for i in range(len(self)): 122 | episode_return += self.dataset_dict["rewards"][i] 123 | 124 | if self.dataset_dict["dones"][i]: 125 | episode_returns.append(episode_return) 126 | episode_ends.append(i + 1) 127 | if i + 1 < len(self): 128 | episode_starts.append(i + 1) 129 | episode_return = 0.0 130 | 131 | return episode_starts, episode_ends, episode_returns 132 | 133 | def filter( 134 | self, percentile: Optional[float] = None, threshold: Optional[float] = None 135 | ): 136 | assert (percentile is None and threshold is not None) or ( 137 | percentile is not None and threshold is None 138 | ) 139 | 140 | ( 141 | episode_starts, 142 | episode_ends, 143 | episode_returns, 144 | ) = self._trajectory_boundaries_and_returns() 145 | 146 | if percentile is not None: 147 | threshold = np.percentile(episode_returns, 100 - percentile) 148 | 149 | bool_indx = np.full((len(self),), False, dtype=bool) 150 | 151 | for i in range(len(episode_returns)): 152 | if episode_returns[i] >= threshold: 153 | bool_indx[episode_starts[i] : episode_ends[i]] = True 154 | 155 | self.dataset_dict = _subselect(self.dataset_dict, bool_indx) 156 | 157 | self.dataset_len = _check_lengths(self.dataset_dict) 158 | 159 | def normalize_returns(self, scaling: float = 1000): 160 | (_, _, episode_returns) = self._trajectory_boundaries_and_returns() 161 | self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min( 162 | episode_returns 163 | ) 164 | self.dataset_dict["rewards"] *= scaling 165 | -------------------------------------------------------------------------------- /jaxrl_m/envs/wrappers/video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import gym 5 | import imageio 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from typing import Optional 10 | 11 | # Take from 12 | # https://github.com/denisyarats/pytorch_sac/ 13 | 14 | 15 | def compose_frames( 16 | all_frames: List[np.ndarray], 17 | num_videos_per_row: int, 18 | margin: int = 4, 19 | ): 20 | num_episodes = len(all_frames) 21 | 22 | if num_videos_per_row is None: 23 | num_videos_per_row = num_episodes 24 | 25 | t = 0 26 | end_of_all_epidoes = False 27 | frames_to_save = [] 28 | while not end_of_all_epidoes: 29 | frames_t = [] 30 | 31 | for i in range(num_episodes): 32 | # If the episode is shorter, repeat the last frame. 33 | t_ = min(t, len(all_frames[i]) - 1) 34 | frame_i_t = all_frames[i][t_] 35 | 36 | # Add the lines. 37 | frame_i_t = np.pad( 38 | frame_i_t, 39 | [[margin, margin], [margin, margin], [0, 0]], 40 | "constant", 41 | constant_values=0, 42 | ) 43 | 44 | frames_t.append(frame_i_t) 45 | 46 | # Arrange the videos based on num_videos_per_row. 47 | frame_t = None 48 | while len(frames_t) >= num_videos_per_row: 49 | frames_t_this_row = frames_t[:num_videos_per_row] 50 | frames_t = frames_t[num_videos_per_row:] 51 | 52 | frame_t_this_row = np.concatenate(frames_t_this_row, axis=1) 53 | if frame_t is None: 54 | frame_t = frame_t_this_row 55 | else: 56 | frame_t = np.concatenate([frame_t, frame_t_this_row], axis=0) 57 | 58 | frames_to_save.append(frame_t) 59 | t += 1 60 | end_of_all_epidoes = all([len(all_frames[i]) <= t for i in range(num_episodes)]) 61 | 62 | return frames_to_save 63 | 64 | 65 | class VideoRecorder(gym.Wrapper): 66 | def __init__( 67 | self, 68 | env: gym.Env, 69 | save_folder: str = "", 70 | save_prefix: str = None, 71 | height: int = 128, 72 | width: int = 128, 73 | fps: int = 30, 74 | camera_id: int = 0, 75 | goal_conditioned: bool = False, 76 | ): 77 | super().__init__(env) 78 | 79 | self.save_folder = save_folder 80 | self.save_prefix = save_prefix 81 | self.height = height 82 | self.width = width 83 | self.fps = fps 84 | self.camera_id = camera_id 85 | self.frames = [] 86 | self.goal_conditioned = goal_conditioned 87 | 88 | if not tf.io.gfile.exists(save_folder): 89 | tf.io.gfile.makedirs(save_folder) 90 | 91 | self.num_record_episodes = -1 92 | 93 | self.num_videos = 0 94 | 95 | # self.all_save_paths = None 96 | self.current_save_path = None 97 | 98 | def start_recording(self, num_episodes: int = None, num_videos_per_row: int = None): 99 | if num_videos_per_row is not None and num_episodes is not None: 100 | assert num_episodes >= num_videos_per_row 101 | 102 | self.num_record_episodes = num_episodes 103 | self.num_videos_per_row = num_videos_per_row 104 | 105 | # self.all_save_paths = [] 106 | self.all_frames = [] 107 | 108 | def stop_recording(self): 109 | self.num_record_episodes = None 110 | 111 | def step(self, action: np.ndarray): # NOQA 112 | 113 | if self.num_record_episodes is None or self.num_record_episodes == 0: 114 | observation, reward, terminated, truncated, info = self.env.step(action) 115 | 116 | elif self.num_record_episodes > 0: 117 | frame = self.env.render( 118 | height=self.height, width=self.width, camera_id=self.camera_id 119 | ) 120 | 121 | if frame is None: 122 | try: 123 | frame = self.sim.render( 124 | width=self.width, height=self.height, mode="offscreen" 125 | ) 126 | frame = np.flipud(frame) 127 | except Exception: 128 | raise NotImplementedError("Rendering is not implemented.") 129 | 130 | self.frames.append(frame.astype(np.uint8)) 131 | 132 | observation, reward, terminated, truncated, info = self.env.step(action) 133 | 134 | if terminated or truncated: 135 | if self.goal_conditioned: 136 | frames = [ 137 | np.concatenate([self.env.current_goal["image"], frame], axis=0) 138 | for frame in self.frames 139 | ] 140 | else: 141 | frames = self.frames 142 | 143 | self.all_frames.append(frames) 144 | 145 | if self.num_record_episodes > 0: 146 | self.num_record_episodes -= 1 147 | 148 | if self.num_record_episodes is None: 149 | # Plot one episode per file. 150 | frames_to_save = frames 151 | should_save = True 152 | elif self.num_record_episodes == 0: 153 | # Plot all episodes in one file. 154 | frames_to_save = compose_frames( 155 | self.all_frames, self.num_videos_per_row 156 | ) 157 | should_save = True 158 | else: 159 | should_save = False 160 | 161 | if should_save: 162 | filename = "%08d.mp4" % (self.num_videos) 163 | if self.save_prefix is not None and self.save_prefix != "": 164 | filename = f"{self.save_prefix}_{filename}" 165 | self.current_save_path = tf.io.gfile.join( 166 | self.save_folder, filename 167 | ) 168 | 169 | with tf.io.gfile.GFile(self.current_save_path, "wb") as f: 170 | imageio.mimsave(f, frames_to_save, "MP4", fps=self.fps) 171 | 172 | self.num_videos += 1 173 | 174 | self.frames = [] 175 | 176 | else: 177 | raise ValueError("Do not forget to call start_recording.") 178 | 179 | return observation, reward, terminated, truncated, info 180 | -------------------------------------------------------------------------------- /jaxrl_m/common/encoding.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from typing import Dict, Optional, Tuple 3 | import jax 4 | import jax.numpy as jnp 5 | import tensorflow as tf 6 | import tensorflow_hub as hub 7 | import tensorflow_text 8 | import functools 9 | import inspect 10 | 11 | 12 | class EncodingWrapper(nn.Module): 13 | """ 14 | Encodes observations into a single flat encoding, adding additional 15 | functionality for adding proprioception and stopping the gradient. 16 | 17 | Args: 18 | encoder: The encoder network. 19 | use_proprio: Whether to concatenate proprioception (after encoding). 20 | stop_gradient: Whether to stop the gradient after the encoder. 21 | """ 22 | 23 | encoder: nn.Module 24 | use_proprio: bool 25 | stop_gradient: bool 26 | 27 | def __call__(self, observations: Dict[str, jnp.ndarray]) -> jnp.ndarray: 28 | encoding = self.encoder(observations["image"]) 29 | if self.use_proprio: 30 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 31 | if self.stop_gradient: 32 | encoding = jax.lax.stop_gradient(encoding) 33 | return encoding 34 | 35 | 36 | class GCEncodingWrapper(nn.Module): 37 | """ 38 | Encodes observations and goals into a single flat encoding. Handles all the 39 | logic about when/how to combine observations and goals. 40 | 41 | Takes a tuple (observations, goals) as input. 42 | 43 | Args: 44 | encoder: The encoder network for observations. 45 | goal_encoder: The encoder to use for goals (optional). If None, early 46 | goal concatenation is used, i.e. the goal is concatenated to the 47 | observation channel-wise before passing it through the encoder. 48 | use_proprio: Whether to concatenate proprioception (after encoding). 49 | stop_gradient: Whether to stop the gradient after the encoder. 50 | """ 51 | 52 | encoder: nn.Module 53 | goal_encoder: Optional[nn.Module] 54 | use_proprio: bool 55 | stop_gradient: bool 56 | 57 | def __call__( 58 | self, 59 | observations_and_goals: Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]], 60 | ) -> jnp.ndarray: 61 | observations, goals = observations_and_goals 62 | if self.goal_encoder is None: 63 | # early goal concat 64 | encoder_inputs = jnp.concatenate( 65 | [observations["image"], goals["image"]], axis=-1 66 | ) 67 | encoding = self.encoder(encoder_inputs) 68 | else: 69 | # late fusion 70 | encoding = self.encoder(observations["image"]) 71 | goal_encoding = self.goal_encoder(goals["image"]) 72 | encoding = jnp.concatenate([encoding, goal_encoding], axis=-1) 73 | 74 | if self.use_proprio: 75 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 76 | 77 | if self.stop_gradient: 78 | encoding = jax.lax.stop_gradient(encoding) 79 | 80 | return encoding 81 | 82 | 83 | def task_expand(task, obs): 84 | if len(task.shape) + 2 == len(obs.shape): 85 | x = jnp.array(task) 86 | x = jnp.expand_dims(x, (-2, -3)) 87 | x = jnp.repeat(x, obs.shape[-2], axis=-2) 88 | x = jnp.repeat(x, obs.shape[-3], axis=-3) 89 | elif len(task.shape) == len(obs.shape): 90 | x = task 91 | else: 92 | assert False, (task.shape, obs.shape) 93 | return jnp.concatenate([obs, x], axis=-1) 94 | 95 | 96 | class MultimodalEncodingWrapper(nn.Module): 97 | task_encoder: nn.Module 98 | modality: str 99 | early_fusion: bool 100 | use_proprio: bool 101 | stop_gradient: bool 102 | early_fuse_initial_obs: bool 103 | no_initial: bool 104 | 105 | def __call__( 106 | self, 107 | observations_and_goals: Tuple[ 108 | Dict[str, jnp.ndarray], Dict[str, jnp.ndarray], Dict[str, jnp.ndarray] 109 | ], 110 | encoder=None, 111 | override_modality=None, 112 | ) -> jnp.ndarray: 113 | if override_modality is not None: 114 | modality = override_modality 115 | else: 116 | modality = self.modality 117 | 118 | observations, goals, initial_obs = observations_and_goals 119 | 120 | if modality == "image" and "image_embed" in goals: 121 | task_embed = goals["image_embed"] 122 | task_embed = task_embed / jnp.linalg.norm( 123 | task_embed, axis=-1, keepdims=True 124 | ) 125 | elif modality == "language" and "text_embed" in goals: 126 | task_embed = goals["text_embed"] 127 | task_embed = task_embed / jnp.linalg.norm( 128 | task_embed, axis=-1, keepdims=True 129 | ) 130 | else: 131 | if self.no_initial: 132 | init_obs = observations["image"] 133 | if "raw" in inspect.signature(self.task_encoder).parameters: 134 | init_obs_unprocessed = observations["unprocessed_image"] 135 | else: 136 | init_obs = initial_obs["image"] 137 | if "raw" in inspect.signature(self.task_encoder).parameters: 138 | init_obs_unprocessed = initial_obs["unprocessed_image"] 139 | 140 | if "raw" in inspect.signature(self.task_encoder).parameters: 141 | task_embed = self.task_encoder( 142 | init_obs, goals[modality], raw=init_obs_unprocessed 143 | ) 144 | else: 145 | task_embed = self.task_encoder(init_obs, goals[modality]) 146 | 147 | if encoder == None: 148 | return None, task_embed 149 | 150 | # images are processed for CLIP, but the resnet should get unprocessed images 151 | # TODO handle this in eval script 152 | if "unprocessed_image" in observations: 153 | obs_img = observations["unprocessed_image"] 154 | else: 155 | obs_img = observations["image"] 156 | 157 | if "unprocessed_image" in initial_obs: 158 | init_obs_img = initial_obs["unprocessed_image"] 159 | else: 160 | init_obs_img = initial_obs["image"] 161 | 162 | if self.early_fuse_initial_obs: 163 | obs = jnp.concatenate([init_obs_img, obs_img], axis=-1) 164 | else: 165 | obs = obs_img 166 | 167 | if self.early_fusion: 168 | encoding = encoder(obs, task_embed) 169 | else: 170 | encoding = encoder(obs) 171 | encoding = jnp.concatenate([encoding, task_embed], axis=-1) 172 | 173 | if self.use_proprio: 174 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 175 | 176 | if self.stop_gradient: 177 | encoding = jax.lax.stop_gradient(encoding) 178 | 179 | return encoding, task_embed 180 | -------------------------------------------------------------------------------- /jaxrl_m/agents/discrete/gc_iql.py: -------------------------------------------------------------------------------- 1 | """Implementations of goal-conditioned IQL (w/ no Q function) in discrete action spaces.""" 2 | 3 | import functools 4 | import copy 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import flax 9 | import flax.linen as nn 10 | import optax 11 | import distrax 12 | 13 | from jaxrl_m.common.typing import * 14 | from jaxrl_m.common.common import TrainState, nonpytree_field, target_update 15 | from jaxrl_m.networks.discrete_nets import DiscreteCriticHead 16 | from jaxrl_m.networks.actor_critic_nets import get_encoding, ValueCritic 17 | from jaxrl_m.agents.discrete.iql import iql_value_loss, iql_critic_loss, iql_actor_loss 18 | import ml_collections 19 | 20 | 21 | class DiscreteGCIQLMultiplexer(nn.Module): 22 | encoder: nn.Module 23 | goal_encoder: nn.Module 24 | networks: Dict[str, nn.Module] 25 | 26 | def get_sg_latent(self, observations, goals): 27 | latents = get_encoding(self.encoder, observations) 28 | goal_latents = get_encoding(self.goal_encoder, goals) 29 | return jnp.concatenate([latents, goal_latents], axis=-1) 30 | 31 | def __call__(self, observations, goals): 32 | latents = self.get_sg_latent(observations, goals) 33 | return {k: net(latents) for k, net in self.networks.items()} 34 | 35 | def actor(self, observations, goals): 36 | latents = self.get_sg_latent(observations, goals) 37 | return self.networks["actor"](latents) 38 | 39 | def value(self, observations, goals): 40 | latents = self.get_sg_latent(observations, goals) 41 | return self.networks["value"](latents) 42 | 43 | def critic(self, observations, goals): 44 | latents = self.get_sg_latent(observations, goals) 45 | return self.networks["critic"](latents) 46 | 47 | 48 | class GoalConditionedIQLAgent(flax.struct.PyTreeNode): 49 | model: TrainState 50 | target_model: TrainState 51 | config: dict = nonpytree_field() 52 | 53 | # What model method to call to get values / distribution 54 | value_method: ModuleMethod = nonpytree_field(default="value") 55 | actor_method: ModuleMethod = nonpytree_field(default="actor") 56 | critic_method: ModuleMethod = nonpytree_field( 57 | default="critic" 58 | ) # Unused in current update 59 | 60 | @functools.partial(jax.pmap, axis_name="pmap") 61 | def update(agent, batch: Batch): 62 | def value_loss_fn(params): 63 | nv = agent.target_model( 64 | batch["next_observations"], batch["goals"], method=agent.value_method 65 | ) 66 | target_q = batch["rewards"] + agent.config["discount"] * nv * batch["masks"] 67 | v = agent.model( 68 | batch["observations"], 69 | batch["goals"], 70 | params=params, 71 | method=agent.value_method, 72 | ) 73 | return iql_value_loss(target_q, v, agent.config["expectile"]) 74 | 75 | def actor_loss_fn(params): 76 | nv = agent.target_model( 77 | batch["next_observations"], batch["goals"], method=agent.value_method 78 | ) 79 | target_q = batch["rewards"] + agent.config["discount"] * nv * batch["masks"] 80 | 81 | v = agent.model( 82 | batch["observations"], batch["goals"], method=agent.value_method 83 | ) 84 | 85 | logits = agent.model( 86 | batch["observations"], 87 | batch["goals"], 88 | params=params, 89 | method=agent.actor_method, 90 | ) 91 | dist = distrax.Categorical(logits=logits) 92 | 93 | return iql_actor_loss( 94 | target_q, v, dist, batch["actions"], agent.config["temperature"] 95 | ) 96 | 97 | def loss_fn(params): 98 | value_loss, value_info = value_loss_fn(params) 99 | actor_loss, actor_info = actor_loss_fn(params) 100 | 101 | return value_loss + actor_loss, {**value_info, **actor_info} 102 | 103 | new_model, info = agent.model.apply_loss_fn( 104 | loss_fn=loss_fn, has_aux=True, pmap_axis="pmap" 105 | ) 106 | new_target_model = target_update( 107 | agent.model, agent.target_model, agent.config["target_update_rate"] 108 | ) 109 | 110 | return agent.replace(model=new_model, target_model=new_target_model), info 111 | 112 | @functools.partial(jax.jit, static_argnames=("argmax")) 113 | def sample_actions( 114 | agent, observations, goals, *, seed, temperature=1.0, argmax=False 115 | ): 116 | dist = agent.model( 117 | observations, goals, temperature=temperature, method=agent.actor_method 118 | ) 119 | if argmax: 120 | return dist.mode() 121 | else: 122 | return dist.sample(seed=seed) 123 | 124 | 125 | def create_discrete_iql_learner( 126 | seed: int, 127 | observations: jnp.ndarray, 128 | goals: jnp.ndarray, 129 | n_actions: int, 130 | # Model architecture 131 | encoder_def: nn.Module, 132 | shared_goal_encoder: bool = False, 133 | network_kwargs: dict = { 134 | "hidden_dims": [256, 256], 135 | }, 136 | # Optimizer 137 | optim_kwargs: dict = { 138 | "learning_rate": 6e-5, 139 | }, 140 | # Algorithm config 141 | discount=0.95, 142 | expectile=0.9, 143 | temperature=1.0, 144 | target_update_rate=0.002, 145 | **kwargs 146 | ): 147 | 148 | print("Extra kwargs:", kwargs) 149 | 150 | rng = jax.random.PRNGKey(seed) 151 | 152 | if shared_goal_encoder: 153 | goal_encoder_def = encoder_def 154 | else: 155 | goal_encoder_def = copy.deepcopy(encoder_def) 156 | 157 | model_def = DiscreteGCIQLMultiplexer( 158 | encoder=encoder_def, 159 | goal_encoder=goal_encoder_def, 160 | networks={ 161 | "actor": DiscreteCriticHead(n_actions=n_actions, **network_kwargs), 162 | "value": ValueCritic(**network_kwargs), 163 | }, 164 | ) 165 | tx = optax.adam(**optim_kwargs) 166 | 167 | params = model_def.init(rng, observations, goals)["params"] 168 | model = TrainState.create(model_def, params, tx=tx) 169 | target_model = TrainState.create(model_def, params) 170 | 171 | config = flax.core.FrozenDict( 172 | dict( 173 | discount=discount, 174 | temperature=temperature, 175 | target_update_rate=target_update_rate, 176 | expectile=expectile, 177 | ) 178 | ) 179 | return GoalConditionedIQLAgent(model, target_model, config) 180 | 181 | 182 | def get_default_config(): 183 | config = ml_collections.ConfigDict( 184 | { 185 | "algo": "gc_iql", 186 | "optim_kwargs": {"learning_rate": 6e-5, "eps": 0.00015}, 187 | "network_kwargs": { 188 | "hidden_dims": (256, 256), 189 | }, 190 | "discount": 0.95, 191 | "expectile": 0.9, 192 | "temperature": 1.0, 193 | "target_update_rate": 0.002, 194 | "shared_goal_encoder": False, 195 | } 196 | ) 197 | return config 198 | -------------------------------------------------------------------------------- /jaxrl_m/agents/discrete/iql.py: -------------------------------------------------------------------------------- 1 | """Implementations of IQL (w/ no Q function) in discrete action spaces.""" 2 | import functools 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import flax 7 | import flax.linen as nn 8 | import optax 9 | import distrax 10 | 11 | from jaxrl_m.common.typing import * 12 | from jaxrl_m.common.common import TrainState, nonpytree_field, target_update 13 | from jaxrl_m.networks.discrete_nets import DiscreteCriticHead 14 | from jaxrl_m.networks.actor_critic_nets import get_latent, ValueCritic 15 | 16 | import ml_collections 17 | 18 | 19 | def expectile_loss(diff, expectile=0.8): 20 | weight = jnp.where(diff > 0, expectile, (1 - expectile)) 21 | return weight * (diff**2) 22 | 23 | 24 | def iql_value_loss(q, v, expectile): 25 | value_loss = expectile_loss(q - v, expectile) 26 | return value_loss.mean(), { 27 | "value_loss": value_loss.mean(), 28 | "v": v.mean(), 29 | } 30 | 31 | 32 | def iql_critic_loss(q, q_target): 33 | critic_loss = jnp.square(q - q_target) 34 | return critic_loss.mean(), { 35 | "critic_loss": critic_loss.mean(), 36 | "q": q.mean(), 37 | } 38 | 39 | 40 | def iql_actor_loss(q, v, dist, actions, temperature=1.0): 41 | a = q - v 42 | 43 | exp_a = jnp.exp(a / temperature) 44 | exp_a = jnp.minimum(exp_a, 100.0) 45 | 46 | log_probs = dist.log_prob(actions) 47 | actor_loss = -(exp_a * log_probs).mean() 48 | 49 | behavior_accuracy = jnp.mean(jnp.equal(dist.mode(), actions)) 50 | 51 | return actor_loss, { 52 | "actor_loss": actor_loss, 53 | "behavior_logprob": log_probs.mean(), 54 | "behavior_accuracy": behavior_accuracy, 55 | "mean a": a.mean(), 56 | "max a": a.max(), 57 | "min a": a.min(), 58 | } 59 | 60 | 61 | class DiscreteIQLMultiplexer(nn.Module): 62 | encoder: nn.Module 63 | networks: Dict[str, nn.Module] 64 | 65 | def __call__(self, observations): 66 | latents = get_latent(self.encoder, observations) 67 | return {k: net(latents) for k, net in self.networks.items()} 68 | 69 | def actor(self, observations): 70 | latents = get_latent(self.encoder, observations) 71 | return self.networks["actor"](latents) 72 | 73 | def value(self, observations): 74 | latents = get_latent(self.encoder, observations) 75 | return self.networks["value"](latents) 76 | 77 | def critic(self, observations): 78 | latents = get_latent(self.encoder, observations) 79 | return self.networks["critic"](latents) 80 | 81 | 82 | class IQLAgent(flax.struct.PyTreeNode): 83 | model: TrainState 84 | target_model: TrainState 85 | config: dict = nonpytree_field() 86 | 87 | value_method: ModuleMethod = nonpytree_field(default="value") 88 | actor_method: ModuleMethod = nonpytree_field(default="actor") 89 | critic_method: ModuleMethod = nonpytree_field( 90 | default="critic" 91 | ) # Unused in current update 92 | 93 | @functools.partial(jax.pmap, axis_name="pmap") 94 | def update(agent, batch: Batch): 95 | def value_loss_fn(params): 96 | nv = agent.target_model( 97 | batch["next_observations"], method=agent.value_method 98 | ) 99 | target_q = batch["rewards"] + agent.config["discount"] * nv * batch["masks"] 100 | v = agent.model( 101 | batch["observations"], params=params, method=agent.value_method 102 | ) 103 | return iql_value_loss(target_q, v, agent.config["expectile"]) 104 | 105 | def actor_loss_fn(params): 106 | nv = agent.target_model( 107 | batch["next_observations"], method=agent.value_method 108 | ) 109 | target_q = batch["rewards"] + agent.config["discount"] * nv * batch["masks"] 110 | 111 | v = agent.model(batch["observations"], method=agent.value_method) 112 | 113 | logits = agent.model( 114 | batch["observations"], params=params, method=agent.actor_method 115 | ) 116 | dist = distrax.Categorical(logits=logits) 117 | 118 | return iql_actor_loss( 119 | target_q, v, dist, batch["actions"], agent.config["temperature"] 120 | ) 121 | 122 | def loss_fn(params): 123 | value_loss, value_info = value_loss_fn(params) 124 | actor_loss, actor_info = actor_loss_fn(params) 125 | 126 | return value_loss + actor_loss, {**value_info, **actor_info} 127 | 128 | new_model, info = agent.model.apply_loss_fn( 129 | loss_fn=loss_fn, has_aux=True, pmap_axis="pmap" 130 | ) 131 | new_target_model = target_update( 132 | agent.model, agent.target_model, agent.config["target_update_rate"] 133 | ) 134 | 135 | return agent.replace(model=new_model, target_model=new_target_model), info 136 | 137 | @functools.partial(jax.jit, static_argnames=("argmax")) 138 | def sample_actions(agent, observations, *, seed, temperature=1.0, argmax=False): 139 | logits = agent.model(observations, method=agent.actor_method) 140 | dist = distrax.Categorical(logits=logits / temperature) 141 | 142 | if argmax: 143 | return dist.mode() 144 | else: 145 | return dist.sample(seed=seed) 146 | 147 | 148 | def create_iql_learner( 149 | seed: int, 150 | observations: jnp.ndarray, 151 | n_actions: int, 152 | # Model architecture 153 | encoder_def: nn.Module, 154 | network_kwargs: dict = { 155 | "hidden_dims": [256, 256], 156 | }, 157 | # Optimizer 158 | optim_kwargs: dict = { 159 | "learning_rate": 6e-5, 160 | }, 161 | # Algorithm config 162 | discount=0.95, 163 | expectile=0.9, 164 | temperature=1.0, 165 | target_update_rate=0.002, 166 | **kwargs 167 | ): 168 | 169 | print("Extra kwargs:", kwargs) 170 | 171 | rng = jax.random.PRNGKey(seed) 172 | 173 | model_def = DiscreteIQLMultiplexer( 174 | encoder=encoder_def, 175 | networks={ 176 | "critic": DiscreteCriticHead(n_actions=n_actions, **network_kwargs), 177 | "actor": DiscreteCriticHead(n_actions=n_actions, **network_kwargs), 178 | "value": ValueCritic(**network_kwargs), 179 | }, 180 | ) 181 | 182 | tx = optax.adam(**optim_kwargs) 183 | 184 | params = model_def.init(rng, observations)["params"] 185 | model = TrainState.create(model_def, params, tx=tx) 186 | target_model = TrainState.create(model_def, params) 187 | 188 | config = flax.core.FrozenDict( 189 | dict( 190 | discount=discount, 191 | temperature=temperature, 192 | target_update_rate=target_update_rate, 193 | expectile=expectile, 194 | ) 195 | ) 196 | 197 | return IQLAgent(model, target_model, config) 198 | 199 | 200 | def get_default_config(): 201 | config = ml_collections.ConfigDict( 202 | { 203 | "algo": "iql", 204 | "optim_kwargs": {"learning_rate": 6e-5, "eps": 0.00015}, 205 | "network_kwargs": { 206 | "hidden_dims": (256, 256), 207 | }, 208 | "discount": 0.95, 209 | "expectile": 0.9, 210 | "temperature": 1.0, 211 | "target_update_rate": 0.002, 212 | } 213 | ) 214 | return config 215 | -------------------------------------------------------------------------------- /experiments/configs/offline_contrastive_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import flax.linen as nn 3 | 4 | multimodal_config_proto = dict( 5 | device="gpu", # tpu 6 | resume_path=None, 7 | resume_clip_parts_path=None, 8 | resume_step=None, 9 | seed=42, 10 | num_steps=5000, 11 | log_interval=25, 12 | eval_interval=100, 13 | save_interval=100, 14 | save_dir="gs://rail-tpus-andre/logs", 15 | dataset="bridgedata", # "ego4d" or "bridgedata" 16 | lang_encoder=dict( 17 | # type="muse", 18 | type="clip", 19 | clip_variant="openai/clip-vit-base-patch32", 20 | # type="pretrained", 21 | # name="distilbert-base-uncased", 22 | kwargs=dict(mlp_kwargs=None, freeze_encoder=False), 23 | ), 24 | image_encoder=dict( 25 | type="clip", # "encoders", 26 | clip_variant="openai/clip-vit-base-patch32", 27 | clip_use_pretrained_params=True, 28 | # name="resnetv1-34-bridge", 29 | # kwargs=dict( 30 | # pooling_method="avg", 31 | # add_spatial_coordinates=True, 32 | # act="swish", 33 | # ), 34 | ), 35 | agent_kwargs=dict( 36 | learning_rate=3e-5, 37 | text_learning_rate=3e-6, 38 | warmup_steps=2000, 39 | decay_steps=int(2e6), 40 | dropout_rate=0.0, 41 | mlp_kwargs=None, # dict( 42 | # hidden_dims=(512, ), #512, 512), 43 | # activation=nn.relu, 44 | # activate_final=False, 45 | # ), 46 | ), 47 | ego4d_kwargs=dict( 48 | path="gs://rail-tpus-kevin/ego4d-tfrecord", 49 | batch_size=64, 50 | shuffle_buffer_size=10, 51 | take_every_n=50, 52 | cache=False, 53 | seed=32, 54 | train_file_idxs=list(range(32)), 55 | val_file_idxs=list(range(32, 35)), 56 | ), 57 | bridge_data_path="gs://rail-tpus-andre/new_tf", 58 | bridge_batch_size=256, 59 | bridge_val_batch_size=64, 60 | bridge_split_strategy="task", 61 | bridge_split_prop=0.2, 62 | bridge_augment_tasks=True, 63 | dataset_kwargs=dict( 64 | shuffle_buffer_size=25000, 65 | labeled_ony=True, 66 | simple_goal=True, 67 | prefetch_num_batches=20, 68 | relabel_actions=True, 69 | goal_relabel_reached_proportion=0.0, 70 | augment=True, 71 | augment_next_obs_goal_differently=False, 72 | augment_kwargs=dict( 73 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 74 | random_brightness=[0.2], 75 | random_contrast=[0.8, 1.2], 76 | random_saturation=[0.8, 1.2], 77 | random_hue=[0.1], 78 | augment_order=[ 79 | "random_resized_crop", 80 | "random_brightness", 81 | "random_contrast", 82 | "random_saturation", 83 | "random_hue", 84 | ], 85 | ), 86 | ), 87 | ss2_train_path="gs://rail-tpus-andre/something-something/tf_fixed/train/", 88 | ss2_val_path="gs://rail-tpus-andre/something-something/tf_fixed/validation/", 89 | ss2_labels_path="gs://rail-tpus-andre/something-something/labels/", 90 | ss2_batch_size=256, 91 | ss2_val_batch_size=256, 92 | ss2_dataset_kwargs=dict( 93 | shuffle_buffer_size=25000, 94 | prefetch_num_batches=20, 95 | augment=True, 96 | augment_next_obs_goal_differently=False, 97 | augment_kwargs=dict( 98 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 99 | random_brightness=[0.2], 100 | random_contrast=[0.8, 1.2], 101 | random_saturation=[0.8, 1.2], 102 | random_hue=[0.1], 103 | augment_order=[ 104 | "random_resized_crop", 105 | "random_brightness", 106 | "random_contrast", 107 | "random_saturation", 108 | "random_hue", 109 | ], 110 | ), 111 | ), 112 | val_set_path="gs://rail-tpus-andre/bridge_validation/scene1", 113 | # "gs://rail-tpus-andre/bridge_validation/scene2", 114 | # "gs://rail-tpus-andre/bridge_validation/scene3", 115 | # "gs://rail-tpus-andre/bridge_validation/scene4", 116 | # "gs://rail-tpus-andre/bridge_validation/transfer", 117 | ) 118 | 119 | 120 | def update_config(proto=multimodal_config_proto, **kwargs): 121 | result = dict(proto).copy() 122 | for key, value in kwargs.items(): 123 | if type(result.get(key)) == dict: 124 | value = dict(update_config(proto=result[key], **kwargs[key])) 125 | result[key] = value 126 | return ml_collections.ConfigDict(result) 127 | 128 | 129 | def get_config(config_string): 130 | possible_structures = { 131 | "lcbc": update_config( 132 | task_encoders=dict(language="resnetv1-18-bridge-task"), 133 | ), 134 | "sg_sl": update_config( 135 | task_encoders=dict( 136 | image="resnetv1-18-bridge", language="resnetv1-18-bridge-task" 137 | ), 138 | ), 139 | "sg_sl_align": update_config( 140 | task_encoders=dict( 141 | image="resnetv1-18-bridge", language="resnetv1-18-bridge-task" 142 | ), 143 | agent_kwargs=dict( 144 | alignment=1.0, 145 | ), 146 | ), 147 | "tiny": update_config( 148 | task_encoders=dict(image="resnetv1-18-bridge", language=""), 149 | agent_kwargs=dict( 150 | alignment=1.0, 151 | ), 152 | pretrained_encoder="distilbert-base-uncased", 153 | ), 154 | "contrastive_tpu": update_config( 155 | device="tpu", 156 | ), 157 | "contrastive_gpu": update_config( 158 | device="gpu", 159 | bridge_batch_size=64, 160 | bridge_val_batch_size=64, 161 | dataset_kwargs=dict( 162 | shuffle_buffer_size=500, 163 | ), 164 | ), 165 | } 166 | 167 | possible_structures["contrastive_tpu_muse"] = update_config( 168 | possible_structures["contrastive_tpu"], 169 | lang_encoder=dict(type="muse"), 170 | ) 171 | 172 | possible_structures["contrastive_tpu_resnet_muse"] = update_config( 173 | possible_structures["contrastive_tpu"], 174 | lang_encoder=dict(type="muse"), 175 | image_encoder=dict( 176 | type="encoders", 177 | name="resnetv1-18-bridge", 178 | kwargs=dict( 179 | pooling_method="avg", 180 | add_spatial_coordinates=True, 181 | act="swish", 182 | ), 183 | mlp_kwargs=dict( 184 | hidden_dims=(512, 512), 185 | activation=nn.relu, 186 | activate_final=False, 187 | ), 188 | ), 189 | ) 190 | 191 | possible_structures["contrastive_tpu_resnet_clip_lang"] = update_config( 192 | possible_structures["contrastive_tpu"], 193 | lang_encoder=dict( 194 | type="clip", 195 | clip_variant="openai/clip-vit-base-patch32", 196 | kwargs=dict(mlp_kwargs=None, freeze_encoder=True), 197 | ), 198 | image_encoder=dict( 199 | type="encoders", 200 | name="resnetv1-18-bridge", 201 | kwargs=dict( 202 | pooling_method="avg", 203 | add_spatial_coordinates=True, 204 | act="swish", 205 | ), 206 | mlp_kwargs=dict( 207 | hidden_dims=(512, 512), 208 | activation=nn.relu, 209 | activate_final=False, 210 | ), 211 | ), 212 | ) 213 | 214 | return possible_structures[config_string] 215 | -------------------------------------------------------------------------------- /jaxrl_m/data/ss2.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Union, Optional, Iterable 2 | 3 | import tensorflow as tf 4 | import os 5 | from jaxrl_m.data.bridge_dataset import BridgeDataset 6 | 7 | EXCLUDE_KEYWORDS = ["Pretending", "Failing"] 8 | INCLUDE_TASKS_ = [ 9 | "Pushing [something] from right to left", 10 | "Moving [something] up", 11 | "Taking [something] out of [something]", 12 | "Pulling [something] from right to left", 13 | "Pushing [something] off of [something]", 14 | "Moving [something] down", 15 | "Pulling [something] out of [something]", 16 | "Pushing [something] from left to right", 17 | "Moving [something] closer to [something]", 18 | "Opening [something]", 19 | "Pulling [something] from left to right", 20 | "Moving [something] and [something] away from each other", 21 | "Putting [something] behind [something]", 22 | "Pushing [something] onto [something]", 23 | ] 24 | 25 | INCLUDE_TASKS = [ 26 | task.replace("[", "").replace("]", "").replace(" ", "_") for task in INCLUDE_TASKS_ 27 | ] 28 | 29 | 30 | class SS2Dataset(BridgeDataset): 31 | """ 32 | Dataloader for Something-Something V2 dataset. 33 | Imitating BridgeDataset. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | root_data_path: str, 39 | seed: int, 40 | batch_size: int, 41 | shuffle_buffer_size: int = 10000, 42 | prefetch_num_batches: int = 5, 43 | train: bool = True, 44 | augment: bool = False, 45 | augment_next_obs_goal_differently: bool = False, 46 | augment_kwargs: Optional[dict] = None, 47 | clip_preprocessing: bool = False, 48 | ): 49 | self.augment_kwargs = augment_kwargs or {} 50 | self.augment_next_obs_goal_differently = augment_next_obs_goal_differently 51 | 52 | # sub_folders = tf.io.gfile.glob(os.path.join(root_data_path, "*")) 53 | # sub_folders = [ 54 | # sub_folder 55 | # for sub_folder in sub_folders 56 | # if not any([keyword in sub_folder for keyword in EXCLUDE_KEYWORDS]) 57 | # ] 58 | # data_paths = [ 59 | # tf.io.gfile.glob(os.path.join(sub_folder, "*.tfrecord")) 60 | # for sub_folder in sub_folders 61 | # ] 62 | 63 | # datasets = [] 64 | # sizes = [] 65 | # for sub_data_paths in data_paths: 66 | # sub_data_paths = [p for p in sub_data_paths if tf.io.gfile.exists(p)] 67 | # if len(sub_data_paths) == 0: 68 | # continue 69 | # print(f"Found {len(sub_data_paths)} tfrecords in {sub_data_paths[0]}") 70 | # datasets.append(self._construct_tf_dataset(sub_data_paths, seed)) 71 | # sizes.append(len(sub_data_paths)) 72 | 73 | data_paths = tf.io.gfile.glob(os.path.join(root_data_path, "*/*.tfrecord")) 74 | data_paths = [ 75 | data_path 76 | for data_path in data_paths 77 | if not any([keyword in data_path for keyword in EXCLUDE_KEYWORDS]) 78 | ] 79 | data_paths = [ 80 | data_path 81 | for data_path in data_paths 82 | if any([task in data_path for task in INCLUDE_TASKS]) 83 | ] 84 | 85 | # shuffle data paths 86 | data_paths = tf.random.shuffle(data_paths, seed=seed).numpy().tolist() 87 | print(f"Found {len(data_paths)} tfrecords in {root_data_path}") 88 | datasets = [self._construct_tf_dataset(data_paths, seed)] 89 | sizes = [len(data_paths)] 90 | 91 | total_size = sum(sizes) 92 | sample_weights = [size / total_size for size in sizes] 93 | 94 | if train: 95 | for i in range(len(datasets)): 96 | datasets[i] = ( 97 | datasets[i] 98 | .repeat() 99 | .shuffle( 100 | max(1, int(shuffle_buffer_size * sample_weights[i])), seed + i 101 | ) 102 | ) 103 | else: 104 | # TODO ? 105 | for i in range(len(datasets)): 106 | datasets[i] = datasets[i].shuffle( 107 | max(1, int(shuffle_buffer_size * sample_weights[i])), seed + i 108 | ) 109 | 110 | dataset = tf.data.Dataset.sample_from_datasets( 111 | datasets, 112 | sample_weights, 113 | seed=seed, 114 | stop_on_empty_dataset=train, 115 | ) 116 | 117 | if train and augment: 118 | dataset = dataset.enumerate(start=seed) 119 | dataset = dataset.map(self._augment, num_parallel_calls=tf.data.AUTOTUNE) 120 | 121 | dataset = dataset.batch( 122 | batch_size, 123 | num_parallel_calls=tf.data.AUTOTUNE, 124 | drop_remainder=True, 125 | deterministic=not train, 126 | ) 127 | 128 | if clip_preprocessing: 129 | dataset = dataset.map( 130 | self._clip_preprocess, num_parallel_calls=tf.data.AUTOTUNE 131 | ) 132 | 133 | dataset = dataset.prefetch(prefetch_num_batches) 134 | self.tf_dataset = dataset 135 | 136 | def _construct_tf_dataset(self, paths: List[str], seed: int) -> tf.data.Dataset: 137 | """ 138 | Constructs a tf.data.Dataset from a list of paths. 139 | The dataset yields a dictionary of tensors for each transition. 140 | """ 141 | dataset = tf.data.Dataset.from_tensor_slices(paths).shuffle( 142 | len(paths), seed=seed 143 | ) 144 | dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) 145 | dataset = dataset.map( 146 | self._decode_example_sgl, num_parallel_calls=tf.data.AUTOTUNE 147 | ) 148 | return dataset 149 | 150 | PROTO_TYPE_SPEC = { 151 | "frames": tf.uint8, 152 | "task_id": tf.uint32, 153 | } 154 | 155 | def _decode_example_sgl(self, example_proto): 156 | features = { 157 | key: tf.io.FixedLenFeature([], tf.string) 158 | for key in self.PROTO_TYPE_SPEC.keys() 159 | } 160 | parsed_features = tf.io.parse_single_example(example_proto, features) 161 | parsed_tensors = { 162 | key: tf.io.parse_tensor(parsed_features[key], dtype) 163 | for key, dtype in self.PROTO_TYPE_SPEC.items() 164 | } 165 | 166 | # only care about initial and final frames 167 | # TODO add goal labeling scheme 168 | # TODO could simplify bridge dataset this way; probably reduces duplicates too 169 | return { 170 | "observations": { 171 | "image": parsed_tensors["frames"][0], 172 | }, 173 | "next_observations": { 174 | "image": parsed_tensors["frames"][-1], 175 | }, 176 | "goals": { 177 | "image": parsed_tensors["frames"][-1], 178 | "language": parsed_tensors["task_id"][0], 179 | }, 180 | "initial_obs": { 181 | "image": parsed_tensors["frames"][0], 182 | }, 183 | } 184 | 185 | 186 | if __name__ == "__main__": 187 | root_data_path = "gs://rail-tpus-andre/something-something/tf_fixed/train/" 188 | seed = 42 189 | batch_size = 64 190 | shuffle_buffer_size = 10 191 | 192 | dataset = SS2Dataset( 193 | root_data_path=root_data_path, 194 | seed=seed, 195 | batch_size=batch_size, 196 | shuffle_buffer_size=shuffle_buffer_size, 197 | train=True, 198 | augment=False, 199 | # augment_kwargs= 200 | ) 201 | 202 | data_iter = dataset.get_iterator() 203 | batch = next(data_iter) 204 | 205 | import matplotlib.pyplot as plt 206 | 207 | plt.figure() 208 | plt.imshow(batch["observations"]["image"][0]) 209 | plt.savefig("obs.png") 210 | print(batch) 211 | 212 | labels_path = "gs://rail-tpus-andre/something-something/tf/labels/" 213 | from jaxrl_m.data.ss2_language import load_mapping, get_encodings 214 | 215 | load_mapping(labels_path) 216 | lang_to_code, code_to_lang = get_encodings() 217 | 218 | print([(k, v) for k, v in lang_to_code.items()][:10]) 219 | print([(k, v) for k, v in code_to_lang.items()][:10]) 220 | -------------------------------------------------------------------------------- /experiments/sim_offline_gc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from jaxrl_m.common.wandb import WandBLogger 3 | from jaxrl_m.common.common import shard_batch 4 | from jaxrl_m.data.bridge_dataset import BridgeDataset 5 | from jaxrl_m.utils.timer_utils import Timer 6 | from jaxrl_m.vision import encoders 7 | from jaxrl_m.agents import agents 8 | from jaxrl_m.utils.train_utils import load_recorded_video 9 | from jaxrl_m.utils.sim_utils import make_mujoco_gc_env 10 | from jaxrl_m.common.evaluation import evaluate_gc, supply_rng 11 | import tensorflow as tf 12 | import tqdm 13 | import jax 14 | import jax.numpy as jnp 15 | from absl import app, flags, logging 16 | from ml_collections import config_flags 17 | import numpy as np 18 | from flax.training import checkpoints 19 | 20 | try: 21 | from jax_smi import initialise_tracking # type: ignore 22 | 23 | initialise_tracking() 24 | except ImportError: 25 | pass 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | flags.DEFINE_string("name", "", "Experiment name.") 30 | flags.DEFINE_bool("debug", False, "Debug config") 31 | 32 | config_flags.DEFINE_config_file( 33 | "config", 34 | None, 35 | "File path to the training hyperparameter configuration.", 36 | lock_config=False, 37 | ) 38 | 39 | config_flags.DEFINE_config_file( 40 | "bridgedata_config", 41 | None, 42 | "File path to the bridgedata configuration.", 43 | lock_config=False, 44 | ) 45 | 46 | 47 | def main(_): 48 | devices = jax.local_devices() 49 | num_devices = len(devices) 50 | assert FLAGS.config.batch_size % num_devices == 0 51 | 52 | # prevent tensorflow from using GPUs 53 | tf.config.set_visible_devices([], "GPU") 54 | 55 | # set up wandb and logging 56 | wandb_config = WandBLogger.get_default_config() 57 | wandb_config.update( 58 | { 59 | "project": "jaxrl_m_sim", 60 | "exp_descriptor": FLAGS.name, 61 | } 62 | ) 63 | wandb_logger = WandBLogger( 64 | wandb_config=wandb_config, 65 | variant=FLAGS.config.to_dict(), 66 | debug=FLAGS.debug, 67 | ) 68 | 69 | save_dir = tf.io.gfile.join( 70 | FLAGS.config.save_dir, 71 | wandb_logger.config.project, 72 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}", 73 | ) 74 | 75 | # load action metadata 76 | with tf.io.gfile.GFile( 77 | tf.io.gfile.join(FLAGS.config.data_path, "train/metadata.npy"), "rb" 78 | ) as f: 79 | action_metadata = np.load(f, allow_pickle=True).item() 80 | 81 | # load eval goals 82 | with tf.io.gfile.GFile( 83 | tf.io.gfile.join(FLAGS.config.data_path, "val/eval_goals.npy"), "rb" 84 | ) as f: 85 | eval_goals = np.load(f, allow_pickle=True).item() 86 | 87 | # create sim environment 88 | eval_env = make_mujoco_gc_env( 89 | env_name=FLAGS.config.env_name, 90 | max_episode_steps=FLAGS.config.max_episode_steps, 91 | action_metadata=action_metadata, 92 | save_video=FLAGS.config.save_video, 93 | save_video_dir=tf.io.gfile.join(save_dir, "videos"), 94 | save_video_prefix="eval", 95 | goals=eval_goals, 96 | ) 97 | 98 | # load datasets 99 | train_paths = tf.io.gfile.glob(f"{FLAGS.config.data_path}/train/*.tfrecord") 100 | val_paths = tf.io.gfile.glob(f"{FLAGS.config.data_path}/val/*.tfrecord") 101 | train_data = BridgeDataset( 102 | train_paths, 103 | FLAGS.config.seed, 104 | batch_size=FLAGS.config.batch_size, 105 | num_devices=num_devices, 106 | train=True, 107 | action_metadata=action_metadata, 108 | relabel_actions=False, 109 | **FLAGS.config.dataset_kwargs, 110 | ) 111 | val_data = BridgeDataset( 112 | val_paths, 113 | FLAGS.config.seed, 114 | batch_size=FLAGS.config.batch_size, 115 | action_metadata=action_metadata, 116 | relabel_actions=False, 117 | train=False, 118 | **FLAGS.config.dataset_kwargs, 119 | ) 120 | train_data_iter = train_data.get_iterator() 121 | 122 | example_batch = next(train_data_iter) 123 | logging.info(f"Batch size: {example_batch['observations']['image'].shape[0]}") 124 | logging.info(f"Number of devices: {num_devices}") 125 | logging.info( 126 | f"Batch size per device: {example_batch['observations']['image'].shape[0] // num_devices}" 127 | ) 128 | 129 | # we shard the leading dimension (batch dimension) accross all devices evenly 130 | sharding = jax.sharding.PositionalSharding(devices) 131 | example_batch = shard_batch(example_batch, sharding) 132 | 133 | # define encoder 134 | encoder_def = encoders[FLAGS.config.encoder](**FLAGS.config.encoder_kwargs) 135 | 136 | # initialize agent 137 | rng = jax.random.PRNGKey(FLAGS.config.seed) 138 | rng, construct_rng = jax.random.split(rng) 139 | agent = agents[FLAGS.config.agent].create( 140 | rng=construct_rng, 141 | observations=example_batch["observations"], 142 | goals=example_batch["goals"], 143 | actions=example_batch["actions"], 144 | encoder_def=encoder_def, 145 | **FLAGS.config.agent_kwargs, 146 | ) 147 | if resume_path := FLAGS.config.get("resume_path", None) is not None: 148 | agent = checkpoints.restore_checkpoint(resume_path, target=agent) 149 | # replicate agent across devices 150 | # need the jnp.array to avoid a bug where device_put doesn't recognize primitives 151 | agent = jax.device_put(jax.tree_map(jnp.array, agent), sharding.replicate()) 152 | 153 | timer = Timer() 154 | for i in tqdm.tqdm(range(int(FLAGS.config.num_steps))): 155 | timer.tick("total") 156 | 157 | timer.tick("dataset") 158 | batch = shard_batch(next(train_data_iter), sharding) 159 | timer.tock("dataset") 160 | 161 | timer.tick("train") 162 | agent, update_info = agent.update(batch) 163 | timer.tock("train") 164 | 165 | if i % FLAGS.config.eval_interval == 0: 166 | logging.info("Validation...") 167 | timer.tick("val") 168 | metrics = [] 169 | i = 0 170 | for batch in val_data.get_iterator(): 171 | metrics.append(agent.get_debug_metrics(batch)) 172 | i += 1 173 | if i >= FLAGS.config.num_val_batches: 174 | break 175 | metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) 176 | metrics["log_probs"] *= -1 177 | wandb_logger.log({"validation": metrics}, step=i) 178 | timer.tock("val") 179 | 180 | rng, policy_key = jax.random.split(rng) 181 | policy_fn = supply_rng( 182 | partial(agent.sample_actions, argmax=FLAGS.config.deterministic_eval), 183 | rng=policy_key, 184 | ) 185 | 186 | logging.info("Evaluating...") 187 | timer.tick("evaluation") 188 | eval_env.goal_sampler = eval_goals 189 | eval_env.start_recording( 190 | FLAGS.config.num_episodes_per_video, FLAGS.config.num_episodes_per_row 191 | ) 192 | eval_info = evaluate_gc( 193 | policy_fn, 194 | eval_env, 195 | num_episodes=FLAGS.config.eval_episodes, 196 | return_trajectories=False, 197 | ) 198 | wandb_logger.log({f"evaluation": eval_info}, step=i) 199 | if FLAGS.config.save_video: 200 | eval_video = load_recorded_video(video_path=eval_env.current_save_path) 201 | wandb_logger.log({"evaluation/video": eval_video}, step=i) 202 | timer.tock("evaluation") 203 | 204 | if i % FLAGS.config.save_interval == 0: 205 | logging.info("Saving checkpoint...") 206 | checkpoint_path = checkpoints.save_checkpoint( 207 | save_dir, agent, step=i, keep=1e6 208 | ) 209 | logging.info("Saved checkpoint to %s", checkpoint_path) 210 | 211 | timer.tock("total") 212 | 213 | if i % FLAGS.config.log_interval == 0: 214 | update_info = jax.device_get(update_info) 215 | wandb_logger.log({"training": update_info}, step=i) 216 | 217 | wandb_logger.log({"timer": timer.get_average_times()}, step=i) 218 | 219 | 220 | if __name__ == "__main__": 221 | app.run(main) 222 | -------------------------------------------------------------------------------- /scripts/bridgedata_numpy_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts data from the BridgeData numpy format to TFRecord format. 3 | 4 | Consider the following directory structure for the input data: 5 | 6 | bridgedata_numpy/ 7 | rss/ 8 | toykitchen2/ 9 | set_table/ 10 | 00/ 11 | train/ 12 | out.npy 13 | val/ 14 | out.npy 15 | icra/ 16 | ... 17 | 18 | The --depth parameter controls how much of the data to process at the 19 | --input_path; for example, if --depth=5, then --input_path should be 20 | "bridgedata_numpy", and all data will be processed. If --depth=3, then 21 | --input_path should be "bridgedata_numpy/rss/toykitchen2", and only data 22 | under "toykitchen2" will be processed. 23 | 24 | The same directory structure will be replicated under --output_path. For 25 | example, in the second case, the output will be written to 26 | "{output_path}/set_table/00/...". 27 | 28 | Can read/write directly from/to Google Cloud Storage. 29 | 30 | Written by Kevin Black (kvablack@berkeley.edu). 31 | """ 32 | from absl import app, flags, logging 33 | import numpy as np 34 | import os 35 | import tqdm 36 | import tensorflow as tf 37 | from multiprocessing import Pool, Manager 38 | from jaxrl_m.data.language import lang_encode, load_mapping, flush_mapping 39 | from jaxrl_m.vision.clip import create_from_checkpoint, process_image, process_text 40 | 41 | FLAGS = flags.FLAGS 42 | 43 | flags.DEFINE_string("input_path", None, "Input path", required=True) 44 | flags.DEFINE_string("output_path", None, "Output path", required=True) 45 | flags.DEFINE_integer( 46 | "depth", 47 | 5, 48 | "Number of directories deep to traverse. Looks for {input_path}/dir_1/dir_2/.../dir_{depth-1}/train/out.npy", 49 | ) 50 | flags.DEFINE_bool("overwrite", True, "Overwrite existing files") 51 | flags.DEFINE_integer("num_workers", 8, "Number of threads to use") 52 | flags.DEFINE_string( 53 | "model_ckpt", None, "Path to model checkpoint for writing embeddings" 54 | ) 55 | 56 | 57 | def tensor_feature(value): 58 | return tf.train.Feature( 59 | bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()]) 60 | ) 61 | 62 | 63 | def process(path): 64 | global clip 65 | 66 | with tf.io.gfile.GFile(path, "rb") as f: 67 | arr = np.load(f, allow_pickle=True) 68 | dirname = os.path.dirname(os.path.abspath(path)) 69 | outpath = os.path.join(FLAGS.output_path, *dirname.split(os.sep)[-FLAGS.depth :]) 70 | outpath = f"{outpath}/out.tfrecord" 71 | 72 | if tf.io.gfile.exists(outpath): 73 | if FLAGS.overwrite: 74 | logging.info(f"Deleting {outpath}") 75 | try: 76 | tf.io.gfile.rmtree(outpath) 77 | except: 78 | pass 79 | else: 80 | logging.info(f"Skipping {outpath}") 81 | return 82 | 83 | if len(arr) == 0: 84 | logging.info(f"Skipping {path}, empty") 85 | return 86 | 87 | tf.io.gfile.makedirs(os.path.dirname(outpath)) 88 | 89 | # get rid of the confidence scores 90 | def clean_lang(lang): 91 | if lang is None: 92 | lang = "" 93 | lang = lang.strip() 94 | lines = lang.split("\n") 95 | lines = [l for l in lines if not "confidence" in l] 96 | # print("\n".join(lines)) 97 | return "\n".join(lines) 98 | 99 | with tf.io.TFRecordWriter(outpath) as writer: 100 | text_batch = [] 101 | img_batch = [] 102 | if FLAGS.model_ckpt: 103 | for traj in arr: 104 | text = traj["language"][0] 105 | text = clean_lang(text) 106 | if text is None: 107 | text = "placeholder" 108 | else: 109 | text = text.split("\n")[0] # multiple labels just take first 110 | 111 | s0 = traj["observations"][0]["images0"] 112 | g = traj["observations"][-1]["images0"] 113 | s0 = process_image(s0) 114 | g = process_image(g) 115 | img = np.concatenate([s0, g], axis=-1) 116 | 117 | text_batch.append(text) 118 | img_batch.append(img) 119 | text_batch = process_text(text_batch) 120 | img_batch = np.concatenate(img_batch, axis=0) 121 | clip_output = clip(pixel_values=img_batch, **text_batch) 122 | text_embeds = clip_output["text_embeds"] 123 | image_embeds = clip_output["image_embeds"] 124 | 125 | for i, traj in enumerate(arr): 126 | if FLAGS.model_ckpt: 127 | traj_text_embed = text_embeds[i : i + 1] 128 | traj_text_embed = np.repeat( 129 | traj_text_embed, len(traj["actions"]), axis=0 130 | ) 131 | traj_image_embed = image_embeds[i : i + 1] 132 | traj_image_embed = np.repeat( 133 | traj_image_embed, len(traj["actions"]), axis=0 134 | ) 135 | 136 | with lock: 137 | encoded_language = tensor_feature( 138 | np.array( 139 | [ 140 | lang_encode(clean_lang(x) if x else None) 141 | for x in traj["language"] 142 | ] 143 | ) 144 | ) 145 | truncates = np.zeros(len(traj["actions"]), dtype=np.bool_) 146 | truncates[-1] = True 147 | feature = { 148 | "observations/images0": tensor_feature( 149 | np.array( 150 | [o["images0"] for o in traj["observations"]], 151 | dtype=np.uint8, 152 | ) 153 | ), 154 | "observations/state": tensor_feature( 155 | np.array( 156 | [o["state"] for o in traj["observations"]], 157 | dtype=np.float32, 158 | ) 159 | ), 160 | "next_observations/images0": tensor_feature( 161 | np.array( 162 | [o["images0"] for o in traj["next_observations"]], 163 | dtype=np.uint8, 164 | ) 165 | ), 166 | "next_observations/state": tensor_feature( 167 | np.array( 168 | [o["state"] for o in traj["next_observations"]], 169 | dtype=np.float32, 170 | ) 171 | ), 172 | "actions": tensor_feature(np.array(traj["actions"], dtype=np.float32)), 173 | "terminals": tensor_feature( 174 | np.zeros(len(traj["actions"]), dtype=np.bool_) 175 | ), 176 | "truncates": tensor_feature(truncates), 177 | "language": encoded_language, 178 | } 179 | if FLAGS.model_ckpt: 180 | feature["text_embed"] = tensor_feature(traj_text_embed) 181 | feature["image_embed"] = tensor_feature(traj_image_embed) 182 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 183 | writer.write(example.SerializeToString()) 184 | 185 | 186 | def main(_): 187 | global clip 188 | if FLAGS.model_ckpt: 189 | clip = create_from_checkpoint(FLAGS.model_ckpt) 190 | 191 | global lock 192 | assert FLAGS.depth >= 1 193 | manager = Manager() 194 | 195 | lock = manager.Lock() 196 | flush_mapping(FLAGS.output_path) 197 | load_mapping(FLAGS.output_path, manager.dict) 198 | paths = tf.io.gfile.glob( 199 | tf.io.gfile.join(FLAGS.input_path, *(["*?"] * (FLAGS.depth - 1))) 200 | ) 201 | code_path = tf.io.gfile.join(FLAGS.output_path, "language_encodings.json") 202 | paths = [os.path.join(p, "train/out.npy") for p in paths] + [ 203 | os.path.join(p, "val/out.npy") for p in paths 204 | ] 205 | # with Pool(FLAGS.num_workers) as p: 206 | # list(tqdm.tqdm(p.imap(process, paths), total=len(paths))) 207 | error_paths = [] 208 | for path in tqdm.tqdm(paths): 209 | #try: 210 | process(path) 211 | # break 212 | #except Exception as e: 213 | # error_paths.append(path) 214 | # print("Error on path", path) 215 | 216 | flush_mapping(FLAGS.output_path) 217 | print(error_paths) 218 | with tf.io.gfile.GFile( 219 | os.path.join(FLAGS.output_path, "error_paths.txt"), "w" 220 | ) as f: 221 | f.write("\n".join(error_paths)) 222 | 223 | 224 | if __name__ == "__main__": 225 | app.run(main) 226 | -------------------------------------------------------------------------------- /jaxrl_m/networks/actor_critic_nets.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import distrax 4 | import jax 5 | import jax.numpy as jnp 6 | import flax.linen as nn 7 | 8 | from jaxrl_m.common.typing import Callable 9 | from jaxrl_m.common.typing import Dict 10 | from jaxrl_m.common.typing import Tuple 11 | from jaxrl_m.common.typing import Optional 12 | from jaxrl_m.common.typing import Sequence 13 | from jaxrl_m.common.common import default_init 14 | from jaxrl_m.common.common import MLP 15 | 16 | 17 | class ActorCriticWrapper(nn.Module): 18 | """ 19 | Generic wrapper for all the networks involved in actor-critic type methods. 20 | This includes an actor, a critic (optional), and a value network (optional). 21 | The critic network takes in observations and actions, while the actor and 22 | value networks only take in observations. 23 | 24 | During initialization, the default `__call__` method can be used, which runs 25 | all of the networks. Later on, the networks can be called individually using 26 | `module_def.apply({"params": params}, ..., method="...")`. 27 | """ 28 | 29 | encoders: Dict[str, nn.Module] 30 | networks: Dict[str, nn.Module] 31 | 32 | def setup(self): 33 | assert self.encoders.keys() == self.networks.keys() 34 | assert self.networks.keys() in [ 35 | {"actor"}, 36 | {"actor", "critic"}, 37 | {"actor", "critic", "value"}, 38 | ] 39 | 40 | def actor(self, observations, temperature=1.0, train=False): 41 | return self.networks["actor"]( 42 | self.encoders["actor"](observations), temperature=temperature, train=train 43 | ) 44 | 45 | def critic(self, observations, actions, train=False): 46 | return self.networks["critic"]( 47 | self.encoders["critic"](observations), 48 | actions, 49 | train=train, 50 | ) 51 | 52 | def value(self, observations, train=False): 53 | return self.networks["value"](self.encoders["value"](observations), train=train) 54 | 55 | def __call__(self, observations, actions): 56 | rets = {"actor": self.actor(observations)} 57 | if "critic" in self.networks: 58 | rets["critic"] = self.critic(observations, actions) 59 | if "value" in self.networks: 60 | rets["value"] = self.value(observations) 61 | return rets 62 | 63 | 64 | class MultimodalActorCriticWrapper(nn.Module): 65 | """ 66 | Actor-critic wrapper that maintains different encoders for different modalities. 67 | """ 68 | 69 | encoders: Dict[str, Tuple[nn.Module, Dict[str, nn.Module]]] 70 | networks: Dict[str, nn.Module] 71 | share_encoders: bool = False 72 | 73 | def setup(self): 74 | assert self.encoders.keys() == self.networks.keys() 75 | assert self.networks.keys() in [ 76 | {"actor"}, 77 | {"actor", "critic"}, 78 | {"actor", "critic", "value"}, 79 | ] 80 | self.contrastive_temp = self.param( 81 | "contrastive_temp", nn.initializers.constant(jnp.log(0.07)), () 82 | ) 83 | 84 | def actor(self, observations, modality, temperature=1.0, freeze_mask=None, train=False): 85 | # TODO ugly 86 | if not self.share_encoders: 87 | enc_modality = modality 88 | else: 89 | enc_modality = "image" 90 | action_enc, task_enc = self.encoders["actor"][1][enc_modality]( 91 | observations, self.encoders["actor"][0], override_modality=modality 92 | ) 93 | if action_enc is not None: 94 | detach_action_enc = jax.lax.stop_gradient(action_enc) 95 | if freeze_mask is None: 96 | mask = 0 97 | else: 98 | mask = freeze_mask[:, None] 99 | action_enc = mask * detach_action_enc + (1 - mask) * action_enc 100 | return ( 101 | self.networks["actor"]( 102 | action_enc, temperature=temperature, train=train 103 | ), 104 | task_enc, 105 | ) 106 | else: 107 | return None, task_enc 108 | 109 | def critic(self, observations, actions, modality, train=False): 110 | return self.networks["critic"]( 111 | self.encoders["critic"][modality](observations), 112 | actions, 113 | train=train, 114 | ) 115 | 116 | def value(self, observations, modality, train=False): 117 | return self.networks["value"]( 118 | self.encoders["value"][modality](observations), train=train 119 | ) 120 | 121 | def __call__(self, observations, actions, modalities, **kwargs): 122 | rets = {} 123 | for modality in modalities: 124 | rets[modality] = {"actor": self.actor(observations, modality, **kwargs)} 125 | # if "critic" in self.networks: 126 | # rets[modality]["critic"] = self.critic(observations, actions, modality) 127 | # if "value" in self.networks: 128 | # rets[modality]["value"] = self.value(observations, modality) 129 | return rets 130 | 131 | 132 | class ValueCritic(nn.Module): 133 | hidden_dims: Sequence[int] 134 | dropout: float = 0.0 135 | 136 | @nn.compact 137 | def __call__(self, observations: jnp.ndarray, train: bool = False) -> jnp.ndarray: 138 | critic = MLP( 139 | (*self.hidden_dims, 1), 140 | activate_final=False, 141 | small_init_final=False, 142 | dropout=self.dropout, 143 | )(observations, train=train) 144 | return jnp.squeeze(critic, -1) 145 | 146 | 147 | class Critic(nn.Module): 148 | hidden_dims: Sequence[int] 149 | activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish 150 | dropout: float = 0.0 151 | 152 | @nn.compact 153 | def __call__( 154 | self, observations: jnp.ndarray, actions: jnp.ndarray, train: bool = False 155 | ) -> jnp.ndarray: 156 | inputs = jnp.concatenate([observations, actions], -1) 157 | critic = MLP( 158 | (*self.hidden_dims, 1), 159 | activate_final=False, 160 | small_init_final=False, 161 | activation=self.activation, 162 | dropout=self.dropout, 163 | )(inputs, train=train) 164 | return jnp.squeeze(critic, -1) 165 | 166 | 167 | def ensemblize(cls, num_qs, out_axes=0): 168 | return nn.vmap( 169 | cls, 170 | variable_axes={"params": 0}, 171 | split_rngs={"params": True}, 172 | in_axes=None, 173 | out_axes=out_axes, 174 | axis_size=num_qs, 175 | ) 176 | 177 | 178 | class Policy(nn.Module): 179 | hidden_dims: Sequence[int] 180 | action_dim: int 181 | log_std_min: Optional[float] = -20 182 | log_std_max: Optional[float] = 2 183 | tanh_squash_distribution: bool = False 184 | fixed_std: Optional[jnp.ndarray] = None 185 | state_dependent_std: bool = True 186 | dropout: float = 0.0 187 | 188 | @nn.compact 189 | def __call__( 190 | self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False 191 | ) -> distrax.Distribution: 192 | outputs = MLP( 193 | self.hidden_dims, 194 | activate_final=True, 195 | small_init_final=False, 196 | dropout=self.dropout, 197 | )(observations, train=train) 198 | 199 | means = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))(outputs) 200 | if self.fixed_std is None: 201 | if self.state_dependent_std: 202 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init(1e-2))( 203 | outputs 204 | ) 205 | else: 206 | log_stds = self.param( 207 | "log_stds", nn.initializers.zeros, (self.action_dim,) 208 | ) 209 | else: 210 | log_stds = jnp.log(jnp.array(self.fixed_std)) 211 | 212 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) / temperature 213 | 214 | if self.tanh_squash_distribution: 215 | distribution = TanhMultivariateNormalDiag( 216 | loc=means, scale_diag=jnp.exp(log_stds) 217 | ) 218 | else: 219 | distribution = distrax.MultivariateNormalDiag( 220 | loc=means, scale_diag=jnp.exp(log_stds) 221 | ) 222 | 223 | return distribution 224 | 225 | 226 | class TanhMultivariateNormalDiag(distrax.Transformed): 227 | def __init__( 228 | self, 229 | loc: jnp.ndarray, 230 | scale_diag: jnp.ndarray, 231 | low: Optional[jnp.ndarray] = None, 232 | high: Optional[jnp.ndarray] = None, 233 | ): 234 | distribution = distrax.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) 235 | 236 | layers = [] 237 | 238 | if not (low is None or high is None): 239 | 240 | def rescale_from_tanh(x): 241 | x = (x + 1) / 2 # (-1, 1) => (0, 1) 242 | return x * (high - low) + low 243 | 244 | def forward_log_det_jacobian(x): 245 | high_ = jnp.broadcast_to(high, x.shape) 246 | low_ = jnp.broadcast_to(low, x.shape) 247 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1) 248 | 249 | layers.append( 250 | distrax.Lambda( 251 | rescale_from_tanh, 252 | forward_log_det_jacobian=forward_log_det_jacobian, 253 | event_ndims_in=1, 254 | event_ndims_out=1, 255 | ) 256 | ) 257 | 258 | layers.append(distrax.Block(distrax.Tanh(), 1)) 259 | 260 | bijector = distrax.Chain(layers) 261 | 262 | super().__init__(distribution=distribution, bijector=bijector) 263 | 264 | def mode(self) -> jnp.ndarray: 265 | return self.bijector.forward(self.distribution.mode()) 266 | -------------------------------------------------------------------------------- /jaxrl_m/common/common.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union 2 | import flax 3 | from flax.core import FrozenDict 4 | import flax.linen as nn 5 | from flax import struct 6 | import jax 7 | import jax.numpy as jnp 8 | import optax 9 | import functools 10 | 11 | from jaxrl_m.common.typing import Callable, PRNGKey 12 | from jaxrl_m.common.typing import Optional 13 | from jaxrl_m.common.typing import Params 14 | from jaxrl_m.common.typing import Sequence 15 | 16 | 17 | nonpytree_field = functools.partial(flax.struct.field, pytree_node=False) 18 | 19 | 20 | def shard_batch(batch, sharding): 21 | """Shards a batch across devices along its first dimension. 22 | 23 | Args: 24 | batch: A pytree of arrays. 25 | sharding: A jax Sharding object with shape (num_devices,). 26 | """ 27 | return jax.tree_map( 28 | lambda x: jax.device_put( 29 | x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) 30 | ), 31 | batch, 32 | ) 33 | 34 | 35 | def default_init(scale: Optional[float] = 1.0): 36 | return nn.initializers.variance_scaling(scale, "fan_avg", "uniform") 37 | 38 | 39 | def orthogonal_init(scale: float = jnp.sqrt(2)): 40 | return nn.initializers.orthogonal(scale) 41 | 42 | 43 | def final_layer_init(init_w: float = 1e-3): 44 | return nn.initializers.uniform(-init_w, init_w) 45 | 46 | 47 | class MLP(nn.Module): 48 | hidden_dims: Sequence[int] 49 | activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish 50 | activate_final: bool = False 51 | small_init_final: bool = False 52 | dropout: float = 0.0 53 | 54 | def setup(self): 55 | self.dropout_layer = nn.Dropout(rate=self.dropout) 56 | 57 | layers = [ 58 | nn.Dense(size, kernel_init=default_init()) for size in self.hidden_dims[:-1] 59 | ] 60 | if self.small_init_final: 61 | layers.append( 62 | nn.Dense( 63 | self.hidden_dims[-1], 64 | kernel_init=final_layer_init(), 65 | bias_init=nn.initializers.constant(0), 66 | ) 67 | ) 68 | else: 69 | layers.append(nn.Dense(self.hidden_dims[-1], kernel_init=default_init())) 70 | self.layers = layers 71 | 72 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 73 | for i, size in enumerate(self.hidden_dims): 74 | x = self.layers[i](x) 75 | if i + 1 < len(self.hidden_dims) or self.activate_final: 76 | x = self.activation(x) 77 | x = self.dropout_layer(x, deterministic=not train) 78 | return x 79 | 80 | 81 | class TaskMLP(MLP): 82 | def __call__(self, observations: jnp.ndarray, task: jnp.ndarray, train: bool=False): 83 | return super().__call__(task, train) 84 | 85 | class JaxRLTrainState(struct.PyTreeNode): 86 | """ 87 | Custom TrainState class to replace `flax.training.train_state.TrainState`. 88 | 89 | Adds support for holding target params and updating them via polyak 90 | averaging. Adds the ability to hold an rng key for dropout. 91 | 92 | Also generalizes the TrainState to support an arbitrary pytree of 93 | optimizers, `txs`. When `apply_gradients()` is called, the `grads` argument 94 | must have `txs` as a prefix. This is backwards-compatible, meaning `txs` can 95 | be a single optimizer and `grads` can be a single tree with the same 96 | structure as `self.params`. 97 | 98 | Also adds a convenience method `apply_loss_fns` that takes a pytree of loss 99 | functions with the same structure as `txs`, computes gradients, and applies 100 | them using `apply_gradients`. 101 | 102 | Attributes: 103 | step: The current training step. 104 | apply_fn: The function used to apply the model. 105 | params: The model parameters. 106 | target_params: The target model parameters. 107 | txs: The optimizer or pytree of optimizers. 108 | opt_states: The optimizer state or pytree of optimizer states. 109 | rng: The internal rng state. 110 | """ 111 | 112 | step: int 113 | apply_fn: Callable = struct.field(pytree_node=False) 114 | params: Params 115 | target_params: Params 116 | txs: Any = struct.field(pytree_node=False) 117 | opt_states: Any 118 | rng: PRNGKey 119 | 120 | @staticmethod 121 | def _tx_tree_map(*args, **kwargs): 122 | return jax.tree_map( 123 | *args, 124 | is_leaf=lambda x: isinstance(x, optax.GradientTransformation), 125 | **kwargs, 126 | ) 127 | 128 | def target_update(self, tau: float) -> "JaxRLTrainState": 129 | """ 130 | Performs an update of the target params via polyak averaging. The new 131 | target params are given by: 132 | 133 | new_target_params = tau * params + (1 - tau) * target_params 134 | """ 135 | new_target_params = jax.tree_map( 136 | lambda p, tp: p * tau + tp * (1 - tau), self.params, self.target_params 137 | ) 138 | return self.replace(target_params=new_target_params) 139 | 140 | def apply_gradients(self, *, grads: Any) -> "JaxRLTrainState": 141 | """ 142 | Only difference from flax's TrainState is that `grads` must have 143 | `self.txs` as a tree prefix (i.e. where `self.txs` has a leaf, `grads` 144 | has a subtree with the same structure as `self.params`.) 145 | """ 146 | updates_and_new_states = self._tx_tree_map( 147 | lambda tx, opt_state, grad: tx.update(grad, opt_state, self.params), 148 | self.txs, 149 | self.opt_states, 150 | grads, 151 | ) 152 | updates = self._tx_tree_map(lambda _, x: x[0], self.txs, updates_and_new_states) 153 | new_opt_states = self._tx_tree_map( 154 | lambda _, x: x[1], self.txs, updates_and_new_states 155 | ) 156 | 157 | # not the cleanest, I know, but this flattens the leaves of `updates` 158 | # into a list where leaves are defined by `self.txs` 159 | updates_flat = [] 160 | self._tx_tree_map( 161 | lambda _, update: updates_flat.append(update), self.txs, updates 162 | ) 163 | 164 | # apply all the updates additively 165 | updates_acc = jax.tree_map( 166 | lambda *xs: jnp.sum(jnp.array(xs), axis=0), *updates_flat 167 | ) 168 | new_params = optax.apply_updates(self.params, updates_acc) 169 | 170 | return self.replace( 171 | step=self.step + 1, params=new_params, opt_states=new_opt_states 172 | ) 173 | 174 | def apply_loss_fns( 175 | self, loss_fns: Any, pmap_axis: str = None, has_aux: bool = False 176 | ) -> Union["JaxRLTrainState", Tuple["JaxRLTrainState", Any]]: 177 | """ 178 | Convenience method to compute gradients based on `self.params` and apply 179 | them using `apply_gradients`. `loss_fns` must have the same structure as 180 | `txs`, and each leaf must be a function that takes two arguments: 181 | `params` and `rng`. 182 | 183 | This method automatically provides fresh rng to each loss function and 184 | updates this train state's internal rng key. 185 | 186 | Args: 187 | loss_fns: loss function or pytree of loss functions with same 188 | structure as `self.txs`. Each loss function must take `params` 189 | as the first argument and `rng` as the second argument, and return 190 | a scalar value. 191 | pmap_axis: if not None, gradients (and optionally auxiliary values) 192 | will be averaged over this axis 193 | has_aux: if True, each `loss_fn` returns a tuple of (loss, aux) where 194 | `aux` is a pytree of auxiliary values to be returned by this 195 | method. 196 | 197 | Returns: 198 | If `has_aux` is True, returns a tuple of (new_train_state, aux). 199 | Otherwise, returns the new train state. 200 | """ 201 | # create a pytree of rngs with the same structure as `loss_fns` 202 | treedef = jax.tree_util.tree_structure(loss_fns) 203 | new_rng, *rngs = jax.random.split(self.rng, treedef.num_leaves + 1) 204 | rngs = jax.tree_util.tree_unflatten(treedef, rngs) 205 | 206 | # compute gradients 207 | grads_and_aux = jax.tree_map( 208 | lambda loss_fn, rng: jax.grad(loss_fn, has_aux=has_aux)(self.params, rng), 209 | loss_fns, 210 | rngs, 211 | ) 212 | 213 | # update rng state 214 | self = self.replace(rng=new_rng) 215 | 216 | # average across devices if necessary 217 | if pmap_axis is not None: 218 | grads_and_aux = jax.lax.pmean(grads_and_aux, axis_name=pmap_axis) 219 | 220 | if has_aux: 221 | grads = jax.tree_map(lambda _, x: x[0], loss_fns, grads_and_aux) 222 | aux = jax.tree_map(lambda _, x: x[1], loss_fns, grads_and_aux) 223 | return self.apply_gradients(grads=grads), aux 224 | else: 225 | return self.apply_gradients(grads=grads_and_aux) 226 | 227 | @classmethod 228 | def create( 229 | cls, *, apply_fn, params, txs, target_params=None, rng=jax.random.PRNGKey(0) 230 | ): 231 | """ 232 | Initializes a new train state. 233 | 234 | Args: 235 | apply_fn: The function used to apply the model, typically `model_def.apply`. 236 | params: The model parameters, typically from `model_def.init`. 237 | txs: The optimizer or pytree of optimizers. 238 | target_params: The target model parameters. 239 | rng: The rng key used to initialize the rng chain for `apply_loss_fns`. 240 | """ 241 | return cls( 242 | step=0, 243 | apply_fn=apply_fn, 244 | params=params, 245 | target_params=target_params, 246 | txs=txs, 247 | opt_states=cls._tx_tree_map(lambda tx: tx.init(params), txs), 248 | rng=rng, 249 | ) 250 | --------------------------------------------------------------------------------