├── .flake8 ├── .github └── workflows │ └── pre-commit.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs └── assets │ └── teaser.jpg ├── examples ├── 01_inference_pretrained.ipynb ├── 02_finetune_new_observation_action.py ├── 03_eval_finetuned.py ├── 04_eval_finetuned_on_robot.py ├── 05_dataloading.ipynb ├── 06_pytorch_oxe_dataloader.py ├── README.md └── envs │ ├── README.md │ ├── aloha_sim_env.py │ └── widowx_env.py ├── octo ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── obs_transforms.py │ ├── oxe │ │ ├── __init__.py │ │ ├── oxe_dataset_configs.py │ │ ├── oxe_dataset_mixes.py │ │ └── oxe_standardization_transforms.py │ ├── traj_transforms.py │ └── utils │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── goal_relabeling.py │ │ ├── task_augmentation.py │ │ └── text_processing.py ├── model │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── action_heads.py │ │ ├── base.py │ │ ├── block_transformer.py │ │ ├── diffusion.py │ │ ├── film_conditioning_layer.py │ │ ├── tokenizers.py │ │ ├── transformer.py │ │ ├── unet.py │ │ └── vit_encoders.py │ ├── octo_model.py │ └── octo_module.py └── utils │ ├── __init__.py │ ├── gym_wrappers.py │ ├── jax_utils.py │ ├── spec.py │ ├── train_callbacks.py │ ├── train_utils.py │ ├── typing.py │ └── visualization_lib.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── configs │ ├── config.py │ ├── finetune_config.py │ └── octo_pretrain_config.py ├── finetune.py └── train.py ├── setup.py └── tests ├── debug_config.py └── debug_dataset └── bridge_dataset └── 1.0.0 ├── bridge_dataset-train.tfrecord-00000-of-00001 ├── bridge_dataset-val.tfrecord-00000-of-00001 ├── dataset_info.json └── features.json /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git 3 | max-line-length = 88 4 | select = E,F,W,C 5 | ignore=W503, 6 | E203, 7 | E731, 8 | E722, 9 | F841, 10 | E402, 11 | E741, 12 | E501, 13 | C406, 14 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.0 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_*.json 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .ipynb_checkpoints/ 163 | wandb 164 | *.png 165 | *.sif 166 | .vscode 167 | .idea 168 | datasets/debug_dataset/bridge_dataset/1.0.0/action_* 169 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.3.0 6 | hooks: 7 | - id: check-yaml 8 | - id: check-ast 9 | - id: check-added-large-files 10 | exclude: ^examples/ 11 | - id: check-case-conflict 12 | - id: check-merge-conflict 13 | - id: end-of-file-fixer 14 | - id: trailing-whitespace 15 | - id: detect-private-key 16 | - id: debug-statements 17 | exclude: ^experiments/ 18 | - repo: https://github.com/psf/black 19 | rev: 22.10.0 20 | hooks: 21 | - id: black 22 | exclude: ^experiments/ 23 | - repo: https://github.com/PyCQA/flake8 24 | rev: 6.1.0 25 | hooks: 26 | - id: flake8 27 | exclude: ^experiments/|^examples/03_eval_finetuned.py 28 | - repo: https://github.com/pycqa/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | exclude: ^experiments/ 33 | args: ["--profile", "black", "--src", "octo", "--src", "experiments"] 34 | - repo: https://github.com/srstevenson/nb-clean 35 | rev: 3.1.0 36 | hooks: 37 | - id: nb-clean 38 | args: 39 | - --remove-empty-cells 40 | - --preserve-cell-outputs 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Robotic AI & Learning Lab Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Octo 2 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/octo-models/octo/blob/main/examples/01_inference_pretrained.ipynb) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | [![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://octo-models.github.io/) 5 | ![](https://github.com/rail-berkeley/octo/workflows/run-debug/badge.svg) 6 | ![](https://github.com/rail-berkeley/octo/workflows/pre-commit/badge.svg) 7 | 8 | This repo contains code for training and finetuning Octo generalist robotic policies (GRPs). 9 | Octo models are transformer-based diffusion policies, trained on a diverse mix of 800k robot trajectories. 10 | 11 | ## Get Started 12 | 13 | Follow the installation instructions, then load a pretrained Octo model! See [examples](examples/) for guides to zero-shot evaluation and finetuning and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z0vELj_lX9OWeoMG_WvXnQs43aPOEAhz?usp=sharing) 14 | for an inference example. 15 | 16 | ```python 17 | from octo.model.octo_model import OctoModel 18 | model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5") 19 | print(model.get_pretty_spec()) 20 | ``` 21 | 22 | ![Octo model](docs/assets/teaser.jpg) 23 | 24 | Out of the box, Octo supports multiple RGB camera inputs, can control various robot arms, 25 | and can be instructed via language commands or goal images. 26 | Octo uses a modular attention structure in its transformer backbone, allowing it to be effectively finetuned 27 | to robot setups with new sensory inputs, action spaces, and morphologies, using only a small target domain 28 | dataset and accessible compute budgets. 29 | 30 | 31 | ## Installation 32 | ```bash 33 | conda create -n octo python=3.10 34 | conda activate octo 35 | pip install -e . 36 | pip install -r requirements.txt 37 | ``` 38 | For GPU: 39 | ```bash 40 | pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 41 | ``` 42 | 43 | For TPU 44 | ```bash 45 | pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 46 | ``` 47 | See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax. 48 | 49 | Test the installation by finetuning on the debug dataset: 50 | ```bash 51 | python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5 --debug 52 | ``` 53 | 54 | ## Checkpoints 55 | 56 | You can find pretrained Octo checkpoints [here](https://huggingface.co/rail-berkeley). 57 | At the moment we provide the following model versions: 58 | 59 | | Model | Inference on 1x NVIDIA 4090 | Size | 60 | |---------------------------------------------------------------|-----------------------------|------------| 61 | | [Octo-Base](https://huggingface.co/rail-berkeley/octo-base) | 13 it/sec | 93M Params | 62 | | [Octo-Small](https://huggingface.co/rail-berkeley/octo-small) | 17 it/sec | 27M Params | 63 | 64 | 65 | ## Examples 66 | 67 | We provide simple [example scripts](examples) that demonstrate how to use and finetune Octo models, 68 | as well as how to use our data loader independently. We provide the following examples: 69 | 70 | | | | 71 | |----------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------| 72 | | [Octo Inference](examples/01_inference_pretrained.ipynb) | Minimal example for loading and running a pretrained Octo model | 73 | | [Octo Finetuning](examples/02_finetune_new_observation_action.py) | Minimal example for finetuning a pretrained Octo models on a small dataset with a new observation and action space | 74 | | [Octo Rollout](examples/03_eval_finetuned.py) | Run a rollout of a pretrained Octo policy in a Gym environment | 75 | | [Octo Robot Eval](examples/04_eval_finetuned_on_robot.py) | Evaluate a pretrained Octo model on a real WidowX robot | 76 | | [OpenX Dataloader Intro](examples/05_dataloading.ipynb) | Walkthrough of the features of our Open X-Embodiment data loader | 77 | | [OpenX PyTorch Dataloader](examples/06_pytorch_oxe_dataloader.ipynb) | Standalone Open X-Embodiment data loader in PyTorch | 78 | 79 | 80 | ## Octo Pretraining 81 | 82 | To reproduce our Octo pretraining on 800k robot trajectories, run: 83 | ```bash 84 | python scripts/train.py --config scripts/configs/octo_pretrain_config.py: --name=octo --config.dataset_kwargs.oxe_kwargs.data_dir=... --config.dataset_kwargs.oxe_kwargs.data_mix=oxe_magic_soup ... 85 | ``` 86 | 87 | To download the pretraining dataset from the [Open X-Embodiment Dataset](https://robotics-transformer-x.github.io/), 88 | install the [rlds_dataset_mod package](https://github.com/kpertsch/rlds_dataset_mod) 89 | and run the [prepare_open_x.sh script](https://github.com/kpertsch/rlds_dataset_mod/blob/main/prepare_open_x.sh). 90 | The total size of the pre-processed dataset is ~1.2TB. 91 | 92 | We run pretraining using a TPUv4-128 pod in 8 hours for the Octo-S model and in 14 hours for Octo-B. 93 | 94 | 95 | ## Octo Finetuning 96 | 97 | We provide a [minimal example](examples/02_finetune_new_observation_action.py) for finetuning with a new observation and action space. 98 | 99 | We also provide a more advanced finetuning script that allows you to change hyperparameters via a config file and logs finetuning 100 | metrics. To run advanced finetuning, use: 101 | ```bash 102 | python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5 103 | ``` 104 | 105 | We offer three finetuning modes depending on the parts of the model that are kept frozen: ```head_only```, ```head_mlp_only```, and ```full``` to finetune the full model. 106 | Additionally, one can specify the task type to finetune with: ```image_conditioned```, ```language_conditioned``` or ```multimodal``` for both. 107 | For example, to finetune the full transformer with image inputs only use: 108 | ```--config=finetune_config.py:full,image_conditioned```. 109 | 110 | 111 | ## Octo Evaluation 112 | 113 | Loading and running a trained Octo model is as easy as: 114 | ```python 115 | from octo.model import OctoModel 116 | 117 | model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5") 118 | task = model.create_tasks(texts=["pick up the spoon"]) 119 | action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0)) 120 | ``` 121 | 122 | We provide examples for evaluating Octo [in a simulated Gym environment](examples/03_eval_finetuned.py) as well 123 | as [on a real WidowX robot](examples/04_eval_finetuned_on_robot.py). 124 | 125 | To evaluate on your own environment, simply wrap it in a Gym interface and follow the instructions in the 126 | [Eval Env README](examples/envs/README.md). 127 | 128 | 129 | ## Code Structure 130 | 131 | | | File | Description | 132 | |---------------------|---------------------------------------------------------|-------------------------------------------------------------------------------| 133 | | Hyperparameters | [config.py](scripts/configs/config.py) | Defines all hyperparameters for the training run. | 134 | | Pretraining Loop | [train.py](scripts/train.py) | Main pretraining script. | 135 | | Finetuning Loop | [finetune.py](scripts/finetune.py) | Main finetuning script. | 136 | | Datasets | [dataset.py](octo/data/dataset.py) | Functions for creating single / interleaved datasets + data augmentation. | 137 | | Tokenizers | [tokenizers.py](octo/model/components/tokenizers.py) | Tokenizers that encode image / text inputs into tokens. | 138 | | Octo Model | [octo_model.py](octo/model/octo_model.py) | Main entry point for interacting with Octo models: loading, saving, and inference. | 139 | | Model Architecture | [octo_module.py](octo/model/octo_module.py) | Combines token sequencing, transformer backbone and readout heads. | 140 | | Visualization | [visualization_lib.py](octo/utils/visualization_lib.py) | Utilities for offline qualitative & quantitative eval. | 141 | 142 | ## FAQ 143 | #### What is the `timestep_pad_mask` in the observation dictionary? 144 | The `timestep_pad_mask` indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set `timestep_pad_mask=False` at the corresponding index. If you use Octo with a window size of 1, `timestep_pad_mask` should always just be `[True]`, indicating that the one and only observation in the window should be attended to. Note that if you wrap your robot environment with the `HistoryWrapper` (see [gym_wrappers.py](octo/utils/gym_wrappers.py)), the `timestep_pad_mask` key will be added to the observation dictionary for you. 145 | #### What is `pad_mask_dict` in the observation dictionary? 146 | While `timestep_pad_mask` indicates which observations should be attended to on a timestep level, `pad_mask_dict` indicates which elements of the observation should be attended to within a single timestep. For example, for datasets without language labels, `pad_mask_dict["language_instruction"]` is set to `False`. For datasets without a wrist camera, `pad_mask_dict["image_wrist"]` is set to `False`. For convenience, if a key is missing from the observation dict, it is equivalent to setting `pad_mask_dict` to `False` for that key. 147 | #### Does `model.sample_actions([...])` return the full trajectory to solve a task? 148 | Octo was pretrained with an action chunking size of 4, meaning it predicts the next 4 actions at once. You can choose to execute all these actions before sampling new ones, or only execute the first action before sampling new ones (also known as receding horizon control). You can also do something more advanced like [temporal ensembling](octo/utils/gym_wrappers.py). 149 | 150 | ## Updates for Version 1.5 151 | - Improved cross-attention between visual and language tokens by repeating language tokens at every timestep in the context window. 152 | - Augmented the language instructions in the data with rephrasings from GPT-3.5. 153 | - Bug fixes: 154 | - Turned off dropout in the diffusion head due to incompatibility with layer norm. 155 | - Fixed an off-by-one error with the attention mask. 156 | - Fixed an issue where different image augmentations did not get fresh random seeds. 157 | 158 | ## Citation 159 | 160 | ``` 161 | @inproceedings{octo_2023, 162 | title={Octo: An Open-Source Generalist Robot Policy}, 163 | author = {{Octo Model Team} and Dibya Ghosh and Homer Walke and Karl Pertsch and Kevin Black and Oier Mees and Sudeep Dasari and Joey Hejna and Charles Xu and Jianlan Luo and Tobias Kreiman and {You Liang} Tan and Pannag Sanketi and Quan Vuong and Ted Xiao and Dorsa Sadigh and Chelsea Finn and Sergey Levine}, 164 | booktitle = {Proceedings of Robotics: Science and Systems}, 165 | address = {Delft, Netherlands}, 166 | year = {2024}, 167 | } 168 | ``` 169 | -------------------------------------------------------------------------------- /docs/assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/docs/assets/teaser.jpg -------------------------------------------------------------------------------- /examples/01_inference_pretrained.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "534daf7f-4b6b-4357-9a38-9117f72ce9b4", 7 | "metadata": {}, 8 | "source": [ 9 | "# Step 1: Minimal Octo Inference Example\n", 10 | "\n", 11 | "This notebook demonstrates how to load a pre-trained / finetuned Octo checkpoint, run inference on some images, and compare the outputs to the true actions.\n", 12 | "\n", 13 | "First, let's start with a minimal example!" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "bae44461", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# run this block if you're using Colab\n", 24 | "\n", 25 | "# Download repo\n", 26 | "!git clone https://github.com/octo-models/octo.git\n", 27 | "%cd octo\n", 28 | "# Install repo\n", 29 | "!pip3 install -e .\n", 30 | "!pip3 install -r requirements.txt\n", 31 | "!pip3 install --upgrade \"jax[cuda11_pip]==0.4.20\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", 32 | "!pip install numpy==1.21.1 # to fix colab AttributeError: module 'numpy' has no attribute '_no_nep50_warning', if the error still shows reload" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "7229ce10", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import os\n", 43 | "os.environ['TOKENIZERS_PARALLELISM'] = 'false'" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "83d34283", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "from octo.model.octo_model import OctoModel\n", 54 | "\n", 55 | "model = OctoModel.load_pretrained(\"hf://rail-berkeley/octo-small-1.5\")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "15fca0dd", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from PIL import Image\n", 66 | "import requests\n", 67 | "import matplotlib.pyplot as plt\n", 68 | "import numpy as np\n", 69 | "# download one example BridgeV2 image\n", 70 | "IMAGE_URL = \"https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg\"\n", 71 | "img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))\n", 72 | "plt.imshow(img)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "e669650f", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# create obs & task dict, run inference\n", 83 | "import jax\n", 84 | "# add batch + time horizon 1\n", 85 | "img = img[np.newaxis,np.newaxis,...]\n", 86 | "observation = {\"image_primary\": img, \"timestep_pad_mask\": np.array([[True]])}\n", 87 | "task = model.create_tasks(texts=[\"pick up the fork\"])\n", 88 | "action = model.sample_actions(\n", 89 | " observation, \n", 90 | " task, \n", 91 | " unnormalization_statistics=model.dataset_statistics[\"bridge_dataset\"][\"action\"], \n", 92 | " rng=jax.random.PRNGKey(0)\n", 93 | ")\n", 94 | "print(action) # [batch, action_chunk, action_dim]" 95 | ] 96 | }, 97 | { 98 | "attachments": {}, 99 | "cell_type": "markdown", 100 | "id": "b2be0d1f", 101 | "metadata": {}, 102 | "source": [ 103 | "# Step 2: Run Inference on Full Trajectories\n", 104 | "\n", 105 | "That was easy! Now let's try to run inference across a whole trajectory and visualize the results!" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "a51eb166", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# Install mediapy for visualization\n", 116 | "!pip install mediapy\n", 117 | "!pip install opencv-python" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "8b0f7fd1-5b43-480f-b00f-766248d7f9af", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "import cv2\n", 128 | "import jax\n", 129 | "import tensorflow_datasets as tfds\n", 130 | "import tqdm\n", 131 | "import mediapy\n", 132 | "import numpy as np" 133 | ] 134 | }, 135 | { 136 | "attachments": {}, 137 | "cell_type": "markdown", 138 | "id": "b79053f4-316f-4d2d-81bd-e6e04cfa81bf", 139 | "metadata": {}, 140 | "source": [ 141 | "## Load Model Checkpoint\n", 142 | "First, we will load the pre-trained checkpoint using the `load_pretrained()` function. You can specify the path to a checkpoint directory or a HuggingFace path.\n", 143 | "\n", 144 | "Below, we are loading directly from HuggingFace.\n" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "42c04953-869d-48a8-a2df-e601324e97e6", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "from octo.model.octo_model import OctoModel\n", 155 | "\n", 156 | "model = OctoModel.load_pretrained(\"hf://rail-berkeley/octo-small-1.5\")" 157 | ] 158 | }, 159 | { 160 | "attachments": {}, 161 | "cell_type": "markdown", 162 | "id": "c298ac8f-da06-41d5-a4a5-145c3080231e", 163 | "metadata": {}, 164 | "source": [ 165 | "## Load Data\n", 166 | "Next, we will load a trajectory from the Bridge dataset for testing the model. We will use the publicly available copy in the Open X-Embodiment dataset bucket." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "392bd127", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "# create RLDS dataset builder\n", 177 | "builder = tfds.builder_from_directory(builder_dir='gs://gresearch/robotics/bridge/0.1.0/')\n", 178 | "ds = builder.as_dataset(split='train[:1]')\n", 179 | "\n", 180 | "# sample episode + resize to 256x256 (default third-person cam resolution)\n", 181 | "episode = next(iter(ds))\n", 182 | "steps = list(episode['steps'])\n", 183 | "images = [cv2.resize(np.array(step['observation']['image']), (256, 256)) for step in steps]\n", 184 | "\n", 185 | "# extract goal image & language instruction\n", 186 | "goal_image = images[-1]\n", 187 | "language_instruction = steps[0]['observation']['natural_language_instruction'].numpy().decode()\n", 188 | "\n", 189 | "# visualize episode\n", 190 | "print(f'Instruction: {language_instruction}')\n", 191 | "mediapy.show_video(images, fps=10)" 192 | ] 193 | }, 194 | { 195 | "attachments": {}, 196 | "cell_type": "markdown", 197 | "id": "b37ffca5", 198 | "metadata": {}, 199 | "source": [ 200 | "## Run Inference\n", 201 | "\n", 202 | "Next, we will run inference over the images in the episode using the loaded model. \n", 203 | "Below we demonstrate setups for both goal-conditioned and language-conditioned training.\n", 204 | "Note that we need to feed inputs of the correct temporal window size." 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "9ad64434", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "WINDOW_SIZE = 2\n", 215 | "\n", 216 | "# create `task` dict\n", 217 | "task = model.create_tasks(goals={\"image_primary\": goal_image[None]}) # for goal-conditioned\n", 218 | "task = model.create_tasks(texts=[language_instruction]) # for language conditioned" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "id": "74d6b20f", 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "# run inference loop, this model only uses 3rd person image observations for bridge\n", 229 | "# collect predicted and true actions\n", 230 | "pred_actions, true_actions = [], []\n", 231 | "for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):\n", 232 | " input_images = np.stack(images[step:step+WINDOW_SIZE])[None]\n", 233 | " observation = {\n", 234 | " 'image_primary': input_images,\n", 235 | " 'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)\n", 236 | " }\n", 237 | " \n", 238 | " # this returns *normalized* actions --> we need to unnormalize using the dataset statistics\n", 239 | " actions = model.sample_actions(\n", 240 | " observation, \n", 241 | " task, \n", 242 | " unnormalization_statistics=model.dataset_statistics[\"bridge_dataset\"][\"action\"], \n", 243 | " rng=jax.random.PRNGKey(0)\n", 244 | " )\n", 245 | " actions = actions[0] # remove batch dim\n", 246 | "\n", 247 | " pred_actions.append(actions)\n", 248 | " final_window_step = step + WINDOW_SIZE - 1\n", 249 | " true_actions.append(np.concatenate(\n", 250 | " (\n", 251 | " steps[final_window_step]['action']['world_vector'], \n", 252 | " steps[final_window_step]['action']['rotation_delta'], \n", 253 | " np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]\n", 254 | " ), axis=-1\n", 255 | " ))" 256 | ] 257 | }, 258 | { 259 | "attachments": {}, 260 | "cell_type": "markdown", 261 | "id": "12a5e3f7", 262 | "metadata": {}, 263 | "source": [ 264 | "## Visualize predictions and ground-truth actions\n", 265 | "\n", 266 | "Finally, we will visualize the predicted actions in comparison to the groundtruth actions." 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "7a79775d", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "import matplotlib.pyplot as plt\n", 277 | "\n", 278 | "ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']\n", 279 | "\n", 280 | "# build image strip to show above actions\n", 281 | "img_strip = np.concatenate(np.array(images[::3]), axis=1)\n", 282 | "\n", 283 | "# set up plt figure\n", 284 | "figure_layout = [\n", 285 | " ['image'] * len(ACTION_DIM_LABELS),\n", 286 | " ACTION_DIM_LABELS\n", 287 | "]\n", 288 | "plt.rcParams.update({'font.size': 12})\n", 289 | "fig, axs = plt.subplot_mosaic(figure_layout)\n", 290 | "fig.set_size_inches([45, 10])\n", 291 | "\n", 292 | "# plot actions\n", 293 | "pred_actions = np.array(pred_actions).squeeze()\n", 294 | "true_actions = np.array(true_actions).squeeze()\n", 295 | "for action_dim, action_label in enumerate(ACTION_DIM_LABELS):\n", 296 | " # actions have batch, horizon, dim, in this example we just take the first action for simplicity\n", 297 | " axs[action_label].plot(pred_actions[:, 0, action_dim], label='predicted action')\n", 298 | " axs[action_label].plot(true_actions[:, action_dim], label='ground truth')\n", 299 | " axs[action_label].set_title(action_label)\n", 300 | " axs[action_label].set_xlabel('Time in one episode')\n", 301 | "\n", 302 | "axs['image'].imshow(img_strip)\n", 303 | "axs['image'].set_xlabel('Time in one episode (subsampled)')\n", 304 | "plt.legend()" 305 | ] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3 (ipykernel)", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3" 324 | } 325 | }, 326 | "nbformat": 4, 327 | "nbformat_minor": 5 328 | } 329 | -------------------------------------------------------------------------------- /examples/02_finetune_new_observation_action.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script demonstrates how to finetune Octo to a new observation space (single camera + proprio) 3 | and new action space (bimanual) using a simulated ALOHA cube handover dataset (https://tonyzhaozh.github.io/aloha/). 4 | 5 | To run this example, first download and extract the dataset from here: https://rail.eecs.berkeley.edu/datasets/example_sim_data.zip 6 | 7 | python examples/02_finetune_new_observation_action.py --pretrained_path=hf://rail-berkeley/octo-small-1.5 --data_dir=... 8 | """ 9 | from absl import app, flags, logging 10 | import flax 11 | import jax 12 | import optax 13 | import tensorflow as tf 14 | import tqdm 15 | import wandb 16 | 17 | from octo.data.dataset import make_single_dataset 18 | from octo.model.components.action_heads import L1ActionHead 19 | from octo.model.components.tokenizers import LowdimObsTokenizer 20 | from octo.model.octo_model import OctoModel 21 | from octo.utils.jax_utils import initialize_compilation_cache 22 | from octo.utils.spec import ModuleSpec 23 | from octo.utils.train_utils import ( 24 | freeze_weights, 25 | merge_params, 26 | process_text, 27 | TrainState, 28 | ) 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_string( 33 | "pretrained_path", None, "Path to pre-trained Octo checkpoint directory." 34 | ) 35 | flags.DEFINE_string("data_dir", None, "Path to finetuning dataset, in RLDS format.") 36 | flags.DEFINE_string("save_dir", None, "Directory for saving finetuning checkpoints.") 37 | flags.DEFINE_integer("batch_size", 128, "Batch size for finetuning.") 38 | 39 | flags.DEFINE_bool( 40 | "freeze_transformer", 41 | False, 42 | "Whether pre-trained transformer weights should be frozen.", 43 | ) 44 | 45 | 46 | def main(_): 47 | assert ( 48 | FLAGS.batch_size % jax.device_count() == 0 49 | ), "Batch size must be divisible by device count." 50 | 51 | initialize_compilation_cache() 52 | # prevent tensorflow from using GPU memory since it's only used for data loading 53 | tf.config.set_visible_devices([], "GPU") 54 | 55 | # setup wandb for logging 56 | wandb.init(name="finetune_aloha", project="octo") 57 | 58 | # load pre-trained model 59 | logging.info("Loading pre-trained model...") 60 | pretrained_model = OctoModel.load_pretrained(FLAGS.pretrained_path) 61 | 62 | # make finetuning dataset 63 | # apply Gaussian normalization, load chunks of 50 actions since we'll train with action chunking 64 | # delete goal images in the data loader since we will train a language-conditioned-only policy 65 | # TODO: directly load this from raw data to make it less opaque? 66 | logging.info("Loading finetuning dataset...") 67 | dataset = make_single_dataset( 68 | dataset_kwargs=dict( 69 | name="aloha_sim_cube_scripted_dataset", 70 | data_dir=FLAGS.data_dir, 71 | image_obs_keys={"primary": "top"}, 72 | proprio_obs_key="state", 73 | language_key="language_instruction", 74 | ), 75 | traj_transform_kwargs=dict( 76 | window_size=1, 77 | action_horizon=50, 78 | ), 79 | frame_transform_kwargs=dict( 80 | resize_size={"primary": (256, 256)}, 81 | ), 82 | train=True, 83 | ) 84 | train_data_iter = ( 85 | dataset.repeat() 86 | .unbatch() 87 | .shuffle(10000) # can reduce this if RAM consumption too high 88 | .batch(FLAGS.batch_size) 89 | .iterator() 90 | ) 91 | 92 | # run text tokenizer over batch (this needs to happen before training / sharding) + delete unused keys 93 | text_processor = pretrained_model.text_processor 94 | 95 | def process_batch(batch): 96 | batch = process_text(batch, text_processor) 97 | del batch["dataset_name"] 98 | return batch 99 | 100 | train_data_iter = map(process_batch, train_data_iter) 101 | example_batch = next(train_data_iter) 102 | 103 | # load pre-training config and modify --> remove wrist cam, add proprio input, change action head 104 | # following Zhao et al. we use "action chunks" of length 50 and L1 loss for ALOHA 105 | config = pretrained_model.config 106 | del config["model"]["observation_tokenizers"]["wrist"] 107 | ### 108 | config["model"]["observation_tokenizers"]["proprio"] = ModuleSpec.create( 109 | LowdimObsTokenizer, 110 | n_bins=256, 111 | bin_type="normal", 112 | low=-2.0, 113 | high=2.0, 114 | obs_keys=["proprio"], 115 | ) 116 | # Fully override the old action head with a new one (for smaller changes, you can use update_config) 117 | config["model"]["heads"]["action"] = ModuleSpec.create( 118 | L1ActionHead, 119 | action_horizon=50, 120 | action_dim=14, 121 | readout_key="readout_action", 122 | ) 123 | 124 | # initialize weights for modified Octo model, then merge in all applicable pre-trained weights 125 | # new position encodings for proprio inputs & weights for new action head will remain "from scratch" 126 | logging.info("Updating model for new observation & action space...") 127 | model = OctoModel.from_config( 128 | config, 129 | example_batch, 130 | text_processor, 131 | verbose=True, 132 | dataset_statistics=dataset.dataset_statistics, 133 | ) 134 | merged_params = merge_params(model.params, pretrained_model.params) 135 | # can perform any additional parameter surgery here... 136 | # ... 137 | model = model.replace(params=merged_params) 138 | del pretrained_model 139 | 140 | # create optimizer & train_state, optionally freeze keys for pre-trained transformer 141 | # train_state bundles parameters & optimizers 142 | learning_rate = optax.join_schedules( 143 | [optax.linear_schedule(0, 3e-5, 100), optax.constant_schedule(3e-5)], [100] 144 | ) 145 | tx = optax.adamw(learning_rate) 146 | frozen_keys = model.config["optimizer"]["frozen_keys"] 147 | if FLAGS.freeze_transformer: 148 | frozen_keys.append("BlockTransformer_0") 149 | tx = freeze_weights(tx, model.params, frozen_keys) 150 | train_state = TrainState.create( 151 | rng=jax.random.PRNGKey(1234), 152 | model=model, 153 | tx=tx, 154 | ) 155 | 156 | # define loss function and train step 157 | def loss_fn(params, batch, rng, train=True): 158 | bound_module = model.module.bind({"params": params}, rngs={"dropout": rng}) 159 | transformer_embeddings = bound_module.octo_transformer( 160 | batch["observation"], 161 | batch["task"], 162 | batch["observation"]["timestep_pad_mask"], 163 | train=train, 164 | ) 165 | action_loss, action_metrics = bound_module.heads["action"].loss( 166 | transformer_embeddings, # Action head knows to pull out the action readout_key 167 | batch["action"], 168 | batch["observation"]["timestep_pad_mask"], 169 | batch["action_pad_mask"], 170 | train=train, 171 | ) 172 | return action_loss, action_metrics 173 | 174 | @jax.jit 175 | def train_step(state, batch): 176 | rng, dropout_rng = jax.random.split(state.rng) 177 | (loss, info), grads = jax.value_and_grad(loss_fn, has_aux=True)( 178 | state.model.params, batch, dropout_rng, train=True 179 | ) 180 | new_state = state.apply_gradients(grads=grads, rng=rng) 181 | return new_state, info 182 | 183 | # run finetuning loop 184 | logging.info("Starting finetuning...") 185 | for i in tqdm.tqdm(range(5000), total=5000, dynamic_ncols=True): 186 | batch = next(train_data_iter) 187 | train_state, update_info = train_step(train_state, batch) 188 | if (i + 1) % 100 == 0: 189 | update_info = jax.device_get(update_info) 190 | wandb.log( 191 | flax.traverse_util.flatten_dict({"training": update_info}, sep="/"), 192 | step=i, 193 | ) 194 | if (i + 1) % 1000 == 0: 195 | # save checkpoint 196 | train_state.model.save_pretrained(step=i, checkpoint_path=FLAGS.save_dir) 197 | 198 | 199 | if __name__ == "__main__": 200 | app.run(main) 201 | -------------------------------------------------------------------------------- /examples/03_eval_finetuned.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script demonstrates how to load and rollout a finetuned Octo model. 3 | We use the Octo model finetuned on ALOHA sim data from the examples/02_finetune_new_observation_action.py script. 4 | 5 | For installing the ALOHA sim environment, clone: https://github.com/tonyzhaozh/act 6 | Then run: 7 | pip3 install opencv-python modern_robotics pyrealsense2 h5py_cache pyquaternion pyyaml rospkg pexpect mujoco==2.3.3 dm_control==1.0.9 einops packaging h5py 8 | 9 | Finally, modify the `sys.path.append` statement below to add the ACT repo to your path. 10 | If you are running this on a head-less server, start a virtual display: 11 | Xvfb :1 -screen 0 1024x768x16 & 12 | export DISPLAY=:1 13 | 14 | To run this script, run: 15 | cd examples 16 | python3 03_eval_finetuned.py --finetuned_path= 17 | """ 18 | from functools import partial 19 | import sys 20 | 21 | from absl import app, flags, logging 22 | import gym 23 | import jax 24 | import numpy as np 25 | import wandb 26 | 27 | sys.path.append("path/to/your/act") 28 | 29 | # keep this to register ALOHA sim env 30 | from envs.aloha_sim_env import AlohaGymEnv # noqa 31 | 32 | from octo.model.octo_model import OctoModel 33 | from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper 34 | from octo.utils.train_callbacks import supply_rng 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | flags.DEFINE_string( 39 | "finetuned_path", None, "Path to finetuned Octo checkpoint directory." 40 | ) 41 | 42 | 43 | def main(_): 44 | # setup wandb for logging 45 | wandb.init(name="eval_aloha", project="octo") 46 | 47 | # load finetuned model 48 | logging.info("Loading finetuned model...") 49 | model = OctoModel.load_pretrained(FLAGS.finetuned_path) 50 | 51 | # make gym environment 52 | ################################################################################################################## 53 | # environment needs to implement standard gym interface + return observations of the following form: 54 | # obs = { 55 | # "image_primary": ... 56 | # } 57 | # it should also implement an env.get_task() function that returns a task dict with goal and/or language instruct. 58 | # task = { 59 | # "language_instruction": "some string" 60 | # "goal": { 61 | # "image_primary": ... 62 | # } 63 | # } 64 | ################################################################################################################## 65 | env = gym.make("aloha-sim-cube-v0") 66 | 67 | # wrap env to normalize proprio 68 | env = NormalizeProprio(env, model.dataset_statistics) 69 | 70 | # add wrappers for history and "receding horizon control", i.e. action chunking 71 | env = HistoryWrapper(env, horizon=1) 72 | env = RHCWrapper(env, exec_horizon=50) 73 | 74 | # the supply_rng wrapper supplies a new random key to sample_actions every time it's called 75 | policy_fn = supply_rng( 76 | partial( 77 | model.sample_actions, 78 | unnormalization_statistics=model.dataset_statistics["action"], 79 | ), 80 | ) 81 | 82 | # running rollouts 83 | for _ in range(3): 84 | obs, info = env.reset() 85 | 86 | # create task specification --> use model utility to create task dict with correct entries 87 | language_instruction = env.get_task()["language_instruction"] 88 | task = model.create_tasks(texts=language_instruction) 89 | 90 | # run rollout for 400 steps 91 | images = [obs["image_primary"][0]] 92 | episode_return = 0.0 93 | while len(images) < 400: 94 | # model returns actions of shape [batch, pred_horizon, action_dim] -- remove batch 95 | actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task) 96 | actions = actions[0] 97 | 98 | # step env -- info contains full "chunk" of observations for logging 99 | # obs only contains observation for final step of chunk 100 | obs, reward, done, trunc, info = env.step(actions) 101 | images.extend([o["image_primary"][0] for o in info["observations"]]) 102 | episode_return += reward 103 | if done or trunc: 104 | break 105 | print(f"Episode return: {episode_return}") 106 | 107 | # log rollout video to wandb -- subsample temporally 2x for faster logging 108 | wandb.log( 109 | {"rollout_video": wandb.Video(np.array(images).transpose(0, 3, 1, 2)[::2])} 110 | ) 111 | 112 | 113 | if __name__ == "__main__": 114 | app.run(main) 115 | -------------------------------------------------------------------------------- /examples/04_eval_finetuned_on_robot.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script shows how we evaluated a finetuned Octo model on a real WidowX robot. While the exact specifics may not 3 | be applicable to your use case, this script serves as a didactic example of how to use Octo in a real-world setting. 4 | 5 | If you wish, you may reproduce these results by [reproducing the robot setup](https://rail-berkeley.github.io/bridgedata/) 6 | and installing [the robot controller](https://github.com/rail-berkeley/bridge_data_robot) 7 | """ 8 | 9 | from datetime import datetime 10 | from functools import partial 11 | import os 12 | import time 13 | 14 | from absl import app, flags, logging 15 | import click 16 | import cv2 17 | from envs.widowx_env import convert_obs, state_to_eep, wait_for_obs, WidowXGym 18 | import imageio 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus 23 | 24 | from octo.model.octo_model import OctoModel 25 | from octo.utils.gym_wrappers import HistoryWrapper, TemporalEnsembleWrapper 26 | from octo.utils.train_callbacks import supply_rng 27 | 28 | np.set_printoptions(suppress=True) 29 | 30 | logging.set_verbosity(logging.WARNING) 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string( 35 | "checkpoint_weights_path", None, "Path to checkpoint", required=True 36 | ) 37 | flags.DEFINE_integer("checkpoint_step", None, "Checkpoint step", required=True) 38 | 39 | # custom to bridge_data_robot 40 | flags.DEFINE_string("ip", "localhost", "IP address of the robot") 41 | flags.DEFINE_integer("port", 5556, "Port of the robot") 42 | flags.DEFINE_spaceseplist("goal_eep", [0.3, 0.0, 0.15], "Goal position") 43 | flags.DEFINE_spaceseplist("initial_eep", [0.3, 0.0, 0.15], "Initial position") 44 | flags.DEFINE_bool("blocking", False, "Use the blocking controller") 45 | 46 | 47 | flags.DEFINE_integer("im_size", None, "Image size", required=True) 48 | flags.DEFINE_string("video_save_path", None, "Path to save video") 49 | flags.DEFINE_integer("num_timesteps", 120, "num timesteps") 50 | flags.DEFINE_integer("window_size", 2, "Observation history length") 51 | flags.DEFINE_integer( 52 | "action_horizon", 4, "Length of action sequence to execute/ensemble" 53 | ) 54 | 55 | 56 | # show image flag 57 | flags.DEFINE_bool("show_image", False, "Show image") 58 | 59 | ############################################################################## 60 | 61 | STEP_DURATION_MESSAGE = """ 62 | Bridge data was collected with non-blocking control and a step duration of 0.2s. 63 | However, we relabel the actions to make it look like the data was collected with 64 | blocking control and we evaluate with blocking control. 65 | Be sure to use a step duration of 0.2 if evaluating with non-blocking control. 66 | """ 67 | STEP_DURATION = 0.2 68 | STICKY_GRIPPER_NUM_STEPS = 1 69 | WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]] 70 | CAMERA_TOPICS = [{"name": "/blue/image_raw"}] 71 | ENV_PARAMS = { 72 | "camera_topics": CAMERA_TOPICS, 73 | "override_workspace_boundaries": WORKSPACE_BOUNDS, 74 | "move_duration": STEP_DURATION, 75 | } 76 | 77 | ############################################################################## 78 | 79 | 80 | def main(_): 81 | # set up the widowx client 82 | if FLAGS.initial_eep is not None: 83 | assert isinstance(FLAGS.initial_eep, list) 84 | initial_eep = [float(e) for e in FLAGS.initial_eep] 85 | start_state = np.concatenate([initial_eep, [0, 0, 0, 1]]) 86 | else: 87 | start_state = None 88 | 89 | env_params = WidowXConfigs.DefaultEnvParams.copy() 90 | env_params.update(ENV_PARAMS) 91 | env_params["start_state"] = list(start_state) 92 | widowx_client = WidowXClient(host=FLAGS.ip, port=FLAGS.port) 93 | widowx_client.init(env_params, image_size=FLAGS.im_size) 94 | env = WidowXGym( 95 | widowx_client, FLAGS.im_size, FLAGS.blocking, STICKY_GRIPPER_NUM_STEPS 96 | ) 97 | if not FLAGS.blocking: 98 | assert STEP_DURATION == 0.2, STEP_DURATION_MESSAGE 99 | 100 | # load models 101 | model = OctoModel.load_pretrained( 102 | FLAGS.checkpoint_weights_path, 103 | FLAGS.checkpoint_step, 104 | ) 105 | 106 | # wrap the robot environment 107 | env = HistoryWrapper(env, FLAGS.window_size) 108 | env = TemporalEnsembleWrapper(env, FLAGS.action_horizon) 109 | # switch TemporalEnsembleWrapper with RHCWrapper for receding horizon control 110 | # env = RHCWrapper(env, FLAGS.action_horizon) 111 | 112 | # create policy functions 113 | def sample_actions( 114 | pretrained_model: OctoModel, 115 | observations, 116 | tasks, 117 | rng, 118 | ): 119 | # add batch dim to observations 120 | observations = jax.tree_map(lambda x: x[None], observations) 121 | actions = pretrained_model.sample_actions( 122 | observations, 123 | tasks, 124 | rng=rng, 125 | unnormalization_statistics=pretrained_model.dataset_statistics[ 126 | "bridge_dataset" 127 | ]["action"], 128 | ) 129 | # remove batch dim 130 | return actions[0] 131 | 132 | policy_fn = supply_rng( 133 | partial( 134 | sample_actions, 135 | model, 136 | argmax=FLAGS.deterministic, 137 | temperature=FLAGS.temperature, 138 | ) 139 | ) 140 | 141 | goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) 142 | goal_instruction = "" 143 | 144 | # goal sampling loop 145 | while True: 146 | modality = click.prompt( 147 | "Language or goal image?", type=click.Choice(["l", "g"]) 148 | ) 149 | 150 | if modality == "g": 151 | if click.confirm("Take a new goal?", default=True): 152 | assert isinstance(FLAGS.goal_eep, list) 153 | _eep = [float(e) for e in FLAGS.goal_eep] 154 | goal_eep = state_to_eep(_eep, 0) 155 | widowx_client.move_gripper(1.0) # open gripper 156 | 157 | move_status = None 158 | while move_status != WidowXStatus.SUCCESS: 159 | move_status = widowx_client.move(goal_eep, duration=1.5) 160 | 161 | input("Press [Enter] when ready for taking the goal image. ") 162 | obs = wait_for_obs(widowx_client) 163 | obs = convert_obs(obs, FLAGS.im_size) 164 | goal = jax.tree_map(lambda x: x[None], obs) 165 | 166 | # Format task for the model 167 | task = model.create_tasks(goals=goal) 168 | # For logging purposes 169 | goal_image = goal["image_primary"][0] 170 | goal_instruction = "" 171 | 172 | elif modality == "l": 173 | print("Current instruction: ", goal_instruction) 174 | if click.confirm("Take a new instruction?", default=True): 175 | text = input("Instruction?") 176 | # Format task for the model 177 | task = model.create_tasks(texts=[text]) 178 | # For logging purposes 179 | goal_instruction = text 180 | goal_image = jnp.zeros_like(goal_image) 181 | else: 182 | raise NotImplementedError() 183 | 184 | input("Press [Enter] to start.") 185 | 186 | # reset env 187 | obs, _ = env.reset() 188 | time.sleep(2.0) 189 | 190 | # do rollout 191 | last_tstep = time.time() 192 | images = [] 193 | goals = [] 194 | t = 0 195 | while t < FLAGS.num_timesteps: 196 | if time.time() > last_tstep + STEP_DURATION: 197 | last_tstep = time.time() 198 | 199 | # save images 200 | images.append(obs["image_primary"][-1]) 201 | goals.append(goal_image) 202 | 203 | if FLAGS.show_image: 204 | bgr_img = cv2.cvtColor(obs["image_primary"][-1], cv2.COLOR_RGB2BGR) 205 | cv2.imshow("img_view", bgr_img) 206 | cv2.waitKey(20) 207 | 208 | # get action 209 | forward_pass_time = time.time() 210 | action = np.array(policy_fn(obs, task), dtype=np.float64) 211 | print("forward pass time: ", time.time() - forward_pass_time) 212 | 213 | # perform environment step 214 | start_time = time.time() 215 | obs, _, _, truncated, _ = env.step(action) 216 | print("step time: ", time.time() - start_time) 217 | 218 | t += 1 219 | 220 | if truncated: 221 | break 222 | 223 | # save video 224 | if FLAGS.video_save_path is not None: 225 | os.makedirs(FLAGS.video_save_path, exist_ok=True) 226 | curr_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 227 | save_path = os.path.join( 228 | FLAGS.video_save_path, 229 | f"{curr_time}.mp4", 230 | ) 231 | video = np.concatenate([np.stack(goals), np.stack(images)], axis=1) 232 | imageio.mimsave(save_path, video, fps=1.0 / STEP_DURATION * 3) 233 | 234 | 235 | if __name__ == "__main__": 236 | app.run(main) 237 | -------------------------------------------------------------------------------- /examples/05_dataloading.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Octo Dataloading Examples\n", 8 | "\n", 9 | "This notebook will walk you through some of the primary features of the Octo dataloader. Data is, after all, the most important part of any machine learning pipeline!" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Loading Open X-Embodiment Data\n", 17 | "\n", 18 | "The [Open X-Embodiment (OXE)](https://robotics-transformer-x.github.io/) project was a massive cross-instutition data collection collaboration the likes of which robot learning has never seen before. The resulting dataset includes 22 different robots demonstrating 527 skills and totals over 1 million trajectories. However, as we found throughout the course of the Octo project, simply loading such a diverse set of robot data is no small feat. We hope that the `octo.data` pipeline can help kickstart anyone who hopes to take advantage of the massive collection of robot demonstrations that is OXE!\n", 19 | "\n", 20 | "### Minimum working example to load a single OXE dataset" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# minimum working example to load a single OXE dataset\n", 30 | "from octo.data.oxe import make_oxe_dataset_kwargs\n", 31 | "from octo.data.dataset import make_single_dataset\n", 32 | "\n", 33 | "dataset_kwargs = make_oxe_dataset_kwargs(\n", 34 | " # see octo/data/oxe/oxe_dataset_configs.py for available datasets\n", 35 | " # (this is a very small one for faster loading)\n", 36 | " \"austin_buds_dataset_converted_externally_to_rlds\",\n", 37 | " # can be local or on cloud storage (anything supported by TFDS)\n", 38 | " # \"/path/to/base/oxe/directory\",\n", 39 | " \"gs://gresearch/robotics\",\n", 40 | ")\n", 41 | "dataset = make_single_dataset(dataset_kwargs, train=True) # load the train split\n", 42 | "iterator = dataset.iterator()" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# make_single_dataset yields entire trajectories\n", 52 | "traj = next(iterator)\n", 53 | "print(\"Top-level keys: \", traj.keys())\n", 54 | "print(\"Observation keys: \", traj[\"observation\"].keys())\n", 55 | "print(\"Task keys: \", traj[\"task\"].keys())" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "from PIL import Image\n", 65 | "import numpy as np\n", 66 | "\n", 67 | "traj = next(iterator)\n", 68 | "images = traj[\"observation\"][\"image_primary\"]\n", 69 | "# should be: (traj_len, window_size, height, width, channels)\n", 70 | "# (window_size defaults to 1)\n", 71 | "print(images.shape) \n", 72 | "Image.fromarray(np.concatenate(images.squeeze()[-5:], axis=1))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# you should set these much higher in practice (as large as your memory can hold!)\n", 82 | "SHUFFLE_BUFFER_SIZE = 1000\n", 83 | "BATCH_SIZE = 64\n", 84 | "\n", 85 | "# turning a dataset of trajectories into a training-ready batched dataset\n", 86 | "train_dataset = (\n", 87 | " dataset.flatten() # flattens trajectories into individual frames\n", 88 | " .shuffle(SHUFFLE_BUFFER_SIZE) # shuffles the frames\n", 89 | " .batch(BATCH_SIZE) # batches the frames\n", 90 | ")\n", 91 | "batch = next(train_dataset.iterator())\n", 92 | "images = batch[\"observation\"][\"image_primary\"]\n", 93 | "# should be: (batch_size, window_size, height, width, channels)\n", 94 | "print(images.shape)\n", 95 | "Image.fromarray(np.concatenate(images.squeeze()[:5], axis=1))" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "### Loading a training-ready OXE mix\n", 103 | "\n", 104 | "In reality, you're probably going to want to mix multiple datasets together, as well as use other transformations such as resizing, augmentation, windowing, etc. This section will show you how to get a proper OXE mix up and running, as well as demonstrate additional `octo.data` features for more realistic use-cases." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "from octo.data.oxe import make_oxe_dataset_kwargs_and_weights\n", 114 | "from octo.data.dataset import make_interleaved_dataset\n", 115 | "\n", 116 | "dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights(\n", 117 | " # you can pass your own list of dataset names and sample weights here, but we've\n", 118 | " # also provided a few named mixes for convenience. The Octo model was trained\n", 119 | " # using the \"oxe_magic_soup\" mix.\n", 120 | " \"rtx\",\n", 121 | " # can be local or on cloud storage (anything supported by TFDS)\n", 122 | " \"gs://gresearch/robotics\",\n", 123 | " # let's get a wrist camera!\n", 124 | " load_camera_views=(\"primary\", \"wrist\"),\n", 125 | ")\n", 126 | "\n", 127 | "# see `octo.data.dataset.make_dataset_from_rlds` for the meaning of these kwargs\n", 128 | "dataset_kwargs_list[0]" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "SHUFFLE_BUFFER_SIZE = 1000\n", 138 | "BATCH_SIZE = 8\n", 139 | "\n", 140 | "# each element of `dataset_kwargs_list` can be used with `make_single_dataset`, but let's\n", 141 | "# use the more powerful `make_interleaved_dataset` to combine them for us!\n", 142 | "dataset = make_interleaved_dataset(\n", 143 | " dataset_kwargs_list,\n", 144 | " sample_weights,\n", 145 | " train=True,\n", 146 | " # unlike our manual shuffling above, `make_interleaved_dataset` will shuffle\n", 147 | " # the JPEG-encoded images, so you should be able to fit a much larger buffer size\n", 148 | " shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,\n", 149 | " batch_size=BATCH_SIZE,\n", 150 | " # see `octo.data.dataset.apply_trajectory_transforms` for full documentation\n", 151 | " # of these configuration options\n", 152 | " traj_transform_kwargs=dict(\n", 153 | " goal_relabeling_strategy=\"uniform\", # let's get some goal images\n", 154 | " window_size=2, # let's get some history\n", 155 | " action_horizon=4, # let's get some future actions for action chunking\n", 156 | " subsample_length=100, # subsampling long trajectories improves shuffling a lot\n", 157 | " ),\n", 158 | " # see `octo.data.dataset.apply_frame_transforms` for full documentation\n", 159 | " # of these configuration options\n", 160 | " frame_transform_kwargs=dict(\n", 161 | " # let's apply some basic image augmentations -- see `dlimp.transforms.augment_image`\n", 162 | " # for full documentation of these configuration options\n", 163 | " image_augment_kwargs=dict(\n", 164 | " primary=dict(\n", 165 | " augment_order=[\"random_resized_crop\", \"random_brightness\"],\n", 166 | " random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),\n", 167 | " random_brightness=[0.1],\n", 168 | " )\n", 169 | " ),\n", 170 | " # providing a `resize_size` is highly recommended for a mixed dataset, otherwise\n", 171 | " # datasets with different resolutions will cause errors\n", 172 | " resize_size=dict(\n", 173 | " primary=(256, 256),\n", 174 | " wrist=(128, 128),\n", 175 | " ),\n", 176 | " # If parallelism options are not provided, they will default to tf.Data.AUTOTUNE.\n", 177 | " # However, we would highly recommend setting them manually if you run into issues\n", 178 | " # with memory or dataloading speed. Frame transforms are usually the speed\n", 179 | " # bottleneck (due to image decoding, augmentation, and resizing), so you can set\n", 180 | " # this to a very high value if you have a lot of CPU cores. Keep in mind that more\n", 181 | " # parallel calls also use more memory, though.\n", 182 | " num_parallel_calls=64,\n", 183 | " ),\n", 184 | " # Same spiel as above about performance, although trajectory transforms and data reading\n", 185 | " # are usually not the speed bottleneck. One reason to manually set these is if you want\n", 186 | " # to reduce memory usage (since autotune may spawn way more threads than necessary).\n", 187 | " traj_transform_threads=16,\n", 188 | " traj_read_threads=16,\n", 189 | ")\n", 190 | "\n", 191 | "# Another performance knob to tune is the number of batches to prefetch -- again,\n", 192 | "# the default of tf.data.AUTOTUNE can sometimes use more memory than necessary.\n", 193 | "iterator = dataset.iterator(prefetch=1)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# phew, that was a lot of configuration! Let's see what we got.\n", 203 | "batch = next(iterator)\n", 204 | "print(\"Top-level keys: \", batch.keys())\n", 205 | "# should now have \"image_primary\" and \"image_wrist\"!\n", 206 | "print(\"Observation keys: \", batch[\"observation\"].keys())\n", 207 | "# should also have \"image_primary\" and \"image_wrist\", corresponding to future goal images\n", 208 | "print(\"Task keys: \", batch[\"task\"].keys())" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "from PIL import Image\n", 218 | "import numpy as np\n", 219 | "\n", 220 | "images_primary = batch[\"observation\"][\"image_primary\"]\n", 221 | "images_wrist = batch[\"observation\"][\"image_wrist\"]\n", 222 | "# should be: (batch_size, window_size (now 2), height, width, channels)\n", 223 | "print(images_primary.shape)\n", 224 | "print(images_wrist.shape)\n", 225 | "actions = batch[\"action\"]\n", 226 | "# should be: (batch_size, window_size, action_horizon, action_dim)\n", 227 | "print(actions.shape)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "# let's visualize a window of primary images\n", 237 | "display(Image.fromarray(np.concatenate(images_primary[0], axis=1)))\n", 238 | "# now a window of wrist images -- many datasets don't have wrist images,\n", 239 | "# so this will often be black\n", 240 | "display(Image.fromarray(np.concatenate(images_wrist[0], axis=1)))\n", 241 | "# pad_mask_dict also tells you which keys should be treated as padding\n", 242 | "# (e.g., if the wrist camera is black, the corresponding pad_mask_dict entry is False)\n", 243 | "print(batch[\"observation\"][\"pad_mask_dict\"][\"image_wrist\"][0])" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "# let's take a look at the \"task\" dict: it should now have both goal\n", 253 | "# images and language instructions!\n", 254 | "goal_primary = batch[\"task\"][\"image_primary\"]\n", 255 | "goal_wrist = batch[\"task\"][\"image_wrist\"]\n", 256 | "language_instruction = batch[\"task\"][\"language_instruction\"]\n", 257 | "display(Image.fromarray(goal_primary[0]))\n", 258 | "display(Image.fromarray(goal_wrist[0]))\n", 259 | "print(language_instruction[0])" 260 | ] 261 | } 262 | ], 263 | "metadata": { 264 | "kernelspec": { 265 | "display_name": "octo-2", 266 | "language": "python", 267 | "name": "python3" 268 | }, 269 | "language_info": { 270 | "codemirror_mode": { 271 | "name": "ipython", 272 | "version": 3 273 | }, 274 | "file_extension": ".py", 275 | "mimetype": "text/x-python", 276 | "name": "python", 277 | "nbconvert_exporter": "python", 278 | "pygments_lexer": "ipython3" 279 | } 280 | }, 281 | "nbformat": 4, 282 | "nbformat_minor": 2 283 | } 284 | -------------------------------------------------------------------------------- /examples/06_pytorch_oxe_dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example shows how to use the `octo.data` dataloader with PyTorch by wrapping it in a simple PyTorch 3 | dataloader. The config below also happens to be our exact pretraining config (except for the batch size and 4 | shuffle buffer size, which are reduced for demonstration purposes). 5 | """ 6 | import numpy as np 7 | import tensorflow as tf 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import tqdm 11 | 12 | from octo.data.dataset import make_interleaved_dataset 13 | from octo.data.oxe import make_oxe_dataset_kwargs_and_weights 14 | 15 | DATA_PATH = "gs://rail-orca-central2/resize_256_256" 16 | 17 | tf.config.set_visible_devices([], "GPU") 18 | 19 | 20 | class TorchRLDSDataset(torch.utils.data.IterableDataset): 21 | """Thin wrapper around RLDS dataset for use with PyTorch dataloaders.""" 22 | 23 | def __init__( 24 | self, 25 | rlds_dataset, 26 | train=True, 27 | ): 28 | self._rlds_dataset = rlds_dataset 29 | self._is_train = train 30 | 31 | def __iter__(self): 32 | for sample in self._rlds_dataset.as_numpy_iterator(): 33 | yield sample 34 | 35 | def __len__(self): 36 | lengths = np.array( 37 | [ 38 | stats["num_transitions"] 39 | for stats in self._rlds_dataset.dataset_statistics 40 | ] 41 | ) 42 | if hasattr(self._rlds_dataset, "sample_weights"): 43 | lengths *= np.array(self._rlds_dataset.sample_weights) 44 | total_len = lengths.sum() 45 | if self._is_train: 46 | return int(0.95 * total_len) 47 | else: 48 | return int(0.05 * total_len) 49 | 50 | 51 | dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights( 52 | "oxe_magic_soup", 53 | DATA_PATH, 54 | load_camera_views=("primary", "wrist"), 55 | ) 56 | 57 | dataset = make_interleaved_dataset( 58 | dataset_kwargs_list, 59 | sample_weights, 60 | train=True, 61 | shuffle_buffer_size=1000, # change to 500k for training, large shuffle buffers are important, but adjust to your RAM 62 | batch_size=None, # batching will be handles in PyTorch Dataloader object 63 | balance_weights=True, 64 | traj_transform_kwargs=dict( 65 | goal_relabeling_strategy="uniform", 66 | window_size=2, 67 | action_horizon=4, 68 | subsample_length=100, 69 | ), 70 | frame_transform_kwargs=dict( 71 | image_augment_kwargs={ 72 | "primary": dict( 73 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 74 | random_brightness=[0.1], 75 | random_contrast=[0.9, 1.1], 76 | random_saturation=[0.9, 1.1], 77 | random_hue=[0.05], 78 | augment_order=[ 79 | "random_resized_crop", 80 | "random_brightness", 81 | "random_contrast", 82 | "random_saturation", 83 | "random_hue", 84 | ], 85 | ), 86 | "wrist": dict( 87 | random_brightness=[0.1], 88 | random_contrast=[0.9, 1.1], 89 | random_saturation=[0.9, 1.1], 90 | random_hue=[0.05], 91 | augment_order=[ 92 | "random_brightness", 93 | "random_contrast", 94 | "random_saturation", 95 | "random_hue", 96 | ], 97 | ), 98 | }, 99 | resize_size=dict( 100 | primary=(256, 256), 101 | wrist=(128, 128), 102 | ), 103 | num_parallel_calls=200, 104 | ), 105 | traj_transform_threads=48, 106 | traj_read_threads=48, 107 | ) 108 | 109 | 110 | pytorch_dataset = TorchRLDSDataset(dataset) 111 | dataloader = DataLoader( 112 | pytorch_dataset, 113 | batch_size=16, 114 | num_workers=0, # important to keep this to 0 so PyTorch does not mess with the parallelism 115 | ) 116 | 117 | for i, sample in tqdm.tqdm(enumerate(dataloader)): 118 | if i == 5000: 119 | break 120 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## Examples 2 | 3 | We provide simple [example scripts](examples) that demonstrate how to inference and finetune OCTO models, 4 | as well as how to use our data loader independently. We provide the following examples: 5 | 6 | | | | 7 | |----------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------| 8 | | [OCTO Inference](examples/01_inference_pretrained.ipynb) | Minimal example for loading and inferencing a pre-trained OCTO model | 9 | | [OCTO Finetuning](examples/02_finetune_new_observation_action.py) | Minimal example for finetuning a pre-trained OCTO models on a small dataset with new observation + action space | 10 | | [OCTO Rollout](examples/03_eval_finetuned.py) | Run a rollout of a pre-trained OCTO policy in a Gym environment | 11 | | [OCTO Robot Eval](examples/04_eval_finetuned_on_robot.py) | Evaluate a pre-trained OCTO model on a real WidowX robot | 12 | | [OpenX Dataloader Intro](examples/05_dataloading.ipynb) | Walkthrough of the features of our Open X-Embodiment data loader | 13 | | [OpenX PyTorch Dataloader](examples/06_pytorch_oxe_dataloader.ipynb) | Standalone Open X-Embodiment data loader in PyTorch | 14 | -------------------------------------------------------------------------------- /examples/envs/README.md: -------------------------------------------------------------------------------- 1 | # Octo Evaluation Environments 2 | 3 | The `step` and `reset` functions of the Gym environment should return observations with the images, depth images, and/or 4 | proprioceptive information that the model expects as input. Specifically, the returned observations should be dictionaries 5 | of the form: 6 | ``` 7 | obs = { 8 | "image_primary": ..., 9 | "image_wrist": ..., 10 | ... 11 | "depth_primary": ..., 12 | "depth_wrist": ..., 13 | ... 14 | "proprio": ..., 15 | } 16 | ``` 17 | 18 | Note that the image keys should be `image_{key}` where `key` is one of the `image_obs_keys` specified in the data loading config used to train the model (typically this is `primary` and/or `wrist`). 19 | If a key is not present in the observation dictionary, the model will substitute it with padding. 20 | 21 | Check out the example environments in this folder to help you integrate your own environment! 22 | -------------------------------------------------------------------------------- /examples/envs/aloha_sim_env.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List 3 | 4 | import dlimp as dl 5 | import gym 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | # need to put https://github.com/tonyzhaozh/act in your PATH for this import to work 10 | from sim_env import BOX_POSE, make_sim_env 11 | 12 | 13 | class AlohaGymEnv(gym.Env): 14 | def __init__( 15 | self, 16 | env: gym.Env, 17 | camera_names: List[str], 18 | im_size: int = 256, 19 | seed: int = 1234, 20 | ): 21 | self._env = env 22 | self.observation_space = gym.spaces.Dict( 23 | { 24 | **{ 25 | f"image_{i}": gym.spaces.Box( 26 | low=np.zeros((im_size, im_size, 3)), 27 | high=255 * np.ones((im_size, im_size, 3)), 28 | dtype=np.uint8, 29 | ) 30 | for i in ["primary", "wrist"][: len(camera_names)] 31 | }, 32 | "proprio": gym.spaces.Box( 33 | low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32 34 | ), 35 | } 36 | ) 37 | self.action_space = gym.spaces.Box( 38 | low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32 39 | ) 40 | self.camera_names = camera_names 41 | self._im_size = im_size 42 | self._rng = np.random.default_rng(seed) 43 | 44 | def step(self, action): 45 | ts = self._env.step(action) 46 | obs, images = self.get_obs(ts) 47 | reward = ts.reward 48 | info = {"images": images} 49 | 50 | if reward == self._env.task.max_reward: 51 | self._episode_is_success = 1 52 | 53 | return obs, reward, False, False, info 54 | 55 | def reset(self, **kwargs): 56 | # sample new box pose 57 | x_range = [0.0, 0.2] 58 | y_range = [0.4, 0.6] 59 | z_range = [0.05, 0.05] 60 | ranges = np.vstack([x_range, y_range, z_range]) 61 | cube_position = self._rng.uniform(ranges[:, 0], ranges[:, 1]) 62 | cube_quat = np.array([1, 0, 0, 0]) 63 | BOX_POSE[0] = np.concatenate([cube_position, cube_quat]) 64 | 65 | ts = self._env.reset(**kwargs) 66 | obs, images = self.get_obs(ts) 67 | info = {"images": images} 68 | self._episode_is_success = 0 69 | 70 | return obs, info 71 | 72 | def get_obs(self, ts): 73 | curr_obs = {} 74 | vis_images = [] 75 | 76 | obs_img_names = ["primary", "wrist"] 77 | for i, cam_name in enumerate(self.camera_names): 78 | curr_image = ts.observation["images"][cam_name] 79 | vis_images.append(copy.deepcopy(curr_image)) 80 | curr_image = jnp.array(curr_image) 81 | curr_obs[f"image_{obs_img_names[i]}"] = curr_image 82 | curr_obs = dl.transforms.resize_images( 83 | curr_obs, match=curr_obs.keys(), size=(self._im_size, self._im_size) 84 | ) 85 | 86 | qpos_numpy = np.array(ts.observation["qpos"]) 87 | qpos = jnp.array(qpos_numpy) 88 | curr_obs["proprio"] = qpos 89 | 90 | return curr_obs, np.concatenate(vis_images, axis=-2) 91 | 92 | def get_task(self): 93 | return { 94 | "language_instruction": ["pick up the cube and hand it over"], 95 | } 96 | 97 | def get_episode_metrics(self): 98 | return { 99 | "success_rate": self._episode_is_success, 100 | } 101 | 102 | 103 | # register gym environments 104 | gym.register( 105 | "aloha-sim-cube-v0", 106 | entry_point=lambda: AlohaGymEnv( 107 | make_sim_env("sim_transfer_cube"), camera_names=["top"] 108 | ), 109 | ) 110 | -------------------------------------------------------------------------------- /examples/envs/widowx_env.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import gym 4 | import numpy as np 5 | from pyquaternion import Quaternion 6 | from widowx_envs.widowx_env_service import WidowXClient 7 | 8 | 9 | def state_to_eep(xyz_coor, zangle: float): 10 | """ 11 | Implement the state to eep function. 12 | Refered to `bridge_data_robot`'s `widowx_controller/widowx_controller.py` 13 | return a 4x4 matrix 14 | """ 15 | assert len(xyz_coor) == 3 16 | DEFAULT_ROTATION = np.array([[0, 0, 1.0], [0, 1.0, 0], [-1.0, 0, 0]]) 17 | new_pose = np.eye(4) 18 | new_pose[:3, -1] = xyz_coor 19 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) * Quaternion( 20 | matrix=DEFAULT_ROTATION 21 | ) 22 | new_pose[:3, :3] = new_quat.rotation_matrix 23 | # yaw, pitch, roll = quat.yaw_pitch_roll 24 | return new_pose 25 | 26 | 27 | def wait_for_obs(widowx_client): 28 | obs = widowx_client.get_observation() 29 | while obs is None: 30 | print("Waiting for observations...") 31 | obs = widowx_client.get_observation() 32 | time.sleep(1) 33 | return obs 34 | 35 | 36 | def convert_obs(obs, im_size): 37 | image_obs = ( 38 | obs["image"].reshape(3, im_size, im_size).transpose(1, 2, 0) * 255 39 | ).astype(np.uint8) 40 | # add padding to proprio to match training 41 | proprio = np.concatenate([obs["state"][:6], [0], obs["state"][-1:]]) 42 | # NOTE: assume image_1 is not available 43 | return { 44 | "image_primary": image_obs, 45 | } 46 | 47 | 48 | def null_obs(img_size): 49 | return { 50 | "image_primary": np.zeros((img_size, img_size, 3), dtype=np.uint8), 51 | } 52 | 53 | 54 | class WidowXGym(gym.Env): 55 | """ 56 | A Gym environment for the WidowX controller provided by: 57 | https://github.com/rail-berkeley/bridge_data_robot 58 | Needed to use Gym wrappers. 59 | """ 60 | 61 | def __init__( 62 | self, 63 | widowx_client: WidowXClient, 64 | im_size: int = 256, 65 | blocking: bool = True, 66 | sticky_gripper_num_steps: int = 1, 67 | ): 68 | self.widowx_client = widowx_client 69 | self.im_size = im_size 70 | self.blocking = blocking 71 | self.observation_space = gym.spaces.Dict( 72 | { 73 | "image_primary": gym.spaces.Box( 74 | low=np.zeros((im_size, im_size, 3)), 75 | high=255 * np.ones((im_size, im_size, 3)), 76 | dtype=np.uint8, 77 | ), 78 | "proprio": gym.spaces.Box( 79 | low=np.ones((8,)) * -1, high=np.ones((8,)), dtype=np.float64 80 | ), 81 | } 82 | ) 83 | self.action_space = gym.spaces.Box( 84 | low=np.zeros((7,)), high=np.ones((7,)), dtype=np.float64 85 | ) 86 | self.sticky_gripper_num_steps = sticky_gripper_num_steps 87 | self.is_gripper_closed = False 88 | self.num_consecutive_gripper_change_actions = 0 89 | 90 | def step(self, action): 91 | # sticky gripper logic 92 | if (action[-1] < 0.5) != self.is_gripper_closed: 93 | self.num_consecutive_gripper_change_actions += 1 94 | else: 95 | self.num_consecutive_gripper_change_actions = 0 96 | 97 | if self.num_consecutive_gripper_change_actions >= self.sticky_gripper_num_steps: 98 | self.is_gripper_closed = not self.is_gripper_closed 99 | self.num_consecutive_gripper_change_actions = 0 100 | action[-1] = 0.0 if self.is_gripper_closed else 1.0 101 | 102 | self.widowx_client.step_action(action, blocking=self.blocking) 103 | 104 | raw_obs = self.widowx_client.get_observation() 105 | 106 | truncated = False 107 | if raw_obs is None: 108 | # this indicates a loss of connection with the server 109 | # due to an exception in the last step so end the trajectory 110 | truncated = True 111 | obs = null_obs(self.im_size) # obs with all zeros 112 | else: 113 | obs = convert_obs(raw_obs, self.im_size) 114 | 115 | return obs, 0, False, truncated, {} 116 | 117 | def reset(self, seed=None, options=None): 118 | super().reset(seed=seed) 119 | self.widowx_client.reset() 120 | 121 | self.is_gripper_closed = False 122 | self.num_consecutive_gripper_change_actions = 0 123 | 124 | raw_obs = wait_for_obs(self.widowx_client) 125 | obs = convert_obs(raw_obs, self.im_size) 126 | 127 | return obs, {} 128 | -------------------------------------------------------------------------------- /octo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/octo/__init__.py -------------------------------------------------------------------------------- /octo/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/octo/data/__init__.py -------------------------------------------------------------------------------- /octo/data/obs_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains observation-level transforms used in the octo data pipeline. These transforms operate on the 3 | "observation" dictionary, and are applied at a per-frame level. 4 | """ 5 | from typing import Mapping, Optional, Tuple, Union 6 | 7 | from absl import logging 8 | import dlimp as dl 9 | import tensorflow as tf 10 | 11 | 12 | def augment( 13 | obs: dict, seed: tf.Tensor, augment_kwargs: Union[dict, Mapping[str, dict]] 14 | ) -> dict: 15 | """Augments images, skipping padding images.""" 16 | if not hasattr(augment_kwargs, "items"): 17 | raise ValueError( 18 | "augment_kwargs must be a dict with keys corresponding to image names, or a single dict " 19 | "with an 'augment_order' key." 20 | ) 21 | image_names = {key[6:] for key in obs if key.startswith("image_")} 22 | 23 | # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed 24 | # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image 25 | # name to augmentation dict) 26 | if "augment_order" in augment_kwargs: 27 | augment_kwargs = {name: augment_kwargs for name in image_names} 28 | 29 | for i, name in enumerate(image_names): 30 | if name not in augment_kwargs: 31 | continue 32 | kwargs = augment_kwargs[name] 33 | logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") 34 | obs[f"image_{name}"] = tf.cond( 35 | obs["pad_mask_dict"][f"image_{name}"], 36 | lambda: dl.transforms.augment_image( 37 | obs[f"image_{name}"], 38 | **kwargs, 39 | seed=seed + i, # augment each image differently 40 | ), 41 | lambda: obs[f"image_{name}"], # skip padding images 42 | ) 43 | 44 | return obs 45 | 46 | 47 | def image_dropout( 48 | obs: dict, 49 | seed: tf.Tensor, 50 | dropout_prob: float, 51 | always_keep_key: Optional[str] = None, 52 | ) -> dict: 53 | """Independently drops out image keys, each with probability `dropout_prob`, but always keeps at least one 54 | image present. 55 | """ 56 | image_keys = [key for key in obs if key.startswith("image_")] 57 | if not image_keys: 58 | return obs 59 | pad_mask = tf.stack([obs["pad_mask_dict"][key] for key in image_keys]) 60 | # if any non-padding images exist, pick one of them to keep no matter what 61 | shuffle_seed, seed = tf.unstack(tf.random.split(seed)) 62 | 63 | if always_keep_key: 64 | assert ( 65 | always_keep_key in image_keys 66 | ), f"Specified always_keep_key {always_keep_key} not present in image_keys: {image_keys} during dropout." 67 | always_keep_index = tf.constant( 68 | image_keys.index(always_keep_key), dtype=tf.int64 69 | ) 70 | else: 71 | always_keep_index = tf.cond( 72 | tf.reduce_any(pad_mask), 73 | # pick a random index from the non-padding images 74 | lambda: tf.random.experimental.stateless_shuffle( 75 | tf.where(pad_mask)[:, 0], seed=shuffle_seed 76 | )[0], 77 | # all images are padding, so it doesn't matter 78 | lambda: tf.constant(0, dtype=tf.int64), 79 | ) 80 | 81 | # drop images independently, except for the one at always_keep_index 82 | rands = tf.random.stateless_uniform([len(image_keys)], seed=seed) 83 | pad_mask = tf.logical_and( 84 | pad_mask, 85 | tf.logical_or( 86 | tf.range(len(image_keys), dtype=tf.int64) == always_keep_index, 87 | rands > dropout_prob, 88 | ), 89 | ) 90 | 91 | # perform the dropout and update pad_mask_dict 92 | for i, key in enumerate(image_keys): 93 | obs["pad_mask_dict"][key] = pad_mask[i] 94 | obs[key] = tf.cond( 95 | pad_mask[i], 96 | lambda: obs[key], 97 | lambda: tf.zeros_like(obs[key]), 98 | ) 99 | return obs 100 | 101 | 102 | def decode_and_resize( 103 | obs: dict, 104 | resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]], 105 | depth_resize_size: Union[Tuple[int, int], Mapping[str, Tuple[int, int]]], 106 | ) -> dict: 107 | """Decodes images and depth images, and then optionally resizes them.""" 108 | # just gets the part after "image_" or "depth_" 109 | image_names = {key[6:] for key in obs if key.startswith("image_")} 110 | depth_names = {key[6:] for key in obs if key.startswith("depth_")} 111 | 112 | if isinstance(resize_size, tuple): 113 | resize_size = {name: resize_size for name in image_names} 114 | if isinstance(depth_resize_size, tuple): 115 | depth_resize_size = {name: depth_resize_size for name in depth_names} 116 | 117 | for name in image_names: 118 | if name not in resize_size: 119 | logging.warning( 120 | f"No resize_size was provided for image_{name}. This will result in 1x1 " 121 | "padding images, which may cause errors if you mix padding and non-padding images." 122 | ) 123 | image = obs[f"image_{name}"] 124 | if image.dtype == tf.string: 125 | if tf.strings.length(image) == 0: 126 | # this is a padding image 127 | image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) 128 | else: 129 | image = tf.io.decode_image( 130 | image, expand_animations=False, dtype=tf.uint8 131 | ) 132 | elif image.dtype != tf.uint8: 133 | raise ValueError( 134 | f"Unsupported image dtype: found image_{name} with dtype {image.dtype}" 135 | ) 136 | if name in resize_size: 137 | image = dl.transforms.resize_image(image, size=resize_size[name]) 138 | obs[f"image_{name}"] = image 139 | 140 | for name in depth_names: 141 | if name not in depth_resize_size: 142 | logging.warning( 143 | f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " 144 | "padding depth images, which may cause errors if you mix padding and non-padding images." 145 | ) 146 | depth = obs[f"depth_{name}"] 147 | if depth.dtype == tf.string: 148 | if tf.strings.length(depth) == 0: 149 | # this is a padding image 150 | depth = tf.zeros( 151 | (*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32 152 | ) 153 | else: 154 | depth = tf.io.decode_image( 155 | depth, expand_animations=False, dtype=tf.float32 156 | )[..., 0] 157 | elif depth.dtype != tf.float32: 158 | raise ValueError( 159 | f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}" 160 | ) 161 | if name in depth_resize_size: 162 | depth = dl.transforms.resize_depth_image( 163 | depth, size=depth_resize_size[name] 164 | ) 165 | obs[f"depth_{name}"] = depth 166 | 167 | return obs 168 | -------------------------------------------------------------------------------- /octo/data/oxe/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from typing import Any, Dict, List, Sequence, Tuple, Union 4 | 5 | from octo.data.oxe.oxe_dataset_configs import ActionEncoding, OXE_DATASET_CONFIGS 6 | from octo.data.oxe.oxe_dataset_mixes import OXE_NAMED_MIXES 7 | from octo.data.oxe.oxe_standardization_transforms import OXE_STANDARDIZATION_TRANSFORMS 8 | from octo.data.utils.data_utils import NormalizationType 9 | from octo.utils.spec import ModuleSpec 10 | 11 | 12 | def make_oxe_dataset_kwargs( 13 | name: str, 14 | data_dir: str, 15 | load_camera_views: Sequence[str] = ("primary",), 16 | load_depth: bool = False, 17 | load_proprio: bool = False, 18 | load_language: bool = True, 19 | force_recompute_dataset_statistics: bool = False, 20 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 21 | ) -> Dict[str, Any]: 22 | """Generates dataset kwargs for a given dataset from Open X-Embodiment. The returned kwargs can be passed 23 | directly into `octo.data.dataset.make_dataset_from_rlds`. 24 | 25 | Args: 26 | name: Name of the dataset to load. See `oxe_dataset_configs.py` for available datasets. 27 | data_dir: Base data directory that contains the dataset. 28 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views. 29 | load_depth: If True, loads corresponding depth channels for each RGB channel. 30 | load_proprio: If True, loads proprioceptive information. 31 | load_language: If True, loads language instructions. 32 | force_recompute_dataset_statistics: If True, recompute dataset statistics. 33 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions. 34 | """ 35 | dataset_kwargs = copy.deepcopy(OXE_DATASET_CONFIGS[name]) 36 | 37 | if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: 38 | # with EEF_POS actions, the last action dimension is gripper 39 | dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] 40 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS: 41 | # with JOINT_POS actions, last dimension is gripper 42 | dataset_kwargs["action_normalization_mask"] = [True] * 7 + [False] 43 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL: 44 | # with JOINT_POS_BIMANUAL actions, 7th and 14th dimension are gripper 45 | dataset_kwargs["action_normalization_mask"] = ( 46 | [True] * 6 + [False] + [True] * 6 + [False] 47 | ) 48 | elif dataset_kwargs["action_encoding"] is ActionEncoding.NAV_2D: 49 | # with NAV_2D actions, all dimensions are deltas 50 | dataset_kwargs["action_normalization_mask"] = [True] * 2 51 | elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL_NAV: 52 | # with JOINT_POS_BIMANUAL_NAV actions, 7th and 14th dimension are gripper 53 | dataset_kwargs["action_normalization_mask"] = ( 54 | [True] * 6 + [False] + [True] * 6 + [False] + [True] * 2 55 | ) 56 | else: 57 | raise ValueError( 58 | f"Cannot load {name} with unsupported action encoding {dataset_kwargs['action_encoding']}." 59 | ) 60 | 61 | # adjust loaded camera views 62 | if missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"])): 63 | raise ValueError( 64 | f"Cannot load {name} with views {missing_keys} since they are not available." 65 | ) 66 | dataset_kwargs["image_obs_keys"] = { 67 | k: v 68 | for k, v in dataset_kwargs["image_obs_keys"].items() 69 | if k in load_camera_views 70 | } 71 | dataset_kwargs["depth_obs_keys"] = { 72 | k: v 73 | for k, v in dataset_kwargs["depth_obs_keys"].items() 74 | if k in load_camera_views 75 | } 76 | 77 | if not load_depth: 78 | dataset_kwargs.pop("depth_obs_keys") 79 | if load_proprio: 80 | dataset_kwargs["proprio_obs_key"] = "proprio" 81 | if load_language: 82 | dataset_kwargs["language_key"] = "language_instruction" 83 | 84 | dataset_kwargs[ 85 | "action_proprio_normalization_type" 86 | ] = action_proprio_normalization_type 87 | 88 | del dataset_kwargs["proprio_encoding"] 89 | del dataset_kwargs["action_encoding"] 90 | 91 | dataset_kwargs["standardize_fn"] = ModuleSpec.create( 92 | OXE_STANDARDIZATION_TRANSFORMS[name] 93 | ) 94 | 95 | if force_recompute_dataset_statistics: 96 | dataset_kwargs["force_recompute_dataset_statistics"] = True 97 | 98 | return {"name": name, "data_dir": data_dir, **dataset_kwargs} 99 | 100 | 101 | def make_oxe_dataset_kwargs_and_weights( 102 | data_mix: Union[str, Sequence[Tuple[str, float]]], 103 | data_dir: str, 104 | load_camera_views: Sequence[str] = ("primary",), 105 | load_depth: bool = False, 106 | load_proprio: bool = False, 107 | load_language: bool = True, 108 | force_recompute_dataset_statistics: bool = False, 109 | action_proprio_normalization_type: NormalizationType = NormalizationType.NORMAL, 110 | ) -> Tuple[Dict[str, Any], List[float]]: 111 | """ 112 | Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs 113 | and weights can be passed directly into `octo.data.dataset.make_interleaved_dataset`. 114 | 115 | Args: 116 | data_mix: List of (dataset name, sampling weight) tuples, or a string specifying a pre-defined mix to 117 | load from `OXE_NAMED_MIXES`. 118 | data_dir: Base data directory that contains the datasets. 119 | load_camera_views: Which views to load. See `oxe_dataset_configs.py` for available views. 120 | load_depth: If True, loads corresponding depth channels for each RGB channel. 121 | load_proprio: If True, loads proprioceptive information. 122 | load_language: If True, loads language instructions. 123 | force_recompute_dataset_statistics: If True, recompute dataset statistics. 124 | action_proprio_normalization_type: Normalization type to use for proprioceptive actions. 125 | Returns: 126 | Tuple of (dataset_kwargs_list, sampling weights). 127 | """ 128 | if isinstance(data_mix, str): 129 | data_mix = OXE_NAMED_MIXES[data_mix] 130 | 131 | filtered_datasets, included_dataset_names = [], [] 132 | for name, weight in data_mix: 133 | if name not in included_dataset_names: 134 | filtered_datasets.append((name, weight)) 135 | included_dataset_names.append(name) 136 | else: 137 | logging.warning(f"Skipping duplicate: {(name, weight)}.") 138 | data_mix = filtered_datasets 139 | 140 | data_kwargs_list, weights = [], [] 141 | for name, weight in data_mix: 142 | try: 143 | data_kwargs_list.append( 144 | make_oxe_dataset_kwargs( 145 | name, 146 | data_dir, 147 | load_camera_views, 148 | load_depth, 149 | load_proprio, 150 | load_language, 151 | force_recompute_dataset_statistics, 152 | action_proprio_normalization_type, 153 | ) 154 | ) 155 | weights.append(weight) 156 | except ValueError as e: 157 | logging.warning(f"Skipping {name} due to error: {e}") 158 | 159 | return data_kwargs_list, weights 160 | -------------------------------------------------------------------------------- /octo/data/oxe/oxe_dataset_mixes.py: -------------------------------------------------------------------------------- 1 | """Defines dataset mixtures and weights for the Open X-Embodiment Datasets.""" 2 | 3 | 4 | BRIDGE_MIX = [ 5 | ("bridge_dataset", 1.0), 6 | ] 7 | 8 | RT_X_MIX = [ 9 | ("fractal20220817_data", 0.54087122203), 10 | ("kuka", 0.8341046294), 11 | ("bridge_dataset", 1.0), 12 | ("taco_play", 2.0), 13 | ("jaco_play", 2.0), 14 | ("berkeley_cable_routing", 3.0), 15 | ("roboturk", 1.0), 16 | ("nyu_door_opening_surprising_effectiveness", 5.0), 17 | ("viola", 2.0), 18 | ("berkeley_autolab_ur5", 1.0), 19 | ("toto", 1.0), 20 | ] 21 | 22 | 23 | OXE_FRANKA_MIX = [ 24 | ("taco_play", 1.0), 25 | ("berkeley_cable_routing", 1.0), 26 | ("viola", 1.0), 27 | ("toto", 1.0), 28 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), 29 | ("austin_buds_dataset_converted_externally_to_rlds", 3.0), 30 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 31 | ("maniskill_dataset_converted_externally_to_rlds", 0.1), 32 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 33 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0), 34 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 35 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 36 | ("berkeley_rpt_converted_externally_to_rlds", 1.0), 37 | ("kaist_nonprehensile_converted_externally_to_rlds", 3.0), 38 | ("stanford_robocook_converted_externally_to_rlds", 1.0), 39 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 40 | ("utaustin_mutex", 1.0), 41 | # ("cmu_playing_with_food", 1.0), 42 | ("cmu_play_fusion", 1.0), 43 | ] 44 | 45 | 46 | OXE_MAGIC_SOUP = [ 47 | ("fractal20220817_data", 0.54087122203), 48 | ("kuka", 0.8341046294), 49 | ("bridge_dataset", 1.0), 50 | ("taco_play", 2.0), 51 | ("jaco_play", 1.0), 52 | ("berkeley_cable_routing", 1.0), 53 | ("roboturk", 2.0), 54 | ("nyu_door_opening_surprising_effectiveness", 1.0), 55 | ("viola", 2.0), 56 | ("berkeley_autolab_ur5", 2.0), 57 | ("toto", 1.0), 58 | ("language_table", 0.1), 59 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 60 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 61 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 62 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 63 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 64 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 65 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 66 | ("bc_z", 0.2), 67 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 68 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 69 | # ("uiuc_d3field", 1.0), --> somehow raw data is broken 70 | ("utaustin_mutex", 1.0), 71 | ("berkeley_fanuc_manipulation", 2.0), 72 | ("cmu_stretch", 1.0), 73 | ] 74 | 75 | 76 | OXE_FLEX_ACT_SOUP = [ 77 | ("fractal20220817_data", 0.54087122203), 78 | ("kuka", 0.8341046294), 79 | ("bridge_dataset", 1.0), 80 | ("taco_play", 2.0), 81 | ("jaco_play", 1.0), 82 | ("berkeley_cable_routing", 1.0), 83 | ("roboturk", 2.0), 84 | ("nyu_door_opening_surprising_effectiveness", 1.0), 85 | ("viola", 2.0), 86 | ("berkeley_autolab_ur5", 2.0), 87 | ("toto", 1.0), 88 | ("language_table", 0.1), 89 | ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), 90 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 91 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), 92 | ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), 93 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), 94 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 95 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 96 | ("bc_z", 0.2), 97 | ("berkeley_mvp_converted_externally_to_rlds", 1.0), 98 | # ("berkeley_rpt_converted_externally_to_rlds", 1.0), 99 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 100 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 101 | # ("uiuc_d3field", 1.0), --> somehow raw data is broken 102 | ("utaustin_mutex", 1.0), 103 | ("berkeley_fanuc_manipulation", 2.0), 104 | ("cmu_stretch", 1.0), 105 | ("gnm_dataset", 1.0), 106 | ("aloha_static_dataset", 3.0), 107 | # ("aloha_dagger_dataset", 1.0), 108 | ("aloha_mobile_dataset", 2.0), 109 | # ("fmb_dataset", 1.0), 110 | ("dobbe", 1.0), 111 | ("roboset", 0.5), 112 | ("rh20t", 0.5), 113 | ] 114 | 115 | 116 | OXE_FULL_MIX = [ 117 | ("fractal20220817_data", 1.0), 118 | ("kuka", 1.0), 119 | ("bridge_dataset", 1), 120 | ("taco_play", 1.0), 121 | ("jaco_play", 1.0), 122 | ("berkeley_cable_routing", 1.0), 123 | ("roboturk", 1.0), 124 | ("nyu_door_opening_surprising_effectiveness", 1.0), 125 | ("viola", 1.0), 126 | ("berkeley_autolab_ur5", 1.0), 127 | ("toto", 1.0), 128 | ("language_table", 1.0), 129 | ("columbia_cairlab_pusht_real", 1.0), 130 | ("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0), 131 | ("nyu_rot_dataset_converted_externally_to_rlds", 1.0), 132 | ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), 133 | ("austin_buds_dataset_converted_externally_to_rlds", 1.0), 134 | ("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0), 135 | ("maniskill_dataset_converted_externally_to_rlds", 1.0), 136 | ("furniture_bench_dataset_converted_externally_to_rlds", 1.0), 137 | ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0), 138 | ("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0), 139 | ("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0), 140 | ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), 141 | ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), 142 | ("bc_z", 1.0), 143 | ("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0), 144 | ("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0), 145 | ("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0), 146 | ("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0), 147 | ("robo_net", 1.0), 148 | ("berkeley_mvp_converted_externally_to_rlds", 1.0), 149 | ("berkeley_rpt_converted_externally_to_rlds", 1.0), 150 | ("kaist_nonprehensile_converted_externally_to_rlds", 1.0), 151 | ("stanford_mask_vit_converted_externally_to_rlds", 1.0), 152 | ("tokyo_u_lsmo_converted_externally_to_rlds", 1.0), 153 | ("dlr_sara_pour_converted_externally_to_rlds", 1.0), 154 | ("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0), 155 | ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), 156 | ("asu_table_top_converted_externally_to_rlds", 1.0), 157 | ("stanford_robocook_converted_externally_to_rlds", 1.0), 158 | ("imperialcollege_sawyer_wrist_cam", 1.0), 159 | ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), 160 | ("uiuc_d3field", 1.0), 161 | ("utaustin_mutex", 1.0), 162 | ("berkeley_fanuc_manipulation", 1.0), 163 | ("cmu_playing_with_food", 1.0), 164 | ("cmu_play_fusion", 1.0), 165 | ("cmu_stretch", 1.0), 166 | ("gnm_dataset", 1.0), 167 | ] 168 | 169 | OXE_NAMED_MIXES = { 170 | "bridge": BRIDGE_MIX, 171 | "rtx": RT_X_MIX, 172 | "rtx_franka": RT_X_MIX + OXE_FRANKA_MIX, 173 | "oxe_magic_soup": OXE_MAGIC_SOUP, 174 | "oxe_flex_act_soup": OXE_FLEX_ACT_SOUP, 175 | } 176 | -------------------------------------------------------------------------------- /octo/data/traj_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains trajectory transforms used in the octo data pipeline. Trajectory transforms operate on a dictionary 3 | that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory 4 | length). 5 | """ 6 | from typing import Optional 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def chunk_act_obs( 12 | traj: dict, 13 | window_size: int = 1, 14 | action_horizon: int = 1, 15 | ) -> dict: 16 | """Chunks actions and observations. 17 | 18 | "observation" keys are given a new history axis, making them of shape [traj_len, window_size, ...], 19 | containing the observation history at each timestep from `t - window_size + 1` to `t`. 20 | 21 | The "action" key is given two new axes, making it of shape [traj_len, window_size, action_horizon, 22 | action_dim]. The first two axes are the same as in the observations, i.e., an action chunk `action[t, h]` 23 | corresponds to an observation `observation[t, h]`. The third axis indexes into the action chunk, 24 | containing the current action plus `action_horizon - 1` future actions. 25 | 26 | The "action" key can also be pre-chunked coming into this function, meaning it starts with shape 27 | [traj_len, N, action_dim] instead of [traj_len, action_dim]. In this case, `N` must be larger than or 28 | equal to `action_horizon`, and only one axis will be added (the history axis). This is useful for 29 | custom chunking schemes where an action may differ depending on which observation it is paired with. 30 | """ 31 | traj_len = tf.shape(traj["action"])[0] 32 | 33 | # chunk observations into histories 34 | history_indices = tf.range(traj_len)[:, None] + tf.range( 35 | -window_size + 1, 1 36 | ) # [traj_len, window_size] 37 | # indicates which observations at the beginning of the trajectory are padding 38 | timestep_pad_mask = history_indices >= 0 39 | # repeat the first observation at the beginning of the trajectory rather than going out of bounds 40 | history_indices = tf.maximum(history_indices, 0) 41 | # gather 42 | traj["observation"] = tf.nest.map_structure( 43 | lambda x: tf.gather(x, history_indices), traj["observation"] 44 | ) # [traj_len, window_size, ...] 45 | traj["observation"]["timestep_pad_mask"] = timestep_pad_mask 46 | 47 | # first, chunk actions into `action_horizon` current + future actions 48 | if len(traj["action"].shape) == 2: 49 | # actions are not pre-chunked 50 | action_chunk_indices = tf.range(traj_len)[:, None] + tf.range( 51 | action_horizon 52 | ) # [traj_len, action_horizon] 53 | # repeat the last action at the end of the trajectory rather than going out of bounds 54 | action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1) 55 | # gather 56 | traj["action"] = tf.gather( 57 | traj["action"], action_chunk_indices 58 | ) # [traj_len, action_horizon, action_dim] 59 | else: 60 | # actions are pre-chunked, so we don't add a new axis 61 | if traj["action"].shape[1] < action_horizon: 62 | raise ValueError( 63 | f"action_horizon ({action_horizon}) is greater than the pre-chunked action dimension ({traj['action'].shape[1]})" 64 | ) 65 | traj["action"] = traj["action"][:, :action_horizon] 66 | 67 | # then, add the history axis to actions 68 | traj["action"] = tf.gather( 69 | traj["action"], history_indices 70 | ) # [traj_len, window_size, action_horizon, action_dim] 71 | 72 | # finally, we deal with marking which actions are past the goal timestep (or final timestep if no goal) 73 | if "timestep" in traj["task"]: 74 | goal_timestep = traj["task"]["timestep"] 75 | else: 76 | goal_timestep = tf.fill([traj_len], traj_len - 1) 77 | # computes the number of timesteps away the goal is relative to a particular action 78 | t, w, h = tf.meshgrid( 79 | tf.range(traj_len), 80 | tf.range(window_size), 81 | tf.range(action_horizon), 82 | indexing="ij", 83 | ) 84 | relative_goal_timestep = goal_timestep[:, None, None] - ( 85 | t - (window_size + 1) + w + h 86 | ) # [traj_len, window_size, action_horizon] 87 | traj["observation"]["task_completed"] = relative_goal_timestep <= 0 88 | 89 | # broadcast "action_pad_mask" to the new chunked shape, and mark actions past the goal timestep as padding 90 | traj["action_pad_mask"] = tf.logical_and( 91 | # [traj_len, 1, 1, action_dim] 92 | traj["action_pad_mask"][:, None, None, :] 93 | if len(traj["action_pad_mask"].shape) == 2 94 | else traj["action_pad_mask"][:, None, :], 95 | # [traj_len, window_size, action_horizon, 1] 96 | tf.logical_not(traj["observation"]["task_completed"])[:, :, :, None], 97 | ) 98 | 99 | return traj 100 | 101 | 102 | def subsample(traj: dict, subsample_length: int) -> dict: 103 | """Subsamples trajectories to the given length.""" 104 | traj_len = tf.shape(traj["action"])[0] 105 | if traj_len > subsample_length: 106 | indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] 107 | traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) 108 | return traj 109 | 110 | 111 | def add_pad_mask_dict(traj: dict) -> dict: 112 | """Adds a dictionary indicating which elements of the observation/task should be treated as padding. 113 | 114 | traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} 115 | """ 116 | traj_len = tf.shape(traj["action"])[0] 117 | for key in ["observation", "task"]: 118 | pad_mask_dict = {} 119 | for subkey in traj[key]: 120 | if traj[key][subkey].dtype == tf.string: 121 | # handles "language_instruction", "image_*", and "depth_*" 122 | pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 123 | else: 124 | # all other keys should not be treated as padding 125 | pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) 126 | traj[key]["pad_mask_dict"] = pad_mask_dict 127 | return traj 128 | 129 | 130 | def pad_actions_and_proprio( 131 | traj: dict, max_action_dim: Optional[int], max_proprio_dim: Optional[int] 132 | ) -> dict: 133 | """Pads actions and proprio to a maximum number of dimensions across all datasets. 134 | 135 | Records which action dimensions are padding in "action_pad_mask". 136 | """ 137 | traj["action_pad_mask"] = tf.ones_like(traj["action"], dtype=tf.bool) 138 | if max_action_dim is not None: 139 | action_dim = traj["action"].shape[-1] 140 | if action_dim > max_action_dim: 141 | raise ValueError( 142 | f"action_dim ({action_dim}) is greater than max_action_dim ({max_action_dim})" 143 | ) 144 | for key in {"action", "action_pad_mask"}: 145 | traj[key] = tf.pad( 146 | traj[key], 147 | [ 148 | *[[0, 0]] * (len(traj[key].shape) - 1), 149 | [0, max_action_dim - action_dim], 150 | ], 151 | ) 152 | 153 | if max_proprio_dim is not None and "proprio" in traj["observation"]: 154 | proprio_dim = traj["observation"]["proprio"].shape[-1] 155 | if proprio_dim > max_proprio_dim: 156 | raise ValueError( 157 | f"proprio_dim ({proprio_dim}) is greater than max_proprio_dim ({max_proprio_dim})" 158 | ) 159 | traj["observation"]["proprio"] = tf.pad( 160 | traj["observation"]["proprio"], [[0, 0], [0, max_proprio_dim - proprio_dim]] 161 | ) 162 | return traj 163 | -------------------------------------------------------------------------------- /octo/data/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/octo/data/utils/__init__.py -------------------------------------------------------------------------------- /octo/data/utils/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. 3 | Each function should add entries to the "task" dict. 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import tensorflow as tf 9 | 10 | from octo.data.utils.data_utils import tree_merge 11 | 12 | 13 | def uniform(traj: dict, max_goal_distance: Optional[int] = None) -> dict: 14 | """ 15 | Relabels with a true uniform distribution over future states. 16 | Optionally caps goal distance. 17 | """ 18 | traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] 19 | 20 | # select a random future index for each transition i in the range [i, traj_len) 21 | rand = tf.random.uniform([traj_len]) 22 | low = tf.cast(tf.range(traj_len), tf.float32) 23 | if max_goal_distance is not None: 24 | high = tf.cast( 25 | tf.minimum(tf.range(traj_len) + max_goal_distance, traj_len), tf.float32 26 | ) 27 | else: 28 | high = tf.cast(traj_len, tf.float32) 29 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 30 | 31 | # sometimes there are floating-point errors that cause an out-of-bounds 32 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 33 | 34 | # adds keys to "task" mirroring "observation" keys (must do a tree merge to combine "pad_mask_dict" from 35 | # "observation" and "task" properly) 36 | goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) 37 | traj["task"] = tree_merge(traj["task"], goal) 38 | 39 | return traj 40 | -------------------------------------------------------------------------------- /octo/data/utils/task_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains basic logic for randomly zero-ing out keys in the task specification. 3 | """ 4 | 5 | import pickle 6 | 7 | from huggingface_hub import hf_hub_download 8 | import tensorflow as tf 9 | 10 | from octo.data.utils.data_utils import to_padding 11 | 12 | 13 | def delete_and_rephrase( 14 | traj, 15 | paraphrases_repo: str, 16 | paraphrases_filename: str, 17 | rephrase_prob: float, 18 | keep_image_prob: float, 19 | ): 20 | traj = rephrase_instruction( 21 | traj, paraphrases_repo, paraphrases_filename, rephrase_prob 22 | ) 23 | traj = delete_task_conditioning(traj, keep_image_prob) 24 | return traj 25 | 26 | 27 | class Rephraser: 28 | def create_static_hash_table(self, dictionary): 29 | """Takes a python dictionary with string keys and values and creates a tf static hash table""" 30 | keys = list(dictionary.keys()) 31 | values = list(dictionary.values()) 32 | initializer = tf.lookup.KeyValueTensorInitializer( 33 | keys, values, key_dtype=tf.string, value_dtype=tf.string 34 | ) 35 | hash_table = tf.lookup.StaticHashTable(initializer, default_value="") 36 | return hash_table 37 | 38 | def __init__(self, paraphrases_repo: str, paraphrases_filename: str): 39 | if isinstance(paraphrases_repo, str) and isinstance(paraphrases_filename, str): 40 | with open( 41 | hf_hub_download( 42 | repo_id=paraphrases_repo, 43 | filename=paraphrases_filename, 44 | repo_type="dataset", 45 | ), 46 | "rb", 47 | ) as file: 48 | lang_paraphrases = pickle.load(file) 49 | # Create StaticHashTable 50 | self.rephrase_lookup = self.create_static_hash_table(lang_paraphrases) 51 | 52 | 53 | def rephrase_instruction( 54 | traj: dict, paraphrases_repo: str, paraphrases_filename: str, rephrase_prob: float 55 | ) -> dict: 56 | """Randomly rephrases language instructions with precomputed paraphrases 57 | Args: 58 | traj: A dictionary containing trajectory data. Should have a "task" key. 59 | paraphrases_repo: The name of the HF repo containing the paraphrases file. 60 | paraphrases_filename: The name of the file containing the paraphrases. 61 | rephrase_prob: The probability of augmenting the language instruction. The probability of keeping the language 62 | instruction is 1 - rephrase_prob. 63 | """ 64 | rephraser = Rephraser(paraphrases_repo, paraphrases_filename) 65 | 66 | if "language_instruction" not in traj["task"]: 67 | return traj 68 | original_language = traj["task"]["language_instruction"] 69 | # check the language key is not empty 70 | string_is_not_empty = tf.reduce_all(tf.strings.length(original_language) > 0) 71 | # check dict is not empty 72 | dict_is_not_empty = bool(rephraser.rephrase_lookup) 73 | if dict_is_not_empty and string_is_not_empty: 74 | rephrased_instruction = rephraser.rephrase_lookup.lookup(original_language[0]) 75 | rephrased_instruction = tf.where( 76 | tf.strings.length(rephrased_instruction) > 0, 77 | original_language[0] + "." + rephrased_instruction, 78 | original_language[0], 79 | ) 80 | split_tensor = tf.strings.split(rephrased_instruction, sep=".") 81 | num_strings = tf.cast(tf.shape(split_tensor)[0], tf.int32) 82 | random_index = tf.random.uniform( 83 | (tf.shape(original_language)[0],), 84 | minval=0, 85 | maxval=num_strings, 86 | dtype=tf.int32, 87 | ) 88 | sampled_language = tf.gather(split_tensor, random_index) 89 | rand = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32) 90 | sampled_language = tf.where( 91 | rand < rephrase_prob, 92 | sampled_language, 93 | original_language, 94 | ) 95 | traj["task"]["language_instruction"] = sampled_language 96 | return traj 97 | 98 | 99 | def delete_task_conditioning( 100 | traj: dict, 101 | keep_image_prob: float, 102 | ): 103 | """ 104 | Randomly drops out either the goal images or the language instruction. Only does something if both of 105 | these are present. 106 | 107 | Args: 108 | traj: A dictionary containing trajectory data. Should have a "task" key. 109 | keep_image_prob: The probability of keeping the goal images. The probability of keeping the language 110 | instruction is 1 - keep_image_prob. 111 | """ 112 | if "language_instruction" not in traj["task"]: 113 | return traj 114 | 115 | image_keys = { 116 | key 117 | for key in traj["task"].keys() 118 | if key.startswith("image_") or key.startswith("depth_") 119 | } 120 | if not image_keys: 121 | return traj 122 | 123 | traj_len = tf.shape(traj["action"])[0] 124 | should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob 125 | should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] 126 | 127 | for key in image_keys | {"language_instruction"}: 128 | should_keep = should_keep_images if key in image_keys else ~should_keep_images 129 | # pad out the key 130 | traj["task"][key] = tf.where( 131 | should_keep, 132 | traj["task"][key], 133 | to_padding(traj["task"][key]), 134 | ) 135 | # zero out the pad mask dict for the key 136 | traj["task"]["pad_mask_dict"][key] = tf.where( 137 | should_keep, 138 | traj["task"]["pad_mask_dict"][key], 139 | tf.zeros_like(traj["task"]["pad_mask_dict"][key]), 140 | ) 141 | 142 | # when no goal images are present, the goal timestep becomes the final timestep 143 | traj["task"]["timestep"] = tf.where( 144 | should_keep_images, 145 | traj["task"]["timestep"], 146 | traj_len - 1, 147 | ) 148 | 149 | return traj 150 | -------------------------------------------------------------------------------- /octo/data/utils/text_processing.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Sequence 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" 8 | 9 | 10 | class TextProcessor(ABC): 11 | """ 12 | Base class for text tokenization or text embedding. 13 | """ 14 | 15 | @abstractmethod 16 | def encode(self, strings: Sequence[str]): 17 | raise NotImplementedError 18 | 19 | 20 | class HFTokenizer(TextProcessor): 21 | def __init__( 22 | self, 23 | tokenizer_name: str, 24 | tokenizer_kwargs: Optional[dict] = { 25 | "max_length": 64, 26 | "padding": "max_length", 27 | "truncation": True, 28 | "return_tensors": "np", 29 | }, 30 | encode_with_model: bool = False, 31 | ): 32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 35 | self.tokenizer_kwargs = tokenizer_kwargs 36 | self.encode_with_model = encode_with_model 37 | if self.encode_with_model: 38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name) 39 | 40 | def encode(self, strings: Sequence[str]): 41 | # this creates another nested layer with "input_ids", "attention_mask", etc. 42 | inputs = self.tokenizer( 43 | strings, 44 | **self.tokenizer_kwargs, 45 | ) 46 | if self.encode_with_model: 47 | return np.array(self.model(**inputs).last_hidden_state) 48 | else: 49 | return dict(inputs) 50 | 51 | 52 | class MuseEmbedding(TextProcessor): 53 | def __init__(self): 54 | import tensorflow_hub as hub # lazy import 55 | import tensorflow_text # noqa: F401 56 | 57 | self.muse_model = hub.load(MULTI_MODULE) 58 | 59 | def encode(self, strings: Sequence[str]): 60 | with tf.device("/cpu:0"): 61 | return self.muse_model(strings).numpy() 62 | 63 | 64 | class CLIPTextProcessor(TextProcessor): 65 | def __init__( 66 | self, 67 | tokenizer_kwargs: Optional[dict] = { 68 | "max_length": 64, 69 | "padding": "max_length", 70 | "truncation": True, 71 | "return_tensors": "np", 72 | }, 73 | ): 74 | from transformers import CLIPProcessor 75 | 76 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 77 | self.kwargs = tokenizer_kwargs 78 | 79 | def encode(self, strings: Sequence[str]): 80 | inputs = self.processor( 81 | text=strings, 82 | **self.kwargs, 83 | ) 84 | inputs["position_ids"] = np.expand_dims( 85 | np.arange(inputs["input_ids"].shape[1]), axis=0 86 | ).repeat(inputs["input_ids"].shape[0], axis=0) 87 | return inputs 88 | -------------------------------------------------------------------------------- /octo/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/octo/model/__init__.py -------------------------------------------------------------------------------- /octo/model/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/octo/model/components/__init__.py -------------------------------------------------------------------------------- /octo/model/components/base.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | from octo.utils.typing import Sequence 6 | 7 | 8 | @flax.struct.dataclass 9 | class TokenGroup: 10 | """A group of tokens that have semantic meaning together (e.g. the tokens for a single observation) 11 | 12 | Attributes: 13 | tokens: jax.Array of shape (..., n_tokens, token_dim) 14 | mask: jax.Array of shape (..., n_tokens) indicating which tokens are valid (1) vs padding (0) 15 | """ 16 | 17 | tokens: jax.typing.ArrayLike 18 | mask: jax.typing.ArrayLike 19 | 20 | @classmethod 21 | def create( 22 | cls, tokens: jax.typing.ArrayLike, mask: jax.typing.ArrayLike = None, **kwargs 23 | ): 24 | if mask is None: 25 | mask = jnp.ones(tokens.shape[:-1]) 26 | assert mask.ndim == tokens.ndim - 1 27 | return cls(tokens, mask, **kwargs) 28 | 29 | @classmethod 30 | def concatenate(cls, group_list: Sequence["TokenGroup"], axis=-2): 31 | data = jnp.concatenate([t.tokens for t in group_list], axis=axis) 32 | mask = jnp.concatenate([t.mask for t in group_list], axis=axis + 1) 33 | return cls(data, mask) 34 | -------------------------------------------------------------------------------- /octo/model/components/diffusion.py: -------------------------------------------------------------------------------- 1 | # copied from: https://raw.githubusercontent.com/rail-berkeley/bridge_data_v2/main/jaxrl_m/networks/diffusion_nets.py 2 | import logging 3 | from typing import Callable, Optional, Sequence 4 | 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | default_init = nn.initializers.xavier_uniform 10 | 11 | 12 | def cosine_beta_schedule(timesteps, s=0.008): 13 | """ 14 | cosine schedule 15 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 16 | """ 17 | steps = timesteps + 1 18 | t = jnp.linspace(0, timesteps, steps) / timesteps 19 | alphas_cumprod = jnp.cos((t + s) / (1 + s) * jnp.pi * 0.5) ** 2 20 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 21 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 22 | return jnp.clip(betas, 0, 0.999) 23 | 24 | 25 | class ScoreActor(nn.Module): 26 | time_preprocess: nn.Module 27 | cond_encoder: nn.Module 28 | reverse_network: nn.Module 29 | 30 | def __call__(self, obs_enc, actions, time, train=False): 31 | """ 32 | Args: 33 | obs_enc: (bd..., obs_dim) where bd... is broadcastable to batch_dims 34 | actions: (batch_dims..., action_dim) 35 | time: (batch_dims..., 1) 36 | """ 37 | t_ff = self.time_preprocess(time) 38 | cond_enc = self.cond_encoder(t_ff, train=train) 39 | if obs_enc.shape[:-1] != cond_enc.shape[:-1]: 40 | new_shape = cond_enc.shape[:-1] + (obs_enc.shape[-1],) 41 | logging.debug( 42 | "Broadcasting obs_enc from %s to %s", obs_enc.shape, new_shape 43 | ) 44 | obs_enc = jnp.broadcast_to(obs_enc, new_shape) 45 | 46 | reverse_input = jnp.concatenate([cond_enc, obs_enc, actions], axis=-1) 47 | eps_pred = self.reverse_network(reverse_input, train=train) 48 | return eps_pred 49 | 50 | 51 | class FourierFeatures(nn.Module): 52 | output_size: int 53 | learnable: bool = True 54 | 55 | @nn.compact 56 | def __call__(self, x: jax.Array): 57 | if self.learnable: 58 | w = self.param( 59 | "kernel", 60 | nn.initializers.normal(0.2), 61 | (self.output_size // 2, x.shape[-1]), 62 | jnp.float32, 63 | ) 64 | f = 2 * jnp.pi * x @ w.T 65 | else: 66 | half_dim = self.output_size // 2 67 | f = jnp.log(10000) / (half_dim - 1) 68 | f = jnp.exp(jnp.arange(half_dim) * -f) 69 | f = x * f 70 | return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) 71 | 72 | 73 | class MLP(nn.Module): 74 | hidden_dims: Sequence[int] 75 | activation: Callable = nn.swish 76 | activate_final: bool = False 77 | use_layer_norm: bool = False 78 | dropout_rate: Optional[float] = None 79 | 80 | @nn.compact 81 | def __call__(self, x: jax.Array, train: bool = False) -> jax.Array: 82 | for i, size in enumerate(self.hidden_dims): 83 | x = nn.Dense(size, kernel_init=default_init())(x) 84 | 85 | if i + 1 < len(self.hidden_dims) or self.activate_final: 86 | if self.dropout_rate is not None and self.dropout_rate > 0: 87 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 88 | if self.use_layer_norm: 89 | x = nn.LayerNorm()(x) 90 | x = self.activation(x) 91 | return x 92 | 93 | 94 | class MLPResNetBlock(nn.Module): 95 | features: int 96 | act: Callable 97 | dropout_rate: float = None 98 | use_layer_norm: bool = False 99 | 100 | @nn.compact 101 | def __call__(self, x, train: bool = False): 102 | residual = x 103 | if self.dropout_rate is not None and self.dropout_rate > 0: 104 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 105 | if self.use_layer_norm: 106 | x = nn.LayerNorm()(x) 107 | x = nn.Dense(self.features * 4)(x) 108 | x = self.act(x) 109 | x = nn.Dense(self.features)(x) 110 | 111 | if residual.shape != x.shape: 112 | residual = nn.Dense(self.features)(residual) 113 | 114 | return residual + x 115 | 116 | 117 | class MLPResNet(nn.Module): 118 | num_blocks: int 119 | out_dim: int 120 | dropout_rate: float = None 121 | use_layer_norm: bool = False 122 | hidden_dim: int = 256 123 | activation: Callable = nn.swish 124 | 125 | @nn.compact 126 | def __call__(self, x: jax.typing.ArrayLike, train: bool = False) -> jax.Array: 127 | x = nn.Dense(self.hidden_dim, kernel_init=default_init())(x) 128 | for _ in range(self.num_blocks): 129 | x = MLPResNetBlock( 130 | self.hidden_dim, 131 | act=self.activation, 132 | use_layer_norm=self.use_layer_norm, 133 | dropout_rate=self.dropout_rate, 134 | )(x, train=train) 135 | 136 | x = self.activation(x) 137 | x = nn.Dense(self.out_dim, kernel_init=default_init())(x) 138 | return x 139 | 140 | 141 | def create_diffusion_model( 142 | out_dim: int, 143 | time_dim: int, 144 | num_blocks: int, 145 | dropout_rate: float, 146 | hidden_dim: int, 147 | use_layer_norm: bool, 148 | ): 149 | return ScoreActor( 150 | FourierFeatures(time_dim, learnable=True), 151 | MLP((2 * time_dim, time_dim)), 152 | MLPResNet( 153 | num_blocks, 154 | out_dim, 155 | dropout_rate=dropout_rate, 156 | hidden_dim=hidden_dim, 157 | use_layer_norm=use_layer_norm, 158 | ), 159 | ) 160 | -------------------------------------------------------------------------------- /octo/model/components/film_conditioning_layer.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | 7 | class FilmConditioning(nn.Module): 8 | @nn.compact 9 | def __call__(self, conv_filters: jnp.ndarray, conditioning: jnp.ndarray): 10 | """Applies FiLM conditioning to a convolutional feature map. 11 | 12 | Args: 13 | conv_filters: A tensor of shape [batch_size, height, width, channels]. 14 | conditioning: A tensor of shape [batch_size, conditioning_size]. 15 | 16 | Returns: 17 | A tensor of shape [batch_size, height, width, channels]. 18 | """ 19 | projected_cond_add = nn.Dense( 20 | features=conv_filters.shape[-1], 21 | kernel_init=nn.initializers.zeros, 22 | bias_init=nn.initializers.zeros, 23 | )(conditioning) 24 | projected_cond_mult = nn.Dense( 25 | features=conv_filters.shape[-1], 26 | kernel_init=nn.initializers.zeros, 27 | bias_init=nn.initializers.zeros, 28 | )(conditioning) 29 | 30 | projected_cond_add = projected_cond_add[:, None, None, :] 31 | projected_cond_mult = projected_cond_mult[:, None, None, :] 32 | 33 | return conv_filters * (1 + projected_cond_add) + projected_cond_mult 34 | -------------------------------------------------------------------------------- /octo/model/components/tokenizers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import Dict, Optional, Sequence 4 | 5 | import flax 6 | import flax.linen as nn 7 | import jax 8 | import jax.numpy as jnp 9 | from jax.scipy.stats import norm 10 | import numpy as np 11 | 12 | from octo.model.components.base import TokenGroup 13 | from octo.model.components.transformer import MAPHead 14 | from octo.utils.spec import ModuleSpec 15 | 16 | EPS = 1e-6 17 | 18 | 19 | def generate_proper_pad_mask( 20 | tokens: jax.Array, 21 | pad_mask_dict: Optional[Dict[str, jax.Array]], 22 | keys: Sequence[str], 23 | ) -> jax.Array: 24 | if pad_mask_dict is None: 25 | logging.warning("No pad_mask_dict found. Nothing will be masked.") 26 | return jnp.ones(tokens.shape[:-1]) 27 | if not all([key in pad_mask_dict for key in keys]): 28 | logging.warning( 29 | f"pad_mask_dict missing keys {set(keys) - set(pad_mask_dict.keys())}." 30 | "Nothing will be masked." 31 | ) 32 | return jnp.ones(tokens.shape[:-1]) 33 | 34 | pad_mask = jnp.stack([pad_mask_dict[key] for key in keys], axis=-1) 35 | pad_mask = jnp.any(pad_mask, axis=-1) 36 | pad_mask = jnp.broadcast_to(pad_mask[..., None], tokens.shape[:-1]) 37 | return pad_mask 38 | 39 | 40 | class TokenLearner(nn.Module): 41 | """ 42 | Learns to map fixed-length sequence of tokens into specified number of tokens. 43 | 44 | Args: 45 | num_tokens (int): Number of output tokens. 46 | bottleneck_dim (int): Size of the hidden layers of the mapping MLP. 47 | dropout_rate (float): Rate of dropout applied in the mapping MLP. Defaults to no dropout. 48 | """ 49 | 50 | num_tokens: int 51 | 52 | @nn.compact 53 | def __call__(self, inputs, train: bool = True): 54 | pos_embed = self.param( 55 | "pos_embed", 56 | nn.initializers.normal(stddev=0.02), 57 | (inputs.shape[-2], inputs.shape[-1]), 58 | ) 59 | x = inputs + jnp.broadcast_to(pos_embed, inputs.shape) 60 | x = nn.LayerNorm()(x) 61 | return MAPHead(num_readouts=self.num_tokens)(x, train=train) 62 | 63 | 64 | def regex_match(regex_keys, x): 65 | return any([re.match(r_key, x) for r_key in regex_keys]) 66 | 67 | 68 | def regex_filter(regex_keys, xs): 69 | return list(filter(lambda x: regex_match(regex_keys, x), xs)) 70 | 71 | 72 | class ImageTokenizer(nn.Module): 73 | """Image tokenizer that encodes image stack into tokens with optional FiLM conditioning. 74 | 75 | Args: 76 | encoder (ModuleSpec): Encoder class. 77 | use_token_learner (bool): Whether to use token learner. Defaults to False. 78 | num_tokens (int): Number of output tokens, only enforced when use_token_learner is True. 79 | obs_stack_keys (Sequence[str]): Which spatial observation inputs get stacked for encoder input. Supports regex. 80 | task_stack_keys (Sequence[str]): Which spatial task inputs get stacked for encoder input. Supports regex. 81 | task_film_keys (Sequence[str]): Which non-spatial task keys get passed into FiLM conditioning. Supports regex. 82 | """ 83 | 84 | encoder: ModuleSpec 85 | use_token_learner: bool = False 86 | num_tokens: int = 8 87 | conditioning_type: str = "none" 88 | obs_stack_keys: Sequence[str] = ("image_.*", "depth_.*") 89 | task_stack_keys: Sequence[str] = tuple() 90 | task_film_keys: Sequence[str] = tuple() 91 | proper_pad_mask: bool = True 92 | 93 | @nn.compact 94 | def __call__( 95 | self, 96 | observations, 97 | tasks=None, 98 | train: bool = True, 99 | ): 100 | def extract_inputs(keys, inputs, check_spatial=False): 101 | extracted_outputs = [] 102 | for key in keys: 103 | if check_spatial: 104 | assert len(inputs[key].shape) >= 4 105 | extracted_outputs.append(inputs[key]) 106 | return jnp.concatenate(extracted_outputs, axis=-1) 107 | 108 | obs_stack_keys = regex_filter(self.obs_stack_keys, sorted(observations.keys())) 109 | if len(obs_stack_keys) == 0: 110 | logging.info( 111 | f"No image inputs matching {self.obs_stack_keys} were found." 112 | "Skipping tokenizer entirely." 113 | ) 114 | assert self.proper_pad_mask, "Cannot skip unless using proper_pad_mask." 115 | return None 116 | 117 | # stack all spatial observation and task inputs 118 | enc_inputs = extract_inputs(obs_stack_keys, observations, check_spatial=True) 119 | if self.task_stack_keys: 120 | needed_task_keys = regex_filter(self.task_stack_keys, observations.keys()) 121 | # if any task inputs are missing, replace with zero padding (TODO: be more flexible) 122 | for k in needed_task_keys: 123 | if k not in tasks: 124 | logging.info( 125 | f"No task inputs matching {k} were found. Replacing with zero padding." 126 | ) 127 | tasks = flax.core.copy( 128 | tasks, {k: jnp.zeros_like(observations[k][:, 0])} 129 | ) 130 | task_stack_keys = regex_filter(self.task_stack_keys, sorted(tasks.keys())) 131 | if len(task_stack_keys) == 0: 132 | raise ValueError( 133 | f"No task inputs matching {self.task_stack_keys} were found." 134 | ) 135 | task_inputs = extract_inputs(task_stack_keys, tasks, check_spatial=True) 136 | task_inputs = task_inputs[:, None].repeat(enc_inputs.shape[1], axis=1) 137 | enc_inputs = jnp.concatenate([enc_inputs, task_inputs], axis=-1) 138 | b, t, h, w, c = enc_inputs.shape 139 | enc_inputs = jnp.reshape(enc_inputs, (b * t, h, w, c)) 140 | 141 | # extract non-spatial FiLM inputs 142 | encoder_input_kwargs = {} 143 | if self.task_film_keys: 144 | film_inputs = extract_inputs(self.task_film_keys, tasks) 145 | film_inputs = film_inputs[:, None].repeat(t, axis=1) 146 | encoder_input_kwargs.update( 147 | {"cond_var": jnp.reshape(film_inputs, (b * t, -1))} 148 | ) 149 | 150 | # run visual encoder 151 | encoder_def = ModuleSpec.instantiate(self.encoder)() 152 | image_tokens = encoder_def(enc_inputs, **encoder_input_kwargs) 153 | image_tokens = jnp.reshape(image_tokens, (b, t, -1, image_tokens.shape[-1])) 154 | 155 | if self.use_token_learner: 156 | image_tokens = TokenLearner(num_tokens=self.num_tokens)( 157 | image_tokens, train=train 158 | ) 159 | 160 | if self.proper_pad_mask: 161 | pad_mask = generate_proper_pad_mask( 162 | image_tokens, 163 | observations.get("pad_mask_dict", None), 164 | obs_stack_keys, 165 | ) 166 | else: 167 | pad_mask = jnp.ones(image_tokens.shape[:-1]) 168 | return TokenGroup(image_tokens, pad_mask) 169 | 170 | 171 | class LanguageTokenizer(nn.Module): 172 | """ 173 | Language tokenizer that embeds text input IDs into continuous language embeddings. Supports pre-trained HF models. 174 | 175 | Args: 176 | num_tokens (int): Number of output tokens (not enforced). 177 | encoder (str, optional): Optional HuggingFace AutoModel name for encoding input IDs. 178 | finetune_encoder (bool, optional): Optional finetune last layers of the language model. 179 | """ 180 | 181 | encoder: str = None 182 | finetune_encoder: bool = False 183 | proper_pad_mask: bool = True 184 | 185 | def setup(self): 186 | if self.encoder is not None: 187 | from transformers import AutoConfig, FlaxAutoModel, FlaxT5EncoderModel 188 | 189 | config = AutoConfig.from_pretrained(self.encoder) 190 | if "t5" in self.encoder: 191 | self.hf_model = FlaxT5EncoderModel(config).module 192 | else: 193 | self.hf_model = FlaxAutoModel.from_config(config).module 194 | 195 | def __call__( 196 | self, 197 | observations, 198 | tasks=None, 199 | train: bool = True, 200 | ): 201 | if "language_instruction" not in tasks: 202 | logging.warning("No language inputs found. Skipping tokenizer entirely.") 203 | assert self.proper_pad_mask, "Cannot skip unless using proper pad mask." 204 | return None 205 | 206 | if not isinstance(tasks["language_instruction"], (jax.Array, np.ndarray)): 207 | assert ( 208 | self.encoder is not None 209 | ), "Received language tokens but no encoder specified." 210 | tokens = self.hf_model(**tasks["language_instruction"]).last_hidden_state 211 | else: 212 | # add a # tokens dimension to language 213 | if tasks["language_instruction"].ndim == 2: 214 | tokens = tasks["language_instruction"][:, None, :] 215 | else: 216 | tokens = tasks["language_instruction"] 217 | 218 | if not self.finetune_encoder: 219 | tokens = jax.lax.stop_gradient(tokens) 220 | 221 | # TODO: incorporate padding info from language tokens here too 222 | if self.proper_pad_mask: 223 | pad_mask = generate_proper_pad_mask( 224 | tokens, 225 | tasks.get("pad_mask_dict", None), 226 | ("language_instruction",), 227 | ) 228 | else: 229 | pad_mask = jnp.ones(tokens.shape[:-1]) 230 | 231 | return TokenGroup(tokens, pad_mask) 232 | 233 | 234 | class BinTokenizer(nn.Module): 235 | """ 236 | Tokenizes continuous inputs via dimension-wise binning in given range. 237 | 238 | Args: 239 | n_bins (int): Number of discrete bins per dimension. 240 | bin_type (str): Type of binning. ['uniform', 'normal' = Gaussian] 241 | low (float): Lower bound for bin range. 242 | high (float): Upper bound for bin range. 243 | """ 244 | 245 | n_bins: int = 256 246 | bin_type: str = "uniform" 247 | low: float = 0 248 | high: float = 1 249 | 250 | def setup(self): 251 | if self.bin_type == "uniform": 252 | self.thresholds = jnp.linspace(self.low, self.high, self.n_bins + 1) 253 | elif self.bin_type == "normal": 254 | self.thresholds = norm.ppf(jnp.linspace(EPS, 1 - EPS, self.n_bins + 1)) 255 | else: 256 | raise ValueError( 257 | f"Binning type {self.bin_type} not supported in BinTokenizer." 258 | ) 259 | 260 | def __call__(self, inputs): 261 | if self.bin_type == "uniform": 262 | inputs = jnp.clip(inputs, self.low + EPS, self.high - EPS) 263 | inputs = inputs[..., None] 264 | token_one_hot = (inputs < self.thresholds[1:]) & ( 265 | inputs >= self.thresholds[:-1] 266 | ).astype(jnp.uint8) 267 | output_tokens = jnp.argmax(token_one_hot, axis=-1) 268 | return output_tokens 269 | 270 | def decode(self, inputs): 271 | one_hot = jax.nn.one_hot(inputs, self.n_bins) 272 | bin_avgs = (self.thresholds[1:] + self.thresholds[:-1]) / 2 273 | outputs = jnp.sum(one_hot * bin_avgs, axis=-1) 274 | return outputs 275 | 276 | 277 | class LowdimObsTokenizer(BinTokenizer): 278 | """ 279 | Tokenizer for non-spatial observations. Optionally discretizes into bins per dimension (see BinTokenizer). 280 | 281 | Args: 282 | obs_keys (Sequence[str]): List of non-spatial keys to concatenate & tokenize. Supports regex. 283 | discretize (bool): If True, discretizes inputs per dimension, see BinTokenizer. 284 | """ 285 | 286 | obs_keys: Sequence[str] = tuple() 287 | discretize: bool = False 288 | proper_pad_mask: bool = True 289 | 290 | def __call__(self, observations, *unused_args, **unused_kwargs): 291 | assert self.obs_keys, "Need to specify observation keys to tokenize." 292 | if len(regex_filter(self.obs_keys, sorted(observations.keys()))) == 0: 293 | logging.warning( 294 | f"No observation inputs matching {self.obs_keys} were found." 295 | "Skipping tokenizer entirely." 296 | ) 297 | assert self.proper_pad_mask, "Cannot skip unless using proper pad mask." 298 | return None 299 | 300 | tokenizer_inputs = [] 301 | for o_key in self.obs_keys: 302 | for key in filter(re.compile(o_key).match, sorted(observations.keys())): 303 | assert ( 304 | len(observations[key].shape) == 3 305 | ), f"Only supports non-spatial inputs but {key} has shape {observations[key].shape}." 306 | tokenizer_inputs.append(observations[key]) 307 | tokenizer_inputs = jnp.concatenate(tokenizer_inputs, axis=-1) 308 | if self.discretize: 309 | tokenized_inputs = super().__call__(tokenizer_inputs) 310 | tokens = jax.nn.one_hot(tokenized_inputs, self.n_bins) 311 | else: 312 | tokens = tokenizer_inputs[..., None] 313 | mask = jnp.ones(tokens.shape[:-1]) 314 | return TokenGroup(tokens, mask) 315 | -------------------------------------------------------------------------------- /octo/model/components/transformer.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py 2 | from typing import Callable, Optional 3 | 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from octo.model.components.base import TokenGroup 9 | from octo.utils.typing import Dtype, PRNGKey, Shape, Union 10 | 11 | 12 | class AddPositionEmbs(nn.Module): 13 | """Adds learned positional embeddings to the inputs. 14 | 15 | Attributes: 16 | posemb_init: positional embedding initializer. 17 | """ 18 | 19 | posemb_init: Callable[[PRNGKey, Shape, Dtype], jax.Array] 20 | 21 | @nn.compact 22 | def __call__(self, inputs): 23 | """Applies the AddPositionEmbs module. 24 | 25 | Args: 26 | inputs: Inputs to the layer. 27 | 28 | Returns: 29 | Output tensor with shape `(bs, timesteps, in_dim)`. 30 | """ 31 | # inputs.shape is (batch_size, seq_len, emb_dim). 32 | assert inputs.ndim == 3, ( 33 | "Number of dimensions should be 3," " but it is: %d" % inputs.ndim 34 | ) 35 | pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) 36 | pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape) 37 | return inputs + pe 38 | 39 | 40 | class MlpBlock(nn.Module): 41 | """Transformer MLP / feed-forward block.""" 42 | 43 | mlp_dim: int 44 | dtype: Dtype = jnp.float32 45 | out_dim: Optional[int] = None 46 | dropout_rate: float = 0.1 47 | kernel_init: Callable[ 48 | [PRNGKey, Shape, Dtype], jax.Array 49 | ] = nn.initializers.xavier_uniform() 50 | bias_init: Callable[[PRNGKey, Shape, Dtype], jax.Array] = nn.initializers.normal( 51 | stddev=1e-6 52 | ) 53 | 54 | @nn.compact 55 | def __call__(self, inputs, *, deterministic): 56 | """Applies Transformer MlpBlock module.""" 57 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim 58 | x = nn.Dense( 59 | features=self.mlp_dim, 60 | dtype=self.dtype, 61 | kernel_init=self.kernel_init, 62 | bias_init=self.bias_init, 63 | )(inputs) 64 | x = nn.gelu(x) 65 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 66 | output = nn.Dense( 67 | features=actual_out_dim, 68 | dtype=self.dtype, 69 | kernel_init=self.kernel_init, 70 | bias_init=self.bias_init, 71 | )(x) 72 | output = nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic) 73 | return output 74 | 75 | 76 | class MAPHead(nn.Module): 77 | """Multihead Attention Pooling. 78 | 79 | From https://github.com/google-research/big_vision/blob/main/big_vision/models/vit.py 80 | """ 81 | 82 | mlp_dim: Optional[int] = None # Defaults to 4x input dim 83 | num_heads: int = 8 84 | num_readouts: int = 1 85 | 86 | @nn.compact 87 | def __call__(self, x: Union[jax.Array, TokenGroup], train=True): 88 | if isinstance(x, TokenGroup): 89 | x, mask = x.tokens, x.mask 90 | else: 91 | mask = None 92 | 93 | *batch_dims, l, d = x.shape 94 | x = x.reshape(-1, l, d) 95 | batch_size = x.shape[0] 96 | 97 | probe = self.param( 98 | "probe", 99 | nn.initializers.xavier_uniform(), 100 | (1, self.num_readouts, d), 101 | x.dtype, 102 | ) 103 | probe = jnp.tile(probe, [batch_size, 1, 1]) 104 | 105 | if mask is not None: 106 | mask = mask.reshape(-1, l) 107 | mask = jnp.broadcast_to( 108 | mask[:, None, None, :], (batch_size, 1, self.num_readouts, l) 109 | ) 110 | 111 | out = nn.MultiHeadDotProductAttention( 112 | num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform() 113 | )(probe, x, mask=mask) 114 | 115 | # TODO: dropout on head? 116 | y = nn.LayerNorm()(out) 117 | 118 | out = out + MlpBlock(mlp_dim=nn.merge_param("mlp_dim", self.mlp_dim, 4 * d))( 119 | y, deterministic=not train 120 | ) 121 | out = out.reshape(*batch_dims, self.num_readouts, d) 122 | return out 123 | 124 | 125 | class Encoder1DBlock(nn.Module): 126 | """Transformer encoder layer. 127 | 128 | Attributes: 129 | inputs: input data. 130 | mlp_dim: dimension of the mlp on top of attention block. 131 | dtype: the dtype of the computation (default: float32). 132 | dropout_rate: dropout rate. 133 | attention_dropout_rate: dropout for attention heads. 134 | deterministic: bool, deterministic or not (to apply dropout). 135 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 136 | """ 137 | 138 | mlp_dim: int 139 | num_heads: int 140 | dtype: Dtype = jnp.float32 141 | dropout_rate: float = 0.1 142 | attention_dropout_rate: float = 0.1 143 | 144 | @nn.compact 145 | def __call__(self, inputs, attention_mask, *, deterministic): 146 | """Applies Encoder1DBlock module. 147 | 148 | Args: 149 | inputs: Inputs to the layer. 150 | deterministic: Dropout will not be applied when set to true. 151 | 152 | Returns: 153 | output after transformer encoder block. 154 | """ 155 | 156 | # Attention block. 157 | assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" 158 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 159 | x = nn.MultiHeadDotProductAttention( 160 | dtype=self.dtype, 161 | kernel_init=nn.initializers.xavier_uniform(), 162 | broadcast_dropout=False, 163 | deterministic=deterministic, 164 | dropout_rate=self.attention_dropout_rate, 165 | num_heads=self.num_heads, 166 | )(x, x, mask=attention_mask) 167 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 168 | x = x + inputs 169 | 170 | # MLP block. 171 | y = nn.LayerNorm(dtype=self.dtype)(x) 172 | y = MlpBlock( 173 | mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate 174 | )(y, deterministic=deterministic) 175 | 176 | return x + y 177 | 178 | 179 | class Transformer(nn.Module): 180 | """Transformer Model Encoder for sequence to sequence translation. 181 | 182 | Attributes: 183 | num_layers: number of layers 184 | mlp_dim: dimension of the mlp on top of attention block 185 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 186 | dropout_rate: dropout rate. 187 | attention_dropout_rate: dropout rate in self attention. 188 | """ 189 | 190 | num_layers: int 191 | mlp_dim: int 192 | num_attention_heads: int 193 | dropout_rate: float = 0.1 194 | attention_dropout_rate: float = 0.1 195 | add_position_embedding: bool = False 196 | 197 | @nn.compact 198 | def __call__(self, x, attention_mask, *, train): 199 | """Applies Transformer model on the inputs. 200 | 201 | Args: 202 | x: Inputs to the layer. 203 | train: Set to `True` when training. 204 | 205 | Returns: 206 | output of a transformer encoder. 207 | """ 208 | assert x.ndim == 3 # (batch, len, emb) 209 | 210 | if self.add_position_embedding: 211 | x = AddPositionEmbs( 212 | posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. 213 | name="posembed_input", 214 | )(x) 215 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 216 | 217 | # Input Encoder 218 | for lyr in range(self.num_layers): 219 | x = Encoder1DBlock( 220 | mlp_dim=self.mlp_dim, 221 | dropout_rate=self.dropout_rate, 222 | attention_dropout_rate=self.attention_dropout_rate, 223 | name=f"encoderblock_{lyr}", 224 | num_heads=self.num_attention_heads, 225 | )(x, attention_mask, deterministic=not train) 226 | encoded = nn.LayerNorm(name="encoder_norm")(x) 227 | 228 | return encoded 229 | 230 | 231 | def common_transformer_sizes(transformer_size: str) -> (int, dict): 232 | """ 233 | Args: 234 | transformer_size (str): The size of the transformer. One of "dummy", "vanilla", "vit_s", "vit_b", "vit_l", "vit_h" 235 | 236 | Returns: 237 | token_embedding_size (int): The size of the token embeddings 238 | transformer_kwargs (dict): The kwargs to pass to the transformer 239 | 240 | """ 241 | assert transformer_size in [ 242 | "dummy", 243 | "vanilla", 244 | "vit_t", 245 | "vit_s", 246 | "vit_b", 247 | "vit_l", 248 | "vit_h", 249 | ] 250 | default_params = { 251 | "attention_dropout_rate": 0.0, 252 | "add_position_embedding": False, 253 | } 254 | 255 | TRANSFORMER_SIZES = { 256 | "dummy": dict( 257 | num_layers=1, 258 | mlp_dim=256, 259 | num_attention_heads=2, 260 | dropout_rate=0.1, 261 | ), 262 | "vanilla": dict( 263 | num_layers=4, 264 | mlp_dim=1024, 265 | num_attention_heads=8, 266 | dropout_rate=0.1, 267 | ), 268 | "vit_t": dict( 269 | num_layers=12, 270 | mlp_dim=768, 271 | num_attention_heads=3, 272 | dropout_rate=0.0, 273 | ), 274 | "vit_s": dict( 275 | num_layers=12, 276 | mlp_dim=1536, 277 | num_attention_heads=6, 278 | dropout_rate=0.0, 279 | ), 280 | "vit_b": dict( 281 | num_layers=12, 282 | mlp_dim=3072, 283 | num_attention_heads=12, 284 | dropout_rate=0.0, 285 | ), 286 | "vit_l": dict( 287 | num_layers=24, 288 | mlp_dim=4096, 289 | num_attention_heads=16, 290 | dropout_rate=0.1, 291 | ), 292 | "vit_h": dict( 293 | num_layers=32, 294 | mlp_dim=5120, 295 | num_attention_heads=16, 296 | dropout_rate=0.1, 297 | ), 298 | } 299 | 300 | TOKEN_DIMS = { 301 | "dummy": 256, 302 | "vanilla": 256, 303 | "vit_t": 192, 304 | "vit_s": 384, 305 | "vit_b": 768, 306 | "vit_l": 1024, 307 | "vit_h": 1280, 308 | } 309 | 310 | return TOKEN_DIMS[transformer_size], { 311 | **default_params, 312 | **TRANSFORMER_SIZES[transformer_size], 313 | } 314 | -------------------------------------------------------------------------------- /octo/model/components/unet.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | default_init = nn.initializers.xavier_uniform 8 | 9 | 10 | @jax.jit 11 | def mish(x): 12 | return x * jnp.tanh(jax.nn.softplus(x)) 13 | 14 | 15 | def unet_squaredcos_cap_v2(timesteps, s=0.008): 16 | t = jnp.linspace(0, timesteps, timesteps + 1) / timesteps 17 | alphas_cumprod = jnp.cos((t + s) / (1 + s) * jnp.pi * 0.5) ** 2 18 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 19 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 20 | return jnp.clip(betas, 0, 0.999) 21 | 22 | 23 | class SinusoidalPosEmb(nn.Module): 24 | features: int 25 | 26 | @nn.compact 27 | def __call__(self, x: jax.Array): 28 | half_features = self.features // 2 29 | emb = jnp.log(10000) / (half_features - 1) 30 | emb = jnp.exp(jnp.arange(half_features) * -emb) 31 | emb = x * emb 32 | emb = jnp.concatenate((jnp.sin(emb), jnp.cos(emb)), axis=-1) 33 | return emb 34 | 35 | 36 | class Downsample1d(nn.Module): 37 | features: int 38 | 39 | @nn.compact 40 | def __call__(self, x: jax.Array): 41 | return nn.Conv(self.features, kernel_size=(3,), strides=(2,))(x) 42 | 43 | 44 | class Upsample1d(nn.Module): 45 | features: int 46 | 47 | @nn.compact 48 | def __call__(self, x: jax.Array): 49 | return nn.ConvTranspose(self.features, kernel_size=(4,), strides=(2,))(x) 50 | 51 | 52 | class Conv1dBlock(nn.Module): 53 | """ 54 | Conv1d --> GroupNorm --> Mish 55 | """ 56 | 57 | features: int 58 | kernel_size: int 59 | n_groups: int 60 | 61 | @nn.compact 62 | def __call__(self, x: jax.Array): 63 | x = nn.Conv( 64 | self.features, 65 | kernel_size=(self.kernel_size,), 66 | strides=1, 67 | padding=self.kernel_size // 2, 68 | )(x) 69 | x = nn.GroupNorm(self.n_groups)(x) 70 | x = mish(x) 71 | return x 72 | 73 | 74 | class ConditionalResidualBlock1D(nn.Module): 75 | features: int 76 | kernel_size: int = 3 77 | n_groups: int = 8 78 | residual_proj: bool = False 79 | 80 | @nn.compact 81 | def __call__(self, x: jax.Array, cond: jax.Array): 82 | residual = x 83 | x = Conv1dBlock( 84 | self.features, kernel_size=self.kernel_size, n_groups=self.n_groups 85 | )(x) 86 | 87 | cond_features = 2 * self.features 88 | cond = nn.Dense(cond_features, kernel_init=default_init())(mish(cond)) 89 | scale, bias = jnp.split(cond, 2, axis=-1) 90 | # Scale, bias are (B, D) and x is shape (B, T, D) 91 | # We need to broadcast over time, so choose axis = -2 92 | x = x * jnp.expand_dims(scale, axis=-2) + jnp.expand_dims(bias, axis=-2) 93 | x = Conv1dBlock( 94 | self.features, kernel_size=self.kernel_size, n_groups=self.n_groups 95 | )(x) 96 | 97 | if self.residual_proj: 98 | residual = nn.Conv(self.features, kernel_size=(1,), strides=1, padding=0)( 99 | residual 100 | ) 101 | 102 | return x + residual 103 | 104 | 105 | class ConditionalUnet1D(nn.Module): 106 | down_features: Tuple[int] = (256, 512, 1024) 107 | mid_layers: int = 2 108 | kernel_size: int = 3 109 | n_groups: int = 8 110 | time_features: int = 256 111 | 112 | @nn.compact 113 | def __call__(self, obs, action, time, train: bool = False): 114 | # Embed the timestep 115 | time = SinusoidalPosEmb(self.time_features)(time) 116 | time = nn.Dense(4 * self.time_features, kernel_init=default_init())(time) 117 | time = mish(time) 118 | time = nn.Dense(self.time_features, kernel_init=default_init())(time) # (B, D) 119 | # Define conditioning as time and observation 120 | cond = jnp.concatenate((obs, time), axis=-1) 121 | 122 | # Project Down 123 | hidden_reps = [] 124 | for i, features in enumerate(self.down_features): 125 | # We always project to the dimension on the first residual connection. 126 | action = ConditionalResidualBlock1D( 127 | features, 128 | kernel_size=self.kernel_size, 129 | n_groups=self.n_groups, 130 | residual_proj=True, 131 | )(action, cond) 132 | action = ConditionalResidualBlock1D( 133 | features, kernel_size=self.kernel_size, n_groups=self.n_groups 134 | )(action, cond) 135 | if i != 0: 136 | hidden_reps.append(action) 137 | if i != len(self.down_features) - 1: 138 | # If we aren't the last step, downsample 139 | action = Downsample1d(features)(action) 140 | 141 | # Mid Layers 142 | for _ in range(self.mid_layers): 143 | action = ConditionalResidualBlock1D( 144 | self.down_features[-1], 145 | kernel_size=self.kernel_size, 146 | n_groups=self.n_groups, 147 | )(action, cond) 148 | 149 | # Project Up 150 | for features, hidden_rep in reversed( 151 | list(zip(self.down_features[:-1], hidden_reps, strict=False)) 152 | ): 153 | action = jnp.concatenate( 154 | (action, hidden_rep), axis=-1 155 | ) # concat on feature dim 156 | # Always project since we are adding in the hidden rep 157 | action = ConditionalResidualBlock1D( 158 | features, 159 | kernel_size=self.kernel_size, 160 | n_groups=self.n_groups, 161 | residual_proj=True, 162 | )(action, cond) 163 | action = ConditionalResidualBlock1D( 164 | features, kernel_size=self.kernel_size, n_groups=self.n_groups 165 | )(action, cond) 166 | # Upsample 167 | action = Upsample1d(features)(action) 168 | 169 | # Should be the same as the input shape 170 | action = Conv1dBlock( 171 | self.down_features[0], kernel_size=self.kernel_size, n_groups=self.n_groups 172 | )(action) 173 | return action 174 | -------------------------------------------------------------------------------- /octo/model/components/vit_encoders.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoders more suitable for ViT architectures. 3 | 4 | - PatchEncoder: Just patchifies the image 5 | - SmallStem: 3 conv layers, then patchifies the image (from xiao et al. 2021) 6 | - ViTResnet: ResNetv2, followed by patchification (from google-research/vision_transformer) 7 | """ 8 | 9 | import functools as ft 10 | from typing import Callable, Sequence, TypeVar 11 | 12 | from flax import linen as nn 13 | import jax.numpy as jnp 14 | 15 | from octo.model.components.film_conditioning_layer import FilmConditioning 16 | 17 | T = TypeVar("T") 18 | 19 | 20 | def normalize_images(img, img_norm_type="default"): 21 | if img_norm_type == "default": 22 | # put pixels in [-1, 1] 23 | return img.astype(jnp.float32) / 127.5 - 1.0 24 | elif img_norm_type == "imagenet": 25 | # put pixels in [0,1] 26 | img = img.astype(jnp.float32) / 255 27 | assert img.shape[-1] % 3 == 0, "images should have rgb channels!" 28 | 29 | # define pixel-wise mean/std stats calculated from ImageNet 30 | mean = jnp.array([0.485, 0.456, 0.406]).reshape((1, 1, 1, 3)) 31 | std = jnp.array([0.229, 0.224, 0.225]).reshape((1, 1, 1, 3)) 32 | 33 | # tile mean and std (to account for stacked early_fusion images) 34 | num_tile = (1, 1, 1, int(img.shape[-1] / 3)) 35 | mean_tile = jnp.tile(mean, num_tile) 36 | std_tile = jnp.tile(std, num_tile) 37 | 38 | # tile the mean/std, normalize image, and return 39 | return (img - mean_tile) / std_tile 40 | raise ValueError() 41 | 42 | 43 | def weight_standardize(w, axis, eps): 44 | """Subtracts mean and divides by standard deviation.""" 45 | w = w - jnp.mean(w, axis=axis) 46 | w = w / (jnp.std(w, axis=axis) + eps) 47 | return w 48 | 49 | 50 | class StdConv(nn.Conv): 51 | """Convolution with weight standardization.""" 52 | 53 | def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: 54 | param = super().param(name, init_fn, *init_args) 55 | if name == "kernel": 56 | param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5) 57 | return param 58 | 59 | 60 | class PatchEncoder(nn.Module): 61 | """Takes an image and breaks it up into patches of size (patch_size x patch_size), 62 | applying a fully connected network to each patch individually. 63 | 64 | The default "encoder" used by most ViTs in practice. 65 | """ 66 | 67 | use_film: bool = False 68 | patch_size: int = 32 69 | num_features: int = 512 70 | img_norm_type: str = "default" 71 | 72 | @nn.compact 73 | def __call__(self, observations: jnp.ndarray, train: bool = True, cond_var=None): 74 | expecting_cond_var = self.use_film 75 | received_cond_var = cond_var is not None 76 | assert ( 77 | expecting_cond_var == received_cond_var 78 | ), "Only pass in cond var iff model expecting cond var" 79 | x = normalize_images(observations, self.img_norm_type) 80 | x = nn.Conv( 81 | features=self.num_features, 82 | kernel_size=(self.patch_size, self.patch_size), 83 | strides=(self.patch_size, self.patch_size), 84 | padding="VALID", 85 | name="embedding", 86 | )(x) 87 | if self.use_film: 88 | assert cond_var is not None, "Cond var is None, nothing to condition on" 89 | x = FilmConditioning()(x, cond_var) 90 | return x 91 | 92 | 93 | class SmallStem(nn.Module): 94 | """Passes the image through a few light-weight convolutional layers, 95 | before patchifying the image. Empirically useful for many computer vision tasks. 96 | 97 | See Xiao et al: Early Convolutions Help Transformers See Better 98 | """ 99 | 100 | use_film: bool = False 101 | patch_size: int = 32 102 | kernel_sizes: tuple = (3, 3, 3, 3) 103 | strides: tuple = (2, 2, 2, 2) 104 | features: tuple = (32, 96, 192, 384) 105 | padding: tuple = (1, 1, 1, 1) 106 | num_features: int = 512 107 | img_norm_type: str = "default" 108 | 109 | @nn.compact 110 | def __call__(self, observations: jnp.ndarray, train: bool = True, cond_var=None): 111 | expecting_cond_var = self.use_film 112 | received_cond_var = cond_var is not None 113 | assert ( 114 | expecting_cond_var == received_cond_var 115 | ), "Only pass in cond var iff model expecting cond var" 116 | 117 | x = normalize_images(observations, self.img_norm_type) 118 | for n, (kernel_size, stride, features, padding) in enumerate( 119 | zip( 120 | self.kernel_sizes, 121 | self.strides, 122 | self.features, 123 | self.padding, 124 | ) 125 | ): 126 | x = StdConv( 127 | features=features, 128 | kernel_size=(kernel_size, kernel_size), 129 | strides=(stride, stride), 130 | padding=padding, 131 | )(x) 132 | x = nn.GroupNorm()(x) 133 | x = nn.relu(x) 134 | 135 | x = nn.Conv( 136 | features=self.num_features, 137 | kernel_size=(self.patch_size // 16, self.patch_size // 16), 138 | strides=(self.patch_size // 16, self.patch_size // 16), 139 | padding="VALID", 140 | name="embedding", 141 | )(x) 142 | if self.use_film: 143 | assert cond_var is not None, "Cond var is None, nothing to condition on" 144 | x = FilmConditioning()(x, cond_var) 145 | return x 146 | 147 | 148 | class ResidualUnit(nn.Module): 149 | """Bottleneck ResNet block.""" 150 | 151 | features: int 152 | strides: Sequence[int] = (1, 1) 153 | 154 | @nn.compact 155 | def __call__(self, x): 156 | needs_projection = x.shape[-1] != self.features * 4 or self.strides != (1, 1) 157 | 158 | residual = x 159 | if needs_projection: 160 | residual = StdConv( 161 | features=self.features * 4, 162 | kernel_size=(1, 1), 163 | strides=self.strides, 164 | use_bias=False, 165 | name="conv_proj", 166 | )(residual) 167 | residual = nn.GroupNorm(name="gn_proj")(residual) 168 | 169 | y = StdConv( 170 | features=self.features, kernel_size=(1, 1), use_bias=False, name="conv1" 171 | )(x) 172 | y = nn.GroupNorm(name="gn1")(y) 173 | y = nn.relu(y) 174 | y = StdConv( 175 | features=self.features, 176 | kernel_size=(3, 3), 177 | strides=self.strides, 178 | use_bias=False, 179 | name="conv2", 180 | )(y) 181 | y = nn.GroupNorm(name="gn2")(y) 182 | y = nn.relu(y) 183 | y = StdConv( 184 | features=self.features * 4, kernel_size=(1, 1), use_bias=False, name="conv3" 185 | )(y) 186 | 187 | y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y) 188 | y = nn.relu(residual + y) 189 | return y 190 | 191 | 192 | class ResNetStage(nn.Module): 193 | """A ResNet stage.""" 194 | 195 | block_size: Sequence[int] 196 | nout: int 197 | first_stride: Sequence[int] 198 | 199 | @nn.compact 200 | def __call__(self, x): 201 | x = ResidualUnit(self.nout, strides=self.first_stride, name="unit1")(x) 202 | for i in range(1, self.block_size): 203 | x = ResidualUnit(self.nout, strides=(1, 1), name=f"unit{i + 1}")(x) 204 | return x 205 | 206 | 207 | class ViTResnet(nn.Module): 208 | """Resnet-v2 architecture used in the original ViT paper for hybrid (Resnet+ViT) architectures 209 | 210 | Mostly copied from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py 211 | 212 | There exist pre-trained parameters here: github.com/google-research/vision_transformer/ 213 | """ 214 | 215 | use_film: bool = False 216 | width: int = 1 217 | num_layers: tuple = tuple() 218 | img_norm_type: str = "default" 219 | 220 | @nn.compact 221 | def __call__(self, observations: jnp.ndarray, train: bool = True, cond_var=None): 222 | expecting_cond_var = self.use_film 223 | received_cond_var = cond_var is not None 224 | assert ( 225 | expecting_cond_var == received_cond_var 226 | ), "Only pass in cond var iff model expecting cond var" 227 | 228 | x = normalize_images(observations, self.img_norm_type) 229 | width = int(64 * self.width) 230 | x = StdConv( 231 | features=width, 232 | kernel_size=(7, 7), 233 | strides=(2, 2), 234 | use_bias=False, 235 | name="conv_root", 236 | )(x) 237 | x = nn.GroupNorm(name="gn_root")(x) 238 | x = nn.relu(x) 239 | x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") 240 | 241 | if self.num_layers: 242 | x = ResNetStage( 243 | block_size=self.num_layers[0], 244 | nout=width, 245 | first_stride=(1, 1), 246 | name="block1", 247 | )(x) 248 | for i, block_size in enumerate(self.num_layers[1:], 1): 249 | x = ResNetStage( 250 | block_size=block_size, 251 | nout=width * 2**i, 252 | first_stride=(2, 2), 253 | name=f"block{i + 1}", 254 | )(x) 255 | if self.use_film: 256 | assert ( 257 | cond_var is not None 258 | ), "Cond var is None, nothing to condition on" 259 | x = FilmConditioning()(x, cond_var) 260 | else: 261 | if self.use_film: 262 | assert cond_var is not None, "Cond var is None, nothing to condition on" 263 | x = FilmConditioning()(x, cond_var) 264 | 265 | return x 266 | 267 | 268 | class SmallStem16(SmallStem): 269 | patch_size: int = 16 270 | 271 | 272 | class SmallStem32(SmallStem): 273 | patch_size: int = 32 274 | 275 | 276 | class ResNet26FILM(ViTResnet): 277 | use_film: bool = True 278 | num_layers: tuple = (2, 2, 2, 2) 279 | 280 | 281 | vit_encoder_configs = { 282 | "patchify-32-film": ft.partial( 283 | PatchEncoder, 284 | use_film=True, 285 | patch_size=32, 286 | ), 287 | "patchify-16-film": ft.partial( 288 | PatchEncoder, 289 | use_film=True, 290 | patch_size=16, 291 | ), 292 | "small-stem-8-film": ft.partial( 293 | SmallStem, 294 | use_film=True, 295 | patch_size=16, 296 | kernel_sizes=(3, 3, 3), 297 | strides=(2, 2, 2), 298 | features=(32, 96, 192), 299 | padding=(1, 1, 1), 300 | ), 301 | "small-stem-16": ft.partial( 302 | SmallStem, 303 | patch_size=16, 304 | ), 305 | "small-stem-16-film": ft.partial( 306 | SmallStem, 307 | use_film=True, 308 | patch_size=16, 309 | ), 310 | "small-stem-32-film": ft.partial( 311 | SmallStem, 312 | use_film=True, 313 | patch_size=32, 314 | ), 315 | "resnetv2-26-film": ft.partial( 316 | ViTResnet, 317 | use_film=True, 318 | num_layers=(2, 2, 2, 2), 319 | ), 320 | "resnetv2-50-film": ft.partial( 321 | ViTResnet, 322 | use_film=True, 323 | num_layers=(3, 4, 6, 3), 324 | ), 325 | } 326 | -------------------------------------------------------------------------------- /octo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/octo/utils/__init__.py -------------------------------------------------------------------------------- /octo/utils/gym_wrappers.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import logging 3 | from typing import Dict, Optional, Sequence, Tuple 4 | 5 | import gym 6 | import gym.spaces 7 | import jax 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | 12 | def stack_and_pad(history: deque, num_obs: int): 13 | """ 14 | Converts a list of observation dictionaries (`history`) into a single observation dictionary 15 | by stacking the values. Adds a padding mask to the observation that denotes which timesteps 16 | represent padding based on the number of observations seen so far (`num_obs`). 17 | """ 18 | horizon = len(history) 19 | full_obs = {k: np.stack([dic[k] for dic in history]) for k in history[0]} 20 | pad_length = horizon - min(num_obs, horizon) 21 | timestep_pad_mask = np.ones(horizon) 22 | timestep_pad_mask[:pad_length] = 0 23 | full_obs["timestep_pad_mask"] = timestep_pad_mask 24 | return full_obs 25 | 26 | 27 | def space_stack(space: gym.Space, repeat: int): 28 | """ 29 | Creates new Gym space that represents the original observation/action space 30 | repeated `repeat` times. 31 | """ 32 | 33 | if isinstance(space, gym.spaces.Box): 34 | return gym.spaces.Box( 35 | low=np.repeat(space.low[None], repeat, axis=0), 36 | high=np.repeat(space.high[None], repeat, axis=0), 37 | dtype=space.dtype, 38 | ) 39 | elif isinstance(space, gym.spaces.Discrete): 40 | return gym.spaces.MultiDiscrete([space.n] * repeat) 41 | elif isinstance(space, gym.spaces.Dict): 42 | return gym.spaces.Dict( 43 | {k: space_stack(v, repeat) for k, v in space.spaces.items()} 44 | ) 45 | else: 46 | raise ValueError(f"Space {space} is not supported by Octo Gym wrappers.") 47 | 48 | 49 | def listdict2dictlist(LD): 50 | return {k: [dic[k] for dic in LD] for k in LD[0]} 51 | 52 | 53 | def add_octo_env_wrappers( 54 | env: gym.Env, 55 | action_proprio_metadata: dict, 56 | horizon: int, 57 | exec_horizon: int, 58 | resize_size: Optional[Dict[str, Tuple]] = None, 59 | use_temp_ensembling: bool = True, 60 | ): 61 | """Adds env wrappers for proprio normalization, action prediction, 62 | image resizing, and history stacking. 63 | 64 | Arguments: 65 | env: gym Env 66 | action_proprio_metadata: dict containing proprio stats for NormalizeProprio 67 | horizon: int for HistoryWrapper 68 | exec_horizon: int for RHCWrapper or TemporalEnsembleWrapper 69 | resize_size: None or tuple or list of tuples for ResizeImageWrapper 70 | use_temp_ensembling: whether to use TemporalEnsembleWrapper or RHCWrapper 71 | """ 72 | env = NormalizeProprio(env, action_proprio_metadata) 73 | env = ResizeImageWrapper(env, resize_size) 74 | 75 | env = HistoryWrapper(env, horizon) 76 | 77 | if use_temp_ensembling: 78 | env = TemporalEnsembleWrapper(env, exec_horizon) 79 | else: 80 | env = RHCWrapper(env, exec_horizon) 81 | 82 | return env 83 | 84 | 85 | class HistoryWrapper(gym.Wrapper): 86 | """ 87 | Accumulates the observation history into `horizon` size chunks. If the length of the history 88 | is less than the length of the horizon, we pad the history to the full horizon length. 89 | A `timestep_pad_mask` key is added to the final observation dictionary that denotes which timesteps 90 | are padding. 91 | """ 92 | 93 | def __init__(self, env: gym.Env, horizon: int): 94 | super().__init__(env) 95 | self.horizon = horizon 96 | 97 | self.history = deque(maxlen=self.horizon) 98 | self.num_obs = 0 99 | 100 | self.observation_space = space_stack(self.env.observation_space, self.horizon) 101 | 102 | def step(self, action): 103 | obs, reward, done, trunc, info = self.env.step(action) 104 | self.num_obs += 1 105 | self.history.append(obs) 106 | assert len(self.history) == self.horizon 107 | full_obs = stack_and_pad(self.history, self.num_obs) 108 | 109 | return full_obs, reward, done, trunc, info 110 | 111 | def reset(self, **kwargs): 112 | obs, info = self.env.reset(**kwargs) 113 | self.num_obs = 1 114 | self.history.extend([obs] * self.horizon) 115 | full_obs = stack_and_pad(self.history, self.num_obs) 116 | 117 | return full_obs, info 118 | 119 | 120 | class RHCWrapper(gym.Wrapper): 121 | """ 122 | Performs receding horizon control. The policy returns `pred_horizon` actions and 123 | we execute `exec_horizon` of them. 124 | """ 125 | 126 | def __init__(self, env: gym.Env, exec_horizon: int): 127 | super().__init__(env) 128 | self.exec_horizon = exec_horizon 129 | 130 | def step(self, actions): 131 | if self.exec_horizon == 1 and len(actions.shape) == 1: 132 | actions = actions[None] 133 | assert len(actions) >= self.exec_horizon 134 | rewards = [] 135 | observations = [] 136 | infos = [] 137 | 138 | for i in range(self.exec_horizon): 139 | obs, reward, done, trunc, info = self.env.step(actions[i]) 140 | observations.append(obs) 141 | rewards.append(reward) 142 | infos.append(info) 143 | 144 | if done or trunc: 145 | break 146 | 147 | infos = listdict2dictlist(infos) 148 | infos["rewards"] = rewards 149 | infos["observations"] = observations 150 | 151 | return obs, np.sum(rewards), done, trunc, infos 152 | 153 | 154 | class TemporalEnsembleWrapper(gym.Wrapper): 155 | """ 156 | Performs temporal ensembling from https://arxiv.org/abs/2304.13705 157 | At every timestep we execute an exponential weighted average of the last 158 | `pred_horizon` predictions for that timestep. 159 | """ 160 | 161 | def __init__(self, env: gym.Env, pred_horizon: int, exp_weight: int = 0): 162 | super().__init__(env) 163 | self.pred_horizon = pred_horizon 164 | self.exp_weight = exp_weight 165 | 166 | self.act_history = deque(maxlen=self.pred_horizon) 167 | 168 | self.action_space = space_stack(self.env.action_space, self.pred_horizon) 169 | 170 | def step(self, actions): 171 | assert len(actions) >= self.pred_horizon 172 | 173 | self.act_history.append(actions[: self.pred_horizon]) 174 | num_actions = len(self.act_history) 175 | 176 | # select the predicted action for the current step from the history of action chunk predictions 177 | curr_act_preds = np.stack( 178 | [ 179 | pred_actions[i] 180 | for (i, pred_actions) in zip( 181 | range(num_actions - 1, -1, -1), self.act_history 182 | ) 183 | ] 184 | ) 185 | 186 | # more recent predictions get exponentially *less* weight than older predictions 187 | weights = np.exp(-self.exp_weight * np.arange(num_actions)) 188 | weights = weights / weights.sum() 189 | # compute the weighted average across all predictions for this timestep 190 | action = np.sum(weights[:, None] * curr_act_preds, axis=0) 191 | 192 | return self.env.step(action) 193 | 194 | def reset(self, **kwargs): 195 | self.act_history = deque(maxlen=self.pred_horizon) 196 | return self.env.reset(**kwargs) 197 | 198 | 199 | class ResizeImageWrapper(gym.ObservationWrapper): 200 | """ 201 | Resizes images from a robot environment to the size the model expects. 202 | 203 | We attempt to match the resizing operations done in the model's data pipeline. 204 | First, we resize the image using lanczos interpolation to match the resizing done 205 | when converting the raw data into RLDS. Then, we crop and resize the image with 206 | bilinear interpolation to match the average of the crop and resize image augmentation 207 | performed during training. 208 | """ 209 | 210 | def __init__( 211 | self, 212 | env: gym.Env, 213 | resize_size: Optional[Dict[str, Tuple]] = None, 214 | augmented_keys: Sequence[str] = ("image_primary",), 215 | avg_scale: float = 0.9, 216 | avg_ratio: float = 1.0, 217 | ): 218 | super().__init__(env) 219 | assert isinstance( 220 | self.observation_space, gym.spaces.Dict 221 | ), "Only Dict observation spaces are supported." 222 | spaces = self.observation_space.spaces 223 | self.resize_size = resize_size 224 | self.augmented_keys = augmented_keys 225 | if len(self.augmented_keys) > 0: 226 | new_height = tf.clip_by_value(tf.sqrt(avg_scale / avg_ratio), 0, 1) 227 | new_width = tf.clip_by_value(tf.sqrt(avg_scale * avg_ratio), 0, 1) 228 | height_offset = (1 - new_height) / 2 229 | width_offset = (1 - new_width) / 2 230 | self.bounding_box = tf.stack( 231 | [ 232 | height_offset, 233 | width_offset, 234 | height_offset + new_height, 235 | width_offset + new_width, 236 | ], 237 | ) 238 | 239 | if resize_size is None: 240 | self.keys_to_resize = {} 241 | else: 242 | self.keys_to_resize = { 243 | f"image_{i}": resize_size[i] for i in resize_size.keys() 244 | } 245 | logging.info(f"Resizing images: {self.keys_to_resize}") 246 | for k, size in self.keys_to_resize.items(): 247 | spaces[k] = gym.spaces.Box( 248 | low=0, 249 | high=255, 250 | shape=size + (3,), 251 | dtype=np.uint8, 252 | ) 253 | self.observation_space = gym.spaces.Dict(spaces) 254 | 255 | def observation(self, observation): 256 | for k, size in self.keys_to_resize.items(): 257 | image = tf.image.resize( 258 | observation[k], size=size, method="lanczos3", antialias=True 259 | ) 260 | 261 | # if this image key was augmented with random resizes and crops, 262 | # we perform the average of the augmentation here 263 | if k in self.augmented_keys: 264 | image = tf.image.crop_and_resize( 265 | image[None], self.bounding_box[None], [0], size 266 | )[0] 267 | 268 | image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy() 269 | 270 | observation[k] = image 271 | return observation 272 | 273 | 274 | class NormalizeProprio(gym.ObservationWrapper): 275 | """ 276 | Un-normalizes the proprio. 277 | """ 278 | 279 | def __init__( 280 | self, 281 | env: gym.Env, 282 | action_proprio_metadata: dict, 283 | ): 284 | self.action_proprio_metadata = jax.tree_map( 285 | lambda x: np.array(x), 286 | action_proprio_metadata, 287 | is_leaf=lambda x: isinstance(x, list), 288 | ) 289 | super().__init__(env) 290 | 291 | def normalize(self, data, metadata): 292 | mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool)) 293 | return np.where( 294 | mask, 295 | (data - metadata["mean"]) / (metadata["std"] + 1e-8), 296 | data, 297 | ) 298 | 299 | def observation(self, obs): 300 | if "proprio" in self.action_proprio_metadata: 301 | obs["proprio"] = self.normalize( 302 | obs["proprio"], self.action_proprio_metadata["proprio"] 303 | ) 304 | else: 305 | assert "proprio" not in obs, "Cannot normalize proprio without metadata." 306 | return obs 307 | -------------------------------------------------------------------------------- /octo/utils/jax_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Any, Optional, Sequence 4 | 5 | import jax 6 | from jax.experimental import multihost_utils 7 | from jax.experimental.compilation_cache import compilation_cache 8 | import jax.numpy as jnp 9 | import numpy as np 10 | 11 | 12 | def host_broadcast_str(x: str) -> str: 13 | """Broadcast_one_to_all, but with a string. Strings should all be the same length.""" 14 | multihost_utils.assert_equal( 15 | len(x), f"String lengths are not equal: got {len(x)} for {jax.process_index()}" 16 | ) 17 | encoded = np.array([ord(c) for c in x], dtype=np.uint8) 18 | encoded = multihost_utils.broadcast_one_to_all(encoded) 19 | return "".join([chr(u) for u in encoded]) 20 | 21 | 22 | def shard_along_axis(x: Any, devices: Sequence[jax.Device], axis: int = 0) -> jax.Array: 23 | """Shard a PyTree of arrays along a given axis, putting them on device in 24 | the process. Works in multi-host setting as long as PyTrees are equal on all 25 | hosts.""" 26 | sharding = jax.sharding.NamedSharding( 27 | jax.sharding.Mesh(devices, "x"), 28 | jax.sharding.PartitionSpec(*([None] * axis + ["x"])), 29 | ) 30 | x = jax.tree_map(jnp.array, x) 31 | return jax.tree_map( 32 | lambda arr: jax.make_array_from_callback( 33 | arr.shape, sharding, lambda index: arr[index] 34 | ), 35 | x, 36 | ) 37 | 38 | 39 | def merge_along_axis(x: Any, axis: int = 0) -> jax.Array: 40 | """Convert a PyTree of host-local arrays to a global array, concatenating and sharding along 41 | `axis`.""" 42 | return multihost_utils.host_local_array_to_global_array( 43 | x, 44 | jax.sharding.Mesh(jax.devices(), "x"), 45 | jax.sharding.PartitionSpec(*([None] * axis + ["x"])), 46 | ) 47 | 48 | 49 | def split_along_axis(x: Any, axis: int = 0) -> jax.Array: 50 | """Convert a PyTree of global arrays to a host-local array, splitting along `axis`.""" 51 | return multihost_utils.global_array_to_host_local_array( 52 | x, 53 | jax.sharding.Mesh(jax.devices(), "x"), 54 | jax.sharding.PartitionSpec(*([None] * axis + ["x"])), 55 | ) 56 | 57 | 58 | def replicate(x: Any, devices: Optional[Sequence[jax.Device]] = None) -> jax.Array: 59 | """Replicate a PyTree of arrays across devices. Works in multi-host setting 60 | as long as PyTrees are equal on all hosts.""" 61 | if devices is None: 62 | devices = jax.devices() 63 | sharding = jax.sharding.PositionalSharding(devices).replicate() 64 | x = jax.tree_map(jnp.array, x) 65 | return jax.tree_map( 66 | lambda arr: jax.make_array_from_callback( 67 | arr.shape, sharding, lambda index: arr[index] 68 | ), 69 | x, 70 | ) 71 | 72 | 73 | def initialize_compilation_cache( 74 | cache_dir=os.path.expanduser("~/.jax_compilation_cache"), 75 | ): 76 | """Initializes the Jax persistent compilation cache.""" 77 | compilation_cache.initialize_cache(cache_dir) 78 | for logger in [logging.getLogger(name) for name in logging.root.manager.loggerDict]: 79 | logger.addFilter( 80 | lambda record: "Not writing persistent cache entry for" 81 | not in record.getMessage() 82 | ) 83 | -------------------------------------------------------------------------------- /octo/utils/spec.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import importlib 3 | from typing import Any, Dict, Tuple, TypedDict, Union 4 | 5 | 6 | class ModuleSpec(TypedDict): 7 | """A JSON-serializable representation of a function or class with some default args and kwargs to pass to 8 | it. Useful for specifying a particular class or function in a config file, while keeping it serializable 9 | and overridable from the command line using ml_collections. 10 | 11 | Usage: 12 | 13 | # Preferred way to create a spec: 14 | >>> from octo.model.components.transformer import Transformer 15 | >>> spec = ModuleSpec.create(Transformer, num_layers=3) 16 | # Same as above using the fully qualified import string: 17 | >>> spec = ModuleSpec.create("octo.model.components.transformer:Transformer", num_layers=3) 18 | 19 | # Usage: 20 | >>> ModuleSpec.instantiate(spec) == partial(Transformer, num_layers=3) 21 | # can pass additional kwargs at instantiation time 22 | >>> transformer = ModuleSpec.instantiate(spec, num_heads=8) 23 | 24 | Note: ModuleSpec is just an alias for a dictionary (that is strongly typed), not a real class. So from 25 | your code's perspective, it is just a dictionary. 26 | 27 | module (str): The module the callable is located in 28 | name (str): The name of the callable in the module 29 | args (tuple): The args to pass to the callable 30 | kwargs (dict): The kwargs to pass to the callable 31 | """ 32 | 33 | module: str 34 | name: str 35 | args: Tuple[Any, ...] 36 | kwargs: Dict[str, Any] 37 | 38 | @staticmethod 39 | def create(callable_or_full_name: Union[str, callable], *args, **kwargs) -> "ModuleSpec": # type: ignore 40 | """Create a module spec from a callable or import string. 41 | 42 | Args: 43 | callable_or_full_name (str or object): Either the object itself or a fully qualified import string 44 | (e.g. "octo.model.components.transformer:Transformer") 45 | args (tuple, optional): Passed into callable upon instantiation. 46 | kwargs (dict, optional): Passed into callable upon instantiation. 47 | """ 48 | if isinstance(callable_or_full_name, str): 49 | assert callable_or_full_name.count(":") == 1, ( 50 | "If passing in a string, it must be a fully qualified import string " 51 | "(e.g. 'octo.model.components.transformer:Transformer')" 52 | ) 53 | module, name = callable_or_full_name.split(":") 54 | else: 55 | module, name = _infer_full_name(callable_or_full_name) 56 | 57 | return ModuleSpec(module=module, name=name, args=args, kwargs=kwargs) 58 | 59 | @staticmethod 60 | def instantiate(spec: "ModuleSpec"): # type: ignore 61 | if set(spec.keys()) != {"module", "name", "args", "kwargs"}: 62 | raise ValueError( 63 | f"Expected ModuleSpec, but got {spec}. " 64 | "ModuleSpec must have keys 'module', 'name', 'args', and 'kwargs'." 65 | ) 66 | cls = _import_from_string(spec["module"], spec["name"]) 67 | return partial(cls, *spec["args"], **spec["kwargs"]) 68 | 69 | @staticmethod 70 | def to_string(spec: "ModuleSpec"): # type: ignore 71 | return ( 72 | f"{spec['module']}:{spec['name']}" 73 | f"({', '.join(spec['args'])}" 74 | f"{', ' if spec['args'] and spec['kwargs'] else ''}" 75 | f"{', '.join(f'{k}={v}' for k, v in spec['kwargs'].items())})" 76 | ) 77 | 78 | 79 | def _infer_full_name(o: object): 80 | if hasattr(o, "__module__") and hasattr(o, "__name__"): 81 | return o.__module__, o.__name__ 82 | else: 83 | raise ValueError( 84 | f"Could not infer identifier for {o}. " 85 | "Please pass in a fully qualified import string instead " 86 | "e.g. 'octo.model.components.transformer:Transformer'" 87 | ) 88 | 89 | 90 | def _import_from_string(module_string: str, name: str): 91 | try: 92 | module = importlib.import_module(module_string) 93 | return getattr(module, name) 94 | except Exception as e: 95 | raise ValueError(f"Could not import {module_string}:{name}") from e 96 | -------------------------------------------------------------------------------- /octo/utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Sequence, Union 2 | 3 | import jax 4 | 5 | PRNGKey = jax.random.KeyArray 6 | PyTree = Union[jax.typing.ArrayLike, Mapping[str, "PyTree"]] 7 | Config = Union[Any, Mapping[str, "Config"]] 8 | Params = Mapping[str, PyTree] 9 | Data = Mapping[str, PyTree] 10 | Shape = Sequence[int] 11 | Dtype = jax.typing.DTypeLike 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | # https://github.com/psf/black 3 | line-length = 88 4 | target-version = ["py310"] 5 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)" 6 | 7 | [tool.isort] 8 | profile = "black" 9 | line_length = 88 10 | force_sort_within_sections = "True" 11 | order_by_type = "False" 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.26 2 | numpy == 1.24.3 3 | ml_dtypes == 0.2.0 4 | chex == 0.1.85 5 | optax == 0.1.5 6 | tensorflow_probability == 0.23.0 7 | tensorflow == 2.15.0 8 | jax == 0.4.20 9 | distrax == 0.1.5 10 | flax == 0.7.5 11 | ml_collections >= 0.1.0 12 | tqdm >= 4.60.0 13 | absl-py >= 0.12.0 14 | scipy >= 1.6.0 15 | wandb >= 0.12.14 16 | einops >= 0.6.1 17 | imageio >= 2.31.1 18 | moviepy >= 1.0.3 19 | pre-commit == 3.3.3 20 | transformers >= 4.34.1 21 | tensorflow_hub >= 0.14.0 22 | tensorflow_text >= 2.13.0 23 | tensorflow_datasets == 4.9.2 24 | tensorflow_graphics == 2021.12.3 25 | dlimp @ git+https://github.com/kvablack/dlimp@5edaa4691567873d495633f2708982b42edf1972 26 | plotly >= 5.16.1 27 | matplotlib 28 | -------------------------------------------------------------------------------- /scripts/configs/config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | from ml_collections.config_dict import FieldReference, placeholder 3 | 4 | from octo.data.utils.text_processing import MuseEmbedding 5 | from octo.model.components.action_heads import MSEActionHead 6 | from octo.model.components.tokenizers import ImageTokenizer 7 | from octo.model.components.transformer import common_transformer_sizes 8 | from octo.model.components.vit_encoders import SmallStem16 9 | from octo.utils.spec import ModuleSpec 10 | 11 | 12 | def get_config( 13 | transformer_size="vit_s", 14 | ): 15 | print("Creating config with: ", locals()) 16 | window_size = FieldReference(default=1) 17 | return ConfigDict( 18 | dict( 19 | seed=42, 20 | num_steps=2e6, 21 | save_dir=placeholder(str), 22 | model=get_model_config(transformer_size), 23 | window_size=window_size, 24 | dataset_kwargs=get_dataset_config(window_size), 25 | optimizer=dict( 26 | learning_rate=dict( 27 | name="rsqrt", 28 | init_value=0.0, 29 | peak_value=3e-4, 30 | warmup_steps=2000, 31 | timescale=10000, 32 | ), 33 | weight_decay=0.1, 34 | clip_gradient=1.0, 35 | frozen_keys=tuple(), 36 | ), 37 | prefetch_num_batches=0, 38 | start_step=placeholder(int), 39 | log_interval=100, 40 | eval_interval=5000, 41 | viz_interval=20000, 42 | save_interval=10000, 43 | val_kwargs=dict( 44 | val_shuffle_buffer_size=1000, 45 | num_val_batches=16, 46 | ), 47 | viz_kwargs=dict( 48 | eval_batch_size=128, 49 | trajs_for_metrics=100, 50 | trajs_for_viz=8, 51 | samples_per_state=8, 52 | ), 53 | resume_path=placeholder(str), 54 | text_processor=ModuleSpec.create(MuseEmbedding), 55 | pretrained_loaders=tuple(), 56 | wandb=dict( 57 | project="octo", 58 | group=placeholder(str), 59 | entity=placeholder(str), 60 | ), 61 | wandb_resume_id=placeholder(str), 62 | eval_datasets=(), 63 | ) 64 | ) 65 | 66 | 67 | def get_model_config(transformer_size): 68 | """ 69 | Transformer_size is one of ["dummy", "vanilla", "vit_t" "vit_s", "vit_b", "vit_l", "vit_h"] 70 | 71 | This model stacks all the images from different cameras together, and passes it through 72 | a small convolutional stem before entering the transformer. 73 | 74 | The action head pools all the observation token embeddings, and passes it through a small MLP 75 | before predicting the action using a MSE loss. 76 | """ 77 | token_embedding_size, transformer_kwargs = common_transformer_sizes( 78 | transformer_size 79 | ) 80 | return dict( 81 | observation_tokenizers=dict( 82 | image=ModuleSpec.create( 83 | ImageTokenizer, 84 | obs_stack_keys=["image_.*"], 85 | task_stack_keys=["image_.*"], 86 | task_film_keys=["language_instruction"], 87 | encoder=ModuleSpec.create(SmallStem16, use_film=True), 88 | ), 89 | ), 90 | task_tokenizers=dict(), 91 | heads=dict( 92 | action=ModuleSpec.create( 93 | MSEActionHead, 94 | action_horizon=1, 95 | action_dim=7, 96 | readout_key="obs", 97 | ), 98 | ), 99 | readouts=dict(), 100 | token_embedding_size=token_embedding_size, 101 | transformer_kwargs=transformer_kwargs, 102 | max_horizon=10, 103 | use_correct_attention=True, 104 | ) 105 | 106 | 107 | def get_dataset_config(window_size=1): 108 | task_augmentation = dict( 109 | task_augment_strategy="delete_task_conditioning", 110 | task_augment_kwargs=dict( 111 | keep_image_prob=0.5, 112 | ), 113 | ) 114 | 115 | return dict( 116 | # oxe_kwargs will generate dataset_kwargs_list and sampling weights 117 | oxe_kwargs=dict( 118 | data_mix=placeholder(str), 119 | data_dir=placeholder(str), 120 | load_camera_views=("primary", "wrist"), 121 | load_depth=False, 122 | ), 123 | traj_transform_kwargs=dict( 124 | window_size=window_size, 125 | action_horizon=1, 126 | goal_relabeling_strategy="uniform", 127 | subsample_length=100, 128 | **task_augmentation, 129 | ), 130 | frame_transform_kwargs=dict( 131 | resize_size=dict(primary=(256, 256)), 132 | image_dropout_prob=0.0, 133 | image_augment_kwargs=dict( 134 | primary=dict( 135 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 136 | random_brightness=[0.2], 137 | random_contrast=[0.8, 1.2], 138 | random_saturation=[0.8, 1.2], 139 | random_hue=[0.1], 140 | augment_order=[ 141 | "random_resized_crop", 142 | "random_brightness", 143 | "random_contrast", 144 | "random_saturation", 145 | "random_hue", 146 | ], 147 | ) 148 | ), 149 | num_parallel_calls=200, 150 | ), 151 | traj_transform_threads=48, # shared between all datasets 152 | traj_read_threads=48, # shared between all datasets 153 | shuffle_buffer_size=100000, # shared between all datasets 154 | batch_size=512, 155 | balance_weights=True, 156 | ) 157 | -------------------------------------------------------------------------------- /scripts/configs/finetune_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | from ml_collections.config_dict import FieldReference, placeholder 3 | 4 | from octo.utils.spec import ModuleSpec 5 | 6 | 7 | def get_config(config_string="full,multimodal"): 8 | mode, task = config_string.split(",") 9 | assert task in ["image_conditioned", "language_conditioned", "multimodal"] 10 | assert mode in ["full", "head_only", "head_mlp_only"] 11 | 12 | # Fill this in for your own dataset! 13 | 14 | # There should be two image keys 15 | # first image key should be the third-person view (None if not used) 16 | # and second image key should be the wrist view (None if not used) 17 | 18 | FINETUNING_KWARGS = { 19 | "name": "bridge_dataset", 20 | "data_dir": "./tests/debug_dataset", 21 | "image_obs_keys": {"primary": "image_0", "wrist": None}, 22 | "proprio_obs_key": "proprio", 23 | "language_key": "language_instruction", 24 | "action_proprio_normalization_type": "normal", 25 | # We want to avoid normalizing the gripper 26 | "action_normalization_mask": [True, True, True, True, True, True, False], 27 | # standardize_fn is dynamically loaded from a file 28 | # for example: "experiments/kevin/custom_standardization_transforms.py:aloha_dataset_transform" 29 | "standardize_fn": ModuleSpec.create( 30 | "octo.data.oxe.oxe_standardization_transforms:bridge_dataset_transform", 31 | ), 32 | # If the default data loading speed is too slow, try these: 33 | # "num_parallel_reads": 8, # for reading from disk / GCS 34 | # "num_parallel_calls": 16, # for initial dataset construction 35 | } 36 | 37 | if mode == "full": 38 | frozen_keys = None 39 | elif mode == "head_only": 40 | frozen_keys = ("octo_transformer.*",) 41 | elif mode == "head_mlp_only": 42 | frozen_keys = ( 43 | "octo_transformer.*", 44 | "heads_*.map_head.probe", 45 | "heads_*.map_head.MultiHeadDotProductAttention_0.*", 46 | ) 47 | else: 48 | raise ValueError("Invalid mode") 49 | 50 | max_steps = FieldReference(50000) 51 | window_size = FieldReference(default=1) 52 | 53 | config = dict( 54 | pretrained_path=placeholder(str), 55 | pretrained_step=placeholder(int), 56 | batch_size=256, 57 | shuffle_buffer_size=10000, 58 | num_steps=max_steps, 59 | log_interval=100, 60 | eval_interval=5000, 61 | save_interval=5000, 62 | save_dir=placeholder(str), 63 | seed=42, 64 | wandb=dict( 65 | project="octo_finetune", group=placeholder(str), entity=placeholder(str) 66 | ), 67 | dataset_kwargs=FINETUNING_KWARGS, 68 | modality=task, 69 | finetuning_mode=mode, 70 | window_size=window_size, 71 | optimizer=dict( 72 | learning_rate=dict( 73 | name="cosine", 74 | init_value=0.0, 75 | peak_value=3e-4, 76 | warmup_steps=2000, 77 | decay_steps=max_steps, 78 | end_value=0.0, 79 | ), 80 | weight_decay=0.01, 81 | clip_gradient=1.0, 82 | frozen_keys=frozen_keys, 83 | grad_accumulation_steps=None, # if you are using grad accumulation, you need to adjust max_steps accordingly 84 | ), 85 | val_kwargs=dict( 86 | val_shuffle_buffer_size=1000, 87 | num_val_batches=16, 88 | ), 89 | viz_kwargs=dict( 90 | eval_batch_size=128, 91 | trajs_for_metrics=100, 92 | trajs_for_viz=8, 93 | samples_per_state=8, 94 | ), 95 | ) 96 | 97 | if task == "image_conditioned": 98 | goal_relabeling_strategy = "uniform" 99 | keep_image_prob = 1.0 100 | elif task == "language_conditioned": 101 | goal_relabeling_strategy = None 102 | keep_image_prob = 0.0 103 | elif task == "multimodal": 104 | goal_relabeling_strategy = "uniform" 105 | keep_image_prob = 0.5 106 | else: 107 | raise ValueError("Invalid modality") 108 | 109 | traj_transform_kwargs = dict( 110 | window_size=window_size, 111 | action_horizon=4, 112 | goal_relabeling_strategy=goal_relabeling_strategy, 113 | task_augment_strategy="delete_task_conditioning", 114 | task_augment_kwargs=dict( 115 | keep_image_prob=keep_image_prob, 116 | ), 117 | # If the default data loading speed is too slow, try these: 118 | # num_parallel_calls=16, # for less CPU-intensive ops 119 | ) 120 | workspace_augment_kwargs = dict( 121 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 122 | random_brightness=[0.1], 123 | random_contrast=[0.9, 1.1], 124 | random_saturation=[0.9, 1.1], 125 | random_hue=[0.05], 126 | augment_order=[ 127 | "random_resized_crop", 128 | "random_brightness", 129 | "random_contrast", 130 | "random_saturation", 131 | "random_hue", 132 | ], 133 | ) 134 | wrist_augment_kwargs = dict( 135 | random_brightness=[0.1], 136 | random_contrast=[0.9, 1.1], 137 | random_saturation=[0.9, 1.1], 138 | random_hue=[0.05], 139 | augment_order=[ 140 | "random_brightness", 141 | "random_contrast", 142 | "random_saturation", 143 | "random_hue", 144 | ], 145 | ) 146 | frame_transform_kwargs = dict( 147 | resize_size={ 148 | "primary": (256, 256), # workspace (3rd person) camera is at 256x256 149 | "wrist": (128, 128), # wrist camera is at 128x128 150 | }, 151 | image_augment_kwargs=dict( 152 | primary=workspace_augment_kwargs, 153 | wrist=wrist_augment_kwargs, 154 | ), 155 | ) 156 | # If the default data loading speed is too slow, try these: 157 | config[ 158 | "frame_transform_threads" 159 | ] = 16 # for the most CPU-intensive ops (decoding, resizing, augmenting) 160 | 161 | config["traj_transform_kwargs"] = traj_transform_kwargs 162 | config["frame_transform_kwargs"] = frame_transform_kwargs 163 | return ConfigDict(config) 164 | -------------------------------------------------------------------------------- /scripts/configs/octo_pretrain_config.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import imp 3 | import os 4 | 5 | from ml_collections import ConfigDict, FieldReference 6 | 7 | get_base_config = imp.load_source( 8 | "config", os.path.join(os.path.dirname(__file__), "config.py") 9 | ).get_config 10 | 11 | from octo.data.utils.text_processing import HFTokenizer 12 | from octo.model.components.action_heads import DiffusionActionHead 13 | from octo.model.components.tokenizers import ImageTokenizer, LanguageTokenizer 14 | from octo.model.components.vit_encoders import SmallStem16 15 | from octo.utils.spec import ModuleSpec 16 | from octo.utils.train_utils import hf_weights_loader 17 | 18 | 19 | def update_config(config, **kwargs): 20 | updates = ConfigDict(kwargs) 21 | new_config = deepcopy(config) 22 | new_config.update(updates) 23 | return new_config 24 | 25 | 26 | def get_config(config_string=None): 27 | config = get_base_config(config_string) 28 | 29 | action_dim = FieldReference(7) 30 | 31 | config["model"]["observation_tokenizers"] = { 32 | "primary": ModuleSpec.create( 33 | ImageTokenizer, 34 | obs_stack_keys=["image_primary"], 35 | task_stack_keys=["image_primary"], 36 | encoder=ModuleSpec.create(SmallStem16), 37 | ), 38 | "wrist": ModuleSpec.create( 39 | ImageTokenizer, 40 | obs_stack_keys=["image_wrist"], 41 | task_stack_keys=["image_wrist"], 42 | encoder=ModuleSpec.create(SmallStem16), 43 | ), 44 | } 45 | config["model"]["task_tokenizers"] = { 46 | "language": ModuleSpec.create( 47 | LanguageTokenizer, 48 | encoder="t5-base", 49 | finetune_encoder=False, 50 | ), 51 | } 52 | config["model"]["repeat_task_tokens"] = True 53 | config["model"]["readouts"] = {"action": 1} 54 | config["model"]["heads"]["action"] = ModuleSpec.create( 55 | DiffusionActionHead, 56 | readout_key="readout_action", 57 | use_map=False, 58 | action_horizon=4, 59 | action_dim=action_dim, 60 | n_diffusion_samples=1, 61 | dropout_rate=0.0, 62 | ) 63 | 64 | # We augment differently for the primary and wrist cameras 65 | primary_augment_kwargs = dict( 66 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 67 | random_brightness=[0.1], 68 | random_contrast=[0.9, 1.1], 69 | random_saturation=[0.9, 1.1], 70 | random_hue=[0.05], 71 | augment_order=[ 72 | "random_resized_crop", 73 | "random_brightness", 74 | "random_contrast", 75 | "random_saturation", 76 | "random_hue", 77 | ], 78 | ) 79 | wrist_augment_kwargs = dict( 80 | random_brightness=[0.1], 81 | random_contrast=[0.9, 1.1], 82 | random_saturation=[0.9, 1.1], 83 | random_hue=[0.05], 84 | augment_order=[ 85 | "random_brightness", 86 | "random_contrast", 87 | "random_saturation", 88 | "random_hue", 89 | ], 90 | ) 91 | 92 | # ML-collections complains if the type of an existing field changes 93 | # so we delete and re-add the field 94 | 95 | del config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"] 96 | del config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"] 97 | 98 | config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"] = { 99 | "primary": (256, 256), # workspace camera is at 256x256 100 | "wrist": (128, 128), # wrist camera is at 128x128 101 | } 102 | config["dataset_kwargs"]["frame_transform_kwargs"]["image_augment_kwargs"] = { 103 | "primary": primary_augment_kwargs, 104 | "wrist": wrist_augment_kwargs, 105 | } 106 | 107 | config = update_config( 108 | config, 109 | num_steps=300000, 110 | window_size=2, 111 | optimizer=dict( 112 | frozen_keys=("*hf_model*",), 113 | ), 114 | dataset_kwargs=dict( 115 | oxe_kwargs=dict( 116 | data_mix="oxe_magic_soup", 117 | data_dir="gs://rail-orca-central2/resize_256_256", 118 | load_camera_views=("primary", "wrist"), 119 | load_depth=False, 120 | force_recompute_dataset_statistics=False, 121 | ), 122 | traj_transform_kwargs=dict( 123 | action_horizon=4, 124 | max_action_dim=action_dim, 125 | task_augment_strategy="delete_and_rephrase", 126 | task_augment_kwargs=dict( 127 | paraphrases_repo="rail-berkeley/OXE_paraphrases", 128 | paraphrases_filename="paraphrases_oxe.pkl", 129 | rephrase_prob=0.5, 130 | ), 131 | ), 132 | batch_size=512, 133 | shuffle_buffer_size=500000, 134 | balance_weights=True, 135 | ), 136 | text_processor=ModuleSpec.create( 137 | HFTokenizer, 138 | tokenizer_name="t5-base", 139 | encode_with_model=False, 140 | tokenizer_kwargs={ 141 | "max_length": 16, 142 | "padding": "max_length", 143 | "truncation": True, 144 | "return_tensors": "np", 145 | }, 146 | ), 147 | pretrained_loaders=( 148 | ModuleSpec.create( 149 | hf_weights_loader, 150 | hf_model="t5-base", 151 | ), 152 | ), 153 | eval_datasets=["bridge_dataset"], 154 | ) 155 | 156 | return config 157 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | # WARNING: importing tensorflow too late can silence important logging (╯°□°)╯︵ ┻━┻ 2 | import tensorflow as tf 3 | 4 | # isort: split 5 | 6 | import datetime 7 | from functools import partial 8 | import os 9 | import os.path as osp 10 | 11 | from absl import app, flags, logging 12 | from flax.traverse_util import flatten_dict 13 | import jax 14 | from jax.experimental import multihost_utils 15 | from jax.sharding import Mesh, NamedSharding, PartitionSpec 16 | from ml_collections import config_flags 17 | import optax 18 | import tqdm 19 | import wandb 20 | 21 | import octo 22 | from octo.data.dataset import make_interleaved_dataset 23 | from octo.data.oxe import make_oxe_dataset_kwargs_and_weights 24 | from octo.model.octo_model import OctoModel 25 | from octo.utils import jax_utils 26 | from octo.utils.spec import ModuleSpec 27 | from octo.utils.train_callbacks import ( 28 | RolloutVisualizationCallback, 29 | SaveCallback, 30 | ValidationCallback, 31 | VisualizationCallback, 32 | ) 33 | from octo.utils.train_utils import ( 34 | create_optimizer, 35 | filter_eval_datasets, 36 | format_name_with_config, 37 | process_text, 38 | Timer, 39 | TrainState, 40 | ) 41 | from octo.utils.typing import Data 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | flags.DEFINE_string("name", "experiment", "Experiment name.") 46 | flags.DEFINE_bool("debug", False, "Debug config (no wandb logging)") 47 | 48 | config_dir = os.path.join(os.path.dirname(__file__), "configs") 49 | config_flags.DEFINE_config_file( 50 | "config", 51 | os.path.join(config_dir, "config.py:transformer_bc"), 52 | "File path to the training hyperparameter configuration.", 53 | lock_config=False, 54 | ) 55 | 56 | 57 | def main(_): 58 | jax_utils.initialize_compilation_cache() 59 | 60 | assert FLAGS.config.dataset_kwargs.batch_size % jax.device_count() == 0 61 | assert FLAGS.config.viz_kwargs.eval_batch_size % jax.device_count() == 0 62 | assert FLAGS.config.dataset_kwargs.batch_size % jax.process_count() == 0 63 | assert FLAGS.config.viz_kwargs.eval_batch_size % jax.process_count() == 0 64 | 65 | # create a 1D mesh with a single axis named "batch" 66 | mesh = Mesh(jax.devices(), axis_names="batch") 67 | # replicated sharding -- does not shard arrays 68 | replicated_sharding = NamedSharding(mesh, PartitionSpec()) 69 | # data-parallel sharding -- shards arrays along the first axis 70 | dp_sharding = NamedSharding(mesh, PartitionSpec("batch")) 71 | 72 | def shard(batch): 73 | return multihost_utils.host_local_array_to_global_array( 74 | batch, mesh, PartitionSpec("batch") 75 | ) 76 | 77 | # prevent tensorflow from using GPUs 78 | tf.config.set_visible_devices([], "GPU") 79 | 80 | # make sure each process loads different data 81 | tf.random.set_seed(FLAGS.config.seed + jax.process_index()) 82 | 83 | # set up wandb and logging 84 | if FLAGS.config.get("wandb_resume_id", None) is None: 85 | name = format_name_with_config( 86 | FLAGS.name, 87 | FLAGS.config.to_dict(), 88 | ) 89 | wandb_id = "{name}_{time}".format( 90 | name=name, 91 | time=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), 92 | ) 93 | wandb_id = jax_utils.host_broadcast_str(wandb_id) 94 | if jax.process_index() == 0: 95 | wandb.init( 96 | config=FLAGS.config.to_dict(), 97 | id=wandb_id, 98 | name=name, 99 | mode="disabled" if FLAGS.debug else None, 100 | **FLAGS.config.wandb, 101 | ) 102 | 103 | if FLAGS.config.save_dir is not None: 104 | save_dir = tf.io.gfile.join( 105 | FLAGS.config.save_dir, 106 | FLAGS.config.wandb.project, 107 | FLAGS.config.wandb.group or "", 108 | wandb_id, 109 | ) 110 | logging.info("Saving to %s", save_dir) 111 | if jax.process_index() == 0: 112 | wandb.config.update(dict(save_dir=save_dir), allow_val_change=True) 113 | else: 114 | save_dir = None 115 | logging.info("save_dir not passed in, not saving checkpoints") 116 | else: 117 | # resume previous run 118 | wandb_run = wandb.Api().run(FLAGS.config.wandb_resume_id) 119 | if jax.process_index() == 0: 120 | wandb.init( 121 | project=wandb_run.project, 122 | id=wandb_run.id, 123 | entity=wandb_run.entity, 124 | resume="must", 125 | ) 126 | save_dir = wandb_run.config["save_dir"] 127 | logging.info("Resuming run %s", FLAGS.config.wandb_resume_id) 128 | save_callback = SaveCallback(save_dir) 129 | 130 | if jax.process_index() == 0: 131 | codebase_directory = osp.abspath(osp.join(osp.dirname(octo.__file__), "..")) 132 | wandb.run.log_code(codebase_directory) 133 | 134 | # set up text tokenization (this needs to happen after batching but before sharding) 135 | if FLAGS.config.text_processor is None: 136 | text_processor = None 137 | else: 138 | text_processor = ModuleSpec.instantiate(FLAGS.config.text_processor)() 139 | 140 | def process_batch(batch): 141 | batch = process_text(batch, text_processor) 142 | del batch["dataset_name"] 143 | return batch 144 | 145 | # load datasets 146 | if "oxe_kwargs" in FLAGS.config.dataset_kwargs: 147 | # create dataset_kwargs_list from oxe_kwargs 148 | ( 149 | FLAGS.config.dataset_kwargs["dataset_kwargs_list"], 150 | FLAGS.config.dataset_kwargs["sample_weights"], 151 | ) = make_oxe_dataset_kwargs_and_weights( 152 | **FLAGS.config.dataset_kwargs["oxe_kwargs"] 153 | ) 154 | del FLAGS.config.dataset_kwargs["oxe_kwargs"] 155 | 156 | FLAGS.config.dataset_kwargs.batch_size //= jax.process_count() 157 | train_data = make_interleaved_dataset(**FLAGS.config.dataset_kwargs, train=True) 158 | 159 | train_data_iter = map( 160 | shard, 161 | map( 162 | process_batch, 163 | train_data.iterator(prefetch=FLAGS.config.prefetch_num_batches), 164 | ), 165 | ) 166 | 167 | example_batch = next(train_data_iter) 168 | logging.info(f"Batch size: {example_batch['action'].shape[0]}") 169 | logging.info(f"Number of devices: {jax.device_count()}") 170 | logging.info( 171 | f"Batch size per device: {example_batch['action'].shape[0] // jax.device_count()}" 172 | ) 173 | 174 | # set up model and initialize weights 175 | rng = jax.random.PRNGKey(FLAGS.config.seed) 176 | rng, init_rng = jax.random.split(rng) 177 | model = OctoModel.from_config( 178 | FLAGS.config.to_dict(), 179 | example_batch, 180 | text_processor, 181 | verbose=True, 182 | rng=init_rng, 183 | dataset_statistics=train_data.dataset_statistics, 184 | ) 185 | 186 | # create optimizer 187 | tx, lr_callable, param_norm_callable = create_optimizer( 188 | model.params, 189 | **FLAGS.config.optimizer.to_dict(), 190 | ) 191 | 192 | # Load pretrained weights (e.g. text encoder) if necessary 193 | for loader in FLAGS.config.pretrained_loaders: 194 | if not callable(loader): # Means that it is a ModuleSpec 195 | loader = ModuleSpec.instantiate(loader) 196 | model = model.replace(params=loader(model.params)) 197 | 198 | # create train state 199 | train_state = TrainState.create(rng, model, tx) 200 | 201 | if FLAGS.config.get("wandb_resume_id", None) is not None: 202 | train_state = save_callback.state_checkpointer.restore( 203 | save_callback.state_checkpointer.latest_step(), items=train_state 204 | ) 205 | checkpoint_step = int(train_state.step) 206 | logging.info("Restored checkpoint from %s", save_dir) 207 | if FLAGS.config.start_step is not None: 208 | start_step = FLAGS.config.start_step # start_step overrides checkpoint 209 | else: 210 | start_step = checkpoint_step 211 | logging.info("Starting training from step %d", start_step) 212 | else: 213 | start_step = FLAGS.config.start_step or 0 214 | train_state = train_state.replace(step=start_step) 215 | 216 | # refreshes the train state so it doesn't crash w/ certain pre-trained loaders 217 | train_state = jax.device_get(train_state) 218 | 219 | def loss_fn(params, batch, rng, train=True): 220 | bound_module = model.module.bind({"params": params}, rngs={"dropout": rng}) 221 | transformer_embeddings = bound_module.octo_transformer( 222 | batch["observation"], 223 | batch["task"], 224 | batch["observation"]["timestep_pad_mask"], 225 | train=train, 226 | ) 227 | action_loss, action_metrics = bound_module.heads["action"].loss( 228 | transformer_embeddings, # action head knows to pull out the "action" readout_key 229 | batch["action"], 230 | batch["observation"]["timestep_pad_mask"], 231 | batch["action_pad_mask"], 232 | train=train, 233 | ) 234 | return action_loss, action_metrics 235 | 236 | @partial( 237 | jax.jit, 238 | # state is replicated, batch is data-parallel 239 | in_shardings=(replicated_sharding, dp_sharding), 240 | out_shardings=(replicated_sharding, replicated_sharding), 241 | # allows jax to modify `state` in-place, saving a lot of memory 242 | donate_argnums=0, 243 | ) 244 | def train_step(state: TrainState, batch: Data): 245 | rng, dropout_rng = jax.random.split(state.rng) 246 | (loss, info), grads = jax.value_and_grad(loss_fn, has_aux=True)( 247 | state.model.params, batch, dropout_rng, train=True 248 | ) 249 | grad_norm = optax.global_norm(grads) 250 | updates, _ = state.tx.update(grads, state.opt_state, state.model.params) 251 | update_norm = optax.global_norm(updates) 252 | info.update( 253 | { 254 | "grad_norm": grad_norm, 255 | "update_norm": update_norm, 256 | "param_norm": param_norm_callable(state.model.params), 257 | "learning_rate": lr_callable(state.step), 258 | } 259 | ) 260 | new_state = state.apply_gradients(grads=grads, rng=rng) 261 | return new_state, info 262 | 263 | val_datasets_kwargs_list, _ = filter_eval_datasets( 264 | FLAGS.config.dataset_kwargs["dataset_kwargs_list"], 265 | FLAGS.config.dataset_kwargs["sample_weights"], 266 | FLAGS.config.eval_datasets, 267 | ) 268 | val_callback = ValidationCallback( 269 | loss_fn=loss_fn, 270 | process_batch_fn=lambda batch: shard(process_batch(batch)), 271 | text_processor=text_processor, 272 | val_dataset_kwargs_list=val_datasets_kwargs_list, 273 | dataset_kwargs=FLAGS.config.dataset_kwargs, 274 | **FLAGS.config.val_kwargs.to_dict(), 275 | ) 276 | viz_callback = VisualizationCallback( 277 | text_processor=text_processor, 278 | val_dataset_kwargs_list=val_datasets_kwargs_list, 279 | dataset_kwargs=FLAGS.config.dataset_kwargs, 280 | **FLAGS.config.viz_kwargs.to_dict(), 281 | ) 282 | if "rollout_kwargs" in FLAGS.config: 283 | rollout_kwargs = FLAGS.config.rollout_kwargs.to_dict() 284 | dataset_name = rollout_kwargs.pop("dataset_name") 285 | rollout_callback = RolloutVisualizationCallback( 286 | text_processor=text_processor, 287 | action_proprio_metadata=train_data.dataset_statistics[dataset_name], 288 | **rollout_kwargs, 289 | ) 290 | else: 291 | rollout_callback = None 292 | 293 | def wandb_log(info, step): 294 | if jax.process_index() == 0: 295 | wandb.log(flatten_dict(info, sep="/"), step=step) 296 | 297 | timer = Timer() 298 | for i in tqdm.tqdm( 299 | range(start_step, int(FLAGS.config.num_steps)), 300 | total=int(FLAGS.config.num_steps), 301 | initial=start_step, 302 | dynamic_ncols=True, 303 | ): 304 | timer.tick("total") 305 | 306 | with timer("dataset"): 307 | batch = next(train_data_iter) 308 | 309 | with timer("train"): 310 | train_state, update_info = train_step(train_state, batch) 311 | 312 | if (i + 1) % FLAGS.config.save_interval == 0: 313 | save_callback(train_state, i + 1) 314 | 315 | if (i + 1) % FLAGS.config.eval_interval == 0: 316 | logging.info("Evaluating...") 317 | with timer("eval"): 318 | val_metrics = val_callback(train_state, i + 1) 319 | wandb_log(val_metrics, step=i + 1) 320 | 321 | if (i + 1) % FLAGS.config.viz_interval == 0: 322 | logging.info("Visualizing...") 323 | with timer("visualize"): 324 | viz_metrics = viz_callback(train_state, i + 1) 325 | wandb_log(viz_metrics, step=i + 1) 326 | 327 | if rollout_callback is not None: 328 | with timer("rollout"): 329 | rollout_metrics = rollout_callback(train_state, i + 1) 330 | wandb_log(rollout_metrics, step=i + 1) 331 | 332 | timer.tock("total") 333 | if (i + 1) % FLAGS.config.log_interval == 0: 334 | update_info = jax.device_get(update_info) 335 | wandb_log( 336 | {"training": update_info, "timer": timer.get_average_times()}, 337 | step=i + 1, 338 | ) 339 | 340 | 341 | if __name__ == "__main__": 342 | app.run(main) 343 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup(name="octo", packages=find_packages()) 4 | -------------------------------------------------------------------------------- /tests/debug_config.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import imp 3 | import os 4 | 5 | from ml_collections import ConfigDict 6 | 7 | from octo.data.oxe.oxe_standardization_transforms import bridge_dataset_transform 8 | from octo.utils.spec import ModuleSpec 9 | 10 | get_base_config = imp.load_source( 11 | "config", os.path.join(os.path.dirname(__file__), "../scripts/configs/config.py") 12 | ).get_config 13 | 14 | 15 | def update_config(config: ConfigDict, **kwargs): 16 | assert isinstance(config, ConfigDict) 17 | updates = ConfigDict(kwargs) 18 | new_config = deepcopy(config) 19 | new_config.update(updates) 20 | return new_config 21 | 22 | 23 | def get_config(): 24 | base_config = get_base_config("dummy") 25 | del base_config["dataset_kwargs"]["oxe_kwargs"] 26 | config = update_config( 27 | base_config, 28 | num_steps=2, 29 | optimizer=dict( 30 | learning_rate=dict( 31 | warmup_steps=1, 32 | ), 33 | ), 34 | val_kwargs=dict( 35 | val_shuffle_buffer_size=1, 36 | num_val_batches=2, 37 | ), 38 | viz_kwargs=dict( 39 | eval_batch_size=2, 40 | trajs_for_metrics=4, 41 | trajs_for_viz=4, 42 | samples_per_state=4, 43 | ), 44 | log_interval=1, 45 | eval_interval=2, 46 | viz_interval=2, 47 | save_interval=2, 48 | eval_datasets=None, 49 | dataset_kwargs={ 50 | "dataset_kwargs_list": [ 51 | { 52 | "name": "bridge_dataset", 53 | "data_dir": "./tests/debug_dataset", 54 | "image_obs_keys": {"primary": "image_0"}, 55 | "proprio_obs_key": "proprio", 56 | "language_key": "language_instruction", 57 | "standardize_fn": ModuleSpec.create(bridge_dataset_transform), 58 | }, 59 | ], 60 | "frame_transform_kwargs": { 61 | "resize_size": (128, 128), 62 | "num_parallel_calls": 4, 63 | }, 64 | "traj_transform_threads": 1, # shared between all datasets 65 | "traj_read_threads": 1, # shared between all datasets 66 | "batch_size": 64, 67 | "sample_weights": None, 68 | "shuffle_buffer_size": 1000, 69 | "balance_weights": True, 70 | }, 71 | ) 72 | return config 73 | -------------------------------------------------------------------------------- /tests/debug_dataset/bridge_dataset/1.0.0/bridge_dataset-train.tfrecord-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/tests/debug_dataset/bridge_dataset/1.0.0/bridge_dataset-train.tfrecord-00000-of-00001 -------------------------------------------------------------------------------- /tests/debug_dataset/bridge_dataset/1.0.0/bridge_dataset-val.tfrecord-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octo-models/octo/241fb3514b7c40957a86d869fecb7c7fc353f540/tests/debug_dataset/bridge_dataset/1.0.0/bridge_dataset-val.tfrecord-00000-of-00001 -------------------------------------------------------------------------------- /tests/debug_dataset/bridge_dataset/1.0.0/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "// TODO(example_dataset): BibTeX citation", 3 | "description": "TODO(example_dataset): Markdown description of your dataset.\nDescription is **formatted** as markdown.\n\nIt should also contain any processing which has been applied (if any),\n(e.g. corrupted example skipped, images cropped,...):", 4 | "fileFormat": "tfrecord", 5 | "moduleName": "bridge_dataset.bridge_dataset_dataset_builder", 6 | "name": "bridge_dataset", 7 | "releaseNotes": { 8 | "1.0.0": "Initial release." 9 | }, 10 | "splits": [ 11 | { 12 | "filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}", 13 | "name": "train", 14 | "numBytes": "3474742", 15 | "shardLengths": [ 16 | "25" 17 | ] 18 | }, 19 | { 20 | "filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}", 21 | "name": "val", 22 | "numBytes": "3474742", 23 | "shardLengths": [ 24 | "25" 25 | ] 26 | } 27 | ], 28 | "version": "1.0.0" 29 | } 30 | --------------------------------------------------------------------------------