├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── data_collection ├── README.md ├── config │ ├── README.md │ ├── berkeley_robot_0 │ │ ├── README.md │ │ ├── cogvlm_server.yaml │ │ ├── config.yaml │ │ ├── gc_bc.yaml │ │ ├── general_params.yaml │ │ ├── lc_bc.yaml │ │ ├── reset_detector.yaml │ │ ├── subgoal_predictor.yaml │ │ ├── success_detector.yaml │ │ ├── task_definition.yaml │ │ └── task_proposer.yaml │ ├── berkeley_robot_1 │ │ ├── cogvlm_server.yaml │ │ ├── config.yaml │ │ ├── gc_bc.yaml │ │ ├── general_params.yaml │ │ ├── lc_bc.yaml │ │ ├── reset_detector.yaml │ │ ├── subgoal_predictor.yaml │ │ ├── success_detector.yaml │ │ ├── task_definition.yaml │ │ └── task_proposer.yaml │ ├── berkeley_robot_2 │ │ ├── cogvlm_server.yaml │ │ ├── config.yaml │ │ ├── gc_bc.yaml │ │ ├── general_params.yaml │ │ ├── reset_detector.yaml │ │ ├── subgoal_predictor.yaml │ │ ├── success_detector.yaml │ │ ├── task_definition.yaml │ │ └── task_proposer.yaml │ ├── berkeley_robot_3 │ │ ├── cogvlm_server.yaml │ │ ├── config.yaml │ │ ├── gc_bc.yaml │ │ ├── general_params.yaml │ │ ├── lc_bc.yaml │ │ ├── reset_detector.yaml │ │ ├── subgoal_predictor.yaml │ │ ├── success_detector.yaml │ │ ├── task_definition.yaml │ │ └── task_proposer.yaml │ └── berkeley_robot_4 │ │ ├── cogvlm_server.yaml │ │ ├── config.yaml │ │ ├── diffusion_policy.yaml │ │ ├── gc_bc.yaml │ │ ├── general_params.yaml │ │ ├── rcsl.yaml │ │ ├── reset_detector.yaml │ │ ├── subgoal_predictor.yaml │ │ ├── success_detector.yaml │ │ ├── task_definition.yaml │ │ └── task_proposer.yaml ├── orchestrator │ ├── cogvlm_server │ │ ├── batched.py │ │ ├── forwarding_server.py │ │ ├── main.py │ │ ├── test-img.png │ │ ├── test.py │ │ └── test_batched.py │ ├── robot │ │ ├── gc_policy.py │ │ ├── logger.py │ │ ├── main.py │ │ ├── reset_detector.py │ │ ├── subgoal_predictor.py │ │ ├── task_proposer.py │ │ ├── task_success_predictor.py │ │ └── utils.py │ ├── set_workspace_bounds │ │ └── teleop.py │ ├── susie_server │ │ └── main.py │ └── web_viewer │ │ ├── app.py │ │ ├── ros_client │ │ └── run_client.py │ │ └── templates │ │ └── index.html ├── requirements.txt ├── setup.py ├── ssh_port_forward.sh └── start_robot.sh ├── media ├── autonomous_data_collection.png ├── soar_logo.jpeg └── soar_teaser.png ├── model_training ├── README.md ├── experiments │ ├── configs │ │ ├── data_config.py │ │ └── train_config.py │ ├── scripts │ │ └── launch.sh │ └── train.py ├── jaxrl_m │ ├── agents │ │ ├── __init__.py │ │ └── continuous │ │ │ └── gc_bc.py │ ├── common │ │ ├── common.py │ │ ├── encoding.py │ │ ├── optimizers.py │ │ ├── typing.py │ │ └── wandb.py │ ├── data │ │ ├── dataset.py │ │ ├── text_processing.py │ │ ├── tf_augmentations.py │ │ └── tf_goal_relabeling.py │ ├── networks │ │ ├── actor_critic_nets.py │ │ └── mlp.py │ ├── utils │ │ ├── jax_utils.py │ │ ├── timer_utils.py │ │ └── train_utils.py │ └── vision │ │ ├── __init__.py │ │ ├── film_conditioning_layer.py │ │ └── resnet_v1.py ├── requirements.txt └── setup.py ├── rlds_converter ├── README.md ├── dataset_builder.py ├── requirements.txt ├── setup.py └── soar_dataset │ ├── __init__.py │ └── soar_dataset_dataset_builder.py └── soar_data ├── download_dataset.sh ├── fetch_urls.sh ├── load_soar_data.ipynb ├── test_dataset_urls.txt └── urls.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | checkpoints/ 165 | goal.png 166 | obs.png 167 | data_collection_logs/ 168 | video_logs 169 | orchestrator/web_viewer/uploads -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | # - id: check-yaml 6 | - id: check-ast 7 | - id: check-added-large-files 8 | args: ['--maxkb=2000'] 9 | - id: check-case-conflict 10 | - id: check-merge-conflict 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | - id: detect-private-key 14 | - id: debug-statements 15 | exclude: ^model_training/experiments/ 16 | - repo: https://github.com/psf/black 17 | rev: 22.10.0 18 | hooks: 19 | - id: black 20 | exclude: ^model_training/experiments/ 21 | - repo: https://github.com/pycqa/isort 22 | rev: 5.12.0 23 | hooks: 24 | - id: isort 25 | exclude: ^model_training/experiments/ 26 | args: [ 27 | "--profile", "black", 28 | "--src", "model_training/jaxrl_m", 29 | "--src", "model_training/experiments", 30 | "--src", "data_collection/orchestrator", 31 | ] 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Robotic AI & Learning Lab Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SOAR: Autonomous Improvement of Instruction Following Skills via Foundation Models 2 | [](media/soar_logo.jpeg) 3 | [![arXiv](https://img.shields.io/badge/arXiv-2407.20635-df2a2a.svg)](https://arxiv.org/pdf/2407.20635) 4 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/rail-berkeley/soar/blob/main/soar_data/load_soar_data.ipynb) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | [![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://auto-improvement.github.io/) 7 | 8 | [Zhiyuan Zhou](https://zhouzypaul.github.io/), [Pranav Atreya](https://pranavatreya.github.io/), [Abraham Lee](https://www.linkedin.com/in/abraham-lee-4a0497242?original_referer=https%3A%2F%2Fwww.google.com%2F), [Homer Walke](https://homerwalke.com/), [Oier Mees](https://www.oiermees.com/), [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/) 9 |
10 | 11 | We present SOAR, an approach to autonomously improve instruction following policies leveraging 12 | foundation models. SOAR breaks down the autonomous improvement problem into components that import 13 | Internet-scale knowledge from VLMs and a component that learns from autonomous data with a purely self-supervised objective. 14 | 15 | ![](media/soar_teaser.png) 16 | 17 | This repository contains three components: (1) the VLM powered semantics-aware autonomous data collection pipeline, (2) converting the collected raw data into the RLDS format, and (3) Jax/Flax code for training the policies used in the paper. 18 | 19 | ## Using SOAR-Data 20 | 21 | We have released SOAR-Data for public access [here](https://rail.eecs.berkeley.edu/datasets/soar_release/1.0.0/) in RLDS format. (The raw data is also available [here](https://rail.eecs.berkeley.edu/datasets/soar_release/numpy_source/soar-dataset-local/)) 22 | We also provided a download script to download the dataset in RLDS format, which requires 136G of disk space. 23 | In this directory, run 24 | ```bash 25 | bash soar_data/download_dataset.sh 26 | ``` 27 | This script should take around 20 minutes to download if you use the parallel download option, and we recommend downloading inside a tmux session. 28 | 29 | To load the dataset for training and other downstream use cases, we have provided a minimal example [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/rail-berkeley/soar/blob/main/soar_data/load_soar_data.ipynb) that loads the dataset and visualizes it. 30 | 31 | ## Installation 32 | ```bash 33 | conda create -n soar python=3.10 34 | conda activate soar 35 | 36 | # model training requirements 37 | pip install -e model_training 38 | pip install -r model_training/requirements.txt 39 | 40 | # data collection requirements (you also need the jaxrl_m library above) 41 | pip install -e data_collection 42 | pip install -r data_collection/requirements.txt 43 | 44 | # rlds conversion requirements 45 | pip install -e rlds_converter 46 | pip install -r rlds_converter/requirements.txt 47 | ``` 48 | 49 | If you would like to train models with Jax, 50 | For GPU: 51 | ```bash 52 | pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 53 | ``` 54 | 55 | For TPU: 56 | ```bash 57 | pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 58 | ``` 59 | 60 | 61 | ## Autonomous Data Collection 62 | 63 | We provide a ready-to-use implementation of autonomous data collection on a fleet of WidowX robot arms. This data collection system is designed around deploying instruction following policies at scale to collect autonomous datasets that are semantically relevant, diverse, and large. Special care is taken to minimize human supervision during data collection, with features like automatic reset detection (and subsequent Slack notification). 64 | 65 | ![](media/autonomous_data_collection.png) 66 | 67 | Run autonomous data collection on the robot with: 68 | ``` 69 | python data_collection/orchestrator/robot/main.py --config_dir config/ 70 | ``` 71 | 72 | See [data_collection/README.md](data_collection/README.md) for more information on the setup required before running data collection. 73 | 74 | ## Model Training 75 | This directory contains a self-contained python project for training goal-conditioned and language conditioned policies on Bridge and on Soar-Data. 76 | 77 | To launch a training run, run: 78 | ```bash 79 | cd model_training 80 | bash experiments/scripts/launch.sh 81 | ``` 82 | This will launch [train.py](model_training/experiments/train.py) with the default arguments specified in [train_config.py](model_training/experiments/configs/train_config.py) and [data_config.py](model_training/experiments/configs/data_config.py). 83 | 84 | ## RLDS Data Conversion 85 | We convert the raw data logged in the `data_collection/*` directories into the commonly used RLDS format. The conversion code is 86 | located in the `rlds_converter` directory. See [rlds_converter/README.md](rlds_converter/README.md) for more information. 87 | 88 | To build the SOAR dataset 89 | ```bash 90 | cd rlds_converter/soar_dataset 91 | CUDA_VISIBLE_DEVICES="" tfds build --manual_dir 92 | ``` 93 | 94 | ## Citation 95 | ``` 96 | @article{zhou2024autonomous, 97 | title={Autonomous Improvement of Instruction Following Skills via Foundation Models}, 98 | author={Zhiyuan Zhou and Pranav Atreya and Abraham Lee and Homer Walke and Oier Mees and Sergey Levine}, 99 | journal = {arXiv preprint arXiv:407.20635}, 100 | year={2024}, 101 | } 102 | ``` 103 | 104 | ## Contributing 105 | We welcome pull requests and bug reports to this repo. 106 | 107 | To enable code checks and auto-formatting, please install pre-commit hooks (run this in the root directory): 108 | ```bash 109 | pre-commit install 110 | ``` 111 | The hooks should now run before every commit. If files are modified during the checks, you'll need to re-stage them and commit again. 112 | -------------------------------------------------------------------------------- /data_collection/README.md: -------------------------------------------------------------------------------- 1 | # Autonomous data collection with VLMs on WidowX robots 2 | 3 | ## Installation 4 | Run in the current directory 5 | ```bash 6 | pip install -e . 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## VLM Server Hosting 11 | You have the option to either host CogVLM on a local server for inference, or use GPT-4V/o. This specification can be made in the configs (see `README.md` under `data_collection/config`). If you are running autonomous data collection with multiple robots, you can host the VLM just once, and all data collection scripts will query this server. 12 | 13 | If you are hosting CogVLM, make sure the port specified in the last line of the file `data_collection/orchestrator/cogvlm_server/main.py` matches the port specified in `data_collection/config//cogvlm_server.yaml`. Then, on the machine you want to host the VLM server run `python main.py` from the `data_collection/orchestrator/cogvlm_server` directory. The VLM requires around 48 GB of memory. 14 | 15 | We provide convenience scripts for testing that the VLM has been hosted correctly and for setting up a proxy server (i.e., to get around firewalls) which are located in the same directory. 16 | 17 | ## OpenAI API Key 18 | 19 | Make sure to specify your OpenAI API key as an environment variable with the name `OPENAI_API_KEY`. It is likely convenient to include this specification in your `.bashrc` file. 20 | 21 | ## SuSIE Server 22 | 23 | Similar to the VLM server, you will need to host the SuSIE model on a machine accessible to the robot machines. The memory requirement is much more modest, taking up around 6 GB. Make sure the port specified in the last line of the file `data_collection/orchestrator/susie_server/main.py` matches the port specified in the config `data_collection/config//subgoal_predictor.yaml`. To launch the SuSIE server, run `python orchestrator/susie_server/main.py --config_dir config/` from the `data_collection` directory, specifying the path to the folder containing your robot's configs. 24 | 25 | ## Web Viewer 26 | 27 | To make it convenient to monitor your robots from anywhere, we include a Flask web server with a simple front-end displaying video streamed by your robots. It is mandatory to launch the web server. There are two parts to launching this web viewer: (1) launch the Flask server on a central machine, and (2) launch the data streaming RosPy script on each of your robots. 28 | 29 | To launch the Flask web server, run `python app.py` from the directory `data_collection/orchestrator/web_viewer`. The default port for the web server is `5000`, which can be adjusted in the last line of the file `app.py`. 30 | 31 | Separately on your robot machine (the machine where you are running the docker container and action server from `bridge_data_robot`), launch the script `python orchestrator/web_viewer/ros_client/run_client.py --config_dir config/` from the `data_collection` directory, making sure the specify the path to the appropriate configuration directory. This command should be run after the docker container and action server from `bridge_data_robot` have been launched (see the README in the `bridge_data_robot` repo for more instructions). 32 | 33 | ## Pre-data collection: Setting Workspace Boundaries for Robot 34 | 35 | The final step before launching data collection is to specify the workspace boundaries for your robot. Specifying workspace boundaries (as the dimensions of an invisible rectangular prism the end-effector is forced to stay inside of) helps with safe robot operation and minimizes the chances that the robot will do something requiring a manual environment reset. 36 | 37 | Run the script `python orchestrator/set_workspace_bounds/teleop.py --config_dir config/` from the `data_collection` directory. This will instantiate a keyboard teleop script (the key controls of which will be printed once you run the script). You should then teleop the end-effector to the extremums of your workspace. Hitting `q` will terminate the script, and print out the minimum and maximim `x`, `y`, and `z` values defining the invisible rectangular prism boundary. You should enter these values in your robot `general_params` config file: `data_collection/config//general_params.yaml`. 38 | 39 | ## Running the Robot 40 | 41 | Finally you are ready to run autonomous data collection on the robot! Simply run the following script: 42 | ``` 43 | python orchestrator/robot/main.py --config_dir config/ 44 | ``` 45 | from the `data_collection` directory. The script `main.py` contains the code for iterating through the full autonomous data collection loop: querying the VLM for which task to command, querying the SuSIE server for a subgoal image, rolling out the policy, querying the VLM for success determination, and logging. You should be able to keep this script and the robot running for many hours at a time, potentially periodically resetting a fallen object in the robot's environment. 46 | 47 | 1. See `checkpoints` for instructions on downloading pre-trained models 48 | 2. See `config` for instructions on parameter specification for setting up autonomous data collection on your robot. 49 | -------------------------------------------------------------------------------- /data_collection/config/README.md: -------------------------------------------------------------------------------- 1 | # Robot Data Collection Configuration Files 2 | 3 | This directory contains the configuration files for the five data collection WidowX robots used in the paper. Minus some per-robot idiosyncratic changes, the configs are identical across robots. 4 | 5 | When setting up a robot for autonomous data collection, all changes you should need to make should be localized to the config subdirectory (i.e., you should not need to make any code changes). We have provided a detailed README under `berkeley_robot_0` with explanations of what the various configurations control and what needs to be modified for your setup. -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/README.md: -------------------------------------------------------------------------------- 1 | # Robot Configuration 2 | 3 | This directory contains a set of configuration files that specify things related to connecting to the robot, logging, image subgoal generation model parameters, goal-conditioned policy architecture, VLM task proposing and success detection, etc. 4 | 5 | ## config.yaml 6 | 7 | This is the top level configuration file that points to other more specific files for all of the aforementioned pipeline components. Make sure to specify configuration files here for the following: 8 | 9 | - general_params 10 | - task_proposer_params 11 | - gc_policy_params 12 | - subgoal_predictor_params 13 | - reset_detector_params 14 | - success_detector_params 15 | - cogvlm_server_params 16 | - task_definition_params 17 | 18 | ## general_params.yaml 19 | This file contains configuration parameters for connecting to the robot, streaming video feed of the robot's third-person camera to a web viewer, workspace boundaries, and logging. 20 | 21 | 1. Specify the IP address and port for the machine plugged into the robot. On this machine, you will need to run the bridge data robot server (see installation instructions from https://github.com/rail-berkeley/bridge_data_robot), and the IP address and port should match that of the server. 22 | 2. We also provide for convenience a web UI that allows you to visualize what your robot(s) is/are doing from anywhere. The web UI is a Flask server that you will need to host on your machine. See instructions in the top-level README for how to do this. `web_viewer_ip` and `web_viewer_port` should be set corresponding to your hosted web server. 23 | 3. If you want to run simultaneous autonomous data collection on an arm farm (as we did in the paper) to scale data collection, then each robot will have a unique id which you must specify. 24 | 4. Ignore `override_workspace_boundaries` (this is deprecated) 25 | 5. `move_duration` is the time alloted for the robot to execute one action. We collect data and evaluate the WidowX robots with blocking control, in which action prediction for the next timestep occurs only after the robot has been given `move_duration` number of seconds to execute the previous action. Longer `move_duration` leads to more precise execution of commanded actions, but slower data collection. 26 | 6. `video_save_path` is a local directory where all your collected robot data will be logged. The logged data include image states and robot actions, generated subgoals, task and robot metadata, etc. Data is saved as numpy arrays, which if you want to subsequently use for training you need to convert to RLDS format (see instructions in top-level README for how to do this conversion). 27 | 7. Specifying `manual_workspace_bounds` allows you to define a rectangular prism in which the end-effector is forced to remain. It is a good idea to specify this as it enables much more safe robot behavior during autonomous data collection. The two numbers for `x`, `y`, and `z` correspond to the minimum and maximum allowed values for each dimension. See instructions in the top-level README for a convenience script that allows you to easily determine what these boundary values are. 28 | 29 | ## task_proposer.yaml 30 | 31 | There multiple different types of task proposers implemented in this code base (see orchestrator/robot/task_proposer.py for the implementations). The implemented task proposers include VLM task proposers using CogVLM and GPT4-V/o in which the VLM looks at the current observations and chooses a task to command from a list of possible tasks, a cycling task proposer which for a two-task setup (e.g., opening and closing a drawer) simply cycles between commanding each of the two tasks, and a human task proposer, which periodically queries you to enter a task via the command line. 32 | 33 | The VLM task proposer, in addition to considering environment affordances prescribed by the current image observation, also considers how many times each of the tasks has been attempted. If two tasks are viable to command according to the VLM and one has been attempted more than the other, the less attempted task will be commanded. To enable this, the `reuse_task_statistics` flag controls whether or not to load previous trajectories in the logging folder to compute the number of times each task has been attempted. If set to false, attempt counters for each of the tasks will be initialized to zero. 34 | 35 | `rand_selection_prob` specifies a probability with which to ignore the task proposed by the VLM and instead propose a random task. Setting this to a nonzero number can be useful in cases where the VLM is less accurate. 36 | 37 | ## gc_bc.yaml 38 | 39 | 1. Specify `checkpoint_path` to be the directory containing the policy checkpoint you trained or downloaded from huggingface (for the latter see instructions in `checkpoints` folder). 40 | 2. `rollout_timesteps` controls how many timesteps to roll out the goal-conditioned policy before a new image subgoal is synthesized. 41 | 3. Adding noise diversifies the collected robot data and facilitates exploration. `exploration` controls the parameters of the added noise. 42 | 4. `open_gripper_if_nothing_grasped`, if set to true, will force the gripper open (even if the policy is commanding it to be closed) if the encoders on the gripper read that nothing has been grasped. We leave this set to false in all our experiments. 43 | 5. `restrict_action_space` restricts the action space to 5 dimensions, assigning the pitch and yaw dimensions to known good values. In particular challenging environments this can help exploration if it is known that these two angle dimensions are not needed to complete desired tasks. This variable is set to false in all of our experiments. 44 | 6. `agent_kwargs` controls the configuration of the trained policy. See the source code under the `model_training` for what these configurations control. 45 | 46 | # subgoal_predictor.yaml 47 | 48 | This config file controls the image generation parameters for SuSIE, the InstructPix2Pix style diffusion model we use for subgoal generation. The defaults should be fine. The only parameters that need to be set are `susie_server_ip` and `susie_server_port`, which should be set appropriately depending on where you host the SuSIE server (see instructions in top-level README for hosting this server). 49 | 50 | # reset_detector.yaml 51 | 52 | During autonomous data collection, it is possible that objects may fall out of the workspace boundaries. To prevent the user from having to monitor the robots 24/7, we have included an automatic reset detection functionality, whereby the VLM is used to determine if objects are missing from the workspace, and if so send a slack messsage to a channel you create. Specify `slack_token` and `channel_id` appropriately for the channel you create. `which_vlm` controls which VLM is used for reset detection. 53 | 54 | # success_detector.yaml 55 | 56 | `which_vlm` controls which VLM you would like to use for determining task success at the end of every trajectory. We used CogVLM for our experiments. 57 | 58 | # cogvlm_server.yaml 59 | 60 | Specify here the IP address and port of the machine you are hosting CogVLM on. CogVLM requires around 48 Gb of memory for batched 61 | inference. 62 | 63 | # task_definition.yaml 64 | 65 | When setting up your robot in a new evironment, make sure to specify which objects are present in the scene, and what are all of 66 | the language tasks you want to be run during data collection. In addition to guiding the VLM for task selection, the information here is logged with every collected trajectory. -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/cogvlm_server.yaml: -------------------------------------------------------------------------------- 1 | # CogVLM server config parameters 2 | cogvlm_server_ip: "localhost" 3 | cogvlm_server_port: 6000 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | general_params: !include general_params.yaml 3 | task_proposer_params: !include task_proposer.yaml 4 | gc_policy_params: !include gc_bc.yaml 5 | subgoal_predictor_params: !include subgoal_predictor.yaml 6 | reset_detector_params: !include reset_detector.yaml 7 | success_detector_params: !include success_detector.yaml 8 | cogvlm_server_params: !include cogvlm_server.yaml 9 | task_definition_params: !include task_definition.yaml -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/gc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level goal conditioned policy 3 | policy_class: "gc_bc" 4 | checkpoint_path: "checkpoints/iterated_gcbc_round_1_policy" 5 | image_size: 256 6 | rollout_timesteps: 20 7 | exploration: 8 | make_traj_deterministic_prob: 0.2 9 | sampling_temperature: 0.2 10 | gripper_open_prob: 0.005 11 | gripper_close_prob: 0.01 12 | open_gripper_if_nothing_grasped: false 13 | restrict_action_space: false 14 | ACT_MEAN: 15 | - 1.9296819e-04 16 | - 1.3667766e-04 17 | - -1.4583133e-04 18 | - -1.8390431e-04 19 | - -3.0808983e-04 20 | - 2.7425270e-04 21 | - 5.9716219e-01 22 | ACT_STD: 23 | - 0.00912848 24 | - 0.0127196 25 | - 0.01229497 26 | - 0.02606696 27 | - 0.02875283 28 | - 0.07807977 29 | - 0.48710242 30 | agent_kwargs: 31 | policy_kwargs: 32 | fixed_std: 33 | - 1.0 34 | - 1.0 35 | - 1.0 36 | - 1.0 37 | - 1.0 38 | - 1.0 39 | - 0.1 40 | std_parameterization: "fixed" 41 | tanh_squash_distribution: false 42 | early_goal_concat: true 43 | shared_goal_encoder: true 44 | use_proprio: false 45 | network_kwargs: 46 | hidden_dims: 47 | - 256 48 | - 256 49 | - 256 50 | dropout_rate: 0.1 51 | encoder: 52 | type: "resnetv1-34-bridge" 53 | config: 54 | act: "swish" 55 | add_spatial_coordinates: false 56 | pooling_method: "avg" 57 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/general_params.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # IP address and port of the robot 3 | ip: "128.32.175.236" 4 | port: 5556 5 | 6 | # IP address of web viewer 7 | web_viewer_ip: "128.32.175.81" 8 | web_viewer_port: 5000 9 | 10 | # Robot ID 11 | robot_id: 0 12 | 13 | # General config parameters 14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper 15 | env_params: 16 | camera_topics: 17 | - name: "/blue/image_raw" 18 | flip: false 19 | override_workspace_boundaries: 20 | - - -20.0 21 | - -20.0 22 | - -20.0 23 | - -1.57 24 | - 0 25 | - - 20.0 26 | - 20.0 27 | - 20.0 28 | - 1.57 29 | - 0 30 | move_duration: 0.3 31 | video_save_path: "video_logs" 32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam 33 | initial_eep: 34 | - 0.3 35 | - 0.0 36 | - 0.15 37 | - 0 38 | - 0 39 | - 0 40 | - 1 41 | manual_workspace_bounds: 42 | x: 43 | - 0.15603437 44 | - 0.42324517 45 | y: 46 | - -0.20489213 47 | - 0.28275232 48 | z: 49 | - 0.02985591 50 | - 0.16494011 51 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/lc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level language conditioned policy 3 | policy_class: "lc_bc" 4 | checkpoint_path: "checkpoints/lcbc/lcbc_bridge_v1" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 20 8 | exploration: 9 | make_traj_deterministic_prob: 0.5 10 | sampling_temperature: 0.2 11 | gripper_open_prob: 0.005 12 | gripper_close_prob: 0.01 13 | open_gripper_if_nothing_grasped: false 14 | restrict_action_space: false 15 | dataset_kwargs: 16 | shuffle_buffer_size: 25000 17 | augment: true 18 | augment_next_obs_goal_differently: false 19 | augment_kwargs: 20 | random_resized_crop: 21 | scale: 22 | - 0.8 23 | - 1.0 24 | ratio: 25 | - 0.9 26 | - 1.1 27 | random_brightness: 28 | - 0.2 29 | random_contrast: 30 | - 0.8 31 | - 1.2 32 | random_saturation: 33 | - 0.8 34 | - 1.2 35 | random_hue: 36 | - 0.1 37 | augment_order: 38 | - "random_resized_crop" 39 | - "random_brightness" 40 | - "random_contrast" 41 | - "random_saturation" 42 | - "random_hue" 43 | goal_relabeling_strategy: "uniform" 44 | goal_relabeling_kwargs: 45 | reached_proportion: 0.0 46 | relabel_actions: true 47 | normalization_type: "normal" 48 | load_language: true 49 | skip_unlabeled: true 50 | ACT_MEAN: 51 | - 1.9296819e-04 52 | - 1.3667766e-04 53 | - -1.4583133e-04 54 | - -1.8390431e-04 55 | - -3.0808983e-04 56 | - 2.7425270e-04 57 | - 5.9716219e-01 58 | ACT_STD: 59 | - 0.00912848 60 | - 0.0127196 61 | - 0.01229497 62 | - 0.02606696 63 | - 0.02875283 64 | - 0.07807977 65 | - 0.48710242 66 | agent_kwargs: 67 | policy_kwargs: 68 | fixed_std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | - 1.0 73 | - 1.0 74 | - 1.0 75 | - 0.1 76 | state_dependent_std: false 77 | tanh_squash_distribution: false 78 | early_goal_concat: true 79 | shared_goal_encoder: true 80 | use_proprio: false 81 | network_kwargs: 82 | hidden_dims: 83 | - 256 84 | - 256 85 | - 256 86 | dropout_rate: 0.1 87 | encoder: 88 | type: "resnetv1-34-bridge-film" 89 | config: 90 | act: "swish" 91 | add_spatial_coordinates: true 92 | pooling_method: "avg" 93 | text_processor: "muse_embedding" 94 | text_processor_kwargs: {} 95 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/reset_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Reset detector config params 3 | slack_token: 4 | channel_id: 5 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none" 6 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/subgoal_predictor.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of high-level subgoal predictor policy 3 | diffusion_checkpoint: "kvablack/susie" 4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m" 5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax" 6 | diffusion_num_steps: 50 7 | prompt_w: 7.5 8 | context_w: 2.0 9 | image_size: 256 10 | max_subgoals: 5 11 | susie_server_ip: "localhost" 12 | susie_server_port: 7000 -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/success_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of success detector 3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none" 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/task_definition.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # task definition config params 3 | object_list: 4 | - "mushroom" 5 | - "blue bowl" 6 | task_list: 7 | - "remove the mushroom from the blue bowl and put it on the table" 8 | - "put the mushroom in the blue bowl" -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_0/task_proposer.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of task proposer 3 | reuse_task_statistics: false 4 | rand_selection_prob: 0.2 5 | zone_center: 0.3 6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone 7 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none", "human", "cycling" 8 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/cogvlm_server.yaml: -------------------------------------------------------------------------------- 1 | # CogVLM server config parameters 2 | cogvlm_server_ip: "localhost" 3 | cogvlm_server_port: 6000 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | general_params: !include general_params.yaml 3 | task_proposer_params: !include task_proposer.yaml 4 | gc_policy_params: !include gc_bc.yaml 5 | subgoal_predictor_params: !include subgoal_predictor.yaml 6 | reset_detector_params: !include reset_detector.yaml 7 | success_detector_params: !include success_detector.yaml 8 | cogvlm_server_params: !include cogvlm_server.yaml 9 | task_definition_params: !include task_definition.yaml 10 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/gc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level goal conditioned policy 3 | policy_class: "gc_bc" 4 | checkpoint_path: "checkpoints/iterated_gcbc_round_1_policy" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 20 8 | exploration: 9 | make_traj_deterministic_prob: 0.5 10 | sampling_temperature: 0.2 11 | gripper_open_prob: 0.005 12 | gripper_close_prob: 0.01 13 | open_gripper_if_nothing_grasped: false 14 | restrict_action_space: false 15 | dataset_kwargs: 16 | shuffle_buffer_size: 25000 17 | augment: true 18 | augment_next_obs_goal_differently: false 19 | augment_kwargs: 20 | random_resized_crop: 21 | scale: 22 | - 0.8 23 | - 1.0 24 | ratio: 25 | - 0.9 26 | - 1.1 27 | random_brightness: 28 | - 0.2 29 | random_contrast: 30 | - 0.8 31 | - 1.2 32 | random_saturation: 33 | - 0.8 34 | - 1.2 35 | random_hue: 36 | - 0.1 37 | augment_order: 38 | - "random_resized_crop" 39 | - "random_brightness" 40 | - "random_contrast" 41 | - "random_saturation" 42 | - "random_hue" 43 | goal_relabeling_strategy: "geometric" 44 | goal_relabeling_kwargs: 45 | reached_proportion: 0.0 46 | discount: 0.98 47 | relabel_actions: true 48 | dataset_contains_commanded_goals: false 49 | normalization_type: "normal" 50 | ACT_MEAN: 51 | - 1.9296819e-04 52 | - 1.3667766e-04 53 | - -1.4583133e-04 54 | - -1.8390431e-04 55 | - -3.0808983e-04 56 | - 2.7425270e-04 57 | - 5.9716219e-01 58 | ACT_STD: 59 | - 0.00912848 60 | - 0.0127196 61 | - 0.01229497 62 | - 0.02606696 63 | - 0.02875283 64 | - 0.07807977 65 | - 0.48710242 66 | agent_kwargs: 67 | policy_kwargs: 68 | fixed_std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | - 1.0 73 | - 1.0 74 | - 1.0 75 | - 0.1 76 | std_parameterization: "fixed" 77 | tanh_squash_distribution: false 78 | early_goal_concat: true 79 | shared_goal_encoder: true 80 | use_proprio: false 81 | network_kwargs: 82 | hidden_dims: 83 | - 256 84 | - 256 85 | - 256 86 | dropout_rate: 0.1 87 | encoder: 88 | type: "resnetv1-34-bridge" 89 | config: 90 | act: "swish" 91 | add_spatial_coordinates: false 92 | pooling_method: "avg" 93 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/general_params.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # IP address and port of the robot 3 | ip: "128.32.175.102" 4 | port: 5556 5 | 6 | # IP address of web viewer 7 | web_viewer_ip: "128.32.175.81" 8 | web_viewer_port: 5000 9 | 10 | # Robot ID 11 | robot_id: 1 12 | 13 | # General config parameters 14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper 15 | env_params: 16 | camera_topics: 17 | - name: "/blue/image_raw" 18 | flip: false 19 | override_workspace_boundaries: 20 | - - -20.0 21 | - -20.0 22 | - -20.0 23 | - -1.57 24 | - 0 25 | - - 20.0 26 | - 20.0 27 | - 20.0 28 | - 1.57 29 | - 0 30 | move_duration: 0.3 31 | video_save_path: "video_logs" 32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam 33 | initial_eep: 34 | - 0.3 35 | - 0.0 36 | - 0.15 37 | - 0 38 | - 0 39 | - 0 40 | - 1 41 | # manual_workspace_bounds: # minsky table height 27 42 | # x: 43 | # - 0.17827454 44 | # - 0.42494287 45 | # y: 46 | # - -0.22023482 47 | # - 0.18838036 48 | # z: 49 | # - 0.02200321 50 | # - 0.23297783 51 | manual_workspace_bounds: # minsky table height 27 with barrier (left side closest to table) 52 | x: 53 | - 0.17376936 54 | - 0.36731001 55 | y: 56 | - -0.15287904 57 | - 0.20850995 58 | z: 59 | - 0.01916022 60 | - 0.2381686 61 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/lc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level language conditioned policy 3 | policy_class: "lc_bc" 4 | checkpoint_path: "checkpoints/lcbc/lcbc_bridge_v1" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 25 8 | exploration: 9 | make_traj_deterministic_prob: 1.0 10 | sampling_temperature: 0.2 11 | gripper_open_prob: 0.005 12 | gripper_close_prob: 0.01 13 | open_gripper_if_nothing_grasped: false 14 | restrict_action_space: false 15 | dataset_kwargs: 16 | shuffle_buffer_size: 25000 17 | augment: true 18 | augment_next_obs_goal_differently: false 19 | augment_kwargs: 20 | random_resized_crop: 21 | scale: 22 | - 0.8 23 | - 1.0 24 | ratio: 25 | - 0.9 26 | - 1.1 27 | random_brightness: 28 | - 0.2 29 | random_contrast: 30 | - 0.8 31 | - 1.2 32 | random_saturation: 33 | - 0.8 34 | - 1.2 35 | random_hue: 36 | - 0.1 37 | augment_order: 38 | - "random_resized_crop" 39 | - "random_brightness" 40 | - "random_contrast" 41 | - "random_saturation" 42 | - "random_hue" 43 | goal_relabeling_strategy: "uniform" 44 | goal_relabeling_kwargs: 45 | reached_proportion: 0.0 46 | relabel_actions: true 47 | normalization_type: "normal" 48 | load_language: true 49 | skip_unlabeled: true 50 | ACT_MEAN: 51 | - 1.9296819e-04 52 | - 1.3667766e-04 53 | - -1.4583133e-04 54 | - -1.8390431e-04 55 | - -3.0808983e-04 56 | - 2.7425270e-04 57 | - 5.9716219e-01 58 | ACT_STD: 59 | - 0.00912848 60 | - 0.0127196 61 | - 0.01229497 62 | - 0.02606696 63 | - 0.02875283 64 | - 0.07807977 65 | - 0.48710242 66 | agent_kwargs: 67 | policy_kwargs: 68 | fixed_std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | - 1.0 73 | - 1.0 74 | - 1.0 75 | - 0.1 76 | state_dependent_std: false 77 | tanh_squash_distribution: false 78 | early_goal_concat: true 79 | shared_goal_encoder: true 80 | use_proprio: false 81 | network_kwargs: 82 | hidden_dims: 83 | - 256 84 | - 256 85 | - 256 86 | dropout_rate: 0.1 87 | encoder: 88 | type: "resnetv1-34-bridge-film" 89 | config: 90 | act: "swish" 91 | add_spatial_coordinates: true 92 | pooling_method: "avg" 93 | text_processor: "muse_embedding" 94 | text_processor_kwargs: {} 95 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/reset_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Reset detector config params 3 | slack_token: 4 | channel_id: C069AT564KC 5 | which_vlm: "none" # choices: "cogvlm", "gpt4v", "none" 6 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/subgoal_predictor.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of high-level subgoal predictor policy 3 | diffusion_checkpoint: "kvablack/susie" 4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m" 5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax" 6 | diffusion_num_steps: 50 7 | prompt_w: 7.5 8 | context_w: 5.0 9 | image_size: 256 10 | max_subgoals: 5 11 | susie_server_ip: "localhost" 12 | susie_server_port: 7000 -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/success_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of success detector 3 | which_vlm: "none" # choices: "cogvlm", "gpt4v", "none" 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/task_definition.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # task definition config params 3 | object_list: 4 | - "green object" 5 | - "wooden bowl" 6 | task_list: 7 | - "remove the green block from inside the brown bowl and put it on the table" 8 | - "put the green block in the brown bowl" 9 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_1/task_proposer.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of task proposer 3 | reuse_task_statistics: false 4 | rand_selection_prob: 0.2 5 | zone_center: 0.3 6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone 7 | which_vlm: "human" # choices: "cogvlm", "gpt4v", "none", "human" 8 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/cogvlm_server.yaml: -------------------------------------------------------------------------------- 1 | # CogVLM server config parameters 2 | cogvlm_server_ip: "localhost" 3 | cogvlm_server_port: 6000 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | general_params: !include general_params.yaml 3 | task_proposer_params: !include task_proposer.yaml 4 | gc_policy_params: !include gc_bc.yaml 5 | subgoal_predictor_params: !include subgoal_predictor.yaml 6 | reset_detector_params: !include reset_detector.yaml 7 | success_detector_params: !include success_detector.yaml 8 | cogvlm_server_params: !include cogvlm_server.yaml 9 | task_definition_params: !include task_definition.yaml -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/gc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level goal conditioned policy 3 | policy_class: "gc_bc" 4 | checkpoint_path: "checkpoints/iterated-gcbc/baseline/checkpoint_75000" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 20 8 | exploration: 9 | make_traj_deterministic_prob: 0.5 10 | sampling_temperature: 0.2 11 | gripper_open_prob: 0.005 12 | gripper_close_prob: 0.01 13 | open_gripper_if_nothing_grasped: false 14 | restrict_action_space: false 15 | dataset_kwargs: 16 | shuffle_buffer_size: 25000 17 | augment: true 18 | augment_next_obs_goal_differently: false 19 | augment_kwargs: 20 | random_resized_crop: 21 | scale: 22 | - 0.8 23 | - 1.0 24 | ratio: 25 | - 0.9 26 | - 1.1 27 | random_brightness: 28 | - 0.2 29 | random_contrast: 30 | - 0.8 31 | - 1.2 32 | random_saturation: 33 | - 0.8 34 | - 1.2 35 | random_hue: 36 | - 0.1 37 | augment_order: 38 | - "random_resized_crop" 39 | - "random_brightness" 40 | - "random_contrast" 41 | - "random_saturation" 42 | - "random_hue" 43 | goal_relabeling_strategy: "geometric" 44 | goal_relabeling_kwargs: 45 | reached_proportion: 0.0 46 | discount: 0.98 47 | relabel_actions: true 48 | dataset_contains_commanded_goals: false 49 | normalization_type: "normal" 50 | ACT_MEAN: 51 | - 1.9296819e-04 52 | - 1.3667766e-04 53 | - -1.4583133e-04 54 | - -1.8390431e-04 55 | - -3.0808983e-04 56 | - 2.7425270e-04 57 | - 5.9716219e-01 58 | ACT_STD: 59 | - 0.00912848 60 | - 0.0127196 61 | - 0.01229497 62 | - 0.02606696 63 | - 0.02875283 64 | - 0.07807977 65 | - 0.48710242 66 | agent_kwargs: 67 | policy_kwargs: 68 | fixed_std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | - 1.0 73 | - 1.0 74 | - 1.0 75 | - 0.1 76 | std_parameterization: "fixed" 77 | tanh_squash_distribution: false 78 | early_goal_concat: true 79 | shared_goal_encoder: true 80 | use_proprio: false 81 | network_kwargs: 82 | hidden_dims: 83 | - 256 84 | - 256 85 | - 256 86 | dropout_rate: 0.1 87 | encoder: 88 | type: "resnetv1-34-bridge" 89 | config: 90 | act: "swish" 91 | add_spatial_coordinates: false 92 | pooling_method: "avg" 93 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/general_params.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # IP address and port of the robot 3 | ip: "128.32.175.217" 4 | port: 5556 5 | 6 | # IP address of web viewer 7 | web_viewer_ip: "128.32.175.81" 8 | web_viewer_port: 5000 9 | 10 | # Robot ID 11 | robot_id: 2 12 | 13 | # General config parameters 14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper 15 | env_params: 16 | camera_topics: 17 | - name: "/blue/image_raw" 18 | flip: false 19 | override_workspace_boundaries: 20 | - - -20.0 21 | - -20.0 22 | - -20.0 23 | - -1.57 24 | - 0 25 | - - 20.0 26 | - 20.0 27 | - 20.0 28 | - 1.57 29 | - 0 30 | move_duration: 0.3 31 | video_save_path: "data_collection_logs/berkeley_robot_2/mushroom_spoon_0530" 32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam 33 | initial_eep: 34 | - 0.3 35 | - 0.0 36 | - 0.15 37 | - 0 38 | - 0 39 | - 0 40 | - 1 41 | manual_workspace_bounds: 42 | x: 43 | - 0.21887143 44 | - 0.41317162 45 | y: 46 | - -0.18368285 47 | - 0.18275802 48 | z: 49 | - 0.03712492 50 | - 0.18284303 51 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/reset_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Reset detector config params 3 | slack_token: 4 | channel_id: C069AT564KC 5 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none" 6 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/subgoal_predictor.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of high-level subgoal predictor policy 3 | diffusion_checkpoint: "kvablack/susie" 4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m" 5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax" 6 | diffusion_num_steps: 50 7 | prompt_w: 7.5 8 | context_w: 2.0 9 | image_size: 256 10 | max_subgoals: 5 11 | susie_server_ip: "localhost" 12 | susie_server_port: 7000 -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/success_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of success detector 3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none" 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/task_definition.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # task definition config params 3 | object_list: 4 | - "silver pot" 5 | - "mushroom" 6 | - "green spoon" 7 | task_list: 8 | - "put mushroom into the silver pot" 9 | - "remove mushroom from inside the silver pot and place it on the table" 10 | - "put the green spoon into the silver pot" 11 | - "remove the green spoon from inside the silver pot and place it on the table" 12 | - "move green spoon to the right side of the table" 13 | - "move green spoon to the left side of the table" 14 | - "move mushroom to the right side of the table" 15 | - "move mushroom to the left side of the table" 16 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_2/task_proposer.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of task proposer 3 | reuse_task_statistics: false 4 | rand_selection_prob: 1.0 5 | zone_center: 0.3 6 | ucb_weight: 5.0 7 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none", "human" 8 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/cogvlm_server.yaml: -------------------------------------------------------------------------------- 1 | # CogVLM server config parameters 2 | cogvlm_server_ip: "localhost" 3 | cogvlm_server_port: 6000 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | general_params: !include general_params.yaml 3 | task_proposer_params: !include task_proposer.yaml 4 | gc_policy_params: !include gc_bc.yaml 5 | subgoal_predictor_params: !include subgoal_predictor.yaml 6 | reset_detector_params: !include reset_detector.yaml 7 | success_detector_params: !include success_detector.yaml 8 | cogvlm_server_params: !include cogvlm_server.yaml 9 | task_definition_params: !include task_definition.yaml -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/gc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level goal conditioned policy 3 | policy_class: "gc_bc" 4 | checkpoint_path: "checkpoints/iterated_gcbc_round_1_policy" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 20 8 | exploration: 9 | make_traj_deterministic_prob: 0.5 10 | sampling_temperature: 0.2 11 | gripper_open_prob: 0.005 12 | gripper_close_prob: 0.01 13 | open_gripper_if_nothing_grasped: false 14 | restrict_action_space: false 15 | dataset_kwargs: 16 | shuffle_buffer_size: 25000 17 | augment: true 18 | augment_next_obs_goal_differently: false 19 | augment_kwargs: 20 | random_resized_crop: 21 | scale: 22 | - 0.8 23 | - 1.0 24 | ratio: 25 | - 0.9 26 | - 1.1 27 | random_brightness: 28 | - 0.2 29 | random_contrast: 30 | - 0.8 31 | - 1.2 32 | random_saturation: 33 | - 0.8 34 | - 1.2 35 | random_hue: 36 | - 0.1 37 | augment_order: 38 | - "random_resized_crop" 39 | - "random_brightness" 40 | - "random_contrast" 41 | - "random_saturation" 42 | - "random_hue" 43 | goal_relabeling_strategy: "geometric" 44 | goal_relabeling_kwargs: 45 | reached_proportion: 0.0 46 | discount: 0.98 47 | relabel_actions: true 48 | dataset_contains_commanded_goals: false 49 | normalization_type: "normal" 50 | ACT_MEAN: 51 | - 1.9296819e-04 52 | - 1.3667766e-04 53 | - -1.4583133e-04 54 | - -1.8390431e-04 55 | - -3.0808983e-04 56 | - 2.7425270e-04 57 | - 5.9716219e-01 58 | ACT_STD: 59 | - 0.00912848 60 | - 0.0127196 61 | - 0.01229497 62 | - 0.02606696 63 | - 0.02875283 64 | - 0.07807977 65 | - 0.48710242 66 | agent_kwargs: 67 | policy_kwargs: 68 | fixed_std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | - 1.0 73 | - 1.0 74 | - 1.0 75 | - 0.1 76 | std_parameterization: "fixed" 77 | tanh_squash_distribution: false 78 | early_goal_concat: true 79 | shared_goal_encoder: true 80 | use_proprio: false 81 | network_kwargs: 82 | hidden_dims: 83 | - 256 84 | - 256 85 | - 256 86 | dropout_rate: 0.1 87 | encoder: 88 | type: "resnetv1-34-bridge" 89 | config: 90 | act: "swish" 91 | add_spatial_coordinates: false 92 | pooling_method: "avg" 93 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/general_params.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # IP address and port of the robot 3 | ip: "128.32.175.186" 4 | port: 5556 5 | 6 | # IP address of web viewer 7 | web_viewer_ip: "128.32.175.81" 8 | web_viewer_port: 5000 9 | 10 | # Robot ID 11 | robot_id: 3 12 | 13 | # General config parameters 14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper 15 | env_params: 16 | camera_topics: 17 | - name: "/blue/image_raw" 18 | flip: false 19 | override_workspace_boundaries: 20 | - - -20.0 21 | - -20.0 22 | - -20.0 23 | - -1.57 24 | - 0 25 | - - 20.0 26 | - 20.0 27 | - 20.0 28 | - 1.57 29 | - 0 30 | move_duration: 0.3 31 | video_save_path: "data_collection_logs/berkeley_robot_3/gc_bc_eval" 32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam 33 | initial_eep: 34 | - 0.3 35 | - 0.0 36 | - 0.15 37 | - 0 38 | - 0 39 | - 0 40 | - 1 41 | manual_workspace_bounds: 42 | x: 43 | - 0.25376111 44 | - 0.44158756 45 | y: 46 | - -0.17615086 47 | - 0.18689521 48 | z: 49 | - 0.0187575 50 | - 0.13504307 51 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/lc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level language conditioned policy 3 | policy_class: "lc_bc" 4 | checkpoint_path: "checkpoints/lcbc/lcbc_bridge_v1" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 25 8 | exploration: 9 | make_traj_deterministic_prob: 1.0 10 | sampling_temperature: 0.2 11 | gripper_open_prob: 0.005 12 | gripper_close_prob: 0.01 13 | open_gripper_if_nothing_grasped: false 14 | restrict_action_space: false 15 | dataset_kwargs: 16 | shuffle_buffer_size: 25000 17 | augment: true 18 | augment_next_obs_goal_differently: false 19 | augment_kwargs: 20 | random_resized_crop: 21 | scale: 22 | - 0.8 23 | - 1.0 24 | ratio: 25 | - 0.9 26 | - 1.1 27 | random_brightness: 28 | - 0.2 29 | random_contrast: 30 | - 0.8 31 | - 1.2 32 | random_saturation: 33 | - 0.8 34 | - 1.2 35 | random_hue: 36 | - 0.1 37 | augment_order: 38 | - "random_resized_crop" 39 | - "random_brightness" 40 | - "random_contrast" 41 | - "random_saturation" 42 | - "random_hue" 43 | goal_relabeling_strategy: "uniform" 44 | goal_relabeling_kwargs: 45 | reached_proportion: 0.0 46 | relabel_actions: true 47 | normalization_type: "normal" 48 | load_language: true 49 | skip_unlabeled: true 50 | ACT_MEAN: 51 | - 1.9296819e-04 52 | - 1.3667766e-04 53 | - -1.4583133e-04 54 | - -1.8390431e-04 55 | - -3.0808983e-04 56 | - 2.7425270e-04 57 | - 5.9716219e-01 58 | ACT_STD: 59 | - 0.00912848 60 | - 0.0127196 61 | - 0.01229497 62 | - 0.02606696 63 | - 0.02875283 64 | - 0.07807977 65 | - 0.48710242 66 | agent_kwargs: 67 | policy_kwargs: 68 | fixed_std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | - 1.0 73 | - 1.0 74 | - 1.0 75 | - 0.1 76 | state_dependent_std: false 77 | tanh_squash_distribution: false 78 | early_goal_concat: true 79 | shared_goal_encoder: true 80 | use_proprio: false 81 | network_kwargs: 82 | hidden_dims: 83 | - 256 84 | - 256 85 | - 256 86 | dropout_rate: 0.1 87 | encoder: 88 | type: "resnetv1-34-bridge-film" 89 | config: 90 | act: "swish" 91 | add_spatial_coordinates: true 92 | pooling_method: "avg" 93 | text_processor: "muse_embedding" 94 | text_processor_kwargs: {} 95 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/reset_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Reset detector config params 3 | slack_token: 4 | channel_id: C069AT564KC 5 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none" 6 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/subgoal_predictor.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of high-level subgoal predictor policy 3 | diffusion_checkpoint: "kvablack/susie" 4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m" 5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax" 6 | diffusion_num_steps: 50 7 | prompt_w: 7.5 8 | context_w: 2.0 9 | image_size: 256 10 | max_subgoals: 5 11 | susie_server_ip: "localhost" 12 | susie_server_port: 7000 -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/success_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of success detector 3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none" 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/task_definition.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # task definition config params 3 | object_list: 4 | - "carrot" 5 | - "lemon" 6 | - "purple eggplant" 7 | - "blue tray" 8 | task_list: 9 | - "remove the lemon from the blue tray and put it on the table" 10 | - "move the carrot to the right side" 11 | - "move the carrot to the left side" 12 | - "put the purple eggplant on the blue tray" 13 | - "put the lemon on the blue tray" 14 | - "remove the carrot from the blue tray and put it on the table" 15 | - "remove the purple eggplant from the blue tray and put it on the table" 16 | - "put the carrot on the blue tray" 17 | 18 | 19 | # put the pink spoon on the blue tray 20 | # remove the pink spoon from the blue tray and put it on the table 21 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_3/task_proposer.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of task proposer 3 | reuse_task_statistics: false 4 | rand_selection_prob: 0.333 5 | zone_center: 0.3 6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone 7 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none", "human" 8 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/cogvlm_server.yaml: -------------------------------------------------------------------------------- 1 | # CogVLM server config parameters 2 | cogvlm_server_ip: "localhost" 3 | cogvlm_server_port: 6000 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | general_params: !include general_params.yaml 3 | task_proposer_params: !include task_proposer.yaml 4 | gc_policy_params: !include gc_bc.yaml 5 | subgoal_predictor_params: !include subgoal_predictor.yaml 6 | reset_detector_params: !include reset_detector.yaml 7 | success_detector_params: !include success_detector.yaml 8 | cogvlm_server_params: !include cogvlm_server.yaml 9 | task_definition_params: !include task_definition.yaml -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/diffusion_policy.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level goal conditioned policy 3 | policy_class: "gc_ddpm_bc" 4 | checkpoint_path: "checkpoints/red_object_metal_pot/dgcbc_95_4p5_0p5_successes_1800000" # iterated_dgcbc_80_20_230000 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 25 8 | exploration: 9 | make_traj_deterministic_prob: 1.0 10 | sampling_temperature: 2.0 11 | gripper_open_prob: 0.01 12 | gripper_close_prob: 0.02 13 | exploration_noise: 0.000 14 | dataset_kwargs: 15 | shuffle_buffer_size: 25000 16 | augment: true 17 | augment_next_obs_goal_differently: false 18 | augment_kwargs: 19 | random_resized_crop: 20 | scale: 21 | - 0.8 22 | - 1.0 23 | ratio: 24 | - 0.9 25 | - 1.1 26 | random_brightness: 27 | - 0.2 28 | random_contrast: 29 | - 0.8 30 | - 1.2 31 | random_saturation: 32 | - 0.8 33 | - 1.2 34 | random_hue: 35 | - 0.1 36 | augment_order: 37 | - "random_resized_crop" 38 | - "random_brightness" 39 | - "random_contrast" 40 | - "random_saturation" 41 | - "random_hue" 42 | goal_relabeling_strategy: "geometric" 43 | goal_relabeling_kwargs: 44 | reached_proportion: 0.0 45 | discount: 0.98 46 | relabel_actions: true 47 | act_pred_horizon: 4 48 | obs_horizon: 1 49 | ACT_MEAN: 50 | - 1.9296819e-04 51 | - 1.3667766e-04 52 | - -1.4583133e-04 53 | - -1.8390431e-04 54 | - -3.0808983e-04 55 | - 2.7425270e-04 56 | - 5.9716219e-01 57 | ACT_STD: 58 | - 0.00912848 59 | - 0.0127196 60 | - 0.01229497 61 | - 0.02606696 62 | - 0.02875283 63 | - 0.07807977 64 | - 0.48710242 65 | PROPRIO_MEAN: 66 | - 0.29730073 67 | - 0.02986212 68 | - 0.06420159 69 | - -0.00201155 70 | - -0.07586625 71 | - 0.159071 72 | - 0.75686556 73 | PROPRIO_STD: 74 | - 0.05918062 75 | - 0.09581848 76 | - 0.05275392 77 | - 0.13922517 78 | - 0.16974117 79 | - 0.6555491 80 | - 0.3397966 81 | agent_kwargs: 82 | score_network_kwargs: 83 | time_dim: 32 84 | num_blocks: 3 85 | dropout_rate: 0.1 86 | hidden_dim: 256 87 | use_layer_norm: true 88 | early_goal_concat: true 89 | shared_goal_encoder: true 90 | use_proprio: false 91 | beta_schedule: "cosine" 92 | diffusion_steps: 20 93 | action_samples: 1 94 | repeat_last_step: 0 95 | learning_rate: 3.0e-4 96 | warmup_steps: 2000 97 | actor_decay_steps: 2000000 98 | encoder: 99 | type: "resnetv1-34-bridge" 100 | config: 101 | act: "swish" 102 | add_spatial_coordinates: true 103 | pooling_method: "avg" 104 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/gc_bc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level goal conditioned policy 3 | policy_class: "gc_bc" 4 | checkpoint_path: "checkpoints/gcbc_just_bridge_vision_backbone_1_200000" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 20 8 | exploration: 9 | make_traj_deterministic_prob: 0.5 10 | sampling_temperature: 0.2 11 | gripper_open_prob: 0.005 12 | gripper_close_prob: 0.01 13 | open_gripper_if_nothing_grasped: false 14 | restrict_action_space: false 15 | dataset_kwargs: 16 | shuffle_buffer_size: 25000 17 | augment: true 18 | augment_next_obs_goal_differently: false 19 | augment_kwargs: 20 | random_resized_crop: 21 | scale: 22 | - 0.8 23 | - 1.0 24 | ratio: 25 | - 0.9 26 | - 1.1 27 | random_brightness: 28 | - 0.2 29 | random_contrast: 30 | - 0.8 31 | - 1.2 32 | random_saturation: 33 | - 0.8 34 | - 1.2 35 | random_hue: 36 | - 0.1 37 | augment_order: 38 | - "random_resized_crop" 39 | - "random_brightness" 40 | - "random_contrast" 41 | - "random_saturation" 42 | - "random_hue" 43 | goal_relabeling_strategy: "geometric" 44 | goal_relabeling_kwargs: 45 | reached_proportion: 0.0 46 | discount: 0.98 47 | relabel_actions: true 48 | dataset_contains_commanded_goals: false 49 | normalization_type: "normal" 50 | ACT_MEAN: 51 | - 1.9296819e-04 52 | - 1.3667766e-04 53 | - -1.4583133e-04 54 | - -1.8390431e-04 55 | - -3.0808983e-04 56 | - 2.7425270e-04 57 | - 5.9716219e-01 58 | ACT_STD: 59 | - 0.00912848 60 | - 0.0127196 61 | - 0.01229497 62 | - 0.02606696 63 | - 0.02875283 64 | - 0.07807977 65 | - 0.48710242 66 | agent_kwargs: 67 | policy_kwargs: 68 | fixed_std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | - 1.0 73 | - 1.0 74 | - 1.0 75 | - 0.1 76 | std_parameterization: "fixed" 77 | tanh_squash_distribution: false 78 | early_goal_concat: true 79 | shared_goal_encoder: true 80 | use_proprio: false 81 | network_kwargs: 82 | hidden_dims: 83 | - 256 84 | - 256 85 | - 256 86 | dropout_rate: 0.1 87 | encoder: 88 | type: "resnetv1-34-bridge" 89 | config: 90 | act: "swish" 91 | add_spatial_coordinates: false 92 | pooling_method: "avg" 93 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/general_params.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # IP address and port of the robot 3 | ip: "128.32.175.227" 4 | port: 5556 5 | 6 | # IP address of web viewer 7 | web_viewer_ip: "128.32.175.81" 8 | web_viewer_port: 5000 9 | 10 | # Robot ID 11 | robot_id: 4 12 | 13 | # General config parameters 14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper 15 | joints_reboot_interval: 1 16 | env_params: 17 | camera_topics: 18 | - name: "/blue/image_raw" 19 | flip: false 20 | override_workspace_boundaries: 21 | - - -20.0 22 | - -20.0 23 | - -20.0 24 | - -1.57 25 | - 0 26 | - - 20.0 27 | - 20.0 28 | - 20.0 29 | - 1.57 30 | - 0 31 | move_duration: 0.3 32 | video_save_path: "data_collection_logs/berkeley_robot_4/drawer_iql" 33 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam 34 | initial_eep: 35 | - 0.3 36 | - 0.0 37 | - 0.15 38 | - 0 39 | - 0 40 | - 0 41 | - 1 42 | manual_workspace_bounds: 43 | x: 44 | - 0.22816383 45 | - 0.40895109 46 | y: 47 | - -0.11804535 48 | - 0.04906207 49 | z: 50 | - 0.04494154 51 | - 0.13566369 52 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/rcsl.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of low-level goal conditioned policy 3 | policy_class: "gc_ddpm_bc" 4 | checkpoint_path: "checkpoints/valentines_day_evals/rcsl_120k" 5 | mini_dataset_path: "mini_dataset" 6 | image_size: 256 7 | rollout_timesteps: 30 8 | exploration_noise: 0.000 9 | dataset_kwargs: 10 | shuffle_buffer_size: 25000 11 | augment: true 12 | augment_next_obs_goal_differently: false 13 | augment_kwargs: 14 | random_resized_crop: 15 | scale: 16 | - 0.8 17 | - 1.0 18 | ratio: 19 | - 0.9 20 | - 1.1 21 | random_brightness: 22 | - 0.2 23 | random_contrast: 24 | - 0.8 25 | - 1.2 26 | random_saturation: 27 | - 0.8 28 | - 1.2 29 | random_hue: 30 | - 0.1 31 | augment_order: 32 | - "random_resized_crop" 33 | - "random_brightness" 34 | - "random_contrast" 35 | - "random_saturation" 36 | - "random_hue" 37 | goal_relabeling_strategy: "geometric" 38 | goal_relabeling_kwargs: 39 | reached_proportion: 0.0 40 | discount: 0.98 41 | relabel_actions: true 42 | act_pred_horizon: 4 43 | obs_horizon: 1 44 | normalization_type: "tanh_normal" 45 | dataset_contains_commanded_goals: false 46 | ACT_MEAN: 47 | - 1.9296819e-04 48 | - 1.3667766e-04 49 | - -1.4583133e-04 50 | - -1.8390431e-04 51 | - -3.0808983e-04 52 | - 2.7425270e-04 53 | - 5.9716219e-01 54 | ACT_STD: 55 | - 0.00912848 56 | - 0.0127196 57 | - 0.01229497 58 | - 0.02606696 59 | - 0.02875283 60 | - 0.07807977 61 | - 0.48710242 62 | agent_kwargs: 63 | score_network_kwargs: 64 | time_dim: 32 65 | num_blocks: 3 66 | dropout_rate: 0.1 67 | hidden_dim: 256 68 | use_layer_norm: true 69 | early_goal_concat: true 70 | shared_goal_encoder: true 71 | use_proprio: false 72 | beta_schedule: "cosine" 73 | diffusion_steps: 20 74 | action_samples: 1 75 | repeat_last_step: 0 76 | learning_rate: 3.0e-4 77 | warmup_steps: 2000 78 | actor_decay_steps: 2000000 79 | encoder: 80 | type: "resnetv1-34-bridge" 81 | config: 82 | act: "swish" 83 | add_spatial_coordinates: true 84 | pooling_method: "avg" -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/reset_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Reset detector config params 3 | slack_token: 4 | channel_id: C069AT564KC 5 | which_vlm: "none" # choices: "cogvlm", "gpt4v", "none" 6 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/subgoal_predictor.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of high-level subgoal predictor policy 3 | diffusion_checkpoint: "kvablack/susie" 4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m" 5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax" 6 | diffusion_num_steps: 50 7 | prompt_w: 7.5 8 | context_w: 2.0 9 | image_size: 256 10 | max_subgoals: 5 11 | susie_server_ip: "localhost" 12 | susie_server_port: 7000 -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/success_detector.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of success detector 3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none" 4 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/task_definition.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # task definition config params 3 | object_list: 4 | - "drawer" 5 | task_list: 6 | - "open the drawer" 7 | - "close the drawer" 8 | -------------------------------------------------------------------------------- /data_collection/config/berkeley_robot_4/task_proposer.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Config parameters of task proposer 3 | reuse_task_statistics: true 4 | rand_selection_prob: 0.5 5 | zone_center: 0.3 6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone 7 | which_vlm: "cycling" # choices: "cogvlm", "gpt4v", "none", "human" 8 | -------------------------------------------------------------------------------- /data_collection/orchestrator/cogvlm_server/batched.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from PIL import Image 4 | from transformers import AutoModelForCausalLM, LlamaTokenizer 5 | import json 6 | import numpy as np 7 | from tqdm import tqdm 8 | from flask import Flask 9 | from flask import Flask, request, jsonify 10 | import copy 11 | import io 12 | 13 | MODEL_PATH = "THUDM/cogvlm-chat-hf" 14 | TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5" 15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) 18 | torch_type = torch.bfloat16 19 | 20 | print("Loading model to GPU mem") 21 | model = ( 22 | AutoModelForCausalLM.from_pretrained( 23 | MODEL_PATH, torch_dtype=torch_type, load_in_4bit=False, trust_remote_code=True 24 | ) 25 | .to(DEVICE) 26 | .eval() 27 | ) 28 | 29 | app = Flask(__name__) 30 | 31 | 32 | @app.route("/query", methods=["POST"]) 33 | def query_model(): 34 | global tokenizer, model, torch_type, DEVICE 35 | 36 | image_files = request.files.getlist("image") 37 | prompts = request.form.getlist("prompt") 38 | 39 | if not image_files or not prompts or len(image_files) != len(prompts): 40 | return jsonify({"error": "Missing image(s) or prompt(s)"}), 400 41 | 42 | model_inputs = [] 43 | max_len = -1 44 | for i, file in enumerate(image_files): 45 | image = Image.open(file.stream) 46 | image = np.array(image)[:, :, :3] 47 | image = Image.fromarray(image).resize((490, 490), Image.LANCZOS) 48 | 49 | input_by_model = model.build_conversation_input_ids( 50 | tokenizer, query=prompts[i], history=[], images=[image] 51 | ) 52 | inputs = { 53 | "input_ids": input_by_model["input_ids"].unsqueeze(0).to(DEVICE), 54 | "token_type_ids": input_by_model["token_type_ids"].unsqueeze(0).to(DEVICE), 55 | "attention_mask": input_by_model["attention_mask"].unsqueeze(0).to(DEVICE), 56 | "images": [[input_by_model["images"][0].to(DEVICE).to(torch_type)]], 57 | } 58 | model_inputs.append(inputs) 59 | max_len = max(max_len, inputs["input_ids"].shape[1]) 60 | concatenated_images = [] 61 | for i in range(len(model_inputs)): 62 | tensor_shape = model_inputs[i]["input_ids"].shape[1] 63 | model_inputs[i]["input_ids"] = torch.cat( 64 | [ 65 | torch.zeros( 66 | (1, max_len - tensor_shape), dtype=torch.long, device=DEVICE 67 | ), 68 | model_inputs[i]["input_ids"], 69 | ], 70 | dim=1, 71 | ) 72 | model_inputs[i]["token_type_ids"] = torch.cat( 73 | [ 74 | torch.zeros( 75 | (1, max_len - tensor_shape), dtype=torch.long, device=DEVICE 76 | ), 77 | model_inputs[i]["token_type_ids"], 78 | ], 79 | dim=1, 80 | ) 81 | model_inputs[i]["attention_mask"] = torch.cat( 82 | [ 83 | torch.zeros( 84 | (1, max_len - tensor_shape), dtype=torch.long, device=DEVICE 85 | ), 86 | model_inputs[i]["attention_mask"], 87 | ], 88 | dim=1, 89 | ) 90 | concatenated_images.append(model_inputs[i]["images"][0]) 91 | combined_inputs = { 92 | "input_ids": torch.cat([inputs["input_ids"] for inputs in model_inputs], dim=0), 93 | "token_type_ids": torch.cat( 94 | [inputs["token_type_ids"] for inputs in model_inputs], dim=0 95 | ), 96 | "attention_mask": torch.cat( 97 | [inputs["attention_mask"] for inputs in model_inputs], dim=0 98 | ), 99 | "images": concatenated_images, 100 | } 101 | 102 | gen_kwargs = {"max_length": 2048, "temperature": 0.0, "do_sample": False} 103 | 104 | with torch.no_grad(): 105 | outputs = model.generate(**combined_inputs, **gen_kwargs) 106 | generation_strings = [] 107 | for i in range(len(outputs)): 108 | output_tokens = outputs[i][max_len:] 109 | generation = tokenizer.decode(output_tokens) 110 | generation = generation[: generation.index("")] 111 | generation_strings.append(generation) 112 | 113 | return ( 114 | jsonify( 115 | { 116 | "response": generation_strings, 117 | } 118 | ), 119 | 200, 120 | ) 121 | 122 | 123 | if __name__ == "__main__": 124 | app.run(debug=False, host="0.0.0.0", port=5000) 125 | -------------------------------------------------------------------------------- /data_collection/orchestrator/cogvlm_server/forwarding_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | from flask import Flask, request, jsonify 7 | import copy 8 | import io 9 | from io import BytesIO 10 | import requests 11 | 12 | app = Flask(__name__) 13 | 14 | 15 | @app.route("/query", methods=["POST"]) 16 | def forward_request(): 17 | print("recieved request") 18 | image_files = request.files.getlist("image") 19 | prompts = request.form.getlist("prompt") 20 | if not image_files or not prompts or len(image_files) != len(prompts): 21 | return jsonify({"error": "Missing image(s) or prompt(s)"}), 400 22 | numpy_images, prompt_strings = [], [] 23 | for i, file in enumerate(image_files): 24 | image = Image.open(file.stream) 25 | image = np.array(image)[:, :, :3] 26 | numpy_images.append(image) 27 | prompt_strings.append(prompts[i]) 28 | files = [] 29 | for i, (numpy_image, prompt) in enumerate(zip(numpy_images, prompt_strings)): 30 | pil_image = Image.fromarray(numpy_image.astype("uint8")) 31 | img_byte_arr = io.BytesIO() 32 | pil_image.save(img_byte_arr, format="PNG") # can be JPEG or other formats 33 | img_byte_arr.seek(0) 34 | 35 | # Append the image file 36 | files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) 37 | 38 | # Append the corresponding prompt 39 | files.append((f"prompt", (None, prompt))) 40 | url = "http://localhost:7000/query" 41 | response = requests.post(url, files=files) 42 | if response.status_code == 200: 43 | json_response = response.json() 44 | return jsonify(json_response), 200 45 | else: 46 | return jsonify({"error": "something went wrong"}), 500 47 | 48 | 49 | if __name__ == "__main__": 50 | app.run(debug=False, host="0.0.0.0", port=6000) 51 | -------------------------------------------------------------------------------- /data_collection/orchestrator/cogvlm_server/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from PIL import Image 4 | from transformers import AutoModelForCausalLM, LlamaTokenizer 5 | import json 6 | import numpy as np 7 | from tqdm import tqdm 8 | from flask import Flask 9 | from flask import Flask, request, jsonify 10 | import copy 11 | import io 12 | 13 | MODEL_PATH = "THUDM/cogvlm-chat-hf" 14 | TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5" 15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) 18 | torch_type = torch.bfloat16 19 | 20 | print("Loading model to GPU mem") 21 | model = ( 22 | AutoModelForCausalLM.from_pretrained( 23 | MODEL_PATH, torch_dtype=torch_type, load_in_4bit=False, trust_remote_code=True 24 | ) 25 | .to(DEVICE) 26 | .eval() 27 | ) 28 | 29 | app = Flask(__name__) 30 | 31 | 32 | @app.route("/query", methods=["POST"]) 33 | def query_model(): 34 | global tokenizer, model, torch_type, DEVICE 35 | 36 | if "image" not in request.files or "prompt" not in request.form: 37 | return jsonify({"error": "Missing image or prompt"}), 400 38 | 39 | prompt = request.form["prompt"] 40 | 41 | file = request.files["image"] 42 | image = Image.open(file.stream) 43 | 44 | # Resize image to 490x490 45 | image = np.array(image)[:, :, :3] 46 | image = Image.fromarray(image).resize((490, 490), Image.LANCZOS) 47 | 48 | input_by_model = model.build_conversation_input_ids( 49 | tokenizer, query=prompt, history=[], images=[image] 50 | ) 51 | inputs = { 52 | "input_ids": input_by_model["input_ids"].unsqueeze(0).to(DEVICE), 53 | "token_type_ids": input_by_model["token_type_ids"].unsqueeze(0).to(DEVICE), 54 | "attention_mask": input_by_model["attention_mask"].unsqueeze(0).to(DEVICE), 55 | "images": [[input_by_model["images"][0].to(DEVICE).to(torch_type)]], 56 | } 57 | 58 | gen_kwargs = {"max_length": 2048, "temperature": 0.0, "do_sample": False} 59 | 60 | with torch.no_grad(): 61 | outputs = model.generate(**inputs, **gen_kwargs) 62 | outputs = outputs[:, inputs["input_ids"].shape[1] :] 63 | response = tokenizer.decode(outputs[0]) 64 | response = response.split("")[0] 65 | 66 | return ( 67 | jsonify( 68 | { 69 | "response": response, 70 | } 71 | ), 72 | 200, 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | app.run(debug=False, host="0.0.0.0", port=5000) 78 | -------------------------------------------------------------------------------- /data_collection/orchestrator/cogvlm_server/test-img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/data_collection/orchestrator/cogvlm_server/test-img.png -------------------------------------------------------------------------------- /data_collection/orchestrator/cogvlm_server/test.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from pathlib import Path 3 | 4 | # Your Flask API endpoint URL 5 | url = "http://localhost:5001/query" # We're testing out ssh port forwarding 6 | 7 | image_path = "test-img.png" 8 | prompt_text = "Describe the image." 9 | 10 | # Make sure the image path is valid 11 | if not Path(image_path).is_file(): 12 | print("The image file does not exist.") 13 | exit() 14 | 15 | # Open the image in binary mode 16 | with open(image_path, "rb") as image_file: 17 | # Prepare the data for the POST request 18 | payload = {"prompt": (None, prompt_text)} 19 | files = {"image": (image_path, image_file, "multipart/form-data")} 20 | 21 | # Send the POST request to the Flask API endpoint 22 | response = requests.post(url, files=files, data=payload) 23 | 24 | # Check the response status code 25 | if response.ok: 26 | print(response.json()) 27 | else: 28 | print(f"Error: {response.status_code}") 29 | print(response.text) 30 | -------------------------------------------------------------------------------- /data_collection/orchestrator/cogvlm_server/test_batched.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | import io 4 | import numpy as np 5 | 6 | img = Image.open("test-img.png") 7 | 8 | # Your images as numpy arrays 9 | numpy_images = [np.array(img), np.array(img)] 10 | prompts = [ 11 | "On which side of the metal tray is the coke can?", 12 | "On which side of the metal tray is the capsicum?", 13 | ] 14 | 15 | numpy_images = 5 * numpy_images 16 | prompts = 5 * prompts 17 | 18 | # The server endpoint 19 | url = "http://localhost:6000/query" 20 | 21 | files = [] 22 | for i, (numpy_image, prompt) in enumerate(zip(numpy_images, prompts)): 23 | pil_image = Image.fromarray(numpy_image.astype("uint8")) 24 | img_byte_arr = io.BytesIO() 25 | pil_image.save(img_byte_arr, format="PNG") # can be JPEG or other formats 26 | img_byte_arr.seek(0) 27 | 28 | # Append the image file 29 | files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) 30 | 31 | # Append the corresponding prompt 32 | files.append((f"prompt", (None, prompt))) 33 | 34 | # Perform the request 35 | response = requests.post(url, files=files) 36 | 37 | # Response handling 38 | if response.status_code == 200: 39 | print("Success:", response.json()) 40 | else: 41 | print("Error:", response.status_code, response.text) 42 | -------------------------------------------------------------------------------- /data_collection/orchestrator/robot/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import os 6 | import subprocess 7 | import json 8 | 9 | 10 | class Logger: 11 | def __init__(self, config): 12 | self.robot_id = str(config["general_params"]["robot_id"]) 13 | self.video_save_path = config["general_params"]["video_save_path"] 14 | # make sure video_save_path exists 15 | if not os.path.exists(self.video_save_path): 16 | os.makedirs(self.video_save_path) 17 | 18 | # Find the highest index trajectory already saved 19 | # in the save folder, and start counting trajectories 20 | # with one above that number 21 | indices = [ 22 | int(name[4:]) 23 | for name in os.listdir(self.video_save_path) 24 | if os.path.isdir(os.path.join(self.video_save_path, name)) 25 | ] 26 | self.traj_idx = max(indices) + 1 if len(indices) != 0 else 0 27 | 28 | # Get image sizes from the config 29 | self.obs_image_size = config["gc_policy_params"]["image_size"] 30 | self.goal_image_size = config["subgoal_predictor_params"]["image_size"] 31 | 32 | # Initialize data structures for things we are logging 33 | self.obs_images = [] 34 | self.goal_images = [] 35 | self.actions = [] 36 | self.poses = [] 37 | 38 | # We will also log the scene information for each trajectory 39 | self.object_list = config["task_definition_params"]["object_list"] 40 | self.task_list = config["task_definition_params"]["task_list"] 41 | 42 | def log_obs(self, image: np.ndarray): 43 | assert image.shape == ( 44 | self.obs_image_size, 45 | self.obs_image_size, 46 | 3, 47 | ), "Cannot log incorrectly shaped obs image" 48 | self.obs_images.append(image) 49 | 50 | def log_goal(self, image: np.ndarray): 51 | assert image.shape == ( 52 | self.goal_image_size, 53 | self.goal_image_size, 54 | 3, 55 | ), "Cannot log incorrectly shaped goal image" 56 | self.goal_images.append(image) 57 | 58 | def log_action(self, action: np.ndarray): 59 | assert action.shape == (7,), "Action should have 7 dimensions" 60 | self.actions.append(action) 61 | 62 | def log_pose(self, pose: np.ndarray): 63 | """ 64 | This method logs the pose of the robot before the action is taken 65 | """ 66 | assert pose.shape == (7,), "Robot pose should have 7 dimensions" 67 | self.poses.append(pose) 68 | 69 | def reset(self): 70 | self.obs_images = [] 71 | self.goal_images = [] 72 | self.actions = [] 73 | self.poses = [] 74 | 75 | def flush_trajectory( 76 | self, commanded_task: str, success: bool, log_combined: bool = True 77 | ): 78 | subdir_path = os.path.join(self.video_save_path, "traj" + str(self.traj_idx)) 79 | if not os.path.exists(subdir_path): 80 | os.makedirs(subdir_path) 81 | 82 | # Log the language task 83 | with open(os.path.join(subdir_path, "language_task.txt"), "w") as f: 84 | f.write(commanded_task) 85 | 86 | # Log the success information 87 | with open(os.path.join(subdir_path, "success.txt"), "w") as f: 88 | f.write(str(success)) 89 | 90 | # Log the actions 91 | np.save(os.path.join(subdir_path, "actions.npy"), np.array(self.actions)) 92 | 93 | # Log the robot poses 94 | np.save(os.path.join(subdir_path, "eef_poses.npy"), np.array(self.poses)) 95 | 96 | # Log the observation video 97 | size = (self.obs_image_size, self.obs_image_size) 98 | out = cv2.VideoWriter( 99 | os.path.join(subdir_path, "trajectory.mp4"), 100 | cv2.VideoWriter_fourcc(*"DIVX"), 101 | 15, 102 | size, 103 | ) 104 | for i in range(len(self.obs_images)): 105 | rgb_img = cv2.cvtColor(self.obs_images[i], cv2.COLOR_RGB2BGR) 106 | out.write(rgb_img) 107 | out.release() 108 | 109 | # Log the goals video 110 | size = (self.goal_image_size, self.goal_image_size) 111 | out = cv2.VideoWriter( 112 | os.path.join(subdir_path, "goals.mp4"), 113 | cv2.VideoWriter_fourcc(*"DIVX"), 114 | 15, 115 | size, 116 | ) 117 | for i in range(len(self.goal_images)): 118 | rgb_img = cv2.cvtColor(self.goal_images[i], cv2.COLOR_RGB2BGR) 119 | out.write(rgb_img) 120 | out.release() 121 | 122 | # Log the combined image 123 | if log_combined: 124 | assert ( 125 | self.obs_image_size == self.goal_image_size 126 | ), "To log combined video obs and goal images must be the same size" 127 | assert len(self.obs_images) == len( 128 | self.goal_images 129 | ), "To log combined video there must be equal number of obs and goal images" 130 | size = (self.obs_image_size + self.goal_image_size, self.obs_image_size) 131 | out = cv2.VideoWriter( 132 | os.path.join(subdir_path, "combined.mp4"), 133 | cv2.VideoWriter_fourcc(*"DIVX"), 134 | 15, 135 | size, 136 | ) 137 | for i in range(len(self.goal_images)): 138 | combined_image = np.concatenate( 139 | [self.obs_images[i], self.goal_images[i]], axis=1 140 | ) 141 | rgb_img = cv2.cvtColor(combined_image, cv2.COLOR_RGB2BGR) 142 | out.write(rgb_img) 143 | out.release() 144 | 145 | # Log the scene information 146 | obj_list_dest = os.path.join(subdir_path, "object_list.txt") 147 | task_list_dest = os.path.join(subdir_path, "task_list.txt") 148 | time_dest = os.path.join(subdir_path, "time.txt") 149 | robot_id_dest = os.path.join(subdir_path, "robot_id.txt") 150 | with open(obj_list_dest, "w") as f: 151 | json.dump(self.object_list, f, indent=4) 152 | with open(task_list_dest, "w") as f: 153 | json.dump(self.task_list, f, indent=4) 154 | time = subprocess.check_output("date", shell=True).decode("utf-8").strip() 155 | with open(time_dest, "w") as f: 156 | f.write(time) 157 | robot_id = self.robot_id 158 | with open(robot_id_dest, "w") as f: 159 | f.write(robot_id) 160 | 161 | # Reset variables 162 | self.obs_images = [] 163 | self.goal_images = [] 164 | self.actions = [] 165 | self.poses = [] 166 | self.traj_idx += 1 167 | -------------------------------------------------------------------------------- /data_collection/orchestrator/robot/reset_detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from slack_sdk import WebClient 3 | from slack_sdk.errors import SlackApiError 4 | import utils 5 | from nltk.corpus import cmudict 6 | 7 | 8 | def send_slack_message(client, channel_id, message): 9 | try: 10 | response = client.chat_postMessage(channel=channel_id, text=message) 11 | assert response["message"]["text"] == message 12 | except SlackApiError as e: 13 | assert e.response["ok"] is False 14 | # str like 'invalid_auth', 'channel_not_found' 15 | assert e.response["error"] 16 | print(f"Got an error: {e.response['error']}") 17 | 18 | 19 | class ResetDetector: 20 | """ 21 | Uses a VLM to predict whether an object is missing (and thus a reset is required). 22 | If so, the Slack API is used to send a notification 23 | """ 24 | 25 | def __init__(self, config): 26 | self.config = config 27 | 28 | # Load object list 29 | self.objects = self.config["task_definition_params"]["object_list"] 30 | 31 | # Prepare prompt strings 32 | self.prompt_components = [ 33 | "Is there ", 34 | " present in the image?", 35 | "I gave a vision-language model an image and asked if there was ", 36 | " present in the image. The following was the model's response:\n\n######\n", 37 | "\n######\n\nIf the model thought the object in question was present, return just the word 'true'. Else return 'false'.", 38 | ] 39 | 40 | # connecting to slack 41 | slack_token = self.config["reset_detector_params"]["slack_token"] 42 | self.slack_client = WebClient(token=slack_token) 43 | 44 | # cache for anything you want 45 | self.cache = {} 46 | 47 | # cache the status of each object so we don't send slack messages every time 48 | self.cache["object_status"] = {obj: True for obj in self.objects} 49 | 50 | def prepend_article(self, object_name): 51 | first_word = object_name.split()[0] 52 | 53 | def starts_with_vowel_sound(word, pronunciations=cmudict.dict()): 54 | for syllables in pronunciations.get(word, []): 55 | return syllables[0][-1].isdigit() # use only the first one 56 | 57 | if starts_with_vowel_sound(first_word): 58 | return "an " + object_name 59 | return "a " + object_name 60 | 61 | def detect(self, image: np.ndarray): 62 | objects_present = [] 63 | if self.config["reset_detector_params"]["which_vlm"] != "none": 64 | 65 | # prompts for the VLM for each object 66 | image_prompts = [] 67 | for object in self.objects: 68 | image_prompt = ( 69 | self.prompt_components[0] 70 | + self.prepend_article(object) 71 | + self.prompt_components[1] 72 | ) 73 | image_prompts.append(image_prompt) 74 | 75 | # image for the VLM for each object 76 | images = [image] * len(self.objects) 77 | 78 | if self.config["reset_detector_params"]["which_vlm"] == "cogvlm": 79 | vlm_output = utils.ask_cogvlm_batched( 80 | images, image_prompts, self.config 81 | ) 82 | elif self.config["reset_detector_params"]["which_vlm"] == "gpt4v": 83 | vlm_output = utils.ask_gpt4v_batched(images, image_prompts) 84 | else: 85 | vlm_output = ( 86 | "Yes, " + self.prepend_article(object) + " is present in the image." 87 | ) 88 | print("vlm output:", vlm_output) 89 | 90 | # decoding prompts for the LLM 91 | decoding_prompts = [ 92 | self.prompt_components[2] 93 | + self.prepend_article(self.objects[i]) 94 | + self.prompt_components[3] 95 | + vlm_output[i].strip() 96 | + self.prompt_components[4] 97 | for i in range(len(self.objects)) 98 | ] 99 | print("decoding prompts:", decoding_prompts) 100 | 101 | llm_output = utils.ask_gpt4_batched(decoding_prompts, cache=self.cache) 102 | for q, r in zip(decoding_prompts, llm_output): 103 | self.cache[q] = r 104 | 105 | # object presence info 106 | updated_object_status = {} 107 | for llm_answer, obj in zip(llm_output, self.objects): 108 | in_scene = llm_answer.strip().lower() == "true" 109 | try: 110 | assert in_scene in (True, False) 111 | except AssertionError: 112 | in_scene = False 113 | updated_object_status[obj] = in_scene 114 | print(f"Reset Detector LLM said {obj} is present:", in_scene) 115 | objects_present.append(in_scene) 116 | 117 | else: 118 | objects_present = [True] * len(self.objects) 119 | updated_object_status = {obj: True for obj in self.objects} 120 | 121 | objects_not_present = [] 122 | for i in range(len(objects_present)): 123 | if not objects_present[i]: 124 | objects_not_present.append(self.objects[i]) 125 | 126 | # missing objects are those in objects_not_present, but have previous status as True 127 | # i.e. they were present the last time ResetDetector was run but not this time 128 | missing_objects = [ 129 | obj for obj in objects_not_present if self.cache["object_status"][obj] 130 | ] 131 | if len(missing_objects) != 0: 132 | # Reset required, send slack message 133 | message = f"Hey! Robot {self.config['general_params']['robot_id']} is missing objects: " 134 | message += ", ".join(missing_objects) 135 | 136 | channel_id = self.config["reset_detector_params"]["channel_id"] 137 | send_slack_message(self.slack_client, channel_id, message) 138 | 139 | # update object status 140 | self.cache["object_status"] = updated_object_status 141 | 142 | # Prepare return dictionary 143 | to_return = {} 144 | for i in range(len(objects_present)): 145 | to_return[self.objects[i]] = objects_present[ 146 | i 147 | ] # we're commenting this out bc the VLM can be unreliable, and we don't want to not propose tasks we can actually execute 148 | return to_return 149 | -------------------------------------------------------------------------------- /data_collection/orchestrator/robot/subgoal_predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import requests 3 | from PIL import Image 4 | import io 5 | 6 | 7 | class SubgoalPredictor: 8 | def __init__(self, config): 9 | diffusion_config = config["subgoal_predictor_params"] 10 | self.image_size = diffusion_config["image_size"] 11 | self.url = ( 12 | "http://" 13 | + diffusion_config["susie_server_ip"] 14 | + ":" 15 | + str(diffusion_config["susie_server_port"]) 16 | + "/generate_subgoal" 17 | ) 18 | 19 | def numpy_to_image(self, np_array): 20 | """Convert a NumPy array to a PIL Image.""" 21 | return Image.fromarray(np.uint8(np_array)) 22 | 23 | def image_to_numpy(self, image): 24 | """Convert a PIL Image to a NumPy array.""" 25 | return np.array(image) 26 | 27 | def send_image_and_text(self, url, np_image, text): 28 | """Send a NumPy image array and text to the specified URL.""" 29 | # Convert NumPy array to PIL Image 30 | image = self.numpy_to_image(np_image) 31 | 32 | # Save the PIL Image to a bytes buffer 33 | img_buffer = io.BytesIO() 34 | image.save(img_buffer, format="JPEG") 35 | img_buffer.seek(0) 36 | 37 | # Prepare files and data for the request 38 | files = {"image": ("image.jpg", img_buffer, "image/jpeg")} 39 | data = {"text": text} 40 | 41 | # Send POST request 42 | response = requests.post(url, files=files, data=data) 43 | return response 44 | 45 | def __call__(self, image_obs: np.ndarray, prompt: str): 46 | assert image_obs.shape == ( 47 | self.image_size, 48 | self.image_size, 49 | 3, 50 | ), "Bad input image shape" 51 | 52 | response = self.send_image_and_text(self.url, image_obs, prompt) 53 | if response.status_code == 200: 54 | # Convert the response content back to a NumPy array 55 | image = Image.open(io.BytesIO(response.content)) 56 | output_np_image = self.image_to_numpy(image) 57 | else: 58 | print("Failed to process image", response.status_code, response.text) 59 | return None 60 | 61 | return output_np_image 62 | -------------------------------------------------------------------------------- /data_collection/orchestrator/robot/task_success_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | import utils 7 | 8 | 9 | class SuccessPredictor: 10 | """ 11 | Uses a VLM to predict whether a given task has been completed 12 | """ 13 | 14 | def __init__(self, config): 15 | self.config = config 16 | 17 | # Prompt to convert task description into VQA style question that can be used to assess task completion 18 | self.task_to_vqa_prompt = """I have a robot arm manipulator in a lab setting that can perform many manipulation tasks. I commanded a task to the robot to perform and I now want to assess whether the robot was successful in completing this task.\n\nTo determine whether or not the robot successfully completed the task, I have access to a vision-language model (VLM) that can answer questions for me when I provide it an image of the lab environment the robot is operating in. Since I can only ask it simple questions about the image, I need to convert the task description into a question that I can feed into the VLM.\n\nHere are several examples:\n\nTask: put the yellow ball on the blue plate\nVLM question: Is the yellow ball on the blue plate?\n\nTask: move the yellow ball from the blue plate to the table\nVLM question: Is the yellow ball on the blue plate or the table?\n\nTask: move the yellow ball to the left side of the table\nVLM question: Is the yellow ball on the left side of the table?\n\nTask: move the yellow ball to the right side of the table\nVLM question: Is the yellow ball on the right side of the table?\n\nTask: put the orange crayon on the blue plate\nVLM question: Is the orange crayon on the blue plate?\n\nTask: move the orange crayon from the blue plate to the table\nVLM question: Is the orange crayon on the blue plate or the table?\n\nTask: put the orange crayon on the cloth\nVLM question: Is the orange crayon on top of the cloth?\n\nTask: move the orange crayon from the cloth to the table\nVLM question: Is the orange crayon on the cloth or the table?\n\nTask: move the orange crayon from the cloth to the blue plate\nVLM question: Is the orange crayon on the cloth or the blue plate?\n\nTask: move the orange crayon from the blue plate to the cloth\nVLM question: Is the orange crayon on the blue plate or the cloth?\n\nTask: move the orange crayon to the right side of the table\nVLM question: Is the orange crayon on the right side of the table?\n\nTask: move the orange crayon to the left side of the table\nVLM question: Is the orange crayon on the left side of the table?\n\nTask: move the red object from the blue plate to the table\nVLM question: Is the red object on the blue plate or the table?\n\nTask: put the red object on the blue plate\nVLM question: Is the red object on the blue plate?\n\nTask: move the red object to the left side of the table\nVLM question: Is the red object on the left side of the table?\n\nTask: move the red object to the right side of the table\nVLM question: Is the red object on the right side of the table?\n\nTask: put the red object on the cloth\nVLM question: Is the red object on the cloth?\n\nTask: move the red object from the cloth to the table\nVLM question: Is the red object on the cloth or the table?\n\nTask: move the red object from the cloth to the blue plate\nVLM question: Is the red object on the cloth or the blue plate?\n\nTask: move the red object from the blue plate to the cloth\nVLM question: Is the red object on the blue plate or the cloth?\n\nTask: move the cloth to the right side of the table\nVLM question: Is the cloth on the right side of the table?\n\nTask: move the cloth to the left side of the table\nVLM question: Is the cloth on the left side of the table?\n\nFollowing the format of these examples, give me the VLM question for the following task:\n\nTask: """ 19 | 20 | # Prompt to decode VLM output into true/false 21 | self.prompt_to_parse_vlm_output = [ 22 | "I have a robot arm manipulator in a lab setting. I commanded it to complete the following task:\n\n", 23 | "\n\nI want to assess whether the robot arm successfully completed the task. To do so, I prompted a vision-language model (VLM) with an image of the current robot workspace and the following question:\n\n", 24 | "\n\nIn response, the VLM answered the following:\n\n", 25 | "\n\nBased on the task commanded and the VLM's response to the question, determine if the robot successfully completed the commanded task or not. If it did successfully complete the task, return just the word true. Otherwise return the word false. If for some reason the answer is neither true nor false, return false.", 26 | ] 27 | 28 | # Use for caching anything you want 29 | self.cache = {} 30 | 31 | # to record the success rates 32 | self.task_success_record = ( 33 | self.init_previous_task_stats() 34 | ) # task -> list of bools 35 | 36 | def init_previous_task_stats(self): 37 | task_success_record = {} 38 | 39 | if self.config["task_proposer_params"]["reuse_task_statistics"]: 40 | trajectory_log_dir = self.config["general_params"]["video_save_path"] 41 | logged_trajs = [ 42 | traj 43 | for traj in os.listdir(trajectory_log_dir) 44 | if os.path.isdir(os.path.join(trajectory_log_dir, traj)) 45 | ] 46 | for traj in tqdm(logged_trajs): 47 | traj_path = os.path.join(trajectory_log_dir, traj) 48 | with open(os.path.join(traj_path, "language_task.txt")) as f: 49 | traj_task = f.read().strip().lower() 50 | with open(os.path.join(traj_path, "success.txt")) as f: 51 | traj_success = f.read().strip().lower() 52 | if traj_task not in task_success_record: 53 | task_success_record[traj_task] = [] 54 | task_success_record[traj_task].append(traj_success == "true") 55 | 56 | return task_success_record 57 | 58 | def record_task_success(self, task_str, success): 59 | if task_str not in self.task_success_record: 60 | self.task_success_record[task_str] = [] 61 | self.task_success_record[task_str].append(success) 62 | 63 | def get_success_rate(self, n_most_recent=None): 64 | success_rates = {} 65 | for task_str, success_list in self.task_success_record.items(): 66 | if n_most_recent is not None: 67 | success_list = success_list[-n_most_recent:] 68 | success_rates[task_str] = sum(success_list) / len(success_list) 69 | return success_rates 70 | 71 | def predict_outcome(self, image: np.ndarray, task_str: str, log_metrics=True): 72 | # convert the task_str into a VQA style question 73 | vqa_style_q_unparsed = utils.ask_gpt4( 74 | self.task_to_vqa_prompt + task_str, cache=self.cache 75 | ) 76 | 77 | # add response to cache 78 | self.cache[self.task_to_vqa_prompt + task_str] = vqa_style_q_unparsed 79 | 80 | vqa_style_q_unparsed = vqa_style_q_unparsed.strip() 81 | if ":" in vqa_style_q_unparsed: 82 | vqa_style_q = vqa_style_q_unparsed[vqa_style_q_unparsed.index(":") + 2 :] 83 | else: 84 | vqa_style_q = vqa_style_q_unparsed 85 | print("vqa_style_q:", vqa_style_q) 86 | 87 | # ask the VLM 88 | if self.config["success_detector_params"]["which_vlm"] == "gpt4v": 89 | vlm_output = utils.ask_gpt4v(image, vqa_style_q) 90 | elif self.config["success_detector_params"]["which_vlm"] == "cogvlm": 91 | vlm_output = utils.ask_cogvlm(image, vqa_style_q, self.config) 92 | else: 93 | # If there's no VLM success detector, we will conservatively assume the task failed 94 | return False 95 | print("vlm_output:", vlm_output) 96 | 97 | # parse the output 98 | decoding_prompt = ( 99 | self.prompt_to_parse_vlm_output[0] 100 | + task_str 101 | + self.prompt_to_parse_vlm_output[1] 102 | + vqa_style_q 103 | + self.prompt_to_parse_vlm_output[2] 104 | + vlm_output 105 | + self.prompt_to_parse_vlm_output[3] 106 | ) 107 | print("decoding_prompt:", decoding_prompt) 108 | parsed_vlm_output = utils.ask_gpt4(decoding_prompt, cache=self.cache) 109 | 110 | # add response to cache 111 | self.cache[decoding_prompt] = parsed_vlm_output 112 | 113 | print("parsed_vlm_output:", parsed_vlm_output) 114 | 115 | success = parsed_vlm_output.strip().lower() == "true" 116 | try: 117 | assert success in (True, False) 118 | except AssertionError: 119 | print( 120 | "Error: VLM output was neither 'true' nor 'false'. Assuming task failed." 121 | ) 122 | success = False 123 | 124 | if log_metrics: 125 | self.record_task_success(task_str, success) 126 | 127 | return success 128 | -------------------------------------------------------------------------------- /data_collection/orchestrator/robot/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyquaternion import Quaternion 3 | from PIL import Image 4 | import os 5 | import requests 6 | from typing import List 7 | from openai import OpenAI 8 | import io 9 | import base64 10 | from tqdm import tqdm 11 | from concurrent.futures import ThreadPoolExecutor 12 | from multiprocessing import Pool 13 | 14 | 15 | def state_to_eep(xyz_coor, zangle: float): 16 | """ 17 | Implement the state to eep function. 18 | Refered to `bridge_data_robot`'s `widowx_controller/widowx_controller.py` 19 | return a 4x4 matrix 20 | """ 21 | assert len(xyz_coor) == 3 22 | DEFAULT_ROTATION = np.array([[0, 0, 1.0], [0, 1.0, 0], [-1.0, 0, 0]]) 23 | new_pose = np.eye(4) 24 | new_pose[:3, -1] = xyz_coor 25 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) * Quaternion( 26 | matrix=DEFAULT_ROTATION 27 | ) 28 | new_pose[:3, :3] = new_quat.rotation_matrix 29 | # yaw, pitch, roll = quat.yaw_pitch_roll 30 | return new_pose 31 | 32 | 33 | def get_observation(widowx_client, config): 34 | while True: 35 | obs = widowx_client.get_observation() 36 | if obs is None: 37 | print("WARNING: failed to get robot observation, retrying...") 38 | else: 39 | break 40 | obs["image"] = ( 41 | obs["image"] 42 | .reshape( 43 | 3, 44 | config["general_params"]["shoulder_camera_image_size"], 45 | config["general_params"]["shoulder_camera_image_size"], 46 | ) 47 | .transpose(1, 2, 0) 48 | * 255 49 | ).astype(np.uint8) 50 | return obs 51 | 52 | 53 | def encode_image_np(image_np: np.ndarray): 54 | # Ensure the NumPy array is in uint8 55 | if image_np.dtype != np.uint8: 56 | image_np = image_np.astype(np.uint8) 57 | # Convert the NumPy array to a PIL Image 58 | pil_image = Image.fromarray(image_np) 59 | # Create a buffer to hold the bytes 60 | buffer = io.BytesIO() 61 | # Save the image to the buffer in PNG format (or JPEG or any other format) 62 | pil_image.save(buffer, format="PNG") 63 | # Encode the buffer's content in base64 64 | base64_encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") 65 | return base64_encoded 66 | 67 | 68 | def ask_gpt4v(image: np.ndarray, prompt: str): 69 | # We will first resize the image to 512x512 70 | if not image.shape == (512, 512, 3): 71 | image_pil = Image.fromarray(image) 72 | resized_pil = image_pil.resize((512, 512), Image.ANTIALIAS) 73 | image = np.array(resized_pil) 74 | 75 | # Prepare jsons for openai api requests 76 | headers = { 77 | "Content-Type": "application/json", 78 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}", 79 | } 80 | payload = { 81 | "model": "gpt-4-turbo", # gpt4o 82 | "messages": [ 83 | { 84 | "role": "user", 85 | "content": [ 86 | {"type": "text", "text": "###############"}, 87 | { 88 | "type": "image_url", 89 | "image_url": {"url": "###############", "detail": "low"}, 90 | }, 91 | ], 92 | } 93 | ], 94 | "max_tokens": 2000, 95 | "temperature": 0.0, 96 | } 97 | 98 | base64_image = encode_image_np(image) 99 | payload["messages"][0]["content"][1]["image_url"][ 100 | "url" 101 | ] = f"data:image/jpeg;base64,{base64_image}" 102 | payload["messages"][0]["content"][0]["text"] = prompt 103 | 104 | while True: 105 | response = requests.post( 106 | "https://api.openai.com/v1/chat/completions", headers=headers, json=payload 107 | ).json() 108 | if "error" in response: 109 | continue 110 | assistant_message = response["choices"][0]["message"]["content"] 111 | break 112 | 113 | return assistant_message 114 | 115 | 116 | def ask_gpt4v_batched(images: List[np.ndarray], prompts: List[str]): 117 | assert len(images) == len(prompts) 118 | 119 | # TODO: implement real batching 120 | assistant_messages = [] 121 | print("querying gpt4v batched") 122 | for i in tqdm(range(len(images))): 123 | assistant_messages.append(ask_gpt4v(images[i], prompts[i])) 124 | 125 | return assistant_messages 126 | 127 | 128 | def ask_gpt4(prompt, cache=None): 129 | if type(prompt) == tuple: 130 | prompt, cache = prompt 131 | 132 | # check if cache contains the answer 133 | if cache is not None and prompt in cache: 134 | return cache[prompt] 135 | 136 | # Prepare jsons for openai api requests 137 | headers = { 138 | "Content-Type": "application/json", 139 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}", 140 | } 141 | payload = { 142 | "model": "gpt-3.5-turbo-1106", 143 | "messages": [ 144 | { 145 | "role": "user", 146 | "content": [ 147 | {"type": "text", "text": "###############"}, 148 | ], 149 | } 150 | ], 151 | "max_tokens": 2000, 152 | "temperature": 0.0, 153 | } 154 | 155 | payload["messages"][0]["content"][0]["text"] = prompt 156 | while True: 157 | try: 158 | response = requests.post( 159 | "https://api.openai.com/v1/chat/completions", 160 | headers=headers, 161 | json=payload, 162 | ).json() 163 | except: 164 | # sometime we get requests.exceptions.JSONDecodeError 165 | continue 166 | if "error" in response: 167 | continue 168 | assistant_message = response["choices"][0]["message"]["content"] 169 | break 170 | 171 | return assistant_message 172 | 173 | 174 | def ask_gpt4_batched(prompts, cache=None): 175 | if prompts is None or len(prompts) == 0: 176 | return [] 177 | if cache is not None: 178 | # zip cache with the prompts list 179 | prompts = [(prompt, cache) for prompt in prompts] 180 | with Pool(len(prompts)) as p: 181 | assistant_messages = p.map(ask_gpt4, prompts) 182 | return assistant_messages 183 | 184 | 185 | # def ask_gpt4_batched(prompts): 186 | # assistant_messages = [] 187 | # for prompt in tqdm(prompts): 188 | # assistant_messages.append(ask_gpt4(prompt)) 189 | # return assistant_messages 190 | 191 | 192 | def ask_cogvlm(image: np.ndarray, prompt: str, config): 193 | image_list, prompt_list = [image], [prompt] 194 | return ask_cogvlm_batched(image_list, prompt_list, config)[0] 195 | 196 | 197 | def ask_cogvlm_batched(images: List[np.ndarray], prompts: List[str], config): 198 | 199 | def _ask_cogvlm_batched_helper(): 200 | assert len(images) == len(prompts) 201 | 202 | files = [] 203 | for i, (numpy_image, prompt) in enumerate(zip(images, prompts)): 204 | pil_image = Image.fromarray(numpy_image.astype("uint8")) 205 | img_byte_arr = io.BytesIO() 206 | pil_image.save(img_byte_arr, format="PNG") # can be JPEG or other formats 207 | img_byte_arr.seek(0) 208 | 209 | # Append the image file 210 | files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) 211 | 212 | # Append the corresponding prompt 213 | files.append((f"prompt", (None, prompt))) 214 | 215 | url = ( 216 | "http://" 217 | + config["cogvlm_server_params"]["cogvlm_server_ip"] 218 | + ":" 219 | + str(config["cogvlm_server_params"]["cogvlm_server_port"]) 220 | + "/query" 221 | ) 222 | response = requests.post(url, files=files) 223 | 224 | return response.json()["response"] 225 | 226 | # repeat the query if it fails 227 | while True: 228 | try: 229 | response = _ask_cogvlm_batched_helper() 230 | break 231 | except: 232 | continue 233 | 234 | return response 235 | -------------------------------------------------------------------------------- /data_collection/orchestrator/set_workspace_bounds/teleop.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from yamlinclude import YamlIncludeConstructor 3 | from absl import app, flags 4 | import numpy as np 5 | from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs 6 | import os 7 | import sys 8 | import tty 9 | import termios 10 | import time 11 | 12 | FLAGS = flags.FLAGS 13 | flags.DEFINE_string( 14 | "config_dir", 15 | None, 16 | "Path to config directory", 17 | required=True, 18 | ) 19 | 20 | print_yellow = lambda x: print("\033[93m {}\033[00m".format(x)) 21 | 22 | 23 | def print_help(): 24 | print_yellow(" Teleop Controls:") 25 | print_yellow(" w, s : move forward/backward") 26 | print_yellow(" a, d : move left/right") 27 | print_yellow(" z, c : move up/down") 28 | print_yellow(" i, k: rotate yaw") 29 | print_yellow(" j, l: rotate pitch") 30 | print_yellow(" n m: rotate roll") 31 | print_yellow(" space: toggle gripper") 32 | print_yellow(" r: reset robot") 33 | print_yellow(" q: quit") 34 | 35 | 36 | def main(_): 37 | YamlIncludeConstructor.add_to_loader_class( 38 | loader_class=yaml.FullLoader, base_dir=FLAGS.config_dir 39 | ) 40 | with open(os.path.join(FLAGS.config_dir, "config.yaml")) as f: 41 | config = yaml.load(f, Loader=yaml.FullLoader) 42 | 43 | env_params = WidowXConfigs.DefaultEnvParams.copy() 44 | env_params.update({"action_clipping": None}) 45 | client = WidowXClient( 46 | host=config["general_params"]["ip"], port=config["general_params"]["port"] 47 | ) 48 | client.init(env_params) 49 | client.reset() 50 | 51 | # Save the terminal settings 52 | fd = sys.stdin.fileno() 53 | old_settings = termios.tcgetattr(fd) 54 | 55 | print_help() 56 | is_open = 1 57 | running = True 58 | xyz_min, xyz_max = None, None 59 | while running: 60 | # Check for key press 61 | try: 62 | # Set the terminal to raw mode to read a single key press 63 | tty.setraw(sys.stdin.fileno()) 64 | key = sys.stdin.read(1) 65 | finally: 66 | # Restore the terminal to its original settings 67 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) 68 | 69 | # escape key to quit 70 | if key == "q": 71 | print("Quitting teleoperation.") 72 | running = False 73 | continue 74 | 75 | # Handle key press for robot control 76 | # translation 77 | if key == "w": 78 | client.step_action(np.array([0.01, 0, 0, 0, 0, 0, is_open])) 79 | elif key == "s": 80 | client.step_action(np.array([-0.01, 0, 0, 0, 0, 0, is_open])) 81 | elif key == "a": 82 | client.step_action(np.array([0, 0.01, 0, 0, 0, 0, is_open])) 83 | elif key == "d": 84 | client.step_action(np.array([0, -0.01, 0, 0, 0, 0, is_open])) 85 | elif key == "z": 86 | client.step_action(np.array([0, 0, 0.01, 0, 0, 0, is_open])) 87 | elif key == "c": 88 | client.step_action(np.array([0, 0, -0.01, 0, 0, 0, is_open])) 89 | 90 | # rotation 91 | elif key == "i": 92 | client.step_action(np.array([0, 0, 0, 0.01, 0, 0, is_open])) 93 | elif key == "k": 94 | client.step_action(np.array([0, 0, 0, -0.01, 0, 0, is_open])) 95 | elif key == "j": 96 | client.step_action(np.array([0, 0, 0, 0, 0.01, 0, is_open])) 97 | elif key == "l": 98 | client.step_action(np.array([0, 0, 0, 0, -0.01, 0, is_open])) 99 | elif key == "n": 100 | client.step_action(np.array([0, 0, 0, 0, 0, 0.01, is_open])) 101 | elif key == "m": 102 | client.step_action(np.array([0, 0, 0, 0, 0, -0.01, is_open])) 103 | 104 | # space bar to change gripper state 105 | elif key == " ": 106 | is_open = 1 - is_open 107 | print("Gripper is now: ", is_open) 108 | client.step_action(np.array([0, 0, 0, 0, 0, 0, is_open])) 109 | elif key == "r": 110 | print("Resetting robot...") 111 | client.reset() 112 | print_help() 113 | 114 | # Get the end-effector position after taking action 115 | obs = client.get_observation() 116 | eef_pose = obs["state"] 117 | if xyz_min is None or xyz_max is None: 118 | xyz_min = eef_pose[:3] 119 | xyz_max = eef_pose[:3] 120 | xyz_min = np.minimum(xyz_min, eef_pose[:3]) 121 | xyz_max = np.maximum(xyz_max, eef_pose[:3]) 122 | print("robot pose:", eef_pose) 123 | 124 | client.stop() # Properly stop the client 125 | print("Teleoperation ended.") 126 | 127 | print() 128 | print("XYZ Min:", xyz_min) 129 | print("XYZ Max:", xyz_max) 130 | 131 | 132 | if __name__ == "__main__": 133 | app.run(main) 134 | -------------------------------------------------------------------------------- /data_collection/orchestrator/susie_server/main.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from yamlinclude import YamlIncludeConstructor 3 | import argparse 4 | import os 5 | import numpy as np 6 | from susie.model import create_sample_fn 7 | from flask import Flask, request, send_file 8 | from PIL import Image 9 | import io 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--config_dir", required=True) 13 | args = parser.parse_args() 14 | 15 | YamlIncludeConstructor.add_to_loader_class( 16 | loader_class=yaml.FullLoader, base_dir=args.config_dir 17 | ) 18 | with open(os.path.join(args.config_dir, "config.yaml")) as f: 19 | config = yaml.load(f, Loader=yaml.FullLoader) 20 | 21 | 22 | class SubgoalPredictor: 23 | def __init__(self, config): 24 | diffusion_config = config["subgoal_predictor_params"] 25 | self.diffusion_sample_func = create_sample_fn( 26 | diffusion_config["diffusion_checkpoint"], 27 | diffusion_config["diffusion_wandb"], 28 | diffusion_config["diffusion_num_steps"], 29 | diffusion_config["prompt_w"], 30 | diffusion_config["context_w"], 31 | 0.0, 32 | diffusion_config["diffusion_pretrained_path"], 33 | ) 34 | self.image_size = diffusion_config["image_size"] 35 | 36 | def __call__(self, image_obs: np.ndarray, prompt: str): 37 | assert image_obs.shape == ( 38 | self.image_size, 39 | self.image_size, 40 | 3, 41 | ), "Bad input image shape" 42 | return self.diffusion_sample_func(image_obs, prompt) 43 | 44 | 45 | subgoal_predictor = SubgoalPredictor(config) 46 | 47 | app = Flask(__name__) 48 | 49 | 50 | @app.route("/generate_subgoal", methods=["POST"]) 51 | def process_image(): 52 | global subgoal_predictor 53 | 54 | # Check if the request contains the 'image' file 55 | if "image" not in request.files: 56 | return "No image part", 400 57 | file = request.files["image"] 58 | 59 | # Check if the request contains the 'text' part 60 | if "text" not in request.form: 61 | return "No text provided", 400 62 | text = request.form["text"] 63 | 64 | # Read the image file 65 | image = Image.open(file.stream) 66 | image = np.array(image) 67 | 68 | generated_subgoal = subgoal_predictor(image_obs=image, prompt=text) 69 | 70 | # Save the image to a binary stream 71 | generated_subgoal = Image.fromarray(generated_subgoal) 72 | img_io = io.BytesIO() 73 | generated_subgoal.save(img_io, "JPEG", quality=70) 74 | img_io.seek(0) 75 | 76 | return send_file(img_io, mimetype="image/jpeg") 77 | 78 | 79 | if __name__ == "__main__": 80 | app.run(debug=False, host="0.0.0.0", port=7000) 81 | -------------------------------------------------------------------------------- /data_collection/orchestrator/web_viewer/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template, request, jsonify, send_file 2 | import os 3 | import portalocker 4 | 5 | app = Flask(__name__) 6 | app.config["UPLOAD_FOLDER"] = "./uploads" 7 | app.config["MAX_CONTENT_LENGTH"] = 16 * 1024 * 1024 # 16MB upload limit 8 | 9 | robot_video_feeds = { 10 | "feed0": { 11 | "observation": None, 12 | "goal": None, 13 | "status": { 14 | "commanded_task": "N/A", 15 | "subgoal": 0, 16 | "timestep": 0, 17 | "task_success": "N/A", 18 | }, 19 | }, 20 | "feed1": { 21 | "observation": None, 22 | "goal": None, 23 | "status": { 24 | "commanded_task": "N/A", 25 | "subgoal": 0, 26 | "timestep": 0, 27 | "task_success": "N/A", 28 | }, 29 | }, 30 | "feed2": { 31 | "observation": None, 32 | "goal": None, 33 | "status": { 34 | "commanded_task": "N/A", 35 | "subgoal": 0, 36 | "timestep": 0, 37 | "task_success": "N/A", 38 | }, 39 | }, 40 | "feed3": { 41 | "observation": None, 42 | "goal": None, 43 | "status": { 44 | "commanded_task": "N/A", 45 | "subgoal": 0, 46 | "timestep": 0, 47 | "task_success": "N/A", 48 | }, 49 | }, 50 | "feed4": { 51 | "observation": None, 52 | "goal": None, 53 | "status": { 54 | "commanded_task": "N/A", 55 | "subgoal": 0, 56 | "timestep": 0, 57 | "task_success": "N/A", 58 | }, 59 | }, 60 | "feed5": { 61 | "observation": None, 62 | "goal": None, 63 | "status": { 64 | "commanded_task": "N/A", 65 | "subgoal": 0, 66 | "timestep": 0, 67 | "task_success": "N/A", 68 | }, 69 | }, 70 | "feed6": { 71 | "observation": None, 72 | "goal": None, 73 | "status": { 74 | "commanded_task": "N/A", 75 | "subgoal": 0, 76 | "timestep": 0, 77 | "task_success": "N/A", 78 | }, 79 | }, 80 | "feed7": { 81 | "observation": None, 82 | "goal": None, 83 | "status": { 84 | "commanded_task": "N/A", 85 | "subgoal": 0, 86 | "timestep": 0, 87 | "task_success": "N/A", 88 | }, 89 | }, 90 | } 91 | 92 | 93 | @app.route("/") 94 | def index(): 95 | return render_template("index.html") 96 | 97 | 98 | @app.route("/upload/", methods=["POST"]) 99 | def upload_image(robot_idx): 100 | robot_idx = int(robot_idx) 101 | image_type = request.args.get("type", "") 102 | if image_type not in ["observation", "goal"]: 103 | return jsonify({"error": "Invalid image type"}), 400 104 | 105 | file = request.files.get("file") 106 | if file and file.filename: 107 | # Save the file as 'observation.jpg' or 'goal.jpg' in the uploads folder 108 | filename = f"{image_type}_{robot_idx}.jpg" 109 | file_path = os.path.join(app.config["UPLOAD_FOLDER"], filename) 110 | with open(file_path, "wb") as f: 111 | portalocker.lock(f, portalocker.LOCK_EX) 112 | f.write(file.read()) 113 | portalocker.unlock(f) 114 | return ( 115 | jsonify( 116 | { 117 | "message": "Image uploaded successfully", 118 | "type": image_type, 119 | "robot_index": robot_idx, 120 | } 121 | ), 122 | 200, 123 | ) 124 | else: 125 | return jsonify({"error": "No file part"}), 400 126 | 127 | 128 | @app.route("/images//", methods=["GET"]) 129 | def get_latest_image(robot_idx, image_type): 130 | robot_idx = int(robot_idx) 131 | if image_type not in ["observation", "goal"]: 132 | return jsonify({"error": "Invalid image type"}), 400 133 | 134 | image_path = os.path.join( 135 | app.config["UPLOAD_FOLDER"], f"{image_type}_{robot_idx}.jpg" 136 | ) 137 | if os.path.exists(image_path): 138 | return send_file(image_path) 139 | else: 140 | return jsonify({"error": "Image not found"}), 404 141 | 142 | 143 | @app.route("/update_status/", methods=["POST"]) 144 | def update_status(robot_idx): 145 | data = request.get_json() 146 | robot_video_feeds["feed" + robot_idx]["status"] = { 147 | "commanded_task": data.get("commanded_task", ""), 148 | "subgoal": data.get("subgoal", 0), 149 | "timestep": data.get("timestep", 0), 150 | "task_success": data.get("task_success", ""), 151 | } 152 | return jsonify({"message": "Status updated successfully"}) 153 | 154 | 155 | @app.route("/get_status_data/") 156 | def get_status_data(robot_idx): 157 | return jsonify(robot_video_feeds["feed" + robot_idx]["status"]) 158 | 159 | 160 | if __name__ == "__main__": 161 | os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True) 162 | app.run(debug=True, host="0.0.0.0", port=5000) 163 | -------------------------------------------------------------------------------- /data_collection/orchestrator/web_viewer/ros_client/run_client.py: -------------------------------------------------------------------------------- 1 | import rospy 2 | from sensor_msgs.msg import Image 3 | import time 4 | import numpy as np 5 | from PIL import Image as PILImage 6 | import requests 7 | import io 8 | import yaml 9 | from yamlinclude import YamlIncludeConstructor 10 | from absl import app, flags 11 | import os 12 | 13 | FLAGS = flags.FLAGS 14 | flags.DEFINE_string( 15 | "config_dir", 16 | None, 17 | "Path to config directory", 18 | required=True, 19 | ) 20 | 21 | 22 | def run_ros_subscriber(_): 23 | YamlIncludeConstructor.add_to_loader_class( 24 | loader_class=yaml.FullLoader, base_dir=FLAGS.config_dir 25 | ) 26 | with open(os.path.join(FLAGS.config_dir, "config.yaml")) as f: 27 | config = yaml.load(f, Loader=yaml.FullLoader) 28 | 29 | rospy.init_node("image_listener", anonymous=True) 30 | 31 | # Function to handle the image message once received 32 | def image_callback(message): 33 | rospy.loginfo("Received an image!") 34 | str_message = str(message) 35 | image_data = str_message[str_message.index("[") + 1 : str_message.index("]")] 36 | image_data = image_data.replace(",", "").split() 37 | image_data = [int(elem) for elem in image_data] 38 | image = np.array(image_data, dtype=np.uint8) 39 | 40 | # The shape of the image is 480 x 640 41 | image = image.reshape(480, 640, 3) 42 | image = image[:, :, ::-1] # convert from BGR to RGB 43 | 44 | img = PILImage.fromarray(image) 45 | img = img.resize((512, 512), PILImage.LANCZOS) # make image square 512x512 46 | 47 | buffer = io.BytesIO() 48 | img.save(buffer, format="JPEG") 49 | buffer.seek(0) 50 | files = {"file": ("image.jpg", buffer.getvalue(), "image/jpeg")} 51 | 52 | # We're sending this to our main web server 53 | url = ( 54 | "http://" 55 | + config["general_params"]["web_viewer_ip"] 56 | + ":" 57 | + str(config["general_params"]["web_viewer_port"]) 58 | + "/upload/" 59 | + str(config["general_params"]["robot_id"]) 60 | + "?type=observation" 61 | ) 62 | response = requests.post(url, files=files) 63 | 64 | rospy.Subscriber("/blue/image_raw", Image, image_callback) 65 | 66 | rospy.spin() # Spin until the node is shut down 67 | 68 | 69 | if __name__ == "__main__": 70 | app.run(run_ros_subscriber) 71 | -------------------------------------------------------------------------------- /data_collection/orchestrator/web_viewer/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Image Feeds 6 | 86 | 87 | 88 |
89 | 90 | 92 |
93 | 94 |
95 | 96 | 170 | 171 | -------------------------------------------------------------------------------- /data_collection/requirements.txt: -------------------------------------------------------------------------------- 1 | distrax==0.1.2 2 | flax==0.7.0 3 | transformers==4.33.1 4 | dlimp @ git+https://github.com/zhouzypaul/dlimp@2df85dc98a7d564fc9c7c5ff2bfda26361526483 5 | einops==0.6.1 6 | wandb==0.15.5 7 | tensorflow==2.15.0.post1 8 | tensorflow-probability==0.23.0 9 | ml-collections==0.1.0 10 | jax==0.4.20 11 | jaxlib==0.4.20 12 | optax==0.1.5 13 | diffusers==0.18.2 14 | ml-dtypes==0.2.0 15 | pyyaml-include==1.3.2 16 | slack_sdk 17 | Pillow==10.1 18 | rospkg==1.5.0 19 | pyyaml-include==1.3.2 20 | opencv-python==4.9.0.80 21 | funcsigs==1.0.2 22 | pyquaternion==0.9.9 23 | openai==1.14.1 24 | nltk==3.8.1 25 | edgeml @ git+https://github.com/youliangtan/edgeml.git -------------------------------------------------------------------------------- /data_collection/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="orchestrator", 5 | packages=find_packages(), 6 | version="0.0.1", 7 | install_requires=[ 8 | "absl-py", 9 | "diffusers[flax]", 10 | "ml_collections", 11 | "tensorflow", 12 | "wandb", 13 | "einops", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /data_collection/ssh_port_forward.sh: -------------------------------------------------------------------------------- 1 | ssh -L 6000:localhost:5000 -N -f -C -o ExitOnForwardFailure=yes $USER@128.32.162.191 2 | ssh -L 7000:localhost:6000 -N -f -C -o ExitOnForwardFailure=yes $USER@128.32.162.191 3 | -------------------------------------------------------------------------------- /data_collection/start_robot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # check that OPENAI_API_KEY is set 4 | if [ -z "$OPENAI_API_KEY" ]; then 5 | echo "OPENAI_API_KEY is not set. Please set it before running the script." 6 | exit 1 7 | fi 8 | 9 | # Check if the first argument is provided 10 | if [ -z "$1" ]; then 11 | echo "No argument provided. Please provide a number between 0 and 4." 12 | exit 1 13 | fi 14 | 15 | # Check if the first argument is a valid number between 0 and 4 16 | if ! [[ "$1" =~ ^[0-4]$ ]]; then 17 | echo "Invalid argument. Please provide a number between 0 and 4." 18 | exit 1 19 | fi 20 | 21 | # Extract the first argument and the rest of the arguments 22 | n="$1" 23 | shift 24 | additional_args="$@" 25 | 26 | # If reset flag is true, execute the reset commands 27 | # echo "Resetting ports 7000 and 6000" 28 | # lsof -ti:7000 | xargs kill -9 29 | # lsof -ti:6000 | xargs kill -9 30 | 31 | # Execute the ssh port forwarding script 32 | echo "Executing: bash ssh_port_forward.sh" 33 | bash ssh_port_forward.sh 34 | 35 | # Construct the command based on the first argument 36 | config_dir="config/berkeley_robot_$n" 37 | command="python orchestrator/robot/main.py --config_dir $config_dir $additional_args" 38 | 39 | # Execute the Python command 40 | echo "Executing: $command" 41 | $command 42 | -------------------------------------------------------------------------------- /media/autonomous_data_collection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/media/autonomous_data_collection.png -------------------------------------------------------------------------------- /media/soar_logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/media/soar_logo.jpeg -------------------------------------------------------------------------------- /media/soar_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/media/soar_teaser.png -------------------------------------------------------------------------------- /model_training/README.md: -------------------------------------------------------------------------------- 1 | # Model training code 2 | 3 | This directory contains a self-contained python project for training goal-conditioned and language conditioned policies on BridgeData and on Soar-Data. 4 | 5 | ## Installation 6 | Run in the current directory 7 | ```bash 8 | pip install -e . 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Structure 13 | - `experiments/`: Contains the main training script `train.py` and the configuration files `train_config.py` and `data_config.py`. 14 | - 'jaxrl_m`: the main library for training models with Jax. 15 | - `jaxrl_m/agents/`: Contains the implementation of the agents. 16 | - `jaxrl_m/data/`: Contains the data processing and data loading code. 17 | 18 | ## Training 19 | In the current directory, run 20 | ```bash 21 | bash experiments/scripts/launch.sh 22 | ``` 23 | This will launch [train.py](experiments/train.py) with the default arguments specified in [train_config.py](experiments/configs/train_config.py) and [data_config.py](experiments/configs/data_config.py). 24 | -------------------------------------------------------------------------------- /model_training/experiments/configs/data_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import ml_collections 5 | 6 | 7 | # BridgeDatav2 metadata 8 | 9 | ACT_MEAN = [ 10 | 1.9296819e-04, 11 | 1.3667766e-04, 12 | -1.4583133e-04, 13 | -1.8390431e-04, 14 | -3.0808983e-04, 15 | 2.7425270e-04, 16 | 5.9716219e-01, 17 | ] 18 | 19 | ACT_STD = [ 20 | 0.00912848, 21 | 0.0127196, 22 | 0.01229497, 23 | 0.02606696, 24 | 0.02875283, 25 | 0.07807977, 26 | 0.48710242, 27 | ] 28 | 29 | ACT_MIN = [ 30 | -0.0437546, 31 | -0.052831028, 32 | -0.035931006, 33 | -0.14489305, 34 | -0.15591072, 35 | -0.26039174, 36 | -0.780331, 37 | ] # 0.1% quantile 38 | 39 | ACT_MAX = [ 40 | 0.04158026, 41 | 0.05223833, 42 | 0.05382493, 43 | 0.15559858, 44 | 0.142592, 45 | 0.25956747, 46 | 0.79311615, 47 | ] # 99.9% quantile 48 | 49 | ACTION_PROPRIO_METADATA = { 50 | "action": { 51 | "mean": np.array(ACT_MEAN), 52 | "std": np.array(ACT_STD), 53 | "min": np.array(ACT_MIN), 54 | "max": np.array(ACT_MAX), 55 | }, 56 | # TODO compute these 57 | "proprio": { 58 | "mean": np.array(ACT_MEAN), 59 | "std": np.array(ACT_STD), 60 | "min": np.array(ACT_MIN), 61 | "max": np.array(ACT_MAX), 62 | }, 63 | } 64 | 65 | 66 | def get_config(config_string): 67 | possible_structures = { 68 | "all": ml_collections.ConfigDict( 69 | { 70 | "pretraining_data": [ 71 | "gs://gresearch/robotics/bridge/0.1.0/" 72 | ], 73 | "autonomous_data": [ 74 | os.path.expanduser("~/tensorflow_datasets/soar_dataset/1.0.0"), 75 | ], 76 | "exclude": [], 77 | "sampling_weights": { 78 | "pretraining_data": 0.8, 79 | "autonomous_data_successes": 0.2, 80 | "autonomous_data_failures": 0.0, 81 | }, 82 | } 83 | ), 84 | } 85 | return possible_structures[config_string] 86 | -------------------------------------------------------------------------------- /model_training/experiments/configs/train_config.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | 3 | 4 | def get_config(config_string): 5 | base_real_config = dict( 6 | batch_size=512, 7 | num_steps=int(1001000), 8 | log_interval=1000, 9 | eval_interval=25000, 10 | save_interval=25000, 11 | num_val_trajs=8, 12 | num_val_batches=8, 13 | save_dir="~/jaxrl_log", 14 | resume_path="", 15 | seed=42, 16 | ) 17 | 18 | base_data_config = dict( 19 | # action_merge_horizon=2, 20 | shuffle_buffer_size=25000, 21 | augment=True, 22 | augment_next_obs_goal_differently=False, 23 | augment_kwargs=dict( 24 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), 25 | random_brightness=[0.2], 26 | random_contrast=[0.8, 1.2], 27 | random_saturation=[0.8, 1.2], 28 | random_hue=[0.1], 29 | augment_order=[ 30 | "random_resized_crop", 31 | "random_brightness", 32 | "random_contrast", 33 | "random_saturation", 34 | "random_hue", 35 | ], 36 | ), 37 | ) 38 | 39 | possible_structures = { 40 | "gc_bc_offline_bridge": ConfigDict( 41 | dict( 42 | agent="gc_bc", 43 | agent_kwargs=dict( 44 | network_kwargs=dict( 45 | hidden_dims=( 46 | 256, 47 | 256, 48 | 256, 49 | ), 50 | dropout_rate=0.1, 51 | ), 52 | policy_kwargs=dict( 53 | tanh_squash_distribution=True, 54 | std_parameterization="fixed", 55 | fixed_std=[1, 1, 1, 1, 1, 1, 0.1], 56 | ), 57 | early_goal_concat=True, 58 | shared_goal_encoder=True, 59 | use_proprio=False, 60 | learning_rate=3e-4, 61 | warmup_steps=2000, 62 | decay_steps=int(2e6), 63 | ), 64 | dataset_kwargs=dict( 65 | goal_relabeling_strategy="geometric", 66 | goal_relabeling_kwargs=dict(reached_proportion=0.0, discount=0.98), 67 | normalization_type="normal", 68 | **base_data_config, 69 | ), 70 | encoder="resnetv1-34-bridge", # diff: bridge release use resnet-50 71 | encoder_kwargs=dict( 72 | pooling_method="avg", 73 | add_spatial_coordinates=True, 74 | act="swish", 75 | ), 76 | **base_real_config, 77 | ) 78 | ), 79 | } 80 | 81 | return possible_structures[config_string] 82 | -------------------------------------------------------------------------------- /model_training/experiments/scripts/launch.sh: -------------------------------------------------------------------------------- 1 | CMD="python experiments/train.py \ 2 | --config experiments/configs/train_config.py:gc_bc_offline_bridge \ 3 | --data_config experiments/configs/data_config.py:all \ 4 | " 5 | 6 | # take in command line args too 7 | $CMD $@ 8 | -------------------------------------------------------------------------------- /model_training/experiments/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import traceback 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import tensorflow as tf 9 | import tqdm 10 | import wandb 11 | from absl import app, flags, logging 12 | from flax.training import checkpoints 13 | from ml_collections import config_flags 14 | 15 | from jaxrl_m.agents import agents 16 | from jaxrl_m.common.common import shard_batch 17 | from jaxrl_m.common.wandb import WandBLogger 18 | from jaxrl_m.data.dataset import WidowXDataset 19 | from jaxrl_m.utils.timer_utils import Timer 20 | from jaxrl_m.vision import encoders 21 | 22 | try: 23 | from jax_smi import initialise_tracking # type: ignore 24 | 25 | initialise_tracking() 26 | except ImportError: 27 | pass 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("exp_name", "", "Experiment name.") 32 | flags.DEFINE_list("tag", list(), "Name of experiment") 33 | flags.DEFINE_string("group", None, "Group of the wandb experiments") 34 | flags.DEFINE_bool("debug", False, "Debug config") 35 | 36 | config_flags.DEFINE_config_file( 37 | "config", 38 | None, 39 | "File path to the training hyperparameter configuration.", 40 | lock_config=False, 41 | ) 42 | 43 | config_flags.DEFINE_config_file( 44 | "data_config", 45 | None, 46 | "File path to the bridgedata configuration.", 47 | lock_config=False, 48 | ) 49 | 50 | 51 | def main(_): 52 | devices = jax.local_devices() 53 | num_devices = len(devices) 54 | assert FLAGS.config.batch_size % num_devices == 0 55 | 56 | # we shard the leading dimension (batch dimension) accross all devices evenly 57 | sharding = jax.sharding.PositionalSharding(devices) 58 | shard_fn = partial(shard_batch, sharding=sharding) 59 | 60 | # prevent tensorflow from using GPUs 61 | tf.config.set_visible_devices([], "GPU") 62 | 63 | # set up wandb and logging 64 | wandb_config = WandBLogger.get_default_config() 65 | wandb_config.update( 66 | { 67 | "project": f"jaxrl_{FLAGS.config.agent}_autonomous_data", 68 | "exp_descriptor": FLAGS.exp_name, 69 | "tag": FLAGS.tag, 70 | "group": FLAGS.group, 71 | } 72 | ) 73 | wandb_logger = WandBLogger( 74 | wandb_config=wandb_config, 75 | variant=FLAGS.config.to_dict(), 76 | debug=FLAGS.debug, 77 | ) 78 | 79 | save_dir = tf.io.gfile.join( 80 | FLAGS.config.save_dir, 81 | wandb_logger.config.project, 82 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}", 83 | ) 84 | 85 | # load datasets 86 | random.seed(FLAGS.config.seed) 87 | train_paths = [] 88 | if FLAGS.data_config.sampling_weights.pretraining_data > 0: 89 | train_paths += [FLAGS.data_config.pretraining_data] 90 | if FLAGS.data_config.sampling_weights.autonomous_data_successes > 0: 91 | train_paths += [FLAGS.data_config.autonomous_data] 92 | if FLAGS.data_config.sampling_weights.autonomous_data_failures > 0: 93 | train_paths += [FLAGS.data_config.autonomous_data] 94 | 95 | # create sample weights for training 96 | train_sample_weights = [ 97 | FLAGS.data_config.sampling_weights["pretraining_data"], 98 | FLAGS.data_config.sampling_weights["autonomous_data_successes"], 99 | FLAGS.data_config.sampling_weights["autonomous_data_failures"], 100 | ] 101 | train_sample_weights = [ 102 | weight for weight in train_sample_weights if weight > 0 103 | ] # remove 0s from the sample weights 104 | assert ( 105 | sum(train_sample_weights) == 1.0 106 | ), f"Sample weights must sum to 1.0, got {sum(train_sample_weights)}" 107 | 108 | # pick out the splits needed from the dataset 109 | train_data_splits = [] 110 | if FLAGS.data_config.sampling_weights.pretraining_data > 0: 111 | train_data_splits.append("train") 112 | if FLAGS.data_config.sampling_weights.autonomous_data_successes > 0: 113 | train_data_splits.append("success") 114 | if FLAGS.data_config.sampling_weights.autonomous_data_failures > 0: 115 | train_data_splits.append("failure") 116 | 117 | train_data = WidowXDataset( 118 | train_paths, 119 | FLAGS.config.seed, 120 | batch_size=FLAGS.config.batch_size, 121 | train=True, 122 | sample_weights=train_sample_weights, 123 | data_splits=train_data_splits, 124 | **FLAGS.config.dataset_kwargs, 125 | ) 126 | val_data = WidowXDataset( 127 | FLAGS.data_config.pretraining_data, 128 | FLAGS.config.seed, 129 | batch_size=FLAGS.config.batch_size, 130 | train=False, 131 | sample_weights=None, 132 | data_splits=["val"], 133 | **FLAGS.config.dataset_kwargs, 134 | ) 135 | 136 | train_data_iter = map(shard_fn, train_data.iterator()) 137 | 138 | example_batch = next(train_data_iter) 139 | logging.info(f"Batch size: {example_batch['observations']['image'].shape[0]}") 140 | logging.info(f"Number of devices: {num_devices}") 141 | logging.info( 142 | f"Batch size per device: {example_batch['observations']['image'].shape[0] // num_devices}" 143 | ) 144 | 145 | # define encoder 146 | encoder_def = encoders[FLAGS.config.encoder](**FLAGS.config.encoder_kwargs) 147 | 148 | # initialize agent 149 | rng = jax.random.PRNGKey(FLAGS.config.seed) 150 | rng, construct_rng = jax.random.split(rng) 151 | agent = agents[FLAGS.config.agent].create( 152 | rng=construct_rng, 153 | observations=example_batch["observations"], 154 | goals=example_batch["goals"], 155 | actions=example_batch["actions"], 156 | encoder_def=encoder_def, 157 | **FLAGS.config.agent_kwargs, 158 | ) 159 | if FLAGS.config.get("resume_path", "") != "": 160 | agent = checkpoints.restore_checkpoint(FLAGS.config.resume_path, target=agent) 161 | # replicate agent across devices 162 | # need the jnp.array to avoid a bug where device_put doesn't recognize primitives 163 | agent = jax.device_put(jax.tree_map(jnp.array, agent), sharding.replicate()) 164 | 165 | timer = Timer() 166 | for i in tqdm.tqdm(range(int(FLAGS.config.num_steps - agent.state.step))): 167 | try: 168 | timer.tick("total") 169 | 170 | timer.tick("dataset") 171 | batch = shard_batch(next(train_data_iter), sharding) 172 | timer.tock("dataset") 173 | 174 | timer.tick("train") 175 | agent, update_info = agent.update(batch) 176 | timer.tock("train") 177 | 178 | if agent.state.step % FLAGS.config.eval_interval == 0: 179 | logging.info("Validation...") 180 | timer.tick("val") 181 | 182 | # plot debug metrics of validation data 183 | val_metrics = [] 184 | j = 0 185 | val_iter = map(shard_fn, val_data.iterator()) 186 | for val_batch in val_iter: 187 | rng, val_rng = jax.random.split(rng) 188 | val_metrics.append(agent.get_debug_metrics(val_batch, seed=val_rng)) 189 | j += 1 190 | if j >= FLAGS.config.num_val_batches: 191 | break 192 | val_metrics = jax.tree_map(lambda *xs: np.mean(xs), *val_metrics) 193 | wandb_logger.log({"validation": val_metrics}, step=agent.state.step) 194 | 195 | timer.tock("val") 196 | 197 | if agent.state.step % FLAGS.config.save_interval == 0: 198 | logging.info("Saving checkpoint...") 199 | checkpoint_path = checkpoints.save_checkpoint( 200 | save_dir, agent, step=agent.state.step, keep=1e6 201 | ) 202 | logging.info("Saved checkpoint to %s", checkpoint_path) 203 | 204 | timer.tock("total") 205 | 206 | if agent.state.step % FLAGS.config.log_interval == 0: 207 | update_info = jax.device_get(update_info) 208 | wandb_logger.log({"training": update_info}, step=agent.state.step) 209 | 210 | wandb_logger.log( 211 | {"timer": timer.get_average_times()}, step=agent.state.step 212 | ) 213 | except tf.errors.OpError as e: 214 | # sometimes tfds will have trouble communicating with cloud storage bucket for some reason... 215 | print(f"Error in iteration {i}: {e}") 216 | print("Skipping to next iteration...") 217 | traceback.print_exc() 218 | 219 | # avoid timer tock errors 220 | timer.force_tock_everything() 221 | 222 | continue 223 | 224 | 225 | if __name__ == "__main__": 226 | app.run(main) 227 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous.gc_bc import GCBCAgent 2 | 3 | agents = { 4 | "gc_bc": GCBCAgent, 5 | } 6 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/agents/continuous/gc_bc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Any, Optional 4 | 5 | import flax 6 | import flax.linen as nn 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | from flax.core import FrozenDict 11 | 12 | from jaxrl_m.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 13 | from jaxrl_m.common.optimizers import make_optimizer 14 | from jaxrl_m.common.encoding import GCEncodingWrapper, LCEncodingWrapper 15 | from jaxrl_m.common.optimizers import make_optimizer 16 | from jaxrl_m.common.typing import Batch, PRNGKey 17 | from jaxrl_m.networks.actor_critic_nets import Policy 18 | from jaxrl_m.networks.mlp import MLP 19 | 20 | 21 | class GCBCAgent(flax.struct.PyTreeNode): 22 | state: JaxRLTrainState 23 | lr_schedule: Any = nonpytree_field() 24 | 25 | @partial(jax.jit, static_argnames="pmap_axis") 26 | def update(self, batch: Batch, pmap_axis: str = None): 27 | def loss_fn(params, rng): 28 | rng, key = jax.random.split(rng) 29 | dist = self.state.apply_fn( 30 | {"params": params}, 31 | (batch["observations"], batch["goals"]), 32 | temperature=1.0, 33 | train=True, 34 | rngs={"dropout": key}, 35 | name="actor", 36 | ) 37 | pi_actions = dist.mode() 38 | log_probs = dist.log_prob(batch["actions"]) 39 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 40 | actor_loss = -(log_probs).mean() 41 | actor_std = dist.stddev().mean(axis=1) 42 | 43 | return actor_loss, { 44 | "actor_loss": actor_loss, 45 | "mse": mse.mean(), 46 | "log_probs": log_probs.mean(), 47 | "pi_actions": pi_actions.mean(), 48 | "mean_std": actor_std.mean(), 49 | "max_std": actor_std.max(), 50 | } 51 | 52 | # compute gradients and update params 53 | new_state, info = self.state.apply_loss_fns( 54 | loss_fn, 55 | pmap_axis=pmap_axis, 56 | has_aux=True, 57 | ) 58 | 59 | # log learning rates 60 | info["lr"] = self.lr_schedule(self.state.step) 61 | 62 | return self.replace(state=new_state), info 63 | 64 | @partial(jax.jit, static_argnames="argmax") 65 | def sample_actions( 66 | self, 67 | observations: np.ndarray, 68 | goals: np.ndarray, 69 | *, 70 | seed: Optional[PRNGKey] = None, 71 | temperature: float = 1.0, 72 | argmax=False, 73 | ) -> jnp.ndarray: 74 | dist = self.state.apply_fn( 75 | {"params": self.state.params}, 76 | (observations, goals), 77 | temperature=temperature, 78 | name="actor", 79 | ) 80 | if argmax: 81 | actions = dist.mode() 82 | else: 83 | actions = dist.sample(seed=seed) 84 | return actions, dist.mode() 85 | 86 | @jax.jit 87 | def get_debug_metrics(self, batch, **kwargs): 88 | dist = self.state.apply_fn( 89 | {"params": self.state.params}, 90 | (batch["observations"], batch["goals"]), 91 | temperature=1.0, 92 | name="actor", 93 | ) 94 | pi_actions = dist.mode() 95 | log_probs = dist.log_prob(batch["actions"]) 96 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 97 | 98 | return { 99 | "mse": mse, 100 | "log_probs": log_probs, 101 | "pi_actions": pi_actions, 102 | } 103 | 104 | @classmethod 105 | def create( 106 | cls, 107 | rng: PRNGKey, 108 | # example arrays for model init 109 | observations: FrozenDict, 110 | actions: jnp.ndarray, 111 | goals: FrozenDict, 112 | # agent config 113 | encoder_def: nn.Module, 114 | language_conditioned: bool = False, 115 | # should only be set if not language conditioned 116 | shared_goal_encoder: Optional[bool] = None, 117 | early_goal_concat: Optional[bool] = None, 118 | # other shared network config 119 | use_proprio: bool = False, 120 | network_kwargs: dict = { 121 | "hidden_dims": [256, 256], 122 | }, 123 | policy_kwargs: dict = { 124 | "tanh_squash_distribution": False, 125 | "std_parameterization": "exp", 126 | }, 127 | # optimizer config 128 | learning_rate: float = 3e-4, 129 | warmup_steps: int = 1000, 130 | decay_steps: int = 1000000, 131 | freeze_encoder: bool = False, 132 | ): 133 | if not language_conditioned: 134 | if shared_goal_encoder is None or early_goal_concat is None: 135 | raise ValueError( 136 | "If not language conditioned, shared_goal_encoder and early_goal_concat must be set" 137 | ) 138 | 139 | if early_goal_concat: 140 | # passing None as the goal encoder causes early goal concat 141 | goal_encoder_def = None 142 | else: 143 | if shared_goal_encoder: 144 | goal_encoder_def = encoder_def 145 | else: 146 | goal_encoder_def = copy.deepcopy(encoder_def) 147 | 148 | encoder_def = GCEncodingWrapper( 149 | encoder=encoder_def, 150 | goal_encoder=goal_encoder_def, 151 | use_proprio=use_proprio, 152 | stop_gradient=freeze_encoder, 153 | ) 154 | else: 155 | if shared_goal_encoder is not None or early_goal_concat is not None: 156 | raise ValueError( 157 | "If language conditioned, shared_goal_encoder and early_goal_concat must not be set" 158 | ) 159 | encoder_def = LCEncodingWrapper( 160 | encoder=encoder_def, 161 | use_proprio=use_proprio, 162 | stop_gradient=freeze_encoder, 163 | ) 164 | 165 | network_kwargs["activate_final"] = True 166 | networks = { 167 | "actor": Policy( 168 | encoder_def, 169 | MLP(**network_kwargs), 170 | action_dim=actions.shape[-1], 171 | **policy_kwargs, 172 | ) 173 | } 174 | 175 | model_def = ModuleDict(networks) 176 | 177 | # create optimizer 178 | tx, lr_schedule = make_optimizer( 179 | learning_rate=learning_rate, 180 | warmup_steps=warmup_steps, 181 | cosine_decay_steps=decay_steps if decay_steps is not None else None, 182 | weight_decay=0.001, 183 | beta2=0.98, 184 | clip_grad_norm=1.0, 185 | return_lr_schedule=True, 186 | ) 187 | 188 | rng, init_rng = jax.random.split(rng) 189 | params = jax.jit(model_def.init)(init_rng, actor=[(observations, goals)])[ 190 | "params" 191 | ] 192 | 193 | rng, create_rng = jax.random.split(rng) 194 | state = JaxRLTrainState.create( 195 | apply_fn=model_def.apply, 196 | params=params, 197 | txs=tx, 198 | rng=create_rng, 199 | ) 200 | 201 | return cls(state, lr_schedule) 202 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/common/encoding.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | 3 | import flax 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | from einops import rearrange, repeat 8 | 9 | 10 | class EncodingWrapper(nn.Module): 11 | """ 12 | Encodes observations into a single flat encoding, adding additional 13 | functionality for adding proprioception and stopping the gradient. 14 | 15 | Args: 16 | encoder: The encoder network. 17 | use_proprio: Whether to concatenate proprioception (after encoding). 18 | stop_gradient: Whether to stop the gradient after the encoder. 19 | """ 20 | 21 | encoder: nn.Module 22 | use_proprio: bool 23 | stop_gradient: bool 24 | enable_stacking: bool = False 25 | 26 | def __call__(self, observations: Dict[str, jnp.ndarray]) -> jnp.ndarray: 27 | if isinstance(observations, flax.core.FrozenDict) or isinstance( 28 | observations, dict 29 | ): 30 | obs = observations["image"] 31 | if self.enable_stacking: 32 | # Combine stacking and channels into a single dimension 33 | if len(obs.shape) == 4: 34 | obs = rearrange(obs, "T H W C -> H W (T C)") 35 | if len(obs.shape) == 5: 36 | obs = rearrange(obs, "B T H W C -> B H W (T C)") 37 | 38 | else: 39 | obs = observations 40 | 41 | encoding = self.encoder(obs) 42 | 43 | if self.use_proprio: 44 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 45 | if self.stop_gradient: 46 | encoding = jax.lax.stop_gradient(encoding) 47 | return encoding 48 | 49 | 50 | class GCEncodingWrapper(nn.Module): 51 | """ 52 | Encodes observations and goals into a single flat encoding. Handles all the 53 | logic about when/how to combine observations and goals. 54 | 55 | Takes a tuple (observations, goals) as input. 56 | 57 | Args: 58 | encoder: The encoder network for observations. 59 | goal_encoder: The encoder to use for goals (optional). If None, early 60 | goal concatenation is used, i.e. the goal is concatenated to the 61 | observation channel-wise before passing it through the encoder. 62 | use_proprio: Whether to concatenate proprioception (after encoding). 63 | stop_gradient: Whether to stop the gradient after the encoder. 64 | """ 65 | 66 | encoder: nn.Module 67 | goal_encoder: Optional[nn.Module] 68 | use_proprio: bool 69 | stop_gradient: bool 70 | 71 | def __call__( 72 | self, 73 | observations_and_goals: Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]], 74 | ) -> jnp.ndarray: 75 | observations, goals = observations_and_goals 76 | 77 | if len(observations["image"].shape) == 5: 78 | # obs history case 79 | batch_size, obs_horizon = observations["image"].shape[:2] 80 | # fold batch_size into obs_horizon to encode each frame separately 81 | obs_image = rearrange(observations["image"], "B T H W C -> (B T) H W C") 82 | # repeat goals so that there's a goal for each frame 83 | goal_image = repeat( 84 | goals["image"], "B H W C -> (B repeat) H W C", repeat=obs_horizon 85 | ) 86 | else: 87 | obs_image = observations["image"] 88 | goal_image = goals["image"] 89 | 90 | if self.goal_encoder is None: 91 | # early goal concat 92 | encoder_inputs = jnp.concatenate([obs_image, goal_image], axis=-1) 93 | encoding = self.encoder(encoder_inputs) 94 | else: 95 | # late fusion 96 | encoding = self.encoder(obs_image) 97 | goal_encoding = self.goal_encoder(goals["image"]) 98 | encoding = jnp.concatenate([encoding, goal_encoding], axis=-1) 99 | 100 | if len(observations["image"].shape) == 5: 101 | # unfold obs_horizon from batch_size 102 | encoding = rearrange( 103 | encoding, "(B T) F -> B (T F)", B=batch_size, T=obs_horizon 104 | ) 105 | 106 | if self.use_proprio: 107 | if len(encoding.shape) == 2 and len(observations["proprio"].shape) == 3: 108 | # edge case 109 | encoding = jnp.concatenate( 110 | [encoding, observations["proprio"][:, 0, :]], axis=-1 111 | ) 112 | else: 113 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 114 | 115 | if self.stop_gradient: 116 | encoding = jax.lax.stop_gradient(encoding) 117 | 118 | return encoding 119 | 120 | 121 | class LCEncodingWrapper(nn.Module): 122 | """ 123 | Encodes observations and language instructions into a single flat encoding. 124 | 125 | Takes a tuple (observations, goals) as input, where goals contains the language instruction. 126 | 127 | Args: 128 | encoder: The encoder network for observations. 129 | use_proprio: Whether to concatenate proprioception (after encoding). 130 | stop_gradient: Whether to stop the gradient after the encoder. 131 | """ 132 | 133 | encoder: nn.Module 134 | use_proprio: bool 135 | stop_gradient: bool 136 | 137 | def __call__( 138 | self, 139 | observations_and_goals: Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]], 140 | ) -> jnp.ndarray: 141 | observations, goals = observations_and_goals 142 | 143 | if len(observations["image"].shape) == 5: 144 | # obs history case 145 | batch_size, obs_horizon = observations["image"].shape[:2] 146 | # fold batch_size into obs_horizon to encode each frame separately 147 | obs_image = rearrange(observations["image"], "B T H W C -> (B T) H W C") 148 | # repeat language so that there's an instruction for each frame 149 | language = repeat( 150 | goals["language"], "B E -> (B repeat) E", repeat=obs_horizon 151 | ) 152 | else: 153 | obs_image = observations["image"] 154 | language = goals["language"] 155 | 156 | encoding = self.encoder(obs_image, cond_var=language) 157 | 158 | if len(observations["image"].shape) == 5: 159 | # unfold obs_horizon from batch_size 160 | encoding = rearrange( 161 | encoding, "(B T) F -> B (T F)", B=batch_size, T=obs_horizon 162 | ) 163 | 164 | if self.use_proprio: 165 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1) 166 | 167 | if self.stop_gradient: 168 | encoding = jax.lax.stop_gradient(encoding) 169 | 170 | return encoding 171 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/common/optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import optax 4 | 5 | 6 | def make_optimizer( 7 | learning_rate: float = 3e-4, 8 | warmup_steps: int = 0, 9 | cosine_decay_steps: Optional[int] = None, 10 | weight_decay: Optional[float] = None, 11 | beta2: Optional[float] = None, 12 | clip_grad_norm: Optional[float] = None, 13 | return_lr_schedule: bool = False, 14 | ) -> optax.GradientTransformation: 15 | if cosine_decay_steps is not None: 16 | learning_rate_schedule = optax.warmup_cosine_decay_schedule( 17 | init_value=0.0, 18 | peak_value=learning_rate, 19 | warmup_steps=warmup_steps, 20 | decay_steps=cosine_decay_steps, 21 | end_value=0.0, 22 | ) 23 | else: 24 | learning_rate_schedule = optax.join_schedules( 25 | [ 26 | optax.linear_schedule(0.0, learning_rate, warmup_steps), 27 | optax.constant_schedule(learning_rate), 28 | ], 29 | [warmup_steps], 30 | ) 31 | 32 | # Define optimizers 33 | @optax.inject_hyperparams 34 | def optimizer(learning_rate: float, weight_decay: Optional[float]): 35 | optimizer_stages = [] 36 | 37 | if clip_grad_norm is not None: 38 | print("Info: turning on gradient clipping") 39 | optimizer_stages.append(optax.clip_by_global_norm(clip_grad_norm)) 40 | 41 | if weight_decay is not None: 42 | optimizer_stages.append( 43 | optax.adamw( 44 | learning_rate=learning_rate, 45 | weight_decay=weight_decay, 46 | b2=0.999 if beta2 is None else beta2, 47 | ) 48 | ) 49 | else: 50 | optimizer_stages.append( 51 | optax.adam( 52 | learning_rate=learning_rate, b2=0.999 if beta2 is None else beta2 53 | ) 54 | ) 55 | 56 | return optax.chain(*optimizer_stages) 57 | 58 | if return_lr_schedule: 59 | return ( 60 | optimizer(learning_rate=learning_rate_schedule, weight_decay=weight_decay), 61 | learning_rate_schedule, 62 | ) 63 | else: 64 | return optimizer( 65 | learning_rate=learning_rate_schedule, weight_decay=weight_decay 66 | ) 67 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/common/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Sequence, Union 2 | 3 | import flax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | PRNGKey = Any 9 | Params = flax.core.FrozenDict[str, Any] 10 | Shape = Sequence[int] 11 | Dtype = Any # this could be a real type? 12 | InfoDict = Dict[str, float] 13 | Array = Union[np.ndarray, jnp.ndarray, tf.Tensor] 14 | Data = Union[Array, Dict[str, "Data"]] 15 | Batch = Dict[str, Data] 16 | # A method to be passed into TrainState.__call__ 17 | ModuleMethod = Union[str, Callable, None] 18 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/common/wandb.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import tempfile 3 | from copy import copy 4 | from socket import gethostname 5 | 6 | import absl.flags as flags 7 | import ml_collections 8 | import wandb 9 | 10 | 11 | def _recursive_flatten_dict(d: dict): 12 | keys, values = [], [] 13 | for key, value in d.items(): 14 | if isinstance(value, dict): 15 | sub_keys, sub_values = _recursive_flatten_dict(value) 16 | keys += [f"{key}/{k}" for k in sub_keys] 17 | values += sub_values 18 | else: 19 | keys.append(key) 20 | values.append(value) 21 | return keys, values 22 | 23 | 24 | class WandBLogger(object): 25 | @staticmethod 26 | def get_default_config(): 27 | config = ml_collections.ConfigDict() 28 | config.project = "jaxrl_m" # WandB Project Name 29 | config.entity = ml_collections.config_dict.FieldReference(None, field_type=str) 30 | # Which entity to log as (default: your own user) 31 | config.exp_descriptor = "" # Run name (doesn't have to be unique) 32 | # Unique identifier for run (will be automatically generated unless 33 | # provided) 34 | config.unique_identifier = "" 35 | config.group = None 36 | return config 37 | 38 | def __init__( 39 | self, 40 | wandb_config, 41 | variant, 42 | wandb_output_dir=None, 43 | debug=False, 44 | ): 45 | self.config = wandb_config 46 | if self.config.unique_identifier == "": 47 | self.config.unique_identifier = datetime.datetime.now().strftime( 48 | "%Y%m%d_%H%M%S" 49 | ) 50 | 51 | self.config.experiment_id = self.experiment_id = ( 52 | f"{self.config.exp_descriptor}_{self.config.unique_identifier}" # NOQA 53 | ) 54 | 55 | print(self.config) 56 | 57 | if wandb_output_dir is None: 58 | wandb_output_dir = tempfile.mkdtemp() 59 | 60 | self._variant = copy(variant) 61 | 62 | if "hostname" not in self._variant: 63 | self._variant["hostname"] = gethostname() 64 | 65 | if debug: 66 | mode = "disabled" 67 | else: 68 | mode = "online" 69 | 70 | self.run = wandb.init( 71 | config=self._variant, 72 | project=self.config.project, 73 | entity=self.config.entity, 74 | group=self.config.group, 75 | tags=self.config.tag, 76 | dir=wandb_output_dir, 77 | id=self.config.experiment_id, 78 | save_code=True, 79 | mode=mode, 80 | ) 81 | 82 | if flags.FLAGS.is_parsed(): 83 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS} 84 | else: 85 | flag_dict = {} 86 | for k in flag_dict: 87 | if isinstance(flag_dict[k], ml_collections.ConfigDict): 88 | flag_dict[k] = flag_dict[k].to_dict() 89 | wandb.config.update(flag_dict) 90 | 91 | def log(self, data: dict, step: int = None): 92 | data_flat = _recursive_flatten_dict(data) 93 | data = {k: v for k, v in zip(*data_flat)} 94 | wandb.log(data, step=step) 95 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/data/text_processing.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import tensorflow as tf 6 | from flax.core import FrozenDict 7 | 8 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3" 9 | 10 | 11 | class TextProcessor: 12 | """ 13 | Base class for text tokenization or text embedding. 14 | """ 15 | 16 | def encode(self, strings): 17 | pass 18 | 19 | 20 | class HFTokenizer(TextProcessor): 21 | def __init__( 22 | self, 23 | tokenizer_name: str, 24 | tokenizer_kwargs: Optional[dict] = { 25 | "max_length": 64, 26 | "padding": "max_length", 27 | "truncation": True, 28 | "return_tensors": "np", 29 | }, 30 | encode_with_model: bool = False, 31 | ): 32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 35 | self.tokenizer_kwargs = tokenizer_kwargs 36 | self.encode_with_model = encode_with_model 37 | if self.encode_with_model: 38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name) 39 | 40 | def encode(self, strings): 41 | # this creates another nested layer with "input_ids", "attention_mask", etc. 42 | inputs = self.tokenizer( 43 | strings, 44 | **self.tokenizer_kwargs, 45 | ) 46 | if self.encode_with_model: 47 | return np.array(self.model(**inputs).last_hidden_state) 48 | else: 49 | return FrozenDict(inputs) 50 | 51 | 52 | class MuseEmbedding(TextProcessor): 53 | def __init__(self): 54 | import tensorflow_hub as hub # lazy import 55 | import tensorflow_text # required for muse 56 | 57 | self.muse_model = hub.load(MULTI_MODULE) 58 | 59 | def encode(self, strings): 60 | with tf.device("/cpu:0"): 61 | return self.muse_model(strings).numpy() 62 | 63 | 64 | class CLIPTextProcessor(TextProcessor): 65 | def __init__( 66 | self, 67 | tokenizer_kwargs: Optional[dict] = { 68 | "max_length": 64, 69 | "padding": "max_length", 70 | "truncation": True, 71 | "return_tensors": "np", 72 | }, 73 | ): 74 | from transformers import CLIPProcessor 75 | 76 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 77 | self.kwargs = tokenizer_kwargs 78 | 79 | def encode(self, strings): 80 | inputs = self.processor( 81 | text=strings, 82 | **self.kwargs, 83 | ) 84 | inputs["position_ids"] = jnp.expand_dims( 85 | jnp.arange(inputs["input_ids"].shape[1]), axis=0 86 | ).repeat(inputs["input_ids"].shape[0], axis=0) 87 | return FrozenDict(inputs) 88 | 89 | 90 | text_processors = { 91 | "hf_tokenizer": HFTokenizer, 92 | "muse_embedding": MuseEmbedding, 93 | "clip_processor": CLIPTextProcessor, 94 | } 95 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/data/tf_augmentations.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | import tensorflow as tf 4 | from ml_collections import ConfigDict 5 | 6 | 7 | def random_resized_crop(image, scale, ratio, seed, batched=False): 8 | if not batched: 9 | image = tf.expand_dims(image, axis=0) 10 | batch_size = tf.shape(image)[0] 11 | # taken from https://keras.io/examples/vision/nnclr/#random-resized-crops 12 | log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1])) 13 | height = tf.shape(image)[-3] 14 | width = tf.shape(image)[-2] 15 | 16 | random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1]) 17 | random_ratios = tf.exp( 18 | tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1]) 19 | ) 20 | 21 | new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1) 22 | new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1) 23 | height_offsets = tf.random.stateless_uniform( 24 | (batch_size,), seed, 0, 1 - new_heights 25 | ) 26 | width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths) 27 | 28 | bounding_boxes = tf.stack( 29 | [ 30 | height_offsets, 31 | width_offsets, 32 | height_offsets + new_heights, 33 | width_offsets + new_widths, 34 | ], 35 | axis=1, 36 | ) 37 | 38 | if len(tf.shape(image)) == 5: 39 | obs_horizon = tf.shape(image)[1] 40 | # fold obs_horizon dimension into batch dimension 41 | image = tf.reshape(image, [batch_size * obs_horizon, height, width, -1]) 42 | # repeat bounding_boxes so each obs history is augmented the same 43 | bounding_boxes = tf.repeat(bounding_boxes, obs_horizon, axis=0) 44 | image = tf.image.crop_and_resize( 45 | image, bounding_boxes, tf.range(batch_size * obs_horizon), (height, width) 46 | ) 47 | image = tf.reshape(image, [batch_size, obs_horizon, height, width, -1]) 48 | else: 49 | image = tf.image.crop_and_resize( 50 | image, bounding_boxes, tf.range(batch_size), (height, width) 51 | ) 52 | 53 | if not batched: 54 | return image[0] 55 | else: 56 | return image 57 | 58 | 59 | AUGMENT_OPS = { 60 | "random_resized_crop": random_resized_crop, 61 | "random_brightness": tf.image.stateless_random_brightness, 62 | "random_contrast": tf.image.stateless_random_contrast, 63 | "random_saturation": tf.image.stateless_random_saturation, 64 | "random_hue": tf.image.stateless_random_hue, 65 | "random_flip_left_right": tf.image.stateless_random_flip_left_right, 66 | } 67 | 68 | 69 | def augment(image, seed, **augment_kwargs): 70 | image = tf.cast(image, tf.float32) / 255 # convert images to [0, 1] 71 | for op in augment_kwargs["augment_order"]: 72 | if op in augment_kwargs: 73 | if isinstance(augment_kwargs[op], Mapping) or isinstance( 74 | augment_kwargs[op], ConfigDict 75 | ): 76 | image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op]) 77 | else: 78 | image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op]) 79 | else: 80 | image = AUGMENT_OPS[op](image, seed=seed) 81 | image = tf.clip_by_value(image, 0, 1) 82 | image = tf.cast(image * 255, tf.uint8) 83 | return image 84 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/networks/actor_critic_nets.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import distrax 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | 7 | from jaxrl_m.common.common import default_init 8 | 9 | 10 | class ValueCritic(nn.Module): 11 | encoder: nn.Module 12 | network: nn.Module 13 | init_final: Optional[float] = None 14 | 15 | @nn.compact 16 | def __call__(self, observations: jnp.ndarray, train: bool = False) -> jnp.ndarray: 17 | outputs = self.network(self.encoder(observations), train=train) 18 | if self.init_final is not None: 19 | value = nn.Dense( 20 | 1, 21 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), 22 | )(outputs) 23 | else: 24 | value = nn.Dense(1, kernel_init=default_init())(outputs) 25 | return jnp.squeeze(value, -1) 26 | 27 | 28 | class Critic(nn.Module): 29 | encoder: Optional[nn.Module] 30 | network: nn.Module 31 | init_final: Optional[float] = None 32 | network_separate_action_input: bool = False # for PTR, input action to every layer 33 | 34 | @nn.compact 35 | def __call__( 36 | self, observations: jnp.ndarray, actions: jnp.ndarray, train: bool = False 37 | ) -> jnp.ndarray: 38 | if self.encoder is None: 39 | obs_enc = observations 40 | else: 41 | obs_enc = self.encoder(observations) 42 | 43 | if self.network_separate_action_input: 44 | outputs = self.network(obs_enc, actions, train=train) 45 | else: 46 | inputs = jnp.concatenate([obs_enc, actions], -1) 47 | outputs = self.network(inputs, train=train) 48 | if self.init_final is not None: 49 | value = nn.Dense( 50 | 1, 51 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), 52 | )(outputs) 53 | else: 54 | value = nn.Dense(1, kernel_init=default_init())(outputs) 55 | return jnp.squeeze(value, -1) 56 | 57 | 58 | def ensemblize(cls, num_qs, out_axes=0): 59 | return nn.vmap( 60 | cls, 61 | variable_axes={"params": 0}, 62 | split_rngs={"params": True}, 63 | in_axes=None, 64 | out_axes=out_axes, 65 | axis_size=num_qs, 66 | ) 67 | 68 | 69 | class Policy(nn.Module): 70 | encoder: Optional[nn.Module] 71 | network: nn.Module 72 | action_dim: int 73 | init_final: Optional[float] = None 74 | std_parameterization: str = "exp" # "exp", "softplus", "fixed", or "uniform" 75 | std_min: Optional[float] = 1e-5 76 | std_max: Optional[float] = 10.0 77 | tanh_squash_distribution: bool = False 78 | fixed_std: Optional[jnp.ndarray] = None 79 | 80 | @nn.compact 81 | def __call__( 82 | self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False 83 | ) -> distrax.Distribution: 84 | if self.encoder is None: 85 | obs_enc = observations 86 | else: 87 | obs_enc = self.encoder(observations) 88 | 89 | outputs = self.network(obs_enc, train=train) 90 | 91 | means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs) 92 | if self.fixed_std is None: 93 | if self.std_parameterization == "exp": 94 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init())( 95 | outputs 96 | ) 97 | stds = jnp.exp(log_stds) 98 | elif self.std_parameterization == "softplus": 99 | stds = nn.Dense(self.action_dim, kernel_init=default_init())(outputs) 100 | stds = nn.softplus(stds) 101 | elif self.std_parameterization == "uniform": 102 | log_stds = self.param( 103 | "log_stds", nn.initializers.zeros, (self.action_dim,) 104 | ) 105 | stds = jnp.exp(log_stds) 106 | else: 107 | raise ValueError( 108 | f"Invalid std_parameterization: {self.std_parameterization}" 109 | ) 110 | else: 111 | assert self.std_parameterization == "fixed" 112 | stds = jnp.array(self.fixed_std) 113 | 114 | # Clip stds to avoid numerical instability 115 | # For a normal distribution under MaxEnt, optimal std scales with sqrt(temperature) 116 | stds = jnp.clip(stds, self.std_min, self.std_max) * jnp.sqrt(temperature) 117 | # stds = jnp.concatenate([stds[:, :6], jnp.ones((len(stds), 1)) * jnp.log(0.3)], axis=-1) 118 | 119 | if self.tanh_squash_distribution: 120 | distribution = TanhMultivariateNormalDiag( 121 | loc=means, 122 | scale_diag=stds, 123 | ) 124 | else: 125 | distribution = distrax.MultivariateNormalDiag( 126 | loc=means, 127 | scale_diag=stds, 128 | ) 129 | 130 | return distribution 131 | 132 | 133 | class TanhMultivariateNormalDiag(distrax.Transformed): 134 | def __init__( 135 | self, 136 | loc: jnp.ndarray, 137 | scale_diag: jnp.ndarray, 138 | low: Optional[jnp.ndarray] = None, 139 | high: Optional[jnp.ndarray] = None, 140 | ): 141 | distribution = distrax.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) 142 | 143 | layers = [] 144 | 145 | if not (low is None or high is None): 146 | 147 | def rescale_from_tanh(x): 148 | x = (x + 1) / 2 # (-1, 1) => (0, 1) 149 | return x * (high - low) + low 150 | 151 | def forward_log_det_jacobian(x): 152 | high_ = jnp.broadcast_to(high, x.shape) 153 | low_ = jnp.broadcast_to(low, x.shape) 154 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1) 155 | 156 | layers.append( 157 | distrax.Lambda( 158 | rescale_from_tanh, 159 | forward_log_det_jacobian=forward_log_det_jacobian, 160 | event_ndims_in=1, 161 | event_ndims_out=1, 162 | ) 163 | ) 164 | 165 | layers.append(distrax.Block(distrax.Tanh(), 1)) 166 | 167 | bijector = distrax.Chain(layers) 168 | 169 | super().__init__(distribution=distribution, bijector=bijector) 170 | 171 | def mode(self) -> jnp.ndarray: 172 | return self.bijector.forward(self.distribution.mode()) 173 | 174 | def stddev(self) -> jnp.ndarray: 175 | return self.bijector.forward(self.distribution.stddev()) 176 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/networks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | from jaxrl_m.common.common import default_init 7 | 8 | 9 | class MLP(nn.Module): 10 | hidden_dims: Sequence[int] 11 | activations: Callable[[jnp.ndarray], jnp.ndarray] | str = nn.swish 12 | activate_final: bool = False 13 | use_layer_norm: bool = False 14 | use_group_norm: bool = False 15 | dropout_rate: Optional[float] = None 16 | 17 | def setup(self): 18 | assert not (self.use_layer_norm and self.use_group_norm) 19 | 20 | @nn.compact 21 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 22 | activations = self.activations 23 | if isinstance(activations, str): 24 | activations = getattr(nn, activations) 25 | 26 | for i, size in enumerate(self.hidden_dims): 27 | x = nn.Dense(size, kernel_init=default_init())(x) 28 | 29 | if i + 1 < len(self.hidden_dims) or self.activate_final: 30 | if self.dropout_rate is not None and self.dropout_rate > 0: 31 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 32 | if self.use_layer_norm: 33 | x = nn.LayerNorm()(x) 34 | elif self.use_group_norm: 35 | x = nn.GroupNorm()(x) 36 | x = activations(x) 37 | return x 38 | 39 | 40 | class LayerInputMLP(MLP): 41 | """ 42 | MLP, but each layer takes in an additional input as well 43 | such as the critic network in PTR 44 | """ 45 | 46 | @nn.compact 47 | def __call__( 48 | self, x: jnp.ndarray, layer_input: jnp.ndarray, train: bool = False 49 | ) -> jnp.ndarray: 50 | activations = self.activations 51 | if isinstance(activations, str): 52 | activations = getattr(nn, activations) 53 | 54 | for i, size in enumerate(self.hidden_dims): 55 | x = jnp.concatenate([x, layer_input], axis=-1) # difference from MLP 56 | x = nn.Dense(size, kernel_init=default_init())(x) 57 | 58 | if i + 1 < len(self.hidden_dims) or self.activate_final: 59 | if self.dropout_rate is not None and self.dropout_rate > 0: 60 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 61 | if self.use_layer_norm: 62 | x = nn.LayerNorm()(x) 63 | elif self.use_group_norm: 64 | x = nn.GroupNorm()(x) 65 | x = activations(x) 66 | return x 67 | 68 | 69 | class MLPResNetBlock(nn.Module): 70 | features: int 71 | act: Callable 72 | dropout_rate: float = None 73 | use_layer_norm: bool = False 74 | 75 | @nn.compact 76 | def __call__(self, x, train: bool = False): 77 | residual = x 78 | if self.dropout_rate is not None and self.dropout_rate > 0: 79 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 80 | if self.use_layer_norm: 81 | x = nn.LayerNorm()(x) 82 | x = nn.Dense(self.features * 4)(x) 83 | x = self.act(x) 84 | x = nn.Dense(self.features)(x) 85 | 86 | if residual.shape != x.shape: 87 | residual = nn.Dense(self.features)(residual) 88 | 89 | return residual + x 90 | 91 | 92 | class MLPResNet(nn.Module): 93 | num_blocks: int 94 | out_dim: int 95 | dropout_rate: float = None 96 | use_layer_norm: bool = False 97 | hidden_dim: int = 256 98 | activations: Callable = nn.swish 99 | 100 | @nn.compact 101 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 102 | x = nn.Dense(self.hidden_dim, kernel_init=default_init())(x) 103 | for _ in range(self.num_blocks): 104 | x = MLPResNetBlock( 105 | self.hidden_dim, 106 | act=self.activations, 107 | use_layer_norm=self.use_layer_norm, 108 | dropout_rate=self.dropout_rate, 109 | )(x, train=train) 110 | 111 | x = self.activations(x) 112 | x = nn.Dense(self.out_dim, kernel_init=default_init())(x) 113 | return x 114 | 115 | 116 | class Scalar(nn.Module): 117 | init_value: float 118 | 119 | def setup(self): 120 | self.value = self.param("value", lambda x: self.init_value) 121 | 122 | def __call__(self): 123 | return self.value 124 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/utils/jax_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | 4 | @jax.jit 5 | def batch_to_jax(batch): 6 | return jax.tree_util.tree_map(jax.device_put, batch) 7 | 8 | 9 | class JaxRNG(object): 10 | """A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside 11 | pure function. 12 | """ 13 | 14 | @classmethod 15 | def from_seed(cls, seed): 16 | return cls(jax.random.PRNGKey(seed)) 17 | 18 | def __init__(self, rng): 19 | self.rng = rng 20 | 21 | def __call__(self, keys=None): 22 | if keys is None: 23 | self.rng, split_rng = jax.random.split(self.rng) 24 | return split_rng 25 | elif isinstance(keys, int): 26 | split_rngs = jax.random.split(self.rng, num=keys + 1) 27 | self.rng = split_rngs[0] 28 | return tuple(split_rngs[1:]) 29 | else: 30 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1) 31 | self.rng = split_rngs[0] 32 | return {key: val for key, val in zip(keys, split_rngs[1:])} 33 | 34 | 35 | def wrap_function_with_rng(rng): 36 | """To be used as decorator, automatically bookkeep a RNG for the wrapped function.""" 37 | 38 | def wrap_function(function): 39 | def wrapped(*args, **kwargs): 40 | nonlocal rng 41 | rng, split_rng = jax.random.split(rng) 42 | return function(split_rng, *args, **kwargs) 43 | 44 | return wrapped 45 | 46 | return wrap_function 47 | 48 | 49 | def init_rng(seed): 50 | global jax_utils_rng 51 | jax_utils_rng = JaxRNG.from_seed(seed) 52 | 53 | 54 | def next_rng(*args, **kwargs): 55 | global jax_utils_rng 56 | return jax_utils_rng(*args, **kwargs) 57 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/utils/timer_utils.py: -------------------------------------------------------------------------------- 1 | """Timer utility.""" 2 | 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class _TimerContextManager: 8 | def __init__(self, timer: "Timer", key: str): 9 | self.timer = timer 10 | self.key = key 11 | 12 | def __enter__(self): 13 | self.timer.tick(self.key) 14 | 15 | def __exit__(self, exc_type, exc_value, exc_traceback): 16 | self.timer.tock(self.key) 17 | 18 | 19 | class Timer: 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.counts = defaultdict(int) 25 | self.times = defaultdict(float) 26 | self.start_times = {} 27 | 28 | def tick(self, key): 29 | if key in self.start_times: 30 | raise ValueError(f"Timer is already ticking for key: {key}") 31 | self.start_times[key] = time.time() 32 | 33 | def tock(self, key): 34 | if key not in self.start_times: 35 | raise ValueError(f"Timer is not ticking for key: {key}") 36 | self.counts[key] += 1 37 | self.times[key] += time.time() - self.start_times[key] 38 | del self.start_times[key] 39 | 40 | def force_tock_everything(self): 41 | for key in self.start_times: 42 | self.tock(key) 43 | 44 | def context(self, key): 45 | """ 46 | Use this like: 47 | 48 | with timer.context("key"): 49 | # do stuff 50 | 51 | Then timer.tock("key") will be called automatically. 52 | """ 53 | return _TimerContextManager(self, key) 54 | 55 | def get_average_times(self, reset=True): 56 | ret = {key: self.times[key] / self.counts[key] for key in self.counts} 57 | if reset: 58 | self.reset() 59 | return ret 60 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | import imageio 4 | import numpy as np 5 | import tensorflow as tf 6 | import wandb 7 | from flax.core import frozen_dict 8 | 9 | 10 | def concatenate_batches(batches): 11 | concatenated = {} 12 | for key in batches[0].keys(): 13 | if isinstance(batches[0][key], Mapping): 14 | # to concatenate batch["observations"]["image"], etc. 15 | concatenated[key] = concatenate_batches([batch[key] for batch in batches]) 16 | else: 17 | concatenated[key] = np.concatenate( 18 | [batch[key] for batch in batches], axis=0 19 | ).astype(np.float32) 20 | return concatenated 21 | 22 | 23 | def index_batch(batch, indices): 24 | indexed = {} 25 | for key in batch.keys(): 26 | if isinstance(batch[key], Mapping): 27 | # to index into batch["observations"]["image"], etc. 28 | indexed[key] = index_batch(batch[key], indices) 29 | else: 30 | indexed[key] = batch[key][indices, ...] 31 | return indexed 32 | 33 | 34 | def subsample_batch(batch, size): 35 | indices = np.random.randint(batch["rewards"].shape[0], size=size) 36 | return index_batch(batch, indices) 37 | 38 | 39 | def load_recorded_video( 40 | video_path: str, 41 | ): 42 | with tf.io.gfile.GFile(video_path, "rb") as f: 43 | video = np.array(imageio.mimread(f, "MP4")).transpose((0, 3, 1, 2)) 44 | assert video.shape[1] == 3, "Numpy array should be (T, C, H, W)" 45 | 46 | return wandb.Video(video, fps=20) 47 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from jaxrl_m.vision.resnet_v1 import resnetv1_configs 2 | 3 | encoders = dict() 4 | encoders.update(resnetv1_configs) 5 | -------------------------------------------------------------------------------- /model_training/jaxrl_m/vision/film_conditioning_layer.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py 2 | import flax.linen as nn 3 | import jax.numpy as jnp 4 | 5 | 6 | class FilmConditioning(nn.Module): 7 | @nn.compact 8 | def __call__(self, conv_filters: jnp.ndarray, conditioning: jnp.ndarray): 9 | """Applies FiLM conditioning to a convolutional feature map. 10 | 11 | Args: 12 | conv_filters: A tensor of shape [batch_size, height, width, channels]. 13 | conditioning: A tensor of shape [batch_size, conditioning_size]. 14 | 15 | Returns: 16 | A tensor of shape [batch_size, height, width, channels]. 17 | """ 18 | projected_cond_add = nn.Dense( 19 | features=conv_filters.shape[-1], 20 | kernel_init=nn.initializers.zeros, 21 | bias_init=nn.initializers.zeros, 22 | )(conditioning) 23 | projected_cond_mult = nn.Dense( 24 | features=conv_filters.shape[-1], 25 | kernel_init=nn.initializers.zeros, 26 | bias_init=nn.initializers.zeros, 27 | )(conditioning) 28 | 29 | projected_cond_add = projected_cond_add[..., None, None, :] 30 | projected_cond_mult = projected_cond_mult[..., None, None, :] 31 | 32 | return conv_filters * (1 + projected_cond_add) + projected_cond_mult 33 | 34 | 35 | if __name__ == "__main__": 36 | import jax 37 | import jax.numpy as jnp 38 | 39 | key = jax.random.PRNGKey(0) 40 | key, subkey = jax.random.split(key) 41 | x = jax.random.normal(subkey, (1, 32, 32, 3)) 42 | x = jnp.array(x) 43 | 44 | z = jnp.ones((1, 64)) 45 | film = FilmConditioning() 46 | params = film.init(key, x, z) 47 | y = film.apply(params, x, z) 48 | 49 | print(y.shape) 50 | -------------------------------------------------------------------------------- /model_training/requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.26 2 | numpy==1.24.3 3 | jax==0.4.20 4 | jaxlib==0.4.20 5 | distrax==0.1.5 6 | flax==0.7.5 7 | orbax-checkpoint==0.3.5 8 | ml_collections >= 0.1.0 9 | tqdm >= 4.60.0 10 | chex==0.1.85 11 | optax==0.1.5 12 | absl-py >= 0.12.0 13 | scipy == 1.12.0 14 | wandb >= 0.12.14 15 | tensorflow==2.15.0 16 | einops >= 0.6.1 17 | imageio >= 2.31.1 18 | moviepy >= 1.0.3 19 | pre-commit == 3.3.3 20 | overrides==7.7.0 21 | tensorflow_probability==0.23.0 22 | tensorflow_text 23 | dlimp @ git+https://github.com/zhouzypaul/dlimp@2df85dc98a7d564fc9c7c5ff2bfda26361526483 24 | octo @ git+https://github.com/octo-models/octo.git 25 | -------------------------------------------------------------------------------- /model_training/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="jaxrl_m", packages=["jaxrl_m"]) 4 | -------------------------------------------------------------------------------- /rlds_converter/README.md: -------------------------------------------------------------------------------- 1 | # Data RLDS conversion code 2 | 3 | This directory contains the code that converts the raw robot data logged using the `data_collection` dir and converts it into 4 | the [RLDS format](https://github.com/google-research/rlds), which is a specification on top of the [TFDS](https://www.tensorflow.org/datasets) (TensorFlow datasets) format, which is for the most part built on top of the TFRecord format. 5 | 6 | This is sourced from [Zhiyuan Zhou's implementation](https://github.com/zhouzypaul/dlimp), which heavily inherits Kevin Black's [dlimp](https://github.com/kvablack/dlimp) library. 7 | 8 | 9 | ## Usage 10 | Install the requirements with 11 | ```bash 12 | pip install -r requirements.txt 13 | pip install -e . 14 | ``` 15 | 16 | To build the SOAR dataset 17 | ```bash 18 | cd soar_dataset 19 | CUDA_VISIBLE_DEVICES="" tfds build --manual_dir 20 | ``` 21 | You can modify settings in side the `soar_dataset/soar_dataset_dataset_builder.py` file (e.g., `NUM_WORKERS` and `CHUNKSIZE`). 22 | 23 | This data builder assumes your raw data is organized into the following structure: 24 | ``` 25 | manual_dir/robot_id/scene_id/policy_type/date/success/trajectory_{i}/* 26 | manual_dir/robot_id/scene_id/policy_type/date/failure/trajectory_{i}/* 27 | ``` 28 | Each `trajectory_{i}` directory will contain the following files, as logged by the code in the `data_collection` dir 29 | - actions.npy 30 | - eef_poses.npy 31 | - language_task.txt 32 | - robot_id.txt 33 | - task_list.txt 34 | - trajectory.mp4 35 | - combined.mp4 36 | - goals.mp4 37 | - object_list.txt 38 | - success.txt 39 | - time.txt 40 | 41 | The RLDS dataset will be automatically saved under `~/tensorflow_datasets/soar_dataset/` 42 | -------------------------------------------------------------------------------- /rlds_converter/dataset_builder.py: -------------------------------------------------------------------------------- 1 | """Inspired by https://github.com/kpertsch/bridge_rlds_builder/blob/f0d16c5a8384c1476aa1c274a9aef3a5f76cbada/bridge_dataset/conversion_utils.py""" 2 | 3 | import abc 4 | import itertools 5 | import multiprocessing as mp 6 | from typing import Any, Callable, Dict, Iterable, Tuple, Union 7 | 8 | import tensorflow_datasets as tfds 9 | from absl import logging 10 | from tensorflow_datasets.core import ( 11 | dataset_builder, 12 | download, 13 | example_serializer, 14 | file_adapters, 15 | naming, 16 | ) 17 | from tensorflow_datasets.core import split_builder as split_builder_lib 18 | from tensorflow_datasets.core import splits as splits_lib 19 | from tensorflow_datasets.core import writer as writer_lib 20 | from tqdm import tqdm 21 | 22 | Key = Union[str, int] 23 | Example = Dict[str, Any] 24 | ExampleInput = Any 25 | 26 | 27 | class MultiThreadedSplitBuilder(split_builder_lib.SplitBuilder): 28 | """Multithreaded version of tfds.core.SplitBuilder. Removes Apache Beam support, only supporting Python generators.""" 29 | 30 | def __init__( 31 | self, 32 | process_fn: Callable[[ExampleInput], Example], 33 | num_workers: int, 34 | chunksize: int, 35 | *args, 36 | **kwargs, 37 | ): 38 | super().__init__(*args, **kwargs) 39 | self._process_fn = process_fn 40 | self.num_workers = num_workers 41 | self.chunksize = chunksize 42 | 43 | def submit_split_generation( 44 | self, 45 | split_name: splits_lib.Split, 46 | generator: Iterable[Tuple[Key, ExampleInput]], 47 | filename_template: naming.ShardedFileTemplate, 48 | disable_shuffling: bool = False, 49 | ) -> splits_lib.SplitInfo: 50 | if self._max_examples_per_split is not None: 51 | logging.warning( 52 | "Splits capped at %s examples max.", self._max_examples_per_split 53 | ) 54 | generator = itertools.islice(generator, self._max_examples_per_split) 55 | total_num_examples = self._max_examples_per_split 56 | else: 57 | # If dataset info has been pre-downloaded from the internet, 58 | # we can use the pre-computed number of example for the progression bar. 59 | split_info = self._split_dict.get(split_name) 60 | if split_info and split_info.num_examples: 61 | total_num_examples = split_info.num_examples 62 | else: 63 | total_num_examples = None 64 | 65 | serialized_info = self._features.get_serialized_info() 66 | writer = writer_lib.Writer( 67 | serializer=example_serializer.ExampleSerializer(serialized_info), 68 | filename_template=filename_template, 69 | hash_salt=split_name, 70 | disable_shuffling=disable_shuffling, 71 | file_format=self._file_format, 72 | shard_config=self._shard_config, 73 | ) 74 | pbar = tqdm( 75 | total=total_num_examples, 76 | desc=f"Generating {split_name} examples...", 77 | unit=" examples", 78 | dynamic_ncols=True, 79 | miniters=1, 80 | ) 81 | with mp.Pool( 82 | self.num_workers, 83 | initializer=MultiThreadedSplitBuilder._worker_init, 84 | initargs=(self._process_fn, self._features), 85 | ) as pool: 86 | logging.info( 87 | "Using %d workers with chunksize %d.", self.num_workers, self.chunksize 88 | ) 89 | while True: 90 | curr = pbar.n 91 | iterator = itertools.islice(generator, self.chunksize) 92 | results = pool.map(MultiThreadedSplitBuilder._worker_fn, iterator) 93 | for key, example in results: 94 | writer._shuffler.add(key, example) 95 | writer._num_examples += 1 96 | pbar.update(1) 97 | if pbar.n == curr: 98 | break 99 | shard_lengths, total_size = writer.finalize() 100 | 101 | return splits_lib.SplitInfo( 102 | name=split_name, 103 | shard_lengths=shard_lengths, 104 | num_bytes=total_size, 105 | filename_template=filename_template, 106 | ) 107 | 108 | @staticmethod 109 | def _worker_init( 110 | process_fn: Callable[[ExampleInput], Example], 111 | features: tfds.features.FeaturesDict, 112 | ): 113 | global __process_fn 114 | global __features 115 | global __serializer 116 | __process_fn = process_fn 117 | __features = features 118 | __serializer = example_serializer.ExampleSerializer( 119 | features.get_serialized_info() 120 | ) 121 | 122 | @staticmethod 123 | def _worker_fn(example_input): 124 | global __process_fn 125 | global __features 126 | global __serializer 127 | key, example = __process_fn(example_input) 128 | return key, __serializer.serialize_example(__features.encode_example(example)) 129 | 130 | 131 | class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder): 132 | """Multithreaded version of tfds.core.GeneratorBasedBuilder.""" 133 | 134 | # Defaults can be overridden by subclasses. 135 | NUM_WORKERS = 16 # number of parallel workers 136 | CHUNKSIZE = 1000 # number of examples to process in memory before writing to disk 137 | 138 | @classmethod 139 | @abc.abstractmethod 140 | def _process_example(cls, example_input: ExampleInput) -> Example: 141 | """Process a single example. 142 | 143 | This is the function that will be parallelized, so it should contain any heavy computation and I/O. It 144 | should return a feature dictionary compatible with `self.info.features` (see the FeatureConnector 145 | documenation) that is ready to be encoded and serialized. 146 | """ 147 | raise NotImplementedError() 148 | 149 | @abc.abstractmethod 150 | def _split_generators( 151 | self, 152 | dl_manager: download.DownloadManager, 153 | ) -> Dict[splits_lib.Split, Iterable[Tuple[Key, ExampleInput]]]: 154 | """Same as GeneratorBasedBuilder._split_generators, but returns generators of tuples (key, 155 | example_input) rather than (key, example). `example_input` will be passed to 156 | `_process_example` for further processing. 157 | """ 158 | raise NotImplementedError() 159 | 160 | def _generate_examples(self, *args, **kwargs): 161 | """This is not actually called from TFDS code. I believe they left it in for legacy reasons. However, 162 | it must be overridden for TFDS to recognize the class as a valid dataset builder. 163 | """ 164 | raise RuntimeError() 165 | 166 | def _download_and_prepare( 167 | self, 168 | dl_manager: download.DownloadManager, 169 | download_config: download.DownloadConfig, 170 | ) -> None: 171 | """Same as superclass `_download_and_prepare`, but removes Apache Beam stuff and uses 172 | MultiThreadedSplitBuilder instead of SplitBuilder. 173 | """ 174 | split_builder = MultiThreadedSplitBuilder( 175 | process_fn=type(self)._process_example, 176 | num_workers=self.NUM_WORKERS, 177 | chunksize=self.CHUNKSIZE, 178 | split_dict=self.info.splits, 179 | features=self.info.features, 180 | dataset_size=self.info.dataset_size, 181 | max_examples_per_split=download_config.max_examples_per_split, 182 | beam_options=download_config.beam_options, 183 | beam_runner=download_config.beam_runner, 184 | file_format=self.info.file_format, 185 | shard_config=download_config.get_shard_config(), 186 | ) 187 | 188 | split_generators = self._split_generators(dl_manager) 189 | dataset_builder._check_split_names(split_generators.keys()) 190 | 191 | # Writer fail if the number of example yield is `0`, so we return here. 192 | if download_config.max_examples_per_split == 0: 193 | return 194 | 195 | # Start generating data for all splits 196 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ 197 | self.info.file_format 198 | ].FILE_SUFFIX 199 | 200 | split_infos = [] 201 | for split_name, generator in split_generators.items(): 202 | filename_template = naming.ShardedFileTemplate( 203 | split=split_name, 204 | dataset_name=self.name, 205 | data_dir=self.data_path, 206 | filetype_suffix=path_suffix, 207 | ) 208 | split_info = split_builder.submit_split_generation( 209 | split_name=split_name, 210 | generator=generator, 211 | filename_template=filename_template, 212 | disable_shuffling=self.info.disable_shuffling, 213 | ) 214 | split_infos.append(split_info) 215 | 216 | # Update the info object with the splits. 217 | split_dict = splits_lib.SplitDict(split_infos) 218 | self.info.set_splits(split_dict) 219 | -------------------------------------------------------------------------------- /rlds_converter/requirements.txt: -------------------------------------------------------------------------------- 1 | pillow == 10.2.0 2 | -------------------------------------------------------------------------------- /rlds_converter/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="dlimp-dataset-builder", 5 | python_requires=">=3.10", 6 | install_requires=[ 7 | "tensorflow_datasets==4.9.4", 8 | "opencv-python", 9 | "apache_beam", 10 | "tensorflow", 11 | ], 12 | ) 13 | -------------------------------------------------------------------------------- /rlds_converter/soar_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/rlds_converter/soar_dataset/__init__.py -------------------------------------------------------------------------------- /soar_data/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Allow overriding variables from the command line 4 | SAVE_DIR="${SAVE_DIR:-~/soar_data}" 5 | REQUIRED_SPACE_GB="${REQUIRED_SPACE_GB:-140}" 6 | URL_FILE="${URL_FILE:-soar_data/urls.txt}" 7 | 8 | # Function to check if enough disk space is available 9 | check_disk_space() { 10 | local available_space 11 | available_space=$(df --output=avail -BG . | tail -1 | tr -d 'G') 12 | if [ "$available_space" -lt "$REQUIRED_SPACE_GB" ]; then 13 | echo "Warning: You need at least $REQUIRED_SPACE_GB GB of free space to proceed." 14 | read -p "Do you want to continue anyway? (y/n) " REPLY 15 | echo 16 | if [[ ! $REPLY =~ ^[Yy]$ ]]; then 17 | echo "Aborting download." 18 | exit 1 19 | fi 20 | fi 21 | } 22 | 23 | # Check disk space 24 | check_disk_space 25 | 26 | # Check if the url file exists 27 | if [ ! -f $URL_FILE ]; then 28 | echo "Error: URLs file '${URL_FILE}' not found. Please run from the root directory of the repository." 29 | echo "URL file should be found at soar_data/urls.txt" 30 | echo "Aborting download." 31 | exit 1 32 | fi 33 | 34 | # Inform saving directory 35 | echo "Saving files to $SAVE_DIR" 36 | 37 | # Count the number of files 38 | TOTAL_FILES=$(wc -l < $URL_FILE) 39 | CURRENT_FILE=0 40 | 41 | # Function to print the progress bar 42 | print_progress() { 43 | local PROGRESS=$(( ($CURRENT_FILE * 100) / $TOTAL_FILES )) 44 | local FILLED=$(( $PROGRESS / 2 )) 45 | local EMPTY=$(( 50 - $FILLED )) 46 | printf "\rDownloading files: [" 47 | printf "%0.s#" $(seq 1 $FILLED) 48 | printf "%0.s " $(seq 1 $EMPTY) 49 | printf "] %d%%" $PROGRESS 50 | } 51 | 52 | # Function to download using wget without parallel 53 | download_without_parallel() { 54 | while IFS= read -r url; do 55 | wget -P "$SAVE_DIR" "$url" 56 | ((CURRENT_FILE++)) 57 | print_progress 58 | done < $URL_FILE 59 | echo 60 | } 61 | 62 | download_with_parallel() { 63 | cat $URL_FILE | parallel -j 4 wget -P "$SAVE_DIR" {} 64 | } 65 | 66 | # Check if parallel is installed 67 | if ! command -v parallel &> /dev/null; then 68 | echo "GNU parallel is not installed." 69 | read -p "Do you want to install GNU parallel? (y/n) " REPLY 70 | echo 71 | if [[ $REPLY =~ ^[Yy]$ ]]; then 72 | # Try to install parallel 73 | if command -v apt-get &> /dev/null; then 74 | sudo apt-get update && sudo apt-get install -y parallel 75 | elif command -v yum &> /dev/null; then 76 | sudo yum install -y parallel 77 | elif command -v brew &> /dev/null; then 78 | brew install parallel 79 | else 80 | echo "Package manager not found. Please install GNU parallel manually." 81 | download_without_parallel 82 | exit 1 83 | fi 84 | else 85 | echo "Downloading files without parallelism..." 86 | download_without_parallel 87 | exit 0 88 | fi 89 | fi 90 | 91 | download_with_parallel 92 | 93 | # Initialize the progress bar 94 | CURRENT_FILE=0 95 | print_progress 96 | 97 | # Monitor the progress of parallel downloads 98 | while IFS= read -r url; do 99 | if [ -f "$SAVE_DIR/$(basename $url)" ]; then 100 | ((CURRENT_FILE++)) 101 | print_progress 102 | fi 103 | done < $URL_FILE 104 | echo 105 | -------------------------------------------------------------------------------- /soar_data/fetch_urls.sh: -------------------------------------------------------------------------------- 1 | # This script will take a long time to run, and the cached result in already saved in urls.txt 2 | # You shouldn't need to rerun this 3 | 4 | # List files, filter by patterns, remove duplicates, and save to urls.txt 5 | BASE_URL="https://rail.eecs.berkeley.edu/datasets/soar_release/1.0.0/" 6 | 7 | echo "Fetching all datafile URLs..." 8 | wget --spider -r -nd -np $BASE_URL 2>&1 | grep '^--' | awk '{ print $3 }' | grep -E '\.json$|tfrecord' | sort | uniq > urls.txt 9 | echo "Finished fetching URLs." 10 | -------------------------------------------------------------------------------- /soar_data/load_soar_data.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","metadata":{},"source":["# Load SOAR Data (for training)\n","This notebook is a minimal example of how to load the RLDS SOAR data (e.g. for downstream training)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","Make sure the Colab has access to the SOAR repo\n","\"\"\"\n","!git clone https://github.com/rail-berkeley/soar.git"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","\"\"\"\n","1. Download a minimal SOAR dataset that's small to be used for testing\n","\n","In this notebook we load a small dummy dataset for speed. If you wish to load the full dataset, \n","use the download script in this directory to download the full dataset. Then it can be loaded\n","in the same way, changing the path to the saved dataset.\n","\"\"\"\n","SAVE_DIR = \"dummy_soar_data\"\n","!cat soar/soar_data/test_dataset_urls.txt | while read url; do wget -P \"dummy_soar_data\" \"$url\"; done"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","2. Import the Dataloader class\n","\"\"\"\n","import subprocess\n","\n","# install jaxrl_m if it it not already installed\n","# the package is located in model_training/jaxrl_m\n","try:\n"," import jaxrl_m\n","except ImportError:\n"," print(\"local jaxrl_m package not installed, trying to install now\")\n"," package_path = 'soar/model_training'\n","\n"," # install jaxrl_m\n"," result = subprocess.run(['pip', 'install', '-e', package_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n"," # Print the standard output and error\n"," print(\"Pip install Output:\\n\", result.stdout)\n"," if result.stderr:\n"," print(\"Pip install Error:\\n\", result.stderr)\n","\n"," # add the package path to sys.path (for ipynb)\n"," import sys\n"," if package_path not in sys.path:\n"," sys.path.append(package_path)\n"," \n"," # install the requirements for the package\n"," result = subprocess.run(['pip', 'install', '-r', f\"{package_path}/requirements.txt\"])\n"," # Print the standard output and error\n"," if result.stderr:\n"," print(\"Pip install Error:\\n\", result.stderr)\n"," \n","# check that installation was successful\n","try:\n"," import jaxrl_m\n","except ImportError:\n"," print(\"Failed to correctly install jaxrl_m package\")\n"," print(\"Please manually install the package with `pip install -e soar/model_training`\")\n"," raise\n","\n","# import dataloader class\n","from jaxrl_m.data.dataset import WidowXDataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","3. Load the dataset\n","\"\"\"\n","train_data = WidowXDataset(\n"," [SAVE_DIR],\n"," seed=0,\n"," batch_size=16,\n"," train=True,\n"," load_language=True,\n"," goal_relabeling_strategy=\"uniform\", # specify goal relabeling to load languages\n"," goal_relabeling_kwargs={\"reached_proportion\": 0.5, \"discount\": 0.98},\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","4. Inspect an example batch\n","\"\"\"\n","!pip install matplotlib\n","import matplotlib.pyplot as plt\n","\n","train_data_iter = train_data.iterator()\n","example_batch = next(train_data_iter)\n","\n","print(f\"Example batch keys: {example_batch.keys()}\")\n","print(f\"Actions shape: {example_batch['actions'].shape}, which is (batch_size, action_dim)\")\n","print(f\"Observations shape: {example_batch['observations']['image'].shape}, which is (batch_size, observation_dim)\")\n","print(f\"Proprio shape: {example_batch['observations']['proprio'].shape}, which is (batch_size, proprio_dim)\")\n","\n","language = example_batch['goals']['language']\n","language = [str(l) for l in language]\n","\n","plt.figure(figsize=(10, 10))\n","for i in range(16):\n"," plt.subplot(4, 4, i+1)\n"," plt.imshow(example_batch['observations']['image'][i])\n"," plt.title(language[i] if len(language[i]) <= 30 else language[i][:30] + '...', fontsize=6)\n"," plt.axis('off')\n","plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","5. Load only the success/failure split of the SOAR-data\n","If you wish, you could only load certain splits of the dataset. \n","\"\"\"\n","success_data = WidowXDataset(\n"," [SAVE_DIR],\n"," data_splits=[\"success\"],\n"," seed=0,\n"," batch_size=16,\n"," train=True,\n",")\n","\n","failure_data = WidowXDataset(\n"," [SAVE_DIR],\n"," data_splits=[\"failure\"],\n"," seed=0,\n"," batch_size=16,\n"," train=True,\n",")"]},{"cell_type":"markdown","metadata":{},"source":["## More Advanced Usage\n","For more advanced usage, check out the arguments of the `BridgeDataset` class at [model_training/jaxrl_m/data/dataset.py](https://github.com/rail-berkeley/soar/blob/main/model_training/jaxrl_m/data/dataset.py).\n","\n","An example of how this dataset is used is in `model_training/experiments/train.py`, and the configuration and arguments of the datasets are in `model_training/experiments/configs/train_config.py` and `model_training/experiments/configs/data_config.py`."]}],"metadata":{"kernelspec":{"display_name":"widowx","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"}},"nbformat":4,"nbformat_minor":2} 2 | -------------------------------------------------------------------------------- /soar_data/test_dataset_urls.txt: -------------------------------------------------------------------------------- 1 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/dataset_info.json 2 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/features.json 3 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/soar_dataset-failure.tfrecord-00000-of-00001 4 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/soar_dataset-success.tfrecord-00000-of-00001 5 | --------------------------------------------------------------------------------