├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── contributing.md ├── data_processing ├── bridgedata_numpy_to_tfrecord.py └── bridgedata_raw_to_numpy.py ├── experiments ├── configs │ ├── data_config.py │ └── train_config.py ├── eval.py ├── eval_gc.py ├── eval_lc.py ├── susie │ └── calvin │ │ ├── README.md │ │ ├── calvin_gcbc.py │ │ ├── calvin_lcbc.py │ │ ├── configs │ │ ├── gcbc_data_config.py │ │ ├── gcbc_train_config.py │ │ ├── lcbc_data_config.py │ │ └── lcbc_train_config.py │ │ ├── dataset_conversion_scripts │ │ ├── goal_conditioned.py │ │ └── language_conditioned.py │ │ └── scripts │ │ ├── launch_calvin_gcbc.sh │ │ └── launch_calvin_lcbc.sh ├── train.py └── utils.py ├── jaxrl_m ├── agents │ ├── __init__.py │ └── continuous │ │ ├── bc.py │ │ ├── gc_bc.py │ │ ├── gc_ddpm_bc.py │ │ ├── gc_iql.py │ │ ├── iql.py │ │ ├── lc_bc.py │ │ └── stable_contrastive_rl.py ├── common │ ├── common.py │ ├── encoding.py │ ├── typing.py │ └── wandb.py ├── data │ ├── bridge_dataset.py │ ├── calvin_dataset.py │ ├── text_processing.py │ ├── tf_augmentations.py │ └── tf_goal_relabeling.py ├── networks │ ├── actor_critic_nets.py │ ├── diffusion_nets.py │ └── mlp.py ├── utils │ └── timer_utils.py └── vision │ ├── __init__.py │ ├── film_conditioning_layer.py │ └── resnet_v1.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | *.ipynb 163 | wandb 164 | checkpoints 165 | log 166 | render 167 | *.png 168 | 169 | *.sif 170 | 171 | # VSCode 172 | .vscode 173 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM robonet-base:latest 2 | 3 | COPY requirements.txt /tmp/requirements.txt 4 | RUN ~/myenv/bin/pip install -r /tmp/requirements.txt 5 | RUN ~/myenv/bin/pip install --upgrade "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 6 | ENV PYTHONPATH=${PYTHONPATH}:/home/robonet/code/bridge_data_v2 7 | 8 | # modify packages to work with python 3.8 (ros noetic needs python 3.8) 9 | # to avoid orbax checkpoint error, downgrade flax 10 | RUN ~/myenv/bin/pip install flax==0.6.11 11 | # to avoid typing errors, upgrade distrax 12 | RUN ~/myenv/bin/pip install distrax==0.1.3 13 | 14 | # avoid git safe directory errors 15 | RUN git config --global --add safe.directory /home/robonet/code/bridge_data_v2 16 | 17 | WORKDIR /home/robonet/code/bridge_data_v2 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Robotic AI & Learning Lab Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Jax BC/RL Implementations for BridgeData V2 2 | 3 | This repository provides code for training on [BridgeData V2](https://rail-berkeley.github.io/bridgedata/). 4 | 5 | We provide implementations for the following subset of methods described in the paper: 6 | 7 | - Goal-conditioned BC 8 | - Goal-conditioned BC with a diffusion policy 9 | - Goal-condtioned IQL 10 | - Goal-conditioned contrastive RL 11 | - Language-conditioned BC 12 | 13 | The official implementations and papers for all the methods can be found here: 14 | - [IDQL](https://github.com/philippe-eecs/IDQL) (IQL + diffusion policy) [[Hansen-Estruch et al.](https://github.com/philippe-eecs/IDQL)] and [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) [[Chi et al.](https://diffusion-policy.cs.columbia.edu/)] 15 | - [IQL](https://github.com/ikostrikov/implicit_q_learning) [[Kostrikov et al.](https://arxiv.org/abs/2110.06169)] 16 | - [Contrastive RL](https://chongyi-zheng.github.io/stable_contrastive_rl/) [[Zheng et al.](https://arxiv.org/abs/2306.03346), [Eysenbach et al.](https://arxiv.org/abs/2206.07568)] 17 | - [RT-1](https://github.com/google-research/robotics_transformer) [[Brohan et al.](https://arxiv.org/abs/2212.06817)] 18 | - [ACT](https://github.com/tonyzhaozh/act) [[Zhao et al.](https://arxiv.org/abs/2304.13705)] 19 | 20 | Please open a GitHub issue if you encounter problems with this code. 21 | 22 | ## Data 23 | The raw dataset (comprised of JPEGs, PNGs, and pkl files) can be downloaded [here](https://rail.eecs.berkeley.edu/datasets/bridge_release/data/). `demos*.zip` file contains the demonstration data, and `scripted*.zip` contains the data collected with a scripted policy. For training, the raw data needs to be converted into a format that is compatible with a data loader. We offer two options: 24 | 25 | - A custom `tf.data` loader. This data loader is implemented in `jaxrl_m/data/bridge_dataset.py` and is used by the training script in this repo. The scripts in the `data_processing` folder convert the raw data into the format required by this data loader. First, use `bridgedata_raw_to_numpy.py` to convert the raw data into NumPy files. Then, use `bridgedata_numpy_to_tfrecord.py` to convert the NumPy files into TFRecord files. 26 | - A [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview) loader. Tensorflow Datasets is a high level wrapper around `tf.data`. We offer a pre-processed TFDS version of the dataset (downsampled to 256x256) in the `tfds` folder here [here](https://rail.eecs.berkeley.edu/datasets/bridge_release/data/). In the TFDS dataset, the trajectories are structured using the [RLDS](https://github.com/google-research/rlds) format. We recommend using the [Octo](https://github.com/octo-models/octo) data loader for loading the RLDS version of BridgeData. If you would like to reprocess BridgeData into RLDS (e.g to change the resolution or add keys), you can use [this repo](https://github.com/kvablack/dlimp/tree/main/rlds_converters). 27 | 28 | ## Training 29 | 30 | To start training run the command below. Replace `METHOD` with one of `gc_bc`, `gc_ddpm_bc`, `gc_iql`, or `contrastive_rl_td`, and replace `NAME` with a name for the run. 31 | 32 | ``` 33 | python experiments/train.py \ 34 | --config experiments/configs/train_config.py:METHOD \ 35 | --bridgedata_config experiments/configs/data_config.py:all \ 36 | --name NAME 37 | ``` 38 | 39 | Training hyperparameters can be modified in `experiments/configs/data_config.py` and data parameters (e.g. subsets to include/exclude) can be modified in `experiments/configs/train_config.py`. 40 | 41 | ## Evaluation 42 | 43 | First, set up the robot hardware according to our [guide](https://docs.google.com/document/d/1si-6cTElTWTgflwcZRPfgHU7-UwfCUkEztkH3ge5CGc/edit?usp=sharing). Install our WidowX robot controller stack from [this repo](https://github.com/rail-berkeley/bridge_data_robot). 44 | 45 | There are two ways to interface a policy with the robot controller: the docker compose service method or the server-client method. Refer to the [bridge_data_robot](https://github.com/rail-berkeley/bridge_data_robot) docs for an explanation of how to set up each method. In general, we recommend the server-client method. 46 | 47 | For the server-client method, start the server on the robot. Then run the following commands on the client. You can specify the IP of the remote server via the `--ip` flag. The default IP is `localhost` (i.e the server and client are the same machine). 48 | 49 | ```bash 50 | # Specify the path to the downloaded checkpoints directory 51 | export CHECKPOINT_DIR=/path/to/checkpoint_dir 52 | 53 | # For GCBC 54 | python experiments/eval.py \ 55 | --checkpoint_weights_path $CHECKPOINT_DIR/checkpoint_300000 \ 56 | --checkpoint_config_path $CHECKPOINT_DIR/gcbc_256_config.json \ 57 | --im_size 256 --goal_type gc --show_image --blocking 58 | 59 | # For LCBC 60 | python experiments/eval.py \ 61 | --checkpoint_weights_path $CHECKPOINT_DIR/checkpoint_145000 \ 62 | --checkpoint_config_path $CHECKPOINT_DIR/lcbc_256_config.json \ 63 | --im_size 256 --goal_type lc --show_image --blocking 64 | ``` 65 | 66 | You can also specify an initial position for the end effector with the flag `--initial_eep`. Similarly, use the flag `--goal_eep` to specify the position of the end effector when taking a goal image. 67 | 68 | To evaluate image-conditioned or language-conditioned methods with the docker compose service method, run `eval_gc.py` or `eval_lc.py` respectively in the `bridge_data_v2` docker container. 69 | 70 | ## Provided Checkpoints 71 | 72 | Checkpoints for GCBC, LCBC, D-GCBC, GCIQL, and CRL are available [here](https://rail.eecs.berkeley.edu/datasets/bridge_release/checkpoints/). Each checkpoint has an associated JSON file with its configuration information. The name of each checkpoint indicates whether it was trained with 128x128 images or 256x256 images. 73 | 74 | We don't currently have a checkpoints for ACT or RT-1 available but may release them soon. 75 | 76 | ## Environment 77 | 78 | The dependencies for this codebase can be installed in a conda environment: 79 | 80 | ```bash 81 | conda create -n jaxrl python=3.10 82 | conda activate jaxrl 83 | pip install -e . 84 | pip install -r requirements.txt 85 | ``` 86 | For GPU: 87 | ```bash 88 | pip install --upgrade "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 89 | ``` 90 | 91 | For TPU 92 | ``` 93 | pip install --upgrade "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 94 | ``` 95 | See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax. 96 | 97 | ## Cite 98 | 99 | This code is based on [jaxrl_m](https://github.com/dibyaghosh/jaxrl_m) from Dibya Ghosh. 100 | 101 | If you use this code and/or BridgeData V2 in your work, please cite the paper with: 102 | 103 | ``` 104 | @inproceedings{walke2023bridgedata, 105 | title={BridgeData V2: A Dataset for Robot Learning at Scale}, 106 | author={Walke, Homer and Black, Kevin and Lee, Abraham and Kim, Moo Jin and Du, Max and Zheng, Chongyi and Zhao, Tony and Hansen-Estruch, Philippe and Vuong, Quan and He, Andre and Myers, Vivek and Fang, Kuan and Finn, Chelsea and Levine, Sergey}, 107 | booktitle={Conference on Robot Learning (CoRL)}, 108 | year={2023} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We discuss two key abstractions used heavily in this codebase: the use of `TrainState` and the expression of agents as `PytreeNodes` 4 | 5 | ## Agents 6 | 7 | In this codebase, we represent agents as PytreeNodes (first-class Jax citizens), making them really easy to handle. The simplest working example we have in the codebase is probably `jaxrl_m/agents/discrete/bc.py`, so check that out for a concrete implementation. 8 | 9 | The general structure of an Agent is as follows: it contains some number of neural networks, some set of configuration values, and has an update function that takes in a batch and returns a agent with updated parameters after performing some gradient update. Usually there's a `sample_actions` to sample from the resulting policy too. 10 | 11 | ```python 12 | class Agent(flax.struct.PyTreeNode): 13 | value_function: TrainState 14 | policy: TrainState 15 | config: dict = nonpytree_field() # tells Jax to not look at this (usually contains discount factor / target update speed / other hyperparams) 16 | 17 | @jax.jit 18 | def update(self, batch: Batch): 19 | ... 20 | new_value_function = ... 21 | new_policy = ... 22 | info = {'loss': 100} 23 | new_agent = self.replace(value_function=value_function, policy=new_policy) 24 | return new_agent, info 25 | 26 | @jax.jit 27 | def sample_actions(self, observations, *, seed): 28 | actions = ... 29 | return actions 30 | ``` 31 | 32 | ### Multiple Devices 33 | 34 | Operating on multiple GPUs / TPUs is really easy! Check out the section at the bottom of the page as to how to accumulate gradients across all the GPUs. 35 | 36 | 37 | - `flax.jax_utils.replicate()`: replicates an object on all GPUs 38 | - `jaxrl_m.common.common.shard_batch`: splits an batch evenly across all the GPUs 39 | - `flax.jax_utils.unreplicate()` brings back to single GPU 40 | 41 | ```python 42 | agent = ... 43 | batch = ... 44 | 45 | replicated_agent = replicate(agent) 46 | replicated_agent, info = replicated_agent.update(shard_batch(batch)) 47 | info = unreplicate(info) # bring info back to single device 48 | 49 | 50 | ``` 51 | ## TrainState 52 | 53 | 54 | The TrainState class (located at `jaxrl_m.common.common.TrainState`) is a fork of Flax's TrainState class with some additional syntactic features for ease of use. 55 | 56 | The TrainState class combines a neural network module (`flax.linen.Module`) with a set of parameters for this network (alongside with potentially an optimizer) 57 | 58 | ### Creating a TrainState 59 | 60 | ```python 61 | model_def = nn.Dense(10) # nn.Module 62 | params = model_def.init(rng, x)['params'] # parameters for nn.Module 63 | tx = optax.adam(1e-3) 64 | model = TrainState.create(model_def, params, tx=tx) 65 | ``` 66 | 67 | ### Running the Model 68 | 69 | ```python 70 | model = TrainState.create(...) 71 | y_pred = model(x) 72 | ``` 73 | 74 | In some cases, the neural network module may have several functions; for example, a VAE might have an `.encode(x)` function and a `.decode(z)` function. By default, the `__call__()` method is used, but this can be specified via an argument: 75 | 76 | ```python 77 | z = model(x, method='encode') 78 | x_pred = model(z, method='decode') 79 | ``` 80 | 81 | You can also run the model with a different set of parameters than that bound to the TrainState. This is most commonly done when taking the gradient with respect to model parameters. 82 | 83 | ```python 84 | y_pred = model(x, params=other_params) 85 | ``` 86 | 87 | ```python 88 | def loss(params): 89 | y_pred = model(x, params=params) 90 | return jnp.mean((y - y_pred) ** 2) 91 | 92 | grads = jax.grad(loss)(model.params) 93 | ``` 94 | 95 | ### Optimizing a TrainState 96 | 97 | To update a model (that has a `tx`), we provide two convenience functions: `.apply_gradients` and `.apply_loss_fn` 98 | 99 | `model.apply_gradients` takes in a set of gradients (same shape as parameters) and computes the new set of parameters using optax. 100 | 101 | ```python 102 | def loss(params): 103 | y_pred = model(x, params=params) 104 | return jnp.mean((y - y_pred) ** 2) 105 | 106 | grads = jax.grad(loss)(model.params) 107 | new_model = model.apply_gradients(grads=grads) 108 | ``` 109 | 110 | `model.apply_loss_fn()` is a convenience method that both computes the gradients and runs `.apply_gradients()`. 111 | 112 | ```python 113 | def loss(params): 114 | y_pred = model(x, params=params) 115 | return jnp.mean((y - y_pred) ** 2) 116 | 117 | new_model = model.apply_loss_fn(loss_fn=loss) 118 | ``` 119 | 120 | If the model is being run across multiple GPUs / TPUs and we wish to aggregate gradients, this can be specified with the `pmap_axis` argument (you can always use jax.lax.pmean as an alternative): 121 | 122 | ```python 123 | @functools.partial(jax.pmap, axis_name='pmap') 124 | def update(model, x, y): 125 | def loss(params): 126 | y_pred = model(x, params=params) 127 | return jnp.mean((y - y_pred) ** 2) 128 | 129 | new_model = model.apply_loss_fn(loss_fn=loss, pmap_axis='pmap') 130 | return new_model 131 | ``` 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /data_processing/bridgedata_numpy_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts data from the BridgeData numpy format to TFRecord format. 3 | 4 | Consider the following directory structure for the input data: 5 | 6 | bridgedata_numpy/ 7 | rss/ 8 | toykitchen2/ 9 | set_table/ 10 | 00/ 11 | train/ 12 | out.npy 13 | val/ 14 | out.npy 15 | icra/ 16 | ... 17 | 18 | The --depth parameter controls how much of the data to process at the 19 | --input_path; for example, if --depth=5, then --input_path should be 20 | "bridgedata_numpy", and all data will be processed. If --depth=3, then 21 | --input_path should be "bridgedata_numpy/rss/toykitchen2", and only data 22 | under "toykitchen2" will be processed. 23 | 24 | The same directory structure will be replicated under --output_path. For 25 | example, in the second case, the output will be written to 26 | "{output_path}/set_table/00/...". 27 | 28 | Can read/write directly from/to Google Cloud Storage. 29 | 30 | Written by Kevin Black (kvablack@berkeley.edu). 31 | """ 32 | import os 33 | from multiprocessing import Pool 34 | 35 | import numpy as np 36 | import tensorflow as tf 37 | import tqdm 38 | from absl import app, flags, logging 39 | 40 | FLAGS = flags.FLAGS 41 | 42 | flags.DEFINE_string("input_path", None, "Input path", required=True) 43 | flags.DEFINE_string("output_path", None, "Output path", required=True) 44 | flags.DEFINE_integer( 45 | "depth", 46 | 5, 47 | "Number of directories deep to traverse. Looks for {input_path}/dir_1/dir_2/.../dir_{depth-1}/train/out.npy", 48 | ) 49 | flags.DEFINE_bool("overwrite", False, "Overwrite existing files") 50 | flags.DEFINE_integer("num_workers", 8, "Number of threads to use") 51 | 52 | 53 | def tensor_feature(value): 54 | return tf.train.Feature( 55 | bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()]) 56 | ) 57 | 58 | 59 | def process(path): 60 | with tf.io.gfile.GFile(path, "rb") as f: 61 | arr = np.load(f, allow_pickle=True) 62 | dirname = os.path.dirname(os.path.abspath(path)) 63 | outpath = os.path.join(FLAGS.output_path, *dirname.split(os.sep)[-FLAGS.depth :]) 64 | outpath = f"{outpath}/out.tfrecord" 65 | 66 | if tf.io.gfile.exists(outpath): 67 | if FLAGS.overwrite: 68 | logging.info(f"Deleting {outpath}") 69 | tf.io.gfile.rmtree(outpath) 70 | else: 71 | logging.info(f"Skipping {outpath}") 72 | return 73 | 74 | if len(arr) == 0: 75 | logging.info(f"Skipping {path}, empty") 76 | return 77 | 78 | tf.io.gfile.makedirs(os.path.dirname(outpath)) 79 | 80 | with tf.io.TFRecordWriter(outpath) as writer: 81 | for traj in arr: 82 | truncates = np.zeros(len(traj["actions"]), dtype=np.bool_) 83 | truncates[-1] = True 84 | example = tf.train.Example( 85 | features=tf.train.Features( 86 | feature={ 87 | "observations/images0": tensor_feature( 88 | np.array( 89 | [o["images0"] for o in traj["observations"]], 90 | dtype=np.uint8, 91 | ) 92 | ), 93 | "observations/state": tensor_feature( 94 | np.array( 95 | [o["state"] for o in traj["observations"]], 96 | dtype=np.float32, 97 | ) 98 | ), 99 | "next_observations/images0": tensor_feature( 100 | np.array( 101 | [o["images0"] for o in traj["next_observations"]], 102 | dtype=np.uint8, 103 | ) 104 | ), 105 | "next_observations/state": tensor_feature( 106 | np.array( 107 | [o["state"] for o in traj["next_observations"]], 108 | dtype=np.float32, 109 | ) 110 | ), 111 | "language": tensor_feature(traj["language"]), 112 | "actions": tensor_feature( 113 | np.array(traj["actions"], dtype=np.float32) 114 | ), 115 | "terminals": tensor_feature( 116 | np.zeros(len(traj["actions"]), dtype=np.bool_) 117 | ), 118 | "truncates": tensor_feature(truncates), 119 | } 120 | ) 121 | ) 122 | writer.write(example.SerializeToString()) 123 | 124 | 125 | def main(_): 126 | assert FLAGS.depth >= 1 127 | 128 | paths = tf.io.gfile.glob( 129 | tf.io.gfile.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1))) 130 | ) 131 | paths = [f"{p}/train/out.npy" for p in paths] + [f"{p}/val/out.npy" for p in paths] 132 | with Pool(FLAGS.num_workers) as p: 133 | list(tqdm.tqdm(p.imap(process, paths), total=len(paths))) 134 | 135 | 136 | if __name__ == "__main__": 137 | app.run(main) 138 | -------------------------------------------------------------------------------- /data_processing/bridgedata_raw_to_numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts data from the BridgeData raw format to numpy format. 3 | 4 | Consider the following directory structure for the input data: 5 | 6 | bridgedata_raw/ 7 | rss/ 8 | toykitchen2/ 9 | set_table/ 10 | 00/ 11 | 2022-01-01_00-00-00/ 12 | collection_metadata.json 13 | config.json 14 | diagnostics.png 15 | raw/ 16 | traj_group0/ 17 | traj0/ 18 | obs_dict.pkl 19 | policy_out.pkl 20 | agent_data.pkl 21 | images0/ 22 | im_0.jpg 23 | im_1.jpg 24 | ... 25 | ... 26 | ... 27 | 01/ 28 | ... 29 | 30 | The --depth parameter controls how much of the data to process at the 31 | --input_path; for example, if --depth=5, then --input_path should be 32 | "bridgedata_raw", and all data will be processed. If --depth=3, then 33 | --input_path should be "bridgedata_raw/rss/toykitchen2", and only data 34 | under "toykitchen2" will be processed. 35 | 36 | The same directory structure will be replicated under --output_path. For 37 | example, in the second case, the output will be written to 38 | "{output_path}/set_table/00/...". 39 | 40 | Squashes images to 128x128. 41 | 42 | Can write directly to Google Cloud Storage, but not read from it. 43 | 44 | Written by Kevin Black (kvablack@berkeley.edu). 45 | """ 46 | import copy 47 | import glob 48 | import os 49 | import pickle 50 | import random 51 | from collections import defaultdict 52 | from datetime import datetime 53 | from functools import partial 54 | from multiprocessing import Pool 55 | 56 | import numpy as np 57 | import tensorflow as tf 58 | import tqdm 59 | from absl import app, flags, logging 60 | from PIL import Image 61 | 62 | FLAGS = flags.FLAGS 63 | 64 | flags.DEFINE_string("input_path", None, "Input path", required=True) 65 | flags.DEFINE_string("output_path", None, "Output path", required=True) 66 | flags.DEFINE_integer( 67 | "depth", 68 | 5, 69 | "Number of directories deep to traverse to the dated directory. Looks for" 70 | "{input_path}/dir_1/dir_2/.../dir_{depth-1}/2022-01-01_00-00-00/...", 71 | ) 72 | flags.DEFINE_bool("overwrite", False, "Overwrite existing files") 73 | flags.DEFINE_float( 74 | "train_proportion", 0.9, "Proportion of data to use for training (rather than val)" 75 | ) 76 | flags.DEFINE_integer("num_workers", 8, "Number of threads to use") 77 | flags.DEFINE_integer("im_size", 128, "Image size") 78 | 79 | 80 | def squash(path): # squash from 480x640 to im_size 81 | im = Image.open(path) 82 | im = im.resize((FLAGS.im_size, FLAGS.im_size), Image.Resampling.LANCZOS) 83 | out = np.asarray(im).astype(np.uint8) 84 | return out 85 | 86 | 87 | def process_images(path): # processes images at a trajectory level 88 | names = sorted( 89 | [x for x in os.listdir(path) if "images" in x and not "depth" in x], 90 | key=lambda x: int(x.split("images")[1]), 91 | ) 92 | image_path = [ 93 | os.path.join(path, x) 94 | for x in os.listdir(path) 95 | if "images" in x and not "depth" in x 96 | ] 97 | image_path = sorted(image_path, key=lambda x: int(x.split("images")[1])) 98 | 99 | images_out = defaultdict(list) 100 | 101 | tlen = len(glob.glob(image_path[0] + "/im_*.jpg")) 102 | 103 | for i, name in enumerate(names): 104 | for t in range(tlen): 105 | images_out[name].append(squash(image_path[i] + "/im_{}.jpg".format(t))) 106 | 107 | images_out = dict(images_out) 108 | 109 | obs, next_obs = dict(), dict() 110 | 111 | for n in names: 112 | obs[n] = images_out[n][:-1] 113 | next_obs[n] = images_out[n][1:] 114 | return obs, next_obs 115 | 116 | 117 | def process_state(path): 118 | fp = os.path.join(path, "obs_dict.pkl") 119 | with open(fp, "rb") as f: 120 | x = pickle.load(f) 121 | return x["full_state"][:-1], x["full_state"][1:] 122 | 123 | 124 | def process_time(path): 125 | fp = os.path.join(path, "obs_dict.pkl") 126 | with open(fp, "rb") as f: 127 | x = pickle.load(f) 128 | return x["time_stamp"][:-1], x["time_stamp"][1:] 129 | 130 | 131 | def process_actions(path): # gets actions 132 | fp = os.path.join(path, "policy_out.pkl") 133 | with open(fp, "rb") as f: 134 | act_list = pickle.load(f) 135 | if isinstance(act_list[0], dict): 136 | act_list = [x["actions"] for x in act_list] 137 | return act_list 138 | 139 | 140 | # processes each data collection attempt 141 | def process_dc(path, train_ratio=0.9): 142 | # a mystery left by the greats of the past 143 | if "lmdb" in path: 144 | logging.warning(f"Skipping {path} because uhhhh lmdb?") 145 | return [], [], [], [] 146 | 147 | all_dicts_train = list() 148 | all_dicts_test = list() 149 | all_rews_train = list() 150 | all_rews_test = list() 151 | 152 | # Data collected prior to 7-23 has a delay of 1, otherwise a delay of 0 153 | date_time = datetime.strptime(path.split("/")[-1], "%Y-%m-%d_%H-%M-%S") 154 | latency_shift = date_time < datetime(2021, 7, 23) 155 | 156 | search_path = os.path.join(path, "raw", "traj_group*", "traj*") 157 | all_traj = glob.glob(search_path) 158 | if all_traj == []: 159 | logging.info(f"no trajs found in {search_path}") 160 | return [], [], [], [] 161 | 162 | random.shuffle(all_traj) 163 | 164 | num_traj = len(all_traj) 165 | for itraj, tp in tqdm.tqdm(enumerate(all_traj)): 166 | try: 167 | out = dict() 168 | 169 | ld = os.listdir(tp) 170 | 171 | assert "obs_dict.pkl" in ld, tp + ":" + str(ld) 172 | assert "policy_out.pkl" in ld, tp + ":" + str(ld) 173 | # assert "agent_data.pkl" in ld, tp + ":" + str(ld) # not used 174 | 175 | obs, next_obs = process_images(tp) 176 | acts = process_actions(tp) 177 | state, next_state = process_state(tp) 178 | time_stamp, next_time_stamp = process_time(tp) 179 | term = [0] * len(acts) 180 | if "lang.txt" in ld: 181 | with open(os.path.join(tp, "lang.txt")) as f: 182 | lang = list(f) 183 | lang = [l.strip() for l in lang if "confidence" not in l] 184 | else: 185 | # empty string is a placeholder for data with no language label 186 | lang = [""] 187 | 188 | out["observations"] = obs 189 | out["observations"]["state"] = state 190 | out["observations"]["time_stamp"] = time_stamp 191 | out["next_observations"] = next_obs 192 | out["next_observations"]["state"] = next_state 193 | out["next_observations"]["time_stamp"] = next_time_stamp 194 | 195 | out["observations"] = [ 196 | dict(zip(out["observations"], t)) 197 | for t in zip(*out["observations"].values()) 198 | ] 199 | out["next_observations"] = [ 200 | dict(zip(out["next_observations"], t)) 201 | for t in zip(*out["next_observations"].values()) 202 | ] 203 | 204 | out["actions"] = acts 205 | out["terminals"] = term 206 | out["language"] = lang 207 | 208 | # shift the actions according to camera latency 209 | if latency_shift: 210 | out["observations"] = out["observations"][1:] 211 | out["next_observations"] = out["next_observations"][1:] 212 | out["actions"] = out["actions"][:-1] 213 | out["terminals"] = term[:-1] 214 | 215 | labeled_rew = copy.deepcopy(out["terminals"])[:] 216 | labeled_rew[-2:] = [1, 1] 217 | 218 | traj_len = len(out["observations"]) 219 | assert len(out["next_observations"]) == traj_len 220 | assert len(out["actions"]) == traj_len 221 | assert len(out["terminals"]) == traj_len 222 | assert len(labeled_rew) == traj_len 223 | 224 | if itraj < int(num_traj * train_ratio): 225 | all_dicts_train.append(out) 226 | all_rews_train.append(labeled_rew) 227 | else: 228 | all_dicts_test.append(out) 229 | all_rews_test.append(labeled_rew) 230 | except FileNotFoundError as e: 231 | logging.error(e) 232 | continue 233 | except AssertionError as e: 234 | logging.error(e) 235 | continue 236 | 237 | return all_dicts_train, all_dicts_test, all_rews_train, all_rews_test 238 | 239 | 240 | def make_numpy(path, train_proportion): 241 | dirname = os.path.abspath(path) 242 | outpath = os.path.join( 243 | FLAGS.output_path, *dirname.split(os.sep)[-(max(FLAGS.depth - 1, 1)) :] 244 | ) 245 | 246 | if os.path.exists(outpath): 247 | if FLAGS.overwrite: 248 | logging.info(f"Deleting {outpath}") 249 | tf.io.gfile.rmtree(outpath) 250 | else: 251 | logging.info(f"Skipping {outpath}") 252 | return 253 | 254 | outpath_train = tf.io.gfile.join(outpath, "train") 255 | outpath_val = tf.io.gfile.join(outpath, "val") 256 | tf.io.gfile.makedirs(outpath_train) 257 | tf.io.gfile.makedirs(outpath_val) 258 | 259 | lst_train = [] 260 | lst_val = [] 261 | rew_train_l = [] 262 | rew_val_l = [] 263 | 264 | for dated_folder in os.listdir(path): 265 | curr_train, curr_val, rew_train, rew_val = process_dc( 266 | os.path.join(path, dated_folder), train_ratio=train_proportion 267 | ) 268 | lst_train.extend(curr_train) 269 | lst_val.extend(curr_val) 270 | rew_train_l.extend(rew_train) 271 | rew_val_l.extend(rew_val) 272 | 273 | if len(lst_train) == 0 or len(lst_val) == 0: 274 | return 275 | 276 | with tf.io.gfile.GFile(tf.io.gfile.join(outpath_train, "out.npy"), "wb") as f: 277 | np.save(f, lst_train) 278 | with tf.io.gfile.GFile(tf.io.gfile.join(outpath_val, "out.npy"), "wb") as f: 279 | np.save(f, lst_val) 280 | 281 | # doesn't seem like these are ever used anymore 282 | # np.save(os.path.join(outpath_train, "out_rew.npy"), rew_train_l) 283 | # np.save(os.path.join(outpath_val, "out_rew.npy"), rew_val_l) 284 | 285 | 286 | def main(_): 287 | assert FLAGS.depth >= 1 288 | 289 | # each path is a directory that contains dated directories 290 | paths = glob.glob(os.path.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1)))) 291 | 292 | worker_fn = partial(make_numpy, train_proportion=FLAGS.train_proportion) 293 | 294 | with Pool(FLAGS.num_workers) as p: 295 | list(tqdm.tqdm(p.imap(worker_fn, paths), total=len(paths))) 296 | 297 | 298 | if __name__ == "__main__": 299 | app.run(main) 300 | -------------------------------------------------------------------------------- /experiments/configs/data_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | ACT_MEAN = [ 4 | 1.9296819e-04, 5 | 1.3667766e-04, 6 | -1.4583133e-04, 7 | -1.8390431e-04, 8 | -3.0808983e-04, 9 | 2.7425270e-04, 10 | 5.9716219e-01, 11 | ] 12 | 13 | ACT_STD = [ 14 | 0.00912848, 15 | 0.0127196, 16 | 0.01229497, 17 | 0.02606696, 18 | 0.02875283, 19 | 0.07807977, 20 | 0.48710242, 21 | ] 22 | 23 | ACTION_PROPRIO_METADATA = { 24 | "action": { 25 | "mean": ACT_MEAN, 26 | "std": ACT_STD, 27 | # TODO compute these 28 | "min": ACT_MEAN, 29 | "max": ACT_STD, 30 | }, 31 | # TODO compute these 32 | "proprio": {"mean": ACT_MEAN, "std": ACT_STD, "min": ACT_MEAN, "max": ACT_STD}, 33 | } 34 | 35 | 36 | def get_config(config_string): 37 | possible_structures = { 38 | "all": ml_collections.ConfigDict( 39 | { 40 | "include": [ 41 | [ 42 | "icra/?*/?*/?*", 43 | "flap/?*/?*/?*", 44 | "bridge_data_v1/berkeley/?*/?*", 45 | "rss/?*/?*/?*", 46 | "bridge_data_v2/?*/?*/?*", 47 | "scripted/?*", 48 | ] 49 | ], 50 | "exclude": [], 51 | "sample_weights": None, 52 | "action_proprio_metadata": ACTION_PROPRIO_METADATA, 53 | } 54 | ) 55 | } 56 | return possible_structures[config_string] 57 | -------------------------------------------------------------------------------- /experiments/configs/train_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | 4 | def get_config(config_string): 5 | base_real_config = dict( 6 | batch_size=256, 7 | num_steps=int(2e6), 8 | log_interval=100, 9 | eval_interval=5000, 10 | save_interval=5000, 11 | save_dir="path/to/save/dir", 12 | data_path="path/to/data", 13 | resume_path=None, 14 | seed=42, 15 | ) 16 | 17 | base_data_config = dict( 18 | shuffle_buffer_size=25000, 19 | augment=True, 20 | augment_next_obs_goal_differently=False, 21 | augment_kwargs=dict( 22 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 23 | random_brightness=[0.2], 24 | random_contrast=[0.8, 1.2], 25 | random_saturation=[0.8, 1.2], 26 | random_hue=[0.1], 27 | augment_order=[ 28 | "random_resized_crop", 29 | "random_brightness", 30 | "random_contrast", 31 | "random_saturation", 32 | "random_hue", 33 | ], 34 | ), 35 | ) 36 | 37 | possible_structures = { 38 | "gc_iql": ConfigDict( 39 | dict( 40 | agent="gc_iql", 41 | agent_kwargs=dict( 42 | network_kwargs=dict(hidden_dims=(256, 256, 256), dropout_rate=0.1), 43 | policy_kwargs=dict( 44 | tanh_squash_distribution=False, 45 | state_dependent_std=False, 46 | fixed_std=[1, 1, 1, 1, 1, 1, 1], 47 | ), 48 | learning_rate=3e-4, 49 | discount=0.98, 50 | expectile=0.7, 51 | temperature=1.0, 52 | target_update_rate=0.002, 53 | shared_encoder=True, 54 | early_goal_concat=True, 55 | shared_goal_encoder=True, 56 | use_proprio=False, 57 | negative_proportion=0.1, 58 | ), 59 | dataset_kwargs=dict( 60 | goal_relabeling_strategy="uniform", 61 | goal_relabeling_kwargs=dict(reached_proportion=0.1), 62 | relabel_actions=True, 63 | **base_data_config, 64 | ), 65 | encoder="resnetv1-34-bridge", 66 | encoder_kwargs=dict( 67 | pooling_method="avg", add_spatial_coordinates=True, act="swish" 68 | ), 69 | **base_real_config, 70 | ) 71 | ), 72 | "gc_bc": ConfigDict( 73 | dict( 74 | agent="gc_bc", 75 | agent_kwargs=dict( 76 | network_kwargs=dict(hidden_dims=(256, 256, 256), dropout_rate=0.1), 77 | policy_kwargs=dict( 78 | tanh_squash_distribution=False, 79 | fixed_std=[1, 1, 1, 1, 1, 1, 1], 80 | state_dependent_std=False, 81 | ), 82 | early_goal_concat=True, 83 | shared_goal_encoder=True, 84 | use_proprio=False, 85 | learning_rate=3e-4, 86 | warmup_steps=2000, 87 | decay_steps=int(2e6), 88 | ), 89 | dataset_kwargs=dict( 90 | goal_relabeling_strategy="uniform", 91 | goal_relabeling_kwargs=dict(reached_proportion=0.0), 92 | relabel_actions=True, 93 | **base_data_config, 94 | ), 95 | encoder="resnetv1-34-bridge", 96 | encoder_kwargs=dict( 97 | pooling_method="avg", add_spatial_coordinates=True, act="swish" 98 | ), 99 | **base_real_config, 100 | ) 101 | ), 102 | "lc_bc": ConfigDict( 103 | dict( 104 | agent="lc_bc", 105 | agent_kwargs=dict( 106 | network_kwargs=dict(hidden_dims=(256, 256, 256), dropout_rate=0.1), 107 | policy_kwargs=dict( 108 | tanh_squash_distribution=False, 109 | fixed_std=[1, 1, 1, 1, 1, 1, 1], 110 | state_dependent_std=False, 111 | ), 112 | early_goal_concat=True, 113 | shared_goal_encoder=True, 114 | use_proprio=False, 115 | learning_rate=3e-4, 116 | warmup_steps=2000, 117 | decay_steps=int(2e6), 118 | ), 119 | dataset_kwargs=dict( 120 | goal_relabeling_strategy="uniform", 121 | goal_relabeling_kwargs=dict(reached_proportion=0.0), 122 | relabel_actions=True, 123 | load_language=True, 124 | skip_unlabeled=True, 125 | **base_data_config, 126 | ), 127 | text_processor="muse_embedding", 128 | text_processor_kwargs=dict(), 129 | encoder="resnetv1-34-bridge-film", 130 | encoder_kwargs=dict( 131 | pooling_method="avg", add_spatial_coordinates=True, act="swish" 132 | ), 133 | **base_real_config, 134 | ) 135 | ), 136 | "gc_ddpm_bc": ConfigDict( 137 | dict( 138 | agent="gc_ddpm_bc", 139 | agent_kwargs=dict( 140 | score_network_kwargs=dict( 141 | time_dim=32, 142 | num_blocks=3, 143 | dropout_rate=0.1, 144 | hidden_dim=256, 145 | use_layer_norm=True, 146 | ), 147 | early_goal_concat=True, 148 | shared_goal_encoder=True, 149 | use_proprio=False, 150 | beta_schedule="cosine", 151 | diffusion_steps=20, 152 | action_samples=1, 153 | repeat_last_step=0, 154 | learning_rate=3e-4, 155 | warmup_steps=2000, 156 | actor_decay_steps=int(2e6), 157 | ), 158 | dataset_kwargs=dict( 159 | goal_relabeling_strategy="uniform", 160 | goal_relabeling_kwargs=dict(reached_proportion=0.0), 161 | relabel_actions=True, 162 | obs_horizon=1, 163 | act_pred_horizon=1, 164 | **base_data_config, 165 | ), 166 | encoder="resnetv1-34-bridge", 167 | encoder_kwargs=dict( 168 | pooling_method="avg", add_spatial_coordinates=True, act="swish" 169 | ), 170 | **base_real_config, 171 | ) 172 | ), 173 | "contrastive_rl_td": ConfigDict( 174 | dict( 175 | agent="stable_contrastive_rl", 176 | agent_kwargs=dict( 177 | critic_network_kwargs=dict( 178 | hidden_dims=(256, 256, 256), use_layer_norm=True 179 | ), 180 | critic_kwargs=dict(init_final=1e-12, repr_dim=16, twin_q=True), 181 | policy_network_kwargs=dict( 182 | hidden_dims=(256, 256, 256), dropout_rate=0.1 183 | ), 184 | policy_kwargs=dict( 185 | tanh_squash_distribution=False, 186 | state_dependent_std=False, 187 | fixed_std=[1, 1, 1, 1, 1, 1, 1], 188 | ), 189 | learning_rate=3e-4, 190 | warmup_steps=2000, 191 | actor_decay_steps=int(2e6), 192 | use_td=True, 193 | gcbc_coef=0.20, 194 | discount=0.98, 195 | temperature=1.0, 196 | target_update_rate=0.002, 197 | shared_encoder=False, 198 | early_goal_concat=False, 199 | shared_goal_encoder=True, 200 | use_proprio=False, 201 | ), 202 | dataset_kwargs=dict( 203 | goal_relabeling_strategy="uniform", 204 | goal_relabeling_kwargs=dict(reached_proportion=0.0), 205 | relabel_actions=True, 206 | **base_data_config, 207 | ), 208 | encoder="resnetv1-34-bridge", 209 | encoder_kwargs=dict( 210 | pooling_method="avg", add_spatial_coordinates=False, act="swish" 211 | ), 212 | **base_real_config, 213 | ) 214 | ), 215 | } 216 | 217 | return possible_structures[config_string] 218 | -------------------------------------------------------------------------------- /experiments/eval_gc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | from datetime import datetime 5 | import traceback 6 | from collections import deque 7 | import json 8 | 9 | from absl import app, flags, logging 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | import jax 15 | from PIL import Image 16 | import imageio 17 | 18 | from flax.training import checkpoints 19 | from jaxrl_m.vision import encoders 20 | from jaxrl_m.agents import agents 21 | 22 | # bridge_data_robot imports 23 | from widowx_envs.widowx_env import BridgeDataRailRLPrivateWidowX 24 | from multicam_server.topic_utils import IMTopic 25 | from utils import stack_obs 26 | 27 | np.set_printoptions(suppress=True) 28 | 29 | logging.set_verbosity(logging.WARNING) 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | flags.DEFINE_multi_string( 34 | "checkpoint_weights_path", None, "Path to checkpoint", required=True 35 | ) 36 | flags.DEFINE_multi_string( 37 | "checkpoint_config_path", None, "Path to checkpoint config JSON", required=True 38 | ) 39 | flags.DEFINE_integer("im_size", None, "Image size", required=True) 40 | flags.DEFINE_string("video_save_path", None, "Path to save video") 41 | flags.DEFINE_string("goal_image_path", None, "Path to a single goal image") 42 | flags.DEFINE_integer("num_timesteps", 120, "num timesteps") 43 | flags.DEFINE_bool("blocking", False, "Use the blocking controller") 44 | flags.DEFINE_spaceseplist("goal_eep", [0.3, 0.0, 0.15], "Goal position") 45 | flags.DEFINE_spaceseplist("initial_eep", [0.3, 0.0, 0.15], "Initial position") 46 | flags.DEFINE_integer("act_exec_horizon", 1, "Action sequence length") 47 | flags.DEFINE_bool("deterministic", True, "Whether to sample action deterministically") 48 | 49 | ############################################################################## 50 | 51 | STEP_DURATION = 0.2 52 | NO_PITCH_ROLL = False 53 | NO_YAW = False 54 | STICKY_GRIPPER_NUM_STEPS = 1 55 | WORKSPACE_BOUNDS = np.array([[0.1, -0.15, -0.1, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]) 56 | CAMERA_TOPICS = [IMTopic("/blue/image_raw")] 57 | FIXED_STD = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 58 | 59 | ############################################################################## 60 | 61 | def load_checkpoint(checkpoint_weights_path, checkpoint_config_path): 62 | with open(checkpoint_config_path, "r") as f: 63 | config = json.load(f) 64 | 65 | # create encoder from wandb config 66 | encoder_def = encoders[config["encoder"]](**config["encoder_kwargs"]) 67 | 68 | act_pred_horizon = config["dataset_kwargs"].get("act_pred_horizon") 69 | obs_horizon = config["dataset_kwargs"].get("obs_horizon") 70 | 71 | if act_pred_horizon is not None: 72 | example_actions = np.zeros((1, act_pred_horizon, 7), dtype=np.float32) 73 | else: 74 | example_actions = np.zeros((1, 7), dtype=np.float32) 75 | 76 | if obs_horizon is not None: 77 | example_obs = { 78 | "image": np.zeros( 79 | (1, obs_horizon, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8 80 | ) 81 | } 82 | else: 83 | example_obs = { 84 | "image": np.zeros((1, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) 85 | } 86 | 87 | example_batch = { 88 | "observations": example_obs, 89 | "goals": { 90 | "image": np.zeros((1, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) 91 | }, 92 | "actions": example_actions, 93 | } 94 | 95 | # create agent from wandb config 96 | rng = jax.random.PRNGKey(0) 97 | rng, construct_rng = jax.random.split(rng) 98 | agent = agents[config["agent"]].create( 99 | rng=construct_rng, 100 | observations=example_batch["observations"], 101 | goals=example_batch["goals"], 102 | actions=example_batch["actions"], 103 | encoder_def=encoder_def, 104 | **config["agent_kwargs"], 105 | ) 106 | 107 | # load action metadata from wandb 108 | action_proprio_metadata = config["bridgedata_config"]["action_proprio_metadata"] 109 | action_mean = np.array(action_proprio_metadata["action"]["mean"]) 110 | action_std = np.array(action_proprio_metadata["action"]["std"]) 111 | 112 | # hydrate agent with parameters from checkpoint 113 | agent = checkpoints.restore_checkpoint(checkpoint_weights_path, agent) 114 | 115 | def get_action(obs, goal_obs): 116 | nonlocal rng 117 | rng, key = jax.random.split(rng) 118 | action = jax.device_get( 119 | agent.sample_actions(obs, goal_obs, seed=key, argmax=FLAGS.deterministic) 120 | ) 121 | action = action * action_std + action_mean 122 | return action 123 | 124 | return get_action, obs_horizon 125 | 126 | 127 | def main(_): 128 | assert len(FLAGS.checkpoint_weights_path) == len(FLAGS.checkpoint_config_path) 129 | 130 | # policies is a dict from run_name to get_action function 131 | policies = {} 132 | for checkpoint_weights_path, checkpoint_config_path in zip( 133 | FLAGS.checkpoint_weights_path, FLAGS.checkpoint_config_path 134 | ): 135 | assert tf.io.gfile.exists(checkpoint_weights_path), checkpoint_weights_path 136 | checkpoint_num = int(checkpoint_weights_path.split("_")[-1]) 137 | run_name = checkpoint_config_path.split("/")[-1] 138 | policies[f"{run_name}-{checkpoint_num}"] = load_checkpoint( 139 | checkpoint_weights_path, checkpoint_config_path 140 | ) 141 | 142 | if FLAGS.initial_eep is not None: 143 | assert isinstance(FLAGS.initial_eep, list) 144 | initial_eep = [float(e) for e in FLAGS.initial_eep] 145 | start_state = np.concatenate([initial_eep, [0, 0, 0, 1]]) 146 | else: 147 | start_state = None 148 | 149 | # set up environment 150 | env_params = { 151 | "fix_zangle": 0.1, 152 | "move_duration": 0.2, 153 | "adaptive_wait": True, 154 | "move_to_rand_start_freq": 1, 155 | "override_workspace_boundaries": WORKSPACE_BOUNDS, 156 | "action_clipping": "xyz", 157 | "catch_environment_except": False, 158 | "start_state": start_state, 159 | "return_full_image": False, 160 | "camera_topics": CAMERA_TOPICS, 161 | } 162 | env = BridgeDataRailRLPrivateWidowX(env_params, fixed_image_size=FLAGS.im_size) 163 | 164 | # load image goal 165 | image_goal = None 166 | if FLAGS.goal_image_path is not None: 167 | image_goal = np.array(Image.open(FLAGS.goal_image_path)) 168 | 169 | # goal sampling loop 170 | while True: 171 | # ask for new goal 172 | if image_goal is None: 173 | print("Taking a new goal...") 174 | ch = "y" 175 | else: 176 | ch = input("Taking a new goal? [y/n]") 177 | if ch == "y": 178 | if FLAGS.goal_eep is not None: 179 | assert isinstance(FLAGS.goal_eep, list) 180 | goal_eep = [float(e) for e in FLAGS.goal_eep] 181 | else: 182 | low_bound = WORKSPACE_BOUNDS[0][:3] + 0.03 183 | high_bound = WORKSPACE_BOUNDS[1][:3] - 0.03 184 | goal_eep = np.random.uniform(low_bound, high_bound) 185 | env.controller().open_gripper(True) 186 | try: 187 | env.controller().move_to_state(goal_eep, 0, duration=1.5) 188 | env._reset_previous_qpos() 189 | except Exception as e: 190 | continue 191 | input("Press [Enter] when ready for taking the goal image. ") 192 | obs = env.current_obs() 193 | image_goal = ( 194 | obs["image"].reshape(3, FLAGS.im_size, FLAGS.im_size).transpose(1, 2, 0) 195 | * 255 196 | ).astype(np.uint8) 197 | 198 | # ask for which policy to use 199 | if len(policies) == 1: 200 | policy_idx = 0 201 | input("Press [Enter] to start.") 202 | else: 203 | print("policies:") 204 | for i, name in enumerate(policies.keys()): 205 | print(f"{i}) {name}") 206 | policy_idx = int(input("select policy: ")) 207 | 208 | policy_name = list(policies.keys())[policy_idx] 209 | get_action, obs_horizon = policies[policy_name] 210 | try: 211 | env.reset() 212 | env.start() 213 | except Exception as e: 214 | continue 215 | 216 | # move to initial position 217 | try: 218 | if FLAGS.initial_eep is not None: 219 | assert isinstance(FLAGS.initial_eep, list) 220 | initial_eep = [float(e) for e in FLAGS.initial_eep] 221 | env.controller().move_to_state(initial_eep, 0, duration=1.5) 222 | env._reset_previous_qpos() 223 | except Exception as e: 224 | continue 225 | 226 | # do rollout 227 | obs = env.current_obs() 228 | last_tstep = time.time() 229 | images = [] 230 | goals = [] 231 | t = 0 232 | if obs_horizon is not None: 233 | obs_hist = deque(maxlen=obs_horizon) 234 | # keep track of our own gripper state to implement sticky gripper 235 | is_gripper_closed = False 236 | num_consecutive_gripper_change_actions = 0 237 | try: 238 | while t < FLAGS.num_timesteps: 239 | if time.time() > last_tstep + STEP_DURATION or FLAGS.blocking: 240 | image_obs = ( 241 | obs["image"] 242 | .reshape(3, FLAGS.im_size, FLAGS.im_size) 243 | .transpose(1, 2, 0) 244 | * 255 245 | ).astype(np.uint8) 246 | obs = {"image": image_obs, "proprio": obs["state"]} 247 | goal_obs = {"image": image_goal} 248 | if obs_horizon is not None: 249 | if len(obs_hist) == 0: 250 | obs_hist.extend([obs] * obs_horizon) 251 | else: 252 | obs_hist.append(obs) 253 | obs = stack_obs(obs_hist) 254 | 255 | last_tstep = time.time() 256 | 257 | actions = get_action(obs, goal_obs) 258 | if len(actions.shape) == 1: 259 | actions = actions[None] 260 | for i in range(FLAGS.act_exec_horizon): 261 | action = actions[i] 262 | action += np.random.normal(0, FIXED_STD) 263 | 264 | # sticky gripper logic 265 | if (action[-1] < 0.5) != is_gripper_closed: 266 | num_consecutive_gripper_change_actions += 1 267 | else: 268 | num_consecutive_gripper_change_actions = 0 269 | 270 | if ( 271 | num_consecutive_gripper_change_actions 272 | >= STICKY_GRIPPER_NUM_STEPS 273 | ): 274 | is_gripper_closed = not is_gripper_closed 275 | num_consecutive_gripper_change_actions = 0 276 | 277 | action[-1] = 0.0 if is_gripper_closed else 1.0 278 | 279 | # remove degrees of freedom 280 | if NO_PITCH_ROLL: 281 | action[3] = 0 282 | action[4] = 0 283 | if NO_YAW: 284 | action[5] = 0 285 | 286 | # perform environment step 287 | obs, _, _, _ = env.step( 288 | action, last_tstep + STEP_DURATION, blocking=FLAGS.blocking 289 | ) 290 | 291 | # save image 292 | images.append(image_obs) 293 | goals.append(image_goal) 294 | 295 | t += 1 296 | except Exception as e: 297 | print(traceback.format_exc(), file=sys.stderr) 298 | 299 | # save video 300 | if FLAGS.video_save_path is not None: 301 | os.makedirs(FLAGS.video_save_path, exist_ok=True) 302 | curr_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 303 | save_path = os.path.join( 304 | FLAGS.video_save_path, 305 | f"{curr_time}_{policy_name}_sticky_{STICKY_GRIPPER_NUM_STEPS}.mp4", 306 | ) 307 | video = np.concatenate([np.stack(goals), np.stack(images)], axis=1) 308 | imageio.mimsave(save_path, video, fps=1.0 / STEP_DURATION * 3) 309 | 310 | 311 | if __name__ == "__main__": 312 | app.run(main) -------------------------------------------------------------------------------- /experiments/eval_lc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | from datetime import datetime 5 | import traceback 6 | from collections import deque 7 | import json 8 | 9 | from absl import app, flags, logging 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | import jax 15 | import imageio 16 | 17 | from flax.training import checkpoints 18 | from jaxrl_m.vision import encoders 19 | from jaxrl_m.agents import agents 20 | from jaxrl_m.data.text_processing import text_processors 21 | 22 | # bridge_data_robot imports 23 | from widowx_envs.widowx_env import BridgeDataRailRLPrivateWidowX 24 | from multicam_server.topic_utils import IMTopic 25 | from utils import stack_obs 26 | 27 | 28 | np.set_printoptions(suppress=True) 29 | 30 | logging.set_verbosity(logging.WARNING) 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_multi_string( 35 | "checkpoint_weights_path", None, "Path to checkpoint", required=True 36 | ) 37 | flags.DEFINE_multi_string( 38 | "checkpoint_config_path", None, "Path to checkpoint config JSON", required=True 39 | ) 40 | flags.DEFINE_integer("im_size", None, "Image size", required=True) 41 | flags.DEFINE_string("video_save_path", None, "Path to save video") 42 | flags.DEFINE_integer("num_timesteps", 120, "num timesteps") 43 | flags.DEFINE_bool("blocking", False, "Use the blocking controller") 44 | flags.DEFINE_spaceseplist("initial_eep", [0.3, 0.0, 0.15], "Initial position") 45 | flags.DEFINE_integer("act_exec_horizon", 1, "Action sequence length") 46 | flags.DEFINE_bool("deterministic", True, "Whether to sample action deterministically") 47 | 48 | ############################################################################## 49 | 50 | STEP_DURATION = 0.2 51 | NO_PITCH_ROLL = False 52 | NO_YAW = False 53 | STICKY_GRIPPER_NUM_STEPS = 1 54 | WORKSPACE_BOUNDS = np.array([[0.1, -0.15, -0.1, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]) 55 | CAMERA_TOPICS = [IMTopic("/blue/image_raw")] 56 | FIXED_STD = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 57 | 58 | ############################################################################## 59 | 60 | def load_checkpoint(checkpoint_weights_path, checkpoint_config_path): 61 | with open(checkpoint_config_path, "r") as f: 62 | config = json.load(f) 63 | 64 | # create encoder from wandb config 65 | encoder_def = encoders[config["encoder"]](**config["encoder_kwargs"]) 66 | 67 | act_pred_horizon = config["dataset_kwargs"].get("act_pred_horizon") 68 | obs_horizon = config["dataset_kwargs"].get("obs_horizon") 69 | 70 | if act_pred_horizon is not None: 71 | example_actions = np.zeros((1, act_pred_horizon, 7), dtype=np.float32) 72 | else: 73 | example_actions = np.zeros((1, 7), dtype=np.float32) 74 | 75 | if obs_horizon is not None: 76 | example_obs = { 77 | "image": np.zeros( 78 | (1, obs_horizon, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8 79 | ) 80 | } 81 | else: 82 | example_obs = { 83 | "image": np.zeros((1, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) 84 | } 85 | 86 | example_batch = { 87 | "observations": example_obs, 88 | "goals": {"language": np.zeros((1, 512), dtype=np.float32)}, 89 | "actions": example_actions, 90 | } 91 | 92 | # create agent from wandb config 93 | rng = jax.random.PRNGKey(0) 94 | rng, construct_rng = jax.random.split(rng) 95 | agent = agents[config["agent"]].create( 96 | rng=construct_rng, 97 | observations=example_batch["observations"], 98 | goals=example_batch["goals"], 99 | actions=example_batch["actions"], 100 | encoder_def=encoder_def, 101 | **config["agent_kwargs"], 102 | ) 103 | 104 | # load action metadata from wandb 105 | action_proprio_metadata = config["bridgedata_config"]["action_proprio_metadata"] 106 | action_mean = np.array(action_proprio_metadata["action"]["mean"]) 107 | action_std = np.array(action_proprio_metadata["action"]["std"]) 108 | 109 | # hydrate agent with parameters from checkpoint 110 | agent = checkpoints.restore_checkpoint(checkpoint_weights_path, agent) 111 | 112 | def get_action(obs, goal_obs): 113 | nonlocal rng 114 | rng, key = jax.random.split(rng) 115 | action = jax.device_get( 116 | agent.sample_actions(obs, goal_obs, seed=key, argmax=FLAGS.deterministic) 117 | ) 118 | action = action * action_std + action_mean 119 | return action 120 | 121 | text_processor = text_processors[config["text_processor"]]( 122 | **config["text_processor_kwargs"] 123 | ) 124 | 125 | return get_action, text_processor, obs_horizon 126 | 127 | 128 | def main(_): 129 | assert len(FLAGS.checkpoint_weights_path) == len(FLAGS.checkpoint_config_path) 130 | 131 | # policies is a dict from run_name to get_action function 132 | policies = {} 133 | for checkpoint_weights_path, checkpoint_config_path in zip( 134 | FLAGS.checkpoint_weights_path, FLAGS.checkpoint_config_path 135 | ): 136 | assert tf.io.gfile.exists(checkpoint_weights_path), checkpoint_weights_path 137 | checkpoint_num = int(checkpoint_weights_path.split("_")[-1]) 138 | run_name = checkpoint_config_path.split("/")[-1] 139 | policies[f"{run_name}-{checkpoint_num}"] = load_checkpoint( 140 | checkpoint_weights_path, checkpoint_config_path 141 | ) 142 | 143 | if FLAGS.initial_eep is not None: 144 | assert isinstance(FLAGS.initial_eep, list) 145 | initial_eep = [float(e) for e in FLAGS.initial_eep] 146 | start_state = np.concatenate([initial_eep, [0, 0, 0, 1]]) 147 | else: 148 | start_state = None 149 | 150 | # set up environment 151 | env_params = { 152 | "fix_zangle": 0.1, 153 | "move_duration": 0.2, 154 | "adaptive_wait": True, 155 | "move_to_rand_start_freq": 1, 156 | "override_workspace_boundaries": WORKSPACE_BOUNDS, 157 | "action_clipping": "xyz", 158 | "catch_environment_except": False, 159 | "start_state": start_state, 160 | "return_full_image": False, 161 | "camera_topics": CAMERA_TOPICS, 162 | } 163 | env = BridgeDataRailRLPrivateWidowX(env_params, fixed_image_size=FLAGS.im_size) 164 | 165 | instruction = None 166 | 167 | # instruction sampling loop 168 | while True: 169 | # ask for which policy to use 170 | if len(policies) == 1: 171 | policy_idx = 0 172 | input("Press [Enter] to start.") 173 | else: 174 | print("policies:") 175 | for i, name in enumerate(policies.keys()): 176 | print(f"{i}) {name}") 177 | policy_idx = int(input("select policy: ")) 178 | 179 | policy_name = list(policies.keys())[policy_idx] 180 | get_action, text_processor, obs_horizon = policies[policy_name] 181 | 182 | # ask for new instruction 183 | if instruction is None: 184 | ch = "y" 185 | else: 186 | ch = input("New instruction? [y/n]") 187 | if ch == "y": 188 | instruction = text_processor.encode(input("Instruction?")) 189 | 190 | try: 191 | env.reset() 192 | env.start() 193 | except Exception as e: 194 | continue 195 | 196 | # move to initial position 197 | try: 198 | if FLAGS.initial_eep is not None: 199 | assert isinstance(FLAGS.initial_eep, list) 200 | initial_eep = [float(e) for e in FLAGS.initial_eep] 201 | env.controller().move_to_state(initial_eep, 0, duration=1.5) 202 | env._reset_previous_qpos() 203 | except Exception as e: 204 | continue 205 | 206 | # do rollout 207 | obs = env.current_obs() 208 | last_tstep = time.time() 209 | images = [] 210 | t = 0 211 | if obs_horizon is not None: 212 | obs_hist = deque(maxlen=obs_horizon) 213 | # keep track of our own gripper state to implement sticky gripper 214 | is_gripper_closed = False 215 | num_consecutive_gripper_change_actions = 0 216 | try: 217 | while t < FLAGS.num_timesteps: 218 | if time.time() > last_tstep + STEP_DURATION or FLAGS.blocking: 219 | image_obs = ( 220 | obs["image"] 221 | .reshape(3, FLAGS.im_size, FLAGS.im_size) 222 | .transpose(1, 2, 0) 223 | * 255 224 | ).astype(np.uint8) 225 | obs = {"image": image_obs, "proprio": obs["state"]} 226 | goal_obs = {"language": instruction} 227 | if obs_horizon is not None: 228 | if len(obs_hist) == 0: 229 | obs_hist.extend([obs] * obs_horizon) 230 | else: 231 | obs_hist.append(obs) 232 | obs = stack_obs(obs_hist) 233 | 234 | last_tstep = time.time() 235 | 236 | actions = get_action(obs, goal_obs) 237 | if len(actions.shape) == 1: 238 | actions = actions[None] 239 | for i in range(FLAGS.act_exec_horizon): 240 | action = actions[i] 241 | action += np.random.normal(0, FIXED_STD) 242 | 243 | # sticky gripper logic 244 | if (action[-1] < 0.5) != is_gripper_closed: 245 | num_consecutive_gripper_change_actions += 1 246 | else: 247 | num_consecutive_gripper_change_actions = 0 248 | 249 | if ( 250 | num_consecutive_gripper_change_actions 251 | >= STICKY_GRIPPER_NUM_STEPS 252 | ): 253 | is_gripper_closed = not is_gripper_closed 254 | num_consecutive_gripper_change_actions = 0 255 | 256 | action[-1] = 0.0 if is_gripper_closed else 1.0 257 | 258 | # remove degrees of freedom 259 | if NO_PITCH_ROLL: 260 | action[3] = 0 261 | action[4] = 0 262 | if NO_YAW: 263 | action[5] = 0 264 | 265 | # perform environment step 266 | obs, _, _, _ = env.step( 267 | action, last_tstep + STEP_DURATION, blocking=FLAGS.blocking 268 | ) 269 | 270 | # save image 271 | images.append(image_obs) 272 | 273 | t += 1 274 | except Exception as e: 275 | print(traceback.format_exc(), file=sys.stderr) 276 | 277 | # save video 278 | if FLAGS.video_save_path is not None: 279 | os.makedirs(FLAGS.video_save_path, exist_ok=True) 280 | curr_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 281 | save_path = os.path.join( 282 | FLAGS.video_save_path, 283 | f"{curr_time}_{policy_name}_sticky_{STICKY_GRIPPER_NUM_STEPS}.mp4", 284 | ) 285 | imageio.mimsave(save_path, images, fps=1.0 / STEP_DURATION * 3) 286 | 287 | 288 | if __name__ == "__main__": 289 | app.run(main) 290 | -------------------------------------------------------------------------------- /experiments/susie/calvin/README.md: -------------------------------------------------------------------------------- 1 | # Goal-conditioned policy training code for SuSIE paper, specific to CALVIN dataset 2 | 3 | ## To train goal reaching or language conditioned policies, you will first need to prepare the CALVIN dataset for training 4 | 5 | 1. Download the full CALVIN dataset following instructions in https://github.com/mees/calvin 6 | 2. Run the scripts in ```dataset_conversion_scripts``` to create two versions of the dataset in TFRecord format. Appropriately set path strings 7 | 8 | ## Training GCBC and LCBC policies 9 | 10 | The GCBC policy, which is used as the low-level policy in SuSIE, and the baseline LCBC policy can be trained via the corresonding scripts in the ```scripts``` subdirectory. Make sure to first replace the placeholder path strings in the ```configs``` folder with the correct strings for your setup. 11 | -------------------------------------------------------------------------------- /experiments/susie/calvin/calvin_gcbc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import tensorflow as tf 8 | import tqdm 9 | from absl import app, flags, logging 10 | from flax.training import checkpoints 11 | from ml_collections import config_flags 12 | 13 | from jaxrl_m.agents import agents 14 | from jaxrl_m.common.common import shard_batch 15 | from jaxrl_m.common.wandb import WandBLogger 16 | from jaxrl_m.data.calvin_dataset import CalvinDataset, glob_to_path_list 17 | from jaxrl_m.utils.timer_utils import Timer 18 | from jaxrl_m.vision import encoders 19 | 20 | try: 21 | from jax_smi import initialise_tracking # type: ignore 22 | 23 | initialise_tracking() 24 | except ImportError: 25 | pass 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | flags.DEFINE_string("name", "", "Experiment name.") 30 | flags.DEFINE_bool("debug", False, "Debug config") 31 | 32 | config_flags.DEFINE_config_file( 33 | "config", 34 | None, 35 | "File path to the training hyperparameter configuration.", 36 | lock_config=False, 37 | ) 38 | 39 | config_flags.DEFINE_config_file( 40 | "calvin_dataset_config", 41 | None, 42 | "File path to the CALVIN dataset configuration.", 43 | lock_config=False, 44 | ) 45 | 46 | 47 | def main(_): 48 | devices = jax.local_devices() 49 | num_devices = len(devices) 50 | assert FLAGS.config.batch_size % num_devices == 0 51 | 52 | # prevent tensorflow from using GPUs 53 | tf.config.set_visible_devices([], "GPU") 54 | 55 | # set up wandb and logging 56 | wandb_config = WandBLogger.get_default_config() 57 | wandb_config.update( 58 | { 59 | "project": "jaxrl_m_calvin_gcbc", 60 | "exp_descriptor": FLAGS.name, 61 | } 62 | ) 63 | wandb_logger = WandBLogger( 64 | wandb_config=wandb_config, 65 | variant=FLAGS.config.to_dict(), 66 | debug=FLAGS.debug, 67 | ) 68 | 69 | save_dir = tf.io.gfile.join( 70 | FLAGS.config.save_dir, 71 | wandb_logger.config.project, 72 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}", 73 | ) 74 | 75 | # load datasets 76 | assert type(FLAGS.calvin_dataset_config.include[0]) == list 77 | task_paths = [ 78 | glob_to_path_list( 79 | path, prefix=FLAGS.config.data_path, exclude=FLAGS.calvin_dataset_config.exclude 80 | ) 81 | for path in FLAGS.calvin_dataset_config.include 82 | ] 83 | 84 | train_paths = [task_paths[0]] 85 | val_paths = [task_paths[1]] 86 | 87 | obs_horizon = FLAGS.config.get("obs_horizon") 88 | 89 | train_data = CalvinDataset( 90 | train_paths, 91 | FLAGS.config.seed, 92 | batch_size=FLAGS.config.batch_size, 93 | num_devices=num_devices, 94 | train=True, 95 | action_proprio_metadata=FLAGS.calvin_dataset_config.action_proprio_metadata, 96 | sample_weights=FLAGS.calvin_dataset_config.sample_weights, 97 | **FLAGS.config.dataset_kwargs, 98 | ) 99 | val_data = CalvinDataset( 100 | val_paths, 101 | FLAGS.config.seed, 102 | batch_size=FLAGS.config.batch_size, 103 | action_proprio_metadata=FLAGS.calvin_dataset_config.action_proprio_metadata, 104 | train=False, 105 | **FLAGS.config.dataset_kwargs, 106 | ) 107 | train_data_iter = train_data.iterator() 108 | 109 | example_batch = next(train_data_iter) 110 | logging.info(f"Batch size: {example_batch['observations']['image'].shape[0]}") 111 | logging.info(f"Number of devices: {num_devices}") 112 | logging.info( 113 | f"Batch size per device: {example_batch['observations']['image'].shape[0] // num_devices}" 114 | ) 115 | 116 | # we shard the leading dimension (batch dimension) accross all devices evenly 117 | sharding = jax.sharding.PositionalSharding(devices) 118 | example_batch = shard_batch(example_batch, sharding) 119 | 120 | # define encoder 121 | encoder_def = encoders[FLAGS.config.encoder](**FLAGS.config.encoder_kwargs) 122 | 123 | # initialize agent 124 | rng = jax.random.PRNGKey(FLAGS.config.seed) 125 | rng, construct_rng = jax.random.split(rng) 126 | agent = agents[FLAGS.config.agent].create( 127 | rng=construct_rng, 128 | observations=example_batch["observations"], 129 | goals=example_batch["goals"], 130 | actions=example_batch["actions"], 131 | encoder_def=encoder_def, 132 | **FLAGS.config.agent_kwargs, 133 | ) 134 | 135 | if FLAGS.config.resume_path is not None: 136 | agent = checkpoints.restore_checkpoint(FLAGS.config.resume_path, target=agent) 137 | # replicate agent across devices 138 | # need the jnp.array to avoid a bug where device_put doesn't recognize primitives 139 | agent = jax.device_put(jax.tree_map(jnp.array, agent), sharding.replicate()) 140 | 141 | timer = Timer() 142 | for i in tqdm.tqdm(range(int(FLAGS.config.num_steps))): 143 | timer.tick("total") 144 | 145 | timer.tick("dataset") 146 | batch = shard_batch(next(train_data_iter), sharding) 147 | timer.tock("dataset") 148 | 149 | timer.tick("train") 150 | agent, update_info = agent.update(batch) 151 | timer.tock("train") 152 | 153 | if (i + 1) % FLAGS.config.eval_interval == 0: 154 | logging.info("Evaluating...") 155 | timer.tick("val") 156 | metrics = [] 157 | for batch in val_data.iterator(): 158 | rng, val_rng = jax.random.split(rng) 159 | metrics.append(agent.get_debug_metrics(batch, seed=val_rng)) 160 | metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) 161 | wandb_logger.log({"validation": metrics}, step=i) 162 | timer.tock("val") 163 | 164 | if (i + 1) % FLAGS.config.save_interval == 0: 165 | logging.info("Saving checkpoint...") 166 | checkpoint_path = checkpoints.save_checkpoint( 167 | save_dir, agent, step=i + 1, keep=1e6 168 | ) 169 | logging.info("Saved checkpoint to %s", checkpoint_path) 170 | 171 | timer.tock("total") 172 | 173 | if (i + 1) % FLAGS.config.log_interval == 0: 174 | update_info = jax.device_get(update_info) 175 | wandb_logger.log({"training": update_info}, step=i) 176 | 177 | wandb_logger.log({"timer": timer.get_average_times()}, step=i) 178 | 179 | 180 | if __name__ == "__main__": 181 | app.run(main) 182 | -------------------------------------------------------------------------------- /experiments/susie/calvin/calvin_lcbc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import tensorflow as tf 8 | import tqdm 9 | from absl import app, flags, logging 10 | from flax.training import checkpoints 11 | from ml_collections import config_flags 12 | 13 | from jaxrl_m.agents import agents 14 | from jaxrl_m.common.common import shard_batch 15 | from jaxrl_m.common.wandb import WandBLogger 16 | from jaxrl_m.data.calvin_dataset import CalvinDataset, glob_to_path_list 17 | from jaxrl_m.utils.timer_utils import Timer 18 | from jaxrl_m.vision import encoders 19 | from jaxrl_m.data.text_processing import text_processors 20 | 21 | try: 22 | from jax_smi import initialise_tracking # type: ignore 23 | 24 | initialise_tracking() 25 | except ImportError: 26 | pass 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("name", "", "Experiment name.") 31 | flags.DEFINE_bool("debug", False, "Debug config") 32 | 33 | config_flags.DEFINE_config_file( 34 | "config", 35 | None, 36 | "File path to the training hyperparameter configuration.", 37 | lock_config=False, 38 | ) 39 | 40 | config_flags.DEFINE_config_file( 41 | "calvin_dataset_config", 42 | None, 43 | "File path to the CALVIN dataset configuration.", 44 | lock_config=False, 45 | ) 46 | 47 | 48 | def main(_): 49 | devices = jax.local_devices() 50 | num_devices = len(devices) 51 | assert FLAGS.config.batch_size % num_devices == 0 52 | 53 | # we shard the leading dimension (batch dimension) accross all devices evenly 54 | sharding = jax.sharding.PositionalSharding(devices) 55 | shard_fn = partial(shard_batch, sharding=sharding) 56 | 57 | # prevent tensorflow from using GPUs 58 | tf.config.set_visible_devices([], "GPU") 59 | 60 | # set up wandb and logging 61 | wandb_config = WandBLogger.get_default_config() 62 | wandb_config.update( 63 | { 64 | "project": "jaxrl_m_calvin_lcbc", 65 | "exp_descriptor": FLAGS.name, 66 | } 67 | ) 68 | wandb_logger = WandBLogger( 69 | wandb_config=wandb_config, 70 | variant=FLAGS.config.to_dict(), 71 | debug=FLAGS.debug, 72 | ) 73 | 74 | save_dir = tf.io.gfile.join( 75 | FLAGS.config.save_dir, 76 | wandb_logger.config.project, 77 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}", 78 | ) 79 | 80 | # load datasets 81 | assert type(FLAGS.calvin_dataset_config.include[0]) == list 82 | task_paths = [ 83 | glob_to_path_list( 84 | path, prefix=FLAGS.config.data_path, exclude=FLAGS.calvin_dataset_config.exclude 85 | ) 86 | for path in FLAGS.calvin_dataset_config.include 87 | ] 88 | 89 | train_paths = [task_paths[0]] 90 | val_paths = [task_paths[1]] 91 | 92 | obs_horizon = FLAGS.config.get("obs_horizon") 93 | 94 | train_data = CalvinDataset( 95 | train_paths, 96 | FLAGS.config.seed, 97 | batch_size=FLAGS.config.batch_size, 98 | train=True, 99 | action_proprio_metadata=FLAGS.calvin_dataset_config.action_proprio_metadata, 100 | sample_weights=FLAGS.calvin_dataset_config.sample_weights, 101 | **FLAGS.config.dataset_kwargs, 102 | ) 103 | val_data = CalvinDataset( 104 | val_paths, 105 | FLAGS.config.seed, 106 | batch_size=FLAGS.config.batch_size, 107 | action_proprio_metadata=FLAGS.calvin_dataset_config.action_proprio_metadata, 108 | train=False, 109 | **FLAGS.config.dataset_kwargs, 110 | ) 111 | 112 | if FLAGS.config.text_processor is None: 113 | text_processor = None 114 | else: 115 | text_processor = text_processors[FLAGS.config.text_processor]( 116 | **FLAGS.config.text_processor_kwargs 117 | ) 118 | 119 | def process_text(batch): 120 | if text_processor is None: 121 | batch["goals"].pop("language") 122 | else: 123 | batch["goals"]["language"] = text_processor.encode( 124 | #[s.decode("utf-8") for s in batch["goals"]["language"]] 125 | [s for s in batch["goals"]["language"]] 126 | ) 127 | return batch 128 | train_data_iter = map(shard_fn, map(process_text, train_data.tf_dataset.as_numpy_iterator())) 129 | 130 | example_batch = next(train_data_iter) 131 | logging.info(f"Batch size: {example_batch['observations']['image'].shape[0]}") 132 | logging.info(f"Number of devices: {num_devices}") 133 | logging.info( 134 | f"Batch size per device: {example_batch['observations']['image'].shape[0] // num_devices}" 135 | ) 136 | 137 | # define encoder 138 | encoder_def = encoders[FLAGS.config.encoder](**FLAGS.config.encoder_kwargs) 139 | 140 | # initialize agent 141 | rng = jax.random.PRNGKey(FLAGS.config.seed) 142 | rng, construct_rng = jax.random.split(rng) 143 | agent = agents[FLAGS.config.agent].create( 144 | rng=construct_rng, 145 | observations=example_batch["observations"], 146 | goals=example_batch["goals"], 147 | actions=example_batch["actions"], 148 | encoder_def=encoder_def, 149 | **FLAGS.config.agent_kwargs, 150 | ) 151 | if FLAGS.config.resume_path is not None: 152 | agent = checkpoints.restore_checkpoint(FLAGS.config.resume_path, target=agent) 153 | # replicate agent across devices 154 | # need the jnp.array to avoid a bug where device_put doesn't recognize primitives 155 | agent = jax.device_put(jax.tree_map(jnp.array, agent), sharding.replicate()) 156 | 157 | timer = Timer() 158 | for i in tqdm.tqdm(range(int(FLAGS.config.num_steps))): 159 | timer.tick("total") 160 | 161 | timer.tick("dataset") 162 | batch = next(train_data_iter) 163 | timer.tock("dataset") 164 | 165 | timer.tick("train") 166 | agent, update_info = agent.update(batch) 167 | timer.tock("train") 168 | 169 | if (i + 1) % FLAGS.config.eval_interval == 0: 170 | logging.info("Evaluating...") 171 | timer.tick("val") 172 | metrics = [] 173 | val_data_iter = map(shard_fn, map(process_text, val_data.tf_dataset.as_numpy_iterator())) 174 | for _, batch in zip(range(FLAGS.config.num_val_batches), val_data_iter): 175 | rng, val_rng = jax.random.split(rng) 176 | metrics.append(agent.get_debug_metrics(batch, seed=val_rng)) 177 | metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) 178 | wandb_logger.log({"validation": metrics}, step=i) 179 | timer.tock("val") 180 | 181 | if (i + 1) % FLAGS.config.save_interval == 0: 182 | logging.info("Saving checkpoint...") 183 | checkpoint_path = checkpoints.save_checkpoint( 184 | save_dir, agent, step=i + 1, keep=1e6 185 | ) 186 | logging.info("Saved checkpoint to %s", checkpoint_path) 187 | 188 | timer.tock("total") 189 | 190 | if (i + 1) % FLAGS.config.log_interval == 0: 191 | update_info = jax.device_get(update_info) 192 | wandb_logger.log({"training": update_info}, step=i) 193 | 194 | wandb_logger.log({"timer": timer.get_average_times()}, step=i) 195 | 196 | 197 | if __name__ == "__main__": 198 | app.run(main) 199 | -------------------------------------------------------------------------------- /experiments/susie/calvin/configs/gcbc_data_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | ACT_MEAN = [ 4 | 2.9842544e-04, 5 | -2.6099570e-04, 6 | -1.5863389e-04, 7 | 5.8916201e-05, 8 | -4.4560504e-05, 9 | 8.2349771e-04, 10 | 9.4075650e-02, 11 | ] 12 | 13 | ACT_STD = [ 14 | 0.27278143, 15 | 0.23548537, 16 | 0.2196189, 17 | 0.15881406, 18 | 0.17537235, 19 | 0.27875036, 20 | 1.0049515, 21 | ] 22 | 23 | PROPRIO_MEAN = [ # We don't actually use proprio so we're using dummy values for this 24 | 0.0, 25 | 0.0, 26 | 0.0, 27 | 0.0, 28 | 0.0, 29 | 0.0, 30 | 0.0, 31 | 0.0, 32 | 0.0, 33 | 0.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 0.0, 38 | 0.0, 39 | ] 40 | 41 | PROPRIO_STD = [ # We don't actually use proprio so we're using dummy values for this 42 | 1.0, 43 | 1.0, 44 | 1.0, 45 | 1.0, 46 | 1.0, 47 | 1.0, 48 | 1.0, 49 | 1.0, 50 | 1.0, 51 | 1.0, 52 | 1.0, 53 | 1.0, 54 | 1.0, 55 | 1.0, 56 | 1.0, 57 | ] 58 | 59 | ACTION_PROPRIO_METADATA = { 60 | "action": { 61 | "mean": ACT_MEAN, 62 | "std": ACT_STD, 63 | # TODO compute these 64 | "min": ACT_MEAN, 65 | "max": ACT_STD 66 | }, 67 | # TODO compute these 68 | "proprio": { 69 | "mean": PROPRIO_MEAN, 70 | "std": PROPRIO_STD, 71 | "min": PROPRIO_MEAN, 72 | "max": PROPRIO_STD 73 | } 74 | } 75 | 76 | 77 | def get_config(config_string): 78 | possible_structures = { 79 | "all": ml_collections.ConfigDict( 80 | { 81 | "include": [ 82 | [ 83 | "training/A/?*/?*", 84 | "training/B/?*/?*", 85 | "training/C/?*/?*" 86 | ], 87 | [ 88 | "validation/D/?*/?*", 89 | ] 90 | ], 91 | "exclude": [], 92 | "sample_weights": None, 93 | "action_proprio_metadata": ACTION_PROPRIO_METADATA 94 | } 95 | ), 96 | } 97 | return possible_structures[config_string] 98 | -------------------------------------------------------------------------------- /experiments/susie/calvin/configs/gcbc_train_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | 4 | def get_config(config_string): 5 | base_real_config = dict( 6 | batch_size=256, 7 | num_val_batches=8, 8 | num_steps=int(2e6), 9 | log_interval=1000, 10 | eval_interval=2000, 11 | save_interval=2000, 12 | save_dir="", 13 | data_path="", 14 | resume_path=None, 15 | seed=42, 16 | ) 17 | 18 | base_data_config = dict( 19 | shuffle_buffer_size=25000, 20 | prefetch_num_batches=20, 21 | augment=True, 22 | augment_next_obs_goal_differently=False, 23 | augment_kwargs=dict( 24 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 25 | random_brightness=[0.2], 26 | random_contrast=[0.8, 1.2], 27 | random_saturation=[0.8, 1.2], 28 | random_hue=[0.1], 29 | augment_order=[ 30 | "random_resized_crop", 31 | "random_brightness", 32 | "random_contrast", 33 | "random_saturation", 34 | "random_hue", 35 | ], 36 | ), 37 | ) 38 | 39 | # params that need to be specified multiple places 40 | normalization_type = "normal" 41 | 42 | possible_structures = { 43 | "gc_ddpm_bc": ConfigDict( 44 | dict( 45 | agent="gc_ddpm_bc", 46 | agent_kwargs=dict( 47 | score_network_kwargs=dict( 48 | time_dim=32, 49 | num_blocks=3, 50 | dropout_rate=0.1, 51 | hidden_dim=256, 52 | use_layer_norm=True, 53 | ), 54 | #language_conditioned=True, 55 | early_goal_concat=True, 56 | shared_goal_encoder=True, 57 | use_proprio=False, 58 | beta_schedule="cosine", 59 | diffusion_steps=20, 60 | action_samples=1, 61 | repeat_last_step=0, 62 | learning_rate=3e-4, 63 | warmup_steps=2000, 64 | actor_decay_steps=int(2e6), 65 | ), 66 | dataset_kwargs=dict( 67 | goal_relabeling_strategy="delta_goals", 68 | goal_relabeling_kwargs=dict(goal_delta=[0, 24]), 69 | #goal_relabeling_strategy="uniform", 70 | #goal_relabeling_kwargs=dict(reached_proportion=0.0), 71 | #load_language=True, 72 | #skip_unlabeled=True, 73 | relabel_actions=False, 74 | act_pred_horizon=4, 75 | obs_horizon=1, 76 | **base_data_config, 77 | ), 78 | #text_processor="muse_embedding", 79 | #text_processor_kwargs=dict(), 80 | encoder="resnetv1-34-bridge", 81 | encoder_kwargs=dict( 82 | pooling_method="avg", 83 | add_spatial_coordinates=True, 84 | act="swish", 85 | ), 86 | **base_real_config, 87 | ) 88 | ), 89 | } 90 | 91 | return possible_structures[config_string] 92 | -------------------------------------------------------------------------------- /experiments/susie/calvin/configs/lcbc_data_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | ACT_MEAN = [ 4 | 2.9842544e-04, 5 | -2.6099570e-04, 6 | -1.5863389e-04, 7 | 5.8916201e-05, 8 | -4.4560504e-05, 9 | 8.2349771e-04, 10 | 9.4075650e-02, 11 | ] 12 | 13 | ACT_STD = [ 14 | 0.27278143, 15 | 0.23548537, 16 | 0.2196189, 17 | 0.15881406, 18 | 0.17537235, 19 | 0.27875036, 20 | 1.0049515, 21 | ] 22 | 23 | PROPRIO_MEAN = [ # We don't actually use proprio so we're using dummy values for this 24 | 0.0, 25 | 0.0, 26 | 0.0, 27 | 0.0, 28 | 0.0, 29 | 0.0, 30 | 0.0, 31 | 0.0, 32 | 0.0, 33 | 0.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 0.0, 38 | 0.0, 39 | ] 40 | 41 | PROPRIO_STD = [ # We don't actually use proprio so we're using dummy values for this 42 | 1.0, 43 | 1.0, 44 | 1.0, 45 | 1.0, 46 | 1.0, 47 | 1.0, 48 | 1.0, 49 | 1.0, 50 | 1.0, 51 | 1.0, 52 | 1.0, 53 | 1.0, 54 | 1.0, 55 | 1.0, 56 | 1.0, 57 | ] 58 | 59 | ACTION_PROPRIO_METADATA = { 60 | "action": { 61 | "mean": ACT_MEAN, 62 | "std": ACT_STD, 63 | # TODO compute these 64 | "min": ACT_MEAN, 65 | "max": ACT_STD 66 | }, 67 | # TODO compute these 68 | "proprio": { 69 | "mean": PROPRIO_MEAN, 70 | "std": PROPRIO_STD, 71 | "min": PROPRIO_MEAN, 72 | "max": PROPRIO_STD 73 | } 74 | } 75 | 76 | 77 | def get_config(config_string): 78 | possible_structures = { 79 | "all": ml_collections.ConfigDict( 80 | { 81 | "include": [ 82 | [ 83 | "training/A/?*", 84 | "training/B/?*", 85 | "training/C/?*" 86 | ], 87 | [ 88 | "validation/D/?*", 89 | ] 90 | ], 91 | "exclude": [], 92 | "sample_weights": None, 93 | "action_proprio_metadata": ACTION_PROPRIO_METADATA 94 | } 95 | ), 96 | } 97 | return possible_structures[config_string] 98 | -------------------------------------------------------------------------------- /experiments/susie/calvin/configs/lcbc_train_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | 4 | def get_config(config_string): 5 | base_real_config = dict( 6 | batch_size=256, 7 | num_val_batches=8, 8 | num_steps=int(2e6), 9 | log_interval=1000, 10 | eval_interval=2000, 11 | save_interval=2000, 12 | save_dir="", 13 | data_path="", 14 | resume_path=None, 15 | seed=42, 16 | ) 17 | 18 | base_data_config = dict( 19 | shuffle_buffer_size=25000, 20 | prefetch_num_batches=20, 21 | augment=True, 22 | augment_next_obs_goal_differently=False, 23 | augment_kwargs=dict( 24 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 25 | random_brightness=[0.2], 26 | random_contrast=[0.8, 1.2], 27 | random_saturation=[0.8, 1.2], 28 | random_hue=[0.1], 29 | augment_order=[ 30 | "random_resized_crop", 31 | "random_brightness", 32 | "random_contrast", 33 | "random_saturation", 34 | "random_hue", 35 | ], 36 | ), 37 | ) 38 | 39 | # params that need to be specified multiple places 40 | normalization_type = "normal" 41 | 42 | possible_structures = { 43 | "lc_ddpm_bc": ConfigDict( 44 | dict( 45 | agent="gc_ddpm_bc", 46 | agent_kwargs=dict( 47 | score_network_kwargs=dict( 48 | time_dim=32, 49 | num_blocks=3, 50 | dropout_rate=0.1, 51 | hidden_dim=256, 52 | use_layer_norm=True, 53 | ), 54 | language_conditioned=True, 55 | early_goal_concat=None, 56 | shared_goal_encoder=None, 57 | use_proprio=False, 58 | beta_schedule="cosine", 59 | diffusion_steps=20, 60 | action_samples=1, 61 | repeat_last_step=0, 62 | learning_rate=3e-4, 63 | warmup_steps=2000, 64 | actor_decay_steps=int(2e6), 65 | ), 66 | dataset_kwargs=dict( 67 | goal_relabeling_strategy="delta_goals", 68 | goal_relabeling_kwargs=dict(goal_delta=[0, 20]), 69 | #goal_relabeling_strategy="uniform", 70 | #goal_relabeling_kwargs=dict(reached_proportion=0.0), 71 | load_language=True, 72 | skip_unlabeled=True, 73 | relabel_actions=False, 74 | act_pred_horizon=4, 75 | obs_horizon=1, 76 | **base_data_config, 77 | ), 78 | text_processor="muse_embedding", 79 | text_processor_kwargs=dict(), 80 | encoder="resnetv1-34-bridge-film", 81 | encoder_kwargs=dict( 82 | pooling_method="avg", 83 | add_spatial_coordinates=True, 84 | act="swish", 85 | ), 86 | **base_real_config, 87 | ) 88 | ), 89 | } 90 | 91 | return possible_structures[config_string] 92 | -------------------------------------------------------------------------------- /experiments/susie/calvin/dataset_conversion_scripts/goal_conditioned.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script processes the full CALVIN dataset, writing it into TFRecord format. 3 | 4 | This script does not process language annotations (i.e. the resulting 5 | dataset can only be used for goal conditioned learning). See the sister 6 | script for code that only converts the language instruction labeled portion 7 | of the dataset into TFRecord format. 8 | 9 | Written by Pranav Atreya (pranavatreya@berkeley.edu). 10 | """ 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | from tqdm import tqdm 15 | import os 16 | from multiprocessing import Pool 17 | 18 | ########## Dataset paths ########### 19 | raw_dataset_path = "" 20 | tfrecord_dataset_path = "" 21 | 22 | ########## Main logic ########### 23 | if not os.path.exists(tfrecord_dataset_path): 24 | os.mkdir(tfrecord_dataset_path) 25 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training")): 26 | os.mkdir(os.path.join(tfrecord_dataset_path, "training")) 27 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "validation")): 28 | os.mkdir(os.path.join(tfrecord_dataset_path, "validation")) 29 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/A")): 30 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/A")) 31 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/B")): 32 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/B")) 33 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/C")): 34 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/C")) 35 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/D")): 36 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/D")) 37 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "validation/D")): 38 | os.mkdir(os.path.join(tfrecord_dataset_path, "validation/D")) 39 | 40 | def make_seven_characters(id): 41 | id = str(id) 42 | while len(id) < 7: 43 | id = "0" + id 44 | return id 45 | 46 | def tensor_feature(value): 47 | return tf.train.Feature( 48 | bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()]) 49 | ) 50 | 51 | def process_trajectory(function_data): 52 | global raw_dataset_path, tfrecord_dataset_path 53 | idx_range, letter, ctr, split = function_data 54 | unique_pid = split + "_" + letter + "_" + str(ctr) 55 | 56 | start_id, end_id = idx_range[0].item(), idx_range[1].item() 57 | 58 | # We will filter the keys to only include what we need 59 | # Namely "rel_actions", "robot_obs", and "rgb_static" 60 | traj_rel_actions, traj_robot_obs, traj_rgb_static = [], [], [] 61 | 62 | for ep_id in range(start_id, end_id+1): # end_id is inclusive 63 | #print(unique_pid + ": iter " + str(ep_id-start_id) + " of " + str(end_id-start_id)) 64 | 65 | ep_id = make_seven_characters(ep_id) 66 | timestep_data = np.load(os.path.join(raw_dataset_path, split, "episode_" + ep_id + ".npz")) 67 | 68 | rel_actions = timestep_data["rel_actions"] 69 | traj_rel_actions.append(rel_actions) 70 | 71 | robot_obs = timestep_data["robot_obs"] 72 | traj_robot_obs.append(robot_obs) 73 | 74 | rgb_static = timestep_data["rgb_static"] # not normalized, so we have to do normalization in another script 75 | traj_rgb_static.append(rgb_static) 76 | 77 | traj_rel_actions, traj_robot_obs, traj_rgb_static = np.array(traj_rel_actions, dtype=np.float32), np.array(traj_robot_obs, dtype=np.float32), np.array(traj_rgb_static, dtype=np.uint8) 78 | 79 | # Determine the output path 80 | write_dir = os.path.join(tfrecord_dataset_path, split, letter, "traj" + str(ctr)) 81 | if not os.path.exists(write_dir): 82 | os.mkdir(write_dir) 83 | 84 | # Split the trajectory into 1000 timestep length segments 85 | for traj_idx in range(0, len(traj_rel_actions), 1000): 86 | traj_rel_actions_segment = traj_rel_actions[traj_idx : min(traj_idx+1000, len(traj_rel_actions))] 87 | traj_robot_obs_segment = traj_robot_obs[traj_idx : min(traj_idx+1000, len(traj_robot_obs))] 88 | traj_rgb_static_segment = traj_rgb_static[traj_idx : min(traj_idx+1000, len(traj_rgb_static))] 89 | 90 | # Write the TFRecord 91 | output_tfrecord_path = os.path.join(write_dir, str(traj_idx // 1000) + ".tfrecord") 92 | with tf.io.TFRecordWriter(output_tfrecord_path) as writer: 93 | example = tf.train.Example( 94 | features=tf.train.Features( 95 | feature={ 96 | "actions" : tensor_feature(traj_rel_actions_segment), 97 | "proprioceptive_states" : tensor_feature(traj_robot_obs_segment), 98 | "image_states" : tensor_feature(traj_rgb_static_segment) 99 | } 100 | ) 101 | ) 102 | writer.write(example.SerializeToString()) 103 | 104 | # Let's prepare the inputs to the process_trajectory function and then parallelize execution 105 | function_inputs = [] 106 | 107 | # First let's do the train data 108 | ep_start_end_ids = np.load(os.path.join(raw_dataset_path, "training", "ep_start_end_ids.npy")) 109 | 110 | scene_info = np.load(os.path.join(raw_dataset_path, "training", "scene_info.npy"), allow_pickle=True) 111 | scene_info = scene_info.item() 112 | 113 | A_ctr, B_ctr, C_ctr, D_ctr = 0, 0, 0, 0 114 | for idx_range in ep_start_end_ids: 115 | start_idx = idx_range[0].item() 116 | if start_idx <= scene_info["calvin_scene_D"][1]: 117 | ctr = D_ctr 118 | D_ctr += 1 119 | letter = "D" 120 | elif start_idx <= scene_info["calvin_scene_B"][1]: # This is actually correct. In ascending order we have D, B, C, A 121 | ctr = B_ctr 122 | B_ctr += 1 123 | letter = "B" 124 | elif start_idx <= scene_info["calvin_scene_C"][1]: 125 | ctr = C_ctr 126 | C_ctr += 1 127 | letter = "C" 128 | else: 129 | ctr = A_ctr 130 | A_ctr += 1 131 | letter = "A" 132 | 133 | function_inputs.append((idx_range, letter, ctr, "training")) 134 | 135 | # Next let's do the validation data 136 | ep_start_end_ids = np.load(os.path.join(raw_dataset_path, "validation", "ep_start_end_ids.npy")) 137 | 138 | ctr = 0 139 | for idx_range in ep_start_end_ids: 140 | function_inputs.append((idx_range, "D", ctr, "validation")) 141 | ctr += 1 142 | 143 | with Pool(len(function_inputs)) as p: # We have one process per input because we are io bound, not cpu bound 144 | p.map(process_trajectory, function_inputs) 145 | #for function_input in tqdm(function_inputs): # If you want to process the dataset in a serialized fashion 146 | # process_trajectory(function_input) 147 | -------------------------------------------------------------------------------- /experiments/susie/calvin/dataset_conversion_scripts/language_conditioned.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script processes the language annotated portions of the CALVIN dataset, writing it into TFRecord format. 3 | 4 | The dataset constructed with this script is meant to be used to train a language conditioned policy. 5 | 6 | Written by Pranav Atreya (pranavatreya@berkeley.edu). 7 | """ 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | import os 13 | from multiprocessing import Pool 14 | 15 | ########## Dataset paths ########### 16 | raw_dataset_path = "" 17 | tfrecord_dataset_path = "" 18 | 19 | ########## Main logic ########### 20 | if not os.path.exists(tfrecord_dataset_path): 21 | os.mkdir(tfrecord_dataset_path) 22 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training")): 23 | os.mkdir(os.path.join(tfrecord_dataset_path, "training")) 24 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "validation")): 25 | os.mkdir(os.path.join(tfrecord_dataset_path, "validation")) 26 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/A")): 27 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/A")) 28 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/B")): 29 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/B")) 30 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/C")): 31 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/C")) 32 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "training/D")): 33 | os.mkdir(os.path.join(tfrecord_dataset_path, "training/D")) 34 | if not os.path.exists(os.path.join(tfrecord_dataset_path, "validation/D")): 35 | os.mkdir(os.path.join(tfrecord_dataset_path, "validation/D")) 36 | 37 | def make_seven_characters(id): 38 | id = str(id) 39 | while len(id) < 7: 40 | id = "0" + id 41 | return id 42 | 43 | def tensor_feature(value): 44 | return tf.train.Feature( 45 | bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()]) 46 | ) 47 | 48 | def string_to_feature(str_value): 49 | return tf.train.Feature( 50 | bytes_list=tf.train.BytesList(value=[str_value.encode("UTF-8")]) 51 | ) 52 | 53 | def process_trajectory(function_data): 54 | global raw_dataset_path, tfrecord_dataset_path 55 | idx_range, letter, ctr, split, lang_ann = function_data 56 | unique_pid = split + "_" + letter + "_" + str(ctr) 57 | 58 | start_id, end_id = idx_range[0], idx_range[1] 59 | 60 | # We will filter the keys to only include what we need 61 | # Namely "rel_actions", "robot_obs", and "rgb_static" 62 | traj_rel_actions, traj_robot_obs, traj_rgb_static = [], [], [] 63 | 64 | for ep_id in range(start_id, end_id+1): # end_id is inclusive 65 | #print(unique_pid + ": iter " + str(ep_id-start_id) + " of " + str(end_id-start_id)) 66 | 67 | ep_id = make_seven_characters(ep_id) 68 | timestep_data = np.load(os.path.join(raw_dataset_path, split, "episode_" + ep_id + ".npz")) 69 | 70 | rel_actions = timestep_data["rel_actions"] 71 | traj_rel_actions.append(rel_actions) 72 | 73 | robot_obs = timestep_data["robot_obs"] 74 | traj_robot_obs.append(robot_obs) 75 | 76 | rgb_static = timestep_data["rgb_static"] # not normalized, so we have to do normalization in another script 77 | traj_rgb_static.append(rgb_static) 78 | 79 | traj_rel_actions, traj_robot_obs, traj_rgb_static = np.array(traj_rel_actions, dtype=np.float32), np.array(traj_robot_obs, dtype=np.float32), np.array(traj_rgb_static, dtype=np.uint8) 80 | 81 | # Determine the output path 82 | write_dir = os.path.join(tfrecord_dataset_path, split, letter) 83 | 84 | # Write the TFRecord 85 | output_tfrecord_path = os.path.join(write_dir, "traj" + str(ctr) + ".tfrecord") 86 | with tf.io.TFRecordWriter(output_tfrecord_path) as writer: 87 | example = tf.train.Example( 88 | features=tf.train.Features( 89 | feature={ 90 | "actions" : tensor_feature(traj_rel_actions), 91 | "proprioceptive_states" : tensor_feature(traj_robot_obs), 92 | "image_states" : tensor_feature(traj_rgb_static), 93 | "language_annotation" : string_to_feature(lang_ann) 94 | } 95 | ) 96 | ) 97 | writer.write(example.SerializeToString()) 98 | 99 | # Let's prepare the inputs 100 | function_inputs = [] 101 | 102 | # First let's do the train data 103 | auto_lang_ann = np.load(os.path.join(raw_dataset_path, "training", "lang_annotations", "auto_lang_ann.npy"), allow_pickle=True) 104 | auto_lang_ann = auto_lang_ann.item() 105 | all_language_annotations = auto_lang_ann["language"]["ann"] 106 | idx_ranges = auto_lang_ann["info"]["indx"] 107 | 108 | scene_info = np.load(os.path.join(raw_dataset_path, "training", "scene_info.npy"), allow_pickle=True) 109 | scene_info = scene_info.item() 110 | 111 | A_ctr, B_ctr, C_ctr, D_ctr = 0, 0, 0, 0 112 | for i, idx_range in enumerate(idx_ranges): 113 | start_idx = idx_range[0] 114 | if start_idx <= scene_info["calvin_scene_D"][1]: 115 | ctr = D_ctr 116 | D_ctr += 1 117 | letter = "D" 118 | elif start_idx <= scene_info["calvin_scene_B"][1]: # This is actually correct. In ascending order we have D, B, C, A 119 | ctr = B_ctr 120 | B_ctr += 1 121 | letter = "B" 122 | elif start_idx <= scene_info["calvin_scene_C"][1]: 123 | ctr = C_ctr 124 | C_ctr += 1 125 | letter = "C" 126 | else: 127 | ctr = A_ctr 128 | A_ctr += 1 129 | letter = "A" 130 | 131 | function_inputs.append((idx_range, letter, ctr, "training", all_language_annotations[i])) 132 | 133 | # Next let's do the validation data 134 | auto_lang_ann = np.load(os.path.join(raw_dataset_path, "validation", "lang_annotations", "auto_lang_ann.npy"), allow_pickle=True) 135 | auto_lang_ann = auto_lang_ann.item() 136 | all_language_annotations = auto_lang_ann["language"]["ann"] 137 | idx_ranges = auto_lang_ann["info"]["indx"] 138 | 139 | ctr = 0 140 | for i, idx_range in enumerate(idx_ranges): 141 | function_inputs.append((idx_range, "D", ctr, "validation", all_language_annotations[i])) 142 | ctr += 1 143 | 144 | # Finally loop through and process everything 145 | for function_input in tqdm(function_inputs): 146 | process_trajectory(function_input) 147 | 148 | # You can also parallelize execution with a process pool, see end of sister script 149 | -------------------------------------------------------------------------------- /experiments/susie/calvin/scripts/launch_calvin_gcbc.sh: -------------------------------------------------------------------------------- 1 | # 2 cores per process 2 | TPU0="export TPU_VISIBLE_DEVICES=0 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 3 | TPU1="export TPU_VISIBLE_DEVICES=1 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8477 TPU_MESH_CONTROLLER_PORT=8477" 4 | TPU2="export TPU_VISIBLE_DEVICES=2 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 5 | TPU3="export TPU_VISIBLE_DEVICES=3 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8479 TPU_MESH_CONTROLLER_PORT=8479" 6 | 7 | # 4 cores per process 8 | TPU01="export TPU_VISIBLE_DEVICES=0,1 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 9 | TPU23="export TPU_VISIBLE_DEVICES=2,3 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 10 | 11 | NAME="gcbc_diffusion_policy" 12 | 13 | CMD="python experiments/susie/calvin/calvin_gcbc.py \ 14 | --config experiments/susie/calvin/configs/gcbc_train_config.py:gc_ddpm_bc \ 15 | --calvin_dataset_config experiments/susie/calvin/configs/gcbc_data_config.py:all \ 16 | --name $NAME" 17 | 18 | $CMD 19 | -------------------------------------------------------------------------------- /experiments/susie/calvin/scripts/launch_calvin_lcbc.sh: -------------------------------------------------------------------------------- 1 | # 2 cores per process 2 | TPU0="export TPU_VISIBLE_DEVICES=0 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 3 | TPU1="export TPU_VISIBLE_DEVICES=1 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8477 TPU_MESH_CONTROLLER_PORT=8477" 4 | TPU2="export TPU_VISIBLE_DEVICES=2 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 5 | TPU3="export TPU_VISIBLE_DEVICES=3 TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8479 TPU_MESH_CONTROLLER_PORT=8479" 6 | 7 | # 4 cores per process 8 | TPU01="export TPU_VISIBLE_DEVICES=0,1 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476" 9 | TPU23="export TPU_VISIBLE_DEVICES=2,3 TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8478 TPU_MESH_CONTROLLER_PORT=8478" 10 | 11 | NAME="lcbc_diffusion_policy" 12 | 13 | $TPU01 14 | CMD="python experiments/susie/calvin/calvin_lcbc.py \ 15 | --config experiments/susie/calvin/configs/lcbc_train_config.py:lc_ddpm_bc \ 16 | --calvin_dataset_config experiments/susie/calvin/configs/lcbc_data_config.py:all \ 17 | --name $NAME" 18 | 19 | $CMD 20 | -------------------------------------------------------------------------------- /experiments/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import tensorflow as tf 8 | import tqdm 9 | from absl import app, flags, logging 10 | from flax.training import checkpoints 11 | from ml_collections import config_flags 12 | 13 | from jaxrl_m.agents import agents 14 | from jaxrl_m.common.common import shard_batch 15 | from jaxrl_m.common.wandb import WandBLogger 16 | from jaxrl_m.data.bridge_dataset import BridgeDataset, glob_to_path_list 17 | from jaxrl_m.utils.timer_utils import Timer 18 | from jaxrl_m.vision import encoders 19 | from jaxrl_m.data.text_processing import text_processors 20 | 21 | try: 22 | from jax_smi import initialise_tracking # type: ignore 23 | 24 | initialise_tracking() 25 | except ImportError: 26 | pass 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("name", "", "Experiment name.") 31 | flags.DEFINE_bool("debug", False, "Debug config") 32 | 33 | config_flags.DEFINE_config_file( 34 | "config", 35 | None, 36 | "File path to the training hyperparameter configuration.", 37 | lock_config=False, 38 | ) 39 | 40 | config_flags.DEFINE_config_file( 41 | "bridgedata_config", 42 | None, 43 | "File path to the bridgedata configuration.", 44 | lock_config=False, 45 | ) 46 | 47 | 48 | def main(_): 49 | devices = jax.local_devices() 50 | num_devices = len(devices) 51 | assert FLAGS.config.batch_size % num_devices == 0 52 | 53 | # we shard the leading dimension (batch dimension) accross all devices evenly 54 | sharding = jax.sharding.PositionalSharding(devices) 55 | shard_fn = partial(shard_batch, sharding=sharding) 56 | 57 | # prevent tensorflow from using GPUs 58 | tf.config.set_visible_devices([], "GPU") 59 | 60 | # set up wandb and logging 61 | wandb_config = WandBLogger.get_default_config() 62 | wandb_config.update({"project": "jaxrl_m_bridgedata", "exp_descriptor": FLAGS.name}) 63 | wandb_logger = WandBLogger( 64 | wandb_config=wandb_config, variant=FLAGS.config.to_dict(), debug=FLAGS.debug 65 | ) 66 | 67 | save_dir = tf.io.gfile.join( 68 | FLAGS.config.save_dir, 69 | wandb_logger.config.project, 70 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}", 71 | ) 72 | 73 | # load datasets 74 | assert type(FLAGS.bridgedata_config.include[0]) == list 75 | task_paths = [ 76 | glob_to_path_list( 77 | path, prefix=FLAGS.config.data_path, exclude=FLAGS.bridgedata_config.exclude 78 | ) 79 | for path in FLAGS.bridgedata_config.include 80 | ] 81 | 82 | train_paths = [ 83 | [os.path.join(path, "train/out.tfrecord") for path in sub_list] 84 | for sub_list in task_paths 85 | ] 86 | val_paths = [ 87 | [os.path.join(path, "val/out.tfrecord") for path in sub_list] 88 | for sub_list in task_paths 89 | ] 90 | 91 | train_data = BridgeDataset( 92 | train_paths, 93 | FLAGS.config.seed, 94 | batch_size=FLAGS.config.batch_size, 95 | train=True, 96 | action_proprio_metadata=FLAGS.bridgedata_config.action_proprio_metadata, 97 | sample_weights=FLAGS.bridgedata_config.sample_weights, 98 | **FLAGS.config.dataset_kwargs, 99 | ) 100 | val_data = BridgeDataset( 101 | val_paths, 102 | FLAGS.config.seed, 103 | batch_size=FLAGS.config.batch_size, 104 | action_proprio_metadata=FLAGS.bridgedata_config.action_proprio_metadata, 105 | train=False, 106 | **FLAGS.config.dataset_kwargs, 107 | ) 108 | 109 | if FLAGS.config.get("text_processor") is None: 110 | text_processor = None 111 | else: 112 | text_processor = text_processors[FLAGS.config.text_processor]( 113 | **FLAGS.config.text_processor_kwargs 114 | ) 115 | 116 | def process_text(batch): 117 | if text_processor is not None: 118 | batch["goals"]["language"] = text_processor.encode( 119 | [s.decode("utf-8") for s in batch["goals"]["language"]] 120 | ) 121 | return batch 122 | 123 | train_data_iter = map( 124 | shard_fn, map(process_text, train_data.tf_dataset.as_numpy_iterator()) 125 | ) 126 | 127 | example_batch = next(train_data_iter) 128 | logging.info(f"Batch size: {example_batch['observations']['image'].shape[0]}") 129 | logging.info(f"Number of devices: {num_devices}") 130 | logging.info( 131 | f"Batch size per device: {example_batch['observations']['image'].shape[0] // num_devices}" 132 | ) 133 | 134 | # define encoder 135 | encoder_def = encoders[FLAGS.config.encoder](**FLAGS.config.encoder_kwargs) 136 | 137 | # initialize agent 138 | rng = jax.random.PRNGKey(FLAGS.config.seed) 139 | rng, construct_rng = jax.random.split(rng) 140 | agent = agents[FLAGS.config.agent].create( 141 | rng=construct_rng, 142 | observations=example_batch["observations"], 143 | goals=example_batch["goals"], 144 | actions=example_batch["actions"], 145 | encoder_def=encoder_def, 146 | **FLAGS.config.agent_kwargs, 147 | ) 148 | if FLAGS.config.resume_path is not None: 149 | agent = checkpoints.restore_checkpoint(FLAGS.config.resume_path, target=agent) 150 | # replicate agent across devices 151 | # need the jnp.array to avoid a bug where device_put doesn't recognize primitives 152 | agent = jax.device_put(jax.tree_map(jnp.array, agent), sharding.replicate()) 153 | 154 | timer = Timer() 155 | for i in tqdm.tqdm(range(int(FLAGS.config.num_steps))): 156 | timer.tick("total") 157 | 158 | timer.tick("dataset") 159 | batch = next(train_data_iter) 160 | timer.tock("dataset") 161 | 162 | timer.tick("train") 163 | agent, update_info = agent.update(batch) 164 | timer.tock("train") 165 | 166 | if (i + 1) % FLAGS.config.eval_interval == 0: 167 | logging.info("Evaluating...") 168 | timer.tick("val") 169 | metrics = [] 170 | val_data_iter = map(shard_fn, map(process_text, val_data.iterator())) 171 | for batch in val_data_iter: 172 | rng, val_rng = jax.random.split(rng) 173 | metrics.append(agent.get_debug_metrics(batch, seed=val_rng)) 174 | metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) 175 | wandb_logger.log({"validation": metrics}, step=i) 176 | timer.tock("val") 177 | 178 | if (i + 1) % FLAGS.config.save_interval == 0: 179 | logging.info("Saving checkpoint...") 180 | checkpoint_path = checkpoints.save_checkpoint( 181 | save_dir, agent, step=i + 1, keep=1e6 182 | ) 183 | logging.info("Saved checkpoint to %s", checkpoint_path) 184 | 185 | timer.tock("total") 186 | 187 | if (i + 1) % FLAGS.config.log_interval == 0: 188 | update_info = jax.device_get(update_info) 189 | wandb_logger.log({"training": update_info}, step=i) 190 | 191 | wandb_logger.log({"timer": timer.get_average_times()}, step=i) 192 | 193 | 194 | if __name__ == "__main__": 195 | app.run(main) 196 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | from pyquaternion import Quaternion 4 | 5 | def stack_obs(obs): 6 | dict_list = {k: [dic[k] for dic in obs] for k in obs[0]} 7 | return jax.tree_map( 8 | lambda x: np.stack(x), dict_list, is_leaf=lambda x: type(x) == list 9 | ) 10 | 11 | def state_to_eep(xyz_coor, zangle: float): 12 | """ 13 | Implement the state to eep function. 14 | Refered to `bridge_data_robot`'s `widowx_controller/widowx_controller.py` 15 | return a 4x4 matrix 16 | """ 17 | assert len(xyz_coor) == 3 18 | DEFAULT_ROTATION = np.array([[0 , 0, 1.0], 19 | [0, 1.0, 0], 20 | [-1.0, 0, 0]]) 21 | new_pose = np.eye(4) 22 | new_pose[:3, -1] = xyz_coor 23 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) \ 24 | * Quaternion(matrix=DEFAULT_ROTATION) 25 | new_pose[:3, :3] = new_quat.rotation_matrix 26 | # yaw, pitch, roll = quat.yaw_pitch_roll 27 | return new_pose 28 | 29 | 30 | def mat_to_xyzrpy(mat: np.ndarray): 31 | """return a 6-dim vector with xyz and rpy""" 32 | assert mat.shape == (4, 4), "mat must be a 4x4 matrix" 33 | xyz = mat[:3, -1] 34 | quat = Quaternion(matrix=mat[:3, :3]) 35 | yaw, pitch, roll = quat.yaw_pitch_roll 36 | return np.concatenate([xyz, [roll, pitch, yaw]]) 37 | -------------------------------------------------------------------------------- /jaxrl_m/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous.bc import BCAgent 2 | from .continuous.gc_bc import GCBCAgent 3 | from .continuous.gc_ddpm_bc import GCDDPMBCAgent 4 | from .continuous.gc_iql import GCIQLAgent 5 | from .continuous.iql import IQLAgent 6 | from .continuous.lc_bc import LCBCAgent 7 | from .continuous.stable_contrastive_rl import StableContrastiveRLAgent 8 | 9 | agents = { 10 | "gc_bc": GCBCAgent, 11 | "lc_bc": LCBCAgent, 12 | "gc_iql": GCIQLAgent, 13 | "gc_ddpm_bc": GCDDPMBCAgent, 14 | "bc": BCAgent, 15 | "iql": IQLAgent, 16 | "stable_contrastive_rl": StableContrastiveRLAgent, 17 | } 18 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/bc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any 3 | import jax 4 | import jax.numpy as jnp 5 | from jaxrl_m.common.encoding import EncodingWrapper 6 | import numpy as np 7 | import flax 8 | import flax.linen as nn 9 | import optax 10 | 11 | from flax.core import FrozenDict 12 | from jaxrl_m.common.typing import Batch 13 | from jaxrl_m.common.typing import PRNGKey 14 | from jaxrl_m.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 15 | from jaxrl_m.networks.actor_critic_nets import Policy 16 | from jaxrl_m.networks.mlp import MLP 17 | 18 | 19 | class BCAgent(flax.struct.PyTreeNode): 20 | state: JaxRLTrainState 21 | lr_schedule: Any = nonpytree_field() 22 | 23 | @partial(jax.jit, static_argnames="pmap_axis") 24 | def update(self, batch: Batch, pmap_axis: str = None): 25 | def loss_fn(params, rng): 26 | rng, key = jax.random.split(rng) 27 | dist = self.state.apply_fn( 28 | {"params": params}, 29 | batch["observations"], 30 | temperature=1.0, 31 | train=True, 32 | rngs={"dropout": key}, 33 | name="actor", 34 | ) 35 | pi_actions = dist.mode() 36 | log_probs = dist.log_prob(batch["actions"]) 37 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 38 | actor_loss = -(log_probs).mean() 39 | actor_std = dist.stddev().mean(axis=1) 40 | 41 | return ( 42 | actor_loss, 43 | { 44 | "actor_loss": actor_loss, 45 | "mse": mse.mean(), 46 | "log_probs": log_probs, 47 | "pi_actions": pi_actions, 48 | "mean_std": actor_std.mean(), 49 | "max_std": actor_std.max(), 50 | }, 51 | ) 52 | 53 | # compute gradients and update params 54 | new_state, info = self.state.apply_loss_fns( 55 | loss_fn, pmap_axis=pmap_axis, has_aux=True 56 | ) 57 | 58 | # log learning rates 59 | info["lr"] = self.lr_schedule(self.state.step) 60 | 61 | return self.replace(state=new_state), info 62 | 63 | @partial(jax.jit, static_argnames="argmax") 64 | def sample_actions( 65 | self, 66 | observations: np.ndarray, 67 | *, 68 | seed: PRNGKey, 69 | temperature: float = 1.0, 70 | argmax=False 71 | ) -> jnp.ndarray: 72 | dist = self.state.apply_fn( 73 | {"params": self.state.params}, 74 | observations, 75 | temperature=temperature, 76 | name="actor", 77 | ) 78 | if argmax: 79 | actions = dist.mode() 80 | else: 81 | actions = dist.sample(seed=seed) 82 | return actions 83 | 84 | @jax.jit 85 | def get_debug_metrics(self, batch, **kwargs): 86 | dist = self.state.apply_fn( 87 | {"params": self.state.params}, 88 | batch["observations"], 89 | temperature=1.0, 90 | name="actor", 91 | ) 92 | pi_actions = dist.mode() 93 | log_probs = dist.log_prob(batch["actions"]) 94 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 95 | 96 | return {"mse": mse, "log_probs": log_probs, "pi_actions": pi_actions} 97 | 98 | @classmethod 99 | def create( 100 | cls, 101 | rng: PRNGKey, 102 | observations: FrozenDict, 103 | actions: jnp.ndarray, 104 | # Model architecture 105 | encoder_def: nn.Module, 106 | use_proprio: bool = False, 107 | network_kwargs: dict = {"hidden_dims": [256, 256]}, 108 | policy_kwargs: dict = { 109 | "tanh_squash_distribution": False, 110 | "state_dependent_std": False, 111 | "dropout": 0.0, 112 | }, 113 | # Optimizer 114 | learning_rate: float = 3e-4, 115 | warmup_steps: int = 1000, 116 | decay_steps: int = 1000000, 117 | ): 118 | encoder_def = EncodingWrapper( 119 | encoder=encoder_def, use_proprio=use_proprio, stop_gradient=False 120 | ) 121 | 122 | network_kwargs["activate_final"] = True 123 | networks = { 124 | "actor": Policy( 125 | encoder_def, 126 | MLP(**network_kwargs), 127 | action_dim=actions.shape[-1], 128 | **policy_kwargs 129 | ) 130 | } 131 | 132 | model_def = ModuleDict(networks) 133 | 134 | lr_schedule = optax.warmup_cosine_decay_schedule( 135 | init_value=0.0, 136 | peak_value=learning_rate, 137 | warmup_steps=warmup_steps, 138 | decay_steps=decay_steps, 139 | end_value=0.0, 140 | ) 141 | tx = optax.adam(lr_schedule) 142 | 143 | rng, init_rng = jax.random.split(rng) 144 | params = model_def.init(init_rng, actor=observations)["params"] 145 | 146 | rng, create_rng = jax.random.split(rng) 147 | state = JaxRLTrainState.create( 148 | apply_fn=model_def.apply, 149 | params=params, 150 | txs=tx, 151 | target_params=params, 152 | rng=create_rng, 153 | ) 154 | 155 | return cls(state, lr_schedule) 156 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/gc_bc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Any 4 | import jax 5 | import jax.numpy as jnp 6 | from jaxrl_m.common.encoding import GCEncodingWrapper 7 | import numpy as np 8 | import flax 9 | import flax.linen as nn 10 | import optax 11 | 12 | from flax.core import FrozenDict 13 | from jaxrl_m.common.typing import Batch 14 | from jaxrl_m.common.typing import PRNGKey 15 | from jaxrl_m.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 16 | from jaxrl_m.networks.actor_critic_nets import Policy 17 | from jaxrl_m.networks.mlp import MLP 18 | 19 | 20 | class GCBCAgent(flax.struct.PyTreeNode): 21 | state: JaxRLTrainState 22 | lr_schedule: Any = nonpytree_field() 23 | 24 | @partial(jax.jit, static_argnames="pmap_axis") 25 | def update(self, batch: Batch, pmap_axis: str = None): 26 | def loss_fn(params, rng): 27 | rng, key = jax.random.split(rng) 28 | dist = self.state.apply_fn( 29 | {"params": params}, 30 | (batch["observations"], batch["goals"]), 31 | temperature=1.0, 32 | train=True, 33 | rngs={"dropout": key}, 34 | name="actor", 35 | ) 36 | pi_actions = dist.mode() 37 | log_probs = dist.log_prob(batch["actions"]) 38 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 39 | actor_loss = -(log_probs).mean() 40 | actor_std = dist.stddev().mean(axis=1) 41 | 42 | return ( 43 | actor_loss, 44 | { 45 | "actor_loss": actor_loss, 46 | "mse": mse.mean(), 47 | "log_probs": log_probs.mean(), 48 | "pi_actions": pi_actions.mean(), 49 | "mean_std": actor_std.mean(), 50 | "max_std": actor_std.max(), 51 | }, 52 | ) 53 | 54 | # compute gradients and update params 55 | new_state, info = self.state.apply_loss_fns( 56 | loss_fn, pmap_axis=pmap_axis, has_aux=True 57 | ) 58 | 59 | # log learning rates 60 | info["lr"] = self.lr_schedule(self.state.step) 61 | 62 | return self.replace(state=new_state), info 63 | 64 | @partial(jax.jit, static_argnames="argmax") 65 | def sample_actions( 66 | self, 67 | observations: np.ndarray, 68 | goals: np.ndarray, 69 | *, 70 | seed: PRNGKey, 71 | temperature: float = 1.0, 72 | argmax=False 73 | ) -> jnp.ndarray: 74 | dist = self.state.apply_fn( 75 | {"params": self.state.params}, 76 | (observations, goals), 77 | temperature=temperature, 78 | name="actor", 79 | ) 80 | if argmax: 81 | actions = dist.mode() 82 | else: 83 | actions = dist.sample(seed=seed) 84 | return actions 85 | 86 | @jax.jit 87 | def get_debug_metrics(self, batch, **kwargs): 88 | dist = self.state.apply_fn( 89 | {"params": self.state.params}, 90 | (batch["observations"], batch["goals"]), 91 | temperature=1.0, 92 | name="actor", 93 | ) 94 | pi_actions = dist.mode() 95 | log_probs = dist.log_prob(batch["actions"]) 96 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 97 | 98 | return {"mse": mse, "log_probs": log_probs, "pi_actions": pi_actions} 99 | 100 | @classmethod 101 | def create( 102 | cls, 103 | rng: PRNGKey, 104 | observations: FrozenDict, 105 | actions: jnp.ndarray, 106 | goals: FrozenDict, 107 | # Model architecture 108 | encoder_def: nn.Module, 109 | shared_goal_encoder: bool = True, 110 | early_goal_concat: bool = False, 111 | use_proprio: bool = False, 112 | network_kwargs: dict = {"hidden_dims": [256, 256]}, 113 | policy_kwargs: dict = { 114 | "tanh_squash_distribution": False, 115 | "state_dependent_std": False, 116 | "dropout": 0.0, 117 | }, 118 | # Optimizer 119 | learning_rate: float = 3e-4, 120 | warmup_steps: int = 1000, 121 | decay_steps: int = 1000000, 122 | ): 123 | if early_goal_concat: 124 | # passing None as the goal encoder causes early goal concat 125 | goal_encoder_def = None 126 | else: 127 | if shared_goal_encoder: 128 | goal_encoder_def = encoder_def 129 | else: 130 | goal_encoder_def = copy.deepcopy(encoder_def) 131 | 132 | encoder_def = GCEncodingWrapper( 133 | encoder=encoder_def, 134 | goal_encoder=goal_encoder_def, 135 | use_proprio=use_proprio, 136 | stop_gradient=False, 137 | ) 138 | 139 | network_kwargs["activate_final"] = True 140 | networks = { 141 | "actor": Policy( 142 | encoder_def, 143 | MLP(**network_kwargs), 144 | action_dim=actions.shape[-1], 145 | **policy_kwargs 146 | ) 147 | } 148 | 149 | model_def = ModuleDict(networks) 150 | 151 | lr_schedule = optax.warmup_cosine_decay_schedule( 152 | init_value=0.0, 153 | peak_value=learning_rate, 154 | warmup_steps=warmup_steps, 155 | decay_steps=decay_steps, 156 | end_value=0.0, 157 | ) 158 | tx = optax.adam(lr_schedule) 159 | 160 | rng, init_rng = jax.random.split(rng) 161 | params = model_def.init(init_rng, actor=[(observations, goals)])["params"] 162 | 163 | rng, create_rng = jax.random.split(rng) 164 | state = JaxRLTrainState.create( 165 | apply_fn=model_def.apply, 166 | params=params, 167 | txs=tx, 168 | target_params=params, 169 | rng=create_rng, 170 | ) 171 | 172 | return cls(state, lr_schedule) 173 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/gc_ddpm_bc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import flax 9 | import flax.linen as nn 10 | import optax 11 | 12 | from flax.core import FrozenDict 13 | from jaxrl_m.common.typing import Batch 14 | from jaxrl_m.common.typing import PRNGKey 15 | from jaxrl_m.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 16 | from jaxrl_m.common.encoding import GCEncodingWrapper 17 | 18 | from jaxrl_m.networks.diffusion_nets import ( 19 | FourierFeatures, 20 | cosine_beta_schedule, 21 | vp_beta_schedule, 22 | ScoreActor, 23 | ) 24 | from jaxrl_m.networks.mlp import MLP, MLPResNet 25 | 26 | 27 | def ddpm_bc_loss(noise_prediction, noise): 28 | ddpm_loss = jnp.square(noise_prediction - noise).sum(-1) 29 | 30 | return ( 31 | ddpm_loss.mean(), 32 | {"ddpm_loss": ddpm_loss, "ddpm_loss_mean": ddpm_loss.mean()}, 33 | ) 34 | 35 | 36 | class GCDDPMBCAgent(flax.struct.PyTreeNode): 37 | """ 38 | Models action distribution with a diffusion model. 39 | 40 | Assumes observation histories as input and action sequences as output. 41 | """ 42 | 43 | state: JaxRLTrainState 44 | config: dict = nonpytree_field() 45 | lr_schedules: dict = nonpytree_field() 46 | 47 | @partial(jax.jit, static_argnames="pmap_axis") 48 | def update(self, batch: Batch, pmap_axis: str = None): 49 | def actor_loss_fn(params, rng): 50 | key, rng = jax.random.split(rng) 51 | time = jax.random.randint( 52 | key, (batch["actions"].shape[0],), 0, self.config["diffusion_steps"] 53 | ) 54 | key, rng = jax.random.split(rng) 55 | noise_sample = jax.random.normal(key, batch["actions"].shape) 56 | 57 | alpha_hats = self.config["alpha_hats"][time] 58 | time = time[:, None] 59 | alpha_1 = jnp.sqrt(alpha_hats)[:, None, None] 60 | alpha_2 = jnp.sqrt(1 - alpha_hats)[:, None, None] 61 | 62 | noisy_actions = alpha_1 * batch["actions"] + alpha_2 * noise_sample 63 | 64 | rng, key = jax.random.split(rng) 65 | noise_pred = self.state.apply_fn( 66 | {"params": params}, # gradient flows through here 67 | (batch["observations"], batch["goals"]), 68 | noisy_actions, 69 | time, 70 | train=True, 71 | rngs={"dropout": key}, 72 | name="actor", 73 | ) 74 | 75 | return ddpm_bc_loss(noise_pred, noise_sample) 76 | 77 | loss_fns = {"actor": actor_loss_fn} 78 | 79 | # compute gradients and update params 80 | new_state, info = self.state.apply_loss_fns( 81 | loss_fns, pmap_axis=pmap_axis, has_aux=True 82 | ) 83 | 84 | # update the target params 85 | new_state = new_state.target_update(self.config["target_update_rate"]) 86 | 87 | # log learning rates 88 | info["actor_lr"] = self.lr_schedules["actor"](self.state.step) 89 | 90 | return self.replace(state=new_state), info 91 | 92 | @partial(jax.jit, static_argnames="argmax") 93 | def sample_actions( 94 | self, 95 | observations: np.ndarray, 96 | goals: np.ndarray, 97 | *, 98 | seed: PRNGKey = None, 99 | temperature: float = 1.0, 100 | argmax: bool = False, 101 | clip_sampler: bool = True, 102 | ) -> jnp.ndarray: 103 | assert len(observations["image"].shape) > 3, "Must use observation histories" 104 | 105 | def fn(input_tuple, time): 106 | current_x, rng = input_tuple 107 | input_time = jnp.broadcast_to(time, (current_x.shape[0], 1)) 108 | 109 | eps_pred = self.state.apply_fn( 110 | {"params": self.state.target_params}, 111 | (observations, goals), 112 | current_x, 113 | input_time, 114 | name="actor", 115 | ) 116 | 117 | alpha_1 = 1 / jnp.sqrt(self.config["alphas"][time]) 118 | alpha_2 = (1 - self.config["alphas"][time]) / ( 119 | jnp.sqrt(1 - self.config["alpha_hats"][time]) 120 | ) 121 | current_x = alpha_1 * (current_x - alpha_2 * eps_pred) 122 | 123 | rng, key = jax.random.split(rng) 124 | z = jax.random.normal(key, shape=current_x.shape) 125 | z_scaled = temperature * z 126 | current_x = current_x + (time > 0) * ( 127 | jnp.sqrt(self.config["betas"][time]) * z_scaled 128 | ) 129 | 130 | if clip_sampler: 131 | current_x = jnp.clip( 132 | current_x, self.config["action_min"], self.config["action_max"] 133 | ) 134 | 135 | return (current_x, rng), () 136 | 137 | key, rng = jax.random.split(seed) 138 | 139 | if len(observations["image"].shape) == 4: 140 | # unbatched input from evaluation 141 | batch_size = 1 142 | observations = jax.tree_map(lambda x: x[None], observations) 143 | goals = jax.tree_map(lambda x: x[None], goals) 144 | else: 145 | batch_size = observations["image"].shape[0] 146 | 147 | input_tuple, () = jax.lax.scan( 148 | fn, 149 | (jax.random.normal(key, (batch_size, *self.config["action_dim"])), rng), 150 | jnp.arange(self.config["diffusion_steps"] - 1, -1, -1), 151 | ) 152 | 153 | for _ in range(self.config["repeat_last_step"]): 154 | input_tuple, () = fn(input_tuple, 0) 155 | 156 | action_0, rng = input_tuple 157 | 158 | if batch_size == 1: 159 | # this is an evaluation call so unbatch 160 | return action_0[0] 161 | else: 162 | return action_0 163 | 164 | @jax.jit 165 | def get_debug_metrics(self, batch, seed, gripper_close_val=None): 166 | actions = self.sample_actions( 167 | observations=batch["observations"], goals=batch["goals"], seed=seed 168 | ) 169 | 170 | metrics = {"mse": ((actions - batch["actions"]) ** 2).sum((-2, -1)).mean()} 171 | 172 | return metrics 173 | 174 | @classmethod 175 | def create( 176 | cls, 177 | rng: PRNGKey, 178 | observations: FrozenDict, 179 | goals: FrozenDict, 180 | actions: jnp.ndarray, 181 | # Model architecture 182 | encoder_def: nn.Module, 183 | shared_goal_encoder: bool = True, 184 | early_goal_concat: bool = False, 185 | use_proprio: bool = False, 186 | score_network_kwargs: dict = { 187 | "time_dim": 32, 188 | "num_blocks": 3, 189 | "dropout_rate": 0.1, 190 | "hidden_dim": 256, 191 | }, 192 | # Optimizer 193 | learning_rate: float = 3e-4, 194 | warmup_steps: int = 2000, 195 | actor_decay_steps: Optional[int] = None, 196 | # Algorithm config 197 | beta_schedule: str = "cosine", 198 | diffusion_steps: int = 25, 199 | action_samples: int = 1, 200 | repeat_last_step: int = 0, 201 | target_update_rate=0.002, 202 | dropout_target_networks=True, 203 | ): 204 | assert len(actions.shape) > 1, "Must use action chunking" 205 | assert len(observations["image"].shape) > 3, "Must use observation histories" 206 | 207 | if early_goal_concat: 208 | # passing None as the goal encoder causes early goal concat 209 | goal_encoder_def = None 210 | else: 211 | if shared_goal_encoder: 212 | goal_encoder_def = encoder_def 213 | else: 214 | goal_encoder_def = copy.deepcopy(encoder_def) 215 | 216 | encoder_def = GCEncodingWrapper( 217 | encoder=encoder_def, 218 | goal_encoder=goal_encoder_def, 219 | use_proprio=use_proprio, 220 | stop_gradient=False, 221 | ) 222 | 223 | networks = { 224 | "actor": ScoreActor( 225 | encoder_def, 226 | FourierFeatures(score_network_kwargs["time_dim"], learnable=True), 227 | MLP( 228 | ( 229 | 2 * score_network_kwargs["time_dim"], 230 | score_network_kwargs["time_dim"], 231 | ) 232 | ), 233 | MLPResNet( 234 | score_network_kwargs["num_blocks"], 235 | actions.shape[-2] * actions.shape[-1], 236 | dropout_rate=score_network_kwargs["dropout_rate"], 237 | use_layer_norm=score_network_kwargs["use_layer_norm"], 238 | ), 239 | ) 240 | } 241 | 242 | model_def = ModuleDict(networks) 243 | 244 | rng, init_rng = jax.random.split(rng) 245 | if len(actions.shape) == 3: 246 | example_time = jnp.zeros((actions.shape[0], 1)) 247 | else: 248 | example_time = jnp.zeros((1,)) 249 | params = model_def.init( 250 | init_rng, actor=[(observations, goals), actions, example_time] 251 | )["params"] 252 | 253 | # no decay 254 | lr_schedule = optax.warmup_cosine_decay_schedule( 255 | init_value=0.0, 256 | peak_value=learning_rate, 257 | warmup_steps=warmup_steps, 258 | decay_steps=warmup_steps + 1, 259 | end_value=learning_rate, 260 | ) 261 | lr_schedules = {"actor": lr_schedule} 262 | if actor_decay_steps is not None: 263 | lr_schedules["actor"] = optax.warmup_cosine_decay_schedule( 264 | init_value=0.0, 265 | peak_value=learning_rate, 266 | warmup_steps=warmup_steps, 267 | decay_steps=actor_decay_steps, 268 | end_value=0.0, 269 | ) 270 | txs = {k: optax.adam(v) for k, v in lr_schedules.items()} 271 | 272 | rng, create_rng = jax.random.split(rng) 273 | state = JaxRLTrainState.create( 274 | apply_fn=model_def.apply, 275 | params=params, 276 | txs=txs, 277 | target_params=params, 278 | rng=create_rng, 279 | ) 280 | 281 | if beta_schedule == "cosine": 282 | betas = jnp.array(cosine_beta_schedule(diffusion_steps)) 283 | elif beta_schedule == "linear": 284 | betas = jnp.linspace(1e-4, 2e-2, diffusion_steps) 285 | elif beta_schedule == "vp": 286 | betas = jnp.array(vp_beta_schedule(diffusion_steps)) 287 | 288 | alphas = 1 - betas 289 | alpha_hat = jnp.array( 290 | [jnp.prod(alphas[: i + 1]) for i in range(diffusion_steps)] 291 | ) 292 | 293 | config = flax.core.FrozenDict( 294 | dict( 295 | target_update_rate=target_update_rate, 296 | dropout_target_networks=dropout_target_networks, 297 | action_dim=actions.shape[-2:], 298 | action_max=2.0, 299 | action_min=-2.0, 300 | betas=betas, 301 | alphas=alphas, 302 | alpha_hats=alpha_hat, 303 | diffusion_steps=diffusion_steps, 304 | action_samples=action_samples, 305 | repeat_last_step=repeat_last_step, 306 | ) 307 | ) 308 | return cls(state, config, lr_schedules) 309 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/iql.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import optax 8 | from flax.core import FrozenDict 9 | from jaxrl_m.common.typing import Batch, PRNGKey 10 | from jaxrl_m.common.common import ModuleDict, JaxRLTrainState, nonpytree_field 11 | from jaxrl_m.common.encoding import EncodingWrapper 12 | from jaxrl_m.networks.actor_critic_nets import ValueCritic 13 | from jaxrl_m.networks.actor_critic_nets import Policy 14 | from jaxrl_m.networks.actor_critic_nets import Critic 15 | from jaxrl_m.networks.mlp import MLP 16 | import numpy as np 17 | 18 | import flax 19 | import flax.linen as nn 20 | 21 | 22 | def expectile_loss(diff, expectile=0.5): 23 | weight = jnp.where(diff > 0, expectile, (1 - expectile)) 24 | return weight * (diff ** 2) 25 | 26 | 27 | def iql_value_loss(q, v, expectile): 28 | value_loss = expectile_loss(q - v, expectile) 29 | return ( 30 | value_loss.mean(), 31 | { 32 | "value_loss": value_loss.mean(), 33 | "uncentered_loss": jnp.mean((q - v) ** 2), 34 | "v": v.mean(), 35 | }, 36 | ) 37 | 38 | 39 | def iql_critic_loss(q, q_target): 40 | critic_loss = jnp.square(q - q_target) 41 | return critic_loss.mean(), {"td_loss": critic_loss.mean(), "q": q.mean()} 42 | 43 | 44 | def iql_actor_loss(q, v, dist, actions, temperature=1.0, adv_clip_max=100.0, mask=None): 45 | adv = q - v 46 | 47 | exp_adv = jnp.exp(adv / temperature) 48 | exp_adv = jnp.minimum(exp_adv, adv_clip_max) 49 | 50 | log_probs = dist.log_prob(actions) 51 | actor_loss = -(exp_adv * log_probs) 52 | 53 | if mask is not None: 54 | actor_loss *= mask 55 | actor_loss = jnp.sum(actor_loss) / jnp.sum(mask) 56 | else: 57 | actor_loss = jnp.mean(actor_loss) 58 | 59 | behavior_mse = jnp.square(dist.mode() - actions).sum(-1) 60 | 61 | return ( 62 | actor_loss, 63 | { 64 | "actor_loss": actor_loss, 65 | "behavior_logprob": log_probs.mean(), 66 | "behavior_mse": behavior_mse.mean(), 67 | "adv_mean": adv.mean(), 68 | "adv_max": adv.max(), 69 | "adv_min": adv.min(), 70 | }, 71 | ) 72 | 73 | 74 | class IQLAgent(flax.struct.PyTreeNode): 75 | state: JaxRLTrainState 76 | config: dict = nonpytree_field() 77 | lr_schedules: dict = nonpytree_field() 78 | 79 | @partial(jax.jit, static_argnames="pmap_axis") 80 | def update(self, batch: Batch, pmap_axis: str = None): 81 | new_rng, dropout_rng = jax.random.split(self.state.rng) 82 | 83 | def critic_loss_fn(params): 84 | next_v = self.state.apply_fn( 85 | {"params": self.state.target_params}, 86 | batch["next_observations"], 87 | name="value", 88 | ) 89 | target_q = ( 90 | batch["rewards"] + self.config["discount"] * next_v * batch["masks"] 91 | ) 92 | q = self.state.apply_fn( 93 | {"params": params}, 94 | batch["observations"], 95 | batch["actions"], 96 | name="critic", 97 | ) 98 | return iql_critic_loss(q, target_q) 99 | 100 | def value_loss_fn(params): 101 | q = self.state.apply_fn( 102 | {"params": self.state.params}, # no gradient flows through here 103 | batch["observations"], 104 | batch["actions"], 105 | name="critic", 106 | ) 107 | v = self.state.apply_fn( 108 | {"params": params}, # gradient flows through here 109 | batch["observations"], 110 | name="value", 111 | ) 112 | return iql_value_loss(q, v, self.config["expectile"]) 113 | 114 | def actor_loss_fn(params): 115 | next_v = self.state.apply_fn( 116 | {"params": self.state.target_params}, 117 | batch["next_observations"], 118 | name="value", 119 | ) 120 | target_q = ( 121 | batch["rewards"] + self.config["discount"] * next_v * batch["masks"] 122 | ) 123 | 124 | v = self.state.apply_fn( 125 | {"params": self.state.params}, # no gradient flows through here 126 | batch["observations"], 127 | name="value", 128 | ) 129 | dist = self.state.apply_fn( 130 | {"params": params}, # gradient flows through here 131 | batch["observations"], 132 | train=True, 133 | rngs={"dropout": dropout_rng}, 134 | name="actor", 135 | ) 136 | mask = batch.get("actor_loss_mask", None) 137 | return iql_actor_loss( 138 | target_q, 139 | v, 140 | dist, 141 | batch["actions"], 142 | self.config["temperature"], 143 | mask=mask, 144 | ) 145 | 146 | loss_fns = { 147 | "critic": critic_loss_fn, 148 | "value": value_loss_fn, 149 | "actor": actor_loss_fn, 150 | } 151 | 152 | # compute gradients and update params 153 | new_state, info = self.state.apply_loss_fns( 154 | loss_fns, pmap_axis=pmap_axis, has_aux=True 155 | ) 156 | 157 | # update the target params 158 | new_state = new_state.target_update(self.config["target_update_rate"]) 159 | 160 | # update rng 161 | new_state = new_state.replace(rng=new_rng) 162 | 163 | # log learning rates 164 | info["actor_lr"] = self.lr_schedules["actor"](self.state.step) 165 | 166 | return self.replace(state=new_state), info 167 | 168 | @partial(jax.jit, static_argnames="argmax") 169 | def sample_actions( 170 | self, 171 | observations: np.ndarray, 172 | *, 173 | seed: PRNGKey, 174 | temperature: float = 1.0, 175 | argmax=False, 176 | ) -> jnp.ndarray: 177 | dist = self.state.apply_fn(observations, temperature=temperature, name="actor") 178 | if argmax: 179 | actions = dist.mode() 180 | else: 181 | actions = dist.sample(seed=seed) 182 | return actions 183 | 184 | @jax.jit 185 | def get_debug_metrics(self, batch, gripper_close_val=None, **kwargs): 186 | dist = self.state.apply_fn( 187 | {"params": self.state.params}, 188 | batch["observations"], 189 | temperature=1.0, 190 | name="actor", 191 | ) 192 | pi_actions = dist.mode() 193 | log_probs = dist.log_prob(batch["actions"]) 194 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 195 | 196 | v = self.state.apply_fn( 197 | {"params": self.state.params}, batch["observations"], name="value" 198 | ) 199 | next_v = self.state.apply_fn( 200 | {"params": self.state.target_params}, 201 | batch["next_observations"], 202 | name="value", 203 | ) 204 | target_q = batch["rewards"] + self.config["discount"] * next_v * batch["masks"] 205 | q = self.state.apply_fn( 206 | {"params": self.state.params}, 207 | batch["observations"], 208 | batch["actions"], 209 | name="critic", 210 | ) 211 | 212 | metrics = { 213 | "log_probs": log_probs, 214 | "mse": ((dist.mode() - batch["actions"]) ** 2).sum(-1), 215 | "pi_actions": pi_actions, 216 | "online_v": v, 217 | "online_q": q, 218 | "target_q": target_q, 219 | "value_err": expectile_loss(target_q - v, self.config["expectile"]), 220 | "td_err": jnp.square(target_q - q), 221 | "advantage": target_q - v, 222 | "qf_advantage": q - v, 223 | } 224 | 225 | if gripper_close_val is not None: 226 | gripper_close_q = self.state.apply_fn( 227 | {"params": self.state.params}, 228 | batch["observations"], 229 | jnp.broadcast_to(gripper_close_val, batch["actions"].shape), 230 | name="critic", 231 | ) 232 | metrics.update( 233 | { 234 | "gripper_close_q": gripper_close_q, 235 | "gripper_close_adv": gripper_close_q - v, 236 | } 237 | ) 238 | 239 | return metrics 240 | 241 | @classmethod 242 | def create( 243 | cls, 244 | rng: PRNGKey, 245 | observations: FrozenDict, 246 | actions: jnp.ndarray, 247 | # Model architecture 248 | encoder_def: nn.Module, 249 | shared_encoder: bool = True, 250 | use_proprio: bool = False, 251 | network_kwargs: dict = {"hidden_dims": [256, 256]}, 252 | policy_kwargs: dict = { 253 | "tanh_squash_distribution": False, 254 | "state_dependent_std": False, 255 | "dropout": 0.0, 256 | }, 257 | # Optimizer 258 | learning_rate: float = 3e-4, 259 | warmup_steps: int = 2000, 260 | actor_decay_steps: Optional[int] = None, 261 | # Algorithm config 262 | discount=0.95, 263 | expectile=0.9, 264 | temperature=1.0, 265 | target_update_rate=0.002, 266 | ): 267 | encoder_def = EncodingWrapper( 268 | encoder=encoder_def, use_proprio=use_proprio, stop_gradient=False 269 | ) 270 | 271 | if shared_encoder: 272 | encoders = { 273 | "actor": encoder_def, 274 | "value": encoder_def, 275 | "critic": encoder_def, 276 | } 277 | else: 278 | encoders = { 279 | "actor": encoder_def, 280 | "value": copy.deepcopy(encoder_def), 281 | "critic": copy.deepcopy(encoder_def), 282 | } 283 | 284 | network_kwargs["activate_final"] = True 285 | networks = { 286 | "actor": Policy( 287 | encoders["actor"], 288 | MLP(**network_kwargs), 289 | action_dim=actions.shape[-1], 290 | **policy_kwargs, 291 | ), 292 | "value": ValueCritic(encoders["value"], MLP(**network_kwargs)), 293 | "critic": Critic(encoders["critic"], MLP(**network_kwargs)), 294 | } 295 | 296 | model_def = ModuleDict(networks) 297 | 298 | # no decay 299 | lr_schedule = optax.warmup_cosine_decay_schedule( 300 | init_value=0.0, 301 | peak_value=learning_rate, 302 | warmup_steps=warmup_steps, 303 | decay_steps=warmup_steps + 1, 304 | end_value=learning_rate, 305 | ) 306 | lr_schedules = { 307 | "actor": lr_schedule, 308 | "value": lr_schedule, 309 | "critic": lr_schedule, 310 | } 311 | if actor_decay_steps is not None: 312 | lr_schedules["actor"] = optax.warmup_cosine_decay_schedule( 313 | init_value=0.0, 314 | peak_value=learning_rate, 315 | warmup_steps=warmup_steps, 316 | decay_steps=actor_decay_steps, 317 | end_value=0.0, 318 | ) 319 | txs = {k: optax.adam(v) for k, v in lr_schedules.items()} 320 | 321 | rng, init_rng = jax.random.split(rng) 322 | params = model_def.init( 323 | init_rng, 324 | actor=observations, 325 | value=observations, 326 | critic=[observations, actions], 327 | )["params"] 328 | 329 | rng, create_rng = jax.random.split(rng) 330 | state = JaxRLTrainState.create( 331 | apply_fn=model_def.apply, 332 | params=params, 333 | txs=txs, 334 | target_params=params, 335 | rng=create_rng, 336 | ) 337 | 338 | config = flax.core.FrozenDict( 339 | dict( 340 | discount=discount, 341 | temperature=temperature, 342 | target_update_rate=target_update_rate, 343 | expectile=expectile, 344 | ) 345 | ) 346 | return cls(state, config) 347 | -------------------------------------------------------------------------------- /jaxrl_m/agents/continuous/lc_bc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any 3 | 4 | import flax 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | from flax.core import FrozenDict 11 | 12 | from jaxrl_m.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 13 | from jaxrl_m.common.encoding import LCEncodingWrapper 14 | from jaxrl_m.common.typing import Batch, PRNGKey 15 | from jaxrl_m.networks.actor_critic_nets import Policy 16 | from jaxrl_m.networks.mlp import MLP 17 | 18 | 19 | class LCBCAgent(flax.struct.PyTreeNode): 20 | state: JaxRLTrainState 21 | lr_schedule: Any = nonpytree_field() 22 | 23 | @partial(jax.jit, static_argnames="pmap_axis") 24 | def update(self, batch: Batch, pmap_axis: str = None): 25 | def loss_fn(params, rng): 26 | rng, key = jax.random.split(rng) 27 | dist = self.state.apply_fn( 28 | {"params": params}, 29 | (batch["observations"], batch["goals"]), 30 | temperature=1.0, 31 | train=True, 32 | rngs={"dropout": key}, 33 | name="actor", 34 | ) 35 | pi_actions = dist.mode() 36 | log_probs = dist.log_prob(batch["actions"]) 37 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 38 | actor_loss = -(log_probs).mean() 39 | actor_std = dist.stddev().mean(axis=1) 40 | 41 | return ( 42 | actor_loss, 43 | { 44 | "actor_loss": actor_loss, 45 | "mse": mse.mean(), 46 | "log_probs": log_probs.mean(), 47 | "pi_actions": pi_actions.mean(), 48 | "mean_std": actor_std.mean(), 49 | "max_std": actor_std.max(), 50 | }, 51 | ) 52 | 53 | # compute gradients and update params 54 | new_state, info = self.state.apply_loss_fns( 55 | loss_fn, pmap_axis=pmap_axis, has_aux=True 56 | ) 57 | 58 | # log learning rates 59 | info["lr"] = self.lr_schedule(self.state.step) 60 | 61 | return self.replace(state=new_state), info 62 | 63 | @partial(jax.jit, static_argnames="argmax") 64 | def sample_actions( 65 | self, 66 | observations: np.ndarray, 67 | goals: np.ndarray, 68 | *, 69 | seed: PRNGKey, 70 | temperature: float = 1.0, 71 | argmax=False 72 | ) -> jnp.ndarray: 73 | dist = self.state.apply_fn( 74 | {"params": self.state.params}, 75 | (observations, goals), 76 | temperature=temperature, 77 | name="actor", 78 | ) 79 | if argmax: 80 | actions = dist.mode() 81 | else: 82 | actions = dist.sample(seed=seed) 83 | return actions 84 | 85 | @jax.jit 86 | def get_debug_metrics(self, batch, **kwargs): 87 | dist = self.state.apply_fn( 88 | {"params": self.state.params}, 89 | (batch["observations"], batch["goals"]), 90 | temperature=1.0, 91 | name="actor", 92 | ) 93 | pi_actions = dist.mode() 94 | log_probs = dist.log_prob(batch["actions"]) 95 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 96 | 97 | return {"mse": mse, "log_probs": log_probs, "pi_actions": pi_actions} 98 | 99 | @classmethod 100 | def create( 101 | cls, 102 | rng: PRNGKey, 103 | observations: FrozenDict, 104 | actions: jnp.ndarray, 105 | goals: FrozenDict, 106 | # Model architecture 107 | encoder_def: nn.Module, 108 | shared_goal_encoder: bool = True, 109 | early_goal_concat: bool = False, 110 | use_proprio: bool = False, 111 | network_kwargs: dict = {"hidden_dims": [256, 256]}, 112 | policy_kwargs: dict = { 113 | "tanh_squash_distribution": False, 114 | "state_dependent_std": False, 115 | "dropout": 0.0, 116 | }, 117 | # Optimizer 118 | learning_rate: float = 3e-4, 119 | warmup_steps: int = 1000, 120 | decay_steps: int = 1000000, 121 | ): 122 | 123 | encoder_def = LCEncodingWrapper( 124 | encoder=encoder_def, use_proprio=use_proprio, stop_gradient=False 125 | ) 126 | 127 | network_kwargs["activate_final"] = True 128 | networks = { 129 | "actor": Policy( 130 | encoder_def, 131 | MLP(**network_kwargs), 132 | action_dim=actions.shape[-1], 133 | **policy_kwargs 134 | ) 135 | } 136 | 137 | model_def = ModuleDict(networks) 138 | 139 | lr_schedule = optax.warmup_cosine_decay_schedule( 140 | init_value=0.0, 141 | peak_value=learning_rate, 142 | warmup_steps=warmup_steps, 143 | decay_steps=decay_steps, 144 | end_value=0.0, 145 | ) 146 | tx = optax.adam(lr_schedule) 147 | 148 | rng, init_rng = jax.random.split(rng) 149 | params = model_def.init(init_rng, actor=[(observations, goals)])["params"] 150 | 151 | rng, create_rng = jax.random.split(rng) 152 | state = JaxRLTrainState.create( 153 | apply_fn=model_def.apply, 154 | params=params, 155 | txs=tx, 156 | target_params=params, 157 | rng=create_rng, 158 | ) 159 | 160 | return cls(state, lr_schedule) 161 | -------------------------------------------------------------------------------- /jaxrl_m/common/common.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union, Mapping, Sequence, Callable 2 | import flax 3 | import flax.linen as nn 4 | from flax import struct 5 | import jax 6 | import jax.numpy as jnp 7 | import optax 8 | import functools 9 | 10 | from jaxrl_m.common.typing import PRNGKey, Params 11 | 12 | nonpytree_field = functools.partial(flax.struct.field, pytree_node=False) 13 | 14 | default_init = nn.initializers.xavier_uniform 15 | 16 | 17 | def shard_batch(batch, sharding): 18 | """Shards a batch across devices along its first dimension. 19 | 20 | Args: 21 | batch: A pytree of arrays. 22 | sharding: A jax Sharding object with shape (num_devices,). 23 | """ 24 | return jax.tree_map( 25 | lambda x: jax.device_put( 26 | x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) 27 | ), 28 | batch, 29 | ) 30 | 31 | 32 | class ModuleDict(nn.Module): 33 | """ 34 | Utility class for wrapping a dictionary of modules. This is useful when you have multiple modules that you want to 35 | initialize all at once (creating a single `params` dictionary), but you want to be able to call them separately 36 | later. As a bonus, the modules may have sub-modules nested inside them that share parameters (e.g. an image encoder) 37 | and Flax will automatically handle this without duplicating the parameters. 38 | 39 | To initialize the modules, call `init` with no `name` kwarg, and then pass the example arguments to each module as 40 | additional kwargs. To call the modules, pass the name of the module as the `name` kwarg, and then pass the arguments 41 | to the module as additional args or kwargs. 42 | 43 | Example usage: 44 | ``` 45 | shared_encoder = Encoder() 46 | actor = Actor(encoder=shared_encoder) 47 | critic = Critic(encoder=shared_encoder) 48 | 49 | model_def = ModuleDict({"actor": actor, "critic": critic}) 50 | params = model_def.init(rng_key, actor=example_obs, critic=(example_obs, example_action)) 51 | 52 | actor_output = model_def.apply({"params": params}, example_obs, name="actor") 53 | critic_output = model_def.apply({"params": params}, example_obs, action=example_action, name="critic") 54 | ``` 55 | """ 56 | 57 | modules: Dict[str, nn.Module] 58 | 59 | @nn.compact 60 | def __call__(self, *args, name=None, **kwargs): 61 | if name is None: 62 | if kwargs.keys() != self.modules.keys(): 63 | raise ValueError( 64 | f"When `name` is not specified, kwargs must contain the arguments for each module. " 65 | f"Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}" 66 | ) 67 | out = {} 68 | for key, value in kwargs.items(): 69 | if isinstance(value, Mapping): 70 | out[key] = self.modules[key](**value) 71 | elif isinstance(value, Sequence): 72 | out[key] = self.modules[key](*value) 73 | else: 74 | out[key] = self.modules[key](value) 75 | return out 76 | 77 | return self.modules[name](*args, **kwargs) 78 | 79 | 80 | class JaxRLTrainState(struct.PyTreeNode): 81 | """ 82 | Custom TrainState class to replace `flax.training.train_state.TrainState`. 83 | 84 | Adds support for holding target params and updating them via polyak 85 | averaging. Adds the ability to hold an rng key for dropout. 86 | 87 | Also generalizes the TrainState to support an arbitrary pytree of 88 | optimizers, `txs`. When `apply_gradients()` is called, the `grads` argument 89 | must have `txs` as a prefix. This is backwards-compatible, meaning `txs` can 90 | be a single optimizer and `grads` can be a single tree with the same 91 | structure as `self.params`. 92 | 93 | Also adds a convenience method `apply_loss_fns` that takes a pytree of loss 94 | functions with the same structure as `txs`, computes gradients, and applies 95 | them using `apply_gradients`. 96 | 97 | Attributes: 98 | step: The current training step. 99 | apply_fn: The function used to apply the model. 100 | params: The model parameters. 101 | target_params: The target model parameters. 102 | txs: The optimizer or pytree of optimizers. 103 | opt_states: The optimizer state or pytree of optimizer states. 104 | rng: The internal rng state. 105 | """ 106 | 107 | step: int 108 | apply_fn: Callable = struct.field(pytree_node=False) 109 | params: Params 110 | target_params: Params 111 | txs: Any = struct.field(pytree_node=False) 112 | opt_states: Any 113 | rng: PRNGKey 114 | 115 | @staticmethod 116 | def _tx_tree_map(*args, **kwargs): 117 | return jax.tree_map( 118 | *args, 119 | is_leaf=lambda x: isinstance(x, optax.GradientTransformation), 120 | **kwargs, 121 | ) 122 | 123 | def target_update(self, tau: float) -> "JaxRLTrainState": 124 | """ 125 | Performs an update of the target params via polyak averaging. The new 126 | target params are given by: 127 | 128 | new_target_params = tau * params + (1 - tau) * target_params 129 | """ 130 | new_target_params = jax.tree_map( 131 | lambda p, tp: p * tau + tp * (1 - tau), self.params, self.target_params 132 | ) 133 | return self.replace(target_params=new_target_params) 134 | 135 | def apply_gradients(self, *, grads: Any) -> "JaxRLTrainState": 136 | """ 137 | Only difference from flax's TrainState is that `grads` must have 138 | `self.txs` as a tree prefix (i.e. where `self.txs` has a leaf, `grads` 139 | has a subtree with the same structure as `self.params`.) 140 | """ 141 | updates_and_new_states = self._tx_tree_map( 142 | lambda tx, opt_state, grad: tx.update(grad, opt_state, self.params), 143 | self.txs, 144 | self.opt_states, 145 | grads, 146 | ) 147 | updates = self._tx_tree_map(lambda _, x: x[0], self.txs, updates_and_new_states) 148 | new_opt_states = self._tx_tree_map( 149 | lambda _, x: x[1], self.txs, updates_and_new_states 150 | ) 151 | 152 | # not the cleanest, I know, but this flattens the leaves of `updates` 153 | # into a list where leaves are defined by `self.txs` 154 | updates_flat = [] 155 | self._tx_tree_map( 156 | lambda _, update: updates_flat.append(update), self.txs, updates 157 | ) 158 | 159 | # apply all the updates additively 160 | updates_acc = jax.tree_map( 161 | lambda *xs: jnp.sum(jnp.array(xs), axis=0), *updates_flat 162 | ) 163 | new_params = optax.apply_updates(self.params, updates_acc) 164 | 165 | return self.replace( 166 | step=self.step + 1, params=new_params, opt_states=new_opt_states 167 | ) 168 | 169 | def apply_loss_fns( 170 | self, loss_fns: Any, pmap_axis: str = None, has_aux: bool = False 171 | ) -> Union["JaxRLTrainState", Tuple["JaxRLTrainState", Any]]: 172 | """ 173 | Convenience method to compute gradients based on `self.params` and apply 174 | them using `apply_gradients`. `loss_fns` must have the same structure as 175 | `txs`, and each leaf must be a function that takes two arguments: 176 | `params` and `rng`. 177 | 178 | This method automatically provides fresh rng to each loss function and 179 | updates this train state's internal rng key. 180 | 181 | Args: 182 | loss_fns: loss function or pytree of loss functions with same 183 | structure as `self.txs`. Each loss function must take `params` 184 | as the first argument and `rng` as the second argument, and return 185 | a scalar value. 186 | pmap_axis: if not None, gradients (and optionally auxiliary values) 187 | will be averaged over this axis 188 | has_aux: if True, each `loss_fn` returns a tuple of (loss, aux) where 189 | `aux` is a pytree of auxiliary values to be returned by this 190 | method. 191 | 192 | Returns: 193 | If `has_aux` is True, returns a tuple of (new_train_state, aux). 194 | Otherwise, returns the new train state. 195 | """ 196 | # create a pytree of rngs with the same structure as `loss_fns` 197 | treedef = jax.tree_util.tree_structure(loss_fns) 198 | new_rng, *rngs = jax.random.split(self.rng, treedef.num_leaves + 1) 199 | rngs = jax.tree_util.tree_unflatten(treedef, rngs) 200 | 201 | # compute gradients 202 | grads_and_aux = jax.tree_map( 203 | lambda loss_fn, rng: jax.grad(loss_fn, has_aux=has_aux)(self.params, rng), 204 | loss_fns, 205 | rngs, 206 | ) 207 | 208 | # update rng state 209 | self = self.replace(rng=new_rng) 210 | 211 | # average across devices if necessary 212 | if pmap_axis is not None: 213 | grads_and_aux = jax.lax.pmean(grads_and_aux, axis_name=pmap_axis) 214 | 215 | if has_aux: 216 | grads = jax.tree_map(lambda _, x: x[0], loss_fns, grads_and_aux) 217 | aux = jax.tree_map(lambda _, x: x[1], loss_fns, grads_and_aux) 218 | return self.apply_gradients(grads=grads), aux 219 | else: 220 | return self.apply_gradients(grads=grads_and_aux) 221 | 222 | @classmethod 223 | def create( 224 | cls, *, apply_fn, params, txs, target_params=None, rng=jax.random.PRNGKey(0) 225 | ): 226 | """ 227 | Initializes a new train state. 228 | 229 | Args: 230 | apply_fn: The function used to apply the model, typically `model_def.apply`. 231 | params: The model parameters, typically from `model_def.init`. 232 | txs: The optimizer or pytree of optimizers. 233 | target_params: The target model parameters. 234 | rng: The rng key used to initialize the rng chain for `apply_loss_fns`. 235 | """ 236 | return cls( 237 | step=0, 238 | apply_fn=apply_fn, 239 | params=params, 240 | target_params=target_params, 241 | txs=txs, 242 | opt_states=cls._tx_tree_map(lambda tx: tx.init(params), txs), 243 | rng=rng, 244 | ) 245 | -------------------------------------------------------------------------------- /jaxrl_m/common/encoding.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | from einops import rearrange, repeat 7 | 8 | 9 | class EncodingWrapper(nn.Module): 10 | """ 11 | Encodes observations into a single flat encoding, adding additional 12 | functionality for adding proprioception and stopping the gradient. 13 | 14 | Args: 15 | encoder: The encoder network. 16 | use_proprio: Whether to concatenate proprioception (after encoding). 17 | stop_gradient: Whether to stop the gradient after the encoder. 18 | """ 19 | 20 | encoder: nn.Module 21 | use_proprio: bool 22 | stop_gradient: bool 23 | 24 | def __call__(self, observations: Dict[str, jnp.ndarray]) -> jnp.ndarray: 25 | encoding = self.encoder(observations["image"]) 26 | if self.use_proprio: 27 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 28 | if self.stop_gradient: 29 | encoding = jax.lax.stop_gradient(encoding) 30 | return encoding 31 | 32 | 33 | class GCEncodingWrapper(nn.Module): 34 | """ 35 | Encodes observations and goals into a single flat encoding. Handles all the 36 | logic about when/how to combine observations and goals. 37 | 38 | Takes a tuple (observations, goals) as input. 39 | 40 | Args: 41 | encoder: The encoder network for observations. 42 | goal_encoder: The encoder to use for goals (optional). If None, early 43 | goal concatenation is used, i.e. the goal is concatenated to the 44 | observation channel-wise before passing it through the encoder. 45 | use_proprio: Whether to concatenate proprioception (after encoding). 46 | stop_gradient: Whether to stop the gradient after the encoder. 47 | """ 48 | 49 | encoder: nn.Module 50 | goal_encoder: Optional[nn.Module] 51 | use_proprio: bool 52 | stop_gradient: bool 53 | 54 | def __call__( 55 | self, 56 | observations_and_goals: Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]], 57 | ) -> jnp.ndarray: 58 | observations, goals = observations_and_goals 59 | 60 | if len(observations["image"].shape) == 5: 61 | # obs history case 62 | batch_size, obs_horizon = observations["image"].shape[:2] 63 | # fold batch_size into obs_horizon to encode each frame separately 64 | obs_image = rearrange(observations["image"], "B T H W C -> (B T) H W C") 65 | # repeat goals so that there's a goal for each frame 66 | goal_image = repeat( 67 | goals["image"], "B H W C -> (B repeat) H W C", repeat=obs_horizon 68 | ) 69 | else: 70 | obs_image = observations["image"] 71 | goal_image = goals["image"] 72 | 73 | if self.goal_encoder is None: 74 | # early goal concat 75 | encoder_inputs = jnp.concatenate([obs_image, goal_image], axis=-1) 76 | encoding = self.encoder(encoder_inputs) 77 | else: 78 | # late fusion 79 | encoding = self.encoder(obs_image) 80 | goal_encoding = self.goal_encoder(goals["image"]) 81 | encoding = jnp.concatenate([encoding, goal_encoding], axis=-1) 82 | 83 | if len(observations["image"].shape) == 5: 84 | # unfold obs_horizon from batch_size 85 | encoding = rearrange( 86 | encoding, "(B T) F -> B (T F)", B=batch_size, T=obs_horizon 87 | ) 88 | 89 | if self.use_proprio: 90 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 91 | 92 | if self.stop_gradient: 93 | encoding = jax.lax.stop_gradient(encoding) 94 | 95 | return encoding 96 | 97 | 98 | class LCEncodingWrapper(nn.Module): 99 | """ 100 | Encodes observations and language instructions into a single flat encoding. 101 | 102 | Takes a tuple (observations, goals) as input, where goals contains the language instruction. 103 | 104 | Args: 105 | encoder: The encoder network for observations. 106 | use_proprio: Whether to concatenate proprioception (after encoding). 107 | stop_gradient: Whether to stop the gradient after the encoder. 108 | """ 109 | 110 | encoder: nn.Module 111 | use_proprio: bool 112 | stop_gradient: bool 113 | 114 | def __call__( 115 | self, 116 | observations_and_goals: Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]], 117 | ) -> jnp.ndarray: 118 | observations, goals = observations_and_goals 119 | 120 | if len(observations["image"].shape) == 5: 121 | # obs history case 122 | batch_size, obs_horizon = observations["image"].shape[:2] 123 | # fold batch_size into obs_horizon to encode each frame separately 124 | obs_image = rearrange(observations["image"], "B T H W C -> (B T) H W C") 125 | # repeat language so that there's an instruction for each frame 126 | language = repeat( 127 | goals["language"], "B E -> (B repeat) E", repeat=obs_horizon 128 | ) 129 | else: 130 | obs_image = observations["image"] 131 | language = goals["language"] 132 | 133 | encoding = self.encoder(obs_image, cond_var=language) 134 | 135 | if len(observations["image"].shape) == 5: 136 | # unfold obs_horizon from batch_size 137 | encoding = rearrange( 138 | encoding, "(B T) F -> B (T F)", B=batch_size, T=obs_horizon 139 | ) 140 | 141 | if self.use_proprio: 142 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 143 | 144 | if self.stop_gradient: 145 | encoding = jax.lax.stop_gradient(encoding) 146 | 147 | return encoding 148 | -------------------------------------------------------------------------------- /jaxrl_m/common/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Sequence, Union 2 | 3 | import numpy as np 4 | import jax.numpy as jnp 5 | import flax 6 | import tensorflow as tf 7 | 8 | 9 | PRNGKey = Any 10 | Params = flax.core.FrozenDict[str, Any] 11 | Shape = Sequence[int] 12 | Dtype = Any # this could be a real type? 13 | InfoDict = Dict[str, float] 14 | Array = Union[np.ndarray, jnp.ndarray, tf.Tensor] 15 | Data = Union[Array, Dict[str, "Data"]] 16 | Batch = Dict[str, Data] 17 | # A method to be passed into TrainState.__call__ 18 | ModuleMethod = Union[str, Callable, None] 19 | -------------------------------------------------------------------------------- /jaxrl_m/common/wandb.py: -------------------------------------------------------------------------------- 1 | import absl.flags as flags 2 | import datetime 3 | import tempfile 4 | from copy import copy 5 | from socket import gethostname 6 | 7 | import ml_collections 8 | import wandb 9 | 10 | 11 | def _recursive_flatten_dict(d: dict): 12 | keys, values = [], [] 13 | for key, value in d.items(): 14 | if isinstance(value, dict): 15 | sub_keys, sub_values = _recursive_flatten_dict(value) 16 | keys += [f"{key}/{k}" for k in sub_keys] 17 | values += sub_values 18 | else: 19 | keys.append(key) 20 | values.append(value) 21 | return keys, values 22 | 23 | 24 | class WandBLogger(object): 25 | @staticmethod 26 | def get_default_config(): 27 | config = ml_collections.ConfigDict() 28 | config.project = "jaxrl_m" # WandB Project Name 29 | config.entity = ml_collections.config_dict.FieldReference(None, field_type=str) 30 | # Which entity to log as (default: your own user) 31 | config.exp_descriptor = "" # Run name (doesn't have to be unique) 32 | # Unique identifier for run (will be automatically generated unless 33 | # provided) 34 | config.unique_identifier = "" 35 | return config 36 | 37 | def __init__(self, wandb_config, variant, wandb_output_dir=None, debug=False): 38 | self.config = wandb_config 39 | if self.config.unique_identifier == "": 40 | self.config.unique_identifier = datetime.datetime.now().strftime( 41 | "%Y%m%d_%H%M%S" 42 | ) 43 | 44 | self.config.experiment_id = ( 45 | self.experiment_id 46 | ) = f"{self.config.exp_descriptor}_{self.config.unique_identifier}" # NOQA 47 | 48 | print(self.config) 49 | 50 | if wandb_output_dir is None: 51 | wandb_output_dir = tempfile.mkdtemp() 52 | 53 | self._variant = copy(variant) 54 | 55 | if "hostname" not in self._variant: 56 | self._variant["hostname"] = gethostname() 57 | 58 | if debug: 59 | mode = "disabled" 60 | else: 61 | mode = "online" 62 | 63 | self.run = wandb.init( 64 | config=self._variant, 65 | project=self.config.project, 66 | entity=self.config.entity, 67 | dir=wandb_output_dir, 68 | id=self.config.experiment_id, 69 | save_code=True, 70 | mode=mode, 71 | ) 72 | 73 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS} 74 | for k in flag_dict: 75 | if isinstance(flag_dict[k], ml_collections.ConfigDict): 76 | flag_dict[k] = flag_dict[k].to_dict() 77 | wandb.config.update(flag_dict) 78 | 79 | def log(self, data: dict, step: int = None): 80 | data_flat = _recursive_flatten_dict(data) 81 | data = {k: v for k, v in zip(*data_flat)} 82 | wandb.log(data, step=step) 83 | -------------------------------------------------------------------------------- /jaxrl_m/data/text_processing.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import tensorflow as tf 6 | from flax.core import FrozenDict 7 | 8 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" 9 | 10 | 11 | class TextProcessor: 12 | """ 13 | Base class for text tokenization or text embedding. 14 | """ 15 | 16 | def encode(self, strings): 17 | pass 18 | 19 | 20 | class HFTokenizer(TextProcessor): 21 | def __init__( 22 | self, 23 | tokenizer_name: str, 24 | tokenizer_kwargs: Optional[dict] = { 25 | "max_length": 64, 26 | "padding": "max_length", 27 | "truncation": True, 28 | "return_tensors": "np", 29 | }, 30 | encode_with_model: bool = False, 31 | ): 32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 35 | self.tokenizer_kwargs = tokenizer_kwargs 36 | self.encode_with_model = encode_with_model 37 | if self.encode_with_model: 38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name) 39 | 40 | def encode(self, strings): 41 | # this creates another nested layer with "input_ids", "attention_mask", etc. 42 | inputs = self.tokenizer(strings, **self.tokenizer_kwargs) 43 | if self.encode_with_model: 44 | return np.array(self.model(**inputs).last_hidden_state) 45 | else: 46 | return FrozenDict(inputs) 47 | 48 | 49 | class MuseEmbedding(TextProcessor): 50 | def __init__(self): 51 | import tensorflow_hub as hub # lazy import 52 | import tensorflow_text # required for muse 53 | 54 | self.muse_model = hub.load(MULTI_MODULE) 55 | 56 | def encode(self, strings): 57 | with tf.device("/cpu:0"): 58 | return self.muse_model(strings).numpy() 59 | 60 | 61 | class CLIPTextProcessor(TextProcessor): 62 | def __init__( 63 | self, 64 | tokenizer_kwargs: Optional[dict] = { 65 | "max_length": 64, 66 | "padding": "max_length", 67 | "truncation": True, 68 | "return_tensors": "np", 69 | }, 70 | ): 71 | from transformers import CLIPProcessor 72 | 73 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 74 | self.kwargs = tokenizer_kwargs 75 | 76 | def encode(self, strings): 77 | inputs = self.processor(text=strings, **self.kwargs) 78 | inputs["position_ids"] = jnp.expand_dims( 79 | jnp.arange(inputs["input_ids"].shape[1]), axis=0 80 | ).repeat(inputs["input_ids"].shape[0], axis=0) 81 | return FrozenDict(inputs) 82 | 83 | 84 | text_processors = { 85 | "hf_tokenizer": HFTokenizer, 86 | "muse_embedding": MuseEmbedding, 87 | "clip_processor": CLIPTextProcessor, 88 | } 89 | -------------------------------------------------------------------------------- /jaxrl_m/data/tf_augmentations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from collections.abc import Mapping 3 | from ml_collections import ConfigDict 4 | 5 | 6 | def random_resized_crop(image, scale, ratio, seed, batched=False): 7 | if not batched: 8 | image = tf.expand_dims(image, axis=0) 9 | batch_size = tf.shape(image)[0] 10 | # taken from https://keras.io/examples/vision/nnclr/#random-resized-crops 11 | log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1])) 12 | height = tf.shape(image)[-3] 13 | width = tf.shape(image)[-2] 14 | 15 | random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1]) 16 | random_ratios = tf.exp( 17 | tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1]) 18 | ) 19 | 20 | new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1) 21 | new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1) 22 | height_offsets = tf.random.stateless_uniform( 23 | (batch_size,), seed, 0, 1 - new_heights 24 | ) 25 | width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths) 26 | 27 | bounding_boxes = tf.stack( 28 | [ 29 | height_offsets, 30 | width_offsets, 31 | height_offsets + new_heights, 32 | width_offsets + new_widths, 33 | ], 34 | axis=1, 35 | ) 36 | 37 | if len(tf.shape(image)) == 5: 38 | obs_horizon = tf.shape(image)[1] 39 | # fold obs_horizon dimension into batch dimension 40 | image = tf.reshape(image, [batch_size * obs_horizon, height, width, -1]) 41 | # repeat bounding_boxes so each obs history is augmented the same 42 | bounding_boxes = tf.repeat(bounding_boxes, obs_horizon, axis=0) 43 | image = tf.image.crop_and_resize( 44 | image, bounding_boxes, tf.range(batch_size * obs_horizon), (height, width) 45 | ) 46 | image = tf.reshape(image, [batch_size, obs_horizon, height, width, -1]) 47 | else: 48 | image = tf.image.crop_and_resize( 49 | image, bounding_boxes, tf.range(batch_size), (height, width) 50 | ) 51 | 52 | if not batched: 53 | return image[0] 54 | else: 55 | return image 56 | 57 | 58 | AUGMENT_OPS = { 59 | "random_resized_crop": random_resized_crop, 60 | "random_brightness": tf.image.stateless_random_brightness, 61 | "random_contrast": tf.image.stateless_random_contrast, 62 | "random_saturation": tf.image.stateless_random_saturation, 63 | "random_hue": tf.image.stateless_random_hue, 64 | "random_flip_left_right": tf.image.stateless_random_flip_left_right, 65 | } 66 | 67 | 68 | def augment(image, seed, **augment_kwargs): 69 | image = tf.cast(image, tf.float32) / 255 # convert images to [0, 1] 70 | for op in augment_kwargs["augment_order"]: 71 | if op in augment_kwargs: 72 | if isinstance(augment_kwargs[op], Mapping) or isinstance( 73 | augment_kwargs[op], ConfigDict 74 | ): 75 | image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op]) 76 | else: 77 | image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op]) 78 | else: 79 | image = AUGMENT_OPS[op](image, seed=seed) 80 | image = tf.clip_by_value(image, 0, 1) 81 | image = tf.cast(image * 255, tf.uint8) 82 | return image 83 | 84 | 85 | def augment_batch(images, seed, **augment_kwargs): 86 | # we shouldn't need this anymore 87 | raise NotImplementedError 88 | batch_size = tf.shape(images)[0] 89 | sub_seeds = [seed] 90 | for _ in range(batch_size): 91 | sub_seeds.append( 92 | tf.random.stateless_uniform( 93 | [2], seed=sub_seeds[-1], minval=None, maxval=None, dtype=tf.int32 94 | ) 95 | ) 96 | images = tf.cast(images, tf.float32) / 255 # convert images to [0, 1] 97 | for op in augment_kwargs["augment_order"]: 98 | if op in augment_kwargs: 99 | if isinstance(augment_kwargs[op], Mapping) or isinstance( 100 | augment_kwargs[op], ConfigDict 101 | ): 102 | # this is random_resized_crop which can handle batches 103 | assert op == "random_resized_crop" 104 | images = AUGMENT_OPS[op]( 105 | images, seed=seed, batched=True, **augment_kwargs[op] 106 | ) 107 | else: 108 | images_list = [] 109 | for i in range(batch_size): 110 | images_list.append( 111 | AUGMENT_OPS[op]( 112 | images[i], seed=sub_seeds[i], *augment_kwargs[op] 113 | ) 114 | ) 115 | images = tf.stack(images_list) 116 | else: 117 | images_list = [] 118 | for i in range(batch_size): 119 | images_list.append(AUGMENT_OPS[op](images[i], seed=sub_seeds[i])) 120 | images = tf.stack(images_list) 121 | images = tf.clip_by_value(images, 0, 1) 122 | images = tf.cast(images * 255, tf.uint8) 123 | return images 124 | -------------------------------------------------------------------------------- /jaxrl_m/data/tf_goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains goal relabeling and reward logic written in TensorFlow. 3 | 4 | Each relabeling function takes a trajectory with keys "observations", 5 | "next_observations", and "terminals". It returns a new trajectory with the added 6 | keys "goals", "rewards", and "masks". Keep in mind that "observations" and 7 | "next_observations" may themselves be dictionaries, and "goals" must match their 8 | structure. 9 | 10 | "masks" determines when the next Q-value is masked out. Typically this is NOT(terminals). 11 | """ 12 | 13 | import tensorflow as tf 14 | 15 | 16 | def uniform(traj, *, reached_proportion): 17 | """ 18 | Relabels with a true uniform distribution over future states. With 19 | probability reached_proportion, observations[i] gets a goal 20 | equal to next_observations[i]. In this case, the reward is 0. Otherwise, 21 | observations[i] gets a goal sampled uniformly from the set 22 | next_observations[i + 1:], and the reward is -1. 23 | """ 24 | traj_len = tf.shape(traj["terminals"])[0] 25 | 26 | # select a random future index for each transition i in the range [i + 1, traj_len) 27 | rand = tf.random.uniform([traj_len]) 28 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 29 | high = tf.cast(traj_len, tf.float32) 30 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 31 | 32 | # TODO(kvablack): don't know how I got an out-of-bounds during training, 33 | # could not reproduce, trying to patch it for now 34 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 35 | 36 | # select a random proportion of transitions to relabel with the next observation 37 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 38 | 39 | # the last transition must be goal-reaching 40 | goal_reached_mask = tf.logical_or( 41 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 42 | ) 43 | 44 | # make goal-reaching transitions have an offset of 0 45 | goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs) 46 | 47 | # select goals 48 | traj["goals"] = tf.nest.map_structure( 49 | lambda x: tf.gather(x, goal_idxs), 50 | traj["next_observations"], 51 | ) 52 | 53 | # reward is 0 for goal-reaching transitions, -1 otherwise 54 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 55 | 56 | # add masks 57 | traj["masks"] = tf.logical_not(traj["terminals"]) 58 | 59 | return traj 60 | 61 | 62 | def last_state_upweighted(traj, *, reached_proportion): 63 | """ 64 | A weird relabeling scheme where the last state gets upweighted. For each 65 | transition i, a uniform random number is generated in the range [i + 1, i + 66 | traj_len). It then gets clipped to be less than traj_len. Therefore, the 67 | first transition (i = 0) gets a goal sampled uniformly from the future, but 68 | for i > 0 the last state gets more and more upweighted. 69 | """ 70 | traj_len = tf.shape(traj["terminals"])[0] 71 | 72 | # select a random future index for each transition 73 | offsets = tf.random.uniform( 74 | [traj_len], 75 | minval=1, 76 | maxval=traj_len, 77 | dtype=tf.int32, 78 | ) 79 | 80 | # select random transitions to relabel as goal-reaching 81 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 82 | # last transition is always goal-reaching 83 | goal_reached_mask = tf.logical_or( 84 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 85 | ) 86 | 87 | # the goal will come from the current transition if the goal was reached 88 | offsets = tf.where(goal_reached_mask, 0, offsets) 89 | 90 | # convert from relative to absolute indices 91 | indices = tf.range(traj_len) + offsets 92 | 93 | # clamp out of bounds indices to the last transition 94 | indices = tf.minimum(indices, traj_len - 1) 95 | 96 | # select goals 97 | traj["goals"] = tf.nest.map_structure( 98 | lambda x: tf.gather(x, indices), 99 | traj["next_observations"], 100 | ) 101 | 102 | # reward is 0 for goal-reaching transitions, -1 otherwise 103 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 104 | 105 | # add masks 106 | traj["masks"] = tf.logical_not(traj["terminals"]) 107 | 108 | return traj 109 | 110 | 111 | def geometric(traj, *, reached_proportion, discount): 112 | """ 113 | Relabels with a geometric distribution over future states. With 114 | probability reached_proportion, observations[i] gets a goal 115 | equal to next_observations[i]. In this case, the reward is 0. Otherwise, 116 | observations[i] gets a goal sampled geometrically from the set 117 | next_observations[i + 1:], and the reward is -1. 118 | """ 119 | traj_len = tf.shape(traj["terminals"])[0] 120 | 121 | # geometrically select a future index for each transition i in the range [i + 1, traj_len) 122 | arange = tf.range(traj_len) 123 | is_future_mask = tf.cast(arange[:, None] < arange[None], tf.float32) 124 | d = discount ** tf.cast(arange[None] - arange[:, None], tf.float32) 125 | 126 | probs = is_future_mask * d 127 | # The indexing changes the shape from [seq_len, 1] to [seq_len] 128 | goal_idxs = tf.random.categorical( 129 | logits=tf.math.log(probs), num_samples=1, dtype=tf.int32 130 | )[:, 0] 131 | 132 | # select a random proportion of transitions to relabel with the next observation 133 | goal_reached_mask = tf.random.uniform([traj_len]) < reached_proportion 134 | 135 | # the last transition must be goal-reaching 136 | goal_reached_mask = tf.logical_or( 137 | goal_reached_mask, tf.range(traj_len) == traj_len - 1 138 | ) 139 | 140 | # make goal-reaching transitions have an offset of 0 141 | goal_idxs = tf.where(goal_reached_mask, tf.range(traj_len), goal_idxs) 142 | 143 | # select goals 144 | traj["goals"] = tf.nest.map_structure( 145 | lambda x: tf.gather(x, goal_idxs), 146 | traj["next_observations"], 147 | ) 148 | 149 | # reward is 0 for goal-reaching transitions, -1 otherwise 150 | traj["rewards"] = tf.cast(tf.where(goal_reached_mask, 0, -1), tf.int32) 151 | 152 | # add masks 153 | traj["masks"] = tf.logical_not(traj["terminals"]) 154 | 155 | return traj 156 | 157 | 158 | def delta_goals(traj, *, goal_delta): 159 | """ 160 | Relabels with a uniform distribution over future states in the range [i + 161 | goal_delta[0], min{traj_len, i + goal_delta[1]}). Truncates trajectories to 162 | have length traj_len - goal_delta[0]. Not suitable for RL (does not add 163 | terminals or rewards). 164 | """ 165 | traj_len = tf.shape(traj["terminals"])[0] 166 | 167 | # add the last observation (which only exists in next_observations) to get 168 | # all the observations 169 | all_obs = tf.nest.map_structure( 170 | lambda obs, next_obs: tf.concat([obs, next_obs[-1:]], axis=0), 171 | traj["observations"], 172 | traj["next_observations"], 173 | ) 174 | all_obs_len = traj_len + 1 175 | 176 | # current obs should only come from [0, traj_len - goal_delta[0]) 177 | curr_idxs = tf.range(traj_len - goal_delta[0]) 178 | 179 | # select a random future index for each transition i in the range [i + goal_delta[0], min{all_obs_len, i + goal_delta[1]}) 180 | rand = tf.random.uniform([traj_len - goal_delta[0]]) 181 | low = tf.cast(curr_idxs + goal_delta[0], tf.float32) 182 | high = tf.cast(tf.minimum(all_obs_len, curr_idxs + goal_delta[1]), tf.float32) 183 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 184 | 185 | # very rarely, floating point errors can cause goal_idxs to be out of bounds 186 | goal_idxs = tf.minimum(goal_idxs, all_obs_len - 1) 187 | 188 | traj_truncated = tf.nest.map_structure( 189 | lambda x: tf.gather(x, curr_idxs), 190 | traj, 191 | ) 192 | 193 | # select goals 194 | traj_truncated["goals"] = tf.nest.map_structure( 195 | lambda x: tf.gather(x, goal_idxs), 196 | all_obs, 197 | ) 198 | 199 | traj_truncated["goal_dists"] = goal_idxs - curr_idxs 200 | 201 | return traj_truncated 202 | 203 | 204 | GOAL_RELABELING_FUNCTIONS = { 205 | "uniform": uniform, 206 | "last_state_upweighted": last_state_upweighted, 207 | "geometric": geometric, 208 | "delta_goals": delta_goals, 209 | } 210 | -------------------------------------------------------------------------------- /jaxrl_m/networks/actor_critic_nets.py: -------------------------------------------------------------------------------- 1 | import distrax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | 5 | from typing import Optional 6 | from jaxrl_m.common.common import default_init 7 | from jaxrl_m.networks.mlp import MLP 8 | from functools import partial 9 | 10 | 11 | class ValueCritic(nn.Module): 12 | encoder: nn.Module 13 | network: nn.Module 14 | init_final: Optional[float] = None 15 | 16 | @nn.compact 17 | def __call__(self, observations: jnp.ndarray, train: bool = False) -> jnp.ndarray: 18 | outputs = self.network(self.encoder(observations), train=train) 19 | if self.init_final is not None: 20 | value = nn.Dense( 21 | 1, 22 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), 23 | )(outputs) 24 | else: 25 | value = nn.Dense(1, kernel_init=default_init())(outputs) 26 | return jnp.squeeze(value, -1) 27 | 28 | 29 | class Critic(nn.Module): 30 | encoder: nn.Module 31 | network: nn.Module 32 | init_final: Optional[float] = None 33 | 34 | @nn.compact 35 | def __call__( 36 | self, observations: jnp.ndarray, actions: jnp.ndarray, train: bool = False 37 | ) -> jnp.ndarray: 38 | obs_enc = self.encoder(observations) 39 | inputs = jnp.concatenate([obs_enc, actions], -1) 40 | outputs = self.network(inputs, train=train) 41 | if self.init_final is not None: 42 | value = nn.Dense( 43 | 1, 44 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), 45 | )(outputs) 46 | else: 47 | value = nn.Dense(1, kernel_init=default_init())(outputs) 48 | return jnp.squeeze(value, -1) 49 | 50 | 51 | class ContrastiveCritic(nn.Module): 52 | encoder: nn.Module 53 | sa_net: nn.Module 54 | g_net: nn.Module 55 | repr_dim: int = 16 56 | twin_q: bool = True 57 | sa_net2: Optional[nn.Module] = None 58 | g_net2: Optional[nn.Module] = None 59 | init_final: Optional[float] = None 60 | 61 | @nn.compact 62 | def __call__( 63 | self, observations: jnp.ndarray, actions: jnp.ndarray, train: bool = False 64 | ) -> jnp.ndarray: 65 | obs_goal_encoding = self.encoder(observations) 66 | encoding_dim = obs_goal_encoding.shape[-1] // 2 67 | obs_encoding, goal_encoding = ( 68 | obs_goal_encoding[..., :encoding_dim], 69 | obs_goal_encoding[..., encoding_dim:], 70 | ) 71 | 72 | if self.init_final is not None: 73 | kernel_init = partial( 74 | nn.initializers.uniform, -self.init_final, self.init_final 75 | ) 76 | else: 77 | kernel_init = default_init 78 | 79 | sa_inputs = jnp.concatenate([obs_encoding, actions], -1) 80 | sa_repr = self.sa_net(sa_inputs, train=train) 81 | sa_repr = nn.Dense(self.repr_dim, kernel_init=kernel_init())(sa_repr) 82 | g_repr = self.g_net(goal_encoding, train=train) 83 | g_repr = nn.Dense(self.repr_dim, kernel_init=kernel_init())(g_repr) 84 | outer = jnp.einsum("ik,jk->ij", sa_repr, g_repr) 85 | 86 | if self.twin_q: 87 | sa_repr2 = self.sa_net2(sa_inputs, train=train) 88 | sa_repr2 = nn.Dense(self.repr_dim, kernel_init=kernel_init())(sa_repr2) 89 | g_repr2 = self.g_net2(goal_encoding, train=train) 90 | g_repr2 = nn.Dense(self.repr_dim, kernel_init=kernel_init())(g_repr2) 91 | outer2 = jnp.einsum("ik,jk->ij", sa_repr2, g_repr2) 92 | 93 | outer = jnp.stack([outer, outer2], axis=-1) 94 | 95 | return outer 96 | 97 | 98 | def ensemblize(cls, num_qs, out_axes=0): 99 | return nn.vmap( 100 | cls, 101 | variable_axes={"params": 0}, 102 | split_rngs={"params": True}, 103 | in_axes=None, 104 | out_axes=out_axes, 105 | axis_size=num_qs, 106 | ) 107 | 108 | 109 | class Policy(nn.Module): 110 | encoder: nn.Module 111 | network: nn.Module 112 | action_dim: int 113 | init_final: Optional[float] = None 114 | log_std_min: Optional[float] = -20 115 | log_std_max: Optional[float] = 2 116 | tanh_squash_distribution: bool = False 117 | fixed_std: Optional[jnp.ndarray] = None 118 | state_dependent_std: bool = True 119 | 120 | @nn.compact 121 | def __call__( 122 | self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False 123 | ) -> distrax.Distribution: 124 | outputs = self.network(self.encoder(observations), train=train) 125 | 126 | means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs) 127 | if self.fixed_std is None: 128 | if self.state_dependent_std: 129 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init())( 130 | outputs 131 | ) 132 | else: 133 | log_stds = self.param( 134 | "log_stds", nn.initializers.zeros, (self.action_dim,) 135 | ) 136 | else: 137 | log_stds = jnp.log(jnp.array(self.fixed_std)) 138 | 139 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) / temperature 140 | 141 | if self.tanh_squash_distribution: 142 | distribution = TanhMultivariateNormalDiag( 143 | loc=means, scale_diag=jnp.exp(log_stds) 144 | ) 145 | else: 146 | distribution = distrax.MultivariateNormalDiag( 147 | loc=means, scale_diag=jnp.exp(log_stds) 148 | ) 149 | 150 | return distribution 151 | 152 | 153 | class TanhMultivariateNormalDiag(distrax.Transformed): 154 | def __init__( 155 | self, 156 | loc: jnp.ndarray, 157 | scale_diag: jnp.ndarray, 158 | low: Optional[jnp.ndarray] = None, 159 | high: Optional[jnp.ndarray] = None, 160 | ): 161 | distribution = distrax.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) 162 | 163 | layers = [] 164 | 165 | if not (low is None or high is None): 166 | 167 | def rescale_from_tanh(x): 168 | x = (x + 1) / 2 # (-1, 1) => (0, 1) 169 | return x * (high - low) + low 170 | 171 | def forward_log_det_jacobian(x): 172 | high_ = jnp.broadcast_to(high, x.shape) 173 | low_ = jnp.broadcast_to(low, x.shape) 174 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1) 175 | 176 | layers.append( 177 | distrax.Lambda( 178 | rescale_from_tanh, 179 | forward_log_det_jacobian=forward_log_det_jacobian, 180 | event_ndims_in=1, 181 | event_ndims_out=1, 182 | ) 183 | ) 184 | 185 | layers.append(distrax.Block(distrax.Tanh(), 1)) 186 | 187 | bijector = distrax.Chain(layers) 188 | 189 | super().__init__(distribution=distribution, bijector=bijector) 190 | 191 | def mode(self) -> jnp.ndarray: 192 | return self.bijector.forward(self.distribution.mode()) 193 | 194 | def stddev(self) -> jnp.ndarray: 195 | return self.bijector.forward(self.distribution.stddev()) 196 | -------------------------------------------------------------------------------- /jaxrl_m/networks/diffusion_nets.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import flax.linen as nn 3 | from jaxrl_m.common.typing import Dict 4 | 5 | 6 | def cosine_beta_schedule(timesteps, s=0.008): 7 | """ 8 | cosine schedule 9 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 10 | """ 11 | steps = timesteps + 1 12 | t = jnp.linspace(0, timesteps, steps) / timesteps 13 | alphas_cumprod = jnp.cos((t + s) / (1 + s) * jnp.pi * 0.5) ** 2 14 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 15 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 16 | return jnp.clip(betas, 0, 0.999) 17 | 18 | 19 | def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2): 20 | betas = jnp.linspace(beta_start, beta_end, timesteps) 21 | return betas 22 | 23 | 24 | def vp_beta_schedule(timesteps): 25 | t = jnp.arange(1, timesteps + 1) 26 | T = timesteps 27 | b_max = 10.0 28 | b_min = 0.1 29 | alpha = jnp.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2) 30 | betas = 1 - alpha 31 | return betas 32 | 33 | 34 | class ScoreActor(nn.Module): 35 | encoder: nn.Module 36 | time_preprocess: nn.Module 37 | cond_encoder: nn.Module 38 | reverse_network: nn.Module 39 | 40 | def __call__(self, observations, actions, time, train=False): 41 | # flatten actions 42 | flat_actions = actions.reshape([actions.shape[0], -1]) 43 | 44 | t_ff = self.time_preprocess(time) 45 | cond_enc = self.cond_encoder(t_ff, train=train) 46 | obs_enc = self.encoder(observations) 47 | reverse_input = jnp.concatenate([cond_enc, obs_enc, flat_actions], axis=-1) 48 | eps_pred = self.reverse_network(reverse_input, train=train) 49 | 50 | # un-flatten pred sequence 51 | return eps_pred.reshape(actions.shape) 52 | 53 | 54 | class FourierFeatures(nn.Module): 55 | output_size: int 56 | learnable: bool = True 57 | 58 | @nn.compact 59 | def __call__(self, x: jnp.ndarray): 60 | if self.learnable: 61 | w = self.param( 62 | "kernel", 63 | nn.initializers.normal(0.2), 64 | (self.output_size // 2, x.shape[-1]), 65 | jnp.float32, 66 | ) 67 | f = 2 * jnp.pi * x @ w.T 68 | else: 69 | half_dim = self.output_size // 2 70 | f = jnp.log(10000) / (half_dim - 1) 71 | f = jnp.exp(jnp.arange(half_dim) * -f) 72 | f = x * f 73 | return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) 74 | -------------------------------------------------------------------------------- /jaxrl_m/networks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | import flax.linen as nn 3 | import jax.numpy as jnp 4 | from jaxrl_m.common.common import default_init 5 | 6 | 7 | class MLP(nn.Module): 8 | hidden_dims: Sequence[int] 9 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish 10 | activate_final: bool = False 11 | use_layer_norm: bool = False 12 | dropout_rate: Optional[float] = None 13 | 14 | @nn.compact 15 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 16 | for i, size in enumerate(self.hidden_dims): 17 | x = nn.Dense(size, kernel_init=default_init())(x) 18 | 19 | if i + 1 < len(self.hidden_dims) or self.activate_final: 20 | if self.dropout_rate is not None and self.dropout_rate > 0: 21 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 22 | if self.use_layer_norm: 23 | x = nn.LayerNorm()(x) 24 | x = self.activations(x) 25 | return x 26 | 27 | 28 | class MLPResNetBlock(nn.Module): 29 | features: int 30 | act: Callable 31 | dropout_rate: float = None 32 | use_layer_norm: bool = False 33 | 34 | @nn.compact 35 | def __call__(self, x, train: bool = False): 36 | residual = x 37 | if self.dropout_rate is not None and self.dropout_rate > 0: 38 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 39 | if self.use_layer_norm: 40 | x = nn.LayerNorm()(x) 41 | x = nn.Dense(self.features * 4)(x) 42 | x = self.act(x) 43 | x = nn.Dense(self.features)(x) 44 | 45 | if residual.shape != x.shape: 46 | residual = nn.Dense(self.features)(residual) 47 | 48 | return residual + x 49 | 50 | 51 | class MLPResNet(nn.Module): 52 | num_blocks: int 53 | out_dim: int 54 | dropout_rate: float = None 55 | use_layer_norm: bool = False 56 | hidden_dim: int = 256 57 | activations: Callable = nn.swish 58 | 59 | @nn.compact 60 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 61 | x = nn.Dense(self.hidden_dim, kernel_init=default_init())(x) 62 | for _ in range(self.num_blocks): 63 | x = MLPResNetBlock( 64 | self.hidden_dim, 65 | act=self.activations, 66 | use_layer_norm=self.use_layer_norm, 67 | dropout_rate=self.dropout_rate, 68 | )(x, train=train) 69 | 70 | x = self.activations(x) 71 | x = nn.Dense(self.out_dim, kernel_init=default_init())(x) 72 | return x 73 | -------------------------------------------------------------------------------- /jaxrl_m/utils/timer_utils.py: -------------------------------------------------------------------------------- 1 | """Timer utility.""" 2 | 3 | from collections import defaultdict 4 | import time 5 | 6 | 7 | class Timer: 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.counts = defaultdict(int) 13 | self.times = defaultdict(float) 14 | self.start_times = {} 15 | 16 | def tick(self, key): 17 | if key in self.start_times: 18 | raise ValueError(f"Timer is already ticking for key: {key}") 19 | self.start_times[key] = time.time() 20 | 21 | def tock(self, key): 22 | if key not in self.start_times: 23 | raise ValueError(f"Timer is not ticking for key: {key}") 24 | self.counts[key] += 1 25 | self.times[key] += time.time() - self.start_times[key] 26 | del self.start_times[key] 27 | 28 | def get_average_times(self, reset=True): 29 | ret = {key: self.times[key] / self.counts[key] for key in self.counts} 30 | if reset: 31 | self.reset() 32 | return ret 33 | -------------------------------------------------------------------------------- /jaxrl_m/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl_m.vision.resnet_v1 import resnetv1_configs 2 | 3 | encoders = dict() 4 | encoders.update(resnetv1_configs) 5 | -------------------------------------------------------------------------------- /jaxrl_m/vision/film_conditioning_layer.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py 2 | import flax.linen as nn 3 | import jax.numpy as jnp 4 | 5 | 6 | class FilmConditioning(nn.Module): 7 | @nn.compact 8 | def __call__(self, conv_filters: jnp.ndarray, conditioning: jnp.ndarray): 9 | """Applies FiLM conditioning to a convolutional feature map. 10 | 11 | Args: 12 | conv_filters: A tensor of shape [batch_size, height, width, channels]. 13 | conditioning: A tensor of shape [batch_size, conditioning_size]. 14 | 15 | Returns: 16 | A tensor of shape [batch_size, height, width, channels]. 17 | """ 18 | projected_cond_add = nn.Dense( 19 | features=conv_filters.shape[-1], 20 | kernel_init=nn.initializers.zeros, 21 | bias_init=nn.initializers.zeros, 22 | )(conditioning) 23 | projected_cond_mult = nn.Dense( 24 | features=conv_filters.shape[-1], 25 | kernel_init=nn.initializers.zeros, 26 | bias_init=nn.initializers.zeros, 27 | )(conditioning) 28 | 29 | projected_cond_add = projected_cond_add[..., None, None, :] 30 | projected_cond_mult = projected_cond_mult[..., None, None, :] 31 | 32 | return conv_filters * (1 + projected_cond_add) + projected_cond_mult 33 | -------------------------------------------------------------------------------- /jaxrl_m/vision/resnet_v1.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | from functools import partial 3 | from typing import Any, Callable, Sequence, Tuple 4 | 5 | import flax.linen as nn 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from jaxrl_m.vision.film_conditioning_layer import FilmConditioning 10 | 11 | ModuleDef = Any 12 | 13 | 14 | class AddSpatialCoordinates(nn.Module): 15 | dtype: Any = jnp.float32 16 | 17 | @nn.compact 18 | def __call__(self, x): 19 | grid = jnp.array( 20 | np.stack( 21 | np.meshgrid(*[np.arange(s) / (s - 1) * 2 - 1 for s in x.shape[-3:-1]]), 22 | axis=-1, 23 | ), 24 | dtype=self.dtype, 25 | ).transpose(1, 0, 2) 26 | 27 | if x.ndim == 4: 28 | grid = jnp.broadcast_to(grid, [x.shape[0], *grid.shape]) 29 | 30 | return jnp.concatenate([x, grid], axis=-1) 31 | 32 | 33 | class SpatialSoftmax(nn.Module): 34 | height: int 35 | width: int 36 | channel: int 37 | pos_x: jnp.ndarray 38 | pos_y: jnp.ndarray 39 | temperature: None 40 | log_heatmap: bool = False 41 | 42 | @nn.compact 43 | def __call__(self, features): 44 | if self.temperature == -1: 45 | from jax.nn import initializers 46 | 47 | temperature = self.param( 48 | "softmax_temperature", initializers.ones, (1), jnp.float32 49 | ) 50 | else: 51 | temperature = 1.0 52 | 53 | # add batch dim if missing 54 | no_batch_dim = len(features.shape) < 4 55 | if no_batch_dim: 56 | features = features[None] 57 | 58 | assert len(features.shape) == 4 59 | batch_size, num_featuremaps = features.shape[0], features.shape[3] 60 | features = features.transpose(0, 3, 1, 2).reshape( 61 | batch_size, num_featuremaps, self.height * self.width 62 | ) 63 | 64 | softmax_attention = nn.softmax(features / temperature) 65 | expected_x = jnp.sum( 66 | self.pos_x * softmax_attention, axis=2, keepdims=True 67 | ).reshape(batch_size, num_featuremaps) 68 | expected_y = jnp.sum( 69 | self.pos_y * softmax_attention, axis=2, keepdims=True 70 | ).reshape(batch_size, num_featuremaps) 71 | expected_xy = jnp.concatenate([expected_x, expected_y], axis=1) 72 | 73 | expected_xy = jnp.reshape(expected_xy, [batch_size, 2 * num_featuremaps]) 74 | 75 | if no_batch_dim: 76 | expected_xy = expected_xy[0] 77 | return expected_xy 78 | 79 | 80 | class SpatialLearnedEmbeddings(nn.Module): 81 | height: int 82 | width: int 83 | channel: int 84 | num_features: int = 5 85 | kernel_init: Callable = nn.initializers.lecun_normal() 86 | param_dtype: Any = jnp.float32 87 | 88 | @nn.compact 89 | def __call__(self, features): 90 | """ 91 | features is B x H x W X C 92 | """ 93 | kernel = self.param( 94 | "kernel", 95 | self.kernel_init, 96 | (self.height, self.width, self.channel, self.num_features), 97 | self.param_dtype, 98 | ) 99 | 100 | # add batch dim if missing 101 | no_batch_dim = len(features.shape) < 4 102 | if no_batch_dim: 103 | features = features[None] 104 | 105 | batch_size = features.shape[0] 106 | assert len(features.shape) == 4 107 | features = jnp.sum( 108 | jnp.expand_dims(features, -1) * jnp.expand_dims(kernel, 0), axis=(1, 2) 109 | ) 110 | features = jnp.reshape(features, [batch_size, -1]) 111 | 112 | if no_batch_dim: 113 | features = features[0] 114 | 115 | return features 116 | 117 | 118 | class MyGroupNorm(nn.GroupNorm): 119 | def __call__(self, x): 120 | if x.ndim == 3: 121 | x = x[jnp.newaxis] 122 | x = super().__call__(x) 123 | return x[0] 124 | else: 125 | return super().__call__(x) 126 | 127 | 128 | class ResNetBlock(nn.Module): 129 | """ResNet block.""" 130 | 131 | filters: int 132 | conv: ModuleDef 133 | norm: ModuleDef 134 | act: Callable 135 | strides: Tuple[int, int] = (1, 1) 136 | 137 | @nn.compact 138 | def __call__(self, x): 139 | residual = x 140 | y = self.conv(self.filters, (3, 3), self.strides)(x) 141 | y = self.norm()(y) 142 | y = self.act(y) 143 | y = self.conv(self.filters, (3, 3))(y) 144 | y = self.norm()(y) 145 | 146 | if residual.shape != y.shape: 147 | residual = self.conv(self.filters, (1, 1), self.strides, name="conv_proj")( 148 | residual 149 | ) 150 | residual = self.norm(name="norm_proj")(residual) 151 | 152 | return self.act(residual + y) 153 | 154 | 155 | class BottleneckResNetBlock(nn.Module): 156 | """Bottleneck ResNet block.""" 157 | 158 | filters: int 159 | conv: ModuleDef 160 | norm: ModuleDef 161 | act: Callable 162 | strides: Tuple[int, int] = (1, 1) 163 | 164 | @nn.compact 165 | def __call__(self, x): 166 | residual = x 167 | y = self.conv(self.filters, (1, 1))(x) 168 | y = self.norm()(y) 169 | y = self.act(y) 170 | y = self.conv(self.filters, (3, 3), self.strides)(y) 171 | y = self.norm()(y) 172 | y = self.act(y) 173 | y = self.conv(self.filters * 4, (1, 1))(y) 174 | y = self.norm(scale_init=nn.initializers.zeros)(y) 175 | 176 | if residual.shape != y.shape: 177 | residual = self.conv( 178 | self.filters * 4, (1, 1), self.strides, name="conv_proj" 179 | )(residual) 180 | residual = self.norm(name="norm_proj")(residual) 181 | 182 | return self.act(residual + y) 183 | 184 | 185 | class ResNetEncoder(nn.Module): 186 | """ResNetV1.""" 187 | 188 | stage_sizes: Sequence[int] 189 | block_cls: ModuleDef 190 | num_filters: int = 64 191 | dtype: Any = jnp.float32 192 | act: str = "relu" 193 | conv: ModuleDef = nn.Conv 194 | norm: str = "group" 195 | add_spatial_coordinates: bool = False 196 | pooling_method: str = "avg" 197 | use_spatial_softmax: bool = False 198 | softmax_temperature: float = 1.0 199 | use_multiplicative_cond: bool = False 200 | num_spatial_blocks: int = 8 201 | use_film: bool = False 202 | 203 | @nn.compact 204 | def __call__(self, observations: jnp.ndarray, train: bool = True, cond_var=None): 205 | # put inputs in [-1, 1] 206 | x = observations.astype(jnp.float32) / 127.5 - 1.0 207 | 208 | if self.add_spatial_coordinates: 209 | x = AddSpatialCoordinates(dtype=self.dtype)(x) 210 | 211 | conv = partial( 212 | self.conv, 213 | use_bias=False, 214 | dtype=self.dtype, 215 | kernel_init=nn.initializers.kaiming_normal(), 216 | ) 217 | if self.norm == "batch": 218 | raise NotImplementedError 219 | elif self.norm == "group": 220 | norm = partial(MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype) 221 | elif self.norm == "layer": 222 | norm = partial(nn.LayerNorm, epsilon=1e-5, dtype=self.dtype) 223 | else: 224 | raise ValueError("norm not found") 225 | 226 | act = getattr(nn, self.act) 227 | 228 | x = conv( 229 | self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name="conv_init" 230 | )(x) 231 | 232 | x = norm(name="norm_init")(x) 233 | x = act(x) 234 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") 235 | for i, block_size in enumerate(self.stage_sizes): 236 | for j in range(block_size): 237 | stride = (2, 2) if i > 0 and j == 0 else (1, 1) 238 | x = self.block_cls( 239 | self.num_filters * 2 ** i, 240 | strides=stride, 241 | conv=conv, 242 | norm=norm, 243 | act=act, 244 | )(x) 245 | if self.use_film: 246 | assert ( 247 | cond_var is not None 248 | ), "Cond var is None, nothing to condition on" 249 | x = FilmConditioning()(x, cond_var) 250 | if self.use_multiplicative_cond: 251 | assert ( 252 | cond_var is not None 253 | ), "Cond var is None, nothing to condition on" 254 | cond_out = nn.Dense( 255 | x.shape[-1], kernel_init=nn.initializers.xavier_normal() 256 | )(cond_var) 257 | x_mult = jnp.expand_dims(jnp.expand_dims(cond_out, 1), 1) 258 | x = x * x_mult 259 | 260 | if self.pooling_method == "spatial_learned_embeddings": 261 | height, width, channel = x.shape[-3:] 262 | x = SpatialLearnedEmbeddings( 263 | height=height, 264 | width=width, 265 | channel=channel, 266 | num_features=self.num_spatial_blocks, 267 | )(x) 268 | elif self.pooling_method == "spatial_softmax": 269 | height, width, channel = x.shape[-3:] 270 | pos_x, pos_y = jnp.meshgrid( 271 | jnp.linspace(-1.0, 1.0, height), jnp.linspace(-1.0, 1.0, width) 272 | ) 273 | pos_x = pos_x.reshape(height * width) 274 | pos_y = pos_y.reshape(height * width) 275 | x = SpatialSoftmax( 276 | height, width, channel, pos_x, pos_y, self.softmax_temperature 277 | )(x) 278 | elif self.pooling_method == "avg": 279 | x = jnp.mean(x, axis=(-3, -2)) 280 | elif self.pooling_method == "max": 281 | x = jnp.max(x, axis=(-3, -2)) 282 | elif self.pooling_method == "none": 283 | pass 284 | else: 285 | raise ValueError("pooling method not found") 286 | 287 | return x 288 | 289 | 290 | resnetv1_configs = { 291 | "resnetv1-18": ft.partial( 292 | ResNetEncoder, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock 293 | ), 294 | "resnetv1-34": ft.partial( 295 | ResNetEncoder, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock 296 | ), 297 | "resnetv1-50": ft.partial( 298 | ResNetEncoder, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock 299 | ), 300 | "resnetv1-18-deeper": ft.partial( 301 | ResNetEncoder, stage_sizes=(3, 3, 3, 3), block_cls=ResNetBlock 302 | ), 303 | "resnetv1-18-deepest": ft.partial( 304 | ResNetEncoder, stage_sizes=(4, 4, 4, 4), block_cls=ResNetBlock 305 | ), 306 | "resnetv1-18-bridge": ft.partial( 307 | ResNetEncoder, 308 | stage_sizes=(2, 2, 2, 2), 309 | block_cls=ResNetBlock, 310 | num_spatial_blocks=8, 311 | ), 312 | "resnetv1-34-bridge": ft.partial( 313 | ResNetEncoder, 314 | stage_sizes=(3, 4, 6, 3), 315 | block_cls=ResNetBlock, 316 | num_spatial_blocks=8, 317 | ), 318 | "resnetv1-34-bridge-film": ft.partial( 319 | ResNetEncoder, 320 | stage_sizes=(3, 4, 6, 3), 321 | block_cls=ResNetBlock, 322 | num_spatial_blocks=8, 323 | use_film=True, 324 | ), 325 | "resnetv1-50-bridge": ft.partial( 326 | ResNetEncoder, 327 | stage_sizes=(3, 4, 6, 3), 328 | block_cls=BottleneckResNetBlock, 329 | num_spatial_blocks=8, 330 | ), 331 | "resnetv1-50-bridge-film": ft.partial( 332 | ResNetEncoder, 333 | stage_sizes=(3, 4, 6, 3), 334 | block_cls=BottleneckResNetBlock, 335 | num_spatial_blocks=8, 336 | use_film=True, 337 | ), 338 | } 339 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.26 2 | numpy==1.24.3 3 | jax==0.4.13 4 | distrax==0.1.2 5 | flax==0.7.0 6 | ml_collections >= 0.1.0 7 | tqdm >= 4.60.0 8 | chex==0.1.6 9 | optax==0.1.5 10 | absl-py >= 0.12.0 11 | scipy >= 1.6.0 12 | wandb >= 0.12.14 13 | tensorflow==2.13.0 14 | tensorflow_probability==0.21 15 | tensorflow_hub 16 | tensorflow_text 17 | einops >= 0.6.1 18 | imageio >= 2.31.1 19 | moviepy >= 1.0.3 20 | orbax 21 | matplotlib 22 | pyquaternion 23 | opencv-python 24 | opencv-contrib-python 25 | funcsigs 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='jaxrl_m', 5 | version='0.0.2', 6 | packages=setuptools.find_packages(), 7 | license='MIT License', 8 | long_description=open('README.md').read(), 9 | ) 10 | --------------------------------------------------------------------------------