├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── data_collection
├── README.md
├── config
│ ├── README.md
│ ├── berkeley_robot_0
│ │ ├── README.md
│ │ ├── cogvlm_server.yaml
│ │ ├── config.yaml
│ │ ├── gc_bc.yaml
│ │ ├── general_params.yaml
│ │ ├── lc_bc.yaml
│ │ ├── reset_detector.yaml
│ │ ├── subgoal_predictor.yaml
│ │ ├── success_detector.yaml
│ │ ├── task_definition.yaml
│ │ └── task_proposer.yaml
│ ├── berkeley_robot_1
│ │ ├── cogvlm_server.yaml
│ │ ├── config.yaml
│ │ ├── gc_bc.yaml
│ │ ├── general_params.yaml
│ │ ├── lc_bc.yaml
│ │ ├── reset_detector.yaml
│ │ ├── subgoal_predictor.yaml
│ │ ├── success_detector.yaml
│ │ ├── task_definition.yaml
│ │ └── task_proposer.yaml
│ ├── berkeley_robot_2
│ │ ├── cogvlm_server.yaml
│ │ ├── config.yaml
│ │ ├── gc_bc.yaml
│ │ ├── general_params.yaml
│ │ ├── reset_detector.yaml
│ │ ├── subgoal_predictor.yaml
│ │ ├── success_detector.yaml
│ │ ├── task_definition.yaml
│ │ └── task_proposer.yaml
│ ├── berkeley_robot_3
│ │ ├── cogvlm_server.yaml
│ │ ├── config.yaml
│ │ ├── gc_bc.yaml
│ │ ├── general_params.yaml
│ │ ├── lc_bc.yaml
│ │ ├── reset_detector.yaml
│ │ ├── subgoal_predictor.yaml
│ │ ├── success_detector.yaml
│ │ ├── task_definition.yaml
│ │ └── task_proposer.yaml
│ └── berkeley_robot_4
│ │ ├── cogvlm_server.yaml
│ │ ├── config.yaml
│ │ ├── diffusion_policy.yaml
│ │ ├── gc_bc.yaml
│ │ ├── general_params.yaml
│ │ ├── rcsl.yaml
│ │ ├── reset_detector.yaml
│ │ ├── subgoal_predictor.yaml
│ │ ├── success_detector.yaml
│ │ ├── task_definition.yaml
│ │ └── task_proposer.yaml
├── orchestrator
│ ├── cogvlm_server
│ │ ├── batched.py
│ │ ├── forwarding_server.py
│ │ ├── main.py
│ │ ├── test-img.png
│ │ ├── test.py
│ │ └── test_batched.py
│ ├── robot
│ │ ├── gc_policy.py
│ │ ├── logger.py
│ │ ├── main.py
│ │ ├── reset_detector.py
│ │ ├── subgoal_predictor.py
│ │ ├── task_proposer.py
│ │ ├── task_success_predictor.py
│ │ └── utils.py
│ ├── set_workspace_bounds
│ │ └── teleop.py
│ ├── susie_server
│ │ └── main.py
│ └── web_viewer
│ │ ├── app.py
│ │ ├── ros_client
│ │ └── run_client.py
│ │ └── templates
│ │ └── index.html
├── requirements.txt
├── setup.py
├── ssh_port_forward.sh
└── start_robot.sh
├── media
├── autonomous_data_collection.png
├── soar_logo.jpeg
└── soar_teaser.png
├── model_training
├── README.md
├── experiments
│ ├── configs
│ │ ├── data_config.py
│ │ └── train_config.py
│ ├── scripts
│ │ └── launch.sh
│ └── train.py
├── jaxrl_m
│ ├── agents
│ │ ├── __init__.py
│ │ └── continuous
│ │ │ └── gc_bc.py
│ ├── common
│ │ ├── common.py
│ │ ├── encoding.py
│ │ ├── optimizers.py
│ │ ├── typing.py
│ │ └── wandb.py
│ ├── data
│ │ ├── dataset.py
│ │ ├── text_processing.py
│ │ ├── tf_augmentations.py
│ │ └── tf_goal_relabeling.py
│ ├── networks
│ │ ├── actor_critic_nets.py
│ │ └── mlp.py
│ ├── utils
│ │ ├── jax_utils.py
│ │ ├── timer_utils.py
│ │ └── train_utils.py
│ └── vision
│ │ ├── __init__.py
│ │ ├── film_conditioning_layer.py
│ │ └── resnet_v1.py
├── requirements.txt
└── setup.py
├── rlds_converter
├── README.md
├── dataset_builder.py
├── requirements.txt
├── setup.py
└── soar_dataset
│ ├── __init__.py
│ └── soar_dataset_dataset_builder.py
└── soar_data
├── download_dataset.sh
├── fetch_urls.sh
├── load_soar_data.ipynb
├── test_dataset_urls.txt
└── urls.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | checkpoints/
165 | goal.png
166 | obs.png
167 | data_collection_logs/
168 | video_logs
169 | orchestrator/web_viewer/uploads
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | # - id: check-yaml
6 | - id: check-ast
7 | - id: check-added-large-files
8 | args: ['--maxkb=2000']
9 | - id: check-case-conflict
10 | - id: check-merge-conflict
11 | - id: end-of-file-fixer
12 | - id: trailing-whitespace
13 | - id: detect-private-key
14 | - id: debug-statements
15 | exclude: ^model_training/experiments/
16 | - repo: https://github.com/psf/black
17 | rev: 22.10.0
18 | hooks:
19 | - id: black
20 | exclude: ^model_training/experiments/
21 | - repo: https://github.com/pycqa/isort
22 | rev: 5.12.0
23 | hooks:
24 | - id: isort
25 | exclude: ^model_training/experiments/
26 | args: [
27 | "--profile", "black",
28 | "--src", "model_training/jaxrl_m",
29 | "--src", "model_training/experiments",
30 | "--src", "data_collection/orchestrator",
31 | ]
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Robotic AI & Learning Lab Berkeley
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SOAR: Autonomous Improvement of Instruction Following Skills via Foundation Models
2 | [](media/soar_logo.jpeg)
3 | [](https://arxiv.org/pdf/2407.20635)
4 | [](https://githubtocolab.com/rail-berkeley/soar/blob/main/soar_data/load_soar_data.ipynb)
5 | [](https://opensource.org/licenses/MIT)
6 | [](https://auto-improvement.github.io/)
7 |
8 | [Zhiyuan Zhou](https://zhouzypaul.github.io/), [Pranav Atreya](https://pranavatreya.github.io/), [Abraham Lee](https://www.linkedin.com/in/abraham-lee-4a0497242?original_referer=https%3A%2F%2Fwww.google.com%2F), [Homer Walke](https://homerwalke.com/), [Oier Mees](https://www.oiermees.com/), [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/)
9 |
10 |
11 | We present SOAR, an approach to autonomously improve instruction following policies leveraging
12 | foundation models. SOAR breaks down the autonomous improvement problem into components that import
13 | Internet-scale knowledge from VLMs and a component that learns from autonomous data with a purely self-supervised objective.
14 |
15 | 
16 |
17 | This repository contains three components: (1) the VLM powered semantics-aware autonomous data collection pipeline, (2) converting the collected raw data into the RLDS format, and (3) Jax/Flax code for training the policies used in the paper.
18 |
19 | ## Using SOAR-Data
20 |
21 | We have released SOAR-Data for public access [here](https://rail.eecs.berkeley.edu/datasets/soar_release/1.0.0/) in RLDS format. (The raw data is also available [here](https://rail.eecs.berkeley.edu/datasets/soar_release/numpy_source/soar-dataset-local/))
22 | We also provided a download script to download the dataset in RLDS format, which requires 136G of disk space.
23 | In this directory, run
24 | ```bash
25 | bash soar_data/download_dataset.sh
26 | ```
27 | This script should take around 20 minutes to download if you use the parallel download option, and we recommend downloading inside a tmux session.
28 |
29 | To load the dataset for training and other downstream use cases, we have provided a minimal example [](https://githubtocolab.com/rail-berkeley/soar/blob/main/soar_data/load_soar_data.ipynb) that loads the dataset and visualizes it.
30 |
31 | ## Installation
32 | ```bash
33 | conda create -n soar python=3.10
34 | conda activate soar
35 |
36 | # model training requirements
37 | pip install -e model_training
38 | pip install -r model_training/requirements.txt
39 |
40 | # data collection requirements (you also need the jaxrl_m library above)
41 | pip install -e data_collection
42 | pip install -r data_collection/requirements.txt
43 |
44 | # rlds conversion requirements
45 | pip install -e rlds_converter
46 | pip install -r rlds_converter/requirements.txt
47 | ```
48 |
49 | If you would like to train models with Jax,
50 | For GPU:
51 | ```bash
52 | pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
53 | ```
54 |
55 | For TPU:
56 | ```bash
57 | pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
58 | ```
59 |
60 |
61 | ## Autonomous Data Collection
62 |
63 | We provide a ready-to-use implementation of autonomous data collection on a fleet of WidowX robot arms. This data collection system is designed around deploying instruction following policies at scale to collect autonomous datasets that are semantically relevant, diverse, and large. Special care is taken to minimize human supervision during data collection, with features like automatic reset detection (and subsequent Slack notification).
64 |
65 | 
66 |
67 | Run autonomous data collection on the robot with:
68 | ```
69 | python data_collection/orchestrator/robot/main.py --config_dir config/
70 | ```
71 |
72 | See [data_collection/README.md](data_collection/README.md) for more information on the setup required before running data collection.
73 |
74 | ## Model Training
75 | This directory contains a self-contained python project for training goal-conditioned and language conditioned policies on Bridge and on Soar-Data.
76 |
77 | To launch a training run, run:
78 | ```bash
79 | cd model_training
80 | bash experiments/scripts/launch.sh
81 | ```
82 | This will launch [train.py](model_training/experiments/train.py) with the default arguments specified in [train_config.py](model_training/experiments/configs/train_config.py) and [data_config.py](model_training/experiments/configs/data_config.py).
83 |
84 | ## RLDS Data Conversion
85 | We convert the raw data logged in the `data_collection/*` directories into the commonly used RLDS format. The conversion code is
86 | located in the `rlds_converter` directory. See [rlds_converter/README.md](rlds_converter/README.md) for more information.
87 |
88 | To build the SOAR dataset
89 | ```bash
90 | cd rlds_converter/soar_dataset
91 | CUDA_VISIBLE_DEVICES="" tfds build --manual_dir
92 | ```
93 |
94 | ## Citation
95 | ```
96 | @article{zhou2024autonomous,
97 | title={Autonomous Improvement of Instruction Following Skills via Foundation Models},
98 | author={Zhiyuan Zhou and Pranav Atreya and Abraham Lee and Homer Walke and Oier Mees and Sergey Levine},
99 | journal = {arXiv preprint arXiv:407.20635},
100 | year={2024},
101 | }
102 | ```
103 |
104 | ## Contributing
105 | We welcome pull requests and bug reports to this repo.
106 |
107 | To enable code checks and auto-formatting, please install pre-commit hooks (run this in the root directory):
108 | ```bash
109 | pre-commit install
110 | ```
111 | The hooks should now run before every commit. If files are modified during the checks, you'll need to re-stage them and commit again.
112 |
--------------------------------------------------------------------------------
/data_collection/README.md:
--------------------------------------------------------------------------------
1 | # Autonomous data collection with VLMs on WidowX robots
2 |
3 | ## Installation
4 | Run in the current directory
5 | ```bash
6 | pip install -e .
7 | pip install -r requirements.txt
8 | ```
9 |
10 | ## VLM Server Hosting
11 | You have the option to either host CogVLM on a local server for inference, or use GPT-4V/o. This specification can be made in the configs (see `README.md` under `data_collection/config`). If you are running autonomous data collection with multiple robots, you can host the VLM just once, and all data collection scripts will query this server.
12 |
13 | If you are hosting CogVLM, make sure the port specified in the last line of the file `data_collection/orchestrator/cogvlm_server/main.py` matches the port specified in `data_collection/config//cogvlm_server.yaml`. Then, on the machine you want to host the VLM server run `python main.py` from the `data_collection/orchestrator/cogvlm_server` directory. The VLM requires around 48 GB of memory.
14 |
15 | We provide convenience scripts for testing that the VLM has been hosted correctly and for setting up a proxy server (i.e., to get around firewalls) which are located in the same directory.
16 |
17 | ## OpenAI API Key
18 |
19 | Make sure to specify your OpenAI API key as an environment variable with the name `OPENAI_API_KEY`. It is likely convenient to include this specification in your `.bashrc` file.
20 |
21 | ## SuSIE Server
22 |
23 | Similar to the VLM server, you will need to host the SuSIE model on a machine accessible to the robot machines. The memory requirement is much more modest, taking up around 6 GB. Make sure the port specified in the last line of the file `data_collection/orchestrator/susie_server/main.py` matches the port specified in the config `data_collection/config//subgoal_predictor.yaml`. To launch the SuSIE server, run `python orchestrator/susie_server/main.py --config_dir config/` from the `data_collection` directory, specifying the path to the folder containing your robot's configs.
24 |
25 | ## Web Viewer
26 |
27 | To make it convenient to monitor your robots from anywhere, we include a Flask web server with a simple front-end displaying video streamed by your robots. It is mandatory to launch the web server. There are two parts to launching this web viewer: (1) launch the Flask server on a central machine, and (2) launch the data streaming RosPy script on each of your robots.
28 |
29 | To launch the Flask web server, run `python app.py` from the directory `data_collection/orchestrator/web_viewer`. The default port for the web server is `5000`, which can be adjusted in the last line of the file `app.py`.
30 |
31 | Separately on your robot machine (the machine where you are running the docker container and action server from `bridge_data_robot`), launch the script `python orchestrator/web_viewer/ros_client/run_client.py --config_dir config/` from the `data_collection` directory, making sure the specify the path to the appropriate configuration directory. This command should be run after the docker container and action server from `bridge_data_robot` have been launched (see the README in the `bridge_data_robot` repo for more instructions).
32 |
33 | ## Pre-data collection: Setting Workspace Boundaries for Robot
34 |
35 | The final step before launching data collection is to specify the workspace boundaries for your robot. Specifying workspace boundaries (as the dimensions of an invisible rectangular prism the end-effector is forced to stay inside of) helps with safe robot operation and minimizes the chances that the robot will do something requiring a manual environment reset.
36 |
37 | Run the script `python orchestrator/set_workspace_bounds/teleop.py --config_dir config/` from the `data_collection` directory. This will instantiate a keyboard teleop script (the key controls of which will be printed once you run the script). You should then teleop the end-effector to the extremums of your workspace. Hitting `q` will terminate the script, and print out the minimum and maximim `x`, `y`, and `z` values defining the invisible rectangular prism boundary. You should enter these values in your robot `general_params` config file: `data_collection/config//general_params.yaml`.
38 |
39 | ## Running the Robot
40 |
41 | Finally you are ready to run autonomous data collection on the robot! Simply run the following script:
42 | ```
43 | python orchestrator/robot/main.py --config_dir config/
44 | ```
45 | from the `data_collection` directory. The script `main.py` contains the code for iterating through the full autonomous data collection loop: querying the VLM for which task to command, querying the SuSIE server for a subgoal image, rolling out the policy, querying the VLM for success determination, and logging. You should be able to keep this script and the robot running for many hours at a time, potentially periodically resetting a fallen object in the robot's environment.
46 |
47 | 1. See `checkpoints` for instructions on downloading pre-trained models
48 | 2. See `config` for instructions on parameter specification for setting up autonomous data collection on your robot.
49 |
--------------------------------------------------------------------------------
/data_collection/config/README.md:
--------------------------------------------------------------------------------
1 | # Robot Data Collection Configuration Files
2 |
3 | This directory contains the configuration files for the five data collection WidowX robots used in the paper. Minus some per-robot idiosyncratic changes, the configs are identical across robots.
4 |
5 | When setting up a robot for autonomous data collection, all changes you should need to make should be localized to the config subdirectory (i.e., you should not need to make any code changes). We have provided a detailed README under `berkeley_robot_0` with explanations of what the various configurations control and what needs to be modified for your setup.
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/README.md:
--------------------------------------------------------------------------------
1 | # Robot Configuration
2 |
3 | This directory contains a set of configuration files that specify things related to connecting to the robot, logging, image subgoal generation model parameters, goal-conditioned policy architecture, VLM task proposing and success detection, etc.
4 |
5 | ## config.yaml
6 |
7 | This is the top level configuration file that points to other more specific files for all of the aforementioned pipeline components. Make sure to specify configuration files here for the following:
8 |
9 | - general_params
10 | - task_proposer_params
11 | - gc_policy_params
12 | - subgoal_predictor_params
13 | - reset_detector_params
14 | - success_detector_params
15 | - cogvlm_server_params
16 | - task_definition_params
17 |
18 | ## general_params.yaml
19 | This file contains configuration parameters for connecting to the robot, streaming video feed of the robot's third-person camera to a web viewer, workspace boundaries, and logging.
20 |
21 | 1. Specify the IP address and port for the machine plugged into the robot. On this machine, you will need to run the bridge data robot server (see installation instructions from https://github.com/rail-berkeley/bridge_data_robot), and the IP address and port should match that of the server.
22 | 2. We also provide for convenience a web UI that allows you to visualize what your robot(s) is/are doing from anywhere. The web UI is a Flask server that you will need to host on your machine. See instructions in the top-level README for how to do this. `web_viewer_ip` and `web_viewer_port` should be set corresponding to your hosted web server.
23 | 3. If you want to run simultaneous autonomous data collection on an arm farm (as we did in the paper) to scale data collection, then each robot will have a unique id which you must specify.
24 | 4. Ignore `override_workspace_boundaries` (this is deprecated)
25 | 5. `move_duration` is the time alloted for the robot to execute one action. We collect data and evaluate the WidowX robots with blocking control, in which action prediction for the next timestep occurs only after the robot has been given `move_duration` number of seconds to execute the previous action. Longer `move_duration` leads to more precise execution of commanded actions, but slower data collection.
26 | 6. `video_save_path` is a local directory where all your collected robot data will be logged. The logged data include image states and robot actions, generated subgoals, task and robot metadata, etc. Data is saved as numpy arrays, which if you want to subsequently use for training you need to convert to RLDS format (see instructions in top-level README for how to do this conversion).
27 | 7. Specifying `manual_workspace_bounds` allows you to define a rectangular prism in which the end-effector is forced to remain. It is a good idea to specify this as it enables much more safe robot behavior during autonomous data collection. The two numbers for `x`, `y`, and `z` correspond to the minimum and maximum allowed values for each dimension. See instructions in the top-level README for a convenience script that allows you to easily determine what these boundary values are.
28 |
29 | ## task_proposer.yaml
30 |
31 | There multiple different types of task proposers implemented in this code base (see orchestrator/robot/task_proposer.py for the implementations). The implemented task proposers include VLM task proposers using CogVLM and GPT4-V/o in which the VLM looks at the current observations and chooses a task to command from a list of possible tasks, a cycling task proposer which for a two-task setup (e.g., opening and closing a drawer) simply cycles between commanding each of the two tasks, and a human task proposer, which periodically queries you to enter a task via the command line.
32 |
33 | The VLM task proposer, in addition to considering environment affordances prescribed by the current image observation, also considers how many times each of the tasks has been attempted. If two tasks are viable to command according to the VLM and one has been attempted more than the other, the less attempted task will be commanded. To enable this, the `reuse_task_statistics` flag controls whether or not to load previous trajectories in the logging folder to compute the number of times each task has been attempted. If set to false, attempt counters for each of the tasks will be initialized to zero.
34 |
35 | `rand_selection_prob` specifies a probability with which to ignore the task proposed by the VLM and instead propose a random task. Setting this to a nonzero number can be useful in cases where the VLM is less accurate.
36 |
37 | ## gc_bc.yaml
38 |
39 | 1. Specify `checkpoint_path` to be the directory containing the policy checkpoint you trained or downloaded from huggingface (for the latter see instructions in `checkpoints` folder).
40 | 2. `rollout_timesteps` controls how many timesteps to roll out the goal-conditioned policy before a new image subgoal is synthesized.
41 | 3. Adding noise diversifies the collected robot data and facilitates exploration. `exploration` controls the parameters of the added noise.
42 | 4. `open_gripper_if_nothing_grasped`, if set to true, will force the gripper open (even if the policy is commanding it to be closed) if the encoders on the gripper read that nothing has been grasped. We leave this set to false in all our experiments.
43 | 5. `restrict_action_space` restricts the action space to 5 dimensions, assigning the pitch and yaw dimensions to known good values. In particular challenging environments this can help exploration if it is known that these two angle dimensions are not needed to complete desired tasks. This variable is set to false in all of our experiments.
44 | 6. `agent_kwargs` controls the configuration of the trained policy. See the source code under the `model_training` for what these configurations control.
45 |
46 | # subgoal_predictor.yaml
47 |
48 | This config file controls the image generation parameters for SuSIE, the InstructPix2Pix style diffusion model we use for subgoal generation. The defaults should be fine. The only parameters that need to be set are `susie_server_ip` and `susie_server_port`, which should be set appropriately depending on where you host the SuSIE server (see instructions in top-level README for hosting this server).
49 |
50 | # reset_detector.yaml
51 |
52 | During autonomous data collection, it is possible that objects may fall out of the workspace boundaries. To prevent the user from having to monitor the robots 24/7, we have included an automatic reset detection functionality, whereby the VLM is used to determine if objects are missing from the workspace, and if so send a slack messsage to a channel you create. Specify `slack_token` and `channel_id` appropriately for the channel you create. `which_vlm` controls which VLM is used for reset detection.
53 |
54 | # success_detector.yaml
55 |
56 | `which_vlm` controls which VLM you would like to use for determining task success at the end of every trajectory. We used CogVLM for our experiments.
57 |
58 | # cogvlm_server.yaml
59 |
60 | Specify here the IP address and port of the machine you are hosting CogVLM on. CogVLM requires around 48 Gb of memory for batched
61 | inference.
62 |
63 | # task_definition.yaml
64 |
65 | When setting up your robot in a new evironment, make sure to specify which objects are present in the scene, and what are all of
66 | the language tasks you want to be run during data collection. In addition to guiding the VLM for task selection, the information here is logged with every collected trajectory.
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/cogvlm_server.yaml:
--------------------------------------------------------------------------------
1 | # CogVLM server config parameters
2 | cogvlm_server_ip: "localhost"
3 | cogvlm_server_port: 6000
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | general_params: !include general_params.yaml
3 | task_proposer_params: !include task_proposer.yaml
4 | gc_policy_params: !include gc_bc.yaml
5 | subgoal_predictor_params: !include subgoal_predictor.yaml
6 | reset_detector_params: !include reset_detector.yaml
7 | success_detector_params: !include success_detector.yaml
8 | cogvlm_server_params: !include cogvlm_server.yaml
9 | task_definition_params: !include task_definition.yaml
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/gc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level goal conditioned policy
3 | policy_class: "gc_bc"
4 | checkpoint_path: "checkpoints/iterated_gcbc_round_1_policy"
5 | image_size: 256
6 | rollout_timesteps: 20
7 | exploration:
8 | make_traj_deterministic_prob: 0.2
9 | sampling_temperature: 0.2
10 | gripper_open_prob: 0.005
11 | gripper_close_prob: 0.01
12 | open_gripper_if_nothing_grasped: false
13 | restrict_action_space: false
14 | ACT_MEAN:
15 | - 1.9296819e-04
16 | - 1.3667766e-04
17 | - -1.4583133e-04
18 | - -1.8390431e-04
19 | - -3.0808983e-04
20 | - 2.7425270e-04
21 | - 5.9716219e-01
22 | ACT_STD:
23 | - 0.00912848
24 | - 0.0127196
25 | - 0.01229497
26 | - 0.02606696
27 | - 0.02875283
28 | - 0.07807977
29 | - 0.48710242
30 | agent_kwargs:
31 | policy_kwargs:
32 | fixed_std:
33 | - 1.0
34 | - 1.0
35 | - 1.0
36 | - 1.0
37 | - 1.0
38 | - 1.0
39 | - 0.1
40 | std_parameterization: "fixed"
41 | tanh_squash_distribution: false
42 | early_goal_concat: true
43 | shared_goal_encoder: true
44 | use_proprio: false
45 | network_kwargs:
46 | hidden_dims:
47 | - 256
48 | - 256
49 | - 256
50 | dropout_rate: 0.1
51 | encoder:
52 | type: "resnetv1-34-bridge"
53 | config:
54 | act: "swish"
55 | add_spatial_coordinates: false
56 | pooling_method: "avg"
57 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/general_params.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # IP address and port of the robot
3 | ip: "128.32.175.236"
4 | port: 5556
5 |
6 | # IP address of web viewer
7 | web_viewer_ip: "128.32.175.81"
8 | web_viewer_port: 5000
9 |
10 | # Robot ID
11 | robot_id: 0
12 |
13 | # General config parameters
14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper
15 | env_params:
16 | camera_topics:
17 | - name: "/blue/image_raw"
18 | flip: false
19 | override_workspace_boundaries:
20 | - - -20.0
21 | - -20.0
22 | - -20.0
23 | - -1.57
24 | - 0
25 | - - 20.0
26 | - 20.0
27 | - 20.0
28 | - 1.57
29 | - 0
30 | move_duration: 0.3
31 | video_save_path: "video_logs"
32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam
33 | initial_eep:
34 | - 0.3
35 | - 0.0
36 | - 0.15
37 | - 0
38 | - 0
39 | - 0
40 | - 1
41 | manual_workspace_bounds:
42 | x:
43 | - 0.15603437
44 | - 0.42324517
45 | y:
46 | - -0.20489213
47 | - 0.28275232
48 | z:
49 | - 0.02985591
50 | - 0.16494011
51 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/lc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level language conditioned policy
3 | policy_class: "lc_bc"
4 | checkpoint_path: "checkpoints/lcbc/lcbc_bridge_v1"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 20
8 | exploration:
9 | make_traj_deterministic_prob: 0.5
10 | sampling_temperature: 0.2
11 | gripper_open_prob: 0.005
12 | gripper_close_prob: 0.01
13 | open_gripper_if_nothing_grasped: false
14 | restrict_action_space: false
15 | dataset_kwargs:
16 | shuffle_buffer_size: 25000
17 | augment: true
18 | augment_next_obs_goal_differently: false
19 | augment_kwargs:
20 | random_resized_crop:
21 | scale:
22 | - 0.8
23 | - 1.0
24 | ratio:
25 | - 0.9
26 | - 1.1
27 | random_brightness:
28 | - 0.2
29 | random_contrast:
30 | - 0.8
31 | - 1.2
32 | random_saturation:
33 | - 0.8
34 | - 1.2
35 | random_hue:
36 | - 0.1
37 | augment_order:
38 | - "random_resized_crop"
39 | - "random_brightness"
40 | - "random_contrast"
41 | - "random_saturation"
42 | - "random_hue"
43 | goal_relabeling_strategy: "uniform"
44 | goal_relabeling_kwargs:
45 | reached_proportion: 0.0
46 | relabel_actions: true
47 | normalization_type: "normal"
48 | load_language: true
49 | skip_unlabeled: true
50 | ACT_MEAN:
51 | - 1.9296819e-04
52 | - 1.3667766e-04
53 | - -1.4583133e-04
54 | - -1.8390431e-04
55 | - -3.0808983e-04
56 | - 2.7425270e-04
57 | - 5.9716219e-01
58 | ACT_STD:
59 | - 0.00912848
60 | - 0.0127196
61 | - 0.01229497
62 | - 0.02606696
63 | - 0.02875283
64 | - 0.07807977
65 | - 0.48710242
66 | agent_kwargs:
67 | policy_kwargs:
68 | fixed_std:
69 | - 1.0
70 | - 1.0
71 | - 1.0
72 | - 1.0
73 | - 1.0
74 | - 1.0
75 | - 0.1
76 | state_dependent_std: false
77 | tanh_squash_distribution: false
78 | early_goal_concat: true
79 | shared_goal_encoder: true
80 | use_proprio: false
81 | network_kwargs:
82 | hidden_dims:
83 | - 256
84 | - 256
85 | - 256
86 | dropout_rate: 0.1
87 | encoder:
88 | type: "resnetv1-34-bridge-film"
89 | config:
90 | act: "swish"
91 | add_spatial_coordinates: true
92 | pooling_method: "avg"
93 | text_processor: "muse_embedding"
94 | text_processor_kwargs: {}
95 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/reset_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Reset detector config params
3 | slack_token:
4 | channel_id:
5 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none"
6 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/subgoal_predictor.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of high-level subgoal predictor policy
3 | diffusion_checkpoint: "kvablack/susie"
4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m"
5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax"
6 | diffusion_num_steps: 50
7 | prompt_w: 7.5
8 | context_w: 2.0
9 | image_size: 256
10 | max_subgoals: 5
11 | susie_server_ip: "localhost"
12 | susie_server_port: 7000
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/success_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of success detector
3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none"
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/task_definition.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # task definition config params
3 | object_list:
4 | - "mushroom"
5 | - "blue bowl"
6 | task_list:
7 | - "remove the mushroom from the blue bowl and put it on the table"
8 | - "put the mushroom in the blue bowl"
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_0/task_proposer.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of task proposer
3 | reuse_task_statistics: false
4 | rand_selection_prob: 0.2
5 | zone_center: 0.3
6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone
7 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none", "human", "cycling"
8 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/cogvlm_server.yaml:
--------------------------------------------------------------------------------
1 | # CogVLM server config parameters
2 | cogvlm_server_ip: "localhost"
3 | cogvlm_server_port: 6000
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | general_params: !include general_params.yaml
3 | task_proposer_params: !include task_proposer.yaml
4 | gc_policy_params: !include gc_bc.yaml
5 | subgoal_predictor_params: !include subgoal_predictor.yaml
6 | reset_detector_params: !include reset_detector.yaml
7 | success_detector_params: !include success_detector.yaml
8 | cogvlm_server_params: !include cogvlm_server.yaml
9 | task_definition_params: !include task_definition.yaml
10 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/gc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level goal conditioned policy
3 | policy_class: "gc_bc"
4 | checkpoint_path: "checkpoints/iterated_gcbc_round_1_policy"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 20
8 | exploration:
9 | make_traj_deterministic_prob: 0.5
10 | sampling_temperature: 0.2
11 | gripper_open_prob: 0.005
12 | gripper_close_prob: 0.01
13 | open_gripper_if_nothing_grasped: false
14 | restrict_action_space: false
15 | dataset_kwargs:
16 | shuffle_buffer_size: 25000
17 | augment: true
18 | augment_next_obs_goal_differently: false
19 | augment_kwargs:
20 | random_resized_crop:
21 | scale:
22 | - 0.8
23 | - 1.0
24 | ratio:
25 | - 0.9
26 | - 1.1
27 | random_brightness:
28 | - 0.2
29 | random_contrast:
30 | - 0.8
31 | - 1.2
32 | random_saturation:
33 | - 0.8
34 | - 1.2
35 | random_hue:
36 | - 0.1
37 | augment_order:
38 | - "random_resized_crop"
39 | - "random_brightness"
40 | - "random_contrast"
41 | - "random_saturation"
42 | - "random_hue"
43 | goal_relabeling_strategy: "geometric"
44 | goal_relabeling_kwargs:
45 | reached_proportion: 0.0
46 | discount: 0.98
47 | relabel_actions: true
48 | dataset_contains_commanded_goals: false
49 | normalization_type: "normal"
50 | ACT_MEAN:
51 | - 1.9296819e-04
52 | - 1.3667766e-04
53 | - -1.4583133e-04
54 | - -1.8390431e-04
55 | - -3.0808983e-04
56 | - 2.7425270e-04
57 | - 5.9716219e-01
58 | ACT_STD:
59 | - 0.00912848
60 | - 0.0127196
61 | - 0.01229497
62 | - 0.02606696
63 | - 0.02875283
64 | - 0.07807977
65 | - 0.48710242
66 | agent_kwargs:
67 | policy_kwargs:
68 | fixed_std:
69 | - 1.0
70 | - 1.0
71 | - 1.0
72 | - 1.0
73 | - 1.0
74 | - 1.0
75 | - 0.1
76 | std_parameterization: "fixed"
77 | tanh_squash_distribution: false
78 | early_goal_concat: true
79 | shared_goal_encoder: true
80 | use_proprio: false
81 | network_kwargs:
82 | hidden_dims:
83 | - 256
84 | - 256
85 | - 256
86 | dropout_rate: 0.1
87 | encoder:
88 | type: "resnetv1-34-bridge"
89 | config:
90 | act: "swish"
91 | add_spatial_coordinates: false
92 | pooling_method: "avg"
93 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/general_params.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # IP address and port of the robot
3 | ip: "128.32.175.102"
4 | port: 5556
5 |
6 | # IP address of web viewer
7 | web_viewer_ip: "128.32.175.81"
8 | web_viewer_port: 5000
9 |
10 | # Robot ID
11 | robot_id: 1
12 |
13 | # General config parameters
14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper
15 | env_params:
16 | camera_topics:
17 | - name: "/blue/image_raw"
18 | flip: false
19 | override_workspace_boundaries:
20 | - - -20.0
21 | - -20.0
22 | - -20.0
23 | - -1.57
24 | - 0
25 | - - 20.0
26 | - 20.0
27 | - 20.0
28 | - 1.57
29 | - 0
30 | move_duration: 0.3
31 | video_save_path: "video_logs"
32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam
33 | initial_eep:
34 | - 0.3
35 | - 0.0
36 | - 0.15
37 | - 0
38 | - 0
39 | - 0
40 | - 1
41 | # manual_workspace_bounds: # minsky table height 27
42 | # x:
43 | # - 0.17827454
44 | # - 0.42494287
45 | # y:
46 | # - -0.22023482
47 | # - 0.18838036
48 | # z:
49 | # - 0.02200321
50 | # - 0.23297783
51 | manual_workspace_bounds: # minsky table height 27 with barrier (left side closest to table)
52 | x:
53 | - 0.17376936
54 | - 0.36731001
55 | y:
56 | - -0.15287904
57 | - 0.20850995
58 | z:
59 | - 0.01916022
60 | - 0.2381686
61 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/lc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level language conditioned policy
3 | policy_class: "lc_bc"
4 | checkpoint_path: "checkpoints/lcbc/lcbc_bridge_v1"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 25
8 | exploration:
9 | make_traj_deterministic_prob: 1.0
10 | sampling_temperature: 0.2
11 | gripper_open_prob: 0.005
12 | gripper_close_prob: 0.01
13 | open_gripper_if_nothing_grasped: false
14 | restrict_action_space: false
15 | dataset_kwargs:
16 | shuffle_buffer_size: 25000
17 | augment: true
18 | augment_next_obs_goal_differently: false
19 | augment_kwargs:
20 | random_resized_crop:
21 | scale:
22 | - 0.8
23 | - 1.0
24 | ratio:
25 | - 0.9
26 | - 1.1
27 | random_brightness:
28 | - 0.2
29 | random_contrast:
30 | - 0.8
31 | - 1.2
32 | random_saturation:
33 | - 0.8
34 | - 1.2
35 | random_hue:
36 | - 0.1
37 | augment_order:
38 | - "random_resized_crop"
39 | - "random_brightness"
40 | - "random_contrast"
41 | - "random_saturation"
42 | - "random_hue"
43 | goal_relabeling_strategy: "uniform"
44 | goal_relabeling_kwargs:
45 | reached_proportion: 0.0
46 | relabel_actions: true
47 | normalization_type: "normal"
48 | load_language: true
49 | skip_unlabeled: true
50 | ACT_MEAN:
51 | - 1.9296819e-04
52 | - 1.3667766e-04
53 | - -1.4583133e-04
54 | - -1.8390431e-04
55 | - -3.0808983e-04
56 | - 2.7425270e-04
57 | - 5.9716219e-01
58 | ACT_STD:
59 | - 0.00912848
60 | - 0.0127196
61 | - 0.01229497
62 | - 0.02606696
63 | - 0.02875283
64 | - 0.07807977
65 | - 0.48710242
66 | agent_kwargs:
67 | policy_kwargs:
68 | fixed_std:
69 | - 1.0
70 | - 1.0
71 | - 1.0
72 | - 1.0
73 | - 1.0
74 | - 1.0
75 | - 0.1
76 | state_dependent_std: false
77 | tanh_squash_distribution: false
78 | early_goal_concat: true
79 | shared_goal_encoder: true
80 | use_proprio: false
81 | network_kwargs:
82 | hidden_dims:
83 | - 256
84 | - 256
85 | - 256
86 | dropout_rate: 0.1
87 | encoder:
88 | type: "resnetv1-34-bridge-film"
89 | config:
90 | act: "swish"
91 | add_spatial_coordinates: true
92 | pooling_method: "avg"
93 | text_processor: "muse_embedding"
94 | text_processor_kwargs: {}
95 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/reset_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Reset detector config params
3 | slack_token:
4 | channel_id: C069AT564KC
5 | which_vlm: "none" # choices: "cogvlm", "gpt4v", "none"
6 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/subgoal_predictor.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of high-level subgoal predictor policy
3 | diffusion_checkpoint: "kvablack/susie"
4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m"
5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax"
6 | diffusion_num_steps: 50
7 | prompt_w: 7.5
8 | context_w: 5.0
9 | image_size: 256
10 | max_subgoals: 5
11 | susie_server_ip: "localhost"
12 | susie_server_port: 7000
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/success_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of success detector
3 | which_vlm: "none" # choices: "cogvlm", "gpt4v", "none"
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/task_definition.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # task definition config params
3 | object_list:
4 | - "green object"
5 | - "wooden bowl"
6 | task_list:
7 | - "remove the green block from inside the brown bowl and put it on the table"
8 | - "put the green block in the brown bowl"
9 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_1/task_proposer.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of task proposer
3 | reuse_task_statistics: false
4 | rand_selection_prob: 0.2
5 | zone_center: 0.3
6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone
7 | which_vlm: "human" # choices: "cogvlm", "gpt4v", "none", "human"
8 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/cogvlm_server.yaml:
--------------------------------------------------------------------------------
1 | # CogVLM server config parameters
2 | cogvlm_server_ip: "localhost"
3 | cogvlm_server_port: 6000
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | general_params: !include general_params.yaml
3 | task_proposer_params: !include task_proposer.yaml
4 | gc_policy_params: !include gc_bc.yaml
5 | subgoal_predictor_params: !include subgoal_predictor.yaml
6 | reset_detector_params: !include reset_detector.yaml
7 | success_detector_params: !include success_detector.yaml
8 | cogvlm_server_params: !include cogvlm_server.yaml
9 | task_definition_params: !include task_definition.yaml
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/gc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level goal conditioned policy
3 | policy_class: "gc_bc"
4 | checkpoint_path: "checkpoints/iterated-gcbc/baseline/checkpoint_75000"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 20
8 | exploration:
9 | make_traj_deterministic_prob: 0.5
10 | sampling_temperature: 0.2
11 | gripper_open_prob: 0.005
12 | gripper_close_prob: 0.01
13 | open_gripper_if_nothing_grasped: false
14 | restrict_action_space: false
15 | dataset_kwargs:
16 | shuffle_buffer_size: 25000
17 | augment: true
18 | augment_next_obs_goal_differently: false
19 | augment_kwargs:
20 | random_resized_crop:
21 | scale:
22 | - 0.8
23 | - 1.0
24 | ratio:
25 | - 0.9
26 | - 1.1
27 | random_brightness:
28 | - 0.2
29 | random_contrast:
30 | - 0.8
31 | - 1.2
32 | random_saturation:
33 | - 0.8
34 | - 1.2
35 | random_hue:
36 | - 0.1
37 | augment_order:
38 | - "random_resized_crop"
39 | - "random_brightness"
40 | - "random_contrast"
41 | - "random_saturation"
42 | - "random_hue"
43 | goal_relabeling_strategy: "geometric"
44 | goal_relabeling_kwargs:
45 | reached_proportion: 0.0
46 | discount: 0.98
47 | relabel_actions: true
48 | dataset_contains_commanded_goals: false
49 | normalization_type: "normal"
50 | ACT_MEAN:
51 | - 1.9296819e-04
52 | - 1.3667766e-04
53 | - -1.4583133e-04
54 | - -1.8390431e-04
55 | - -3.0808983e-04
56 | - 2.7425270e-04
57 | - 5.9716219e-01
58 | ACT_STD:
59 | - 0.00912848
60 | - 0.0127196
61 | - 0.01229497
62 | - 0.02606696
63 | - 0.02875283
64 | - 0.07807977
65 | - 0.48710242
66 | agent_kwargs:
67 | policy_kwargs:
68 | fixed_std:
69 | - 1.0
70 | - 1.0
71 | - 1.0
72 | - 1.0
73 | - 1.0
74 | - 1.0
75 | - 0.1
76 | std_parameterization: "fixed"
77 | tanh_squash_distribution: false
78 | early_goal_concat: true
79 | shared_goal_encoder: true
80 | use_proprio: false
81 | network_kwargs:
82 | hidden_dims:
83 | - 256
84 | - 256
85 | - 256
86 | dropout_rate: 0.1
87 | encoder:
88 | type: "resnetv1-34-bridge"
89 | config:
90 | act: "swish"
91 | add_spatial_coordinates: false
92 | pooling_method: "avg"
93 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/general_params.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # IP address and port of the robot
3 | ip: "128.32.175.217"
4 | port: 5556
5 |
6 | # IP address of web viewer
7 | web_viewer_ip: "128.32.175.81"
8 | web_viewer_port: 5000
9 |
10 | # Robot ID
11 | robot_id: 2
12 |
13 | # General config parameters
14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper
15 | env_params:
16 | camera_topics:
17 | - name: "/blue/image_raw"
18 | flip: false
19 | override_workspace_boundaries:
20 | - - -20.0
21 | - -20.0
22 | - -20.0
23 | - -1.57
24 | - 0
25 | - - 20.0
26 | - 20.0
27 | - 20.0
28 | - 1.57
29 | - 0
30 | move_duration: 0.3
31 | video_save_path: "data_collection_logs/berkeley_robot_2/mushroom_spoon_0530"
32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam
33 | initial_eep:
34 | - 0.3
35 | - 0.0
36 | - 0.15
37 | - 0
38 | - 0
39 | - 0
40 | - 1
41 | manual_workspace_bounds:
42 | x:
43 | - 0.21887143
44 | - 0.41317162
45 | y:
46 | - -0.18368285
47 | - 0.18275802
48 | z:
49 | - 0.03712492
50 | - 0.18284303
51 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/reset_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Reset detector config params
3 | slack_token:
4 | channel_id: C069AT564KC
5 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none"
6 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/subgoal_predictor.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of high-level subgoal predictor policy
3 | diffusion_checkpoint: "kvablack/susie"
4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m"
5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax"
6 | diffusion_num_steps: 50
7 | prompt_w: 7.5
8 | context_w: 2.0
9 | image_size: 256
10 | max_subgoals: 5
11 | susie_server_ip: "localhost"
12 | susie_server_port: 7000
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/success_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of success detector
3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none"
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/task_definition.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # task definition config params
3 | object_list:
4 | - "silver pot"
5 | - "mushroom"
6 | - "green spoon"
7 | task_list:
8 | - "put mushroom into the silver pot"
9 | - "remove mushroom from inside the silver pot and place it on the table"
10 | - "put the green spoon into the silver pot"
11 | - "remove the green spoon from inside the silver pot and place it on the table"
12 | - "move green spoon to the right side of the table"
13 | - "move green spoon to the left side of the table"
14 | - "move mushroom to the right side of the table"
15 | - "move mushroom to the left side of the table"
16 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_2/task_proposer.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of task proposer
3 | reuse_task_statistics: false
4 | rand_selection_prob: 1.0
5 | zone_center: 0.3
6 | ucb_weight: 5.0
7 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none", "human"
8 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/cogvlm_server.yaml:
--------------------------------------------------------------------------------
1 | # CogVLM server config parameters
2 | cogvlm_server_ip: "localhost"
3 | cogvlm_server_port: 6000
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | general_params: !include general_params.yaml
3 | task_proposer_params: !include task_proposer.yaml
4 | gc_policy_params: !include gc_bc.yaml
5 | subgoal_predictor_params: !include subgoal_predictor.yaml
6 | reset_detector_params: !include reset_detector.yaml
7 | success_detector_params: !include success_detector.yaml
8 | cogvlm_server_params: !include cogvlm_server.yaml
9 | task_definition_params: !include task_definition.yaml
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/gc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level goal conditioned policy
3 | policy_class: "gc_bc"
4 | checkpoint_path: "checkpoints/iterated_gcbc_round_1_policy"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 20
8 | exploration:
9 | make_traj_deterministic_prob: 0.5
10 | sampling_temperature: 0.2
11 | gripper_open_prob: 0.005
12 | gripper_close_prob: 0.01
13 | open_gripper_if_nothing_grasped: false
14 | restrict_action_space: false
15 | dataset_kwargs:
16 | shuffle_buffer_size: 25000
17 | augment: true
18 | augment_next_obs_goal_differently: false
19 | augment_kwargs:
20 | random_resized_crop:
21 | scale:
22 | - 0.8
23 | - 1.0
24 | ratio:
25 | - 0.9
26 | - 1.1
27 | random_brightness:
28 | - 0.2
29 | random_contrast:
30 | - 0.8
31 | - 1.2
32 | random_saturation:
33 | - 0.8
34 | - 1.2
35 | random_hue:
36 | - 0.1
37 | augment_order:
38 | - "random_resized_crop"
39 | - "random_brightness"
40 | - "random_contrast"
41 | - "random_saturation"
42 | - "random_hue"
43 | goal_relabeling_strategy: "geometric"
44 | goal_relabeling_kwargs:
45 | reached_proportion: 0.0
46 | discount: 0.98
47 | relabel_actions: true
48 | dataset_contains_commanded_goals: false
49 | normalization_type: "normal"
50 | ACT_MEAN:
51 | - 1.9296819e-04
52 | - 1.3667766e-04
53 | - -1.4583133e-04
54 | - -1.8390431e-04
55 | - -3.0808983e-04
56 | - 2.7425270e-04
57 | - 5.9716219e-01
58 | ACT_STD:
59 | - 0.00912848
60 | - 0.0127196
61 | - 0.01229497
62 | - 0.02606696
63 | - 0.02875283
64 | - 0.07807977
65 | - 0.48710242
66 | agent_kwargs:
67 | policy_kwargs:
68 | fixed_std:
69 | - 1.0
70 | - 1.0
71 | - 1.0
72 | - 1.0
73 | - 1.0
74 | - 1.0
75 | - 0.1
76 | std_parameterization: "fixed"
77 | tanh_squash_distribution: false
78 | early_goal_concat: true
79 | shared_goal_encoder: true
80 | use_proprio: false
81 | network_kwargs:
82 | hidden_dims:
83 | - 256
84 | - 256
85 | - 256
86 | dropout_rate: 0.1
87 | encoder:
88 | type: "resnetv1-34-bridge"
89 | config:
90 | act: "swish"
91 | add_spatial_coordinates: false
92 | pooling_method: "avg"
93 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/general_params.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # IP address and port of the robot
3 | ip: "128.32.175.186"
4 | port: 5556
5 |
6 | # IP address of web viewer
7 | web_viewer_ip: "128.32.175.81"
8 | web_viewer_port: 5000
9 |
10 | # Robot ID
11 | robot_id: 3
12 |
13 | # General config parameters
14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper
15 | env_params:
16 | camera_topics:
17 | - name: "/blue/image_raw"
18 | flip: false
19 | override_workspace_boundaries:
20 | - - -20.0
21 | - -20.0
22 | - -20.0
23 | - -1.57
24 | - 0
25 | - - 20.0
26 | - 20.0
27 | - 20.0
28 | - 1.57
29 | - 0
30 | move_duration: 0.3
31 | video_save_path: "data_collection_logs/berkeley_robot_3/gc_bc_eval"
32 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam
33 | initial_eep:
34 | - 0.3
35 | - 0.0
36 | - 0.15
37 | - 0
38 | - 0
39 | - 0
40 | - 1
41 | manual_workspace_bounds:
42 | x:
43 | - 0.25376111
44 | - 0.44158756
45 | y:
46 | - -0.17615086
47 | - 0.18689521
48 | z:
49 | - 0.0187575
50 | - 0.13504307
51 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/lc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level language conditioned policy
3 | policy_class: "lc_bc"
4 | checkpoint_path: "checkpoints/lcbc/lcbc_bridge_v1"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 25
8 | exploration:
9 | make_traj_deterministic_prob: 1.0
10 | sampling_temperature: 0.2
11 | gripper_open_prob: 0.005
12 | gripper_close_prob: 0.01
13 | open_gripper_if_nothing_grasped: false
14 | restrict_action_space: false
15 | dataset_kwargs:
16 | shuffle_buffer_size: 25000
17 | augment: true
18 | augment_next_obs_goal_differently: false
19 | augment_kwargs:
20 | random_resized_crop:
21 | scale:
22 | - 0.8
23 | - 1.0
24 | ratio:
25 | - 0.9
26 | - 1.1
27 | random_brightness:
28 | - 0.2
29 | random_contrast:
30 | - 0.8
31 | - 1.2
32 | random_saturation:
33 | - 0.8
34 | - 1.2
35 | random_hue:
36 | - 0.1
37 | augment_order:
38 | - "random_resized_crop"
39 | - "random_brightness"
40 | - "random_contrast"
41 | - "random_saturation"
42 | - "random_hue"
43 | goal_relabeling_strategy: "uniform"
44 | goal_relabeling_kwargs:
45 | reached_proportion: 0.0
46 | relabel_actions: true
47 | normalization_type: "normal"
48 | load_language: true
49 | skip_unlabeled: true
50 | ACT_MEAN:
51 | - 1.9296819e-04
52 | - 1.3667766e-04
53 | - -1.4583133e-04
54 | - -1.8390431e-04
55 | - -3.0808983e-04
56 | - 2.7425270e-04
57 | - 5.9716219e-01
58 | ACT_STD:
59 | - 0.00912848
60 | - 0.0127196
61 | - 0.01229497
62 | - 0.02606696
63 | - 0.02875283
64 | - 0.07807977
65 | - 0.48710242
66 | agent_kwargs:
67 | policy_kwargs:
68 | fixed_std:
69 | - 1.0
70 | - 1.0
71 | - 1.0
72 | - 1.0
73 | - 1.0
74 | - 1.0
75 | - 0.1
76 | state_dependent_std: false
77 | tanh_squash_distribution: false
78 | early_goal_concat: true
79 | shared_goal_encoder: true
80 | use_proprio: false
81 | network_kwargs:
82 | hidden_dims:
83 | - 256
84 | - 256
85 | - 256
86 | dropout_rate: 0.1
87 | encoder:
88 | type: "resnetv1-34-bridge-film"
89 | config:
90 | act: "swish"
91 | add_spatial_coordinates: true
92 | pooling_method: "avg"
93 | text_processor: "muse_embedding"
94 | text_processor_kwargs: {}
95 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/reset_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Reset detector config params
3 | slack_token:
4 | channel_id: C069AT564KC
5 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none"
6 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/subgoal_predictor.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of high-level subgoal predictor policy
3 | diffusion_checkpoint: "kvablack/susie"
4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m"
5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax"
6 | diffusion_num_steps: 50
7 | prompt_w: 7.5
8 | context_w: 2.0
9 | image_size: 256
10 | max_subgoals: 5
11 | susie_server_ip: "localhost"
12 | susie_server_port: 7000
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/success_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of success detector
3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none"
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/task_definition.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # task definition config params
3 | object_list:
4 | - "carrot"
5 | - "lemon"
6 | - "purple eggplant"
7 | - "blue tray"
8 | task_list:
9 | - "remove the lemon from the blue tray and put it on the table"
10 | - "move the carrot to the right side"
11 | - "move the carrot to the left side"
12 | - "put the purple eggplant on the blue tray"
13 | - "put the lemon on the blue tray"
14 | - "remove the carrot from the blue tray and put it on the table"
15 | - "remove the purple eggplant from the blue tray and put it on the table"
16 | - "put the carrot on the blue tray"
17 |
18 |
19 | # put the pink spoon on the blue tray
20 | # remove the pink spoon from the blue tray and put it on the table
21 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_3/task_proposer.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of task proposer
3 | reuse_task_statistics: false
4 | rand_selection_prob: 0.333
5 | zone_center: 0.3
6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone
7 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none", "human"
8 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/cogvlm_server.yaml:
--------------------------------------------------------------------------------
1 | # CogVLM server config parameters
2 | cogvlm_server_ip: "localhost"
3 | cogvlm_server_port: 6000
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | general_params: !include general_params.yaml
3 | task_proposer_params: !include task_proposer.yaml
4 | gc_policy_params: !include gc_bc.yaml
5 | subgoal_predictor_params: !include subgoal_predictor.yaml
6 | reset_detector_params: !include reset_detector.yaml
7 | success_detector_params: !include success_detector.yaml
8 | cogvlm_server_params: !include cogvlm_server.yaml
9 | task_definition_params: !include task_definition.yaml
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/diffusion_policy.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level goal conditioned policy
3 | policy_class: "gc_ddpm_bc"
4 | checkpoint_path: "checkpoints/red_object_metal_pot/dgcbc_95_4p5_0p5_successes_1800000" # iterated_dgcbc_80_20_230000
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 25
8 | exploration:
9 | make_traj_deterministic_prob: 1.0
10 | sampling_temperature: 2.0
11 | gripper_open_prob: 0.01
12 | gripper_close_prob: 0.02
13 | exploration_noise: 0.000
14 | dataset_kwargs:
15 | shuffle_buffer_size: 25000
16 | augment: true
17 | augment_next_obs_goal_differently: false
18 | augment_kwargs:
19 | random_resized_crop:
20 | scale:
21 | - 0.8
22 | - 1.0
23 | ratio:
24 | - 0.9
25 | - 1.1
26 | random_brightness:
27 | - 0.2
28 | random_contrast:
29 | - 0.8
30 | - 1.2
31 | random_saturation:
32 | - 0.8
33 | - 1.2
34 | random_hue:
35 | - 0.1
36 | augment_order:
37 | - "random_resized_crop"
38 | - "random_brightness"
39 | - "random_contrast"
40 | - "random_saturation"
41 | - "random_hue"
42 | goal_relabeling_strategy: "geometric"
43 | goal_relabeling_kwargs:
44 | reached_proportion: 0.0
45 | discount: 0.98
46 | relabel_actions: true
47 | act_pred_horizon: 4
48 | obs_horizon: 1
49 | ACT_MEAN:
50 | - 1.9296819e-04
51 | - 1.3667766e-04
52 | - -1.4583133e-04
53 | - -1.8390431e-04
54 | - -3.0808983e-04
55 | - 2.7425270e-04
56 | - 5.9716219e-01
57 | ACT_STD:
58 | - 0.00912848
59 | - 0.0127196
60 | - 0.01229497
61 | - 0.02606696
62 | - 0.02875283
63 | - 0.07807977
64 | - 0.48710242
65 | PROPRIO_MEAN:
66 | - 0.29730073
67 | - 0.02986212
68 | - 0.06420159
69 | - -0.00201155
70 | - -0.07586625
71 | - 0.159071
72 | - 0.75686556
73 | PROPRIO_STD:
74 | - 0.05918062
75 | - 0.09581848
76 | - 0.05275392
77 | - 0.13922517
78 | - 0.16974117
79 | - 0.6555491
80 | - 0.3397966
81 | agent_kwargs:
82 | score_network_kwargs:
83 | time_dim: 32
84 | num_blocks: 3
85 | dropout_rate: 0.1
86 | hidden_dim: 256
87 | use_layer_norm: true
88 | early_goal_concat: true
89 | shared_goal_encoder: true
90 | use_proprio: false
91 | beta_schedule: "cosine"
92 | diffusion_steps: 20
93 | action_samples: 1
94 | repeat_last_step: 0
95 | learning_rate: 3.0e-4
96 | warmup_steps: 2000
97 | actor_decay_steps: 2000000
98 | encoder:
99 | type: "resnetv1-34-bridge"
100 | config:
101 | act: "swish"
102 | add_spatial_coordinates: true
103 | pooling_method: "avg"
104 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/gc_bc.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level goal conditioned policy
3 | policy_class: "gc_bc"
4 | checkpoint_path: "checkpoints/gcbc_just_bridge_vision_backbone_1_200000"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 20
8 | exploration:
9 | make_traj_deterministic_prob: 0.5
10 | sampling_temperature: 0.2
11 | gripper_open_prob: 0.005
12 | gripper_close_prob: 0.01
13 | open_gripper_if_nothing_grasped: false
14 | restrict_action_space: false
15 | dataset_kwargs:
16 | shuffle_buffer_size: 25000
17 | augment: true
18 | augment_next_obs_goal_differently: false
19 | augment_kwargs:
20 | random_resized_crop:
21 | scale:
22 | - 0.8
23 | - 1.0
24 | ratio:
25 | - 0.9
26 | - 1.1
27 | random_brightness:
28 | - 0.2
29 | random_contrast:
30 | - 0.8
31 | - 1.2
32 | random_saturation:
33 | - 0.8
34 | - 1.2
35 | random_hue:
36 | - 0.1
37 | augment_order:
38 | - "random_resized_crop"
39 | - "random_brightness"
40 | - "random_contrast"
41 | - "random_saturation"
42 | - "random_hue"
43 | goal_relabeling_strategy: "geometric"
44 | goal_relabeling_kwargs:
45 | reached_proportion: 0.0
46 | discount: 0.98
47 | relabel_actions: true
48 | dataset_contains_commanded_goals: false
49 | normalization_type: "normal"
50 | ACT_MEAN:
51 | - 1.9296819e-04
52 | - 1.3667766e-04
53 | - -1.4583133e-04
54 | - -1.8390431e-04
55 | - -3.0808983e-04
56 | - 2.7425270e-04
57 | - 5.9716219e-01
58 | ACT_STD:
59 | - 0.00912848
60 | - 0.0127196
61 | - 0.01229497
62 | - 0.02606696
63 | - 0.02875283
64 | - 0.07807977
65 | - 0.48710242
66 | agent_kwargs:
67 | policy_kwargs:
68 | fixed_std:
69 | - 1.0
70 | - 1.0
71 | - 1.0
72 | - 1.0
73 | - 1.0
74 | - 1.0
75 | - 0.1
76 | std_parameterization: "fixed"
77 | tanh_squash_distribution: false
78 | early_goal_concat: true
79 | shared_goal_encoder: true
80 | use_proprio: false
81 | network_kwargs:
82 | hidden_dims:
83 | - 256
84 | - 256
85 | - 256
86 | dropout_rate: 0.1
87 | encoder:
88 | type: "resnetv1-34-bridge"
89 | config:
90 | act: "swish"
91 | add_spatial_coordinates: false
92 | pooling_method: "avg"
93 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/general_params.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # IP address and port of the robot
3 | ip: "128.32.175.227"
4 | port: 5556
5 |
6 | # IP address of web viewer
7 | web_viewer_ip: "128.32.175.81"
8 | web_viewer_port: 5000
9 |
10 | # Robot ID
11 | robot_id: 4
12 |
13 | # General config parameters
14 | sticky_gripper_num_steps: 1 # I'm thinking that for online improvement, we should turn off sticky gripper
15 | joints_reboot_interval: 1
16 | env_params:
17 | camera_topics:
18 | - name: "/blue/image_raw"
19 | flip: false
20 | override_workspace_boundaries:
21 | - - -20.0
22 | - -20.0
23 | - -20.0
24 | - -1.57
25 | - 0
26 | - - 20.0
27 | - 20.0
28 | - 20.0
29 | - 1.57
30 | - 0
31 | move_duration: 0.3
32 | video_save_path: "data_collection_logs/berkeley_robot_4/drawer_iql"
33 | shoulder_camera_image_size: 256 # size of image returned by shoulder cam
34 | initial_eep:
35 | - 0.3
36 | - 0.0
37 | - 0.15
38 | - 0
39 | - 0
40 | - 0
41 | - 1
42 | manual_workspace_bounds:
43 | x:
44 | - 0.22816383
45 | - 0.40895109
46 | y:
47 | - -0.11804535
48 | - 0.04906207
49 | z:
50 | - 0.04494154
51 | - 0.13566369
52 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/rcsl.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of low-level goal conditioned policy
3 | policy_class: "gc_ddpm_bc"
4 | checkpoint_path: "checkpoints/valentines_day_evals/rcsl_120k"
5 | mini_dataset_path: "mini_dataset"
6 | image_size: 256
7 | rollout_timesteps: 30
8 | exploration_noise: 0.000
9 | dataset_kwargs:
10 | shuffle_buffer_size: 25000
11 | augment: true
12 | augment_next_obs_goal_differently: false
13 | augment_kwargs:
14 | random_resized_crop:
15 | scale:
16 | - 0.8
17 | - 1.0
18 | ratio:
19 | - 0.9
20 | - 1.1
21 | random_brightness:
22 | - 0.2
23 | random_contrast:
24 | - 0.8
25 | - 1.2
26 | random_saturation:
27 | - 0.8
28 | - 1.2
29 | random_hue:
30 | - 0.1
31 | augment_order:
32 | - "random_resized_crop"
33 | - "random_brightness"
34 | - "random_contrast"
35 | - "random_saturation"
36 | - "random_hue"
37 | goal_relabeling_strategy: "geometric"
38 | goal_relabeling_kwargs:
39 | reached_proportion: 0.0
40 | discount: 0.98
41 | relabel_actions: true
42 | act_pred_horizon: 4
43 | obs_horizon: 1
44 | normalization_type: "tanh_normal"
45 | dataset_contains_commanded_goals: false
46 | ACT_MEAN:
47 | - 1.9296819e-04
48 | - 1.3667766e-04
49 | - -1.4583133e-04
50 | - -1.8390431e-04
51 | - -3.0808983e-04
52 | - 2.7425270e-04
53 | - 5.9716219e-01
54 | ACT_STD:
55 | - 0.00912848
56 | - 0.0127196
57 | - 0.01229497
58 | - 0.02606696
59 | - 0.02875283
60 | - 0.07807977
61 | - 0.48710242
62 | agent_kwargs:
63 | score_network_kwargs:
64 | time_dim: 32
65 | num_blocks: 3
66 | dropout_rate: 0.1
67 | hidden_dim: 256
68 | use_layer_norm: true
69 | early_goal_concat: true
70 | shared_goal_encoder: true
71 | use_proprio: false
72 | beta_schedule: "cosine"
73 | diffusion_steps: 20
74 | action_samples: 1
75 | repeat_last_step: 0
76 | learning_rate: 3.0e-4
77 | warmup_steps: 2000
78 | actor_decay_steps: 2000000
79 | encoder:
80 | type: "resnetv1-34-bridge"
81 | config:
82 | act: "swish"
83 | add_spatial_coordinates: true
84 | pooling_method: "avg"
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/reset_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Reset detector config params
3 | slack_token:
4 | channel_id: C069AT564KC
5 | which_vlm: "none" # choices: "cogvlm", "gpt4v", "none"
6 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/subgoal_predictor.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of high-level subgoal predictor policy
3 | diffusion_checkpoint: "kvablack/susie"
4 | diffusion_wandb: "kvablack/dlimp-diffusion/9n9ped8m"
5 | diffusion_pretrained_path: "runwayml/stable-diffusion-v1-5:flax"
6 | diffusion_num_steps: 50
7 | prompt_w: 7.5
8 | context_w: 2.0
9 | image_size: 256
10 | max_subgoals: 5
11 | susie_server_ip: "localhost"
12 | susie_server_port: 7000
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/success_detector.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of success detector
3 | which_vlm: "cogvlm" # choices: "cogvlm", "gpt4v", "none"
4 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/task_definition.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # task definition config params
3 | object_list:
4 | - "drawer"
5 | task_list:
6 | - "open the drawer"
7 | - "close the drawer"
8 |
--------------------------------------------------------------------------------
/data_collection/config/berkeley_robot_4/task_proposer.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # Config parameters of task proposer
3 | reuse_task_statistics: true
4 | rand_selection_prob: 0.5
5 | zone_center: 0.3
6 | ucb_weight: 100 # high value like 100 ignores dist_to_zone
7 | which_vlm: "cycling" # choices: "cogvlm", "gpt4v", "none", "human"
8 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/cogvlm_server/batched.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from PIL import Image
4 | from transformers import AutoModelForCausalLM, LlamaTokenizer
5 | import json
6 | import numpy as np
7 | from tqdm import tqdm
8 | from flask import Flask
9 | from flask import Flask, request, jsonify
10 | import copy
11 | import io
12 |
13 | MODEL_PATH = "THUDM/cogvlm-chat-hf"
14 | TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5"
15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16 |
17 | tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
18 | torch_type = torch.bfloat16
19 |
20 | print("Loading model to GPU mem")
21 | model = (
22 | AutoModelForCausalLM.from_pretrained(
23 | MODEL_PATH, torch_dtype=torch_type, load_in_4bit=False, trust_remote_code=True
24 | )
25 | .to(DEVICE)
26 | .eval()
27 | )
28 |
29 | app = Flask(__name__)
30 |
31 |
32 | @app.route("/query", methods=["POST"])
33 | def query_model():
34 | global tokenizer, model, torch_type, DEVICE
35 |
36 | image_files = request.files.getlist("image")
37 | prompts = request.form.getlist("prompt")
38 |
39 | if not image_files or not prompts or len(image_files) != len(prompts):
40 | return jsonify({"error": "Missing image(s) or prompt(s)"}), 400
41 |
42 | model_inputs = []
43 | max_len = -1
44 | for i, file in enumerate(image_files):
45 | image = Image.open(file.stream)
46 | image = np.array(image)[:, :, :3]
47 | image = Image.fromarray(image).resize((490, 490), Image.LANCZOS)
48 |
49 | input_by_model = model.build_conversation_input_ids(
50 | tokenizer, query=prompts[i], history=[], images=[image]
51 | )
52 | inputs = {
53 | "input_ids": input_by_model["input_ids"].unsqueeze(0).to(DEVICE),
54 | "token_type_ids": input_by_model["token_type_ids"].unsqueeze(0).to(DEVICE),
55 | "attention_mask": input_by_model["attention_mask"].unsqueeze(0).to(DEVICE),
56 | "images": [[input_by_model["images"][0].to(DEVICE).to(torch_type)]],
57 | }
58 | model_inputs.append(inputs)
59 | max_len = max(max_len, inputs["input_ids"].shape[1])
60 | concatenated_images = []
61 | for i in range(len(model_inputs)):
62 | tensor_shape = model_inputs[i]["input_ids"].shape[1]
63 | model_inputs[i]["input_ids"] = torch.cat(
64 | [
65 | torch.zeros(
66 | (1, max_len - tensor_shape), dtype=torch.long, device=DEVICE
67 | ),
68 | model_inputs[i]["input_ids"],
69 | ],
70 | dim=1,
71 | )
72 | model_inputs[i]["token_type_ids"] = torch.cat(
73 | [
74 | torch.zeros(
75 | (1, max_len - tensor_shape), dtype=torch.long, device=DEVICE
76 | ),
77 | model_inputs[i]["token_type_ids"],
78 | ],
79 | dim=1,
80 | )
81 | model_inputs[i]["attention_mask"] = torch.cat(
82 | [
83 | torch.zeros(
84 | (1, max_len - tensor_shape), dtype=torch.long, device=DEVICE
85 | ),
86 | model_inputs[i]["attention_mask"],
87 | ],
88 | dim=1,
89 | )
90 | concatenated_images.append(model_inputs[i]["images"][0])
91 | combined_inputs = {
92 | "input_ids": torch.cat([inputs["input_ids"] for inputs in model_inputs], dim=0),
93 | "token_type_ids": torch.cat(
94 | [inputs["token_type_ids"] for inputs in model_inputs], dim=0
95 | ),
96 | "attention_mask": torch.cat(
97 | [inputs["attention_mask"] for inputs in model_inputs], dim=0
98 | ),
99 | "images": concatenated_images,
100 | }
101 |
102 | gen_kwargs = {"max_length": 2048, "temperature": 0.0, "do_sample": False}
103 |
104 | with torch.no_grad():
105 | outputs = model.generate(**combined_inputs, **gen_kwargs)
106 | generation_strings = []
107 | for i in range(len(outputs)):
108 | output_tokens = outputs[i][max_len:]
109 | generation = tokenizer.decode(output_tokens)
110 | generation = generation[: generation.index("")]
111 | generation_strings.append(generation)
112 |
113 | return (
114 | jsonify(
115 | {
116 | "response": generation_strings,
117 | }
118 | ),
119 | 200,
120 | )
121 |
122 |
123 | if __name__ == "__main__":
124 | app.run(debug=False, host="0.0.0.0", port=5000)
125 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/cogvlm_server/forwarding_server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from PIL import Image
3 | import json
4 | import numpy as np
5 | from tqdm import tqdm
6 | from flask import Flask, request, jsonify
7 | import copy
8 | import io
9 | from io import BytesIO
10 | import requests
11 |
12 | app = Flask(__name__)
13 |
14 |
15 | @app.route("/query", methods=["POST"])
16 | def forward_request():
17 | print("recieved request")
18 | image_files = request.files.getlist("image")
19 | prompts = request.form.getlist("prompt")
20 | if not image_files or not prompts or len(image_files) != len(prompts):
21 | return jsonify({"error": "Missing image(s) or prompt(s)"}), 400
22 | numpy_images, prompt_strings = [], []
23 | for i, file in enumerate(image_files):
24 | image = Image.open(file.stream)
25 | image = np.array(image)[:, :, :3]
26 | numpy_images.append(image)
27 | prompt_strings.append(prompts[i])
28 | files = []
29 | for i, (numpy_image, prompt) in enumerate(zip(numpy_images, prompt_strings)):
30 | pil_image = Image.fromarray(numpy_image.astype("uint8"))
31 | img_byte_arr = io.BytesIO()
32 | pil_image.save(img_byte_arr, format="PNG") # can be JPEG or other formats
33 | img_byte_arr.seek(0)
34 |
35 | # Append the image file
36 | files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
37 |
38 | # Append the corresponding prompt
39 | files.append((f"prompt", (None, prompt)))
40 | url = "http://localhost:7000/query"
41 | response = requests.post(url, files=files)
42 | if response.status_code == 200:
43 | json_response = response.json()
44 | return jsonify(json_response), 200
45 | else:
46 | return jsonify({"error": "something went wrong"}), 500
47 |
48 |
49 | if __name__ == "__main__":
50 | app.run(debug=False, host="0.0.0.0", port=6000)
51 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/cogvlm_server/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from PIL import Image
4 | from transformers import AutoModelForCausalLM, LlamaTokenizer
5 | import json
6 | import numpy as np
7 | from tqdm import tqdm
8 | from flask import Flask
9 | from flask import Flask, request, jsonify
10 | import copy
11 | import io
12 |
13 | MODEL_PATH = "THUDM/cogvlm-chat-hf"
14 | TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5"
15 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16 |
17 | tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
18 | torch_type = torch.bfloat16
19 |
20 | print("Loading model to GPU mem")
21 | model = (
22 | AutoModelForCausalLM.from_pretrained(
23 | MODEL_PATH, torch_dtype=torch_type, load_in_4bit=False, trust_remote_code=True
24 | )
25 | .to(DEVICE)
26 | .eval()
27 | )
28 |
29 | app = Flask(__name__)
30 |
31 |
32 | @app.route("/query", methods=["POST"])
33 | def query_model():
34 | global tokenizer, model, torch_type, DEVICE
35 |
36 | if "image" not in request.files or "prompt" not in request.form:
37 | return jsonify({"error": "Missing image or prompt"}), 400
38 |
39 | prompt = request.form["prompt"]
40 |
41 | file = request.files["image"]
42 | image = Image.open(file.stream)
43 |
44 | # Resize image to 490x490
45 | image = np.array(image)[:, :, :3]
46 | image = Image.fromarray(image).resize((490, 490), Image.LANCZOS)
47 |
48 | input_by_model = model.build_conversation_input_ids(
49 | tokenizer, query=prompt, history=[], images=[image]
50 | )
51 | inputs = {
52 | "input_ids": input_by_model["input_ids"].unsqueeze(0).to(DEVICE),
53 | "token_type_ids": input_by_model["token_type_ids"].unsqueeze(0).to(DEVICE),
54 | "attention_mask": input_by_model["attention_mask"].unsqueeze(0).to(DEVICE),
55 | "images": [[input_by_model["images"][0].to(DEVICE).to(torch_type)]],
56 | }
57 |
58 | gen_kwargs = {"max_length": 2048, "temperature": 0.0, "do_sample": False}
59 |
60 | with torch.no_grad():
61 | outputs = model.generate(**inputs, **gen_kwargs)
62 | outputs = outputs[:, inputs["input_ids"].shape[1] :]
63 | response = tokenizer.decode(outputs[0])
64 | response = response.split("")[0]
65 |
66 | return (
67 | jsonify(
68 | {
69 | "response": response,
70 | }
71 | ),
72 | 200,
73 | )
74 |
75 |
76 | if __name__ == "__main__":
77 | app.run(debug=False, host="0.0.0.0", port=5000)
78 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/cogvlm_server/test-img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/data_collection/orchestrator/cogvlm_server/test-img.png
--------------------------------------------------------------------------------
/data_collection/orchestrator/cogvlm_server/test.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from pathlib import Path
3 |
4 | # Your Flask API endpoint URL
5 | url = "http://localhost:5001/query" # We're testing out ssh port forwarding
6 |
7 | image_path = "test-img.png"
8 | prompt_text = "Describe the image."
9 |
10 | # Make sure the image path is valid
11 | if not Path(image_path).is_file():
12 | print("The image file does not exist.")
13 | exit()
14 |
15 | # Open the image in binary mode
16 | with open(image_path, "rb") as image_file:
17 | # Prepare the data for the POST request
18 | payload = {"prompt": (None, prompt_text)}
19 | files = {"image": (image_path, image_file, "multipart/form-data")}
20 |
21 | # Send the POST request to the Flask API endpoint
22 | response = requests.post(url, files=files, data=payload)
23 |
24 | # Check the response status code
25 | if response.ok:
26 | print(response.json())
27 | else:
28 | print(f"Error: {response.status_code}")
29 | print(response.text)
30 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/cogvlm_server/test_batched.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from PIL import Image
3 | import io
4 | import numpy as np
5 |
6 | img = Image.open("test-img.png")
7 |
8 | # Your images as numpy arrays
9 | numpy_images = [np.array(img), np.array(img)]
10 | prompts = [
11 | "On which side of the metal tray is the coke can?",
12 | "On which side of the metal tray is the capsicum?",
13 | ]
14 |
15 | numpy_images = 5 * numpy_images
16 | prompts = 5 * prompts
17 |
18 | # The server endpoint
19 | url = "http://localhost:6000/query"
20 |
21 | files = []
22 | for i, (numpy_image, prompt) in enumerate(zip(numpy_images, prompts)):
23 | pil_image = Image.fromarray(numpy_image.astype("uint8"))
24 | img_byte_arr = io.BytesIO()
25 | pil_image.save(img_byte_arr, format="PNG") # can be JPEG or other formats
26 | img_byte_arr.seek(0)
27 |
28 | # Append the image file
29 | files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
30 |
31 | # Append the corresponding prompt
32 | files.append((f"prompt", (None, prompt)))
33 |
34 | # Perform the request
35 | response = requests.post(url, files=files)
36 |
37 | # Response handling
38 | if response.status_code == 200:
39 | print("Success:", response.json())
40 | else:
41 | print("Error:", response.status_code, response.text)
42 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/robot/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import os
6 | import subprocess
7 | import json
8 |
9 |
10 | class Logger:
11 | def __init__(self, config):
12 | self.robot_id = str(config["general_params"]["robot_id"])
13 | self.video_save_path = config["general_params"]["video_save_path"]
14 | # make sure video_save_path exists
15 | if not os.path.exists(self.video_save_path):
16 | os.makedirs(self.video_save_path)
17 |
18 | # Find the highest index trajectory already saved
19 | # in the save folder, and start counting trajectories
20 | # with one above that number
21 | indices = [
22 | int(name[4:])
23 | for name in os.listdir(self.video_save_path)
24 | if os.path.isdir(os.path.join(self.video_save_path, name))
25 | ]
26 | self.traj_idx = max(indices) + 1 if len(indices) != 0 else 0
27 |
28 | # Get image sizes from the config
29 | self.obs_image_size = config["gc_policy_params"]["image_size"]
30 | self.goal_image_size = config["subgoal_predictor_params"]["image_size"]
31 |
32 | # Initialize data structures for things we are logging
33 | self.obs_images = []
34 | self.goal_images = []
35 | self.actions = []
36 | self.poses = []
37 |
38 | # We will also log the scene information for each trajectory
39 | self.object_list = config["task_definition_params"]["object_list"]
40 | self.task_list = config["task_definition_params"]["task_list"]
41 |
42 | def log_obs(self, image: np.ndarray):
43 | assert image.shape == (
44 | self.obs_image_size,
45 | self.obs_image_size,
46 | 3,
47 | ), "Cannot log incorrectly shaped obs image"
48 | self.obs_images.append(image)
49 |
50 | def log_goal(self, image: np.ndarray):
51 | assert image.shape == (
52 | self.goal_image_size,
53 | self.goal_image_size,
54 | 3,
55 | ), "Cannot log incorrectly shaped goal image"
56 | self.goal_images.append(image)
57 |
58 | def log_action(self, action: np.ndarray):
59 | assert action.shape == (7,), "Action should have 7 dimensions"
60 | self.actions.append(action)
61 |
62 | def log_pose(self, pose: np.ndarray):
63 | """
64 | This method logs the pose of the robot before the action is taken
65 | """
66 | assert pose.shape == (7,), "Robot pose should have 7 dimensions"
67 | self.poses.append(pose)
68 |
69 | def reset(self):
70 | self.obs_images = []
71 | self.goal_images = []
72 | self.actions = []
73 | self.poses = []
74 |
75 | def flush_trajectory(
76 | self, commanded_task: str, success: bool, log_combined: bool = True
77 | ):
78 | subdir_path = os.path.join(self.video_save_path, "traj" + str(self.traj_idx))
79 | if not os.path.exists(subdir_path):
80 | os.makedirs(subdir_path)
81 |
82 | # Log the language task
83 | with open(os.path.join(subdir_path, "language_task.txt"), "w") as f:
84 | f.write(commanded_task)
85 |
86 | # Log the success information
87 | with open(os.path.join(subdir_path, "success.txt"), "w") as f:
88 | f.write(str(success))
89 |
90 | # Log the actions
91 | np.save(os.path.join(subdir_path, "actions.npy"), np.array(self.actions))
92 |
93 | # Log the robot poses
94 | np.save(os.path.join(subdir_path, "eef_poses.npy"), np.array(self.poses))
95 |
96 | # Log the observation video
97 | size = (self.obs_image_size, self.obs_image_size)
98 | out = cv2.VideoWriter(
99 | os.path.join(subdir_path, "trajectory.mp4"),
100 | cv2.VideoWriter_fourcc(*"DIVX"),
101 | 15,
102 | size,
103 | )
104 | for i in range(len(self.obs_images)):
105 | rgb_img = cv2.cvtColor(self.obs_images[i], cv2.COLOR_RGB2BGR)
106 | out.write(rgb_img)
107 | out.release()
108 |
109 | # Log the goals video
110 | size = (self.goal_image_size, self.goal_image_size)
111 | out = cv2.VideoWriter(
112 | os.path.join(subdir_path, "goals.mp4"),
113 | cv2.VideoWriter_fourcc(*"DIVX"),
114 | 15,
115 | size,
116 | )
117 | for i in range(len(self.goal_images)):
118 | rgb_img = cv2.cvtColor(self.goal_images[i], cv2.COLOR_RGB2BGR)
119 | out.write(rgb_img)
120 | out.release()
121 |
122 | # Log the combined image
123 | if log_combined:
124 | assert (
125 | self.obs_image_size == self.goal_image_size
126 | ), "To log combined video obs and goal images must be the same size"
127 | assert len(self.obs_images) == len(
128 | self.goal_images
129 | ), "To log combined video there must be equal number of obs and goal images"
130 | size = (self.obs_image_size + self.goal_image_size, self.obs_image_size)
131 | out = cv2.VideoWriter(
132 | os.path.join(subdir_path, "combined.mp4"),
133 | cv2.VideoWriter_fourcc(*"DIVX"),
134 | 15,
135 | size,
136 | )
137 | for i in range(len(self.goal_images)):
138 | combined_image = np.concatenate(
139 | [self.obs_images[i], self.goal_images[i]], axis=1
140 | )
141 | rgb_img = cv2.cvtColor(combined_image, cv2.COLOR_RGB2BGR)
142 | out.write(rgb_img)
143 | out.release()
144 |
145 | # Log the scene information
146 | obj_list_dest = os.path.join(subdir_path, "object_list.txt")
147 | task_list_dest = os.path.join(subdir_path, "task_list.txt")
148 | time_dest = os.path.join(subdir_path, "time.txt")
149 | robot_id_dest = os.path.join(subdir_path, "robot_id.txt")
150 | with open(obj_list_dest, "w") as f:
151 | json.dump(self.object_list, f, indent=4)
152 | with open(task_list_dest, "w") as f:
153 | json.dump(self.task_list, f, indent=4)
154 | time = subprocess.check_output("date", shell=True).decode("utf-8").strip()
155 | with open(time_dest, "w") as f:
156 | f.write(time)
157 | robot_id = self.robot_id
158 | with open(robot_id_dest, "w") as f:
159 | f.write(robot_id)
160 |
161 | # Reset variables
162 | self.obs_images = []
163 | self.goal_images = []
164 | self.actions = []
165 | self.poses = []
166 | self.traj_idx += 1
167 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/robot/reset_detector.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from slack_sdk import WebClient
3 | from slack_sdk.errors import SlackApiError
4 | import utils
5 | from nltk.corpus import cmudict
6 |
7 |
8 | def send_slack_message(client, channel_id, message):
9 | try:
10 | response = client.chat_postMessage(channel=channel_id, text=message)
11 | assert response["message"]["text"] == message
12 | except SlackApiError as e:
13 | assert e.response["ok"] is False
14 | # str like 'invalid_auth', 'channel_not_found'
15 | assert e.response["error"]
16 | print(f"Got an error: {e.response['error']}")
17 |
18 |
19 | class ResetDetector:
20 | """
21 | Uses a VLM to predict whether an object is missing (and thus a reset is required).
22 | If so, the Slack API is used to send a notification
23 | """
24 |
25 | def __init__(self, config):
26 | self.config = config
27 |
28 | # Load object list
29 | self.objects = self.config["task_definition_params"]["object_list"]
30 |
31 | # Prepare prompt strings
32 | self.prompt_components = [
33 | "Is there ",
34 | " present in the image?",
35 | "I gave a vision-language model an image and asked if there was ",
36 | " present in the image. The following was the model's response:\n\n######\n",
37 | "\n######\n\nIf the model thought the object in question was present, return just the word 'true'. Else return 'false'.",
38 | ]
39 |
40 | # connecting to slack
41 | slack_token = self.config["reset_detector_params"]["slack_token"]
42 | self.slack_client = WebClient(token=slack_token)
43 |
44 | # cache for anything you want
45 | self.cache = {}
46 |
47 | # cache the status of each object so we don't send slack messages every time
48 | self.cache["object_status"] = {obj: True for obj in self.objects}
49 |
50 | def prepend_article(self, object_name):
51 | first_word = object_name.split()[0]
52 |
53 | def starts_with_vowel_sound(word, pronunciations=cmudict.dict()):
54 | for syllables in pronunciations.get(word, []):
55 | return syllables[0][-1].isdigit() # use only the first one
56 |
57 | if starts_with_vowel_sound(first_word):
58 | return "an " + object_name
59 | return "a " + object_name
60 |
61 | def detect(self, image: np.ndarray):
62 | objects_present = []
63 | if self.config["reset_detector_params"]["which_vlm"] != "none":
64 |
65 | # prompts for the VLM for each object
66 | image_prompts = []
67 | for object in self.objects:
68 | image_prompt = (
69 | self.prompt_components[0]
70 | + self.prepend_article(object)
71 | + self.prompt_components[1]
72 | )
73 | image_prompts.append(image_prompt)
74 |
75 | # image for the VLM for each object
76 | images = [image] * len(self.objects)
77 |
78 | if self.config["reset_detector_params"]["which_vlm"] == "cogvlm":
79 | vlm_output = utils.ask_cogvlm_batched(
80 | images, image_prompts, self.config
81 | )
82 | elif self.config["reset_detector_params"]["which_vlm"] == "gpt4v":
83 | vlm_output = utils.ask_gpt4v_batched(images, image_prompts)
84 | else:
85 | vlm_output = (
86 | "Yes, " + self.prepend_article(object) + " is present in the image."
87 | )
88 | print("vlm output:", vlm_output)
89 |
90 | # decoding prompts for the LLM
91 | decoding_prompts = [
92 | self.prompt_components[2]
93 | + self.prepend_article(self.objects[i])
94 | + self.prompt_components[3]
95 | + vlm_output[i].strip()
96 | + self.prompt_components[4]
97 | for i in range(len(self.objects))
98 | ]
99 | print("decoding prompts:", decoding_prompts)
100 |
101 | llm_output = utils.ask_gpt4_batched(decoding_prompts, cache=self.cache)
102 | for q, r in zip(decoding_prompts, llm_output):
103 | self.cache[q] = r
104 |
105 | # object presence info
106 | updated_object_status = {}
107 | for llm_answer, obj in zip(llm_output, self.objects):
108 | in_scene = llm_answer.strip().lower() == "true"
109 | try:
110 | assert in_scene in (True, False)
111 | except AssertionError:
112 | in_scene = False
113 | updated_object_status[obj] = in_scene
114 | print(f"Reset Detector LLM said {obj} is present:", in_scene)
115 | objects_present.append(in_scene)
116 |
117 | else:
118 | objects_present = [True] * len(self.objects)
119 | updated_object_status = {obj: True for obj in self.objects}
120 |
121 | objects_not_present = []
122 | for i in range(len(objects_present)):
123 | if not objects_present[i]:
124 | objects_not_present.append(self.objects[i])
125 |
126 | # missing objects are those in objects_not_present, but have previous status as True
127 | # i.e. they were present the last time ResetDetector was run but not this time
128 | missing_objects = [
129 | obj for obj in objects_not_present if self.cache["object_status"][obj]
130 | ]
131 | if len(missing_objects) != 0:
132 | # Reset required, send slack message
133 | message = f"Hey! Robot {self.config['general_params']['robot_id']} is missing objects: "
134 | message += ", ".join(missing_objects)
135 |
136 | channel_id = self.config["reset_detector_params"]["channel_id"]
137 | send_slack_message(self.slack_client, channel_id, message)
138 |
139 | # update object status
140 | self.cache["object_status"] = updated_object_status
141 |
142 | # Prepare return dictionary
143 | to_return = {}
144 | for i in range(len(objects_present)):
145 | to_return[self.objects[i]] = objects_present[
146 | i
147 | ] # we're commenting this out bc the VLM can be unreliable, and we don't want to not propose tasks we can actually execute
148 | return to_return
149 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/robot/subgoal_predictor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import requests
3 | from PIL import Image
4 | import io
5 |
6 |
7 | class SubgoalPredictor:
8 | def __init__(self, config):
9 | diffusion_config = config["subgoal_predictor_params"]
10 | self.image_size = diffusion_config["image_size"]
11 | self.url = (
12 | "http://"
13 | + diffusion_config["susie_server_ip"]
14 | + ":"
15 | + str(diffusion_config["susie_server_port"])
16 | + "/generate_subgoal"
17 | )
18 |
19 | def numpy_to_image(self, np_array):
20 | """Convert a NumPy array to a PIL Image."""
21 | return Image.fromarray(np.uint8(np_array))
22 |
23 | def image_to_numpy(self, image):
24 | """Convert a PIL Image to a NumPy array."""
25 | return np.array(image)
26 |
27 | def send_image_and_text(self, url, np_image, text):
28 | """Send a NumPy image array and text to the specified URL."""
29 | # Convert NumPy array to PIL Image
30 | image = self.numpy_to_image(np_image)
31 |
32 | # Save the PIL Image to a bytes buffer
33 | img_buffer = io.BytesIO()
34 | image.save(img_buffer, format="JPEG")
35 | img_buffer.seek(0)
36 |
37 | # Prepare files and data for the request
38 | files = {"image": ("image.jpg", img_buffer, "image/jpeg")}
39 | data = {"text": text}
40 |
41 | # Send POST request
42 | response = requests.post(url, files=files, data=data)
43 | return response
44 |
45 | def __call__(self, image_obs: np.ndarray, prompt: str):
46 | assert image_obs.shape == (
47 | self.image_size,
48 | self.image_size,
49 | 3,
50 | ), "Bad input image shape"
51 |
52 | response = self.send_image_and_text(self.url, image_obs, prompt)
53 | if response.status_code == 200:
54 | # Convert the response content back to a NumPy array
55 | image = Image.open(io.BytesIO(response.content))
56 | output_np_image = self.image_to_numpy(image)
57 | else:
58 | print("Failed to process image", response.status_code, response.text)
59 | return None
60 |
61 | return output_np_image
62 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/robot/task_success_predictor.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from tqdm import tqdm
4 | import numpy as np
5 |
6 | import utils
7 |
8 |
9 | class SuccessPredictor:
10 | """
11 | Uses a VLM to predict whether a given task has been completed
12 | """
13 |
14 | def __init__(self, config):
15 | self.config = config
16 |
17 | # Prompt to convert task description into VQA style question that can be used to assess task completion
18 | self.task_to_vqa_prompt = """I have a robot arm manipulator in a lab setting that can perform many manipulation tasks. I commanded a task to the robot to perform and I now want to assess whether the robot was successful in completing this task.\n\nTo determine whether or not the robot successfully completed the task, I have access to a vision-language model (VLM) that can answer questions for me when I provide it an image of the lab environment the robot is operating in. Since I can only ask it simple questions about the image, I need to convert the task description into a question that I can feed into the VLM.\n\nHere are several examples:\n\nTask: put the yellow ball on the blue plate\nVLM question: Is the yellow ball on the blue plate?\n\nTask: move the yellow ball from the blue plate to the table\nVLM question: Is the yellow ball on the blue plate or the table?\n\nTask: move the yellow ball to the left side of the table\nVLM question: Is the yellow ball on the left side of the table?\n\nTask: move the yellow ball to the right side of the table\nVLM question: Is the yellow ball on the right side of the table?\n\nTask: put the orange crayon on the blue plate\nVLM question: Is the orange crayon on the blue plate?\n\nTask: move the orange crayon from the blue plate to the table\nVLM question: Is the orange crayon on the blue plate or the table?\n\nTask: put the orange crayon on the cloth\nVLM question: Is the orange crayon on top of the cloth?\n\nTask: move the orange crayon from the cloth to the table\nVLM question: Is the orange crayon on the cloth or the table?\n\nTask: move the orange crayon from the cloth to the blue plate\nVLM question: Is the orange crayon on the cloth or the blue plate?\n\nTask: move the orange crayon from the blue plate to the cloth\nVLM question: Is the orange crayon on the blue plate or the cloth?\n\nTask: move the orange crayon to the right side of the table\nVLM question: Is the orange crayon on the right side of the table?\n\nTask: move the orange crayon to the left side of the table\nVLM question: Is the orange crayon on the left side of the table?\n\nTask: move the red object from the blue plate to the table\nVLM question: Is the red object on the blue plate or the table?\n\nTask: put the red object on the blue plate\nVLM question: Is the red object on the blue plate?\n\nTask: move the red object to the left side of the table\nVLM question: Is the red object on the left side of the table?\n\nTask: move the red object to the right side of the table\nVLM question: Is the red object on the right side of the table?\n\nTask: put the red object on the cloth\nVLM question: Is the red object on the cloth?\n\nTask: move the red object from the cloth to the table\nVLM question: Is the red object on the cloth or the table?\n\nTask: move the red object from the cloth to the blue plate\nVLM question: Is the red object on the cloth or the blue plate?\n\nTask: move the red object from the blue plate to the cloth\nVLM question: Is the red object on the blue plate or the cloth?\n\nTask: move the cloth to the right side of the table\nVLM question: Is the cloth on the right side of the table?\n\nTask: move the cloth to the left side of the table\nVLM question: Is the cloth on the left side of the table?\n\nFollowing the format of these examples, give me the VLM question for the following task:\n\nTask: """
19 |
20 | # Prompt to decode VLM output into true/false
21 | self.prompt_to_parse_vlm_output = [
22 | "I have a robot arm manipulator in a lab setting. I commanded it to complete the following task:\n\n",
23 | "\n\nI want to assess whether the robot arm successfully completed the task. To do so, I prompted a vision-language model (VLM) with an image of the current robot workspace and the following question:\n\n",
24 | "\n\nIn response, the VLM answered the following:\n\n",
25 | "\n\nBased on the task commanded and the VLM's response to the question, determine if the robot successfully completed the commanded task or not. If it did successfully complete the task, return just the word true. Otherwise return the word false. If for some reason the answer is neither true nor false, return false.",
26 | ]
27 |
28 | # Use for caching anything you want
29 | self.cache = {}
30 |
31 | # to record the success rates
32 | self.task_success_record = (
33 | self.init_previous_task_stats()
34 | ) # task -> list of bools
35 |
36 | def init_previous_task_stats(self):
37 | task_success_record = {}
38 |
39 | if self.config["task_proposer_params"]["reuse_task_statistics"]:
40 | trajectory_log_dir = self.config["general_params"]["video_save_path"]
41 | logged_trajs = [
42 | traj
43 | for traj in os.listdir(trajectory_log_dir)
44 | if os.path.isdir(os.path.join(trajectory_log_dir, traj))
45 | ]
46 | for traj in tqdm(logged_trajs):
47 | traj_path = os.path.join(trajectory_log_dir, traj)
48 | with open(os.path.join(traj_path, "language_task.txt")) as f:
49 | traj_task = f.read().strip().lower()
50 | with open(os.path.join(traj_path, "success.txt")) as f:
51 | traj_success = f.read().strip().lower()
52 | if traj_task not in task_success_record:
53 | task_success_record[traj_task] = []
54 | task_success_record[traj_task].append(traj_success == "true")
55 |
56 | return task_success_record
57 |
58 | def record_task_success(self, task_str, success):
59 | if task_str not in self.task_success_record:
60 | self.task_success_record[task_str] = []
61 | self.task_success_record[task_str].append(success)
62 |
63 | def get_success_rate(self, n_most_recent=None):
64 | success_rates = {}
65 | for task_str, success_list in self.task_success_record.items():
66 | if n_most_recent is not None:
67 | success_list = success_list[-n_most_recent:]
68 | success_rates[task_str] = sum(success_list) / len(success_list)
69 | return success_rates
70 |
71 | def predict_outcome(self, image: np.ndarray, task_str: str, log_metrics=True):
72 | # convert the task_str into a VQA style question
73 | vqa_style_q_unparsed = utils.ask_gpt4(
74 | self.task_to_vqa_prompt + task_str, cache=self.cache
75 | )
76 |
77 | # add response to cache
78 | self.cache[self.task_to_vqa_prompt + task_str] = vqa_style_q_unparsed
79 |
80 | vqa_style_q_unparsed = vqa_style_q_unparsed.strip()
81 | if ":" in vqa_style_q_unparsed:
82 | vqa_style_q = vqa_style_q_unparsed[vqa_style_q_unparsed.index(":") + 2 :]
83 | else:
84 | vqa_style_q = vqa_style_q_unparsed
85 | print("vqa_style_q:", vqa_style_q)
86 |
87 | # ask the VLM
88 | if self.config["success_detector_params"]["which_vlm"] == "gpt4v":
89 | vlm_output = utils.ask_gpt4v(image, vqa_style_q)
90 | elif self.config["success_detector_params"]["which_vlm"] == "cogvlm":
91 | vlm_output = utils.ask_cogvlm(image, vqa_style_q, self.config)
92 | else:
93 | # If there's no VLM success detector, we will conservatively assume the task failed
94 | return False
95 | print("vlm_output:", vlm_output)
96 |
97 | # parse the output
98 | decoding_prompt = (
99 | self.prompt_to_parse_vlm_output[0]
100 | + task_str
101 | + self.prompt_to_parse_vlm_output[1]
102 | + vqa_style_q
103 | + self.prompt_to_parse_vlm_output[2]
104 | + vlm_output
105 | + self.prompt_to_parse_vlm_output[3]
106 | )
107 | print("decoding_prompt:", decoding_prompt)
108 | parsed_vlm_output = utils.ask_gpt4(decoding_prompt, cache=self.cache)
109 |
110 | # add response to cache
111 | self.cache[decoding_prompt] = parsed_vlm_output
112 |
113 | print("parsed_vlm_output:", parsed_vlm_output)
114 |
115 | success = parsed_vlm_output.strip().lower() == "true"
116 | try:
117 | assert success in (True, False)
118 | except AssertionError:
119 | print(
120 | "Error: VLM output was neither 'true' nor 'false'. Assuming task failed."
121 | )
122 | success = False
123 |
124 | if log_metrics:
125 | self.record_task_success(task_str, success)
126 |
127 | return success
128 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/robot/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pyquaternion import Quaternion
3 | from PIL import Image
4 | import os
5 | import requests
6 | from typing import List
7 | from openai import OpenAI
8 | import io
9 | import base64
10 | from tqdm import tqdm
11 | from concurrent.futures import ThreadPoolExecutor
12 | from multiprocessing import Pool
13 |
14 |
15 | def state_to_eep(xyz_coor, zangle: float):
16 | """
17 | Implement the state to eep function.
18 | Refered to `bridge_data_robot`'s `widowx_controller/widowx_controller.py`
19 | return a 4x4 matrix
20 | """
21 | assert len(xyz_coor) == 3
22 | DEFAULT_ROTATION = np.array([[0, 0, 1.0], [0, 1.0, 0], [-1.0, 0, 0]])
23 | new_pose = np.eye(4)
24 | new_pose[:3, -1] = xyz_coor
25 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) * Quaternion(
26 | matrix=DEFAULT_ROTATION
27 | )
28 | new_pose[:3, :3] = new_quat.rotation_matrix
29 | # yaw, pitch, roll = quat.yaw_pitch_roll
30 | return new_pose
31 |
32 |
33 | def get_observation(widowx_client, config):
34 | while True:
35 | obs = widowx_client.get_observation()
36 | if obs is None:
37 | print("WARNING: failed to get robot observation, retrying...")
38 | else:
39 | break
40 | obs["image"] = (
41 | obs["image"]
42 | .reshape(
43 | 3,
44 | config["general_params"]["shoulder_camera_image_size"],
45 | config["general_params"]["shoulder_camera_image_size"],
46 | )
47 | .transpose(1, 2, 0)
48 | * 255
49 | ).astype(np.uint8)
50 | return obs
51 |
52 |
53 | def encode_image_np(image_np: np.ndarray):
54 | # Ensure the NumPy array is in uint8
55 | if image_np.dtype != np.uint8:
56 | image_np = image_np.astype(np.uint8)
57 | # Convert the NumPy array to a PIL Image
58 | pil_image = Image.fromarray(image_np)
59 | # Create a buffer to hold the bytes
60 | buffer = io.BytesIO()
61 | # Save the image to the buffer in PNG format (or JPEG or any other format)
62 | pil_image.save(buffer, format="PNG")
63 | # Encode the buffer's content in base64
64 | base64_encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
65 | return base64_encoded
66 |
67 |
68 | def ask_gpt4v(image: np.ndarray, prompt: str):
69 | # We will first resize the image to 512x512
70 | if not image.shape == (512, 512, 3):
71 | image_pil = Image.fromarray(image)
72 | resized_pil = image_pil.resize((512, 512), Image.ANTIALIAS)
73 | image = np.array(resized_pil)
74 |
75 | # Prepare jsons for openai api requests
76 | headers = {
77 | "Content-Type": "application/json",
78 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
79 | }
80 | payload = {
81 | "model": "gpt-4-turbo", # gpt4o
82 | "messages": [
83 | {
84 | "role": "user",
85 | "content": [
86 | {"type": "text", "text": "###############"},
87 | {
88 | "type": "image_url",
89 | "image_url": {"url": "###############", "detail": "low"},
90 | },
91 | ],
92 | }
93 | ],
94 | "max_tokens": 2000,
95 | "temperature": 0.0,
96 | }
97 |
98 | base64_image = encode_image_np(image)
99 | payload["messages"][0]["content"][1]["image_url"][
100 | "url"
101 | ] = f"data:image/jpeg;base64,{base64_image}"
102 | payload["messages"][0]["content"][0]["text"] = prompt
103 |
104 | while True:
105 | response = requests.post(
106 | "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
107 | ).json()
108 | if "error" in response:
109 | continue
110 | assistant_message = response["choices"][0]["message"]["content"]
111 | break
112 |
113 | return assistant_message
114 |
115 |
116 | def ask_gpt4v_batched(images: List[np.ndarray], prompts: List[str]):
117 | assert len(images) == len(prompts)
118 |
119 | # TODO: implement real batching
120 | assistant_messages = []
121 | print("querying gpt4v batched")
122 | for i in tqdm(range(len(images))):
123 | assistant_messages.append(ask_gpt4v(images[i], prompts[i]))
124 |
125 | return assistant_messages
126 |
127 |
128 | def ask_gpt4(prompt, cache=None):
129 | if type(prompt) == tuple:
130 | prompt, cache = prompt
131 |
132 | # check if cache contains the answer
133 | if cache is not None and prompt in cache:
134 | return cache[prompt]
135 |
136 | # Prepare jsons for openai api requests
137 | headers = {
138 | "Content-Type": "application/json",
139 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
140 | }
141 | payload = {
142 | "model": "gpt-3.5-turbo-1106",
143 | "messages": [
144 | {
145 | "role": "user",
146 | "content": [
147 | {"type": "text", "text": "###############"},
148 | ],
149 | }
150 | ],
151 | "max_tokens": 2000,
152 | "temperature": 0.0,
153 | }
154 |
155 | payload["messages"][0]["content"][0]["text"] = prompt
156 | while True:
157 | try:
158 | response = requests.post(
159 | "https://api.openai.com/v1/chat/completions",
160 | headers=headers,
161 | json=payload,
162 | ).json()
163 | except:
164 | # sometime we get requests.exceptions.JSONDecodeError
165 | continue
166 | if "error" in response:
167 | continue
168 | assistant_message = response["choices"][0]["message"]["content"]
169 | break
170 |
171 | return assistant_message
172 |
173 |
174 | def ask_gpt4_batched(prompts, cache=None):
175 | if prompts is None or len(prompts) == 0:
176 | return []
177 | if cache is not None:
178 | # zip cache with the prompts list
179 | prompts = [(prompt, cache) for prompt in prompts]
180 | with Pool(len(prompts)) as p:
181 | assistant_messages = p.map(ask_gpt4, prompts)
182 | return assistant_messages
183 |
184 |
185 | # def ask_gpt4_batched(prompts):
186 | # assistant_messages = []
187 | # for prompt in tqdm(prompts):
188 | # assistant_messages.append(ask_gpt4(prompt))
189 | # return assistant_messages
190 |
191 |
192 | def ask_cogvlm(image: np.ndarray, prompt: str, config):
193 | image_list, prompt_list = [image], [prompt]
194 | return ask_cogvlm_batched(image_list, prompt_list, config)[0]
195 |
196 |
197 | def ask_cogvlm_batched(images: List[np.ndarray], prompts: List[str], config):
198 |
199 | def _ask_cogvlm_batched_helper():
200 | assert len(images) == len(prompts)
201 |
202 | files = []
203 | for i, (numpy_image, prompt) in enumerate(zip(images, prompts)):
204 | pil_image = Image.fromarray(numpy_image.astype("uint8"))
205 | img_byte_arr = io.BytesIO()
206 | pil_image.save(img_byte_arr, format="PNG") # can be JPEG or other formats
207 | img_byte_arr.seek(0)
208 |
209 | # Append the image file
210 | files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
211 |
212 | # Append the corresponding prompt
213 | files.append((f"prompt", (None, prompt)))
214 |
215 | url = (
216 | "http://"
217 | + config["cogvlm_server_params"]["cogvlm_server_ip"]
218 | + ":"
219 | + str(config["cogvlm_server_params"]["cogvlm_server_port"])
220 | + "/query"
221 | )
222 | response = requests.post(url, files=files)
223 |
224 | return response.json()["response"]
225 |
226 | # repeat the query if it fails
227 | while True:
228 | try:
229 | response = _ask_cogvlm_batched_helper()
230 | break
231 | except:
232 | continue
233 |
234 | return response
235 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/set_workspace_bounds/teleop.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from yamlinclude import YamlIncludeConstructor
3 | from absl import app, flags
4 | import numpy as np
5 | from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs
6 | import os
7 | import sys
8 | import tty
9 | import termios
10 | import time
11 |
12 | FLAGS = flags.FLAGS
13 | flags.DEFINE_string(
14 | "config_dir",
15 | None,
16 | "Path to config directory",
17 | required=True,
18 | )
19 |
20 | print_yellow = lambda x: print("\033[93m {}\033[00m".format(x))
21 |
22 |
23 | def print_help():
24 | print_yellow(" Teleop Controls:")
25 | print_yellow(" w, s : move forward/backward")
26 | print_yellow(" a, d : move left/right")
27 | print_yellow(" z, c : move up/down")
28 | print_yellow(" i, k: rotate yaw")
29 | print_yellow(" j, l: rotate pitch")
30 | print_yellow(" n m: rotate roll")
31 | print_yellow(" space: toggle gripper")
32 | print_yellow(" r: reset robot")
33 | print_yellow(" q: quit")
34 |
35 |
36 | def main(_):
37 | YamlIncludeConstructor.add_to_loader_class(
38 | loader_class=yaml.FullLoader, base_dir=FLAGS.config_dir
39 | )
40 | with open(os.path.join(FLAGS.config_dir, "config.yaml")) as f:
41 | config = yaml.load(f, Loader=yaml.FullLoader)
42 |
43 | env_params = WidowXConfigs.DefaultEnvParams.copy()
44 | env_params.update({"action_clipping": None})
45 | client = WidowXClient(
46 | host=config["general_params"]["ip"], port=config["general_params"]["port"]
47 | )
48 | client.init(env_params)
49 | client.reset()
50 |
51 | # Save the terminal settings
52 | fd = sys.stdin.fileno()
53 | old_settings = termios.tcgetattr(fd)
54 |
55 | print_help()
56 | is_open = 1
57 | running = True
58 | xyz_min, xyz_max = None, None
59 | while running:
60 | # Check for key press
61 | try:
62 | # Set the terminal to raw mode to read a single key press
63 | tty.setraw(sys.stdin.fileno())
64 | key = sys.stdin.read(1)
65 | finally:
66 | # Restore the terminal to its original settings
67 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
68 |
69 | # escape key to quit
70 | if key == "q":
71 | print("Quitting teleoperation.")
72 | running = False
73 | continue
74 |
75 | # Handle key press for robot control
76 | # translation
77 | if key == "w":
78 | client.step_action(np.array([0.01, 0, 0, 0, 0, 0, is_open]))
79 | elif key == "s":
80 | client.step_action(np.array([-0.01, 0, 0, 0, 0, 0, is_open]))
81 | elif key == "a":
82 | client.step_action(np.array([0, 0.01, 0, 0, 0, 0, is_open]))
83 | elif key == "d":
84 | client.step_action(np.array([0, -0.01, 0, 0, 0, 0, is_open]))
85 | elif key == "z":
86 | client.step_action(np.array([0, 0, 0.01, 0, 0, 0, is_open]))
87 | elif key == "c":
88 | client.step_action(np.array([0, 0, -0.01, 0, 0, 0, is_open]))
89 |
90 | # rotation
91 | elif key == "i":
92 | client.step_action(np.array([0, 0, 0, 0.01, 0, 0, is_open]))
93 | elif key == "k":
94 | client.step_action(np.array([0, 0, 0, -0.01, 0, 0, is_open]))
95 | elif key == "j":
96 | client.step_action(np.array([0, 0, 0, 0, 0.01, 0, is_open]))
97 | elif key == "l":
98 | client.step_action(np.array([0, 0, 0, 0, -0.01, 0, is_open]))
99 | elif key == "n":
100 | client.step_action(np.array([0, 0, 0, 0, 0, 0.01, is_open]))
101 | elif key == "m":
102 | client.step_action(np.array([0, 0, 0, 0, 0, -0.01, is_open]))
103 |
104 | # space bar to change gripper state
105 | elif key == " ":
106 | is_open = 1 - is_open
107 | print("Gripper is now: ", is_open)
108 | client.step_action(np.array([0, 0, 0, 0, 0, 0, is_open]))
109 | elif key == "r":
110 | print("Resetting robot...")
111 | client.reset()
112 | print_help()
113 |
114 | # Get the end-effector position after taking action
115 | obs = client.get_observation()
116 | eef_pose = obs["state"]
117 | if xyz_min is None or xyz_max is None:
118 | xyz_min = eef_pose[:3]
119 | xyz_max = eef_pose[:3]
120 | xyz_min = np.minimum(xyz_min, eef_pose[:3])
121 | xyz_max = np.maximum(xyz_max, eef_pose[:3])
122 | print("robot pose:", eef_pose)
123 |
124 | client.stop() # Properly stop the client
125 | print("Teleoperation ended.")
126 |
127 | print()
128 | print("XYZ Min:", xyz_min)
129 | print("XYZ Max:", xyz_max)
130 |
131 |
132 | if __name__ == "__main__":
133 | app.run(main)
134 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/susie_server/main.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from yamlinclude import YamlIncludeConstructor
3 | import argparse
4 | import os
5 | import numpy as np
6 | from susie.model import create_sample_fn
7 | from flask import Flask, request, send_file
8 | from PIL import Image
9 | import io
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--config_dir", required=True)
13 | args = parser.parse_args()
14 |
15 | YamlIncludeConstructor.add_to_loader_class(
16 | loader_class=yaml.FullLoader, base_dir=args.config_dir
17 | )
18 | with open(os.path.join(args.config_dir, "config.yaml")) as f:
19 | config = yaml.load(f, Loader=yaml.FullLoader)
20 |
21 |
22 | class SubgoalPredictor:
23 | def __init__(self, config):
24 | diffusion_config = config["subgoal_predictor_params"]
25 | self.diffusion_sample_func = create_sample_fn(
26 | diffusion_config["diffusion_checkpoint"],
27 | diffusion_config["diffusion_wandb"],
28 | diffusion_config["diffusion_num_steps"],
29 | diffusion_config["prompt_w"],
30 | diffusion_config["context_w"],
31 | 0.0,
32 | diffusion_config["diffusion_pretrained_path"],
33 | )
34 | self.image_size = diffusion_config["image_size"]
35 |
36 | def __call__(self, image_obs: np.ndarray, prompt: str):
37 | assert image_obs.shape == (
38 | self.image_size,
39 | self.image_size,
40 | 3,
41 | ), "Bad input image shape"
42 | return self.diffusion_sample_func(image_obs, prompt)
43 |
44 |
45 | subgoal_predictor = SubgoalPredictor(config)
46 |
47 | app = Flask(__name__)
48 |
49 |
50 | @app.route("/generate_subgoal", methods=["POST"])
51 | def process_image():
52 | global subgoal_predictor
53 |
54 | # Check if the request contains the 'image' file
55 | if "image" not in request.files:
56 | return "No image part", 400
57 | file = request.files["image"]
58 |
59 | # Check if the request contains the 'text' part
60 | if "text" not in request.form:
61 | return "No text provided", 400
62 | text = request.form["text"]
63 |
64 | # Read the image file
65 | image = Image.open(file.stream)
66 | image = np.array(image)
67 |
68 | generated_subgoal = subgoal_predictor(image_obs=image, prompt=text)
69 |
70 | # Save the image to a binary stream
71 | generated_subgoal = Image.fromarray(generated_subgoal)
72 | img_io = io.BytesIO()
73 | generated_subgoal.save(img_io, "JPEG", quality=70)
74 | img_io.seek(0)
75 |
76 | return send_file(img_io, mimetype="image/jpeg")
77 |
78 |
79 | if __name__ == "__main__":
80 | app.run(debug=False, host="0.0.0.0", port=7000)
81 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/web_viewer/app.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, render_template, request, jsonify, send_file
2 | import os
3 | import portalocker
4 |
5 | app = Flask(__name__)
6 | app.config["UPLOAD_FOLDER"] = "./uploads"
7 | app.config["MAX_CONTENT_LENGTH"] = 16 * 1024 * 1024 # 16MB upload limit
8 |
9 | robot_video_feeds = {
10 | "feed0": {
11 | "observation": None,
12 | "goal": None,
13 | "status": {
14 | "commanded_task": "N/A",
15 | "subgoal": 0,
16 | "timestep": 0,
17 | "task_success": "N/A",
18 | },
19 | },
20 | "feed1": {
21 | "observation": None,
22 | "goal": None,
23 | "status": {
24 | "commanded_task": "N/A",
25 | "subgoal": 0,
26 | "timestep": 0,
27 | "task_success": "N/A",
28 | },
29 | },
30 | "feed2": {
31 | "observation": None,
32 | "goal": None,
33 | "status": {
34 | "commanded_task": "N/A",
35 | "subgoal": 0,
36 | "timestep": 0,
37 | "task_success": "N/A",
38 | },
39 | },
40 | "feed3": {
41 | "observation": None,
42 | "goal": None,
43 | "status": {
44 | "commanded_task": "N/A",
45 | "subgoal": 0,
46 | "timestep": 0,
47 | "task_success": "N/A",
48 | },
49 | },
50 | "feed4": {
51 | "observation": None,
52 | "goal": None,
53 | "status": {
54 | "commanded_task": "N/A",
55 | "subgoal": 0,
56 | "timestep": 0,
57 | "task_success": "N/A",
58 | },
59 | },
60 | "feed5": {
61 | "observation": None,
62 | "goal": None,
63 | "status": {
64 | "commanded_task": "N/A",
65 | "subgoal": 0,
66 | "timestep": 0,
67 | "task_success": "N/A",
68 | },
69 | },
70 | "feed6": {
71 | "observation": None,
72 | "goal": None,
73 | "status": {
74 | "commanded_task": "N/A",
75 | "subgoal": 0,
76 | "timestep": 0,
77 | "task_success": "N/A",
78 | },
79 | },
80 | "feed7": {
81 | "observation": None,
82 | "goal": None,
83 | "status": {
84 | "commanded_task": "N/A",
85 | "subgoal": 0,
86 | "timestep": 0,
87 | "task_success": "N/A",
88 | },
89 | },
90 | }
91 |
92 |
93 | @app.route("/")
94 | def index():
95 | return render_template("index.html")
96 |
97 |
98 | @app.route("/upload/", methods=["POST"])
99 | def upload_image(robot_idx):
100 | robot_idx = int(robot_idx)
101 | image_type = request.args.get("type", "")
102 | if image_type not in ["observation", "goal"]:
103 | return jsonify({"error": "Invalid image type"}), 400
104 |
105 | file = request.files.get("file")
106 | if file and file.filename:
107 | # Save the file as 'observation.jpg' or 'goal.jpg' in the uploads folder
108 | filename = f"{image_type}_{robot_idx}.jpg"
109 | file_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
110 | with open(file_path, "wb") as f:
111 | portalocker.lock(f, portalocker.LOCK_EX)
112 | f.write(file.read())
113 | portalocker.unlock(f)
114 | return (
115 | jsonify(
116 | {
117 | "message": "Image uploaded successfully",
118 | "type": image_type,
119 | "robot_index": robot_idx,
120 | }
121 | ),
122 | 200,
123 | )
124 | else:
125 | return jsonify({"error": "No file part"}), 400
126 |
127 |
128 | @app.route("/images//", methods=["GET"])
129 | def get_latest_image(robot_idx, image_type):
130 | robot_idx = int(robot_idx)
131 | if image_type not in ["observation", "goal"]:
132 | return jsonify({"error": "Invalid image type"}), 400
133 |
134 | image_path = os.path.join(
135 | app.config["UPLOAD_FOLDER"], f"{image_type}_{robot_idx}.jpg"
136 | )
137 | if os.path.exists(image_path):
138 | return send_file(image_path)
139 | else:
140 | return jsonify({"error": "Image not found"}), 404
141 |
142 |
143 | @app.route("/update_status/", methods=["POST"])
144 | def update_status(robot_idx):
145 | data = request.get_json()
146 | robot_video_feeds["feed" + robot_idx]["status"] = {
147 | "commanded_task": data.get("commanded_task", ""),
148 | "subgoal": data.get("subgoal", 0),
149 | "timestep": data.get("timestep", 0),
150 | "task_success": data.get("task_success", ""),
151 | }
152 | return jsonify({"message": "Status updated successfully"})
153 |
154 |
155 | @app.route("/get_status_data/")
156 | def get_status_data(robot_idx):
157 | return jsonify(robot_video_feeds["feed" + robot_idx]["status"])
158 |
159 |
160 | if __name__ == "__main__":
161 | os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
162 | app.run(debug=True, host="0.0.0.0", port=5000)
163 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/web_viewer/ros_client/run_client.py:
--------------------------------------------------------------------------------
1 | import rospy
2 | from sensor_msgs.msg import Image
3 | import time
4 | import numpy as np
5 | from PIL import Image as PILImage
6 | import requests
7 | import io
8 | import yaml
9 | from yamlinclude import YamlIncludeConstructor
10 | from absl import app, flags
11 | import os
12 |
13 | FLAGS = flags.FLAGS
14 | flags.DEFINE_string(
15 | "config_dir",
16 | None,
17 | "Path to config directory",
18 | required=True,
19 | )
20 |
21 |
22 | def run_ros_subscriber(_):
23 | YamlIncludeConstructor.add_to_loader_class(
24 | loader_class=yaml.FullLoader, base_dir=FLAGS.config_dir
25 | )
26 | with open(os.path.join(FLAGS.config_dir, "config.yaml")) as f:
27 | config = yaml.load(f, Loader=yaml.FullLoader)
28 |
29 | rospy.init_node("image_listener", anonymous=True)
30 |
31 | # Function to handle the image message once received
32 | def image_callback(message):
33 | rospy.loginfo("Received an image!")
34 | str_message = str(message)
35 | image_data = str_message[str_message.index("[") + 1 : str_message.index("]")]
36 | image_data = image_data.replace(",", "").split()
37 | image_data = [int(elem) for elem in image_data]
38 | image = np.array(image_data, dtype=np.uint8)
39 |
40 | # The shape of the image is 480 x 640
41 | image = image.reshape(480, 640, 3)
42 | image = image[:, :, ::-1] # convert from BGR to RGB
43 |
44 | img = PILImage.fromarray(image)
45 | img = img.resize((512, 512), PILImage.LANCZOS) # make image square 512x512
46 |
47 | buffer = io.BytesIO()
48 | img.save(buffer, format="JPEG")
49 | buffer.seek(0)
50 | files = {"file": ("image.jpg", buffer.getvalue(), "image/jpeg")}
51 |
52 | # We're sending this to our main web server
53 | url = (
54 | "http://"
55 | + config["general_params"]["web_viewer_ip"]
56 | + ":"
57 | + str(config["general_params"]["web_viewer_port"])
58 | + "/upload/"
59 | + str(config["general_params"]["robot_id"])
60 | + "?type=observation"
61 | )
62 | response = requests.post(url, files=files)
63 |
64 | rospy.Subscriber("/blue/image_raw", Image, image_callback)
65 |
66 | rospy.spin() # Spin until the node is shut down
67 |
68 |
69 | if __name__ == "__main__":
70 | app.run(run_ros_subscriber)
71 |
--------------------------------------------------------------------------------
/data_collection/orchestrator/web_viewer/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Image Feeds
6 |
86 |
87 |
88 |
95 |
96 |
170 |
171 |
--------------------------------------------------------------------------------
/data_collection/requirements.txt:
--------------------------------------------------------------------------------
1 | distrax==0.1.2
2 | flax==0.7.0
3 | transformers==4.33.1
4 | dlimp @ git+https://github.com/zhouzypaul/dlimp@2df85dc98a7d564fc9c7c5ff2bfda26361526483
5 | einops==0.6.1
6 | wandb==0.15.5
7 | tensorflow==2.15.0.post1
8 | tensorflow-probability==0.23.0
9 | ml-collections==0.1.0
10 | jax==0.4.20
11 | jaxlib==0.4.20
12 | optax==0.1.5
13 | diffusers==0.18.2
14 | ml-dtypes==0.2.0
15 | pyyaml-include==1.3.2
16 | slack_sdk
17 | Pillow==10.1
18 | rospkg==1.5.0
19 | pyyaml-include==1.3.2
20 | opencv-python==4.9.0.80
21 | funcsigs==1.0.2
22 | pyquaternion==0.9.9
23 | openai==1.14.1
24 | nltk==3.8.1
25 | edgeml @ git+https://github.com/youliangtan/edgeml.git
--------------------------------------------------------------------------------
/data_collection/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="orchestrator",
5 | packages=find_packages(),
6 | version="0.0.1",
7 | install_requires=[
8 | "absl-py",
9 | "diffusers[flax]",
10 | "ml_collections",
11 | "tensorflow",
12 | "wandb",
13 | "einops",
14 | ],
15 | )
16 |
--------------------------------------------------------------------------------
/data_collection/ssh_port_forward.sh:
--------------------------------------------------------------------------------
1 | ssh -L 6000:localhost:5000 -N -f -C -o ExitOnForwardFailure=yes $USER@128.32.162.191
2 | ssh -L 7000:localhost:6000 -N -f -C -o ExitOnForwardFailure=yes $USER@128.32.162.191
3 |
--------------------------------------------------------------------------------
/data_collection/start_robot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # check that OPENAI_API_KEY is set
4 | if [ -z "$OPENAI_API_KEY" ]; then
5 | echo "OPENAI_API_KEY is not set. Please set it before running the script."
6 | exit 1
7 | fi
8 |
9 | # Check if the first argument is provided
10 | if [ -z "$1" ]; then
11 | echo "No argument provided. Please provide a number between 0 and 4."
12 | exit 1
13 | fi
14 |
15 | # Check if the first argument is a valid number between 0 and 4
16 | if ! [[ "$1" =~ ^[0-4]$ ]]; then
17 | echo "Invalid argument. Please provide a number between 0 and 4."
18 | exit 1
19 | fi
20 |
21 | # Extract the first argument and the rest of the arguments
22 | n="$1"
23 | shift
24 | additional_args="$@"
25 |
26 | # If reset flag is true, execute the reset commands
27 | # echo "Resetting ports 7000 and 6000"
28 | # lsof -ti:7000 | xargs kill -9
29 | # lsof -ti:6000 | xargs kill -9
30 |
31 | # Execute the ssh port forwarding script
32 | echo "Executing: bash ssh_port_forward.sh"
33 | bash ssh_port_forward.sh
34 |
35 | # Construct the command based on the first argument
36 | config_dir="config/berkeley_robot_$n"
37 | command="python orchestrator/robot/main.py --config_dir $config_dir $additional_args"
38 |
39 | # Execute the Python command
40 | echo "Executing: $command"
41 | $command
42 |
--------------------------------------------------------------------------------
/media/autonomous_data_collection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/media/autonomous_data_collection.png
--------------------------------------------------------------------------------
/media/soar_logo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/media/soar_logo.jpeg
--------------------------------------------------------------------------------
/media/soar_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/media/soar_teaser.png
--------------------------------------------------------------------------------
/model_training/README.md:
--------------------------------------------------------------------------------
1 | # Model training code
2 |
3 | This directory contains a self-contained python project for training goal-conditioned and language conditioned policies on BridgeData and on Soar-Data.
4 |
5 | ## Installation
6 | Run in the current directory
7 | ```bash
8 | pip install -e .
9 | pip install -r requirements.txt
10 | ```
11 |
12 | ## Structure
13 | - `experiments/`: Contains the main training script `train.py` and the configuration files `train_config.py` and `data_config.py`.
14 | - 'jaxrl_m`: the main library for training models with Jax.
15 | - `jaxrl_m/agents/`: Contains the implementation of the agents.
16 | - `jaxrl_m/data/`: Contains the data processing and data loading code.
17 |
18 | ## Training
19 | In the current directory, run
20 | ```bash
21 | bash experiments/scripts/launch.sh
22 | ```
23 | This will launch [train.py](experiments/train.py) with the default arguments specified in [train_config.py](experiments/configs/train_config.py) and [data_config.py](experiments/configs/data_config.py).
24 |
--------------------------------------------------------------------------------
/model_training/experiments/configs/data_config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import ml_collections
5 |
6 |
7 | # BridgeDatav2 metadata
8 |
9 | ACT_MEAN = [
10 | 1.9296819e-04,
11 | 1.3667766e-04,
12 | -1.4583133e-04,
13 | -1.8390431e-04,
14 | -3.0808983e-04,
15 | 2.7425270e-04,
16 | 5.9716219e-01,
17 | ]
18 |
19 | ACT_STD = [
20 | 0.00912848,
21 | 0.0127196,
22 | 0.01229497,
23 | 0.02606696,
24 | 0.02875283,
25 | 0.07807977,
26 | 0.48710242,
27 | ]
28 |
29 | ACT_MIN = [
30 | -0.0437546,
31 | -0.052831028,
32 | -0.035931006,
33 | -0.14489305,
34 | -0.15591072,
35 | -0.26039174,
36 | -0.780331,
37 | ] # 0.1% quantile
38 |
39 | ACT_MAX = [
40 | 0.04158026,
41 | 0.05223833,
42 | 0.05382493,
43 | 0.15559858,
44 | 0.142592,
45 | 0.25956747,
46 | 0.79311615,
47 | ] # 99.9% quantile
48 |
49 | ACTION_PROPRIO_METADATA = {
50 | "action": {
51 | "mean": np.array(ACT_MEAN),
52 | "std": np.array(ACT_STD),
53 | "min": np.array(ACT_MIN),
54 | "max": np.array(ACT_MAX),
55 | },
56 | # TODO compute these
57 | "proprio": {
58 | "mean": np.array(ACT_MEAN),
59 | "std": np.array(ACT_STD),
60 | "min": np.array(ACT_MIN),
61 | "max": np.array(ACT_MAX),
62 | },
63 | }
64 |
65 |
66 | def get_config(config_string):
67 | possible_structures = {
68 | "all": ml_collections.ConfigDict(
69 | {
70 | "pretraining_data": [
71 | "gs://gresearch/robotics/bridge/0.1.0/"
72 | ],
73 | "autonomous_data": [
74 | os.path.expanduser("~/tensorflow_datasets/soar_dataset/1.0.0"),
75 | ],
76 | "exclude": [],
77 | "sampling_weights": {
78 | "pretraining_data": 0.8,
79 | "autonomous_data_successes": 0.2,
80 | "autonomous_data_failures": 0.0,
81 | },
82 | }
83 | ),
84 | }
85 | return possible_structures[config_string]
86 |
--------------------------------------------------------------------------------
/model_training/experiments/configs/train_config.py:
--------------------------------------------------------------------------------
1 | from ml_collections import ConfigDict
2 |
3 |
4 | def get_config(config_string):
5 | base_real_config = dict(
6 | batch_size=512,
7 | num_steps=int(1001000),
8 | log_interval=1000,
9 | eval_interval=25000,
10 | save_interval=25000,
11 | num_val_trajs=8,
12 | num_val_batches=8,
13 | save_dir="~/jaxrl_log",
14 | resume_path="",
15 | seed=42,
16 | )
17 |
18 | base_data_config = dict(
19 | # action_merge_horizon=2,
20 | shuffle_buffer_size=25000,
21 | augment=True,
22 | augment_next_obs_goal_differently=False,
23 | augment_kwargs=dict(
24 | random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
25 | random_brightness=[0.2],
26 | random_contrast=[0.8, 1.2],
27 | random_saturation=[0.8, 1.2],
28 | random_hue=[0.1],
29 | augment_order=[
30 | "random_resized_crop",
31 | "random_brightness",
32 | "random_contrast",
33 | "random_saturation",
34 | "random_hue",
35 | ],
36 | ),
37 | )
38 |
39 | possible_structures = {
40 | "gc_bc_offline_bridge": ConfigDict(
41 | dict(
42 | agent="gc_bc",
43 | agent_kwargs=dict(
44 | network_kwargs=dict(
45 | hidden_dims=(
46 | 256,
47 | 256,
48 | 256,
49 | ),
50 | dropout_rate=0.1,
51 | ),
52 | policy_kwargs=dict(
53 | tanh_squash_distribution=True,
54 | std_parameterization="fixed",
55 | fixed_std=[1, 1, 1, 1, 1, 1, 0.1],
56 | ),
57 | early_goal_concat=True,
58 | shared_goal_encoder=True,
59 | use_proprio=False,
60 | learning_rate=3e-4,
61 | warmup_steps=2000,
62 | decay_steps=int(2e6),
63 | ),
64 | dataset_kwargs=dict(
65 | goal_relabeling_strategy="geometric",
66 | goal_relabeling_kwargs=dict(reached_proportion=0.0, discount=0.98),
67 | normalization_type="normal",
68 | **base_data_config,
69 | ),
70 | encoder="resnetv1-34-bridge", # diff: bridge release use resnet-50
71 | encoder_kwargs=dict(
72 | pooling_method="avg",
73 | add_spatial_coordinates=True,
74 | act="swish",
75 | ),
76 | **base_real_config,
77 | )
78 | ),
79 | }
80 |
81 | return possible_structures[config_string]
82 |
--------------------------------------------------------------------------------
/model_training/experiments/scripts/launch.sh:
--------------------------------------------------------------------------------
1 | CMD="python experiments/train.py \
2 | --config experiments/configs/train_config.py:gc_bc_offline_bridge \
3 | --data_config experiments/configs/data_config.py:all \
4 | "
5 |
6 | # take in command line args too
7 | $CMD $@
8 |
--------------------------------------------------------------------------------
/model_training/experiments/train.py:
--------------------------------------------------------------------------------
1 | import random
2 | import traceback
3 | from functools import partial
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 | import tensorflow as tf
9 | import tqdm
10 | import wandb
11 | from absl import app, flags, logging
12 | from flax.training import checkpoints
13 | from ml_collections import config_flags
14 |
15 | from jaxrl_m.agents import agents
16 | from jaxrl_m.common.common import shard_batch
17 | from jaxrl_m.common.wandb import WandBLogger
18 | from jaxrl_m.data.dataset import WidowXDataset
19 | from jaxrl_m.utils.timer_utils import Timer
20 | from jaxrl_m.vision import encoders
21 |
22 | try:
23 | from jax_smi import initialise_tracking # type: ignore
24 |
25 | initialise_tracking()
26 | except ImportError:
27 | pass
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | flags.DEFINE_string("exp_name", "", "Experiment name.")
32 | flags.DEFINE_list("tag", list(), "Name of experiment")
33 | flags.DEFINE_string("group", None, "Group of the wandb experiments")
34 | flags.DEFINE_bool("debug", False, "Debug config")
35 |
36 | config_flags.DEFINE_config_file(
37 | "config",
38 | None,
39 | "File path to the training hyperparameter configuration.",
40 | lock_config=False,
41 | )
42 |
43 | config_flags.DEFINE_config_file(
44 | "data_config",
45 | None,
46 | "File path to the bridgedata configuration.",
47 | lock_config=False,
48 | )
49 |
50 |
51 | def main(_):
52 | devices = jax.local_devices()
53 | num_devices = len(devices)
54 | assert FLAGS.config.batch_size % num_devices == 0
55 |
56 | # we shard the leading dimension (batch dimension) accross all devices evenly
57 | sharding = jax.sharding.PositionalSharding(devices)
58 | shard_fn = partial(shard_batch, sharding=sharding)
59 |
60 | # prevent tensorflow from using GPUs
61 | tf.config.set_visible_devices([], "GPU")
62 |
63 | # set up wandb and logging
64 | wandb_config = WandBLogger.get_default_config()
65 | wandb_config.update(
66 | {
67 | "project": f"jaxrl_{FLAGS.config.agent}_autonomous_data",
68 | "exp_descriptor": FLAGS.exp_name,
69 | "tag": FLAGS.tag,
70 | "group": FLAGS.group,
71 | }
72 | )
73 | wandb_logger = WandBLogger(
74 | wandb_config=wandb_config,
75 | variant=FLAGS.config.to_dict(),
76 | debug=FLAGS.debug,
77 | )
78 |
79 | save_dir = tf.io.gfile.join(
80 | FLAGS.config.save_dir,
81 | wandb_logger.config.project,
82 | f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}",
83 | )
84 |
85 | # load datasets
86 | random.seed(FLAGS.config.seed)
87 | train_paths = []
88 | if FLAGS.data_config.sampling_weights.pretraining_data > 0:
89 | train_paths += [FLAGS.data_config.pretraining_data]
90 | if FLAGS.data_config.sampling_weights.autonomous_data_successes > 0:
91 | train_paths += [FLAGS.data_config.autonomous_data]
92 | if FLAGS.data_config.sampling_weights.autonomous_data_failures > 0:
93 | train_paths += [FLAGS.data_config.autonomous_data]
94 |
95 | # create sample weights for training
96 | train_sample_weights = [
97 | FLAGS.data_config.sampling_weights["pretraining_data"],
98 | FLAGS.data_config.sampling_weights["autonomous_data_successes"],
99 | FLAGS.data_config.sampling_weights["autonomous_data_failures"],
100 | ]
101 | train_sample_weights = [
102 | weight for weight in train_sample_weights if weight > 0
103 | ] # remove 0s from the sample weights
104 | assert (
105 | sum(train_sample_weights) == 1.0
106 | ), f"Sample weights must sum to 1.0, got {sum(train_sample_weights)}"
107 |
108 | # pick out the splits needed from the dataset
109 | train_data_splits = []
110 | if FLAGS.data_config.sampling_weights.pretraining_data > 0:
111 | train_data_splits.append("train")
112 | if FLAGS.data_config.sampling_weights.autonomous_data_successes > 0:
113 | train_data_splits.append("success")
114 | if FLAGS.data_config.sampling_weights.autonomous_data_failures > 0:
115 | train_data_splits.append("failure")
116 |
117 | train_data = WidowXDataset(
118 | train_paths,
119 | FLAGS.config.seed,
120 | batch_size=FLAGS.config.batch_size,
121 | train=True,
122 | sample_weights=train_sample_weights,
123 | data_splits=train_data_splits,
124 | **FLAGS.config.dataset_kwargs,
125 | )
126 | val_data = WidowXDataset(
127 | FLAGS.data_config.pretraining_data,
128 | FLAGS.config.seed,
129 | batch_size=FLAGS.config.batch_size,
130 | train=False,
131 | sample_weights=None,
132 | data_splits=["val"],
133 | **FLAGS.config.dataset_kwargs,
134 | )
135 |
136 | train_data_iter = map(shard_fn, train_data.iterator())
137 |
138 | example_batch = next(train_data_iter)
139 | logging.info(f"Batch size: {example_batch['observations']['image'].shape[0]}")
140 | logging.info(f"Number of devices: {num_devices}")
141 | logging.info(
142 | f"Batch size per device: {example_batch['observations']['image'].shape[0] // num_devices}"
143 | )
144 |
145 | # define encoder
146 | encoder_def = encoders[FLAGS.config.encoder](**FLAGS.config.encoder_kwargs)
147 |
148 | # initialize agent
149 | rng = jax.random.PRNGKey(FLAGS.config.seed)
150 | rng, construct_rng = jax.random.split(rng)
151 | agent = agents[FLAGS.config.agent].create(
152 | rng=construct_rng,
153 | observations=example_batch["observations"],
154 | goals=example_batch["goals"],
155 | actions=example_batch["actions"],
156 | encoder_def=encoder_def,
157 | **FLAGS.config.agent_kwargs,
158 | )
159 | if FLAGS.config.get("resume_path", "") != "":
160 | agent = checkpoints.restore_checkpoint(FLAGS.config.resume_path, target=agent)
161 | # replicate agent across devices
162 | # need the jnp.array to avoid a bug where device_put doesn't recognize primitives
163 | agent = jax.device_put(jax.tree_map(jnp.array, agent), sharding.replicate())
164 |
165 | timer = Timer()
166 | for i in tqdm.tqdm(range(int(FLAGS.config.num_steps - agent.state.step))):
167 | try:
168 | timer.tick("total")
169 |
170 | timer.tick("dataset")
171 | batch = shard_batch(next(train_data_iter), sharding)
172 | timer.tock("dataset")
173 |
174 | timer.tick("train")
175 | agent, update_info = agent.update(batch)
176 | timer.tock("train")
177 |
178 | if agent.state.step % FLAGS.config.eval_interval == 0:
179 | logging.info("Validation...")
180 | timer.tick("val")
181 |
182 | # plot debug metrics of validation data
183 | val_metrics = []
184 | j = 0
185 | val_iter = map(shard_fn, val_data.iterator())
186 | for val_batch in val_iter:
187 | rng, val_rng = jax.random.split(rng)
188 | val_metrics.append(agent.get_debug_metrics(val_batch, seed=val_rng))
189 | j += 1
190 | if j >= FLAGS.config.num_val_batches:
191 | break
192 | val_metrics = jax.tree_map(lambda *xs: np.mean(xs), *val_metrics)
193 | wandb_logger.log({"validation": val_metrics}, step=agent.state.step)
194 |
195 | timer.tock("val")
196 |
197 | if agent.state.step % FLAGS.config.save_interval == 0:
198 | logging.info("Saving checkpoint...")
199 | checkpoint_path = checkpoints.save_checkpoint(
200 | save_dir, agent, step=agent.state.step, keep=1e6
201 | )
202 | logging.info("Saved checkpoint to %s", checkpoint_path)
203 |
204 | timer.tock("total")
205 |
206 | if agent.state.step % FLAGS.config.log_interval == 0:
207 | update_info = jax.device_get(update_info)
208 | wandb_logger.log({"training": update_info}, step=agent.state.step)
209 |
210 | wandb_logger.log(
211 | {"timer": timer.get_average_times()}, step=agent.state.step
212 | )
213 | except tf.errors.OpError as e:
214 | # sometimes tfds will have trouble communicating with cloud storage bucket for some reason...
215 | print(f"Error in iteration {i}: {e}")
216 | print("Skipping to next iteration...")
217 | traceback.print_exc()
218 |
219 | # avoid timer tock errors
220 | timer.force_tock_everything()
221 |
222 | continue
223 |
224 |
225 | if __name__ == "__main__":
226 | app.run(main)
227 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .continuous.gc_bc import GCBCAgent
2 |
3 | agents = {
4 | "gc_bc": GCBCAgent,
5 | }
6 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/agents/continuous/gc_bc.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from functools import partial
3 | from typing import Any, Optional
4 |
5 | import flax
6 | import flax.linen as nn
7 | import jax
8 | import jax.numpy as jnp
9 | import numpy as np
10 | from flax.core import FrozenDict
11 |
12 | from jaxrl_m.common.common import JaxRLTrainState, ModuleDict, nonpytree_field
13 | from jaxrl_m.common.optimizers import make_optimizer
14 | from jaxrl_m.common.encoding import GCEncodingWrapper, LCEncodingWrapper
15 | from jaxrl_m.common.optimizers import make_optimizer
16 | from jaxrl_m.common.typing import Batch, PRNGKey
17 | from jaxrl_m.networks.actor_critic_nets import Policy
18 | from jaxrl_m.networks.mlp import MLP
19 |
20 |
21 | class GCBCAgent(flax.struct.PyTreeNode):
22 | state: JaxRLTrainState
23 | lr_schedule: Any = nonpytree_field()
24 |
25 | @partial(jax.jit, static_argnames="pmap_axis")
26 | def update(self, batch: Batch, pmap_axis: str = None):
27 | def loss_fn(params, rng):
28 | rng, key = jax.random.split(rng)
29 | dist = self.state.apply_fn(
30 | {"params": params},
31 | (batch["observations"], batch["goals"]),
32 | temperature=1.0,
33 | train=True,
34 | rngs={"dropout": key},
35 | name="actor",
36 | )
37 | pi_actions = dist.mode()
38 | log_probs = dist.log_prob(batch["actions"])
39 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1)
40 | actor_loss = -(log_probs).mean()
41 | actor_std = dist.stddev().mean(axis=1)
42 |
43 | return actor_loss, {
44 | "actor_loss": actor_loss,
45 | "mse": mse.mean(),
46 | "log_probs": log_probs.mean(),
47 | "pi_actions": pi_actions.mean(),
48 | "mean_std": actor_std.mean(),
49 | "max_std": actor_std.max(),
50 | }
51 |
52 | # compute gradients and update params
53 | new_state, info = self.state.apply_loss_fns(
54 | loss_fn,
55 | pmap_axis=pmap_axis,
56 | has_aux=True,
57 | )
58 |
59 | # log learning rates
60 | info["lr"] = self.lr_schedule(self.state.step)
61 |
62 | return self.replace(state=new_state), info
63 |
64 | @partial(jax.jit, static_argnames="argmax")
65 | def sample_actions(
66 | self,
67 | observations: np.ndarray,
68 | goals: np.ndarray,
69 | *,
70 | seed: Optional[PRNGKey] = None,
71 | temperature: float = 1.0,
72 | argmax=False,
73 | ) -> jnp.ndarray:
74 | dist = self.state.apply_fn(
75 | {"params": self.state.params},
76 | (observations, goals),
77 | temperature=temperature,
78 | name="actor",
79 | )
80 | if argmax:
81 | actions = dist.mode()
82 | else:
83 | actions = dist.sample(seed=seed)
84 | return actions, dist.mode()
85 |
86 | @jax.jit
87 | def get_debug_metrics(self, batch, **kwargs):
88 | dist = self.state.apply_fn(
89 | {"params": self.state.params},
90 | (batch["observations"], batch["goals"]),
91 | temperature=1.0,
92 | name="actor",
93 | )
94 | pi_actions = dist.mode()
95 | log_probs = dist.log_prob(batch["actions"])
96 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1)
97 |
98 | return {
99 | "mse": mse,
100 | "log_probs": log_probs,
101 | "pi_actions": pi_actions,
102 | }
103 |
104 | @classmethod
105 | def create(
106 | cls,
107 | rng: PRNGKey,
108 | # example arrays for model init
109 | observations: FrozenDict,
110 | actions: jnp.ndarray,
111 | goals: FrozenDict,
112 | # agent config
113 | encoder_def: nn.Module,
114 | language_conditioned: bool = False,
115 | # should only be set if not language conditioned
116 | shared_goal_encoder: Optional[bool] = None,
117 | early_goal_concat: Optional[bool] = None,
118 | # other shared network config
119 | use_proprio: bool = False,
120 | network_kwargs: dict = {
121 | "hidden_dims": [256, 256],
122 | },
123 | policy_kwargs: dict = {
124 | "tanh_squash_distribution": False,
125 | "std_parameterization": "exp",
126 | },
127 | # optimizer config
128 | learning_rate: float = 3e-4,
129 | warmup_steps: int = 1000,
130 | decay_steps: int = 1000000,
131 | freeze_encoder: bool = False,
132 | ):
133 | if not language_conditioned:
134 | if shared_goal_encoder is None or early_goal_concat is None:
135 | raise ValueError(
136 | "If not language conditioned, shared_goal_encoder and early_goal_concat must be set"
137 | )
138 |
139 | if early_goal_concat:
140 | # passing None as the goal encoder causes early goal concat
141 | goal_encoder_def = None
142 | else:
143 | if shared_goal_encoder:
144 | goal_encoder_def = encoder_def
145 | else:
146 | goal_encoder_def = copy.deepcopy(encoder_def)
147 |
148 | encoder_def = GCEncodingWrapper(
149 | encoder=encoder_def,
150 | goal_encoder=goal_encoder_def,
151 | use_proprio=use_proprio,
152 | stop_gradient=freeze_encoder,
153 | )
154 | else:
155 | if shared_goal_encoder is not None or early_goal_concat is not None:
156 | raise ValueError(
157 | "If language conditioned, shared_goal_encoder and early_goal_concat must not be set"
158 | )
159 | encoder_def = LCEncodingWrapper(
160 | encoder=encoder_def,
161 | use_proprio=use_proprio,
162 | stop_gradient=freeze_encoder,
163 | )
164 |
165 | network_kwargs["activate_final"] = True
166 | networks = {
167 | "actor": Policy(
168 | encoder_def,
169 | MLP(**network_kwargs),
170 | action_dim=actions.shape[-1],
171 | **policy_kwargs,
172 | )
173 | }
174 |
175 | model_def = ModuleDict(networks)
176 |
177 | # create optimizer
178 | tx, lr_schedule = make_optimizer(
179 | learning_rate=learning_rate,
180 | warmup_steps=warmup_steps,
181 | cosine_decay_steps=decay_steps if decay_steps is not None else None,
182 | weight_decay=0.001,
183 | beta2=0.98,
184 | clip_grad_norm=1.0,
185 | return_lr_schedule=True,
186 | )
187 |
188 | rng, init_rng = jax.random.split(rng)
189 | params = jax.jit(model_def.init)(init_rng, actor=[(observations, goals)])[
190 | "params"
191 | ]
192 |
193 | rng, create_rng = jax.random.split(rng)
194 | state = JaxRLTrainState.create(
195 | apply_fn=model_def.apply,
196 | params=params,
197 | txs=tx,
198 | rng=create_rng,
199 | )
200 |
201 | return cls(state, lr_schedule)
202 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/common/encoding.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Tuple
2 |
3 | import flax
4 | import flax.linen as nn
5 | import jax
6 | import jax.numpy as jnp
7 | from einops import rearrange, repeat
8 |
9 |
10 | class EncodingWrapper(nn.Module):
11 | """
12 | Encodes observations into a single flat encoding, adding additional
13 | functionality for adding proprioception and stopping the gradient.
14 |
15 | Args:
16 | encoder: The encoder network.
17 | use_proprio: Whether to concatenate proprioception (after encoding).
18 | stop_gradient: Whether to stop the gradient after the encoder.
19 | """
20 |
21 | encoder: nn.Module
22 | use_proprio: bool
23 | stop_gradient: bool
24 | enable_stacking: bool = False
25 |
26 | def __call__(self, observations: Dict[str, jnp.ndarray]) -> jnp.ndarray:
27 | if isinstance(observations, flax.core.FrozenDict) or isinstance(
28 | observations, dict
29 | ):
30 | obs = observations["image"]
31 | if self.enable_stacking:
32 | # Combine stacking and channels into a single dimension
33 | if len(obs.shape) == 4:
34 | obs = rearrange(obs, "T H W C -> H W (T C)")
35 | if len(obs.shape) == 5:
36 | obs = rearrange(obs, "B T H W C -> B H W (T C)")
37 |
38 | else:
39 | obs = observations
40 |
41 | encoding = self.encoder(obs)
42 |
43 | if self.use_proprio:
44 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1)
45 | if self.stop_gradient:
46 | encoding = jax.lax.stop_gradient(encoding)
47 | return encoding
48 |
49 |
50 | class GCEncodingWrapper(nn.Module):
51 | """
52 | Encodes observations and goals into a single flat encoding. Handles all the
53 | logic about when/how to combine observations and goals.
54 |
55 | Takes a tuple (observations, goals) as input.
56 |
57 | Args:
58 | encoder: The encoder network for observations.
59 | goal_encoder: The encoder to use for goals (optional). If None, early
60 | goal concatenation is used, i.e. the goal is concatenated to the
61 | observation channel-wise before passing it through the encoder.
62 | use_proprio: Whether to concatenate proprioception (after encoding).
63 | stop_gradient: Whether to stop the gradient after the encoder.
64 | """
65 |
66 | encoder: nn.Module
67 | goal_encoder: Optional[nn.Module]
68 | use_proprio: bool
69 | stop_gradient: bool
70 |
71 | def __call__(
72 | self,
73 | observations_and_goals: Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]],
74 | ) -> jnp.ndarray:
75 | observations, goals = observations_and_goals
76 |
77 | if len(observations["image"].shape) == 5:
78 | # obs history case
79 | batch_size, obs_horizon = observations["image"].shape[:2]
80 | # fold batch_size into obs_horizon to encode each frame separately
81 | obs_image = rearrange(observations["image"], "B T H W C -> (B T) H W C")
82 | # repeat goals so that there's a goal for each frame
83 | goal_image = repeat(
84 | goals["image"], "B H W C -> (B repeat) H W C", repeat=obs_horizon
85 | )
86 | else:
87 | obs_image = observations["image"]
88 | goal_image = goals["image"]
89 |
90 | if self.goal_encoder is None:
91 | # early goal concat
92 | encoder_inputs = jnp.concatenate([obs_image, goal_image], axis=-1)
93 | encoding = self.encoder(encoder_inputs)
94 | else:
95 | # late fusion
96 | encoding = self.encoder(obs_image)
97 | goal_encoding = self.goal_encoder(goals["image"])
98 | encoding = jnp.concatenate([encoding, goal_encoding], axis=-1)
99 |
100 | if len(observations["image"].shape) == 5:
101 | # unfold obs_horizon from batch_size
102 | encoding = rearrange(
103 | encoding, "(B T) F -> B (T F)", B=batch_size, T=obs_horizon
104 | )
105 |
106 | if self.use_proprio:
107 | if len(encoding.shape) == 2 and len(observations["proprio"].shape) == 3:
108 | # edge case
109 | encoding = jnp.concatenate(
110 | [encoding, observations["proprio"][:, 0, :]], axis=-1
111 | )
112 | else:
113 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1)
114 |
115 | if self.stop_gradient:
116 | encoding = jax.lax.stop_gradient(encoding)
117 |
118 | return encoding
119 |
120 |
121 | class LCEncodingWrapper(nn.Module):
122 | """
123 | Encodes observations and language instructions into a single flat encoding.
124 |
125 | Takes a tuple (observations, goals) as input, where goals contains the language instruction.
126 |
127 | Args:
128 | encoder: The encoder network for observations.
129 | use_proprio: Whether to concatenate proprioception (after encoding).
130 | stop_gradient: Whether to stop the gradient after the encoder.
131 | """
132 |
133 | encoder: nn.Module
134 | use_proprio: bool
135 | stop_gradient: bool
136 |
137 | def __call__(
138 | self,
139 | observations_and_goals: Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]],
140 | ) -> jnp.ndarray:
141 | observations, goals = observations_and_goals
142 |
143 | if len(observations["image"].shape) == 5:
144 | # obs history case
145 | batch_size, obs_horizon = observations["image"].shape[:2]
146 | # fold batch_size into obs_horizon to encode each frame separately
147 | obs_image = rearrange(observations["image"], "B T H W C -> (B T) H W C")
148 | # repeat language so that there's an instruction for each frame
149 | language = repeat(
150 | goals["language"], "B E -> (B repeat) E", repeat=obs_horizon
151 | )
152 | else:
153 | obs_image = observations["image"]
154 | language = goals["language"]
155 |
156 | encoding = self.encoder(obs_image, cond_var=language)
157 |
158 | if len(observations["image"].shape) == 5:
159 | # unfold obs_horizon from batch_size
160 | encoding = rearrange(
161 | encoding, "(B T) F -> B (T F)", B=batch_size, T=obs_horizon
162 | )
163 |
164 | if self.use_proprio:
165 | encoding = jnp.concatenate([encoding, observations["proprio"]], axis=-1)
166 |
167 | if self.stop_gradient:
168 | encoding = jax.lax.stop_gradient(encoding)
169 |
170 | return encoding
171 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/common/optimizers.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import optax
4 |
5 |
6 | def make_optimizer(
7 | learning_rate: float = 3e-4,
8 | warmup_steps: int = 0,
9 | cosine_decay_steps: Optional[int] = None,
10 | weight_decay: Optional[float] = None,
11 | beta2: Optional[float] = None,
12 | clip_grad_norm: Optional[float] = None,
13 | return_lr_schedule: bool = False,
14 | ) -> optax.GradientTransformation:
15 | if cosine_decay_steps is not None:
16 | learning_rate_schedule = optax.warmup_cosine_decay_schedule(
17 | init_value=0.0,
18 | peak_value=learning_rate,
19 | warmup_steps=warmup_steps,
20 | decay_steps=cosine_decay_steps,
21 | end_value=0.0,
22 | )
23 | else:
24 | learning_rate_schedule = optax.join_schedules(
25 | [
26 | optax.linear_schedule(0.0, learning_rate, warmup_steps),
27 | optax.constant_schedule(learning_rate),
28 | ],
29 | [warmup_steps],
30 | )
31 |
32 | # Define optimizers
33 | @optax.inject_hyperparams
34 | def optimizer(learning_rate: float, weight_decay: Optional[float]):
35 | optimizer_stages = []
36 |
37 | if clip_grad_norm is not None:
38 | print("Info: turning on gradient clipping")
39 | optimizer_stages.append(optax.clip_by_global_norm(clip_grad_norm))
40 |
41 | if weight_decay is not None:
42 | optimizer_stages.append(
43 | optax.adamw(
44 | learning_rate=learning_rate,
45 | weight_decay=weight_decay,
46 | b2=0.999 if beta2 is None else beta2,
47 | )
48 | )
49 | else:
50 | optimizer_stages.append(
51 | optax.adam(
52 | learning_rate=learning_rate, b2=0.999 if beta2 is None else beta2
53 | )
54 | )
55 |
56 | return optax.chain(*optimizer_stages)
57 |
58 | if return_lr_schedule:
59 | return (
60 | optimizer(learning_rate=learning_rate_schedule, weight_decay=weight_decay),
61 | learning_rate_schedule,
62 | )
63 | else:
64 | return optimizer(
65 | learning_rate=learning_rate_schedule, weight_decay=weight_decay
66 | )
67 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/common/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Sequence, Union
2 |
3 | import flax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | PRNGKey = Any
9 | Params = flax.core.FrozenDict[str, Any]
10 | Shape = Sequence[int]
11 | Dtype = Any # this could be a real type?
12 | InfoDict = Dict[str, float]
13 | Array = Union[np.ndarray, jnp.ndarray, tf.Tensor]
14 | Data = Union[Array, Dict[str, "Data"]]
15 | Batch = Dict[str, Data]
16 | # A method to be passed into TrainState.__call__
17 | ModuleMethod = Union[str, Callable, None]
18 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/common/wandb.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import tempfile
3 | from copy import copy
4 | from socket import gethostname
5 |
6 | import absl.flags as flags
7 | import ml_collections
8 | import wandb
9 |
10 |
11 | def _recursive_flatten_dict(d: dict):
12 | keys, values = [], []
13 | for key, value in d.items():
14 | if isinstance(value, dict):
15 | sub_keys, sub_values = _recursive_flatten_dict(value)
16 | keys += [f"{key}/{k}" for k in sub_keys]
17 | values += sub_values
18 | else:
19 | keys.append(key)
20 | values.append(value)
21 | return keys, values
22 |
23 |
24 | class WandBLogger(object):
25 | @staticmethod
26 | def get_default_config():
27 | config = ml_collections.ConfigDict()
28 | config.project = "jaxrl_m" # WandB Project Name
29 | config.entity = ml_collections.config_dict.FieldReference(None, field_type=str)
30 | # Which entity to log as (default: your own user)
31 | config.exp_descriptor = "" # Run name (doesn't have to be unique)
32 | # Unique identifier for run (will be automatically generated unless
33 | # provided)
34 | config.unique_identifier = ""
35 | config.group = None
36 | return config
37 |
38 | def __init__(
39 | self,
40 | wandb_config,
41 | variant,
42 | wandb_output_dir=None,
43 | debug=False,
44 | ):
45 | self.config = wandb_config
46 | if self.config.unique_identifier == "":
47 | self.config.unique_identifier = datetime.datetime.now().strftime(
48 | "%Y%m%d_%H%M%S"
49 | )
50 |
51 | self.config.experiment_id = self.experiment_id = (
52 | f"{self.config.exp_descriptor}_{self.config.unique_identifier}" # NOQA
53 | )
54 |
55 | print(self.config)
56 |
57 | if wandb_output_dir is None:
58 | wandb_output_dir = tempfile.mkdtemp()
59 |
60 | self._variant = copy(variant)
61 |
62 | if "hostname" not in self._variant:
63 | self._variant["hostname"] = gethostname()
64 |
65 | if debug:
66 | mode = "disabled"
67 | else:
68 | mode = "online"
69 |
70 | self.run = wandb.init(
71 | config=self._variant,
72 | project=self.config.project,
73 | entity=self.config.entity,
74 | group=self.config.group,
75 | tags=self.config.tag,
76 | dir=wandb_output_dir,
77 | id=self.config.experiment_id,
78 | save_code=True,
79 | mode=mode,
80 | )
81 |
82 | if flags.FLAGS.is_parsed():
83 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS}
84 | else:
85 | flag_dict = {}
86 | for k in flag_dict:
87 | if isinstance(flag_dict[k], ml_collections.ConfigDict):
88 | flag_dict[k] = flag_dict[k].to_dict()
89 | wandb.config.update(flag_dict)
90 |
91 | def log(self, data: dict, step: int = None):
92 | data_flat = _recursive_flatten_dict(data)
93 | data = {k: v for k, v in zip(*data_flat)}
94 | wandb.log(data, step=step)
95 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/data/text_processing.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import jax.numpy as jnp
4 | import numpy as np
5 | import tensorflow as tf
6 | from flax.core import FrozenDict
7 |
8 | MULTI_MODULE = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
9 |
10 |
11 | class TextProcessor:
12 | """
13 | Base class for text tokenization or text embedding.
14 | """
15 |
16 | def encode(self, strings):
17 | pass
18 |
19 |
20 | class HFTokenizer(TextProcessor):
21 | def __init__(
22 | self,
23 | tokenizer_name: str,
24 | tokenizer_kwargs: Optional[dict] = {
25 | "max_length": 64,
26 | "padding": "max_length",
27 | "truncation": True,
28 | "return_tensors": "np",
29 | },
30 | encode_with_model: bool = False,
31 | ):
32 | from transformers import AutoTokenizer, FlaxAutoModel # lazy import
33 |
34 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
35 | self.tokenizer_kwargs = tokenizer_kwargs
36 | self.encode_with_model = encode_with_model
37 | if self.encode_with_model:
38 | self.model = FlaxAutoModel.from_pretrained(tokenizer_name)
39 |
40 | def encode(self, strings):
41 | # this creates another nested layer with "input_ids", "attention_mask", etc.
42 | inputs = self.tokenizer(
43 | strings,
44 | **self.tokenizer_kwargs,
45 | )
46 | if self.encode_with_model:
47 | return np.array(self.model(**inputs).last_hidden_state)
48 | else:
49 | return FrozenDict(inputs)
50 |
51 |
52 | class MuseEmbedding(TextProcessor):
53 | def __init__(self):
54 | import tensorflow_hub as hub # lazy import
55 | import tensorflow_text # required for muse
56 |
57 | self.muse_model = hub.load(MULTI_MODULE)
58 |
59 | def encode(self, strings):
60 | with tf.device("/cpu:0"):
61 | return self.muse_model(strings).numpy()
62 |
63 |
64 | class CLIPTextProcessor(TextProcessor):
65 | def __init__(
66 | self,
67 | tokenizer_kwargs: Optional[dict] = {
68 | "max_length": 64,
69 | "padding": "max_length",
70 | "truncation": True,
71 | "return_tensors": "np",
72 | },
73 | ):
74 | from transformers import CLIPProcessor
75 |
76 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
77 | self.kwargs = tokenizer_kwargs
78 |
79 | def encode(self, strings):
80 | inputs = self.processor(
81 | text=strings,
82 | **self.kwargs,
83 | )
84 | inputs["position_ids"] = jnp.expand_dims(
85 | jnp.arange(inputs["input_ids"].shape[1]), axis=0
86 | ).repeat(inputs["input_ids"].shape[0], axis=0)
87 | return FrozenDict(inputs)
88 |
89 |
90 | text_processors = {
91 | "hf_tokenizer": HFTokenizer,
92 | "muse_embedding": MuseEmbedding,
93 | "clip_processor": CLIPTextProcessor,
94 | }
95 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/data/tf_augmentations.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping
2 |
3 | import tensorflow as tf
4 | from ml_collections import ConfigDict
5 |
6 |
7 | def random_resized_crop(image, scale, ratio, seed, batched=False):
8 | if not batched:
9 | image = tf.expand_dims(image, axis=0)
10 | batch_size = tf.shape(image)[0]
11 | # taken from https://keras.io/examples/vision/nnclr/#random-resized-crops
12 | log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1]))
13 | height = tf.shape(image)[-3]
14 | width = tf.shape(image)[-2]
15 |
16 | random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1])
17 | random_ratios = tf.exp(
18 | tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1])
19 | )
20 |
21 | new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1)
22 | new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1)
23 | height_offsets = tf.random.stateless_uniform(
24 | (batch_size,), seed, 0, 1 - new_heights
25 | )
26 | width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths)
27 |
28 | bounding_boxes = tf.stack(
29 | [
30 | height_offsets,
31 | width_offsets,
32 | height_offsets + new_heights,
33 | width_offsets + new_widths,
34 | ],
35 | axis=1,
36 | )
37 |
38 | if len(tf.shape(image)) == 5:
39 | obs_horizon = tf.shape(image)[1]
40 | # fold obs_horizon dimension into batch dimension
41 | image = tf.reshape(image, [batch_size * obs_horizon, height, width, -1])
42 | # repeat bounding_boxes so each obs history is augmented the same
43 | bounding_boxes = tf.repeat(bounding_boxes, obs_horizon, axis=0)
44 | image = tf.image.crop_and_resize(
45 | image, bounding_boxes, tf.range(batch_size * obs_horizon), (height, width)
46 | )
47 | image = tf.reshape(image, [batch_size, obs_horizon, height, width, -1])
48 | else:
49 | image = tf.image.crop_and_resize(
50 | image, bounding_boxes, tf.range(batch_size), (height, width)
51 | )
52 |
53 | if not batched:
54 | return image[0]
55 | else:
56 | return image
57 |
58 |
59 | AUGMENT_OPS = {
60 | "random_resized_crop": random_resized_crop,
61 | "random_brightness": tf.image.stateless_random_brightness,
62 | "random_contrast": tf.image.stateless_random_contrast,
63 | "random_saturation": tf.image.stateless_random_saturation,
64 | "random_hue": tf.image.stateless_random_hue,
65 | "random_flip_left_right": tf.image.stateless_random_flip_left_right,
66 | }
67 |
68 |
69 | def augment(image, seed, **augment_kwargs):
70 | image = tf.cast(image, tf.float32) / 255 # convert images to [0, 1]
71 | for op in augment_kwargs["augment_order"]:
72 | if op in augment_kwargs:
73 | if isinstance(augment_kwargs[op], Mapping) or isinstance(
74 | augment_kwargs[op], ConfigDict
75 | ):
76 | image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op])
77 | else:
78 | image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op])
79 | else:
80 | image = AUGMENT_OPS[op](image, seed=seed)
81 | image = tf.clip_by_value(image, 0, 1)
82 | image = tf.cast(image * 255, tf.uint8)
83 | return image
84 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/networks/actor_critic_nets.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import distrax
4 | import flax.linen as nn
5 | import jax.numpy as jnp
6 |
7 | from jaxrl_m.common.common import default_init
8 |
9 |
10 | class ValueCritic(nn.Module):
11 | encoder: nn.Module
12 | network: nn.Module
13 | init_final: Optional[float] = None
14 |
15 | @nn.compact
16 | def __call__(self, observations: jnp.ndarray, train: bool = False) -> jnp.ndarray:
17 | outputs = self.network(self.encoder(observations), train=train)
18 | if self.init_final is not None:
19 | value = nn.Dense(
20 | 1,
21 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final),
22 | )(outputs)
23 | else:
24 | value = nn.Dense(1, kernel_init=default_init())(outputs)
25 | return jnp.squeeze(value, -1)
26 |
27 |
28 | class Critic(nn.Module):
29 | encoder: Optional[nn.Module]
30 | network: nn.Module
31 | init_final: Optional[float] = None
32 | network_separate_action_input: bool = False # for PTR, input action to every layer
33 |
34 | @nn.compact
35 | def __call__(
36 | self, observations: jnp.ndarray, actions: jnp.ndarray, train: bool = False
37 | ) -> jnp.ndarray:
38 | if self.encoder is None:
39 | obs_enc = observations
40 | else:
41 | obs_enc = self.encoder(observations)
42 |
43 | if self.network_separate_action_input:
44 | outputs = self.network(obs_enc, actions, train=train)
45 | else:
46 | inputs = jnp.concatenate([obs_enc, actions], -1)
47 | outputs = self.network(inputs, train=train)
48 | if self.init_final is not None:
49 | value = nn.Dense(
50 | 1,
51 | kernel_init=nn.initializers.uniform(-self.init_final, self.init_final),
52 | )(outputs)
53 | else:
54 | value = nn.Dense(1, kernel_init=default_init())(outputs)
55 | return jnp.squeeze(value, -1)
56 |
57 |
58 | def ensemblize(cls, num_qs, out_axes=0):
59 | return nn.vmap(
60 | cls,
61 | variable_axes={"params": 0},
62 | split_rngs={"params": True},
63 | in_axes=None,
64 | out_axes=out_axes,
65 | axis_size=num_qs,
66 | )
67 |
68 |
69 | class Policy(nn.Module):
70 | encoder: Optional[nn.Module]
71 | network: nn.Module
72 | action_dim: int
73 | init_final: Optional[float] = None
74 | std_parameterization: str = "exp" # "exp", "softplus", "fixed", or "uniform"
75 | std_min: Optional[float] = 1e-5
76 | std_max: Optional[float] = 10.0
77 | tanh_squash_distribution: bool = False
78 | fixed_std: Optional[jnp.ndarray] = None
79 |
80 | @nn.compact
81 | def __call__(
82 | self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False
83 | ) -> distrax.Distribution:
84 | if self.encoder is None:
85 | obs_enc = observations
86 | else:
87 | obs_enc = self.encoder(observations)
88 |
89 | outputs = self.network(obs_enc, train=train)
90 |
91 | means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
92 | if self.fixed_std is None:
93 | if self.std_parameterization == "exp":
94 | log_stds = nn.Dense(self.action_dim, kernel_init=default_init())(
95 | outputs
96 | )
97 | stds = jnp.exp(log_stds)
98 | elif self.std_parameterization == "softplus":
99 | stds = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
100 | stds = nn.softplus(stds)
101 | elif self.std_parameterization == "uniform":
102 | log_stds = self.param(
103 | "log_stds", nn.initializers.zeros, (self.action_dim,)
104 | )
105 | stds = jnp.exp(log_stds)
106 | else:
107 | raise ValueError(
108 | f"Invalid std_parameterization: {self.std_parameterization}"
109 | )
110 | else:
111 | assert self.std_parameterization == "fixed"
112 | stds = jnp.array(self.fixed_std)
113 |
114 | # Clip stds to avoid numerical instability
115 | # For a normal distribution under MaxEnt, optimal std scales with sqrt(temperature)
116 | stds = jnp.clip(stds, self.std_min, self.std_max) * jnp.sqrt(temperature)
117 | # stds = jnp.concatenate([stds[:, :6], jnp.ones((len(stds), 1)) * jnp.log(0.3)], axis=-1)
118 |
119 | if self.tanh_squash_distribution:
120 | distribution = TanhMultivariateNormalDiag(
121 | loc=means,
122 | scale_diag=stds,
123 | )
124 | else:
125 | distribution = distrax.MultivariateNormalDiag(
126 | loc=means,
127 | scale_diag=stds,
128 | )
129 |
130 | return distribution
131 |
132 |
133 | class TanhMultivariateNormalDiag(distrax.Transformed):
134 | def __init__(
135 | self,
136 | loc: jnp.ndarray,
137 | scale_diag: jnp.ndarray,
138 | low: Optional[jnp.ndarray] = None,
139 | high: Optional[jnp.ndarray] = None,
140 | ):
141 | distribution = distrax.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
142 |
143 | layers = []
144 |
145 | if not (low is None or high is None):
146 |
147 | def rescale_from_tanh(x):
148 | x = (x + 1) / 2 # (-1, 1) => (0, 1)
149 | return x * (high - low) + low
150 |
151 | def forward_log_det_jacobian(x):
152 | high_ = jnp.broadcast_to(high, x.shape)
153 | low_ = jnp.broadcast_to(low, x.shape)
154 | return jnp.sum(jnp.log(0.5 * (high_ - low_)), -1)
155 |
156 | layers.append(
157 | distrax.Lambda(
158 | rescale_from_tanh,
159 | forward_log_det_jacobian=forward_log_det_jacobian,
160 | event_ndims_in=1,
161 | event_ndims_out=1,
162 | )
163 | )
164 |
165 | layers.append(distrax.Block(distrax.Tanh(), 1))
166 |
167 | bijector = distrax.Chain(layers)
168 |
169 | super().__init__(distribution=distribution, bijector=bijector)
170 |
171 | def mode(self) -> jnp.ndarray:
172 | return self.bijector.forward(self.distribution.mode())
173 |
174 | def stddev(self) -> jnp.ndarray:
175 | return self.bijector.forward(self.distribution.stddev())
176 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/networks/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Sequence
2 |
3 | import flax.linen as nn
4 | import jax.numpy as jnp
5 |
6 | from jaxrl_m.common.common import default_init
7 |
8 |
9 | class MLP(nn.Module):
10 | hidden_dims: Sequence[int]
11 | activations: Callable[[jnp.ndarray], jnp.ndarray] | str = nn.swish
12 | activate_final: bool = False
13 | use_layer_norm: bool = False
14 | use_group_norm: bool = False
15 | dropout_rate: Optional[float] = None
16 |
17 | def setup(self):
18 | assert not (self.use_layer_norm and self.use_group_norm)
19 |
20 | @nn.compact
21 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray:
22 | activations = self.activations
23 | if isinstance(activations, str):
24 | activations = getattr(nn, activations)
25 |
26 | for i, size in enumerate(self.hidden_dims):
27 | x = nn.Dense(size, kernel_init=default_init())(x)
28 |
29 | if i + 1 < len(self.hidden_dims) or self.activate_final:
30 | if self.dropout_rate is not None and self.dropout_rate > 0:
31 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
32 | if self.use_layer_norm:
33 | x = nn.LayerNorm()(x)
34 | elif self.use_group_norm:
35 | x = nn.GroupNorm()(x)
36 | x = activations(x)
37 | return x
38 |
39 |
40 | class LayerInputMLP(MLP):
41 | """
42 | MLP, but each layer takes in an additional input as well
43 | such as the critic network in PTR
44 | """
45 |
46 | @nn.compact
47 | def __call__(
48 | self, x: jnp.ndarray, layer_input: jnp.ndarray, train: bool = False
49 | ) -> jnp.ndarray:
50 | activations = self.activations
51 | if isinstance(activations, str):
52 | activations = getattr(nn, activations)
53 |
54 | for i, size in enumerate(self.hidden_dims):
55 | x = jnp.concatenate([x, layer_input], axis=-1) # difference from MLP
56 | x = nn.Dense(size, kernel_init=default_init())(x)
57 |
58 | if i + 1 < len(self.hidden_dims) or self.activate_final:
59 | if self.dropout_rate is not None and self.dropout_rate > 0:
60 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
61 | if self.use_layer_norm:
62 | x = nn.LayerNorm()(x)
63 | elif self.use_group_norm:
64 | x = nn.GroupNorm()(x)
65 | x = activations(x)
66 | return x
67 |
68 |
69 | class MLPResNetBlock(nn.Module):
70 | features: int
71 | act: Callable
72 | dropout_rate: float = None
73 | use_layer_norm: bool = False
74 |
75 | @nn.compact
76 | def __call__(self, x, train: bool = False):
77 | residual = x
78 | if self.dropout_rate is not None and self.dropout_rate > 0:
79 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
80 | if self.use_layer_norm:
81 | x = nn.LayerNorm()(x)
82 | x = nn.Dense(self.features * 4)(x)
83 | x = self.act(x)
84 | x = nn.Dense(self.features)(x)
85 |
86 | if residual.shape != x.shape:
87 | residual = nn.Dense(self.features)(residual)
88 |
89 | return residual + x
90 |
91 |
92 | class MLPResNet(nn.Module):
93 | num_blocks: int
94 | out_dim: int
95 | dropout_rate: float = None
96 | use_layer_norm: bool = False
97 | hidden_dim: int = 256
98 | activations: Callable = nn.swish
99 |
100 | @nn.compact
101 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray:
102 | x = nn.Dense(self.hidden_dim, kernel_init=default_init())(x)
103 | for _ in range(self.num_blocks):
104 | x = MLPResNetBlock(
105 | self.hidden_dim,
106 | act=self.activations,
107 | use_layer_norm=self.use_layer_norm,
108 | dropout_rate=self.dropout_rate,
109 | )(x, train=train)
110 |
111 | x = self.activations(x)
112 | x = nn.Dense(self.out_dim, kernel_init=default_init())(x)
113 | return x
114 |
115 |
116 | class Scalar(nn.Module):
117 | init_value: float
118 |
119 | def setup(self):
120 | self.value = self.param("value", lambda x: self.init_value)
121 |
122 | def __call__(self):
123 | return self.value
124 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/utils/jax_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 |
3 |
4 | @jax.jit
5 | def batch_to_jax(batch):
6 | return jax.tree_util.tree_map(jax.device_put, batch)
7 |
8 |
9 | class JaxRNG(object):
10 | """A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
11 | pure function.
12 | """
13 |
14 | @classmethod
15 | def from_seed(cls, seed):
16 | return cls(jax.random.PRNGKey(seed))
17 |
18 | def __init__(self, rng):
19 | self.rng = rng
20 |
21 | def __call__(self, keys=None):
22 | if keys is None:
23 | self.rng, split_rng = jax.random.split(self.rng)
24 | return split_rng
25 | elif isinstance(keys, int):
26 | split_rngs = jax.random.split(self.rng, num=keys + 1)
27 | self.rng = split_rngs[0]
28 | return tuple(split_rngs[1:])
29 | else:
30 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
31 | self.rng = split_rngs[0]
32 | return {key: val for key, val in zip(keys, split_rngs[1:])}
33 |
34 |
35 | def wrap_function_with_rng(rng):
36 | """To be used as decorator, automatically bookkeep a RNG for the wrapped function."""
37 |
38 | def wrap_function(function):
39 | def wrapped(*args, **kwargs):
40 | nonlocal rng
41 | rng, split_rng = jax.random.split(rng)
42 | return function(split_rng, *args, **kwargs)
43 |
44 | return wrapped
45 |
46 | return wrap_function
47 |
48 |
49 | def init_rng(seed):
50 | global jax_utils_rng
51 | jax_utils_rng = JaxRNG.from_seed(seed)
52 |
53 |
54 | def next_rng(*args, **kwargs):
55 | global jax_utils_rng
56 | return jax_utils_rng(*args, **kwargs)
57 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/utils/timer_utils.py:
--------------------------------------------------------------------------------
1 | """Timer utility."""
2 |
3 | import time
4 | from collections import defaultdict
5 |
6 |
7 | class _TimerContextManager:
8 | def __init__(self, timer: "Timer", key: str):
9 | self.timer = timer
10 | self.key = key
11 |
12 | def __enter__(self):
13 | self.timer.tick(self.key)
14 |
15 | def __exit__(self, exc_type, exc_value, exc_traceback):
16 | self.timer.tock(self.key)
17 |
18 |
19 | class Timer:
20 | def __init__(self):
21 | self.reset()
22 |
23 | def reset(self):
24 | self.counts = defaultdict(int)
25 | self.times = defaultdict(float)
26 | self.start_times = {}
27 |
28 | def tick(self, key):
29 | if key in self.start_times:
30 | raise ValueError(f"Timer is already ticking for key: {key}")
31 | self.start_times[key] = time.time()
32 |
33 | def tock(self, key):
34 | if key not in self.start_times:
35 | raise ValueError(f"Timer is not ticking for key: {key}")
36 | self.counts[key] += 1
37 | self.times[key] += time.time() - self.start_times[key]
38 | del self.start_times[key]
39 |
40 | def force_tock_everything(self):
41 | for key in self.start_times:
42 | self.tock(key)
43 |
44 | def context(self, key):
45 | """
46 | Use this like:
47 |
48 | with timer.context("key"):
49 | # do stuff
50 |
51 | Then timer.tock("key") will be called automatically.
52 | """
53 | return _TimerContextManager(self, key)
54 |
55 | def get_average_times(self, reset=True):
56 | ret = {key: self.times[key] / self.counts[key] for key in self.counts}
57 | if reset:
58 | self.reset()
59 | return ret
60 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping
2 |
3 | import imageio
4 | import numpy as np
5 | import tensorflow as tf
6 | import wandb
7 | from flax.core import frozen_dict
8 |
9 |
10 | def concatenate_batches(batches):
11 | concatenated = {}
12 | for key in batches[0].keys():
13 | if isinstance(batches[0][key], Mapping):
14 | # to concatenate batch["observations"]["image"], etc.
15 | concatenated[key] = concatenate_batches([batch[key] for batch in batches])
16 | else:
17 | concatenated[key] = np.concatenate(
18 | [batch[key] for batch in batches], axis=0
19 | ).astype(np.float32)
20 | return concatenated
21 |
22 |
23 | def index_batch(batch, indices):
24 | indexed = {}
25 | for key in batch.keys():
26 | if isinstance(batch[key], Mapping):
27 | # to index into batch["observations"]["image"], etc.
28 | indexed[key] = index_batch(batch[key], indices)
29 | else:
30 | indexed[key] = batch[key][indices, ...]
31 | return indexed
32 |
33 |
34 | def subsample_batch(batch, size):
35 | indices = np.random.randint(batch["rewards"].shape[0], size=size)
36 | return index_batch(batch, indices)
37 |
38 |
39 | def load_recorded_video(
40 | video_path: str,
41 | ):
42 | with tf.io.gfile.GFile(video_path, "rb") as f:
43 | video = np.array(imageio.mimread(f, "MP4")).transpose((0, 3, 1, 2))
44 | assert video.shape[1] == 3, "Numpy array should be (T, C, H, W)"
45 |
46 | return wandb.Video(video, fps=20)
47 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from jaxrl_m.vision.resnet_v1 import resnetv1_configs
2 |
3 | encoders = dict()
4 | encoders.update(resnetv1_configs)
5 |
--------------------------------------------------------------------------------
/model_training/jaxrl_m/vision/film_conditioning_layer.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py
2 | import flax.linen as nn
3 | import jax.numpy as jnp
4 |
5 |
6 | class FilmConditioning(nn.Module):
7 | @nn.compact
8 | def __call__(self, conv_filters: jnp.ndarray, conditioning: jnp.ndarray):
9 | """Applies FiLM conditioning to a convolutional feature map.
10 |
11 | Args:
12 | conv_filters: A tensor of shape [batch_size, height, width, channels].
13 | conditioning: A tensor of shape [batch_size, conditioning_size].
14 |
15 | Returns:
16 | A tensor of shape [batch_size, height, width, channels].
17 | """
18 | projected_cond_add = nn.Dense(
19 | features=conv_filters.shape[-1],
20 | kernel_init=nn.initializers.zeros,
21 | bias_init=nn.initializers.zeros,
22 | )(conditioning)
23 | projected_cond_mult = nn.Dense(
24 | features=conv_filters.shape[-1],
25 | kernel_init=nn.initializers.zeros,
26 | bias_init=nn.initializers.zeros,
27 | )(conditioning)
28 |
29 | projected_cond_add = projected_cond_add[..., None, None, :]
30 | projected_cond_mult = projected_cond_mult[..., None, None, :]
31 |
32 | return conv_filters * (1 + projected_cond_add) + projected_cond_mult
33 |
34 |
35 | if __name__ == "__main__":
36 | import jax
37 | import jax.numpy as jnp
38 |
39 | key = jax.random.PRNGKey(0)
40 | key, subkey = jax.random.split(key)
41 | x = jax.random.normal(subkey, (1, 32, 32, 3))
42 | x = jnp.array(x)
43 |
44 | z = jnp.ones((1, 64))
45 | film = FilmConditioning()
46 | params = film.init(key, x, z)
47 | y = film.apply(params, x, z)
48 |
49 | print(y.shape)
50 |
--------------------------------------------------------------------------------
/model_training/requirements.txt:
--------------------------------------------------------------------------------
1 | gym >= 0.26
2 | numpy==1.24.3
3 | jax==0.4.20
4 | jaxlib==0.4.20
5 | distrax==0.1.5
6 | flax==0.7.5
7 | orbax-checkpoint==0.3.5
8 | ml_collections >= 0.1.0
9 | tqdm >= 4.60.0
10 | chex==0.1.85
11 | optax==0.1.5
12 | absl-py >= 0.12.0
13 | scipy == 1.12.0
14 | wandb >= 0.12.14
15 | tensorflow==2.15.0
16 | einops >= 0.6.1
17 | imageio >= 2.31.1
18 | moviepy >= 1.0.3
19 | pre-commit == 3.3.3
20 | overrides==7.7.0
21 | tensorflow_probability==0.23.0
22 | tensorflow_text
23 | dlimp @ git+https://github.com/zhouzypaul/dlimp@2df85dc98a7d564fc9c7c5ff2bfda26361526483
24 | octo @ git+https://github.com/octo-models/octo.git
25 |
--------------------------------------------------------------------------------
/model_training/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name="jaxrl_m", packages=["jaxrl_m"])
4 |
--------------------------------------------------------------------------------
/rlds_converter/README.md:
--------------------------------------------------------------------------------
1 | # Data RLDS conversion code
2 |
3 | This directory contains the code that converts the raw robot data logged using the `data_collection` dir and converts it into
4 | the [RLDS format](https://github.com/google-research/rlds), which is a specification on top of the [TFDS](https://www.tensorflow.org/datasets) (TensorFlow datasets) format, which is for the most part built on top of the TFRecord format.
5 |
6 | This is sourced from [Zhiyuan Zhou's implementation](https://github.com/zhouzypaul/dlimp), which heavily inherits Kevin Black's [dlimp](https://github.com/kvablack/dlimp) library.
7 |
8 |
9 | ## Usage
10 | Install the requirements with
11 | ```bash
12 | pip install -r requirements.txt
13 | pip install -e .
14 | ```
15 |
16 | To build the SOAR dataset
17 | ```bash
18 | cd soar_dataset
19 | CUDA_VISIBLE_DEVICES="" tfds build --manual_dir
20 | ```
21 | You can modify settings in side the `soar_dataset/soar_dataset_dataset_builder.py` file (e.g., `NUM_WORKERS` and `CHUNKSIZE`).
22 |
23 | This data builder assumes your raw data is organized into the following structure:
24 | ```
25 | manual_dir/robot_id/scene_id/policy_type/date/success/trajectory_{i}/*
26 | manual_dir/robot_id/scene_id/policy_type/date/failure/trajectory_{i}/*
27 | ```
28 | Each `trajectory_{i}` directory will contain the following files, as logged by the code in the `data_collection` dir
29 | - actions.npy
30 | - eef_poses.npy
31 | - language_task.txt
32 | - robot_id.txt
33 | - task_list.txt
34 | - trajectory.mp4
35 | - combined.mp4
36 | - goals.mp4
37 | - object_list.txt
38 | - success.txt
39 | - time.txt
40 |
41 | The RLDS dataset will be automatically saved under `~/tensorflow_datasets/soar_dataset/`
42 |
--------------------------------------------------------------------------------
/rlds_converter/dataset_builder.py:
--------------------------------------------------------------------------------
1 | """Inspired by https://github.com/kpertsch/bridge_rlds_builder/blob/f0d16c5a8384c1476aa1c274a9aef3a5f76cbada/bridge_dataset/conversion_utils.py"""
2 |
3 | import abc
4 | import itertools
5 | import multiprocessing as mp
6 | from typing import Any, Callable, Dict, Iterable, Tuple, Union
7 |
8 | import tensorflow_datasets as tfds
9 | from absl import logging
10 | from tensorflow_datasets.core import (
11 | dataset_builder,
12 | download,
13 | example_serializer,
14 | file_adapters,
15 | naming,
16 | )
17 | from tensorflow_datasets.core import split_builder as split_builder_lib
18 | from tensorflow_datasets.core import splits as splits_lib
19 | from tensorflow_datasets.core import writer as writer_lib
20 | from tqdm import tqdm
21 |
22 | Key = Union[str, int]
23 | Example = Dict[str, Any]
24 | ExampleInput = Any
25 |
26 |
27 | class MultiThreadedSplitBuilder(split_builder_lib.SplitBuilder):
28 | """Multithreaded version of tfds.core.SplitBuilder. Removes Apache Beam support, only supporting Python generators."""
29 |
30 | def __init__(
31 | self,
32 | process_fn: Callable[[ExampleInput], Example],
33 | num_workers: int,
34 | chunksize: int,
35 | *args,
36 | **kwargs,
37 | ):
38 | super().__init__(*args, **kwargs)
39 | self._process_fn = process_fn
40 | self.num_workers = num_workers
41 | self.chunksize = chunksize
42 |
43 | def submit_split_generation(
44 | self,
45 | split_name: splits_lib.Split,
46 | generator: Iterable[Tuple[Key, ExampleInput]],
47 | filename_template: naming.ShardedFileTemplate,
48 | disable_shuffling: bool = False,
49 | ) -> splits_lib.SplitInfo:
50 | if self._max_examples_per_split is not None:
51 | logging.warning(
52 | "Splits capped at %s examples max.", self._max_examples_per_split
53 | )
54 | generator = itertools.islice(generator, self._max_examples_per_split)
55 | total_num_examples = self._max_examples_per_split
56 | else:
57 | # If dataset info has been pre-downloaded from the internet,
58 | # we can use the pre-computed number of example for the progression bar.
59 | split_info = self._split_dict.get(split_name)
60 | if split_info and split_info.num_examples:
61 | total_num_examples = split_info.num_examples
62 | else:
63 | total_num_examples = None
64 |
65 | serialized_info = self._features.get_serialized_info()
66 | writer = writer_lib.Writer(
67 | serializer=example_serializer.ExampleSerializer(serialized_info),
68 | filename_template=filename_template,
69 | hash_salt=split_name,
70 | disable_shuffling=disable_shuffling,
71 | file_format=self._file_format,
72 | shard_config=self._shard_config,
73 | )
74 | pbar = tqdm(
75 | total=total_num_examples,
76 | desc=f"Generating {split_name} examples...",
77 | unit=" examples",
78 | dynamic_ncols=True,
79 | miniters=1,
80 | )
81 | with mp.Pool(
82 | self.num_workers,
83 | initializer=MultiThreadedSplitBuilder._worker_init,
84 | initargs=(self._process_fn, self._features),
85 | ) as pool:
86 | logging.info(
87 | "Using %d workers with chunksize %d.", self.num_workers, self.chunksize
88 | )
89 | while True:
90 | curr = pbar.n
91 | iterator = itertools.islice(generator, self.chunksize)
92 | results = pool.map(MultiThreadedSplitBuilder._worker_fn, iterator)
93 | for key, example in results:
94 | writer._shuffler.add(key, example)
95 | writer._num_examples += 1
96 | pbar.update(1)
97 | if pbar.n == curr:
98 | break
99 | shard_lengths, total_size = writer.finalize()
100 |
101 | return splits_lib.SplitInfo(
102 | name=split_name,
103 | shard_lengths=shard_lengths,
104 | num_bytes=total_size,
105 | filename_template=filename_template,
106 | )
107 |
108 | @staticmethod
109 | def _worker_init(
110 | process_fn: Callable[[ExampleInput], Example],
111 | features: tfds.features.FeaturesDict,
112 | ):
113 | global __process_fn
114 | global __features
115 | global __serializer
116 | __process_fn = process_fn
117 | __features = features
118 | __serializer = example_serializer.ExampleSerializer(
119 | features.get_serialized_info()
120 | )
121 |
122 | @staticmethod
123 | def _worker_fn(example_input):
124 | global __process_fn
125 | global __features
126 | global __serializer
127 | key, example = __process_fn(example_input)
128 | return key, __serializer.serialize_example(__features.encode_example(example))
129 |
130 |
131 | class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder):
132 | """Multithreaded version of tfds.core.GeneratorBasedBuilder."""
133 |
134 | # Defaults can be overridden by subclasses.
135 | NUM_WORKERS = 16 # number of parallel workers
136 | CHUNKSIZE = 1000 # number of examples to process in memory before writing to disk
137 |
138 | @classmethod
139 | @abc.abstractmethod
140 | def _process_example(cls, example_input: ExampleInput) -> Example:
141 | """Process a single example.
142 |
143 | This is the function that will be parallelized, so it should contain any heavy computation and I/O. It
144 | should return a feature dictionary compatible with `self.info.features` (see the FeatureConnector
145 | documenation) that is ready to be encoded and serialized.
146 | """
147 | raise NotImplementedError()
148 |
149 | @abc.abstractmethod
150 | def _split_generators(
151 | self,
152 | dl_manager: download.DownloadManager,
153 | ) -> Dict[splits_lib.Split, Iterable[Tuple[Key, ExampleInput]]]:
154 | """Same as GeneratorBasedBuilder._split_generators, but returns generators of tuples (key,
155 | example_input) rather than (key, example). `example_input` will be passed to
156 | `_process_example` for further processing.
157 | """
158 | raise NotImplementedError()
159 |
160 | def _generate_examples(self, *args, **kwargs):
161 | """This is not actually called from TFDS code. I believe they left it in for legacy reasons. However,
162 | it must be overridden for TFDS to recognize the class as a valid dataset builder.
163 | """
164 | raise RuntimeError()
165 |
166 | def _download_and_prepare(
167 | self,
168 | dl_manager: download.DownloadManager,
169 | download_config: download.DownloadConfig,
170 | ) -> None:
171 | """Same as superclass `_download_and_prepare`, but removes Apache Beam stuff and uses
172 | MultiThreadedSplitBuilder instead of SplitBuilder.
173 | """
174 | split_builder = MultiThreadedSplitBuilder(
175 | process_fn=type(self)._process_example,
176 | num_workers=self.NUM_WORKERS,
177 | chunksize=self.CHUNKSIZE,
178 | split_dict=self.info.splits,
179 | features=self.info.features,
180 | dataset_size=self.info.dataset_size,
181 | max_examples_per_split=download_config.max_examples_per_split,
182 | beam_options=download_config.beam_options,
183 | beam_runner=download_config.beam_runner,
184 | file_format=self.info.file_format,
185 | shard_config=download_config.get_shard_config(),
186 | )
187 |
188 | split_generators = self._split_generators(dl_manager)
189 | dataset_builder._check_split_names(split_generators.keys())
190 |
191 | # Writer fail if the number of example yield is `0`, so we return here.
192 | if download_config.max_examples_per_split == 0:
193 | return
194 |
195 | # Start generating data for all splits
196 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[
197 | self.info.file_format
198 | ].FILE_SUFFIX
199 |
200 | split_infos = []
201 | for split_name, generator in split_generators.items():
202 | filename_template = naming.ShardedFileTemplate(
203 | split=split_name,
204 | dataset_name=self.name,
205 | data_dir=self.data_path,
206 | filetype_suffix=path_suffix,
207 | )
208 | split_info = split_builder.submit_split_generation(
209 | split_name=split_name,
210 | generator=generator,
211 | filename_template=filename_template,
212 | disable_shuffling=self.info.disable_shuffling,
213 | )
214 | split_infos.append(split_info)
215 |
216 | # Update the info object with the splits.
217 | split_dict = splits_lib.SplitDict(split_infos)
218 | self.info.set_splits(split_dict)
219 |
--------------------------------------------------------------------------------
/rlds_converter/requirements.txt:
--------------------------------------------------------------------------------
1 | pillow == 10.2.0
2 |
--------------------------------------------------------------------------------
/rlds_converter/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="dlimp-dataset-builder",
5 | python_requires=">=3.10",
6 | install_requires=[
7 | "tensorflow_datasets==4.9.4",
8 | "opencv-python",
9 | "apache_beam",
10 | "tensorflow",
11 | ],
12 | )
13 |
--------------------------------------------------------------------------------
/rlds_converter/soar_dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rail-berkeley/soar/1195ab7b46cd0df1be30bcbfe280605374c22190/rlds_converter/soar_dataset/__init__.py
--------------------------------------------------------------------------------
/soar_data/download_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Allow overriding variables from the command line
4 | SAVE_DIR="${SAVE_DIR:-~/soar_data}"
5 | REQUIRED_SPACE_GB="${REQUIRED_SPACE_GB:-140}"
6 | URL_FILE="${URL_FILE:-soar_data/urls.txt}"
7 |
8 | # Function to check if enough disk space is available
9 | check_disk_space() {
10 | local available_space
11 | available_space=$(df --output=avail -BG . | tail -1 | tr -d 'G')
12 | if [ "$available_space" -lt "$REQUIRED_SPACE_GB" ]; then
13 | echo "Warning: You need at least $REQUIRED_SPACE_GB GB of free space to proceed."
14 | read -p "Do you want to continue anyway? (y/n) " REPLY
15 | echo
16 | if [[ ! $REPLY =~ ^[Yy]$ ]]; then
17 | echo "Aborting download."
18 | exit 1
19 | fi
20 | fi
21 | }
22 |
23 | # Check disk space
24 | check_disk_space
25 |
26 | # Check if the url file exists
27 | if [ ! -f $URL_FILE ]; then
28 | echo "Error: URLs file '${URL_FILE}' not found. Please run from the root directory of the repository."
29 | echo "URL file should be found at soar_data/urls.txt"
30 | echo "Aborting download."
31 | exit 1
32 | fi
33 |
34 | # Inform saving directory
35 | echo "Saving files to $SAVE_DIR"
36 |
37 | # Count the number of files
38 | TOTAL_FILES=$(wc -l < $URL_FILE)
39 | CURRENT_FILE=0
40 |
41 | # Function to print the progress bar
42 | print_progress() {
43 | local PROGRESS=$(( ($CURRENT_FILE * 100) / $TOTAL_FILES ))
44 | local FILLED=$(( $PROGRESS / 2 ))
45 | local EMPTY=$(( 50 - $FILLED ))
46 | printf "\rDownloading files: ["
47 | printf "%0.s#" $(seq 1 $FILLED)
48 | printf "%0.s " $(seq 1 $EMPTY)
49 | printf "] %d%%" $PROGRESS
50 | }
51 |
52 | # Function to download using wget without parallel
53 | download_without_parallel() {
54 | while IFS= read -r url; do
55 | wget -P "$SAVE_DIR" "$url"
56 | ((CURRENT_FILE++))
57 | print_progress
58 | done < $URL_FILE
59 | echo
60 | }
61 |
62 | download_with_parallel() {
63 | cat $URL_FILE | parallel -j 4 wget -P "$SAVE_DIR" {}
64 | }
65 |
66 | # Check if parallel is installed
67 | if ! command -v parallel &> /dev/null; then
68 | echo "GNU parallel is not installed."
69 | read -p "Do you want to install GNU parallel? (y/n) " REPLY
70 | echo
71 | if [[ $REPLY =~ ^[Yy]$ ]]; then
72 | # Try to install parallel
73 | if command -v apt-get &> /dev/null; then
74 | sudo apt-get update && sudo apt-get install -y parallel
75 | elif command -v yum &> /dev/null; then
76 | sudo yum install -y parallel
77 | elif command -v brew &> /dev/null; then
78 | brew install parallel
79 | else
80 | echo "Package manager not found. Please install GNU parallel manually."
81 | download_without_parallel
82 | exit 1
83 | fi
84 | else
85 | echo "Downloading files without parallelism..."
86 | download_without_parallel
87 | exit 0
88 | fi
89 | fi
90 |
91 | download_with_parallel
92 |
93 | # Initialize the progress bar
94 | CURRENT_FILE=0
95 | print_progress
96 |
97 | # Monitor the progress of parallel downloads
98 | while IFS= read -r url; do
99 | if [ -f "$SAVE_DIR/$(basename $url)" ]; then
100 | ((CURRENT_FILE++))
101 | print_progress
102 | fi
103 | done < $URL_FILE
104 | echo
105 |
--------------------------------------------------------------------------------
/soar_data/fetch_urls.sh:
--------------------------------------------------------------------------------
1 | # This script will take a long time to run, and the cached result in already saved in urls.txt
2 | # You shouldn't need to rerun this
3 |
4 | # List files, filter by patterns, remove duplicates, and save to urls.txt
5 | BASE_URL="https://rail.eecs.berkeley.edu/datasets/soar_release/1.0.0/"
6 |
7 | echo "Fetching all datafile URLs..."
8 | wget --spider -r -nd -np $BASE_URL 2>&1 | grep '^--' | awk '{ print $3 }' | grep -E '\.json$|tfrecord' | sort | uniq > urls.txt
9 | echo "Finished fetching URLs."
10 |
--------------------------------------------------------------------------------
/soar_data/load_soar_data.ipynb:
--------------------------------------------------------------------------------
1 | {"cells":[{"cell_type":"markdown","metadata":{},"source":["# Load SOAR Data (for training)\n","This notebook is a minimal example of how to load the RLDS SOAR data (e.g. for downstream training)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","Make sure the Colab has access to the SOAR repo\n","\"\"\"\n","!git clone https://github.com/rail-berkeley/soar.git"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","\"\"\"\n","1. Download a minimal SOAR dataset that's small to be used for testing\n","\n","In this notebook we load a small dummy dataset for speed. If you wish to load the full dataset, \n","use the download script in this directory to download the full dataset. Then it can be loaded\n","in the same way, changing the path to the saved dataset.\n","\"\"\"\n","SAVE_DIR = \"dummy_soar_data\"\n","!cat soar/soar_data/test_dataset_urls.txt | while read url; do wget -P \"dummy_soar_data\" \"$url\"; done"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","2. Import the Dataloader class\n","\"\"\"\n","import subprocess\n","\n","# install jaxrl_m if it it not already installed\n","# the package is located in model_training/jaxrl_m\n","try:\n"," import jaxrl_m\n","except ImportError:\n"," print(\"local jaxrl_m package not installed, trying to install now\")\n"," package_path = 'soar/model_training'\n","\n"," # install jaxrl_m\n"," result = subprocess.run(['pip', 'install', '-e', package_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n"," # Print the standard output and error\n"," print(\"Pip install Output:\\n\", result.stdout)\n"," if result.stderr:\n"," print(\"Pip install Error:\\n\", result.stderr)\n","\n"," # add the package path to sys.path (for ipynb)\n"," import sys\n"," if package_path not in sys.path:\n"," sys.path.append(package_path)\n"," \n"," # install the requirements for the package\n"," result = subprocess.run(['pip', 'install', '-r', f\"{package_path}/requirements.txt\"])\n"," # Print the standard output and error\n"," if result.stderr:\n"," print(\"Pip install Error:\\n\", result.stderr)\n"," \n","# check that installation was successful\n","try:\n"," import jaxrl_m\n","except ImportError:\n"," print(\"Failed to correctly install jaxrl_m package\")\n"," print(\"Please manually install the package with `pip install -e soar/model_training`\")\n"," raise\n","\n","# import dataloader class\n","from jaxrl_m.data.dataset import WidowXDataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","3. Load the dataset\n","\"\"\"\n","train_data = WidowXDataset(\n"," [SAVE_DIR],\n"," seed=0,\n"," batch_size=16,\n"," train=True,\n"," load_language=True,\n"," goal_relabeling_strategy=\"uniform\", # specify goal relabeling to load languages\n"," goal_relabeling_kwargs={\"reached_proportion\": 0.5, \"discount\": 0.98},\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","4. Inspect an example batch\n","\"\"\"\n","!pip install matplotlib\n","import matplotlib.pyplot as plt\n","\n","train_data_iter = train_data.iterator()\n","example_batch = next(train_data_iter)\n","\n","print(f\"Example batch keys: {example_batch.keys()}\")\n","print(f\"Actions shape: {example_batch['actions'].shape}, which is (batch_size, action_dim)\")\n","print(f\"Observations shape: {example_batch['observations']['image'].shape}, which is (batch_size, observation_dim)\")\n","print(f\"Proprio shape: {example_batch['observations']['proprio'].shape}, which is (batch_size, proprio_dim)\")\n","\n","language = example_batch['goals']['language']\n","language = [str(l) for l in language]\n","\n","plt.figure(figsize=(10, 10))\n","for i in range(16):\n"," plt.subplot(4, 4, i+1)\n"," plt.imshow(example_batch['observations']['image'][i])\n"," plt.title(language[i] if len(language[i]) <= 30 else language[i][:30] + '...', fontsize=6)\n"," plt.axis('off')\n","plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","5. Load only the success/failure split of the SOAR-data\n","If you wish, you could only load certain splits of the dataset. \n","\"\"\"\n","success_data = WidowXDataset(\n"," [SAVE_DIR],\n"," data_splits=[\"success\"],\n"," seed=0,\n"," batch_size=16,\n"," train=True,\n",")\n","\n","failure_data = WidowXDataset(\n"," [SAVE_DIR],\n"," data_splits=[\"failure\"],\n"," seed=0,\n"," batch_size=16,\n"," train=True,\n",")"]},{"cell_type":"markdown","metadata":{},"source":["## More Advanced Usage\n","For more advanced usage, check out the arguments of the `BridgeDataset` class at [model_training/jaxrl_m/data/dataset.py](https://github.com/rail-berkeley/soar/blob/main/model_training/jaxrl_m/data/dataset.py).\n","\n","An example of how this dataset is used is in `model_training/experiments/train.py`, and the configuration and arguments of the datasets are in `model_training/experiments/configs/train_config.py` and `model_training/experiments/configs/data_config.py`."]}],"metadata":{"kernelspec":{"display_name":"widowx","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"}},"nbformat":4,"nbformat_minor":2}
2 |
--------------------------------------------------------------------------------
/soar_data/test_dataset_urls.txt:
--------------------------------------------------------------------------------
1 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/dataset_info.json
2 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/features.json
3 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/soar_dataset-failure.tfrecord-00000-of-00001
4 | https://rail.eecs.berkeley.edu/datasets/soar_release/test/soar_dataset-success.tfrecord-00000-of-00001
5 |
--------------------------------------------------------------------------------