├── .github
└── FUNDING.yml
├── .gitignore
├── AgibotWorld.ipynb
├── CONTRIBUTING.md
├── README.md
├── assets
├── Contact-rich_manipulation.gif
├── Long-horizon_planning.gif
├── Multi-robot_collaboration.gif
└── dataset_visualization.gif
└── scripts
├── convert_to_lerobot.py
└── visualize_dataset.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
3 | patreon: # Replace with a single Patreon username
4 | open_collective: # Replace with a single Open Collective username
5 | ko_fi: # Replace with a single Ko-fi username
6 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
7 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
8 | liberapay: # Replace with a single Liberapay username
9 | issuehunt: # Replace with a single IssueHunt username
10 | otechie: # Replace with a single Otechie username
11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/AgibotWorld.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# AgiBot World Diffusion Policy Training Demo\n",
8 | "\n",
9 | "This notebook demonstrates how to use **AgiBotWorldDataset** to run an offline training workflow.\n",
10 | "Make sure you have installed all necessary packages before running.\n"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "# =============================================\n",
20 | "# 1. Imports and Parameter Settings\n",
21 | "# =============================================\n",
22 | "import torch\n",
23 | "import numpy as np\n",
24 | "\n",
25 | "from lerobot.common.datasets.lerobot_dataset import LeRobotDataset\n",
26 | "from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig\n",
27 | "from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy\n",
28 | "\n",
29 | "# Parameters\n",
30 | "FPS = 30\n",
31 | "TASK_ID = 352\n",
32 | "training_steps = 5000\n",
33 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
34 | "\n",
35 | "# Paths\n",
36 | "dataset_path = \"/path/to/your/AgiBotWorld/dataset\"\n",
37 | "output_path = \"/path/to/save/your/checkpoint\""
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "# =============================================\n",
47 | "# 2. Dataset Setup\n",
48 | "# =============================================\n",
49 | "observation_idx = np.array([-1, 0])\n",
50 | "action_idx = np.arange(-1, 15)\n",
51 | "repo_id = f\"agibotworld/task_{TASK_ID}\"\n",
52 | "\n",
53 | "delta_timestamps = {\n",
54 | " \"observation.images.top_head\": (observation_idx / FPS).tolist(),\n",
55 | " \"observation.state\": (observation_idx / FPS).tolist(),\n",
56 | " \"action\": (action_idx / FPS).tolist(),\n",
57 | "}\n",
58 | "\n",
59 | "dataset = LeRobotDataset(\n",
60 | " repo_id=repo_id,\n",
61 | " root=f\"{dataset_path}/{repo_id}\",\n",
62 | " delta_timestamps=delta_timestamps,\n",
63 | " local_files_only=True\n",
64 | ")\n",
65 | "\n",
66 | "dataloader = torch.utils.data.DataLoader(\n",
67 | " dataset,\n",
68 | " num_workers=0,\n",
69 | " batch_size=64,\n",
70 | " shuffle=True,\n",
71 | " pin_memory=(device.type == \"cuda\"),\n",
72 | " drop_last=True,\n",
73 | ")"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "If you want to train one robot policy model to master multiple distinct skills, you can use ’MultiLeRobotDataset‘ to load datasets for various tasks into a unified training process."
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "from pathlib import Path\n",
90 | "from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset\n",
91 | "repo_ids = [f\"agibotworld/{path.name}\" for path in Path(dataset_path).glob(\"agibotworld/task_*\")]\n",
92 | "multi_dataset = MultiLeRobotDataset(\n",
93 | " repo_ids=repo_ids,\n",
94 | " root=dataset_path,\n",
95 | " delta_timestamps=delta_timestamps,\n",
96 | " local_files_only=True\n",
97 | ")"
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {},
103 | "source": [
104 | "Let's kick off a simple training with Diffusion Policy:"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "# =============================================\n",
114 | "# 3. Policy Configuration and Initialization\n",
115 | "# =============================================\n",
116 | "cfg = DiffusionConfig()\n",
117 | "cfg.input_shapes = {\n",
118 | " \"observation.images.top_head\": [3, 480, 640],\n",
119 | " \"observation.state\": [20],\n",
120 | "}\n",
121 | "cfg.input_normalization_modes = {\n",
122 | " \"observation.images.top_head\": \"mean_std\",\n",
123 | " \"observation.state\": \"min_max\",\n",
124 | "}\n",
125 | "cfg.output_shapes = {\n",
126 | " \"action\": [22],\n",
127 | "}\n",
128 | "\n",
129 | "policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)\n",
130 | "#policy = DiffusionPolicy(cfg, dataset_stats=multi_dataset.stats)\n",
131 | "policy.train()\n",
132 | "policy.to(device)\n",
133 | "\n",
134 | "optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "# =============================================\n",
144 | "# 4. Training Loop\n",
145 | "# =============================================\n",
146 | "step = 0\n",
147 | "done = False\n",
148 | "\n",
149 | "while not done:\n",
150 | " for batch in dataloader:\n",
151 | " batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}\n",
152 | " output_dict = policy.forward(batch)\n",
153 | " loss = output_dict[\"loss\"]\n",
154 | " \n",
155 | " loss.backward()\n",
156 | " optimizer.step()\n",
157 | " optimizer.zero_grad()\n",
158 | " \n",
159 | " print(f\"Step {step}, Loss: {loss.item():.3f}\")\n",
160 | " step += 1\n",
161 | " \n",
162 | " if step >= training_steps:\n",
163 | " done = True\n",
164 | " break\n"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "# =============================================\n",
174 | "# 5. Save Policy Checkpoint\n",
175 | "# =============================================\n",
176 | "policy.save_pretrained(output_path)\n",
177 | "print(f\"Model saved to {output_path}\")\n"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "Congrats! Now please feel free to explore the AgiBot World!"
185 | ]
186 | }
187 | ],
188 | "metadata": {
189 | "kernelspec": {
190 | "display_name": "base",
191 | "language": "python",
192 | "name": "python3"
193 | },
194 | "language_info": {
195 | "name": "python",
196 | "version": "3.10.14"
197 | }
198 | },
199 | "nbformat": 4,
200 | "nbformat_minor": 2
201 | }
202 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # AgiBot-World Contributions
4 |
5 |
6 |
7 | `AgiBot-World` is a joint effort from multiple and diverse research teams.
8 |
9 | - Here we list all contributors towards constructing this platform from 2024 and onwards.
10 | - **Note that** names marked with `*` do not appear in the [technical report](https://arxiv.org/abs/2503.06669) and yet are much appreciated as the project evolves.
11 | - We encourage all future endeavors from the community!
12 |
13 | ### 🌟 Core Contributors
14 | > The whole span of project, including data collection, algorithm, experiment, and writing
15 |
16 | - *[Qingwen Bu](https://scholar.google.com/citations?user=-JCRysgAAAAJ&hl=zh-CN), [Guanghui Ren](https://scholar.google.com/citations?hl=zh-CN&user=oqN1dA8AAAAJ), [Chiming Liu](https://scholar.google.co.uk/citations?user=VuL0zQkAAAAJ&hl=en), [Chengen Xie](https://scholar.google.com/citations?hl=zh-CN&user=-Sk1x_gAAAAJ), [Modi Shi](https://github.com/ModiShi), [Xindong He](https://scholar.google.com/citations?view_op=list_works&hl=en&user=YAuiW5MAAAAJ), [Jianheng Song](https://github.com/JianJianHeng), [Yuxiang Lu](https://scholar.google.com/citations?hl=zh-CN&user=7m-TOp8AAAAJ), [Siyuan Feng](https://github.com/Eralien)*
17 |
18 |
19 | ### 🌏 Algorithm
20 | > Technical roadmap, model training and evaluation
21 |
22 | **Roadmap and Methodology**
23 | - *[Yao Mu](https://yaomarkmu.github.io/), [Li Chen](https://ilnehc.github.io/), [Yan Ding](https://yding25.com/), [Yixuan Pan](https://lzpyx.github.io/)\**
24 |
25 | **Pre-training**
26 | - *Yi Liu, Yuxin Jiang, Xiuqi Cui*
27 |
28 | **Post-training**
29 | - *Ziyu Xiong, Xu Huang, Dafeng Wei*
30 |
31 | **Deployment & Evaluation**
32 | - *Guo Xu, [Shu Jiang](https://scholar.google.com.hk/citations?user=oPZpk1oAAAAJ&hl=zh-CN)\*, Chengshi Shi\**
33 |
34 | ### 💫 Product & Ecosystem
35 | > System architecture design, project management, community engagement
36 | - *Chengyue Zhao, Shukai Yang, [Huijie Wang](https://faikit.github.io/), Yongjian Shen, Jialu Li, Jiaqi Zhao, Jianchao Zhu, Jiaqi Shan*
37 |
38 | ### 📖 Manuscript Preparation
39 | > Manuscript outline, writing and revising
40 | - *[Jisong Cai](https://scholar.google.com/citations?hl=zh-CN&user=dTrpq94AAAAJ), [Chonghao Sima](https://scholar.google.com/citations?user=dgYJ6esAAAAJ), [Shenyuan Gao](https://scholar.google.com/citations?user=hZtOnecAAAAJ)\**
41 |
42 | ### 🦾 Data Curation
43 | > Data collection, quality check
44 | - *Cheng Ruan, Jia Zeng, Lei Yang*
45 |
46 | ### 🛠️ Hardware & Software Development
47 | > Hardware design, embedded software development
48 | - *Yuehan Niu, Cheng Jing, Mingkang Shi, Chi Zhang, Qinglin Zhang, Cunbiao Yang, [Wenhao Wang](https://hao-starrr.github.io/), [Xuan Hu](https://github.com/huxuan)*
49 |
50 | ### 🚀 Project Co-lead and Advising
51 | > Research direction, project coordination, technical advising
52 | - *Maoqing Yao, Yu Qiao, Hongyang Li, [Jianlan Luo](https://scholar.google.co.uk/citations?user=SJoRNbYAAAAJ&hl=en&oi=ao), [Jiangmiao Pang](https://scholar.google.co.uk/citations?user=ssSfKpAAAAAJ&hl=en&oi=ao), [Bin Zhao](https://scholar.google.com/citations?user=DQB0hqwAAAAJ), [Junchi Yan](https://scholar.google.co.uk/citations?user=ga230VoAAAAJ&hl=en&oi=ao), Ping Luo*
53 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | 
4 |
5 |
Research Blog on March 10 |
Technical Report
6 |
7 |

[](https://huggingface.co/agibot-world) [](https://agibot-world.com) [](https://creativecommons.org/licenses/by-nc-sa/4.0/)
8 |

9 |
10 |
11 |
12 | AgiBot World Colosseo is a full-stack large-scale robot learning platform curated for advancing bimanual manipulation in scalable and intelligent embodied systems. It is accompanied by foundation models, benchmarks, and an ecosystem to democratize access to high-quality robot data for the academic community and the industry, paving the path towards the "ImageNet Moment" for Embodied AI.
13 |
14 | We have released:
15 | - **Task Catalog:** Reference sheet outlining the tasks in our dataset, including robot end-effector types, sample action-text descriptions and more
16 | - **AgiBot World Beta:** Our complete dataset featuring 1,003,672 trajectories (~43.8T)
17 | - **AgiBot World Alpha:** A curated subset of AgiBot World Beta, containing 92,214 trajectories (~8.5T)
18 |
19 | ## News📰
20 |
21 | - **`[2025/03/10]`** 📄 Research Blog and Technical Report released.
22 | - **`[2025/03/01]`** Agibot World Beta released.
23 | - **`[2025/01/03]`** Agibot World Alpha Sample Dataset released.
24 | - **`[2024/12/30]`** 🤖 Agibot World Alpha released.
25 |
26 | ## TODO List 📅
27 |
28 | - [x] **AgiBot World Alpha**
29 | - [x] **AgiBot World Beta** (expected Q1 2025)
30 | - [x] ~1,000,000 trajectories of high-quality robot data
31 | - [ ] **AgiBot World Foundation Model: GO-1** (expected Q2 2025)
32 | - [ ] Training & inference code
33 | - [ ] Pretrained model checkpoint
34 | - [ ] **AgiBot World Colosseo** (expected 2025)
35 | - [ ] A comprehensive platform with toolkits including teleoperation, training and inference.
36 | - [ ] **2025 AgiBot World Challenge** (expected 2025)
37 |
38 | ## Key Features 🔑
39 |
40 | - **1 million+** trajectories from 100 robots.
41 | - **100+ 1:1 replicated real-life scenarios** across 5 target domains.
42 | - **Cutting-edge hardware:** visual tactile sensors / 6-DoF Dexterous hand / mobile dual-arm robots
43 | - **Wide-spectrum versatile challenging tasks**
44 |
45 |
46 |
47 |
48 |
49 |
50 | Contact-rich Manipulation
51 | |
52 |
53 |
54 | Long-horizon Planning
55 | |
56 |
57 |
58 | Multi-robot Collaboration
59 | |
60 |
61 |
62 |
63 |
64 |
65 |
66 | ## Table of Contents
67 |
68 | 1. [Key Features](#keyfeatures)
69 | 2. [At a Quick Glance](#quickglance)
70 | 3. [Getting Started](#installation)
71 | - [Installation](#training)
72 | - [How to Get Started with Our AgiBot World Data](#preaparedata)
73 | - [Visualize Datasets](#visualizedatasets)
74 | - [Policy Learning Quickstart](#training)
75 | 4. [TODO List](#todolist)
76 | 5. [License and Citation](#liscenseandcitation)
77 |
78 | ## At a Quick Glance⬇️
79 |
80 | Follow the steps below to quickly explore and get an overview of AgiBot World with our [sample dataset](https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha/blob/main/sample_dataset.tar) (~7GB).
81 |
82 | ```bash
83 | # Installation
84 | conda create -n agibotworld python=3.10 -y
85 | conda activate agibotworld
86 | pip install git+https://github.com/huggingface/lerobot@59e275743499c5811a9f651a8947e8f881c4058c
87 | pip install matplotlib
88 | git clone https://github.com/OpenDriveLab/AgiBot-World.git
89 | cd AgiBot-World
90 |
91 | # Download the sample dataset (~7GB) from Hugging Face. Replace with your Hugging Face Access Token. You can generate an access token by following the instructions in the Hugging Face documentation from https://huggingface.co/docs/hub/security-tokens
92 | mkdir data
93 | cd data
94 | curl -L -o sample_dataset.tar -H "Authorization: Bearer " https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha/resolve/main/sample_dataset.tar
95 | tar -xvf sample_dataset.tar
96 |
97 | # Convert the sample dataset to LeRobot dataset format and visualize
98 | cd ..
99 | python scripts/convert_to_lerobot.py --src_path ./data/sample_dataset --task_id 390 --tgt_path ./data/sample_lerobot
100 | python scripts/visualize_dataset.py --task-id 390 --dataset-path ./data/sample_lerobot
101 | ```
102 |
103 | ## Getting started 🔥
104 |
105 | #### Installation
106 |
107 | Download our source code:
108 | ```bash
109 | git clone https://github.com/OpenDriveLab/AgiBot-World.git
110 | cd AgiBot-World
111 | ```
112 |
113 | Our project is built upon the [lerobot library](https://github.com/huggingface/lerobot) (**dataset `v2.0`, commit 59e2757**),
114 | install lerobot through
115 | ```bash
116 | pip install git+https://github.com/huggingface/lerobot@59e275743499c5811a9f651a8947e8f881c4058c
117 | ```
118 |
119 | #### How to Get Started with Our AgiBot World Data
120 |
121 | - [OPTION 1] Download data from our [OpenDataLab](https://opendatalab.com/OpenDriveLab/AgiBot-World) page.
122 |
123 | ```bash
124 | pip install openxlab # install CLI
125 | openxlab dataset get --dataset-repo OpenDriveLab/AgiBot-World # dataset download
126 | ```
127 |
128 | - [OPTION 2] Download data from our [HuggingFace](https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha) page.
129 |
130 | ```bash
131 | huggingface-cli download --resume-download --repo-type dataset agibot-world/AgiBotWorld-Alpha --local-dir ./AgiBotWorld-Alpha
132 | ```
133 |
134 | Convert the data to **LeRobot Dataset** format.
135 |
136 | ```bash
137 | python scripts/convert_to_lerobot.py --src_path /path/to/agibotworld/alpha --task_id 390 --tgt_path /path/to/save/lerobot
138 | ```
139 |
140 | #### Visualize Datasets
141 |
142 | We adapt and extend the dataset visualization script from [LeRobot Project](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/visualize_dataset.py)
143 |
144 | ```bash
145 | python scripts/visualize_dataset.py --task-id 390 --dataset-path /path/to/lerobot/format/dataset
146 | ```
147 |
148 | It will open `rerun.io` and display the camera streams, robot states and actions, like this:
149 |
150 |

151 |
152 |
153 | #### Policy Training Quickstart
154 |
155 | Leveraging the simplicity of [LeRobot Dataset](https://github.com/huggingface/lerobot), we provide a user-friendly [Jupyter Notebook](https://github.com/OpenDriveLab/AgiBot-World/blob/main/AgibotWorld.ipynb) for training diffusion policy on AgiBot World Dataset.
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 | ## License and Citation📄
165 |
166 | All the data and code within this repo are under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/).
167 |
168 | - Please consider citing our work if it helps your research.
169 | - For the full authorship and detailed contributions, please refer to [contributions](CONTRIBUTING.md).
170 | - In alphabetical order by surname:
171 | ```BibTeX
172 | @article{bu2025agibot,
173 | title={Agibot world colosseo: A large-scale manipulation platform for scalable and intelligent embodied systems},
174 | author={Bu, Qingwen and Cai, Jisong and Chen, Li and Cui, Xiuqi and Ding, Yan and Feng, Siyuan and Gao, Shenyuan and He, Xindong and Huang, Xu and Jiang, Shu and others},
175 | journal={arXiv preprint arXiv:2503.06669},
176 | year={2025}
177 | }
178 | ```
179 |
--------------------------------------------------------------------------------
/assets/Contact-rich_manipulation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/Contact-rich_manipulation.gif
--------------------------------------------------------------------------------
/assets/Long-horizon_planning.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/Long-horizon_planning.gif
--------------------------------------------------------------------------------
/assets/Multi-robot_collaboration.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/Multi-robot_collaboration.gif
--------------------------------------------------------------------------------
/assets/dataset_visualization.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/dataset_visualization.gif
--------------------------------------------------------------------------------
/scripts/convert_to_lerobot.py:
--------------------------------------------------------------------------------
1 | """
2 | This project is built upon the open-source project 🤗 LeRobot: https://github.com/huggingface/lerobot
3 |
4 | We are grateful to the LeRobot team for their outstanding work and their contributions to the community.
5 |
6 | If you find this project useful, please also consider supporting and exploring LeRobot.
7 | """
8 |
9 | import os
10 | import json
11 | import shutil
12 | import logging
13 | import argparse
14 | import gc
15 | from pathlib import Path
16 | from typing import Callable
17 | from functools import partial
18 | from math import ceil
19 | from copy import deepcopy
20 |
21 | import h5py
22 | import torch
23 | import einops
24 | import numpy as np
25 | from PIL import Image
26 | from tqdm import tqdm
27 | from pprint import pformat
28 | from tqdm.contrib.concurrent import process_map
29 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
30 | from lerobot.common.datasets.utils import (
31 | STATS_PATH,
32 | check_timestamps_sync,
33 | get_episode_data_index,
34 | serialize_dict,
35 | write_json,
36 | )
37 |
38 | HEAD_COLOR = "head_color.mp4"
39 | HAND_LEFT_COLOR = "hand_left_color.mp4"
40 | HAND_RIGHT_COLOR = "hand_right_color.mp4"
41 | HEAD_CENTER_FISHEYE_COLOR = "head_center_fisheye_color.mp4"
42 | HEAD_LEFT_FISHEYE_COLOR = "head_left_fisheye_color.mp4"
43 | HEAD_RIGHT_FISHEYE_COLOR = "head_right_fisheye_color.mp4"
44 | BACK_LEFT_FISHEYE_COLOR = "back_left_fisheye_color.mp4"
45 | BACK_RIGHT_FISHEYE_COLOR = "back_right_fisheye_color.mp4"
46 | HEAD_DEPTH = "head_depth"
47 |
48 | DEFAULT_IMAGE_PATH = (
49 | "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.jpg"
50 | )
51 |
52 | FEATURES = {
53 | "observation.images.top_head": {
54 | "dtype": "video",
55 | "shape": [480, 640, 3],
56 | "names": ["height", "width", "channel"],
57 | "video_info": {
58 | "video.fps": 30.0,
59 | "video.codec": "av1",
60 | "video.pix_fmt": "yuv420p",
61 | "video.is_depth_map": False,
62 | "has_audio": False,
63 | },
64 | },
65 | "observation.images.cam_top_depth": {
66 | "dtype": "image",
67 | "shape": [480, 640, 1],
68 | "names": ["height", "width", "channel"],
69 | },
70 | "observation.images.hand_left": {
71 | "dtype": "video",
72 | "shape": [480, 640, 3],
73 | "names": ["height", "width", "channel"],
74 | "video_info": {
75 | "video.fps": 30.0,
76 | "video.codec": "av1",
77 | "video.pix_fmt": "yuv420p",
78 | "video.is_depth_map": False,
79 | "has_audio": False,
80 | },
81 | },
82 | "observation.images.hand_right": {
83 | "dtype": "video",
84 | "shape": [480, 640, 3],
85 | "names": ["height", "width", "channel"],
86 | "video_info": {
87 | "video.fps": 30.0,
88 | "video.codec": "av1",
89 | "video.pix_fmt": "yuv420p",
90 | "video.is_depth_map": False,
91 | "has_audio": False,
92 | },
93 | },
94 | "observation.images.head_center_fisheye": {
95 | "dtype": "video",
96 | "shape": [748, 960, 3],
97 | "names": ["height", "width", "channel"],
98 | "video_info": {
99 | "video.fps": 30.0,
100 | "video.codec": "av1",
101 | "video.pix_fmt": "yuv420p",
102 | "video.is_depth_map": False,
103 | "has_audio": False,
104 | },
105 | },
106 | "observation.images.head_left_fisheye": {
107 | "dtype": "video",
108 | "shape": [748, 960, 3],
109 | "names": ["height", "width", "channel"],
110 | "video_info": {
111 | "video.fps": 30.0,
112 | "video.codec": "av1",
113 | "video.pix_fmt": "yuv420p",
114 | "video.is_depth_map": False,
115 | "has_audio": False,
116 | },
117 | },
118 | "observation.images.head_right_fisheye": {
119 | "dtype": "video",
120 | "shape": [748, 960, 3],
121 | "names": ["height", "width", "channel"],
122 | "video_info": {
123 | "video.fps": 30.0,
124 | "video.codec": "av1",
125 | "video.pix_fmt": "yuv420p",
126 | "video.is_depth_map": False,
127 | "has_audio": False,
128 | },
129 | },
130 | "observation.images.back_left_fisheye": {
131 | "dtype": "video",
132 | "shape": [748, 960, 3],
133 | "names": ["height", "width", "channel"],
134 | "video_info": {
135 | "video.fps": 30.0,
136 | "video.codec": "av1",
137 | "video.pix_fmt": "yuv420p",
138 | "video.is_depth_map": False,
139 | "has_audio": False,
140 | },
141 | },
142 | "observation.images.back_right_fisheye": {
143 | "dtype": "video",
144 | "shape": [748, 960, 3],
145 | "names": ["height", "width", "channel"],
146 | "video_info": {
147 | "video.fps": 30.0,
148 | "video.codec": "av1",
149 | "video.pix_fmt": "yuv420p",
150 | "video.is_depth_map": False,
151 | "has_audio": False,
152 | },
153 | },
154 | "observation.state": {
155 | "dtype": "float32",
156 | "shape": [20],
157 | },
158 | "action": {
159 | "dtype": "float32",
160 | "shape": [22],
161 | },
162 | "episode_index": {
163 | "dtype": "int64",
164 | "shape": [1],
165 | "names": None,
166 | },
167 | "frame_index": {
168 | "dtype": "int64",
169 | "shape": [1],
170 | "names": None,
171 | },
172 | "index": {
173 | "dtype": "int64",
174 | "shape": [1],
175 | "names": None,
176 | },
177 | "task_index": {
178 | "dtype": "int64",
179 | "shape": [1],
180 | "names": None,
181 | },
182 | }
183 |
184 |
185 | def get_stats_einops_patterns(dataset, num_workers=0):
186 | """These einops patterns will be used to aggregate batches and compute statistics.
187 |
188 | Note: We assume the images are in channel first format
189 | """
190 |
191 | dataloader = torch.utils.data.DataLoader(
192 | dataset,
193 | num_workers=num_workers,
194 | batch_size=2,
195 | shuffle=False,
196 | )
197 | batch = next(iter(dataloader))
198 |
199 | stats_patterns = {}
200 |
201 | for key in dataset.features:
202 | # sanity check that tensors are not float64
203 | assert batch[key].dtype != torch.float64
204 |
205 | # if isinstance(feats_type, (VideoFrame, Image)):
206 | if key in dataset.meta.camera_keys:
207 | # sanity check that images are channel first
208 | _, c, h, w = batch[key].shape
209 | assert (
210 | c < h and c < w
211 | ), f"expect channel first images, but instead {batch[key].shape}"
212 | assert (
213 | batch[key].dtype == torch.float32
214 | ), f"expect torch.float32, but instead {batch[key].dtype=}"
215 | # assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
216 | # assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
217 | stats_patterns[key] = "b c h w -> c 1 1"
218 | elif batch[key].ndim == 2:
219 | stats_patterns[key] = "b c -> c "
220 | elif batch[key].ndim == 1:
221 | stats_patterns[key] = "b -> 1"
222 | else:
223 | raise ValueError(f"{key}, {batch[key].shape}")
224 |
225 | return stats_patterns
226 |
227 |
228 | def compute_stats(dataset, batch_size=8, num_workers=4, max_num_samples=None):
229 | """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
230 | if max_num_samples is None:
231 | max_num_samples = len(dataset)
232 |
233 | # for more info on why we need to set the same number of workers, see `load_from_videos`
234 | stats_patterns = get_stats_einops_patterns(dataset, num_workers)
235 |
236 | # mean and std will be computed incrementally while max and min will track the running value.
237 | mean, std, max, min = {}, {}, {}, {}
238 | for key in stats_patterns:
239 | mean[key] = torch.tensor(0.0).float()
240 | std[key] = torch.tensor(0.0).float()
241 | max[key] = torch.tensor(-float("inf")).float()
242 | min[key] = torch.tensor(float("inf")).float()
243 |
244 | def create_seeded_dataloader(dataset, batch_size, seed):
245 | generator = torch.Generator()
246 | generator.manual_seed(seed)
247 | dataloader = torch.utils.data.DataLoader(
248 | dataset,
249 | num_workers=num_workers,
250 | batch_size=batch_size,
251 | shuffle=True,
252 | drop_last=False,
253 | generator=generator,
254 | )
255 | return dataloader
256 |
257 | # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
258 | # surprises when rerunning the sampler.
259 | first_batch = None
260 | running_item_count = 0 # for online mean computation
261 | dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
262 | for i, batch in enumerate(
263 | tqdm(
264 | dataloader,
265 | total=ceil(max_num_samples / batch_size),
266 | desc="Compute mean, min, max",
267 | )
268 | ):
269 | this_batch_size = len(batch["index"])
270 | running_item_count += this_batch_size
271 | if first_batch is None:
272 | first_batch = deepcopy(batch)
273 | for key, pattern in stats_patterns.items():
274 | batch[key] = batch[key].float()
275 | # Numerically stable update step for mean computation.
276 | batch_mean = einops.reduce(batch[key], pattern, "mean")
277 | # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
278 | # the update step, N is the running item count, B is this batch size, x̄ is the running mean,
279 | # and x is the current batch mean. Some rearrangement is then required to avoid risking
280 | # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
281 | # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
282 | mean[key] = (
283 | mean[key]
284 | + this_batch_size * (batch_mean - mean[key]) / running_item_count
285 | )
286 | max[key] = torch.maximum(
287 | max[key], einops.reduce(batch[key], pattern, "max")
288 | )
289 | min[key] = torch.minimum(
290 | min[key], einops.reduce(batch[key], pattern, "min")
291 | )
292 |
293 | if i == ceil(max_num_samples / batch_size) - 1:
294 | break
295 |
296 | first_batch_ = None
297 | running_item_count = 0 # for online std computation
298 | dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
299 | for i, batch in enumerate(
300 | tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
301 | ):
302 | this_batch_size = len(batch["index"])
303 | running_item_count += this_batch_size
304 | # Sanity check to make sure the batches are still in the same order as before.
305 | if first_batch_ is None:
306 | first_batch_ = deepcopy(batch)
307 | for key in stats_patterns:
308 | assert torch.equal(first_batch_[key], first_batch[key])
309 | for key, pattern in stats_patterns.items():
310 | batch[key] = batch[key].float()
311 | # Numerically stable update step for mean computation (where the mean is over squared
312 | # residuals).See notes in the mean computation loop above.
313 | batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
314 | std[key] = (
315 | std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
316 | )
317 |
318 | if i == ceil(max_num_samples / batch_size) - 1:
319 | break
320 |
321 | for key in stats_patterns:
322 | std[key] = torch.sqrt(std[key])
323 |
324 | stats = {}
325 | for key in stats_patterns:
326 | stats[key] = {
327 | "mean": mean[key],
328 | "std": std[key],
329 | "max": max[key],
330 | "min": min[key],
331 | }
332 | return stats
333 |
334 |
335 | class AgiBotDataset(LeRobotDataset):
336 | def __init__(
337 | self,
338 | repo_id: str,
339 | root: str | Path | None = None,
340 | episodes: list[int] | None = None,
341 | image_transforms: Callable | None = None,
342 | delta_timestamps: dict[list[float]] | None = None,
343 | tolerance_s: float = 1e-4,
344 | download_videos: bool = True,
345 | local_files_only: bool = False,
346 | video_backend: str | None = None,
347 | ):
348 | super().__init__(
349 | repo_id=repo_id,
350 | root=root,
351 | episodes=episodes,
352 | image_transforms=image_transforms,
353 | delta_timestamps=delta_timestamps,
354 | tolerance_s=tolerance_s,
355 | download_videos=download_videos,
356 | local_files_only=local_files_only,
357 | video_backend=video_backend,
358 | )
359 |
360 | def save_episode(
361 | self, task: str, episode_data: dict | None = None, videos: dict | None = None
362 | ) -> None:
363 | """
364 | We rewrite this method to copy mp4 videos to the target position
365 | """
366 | if not episode_data:
367 | episode_buffer = self.episode_buffer
368 |
369 | episode_length = episode_buffer.pop("size")
370 | episode_index = episode_buffer["episode_index"]
371 | if episode_index != self.meta.total_episodes:
372 | # TODO(aliberts): Add option to use existing episode_index
373 | raise NotImplementedError(
374 | "You might have manually provided the episode_buffer with an episode_index that doesn't "
375 | "match the total number of episodes in the dataset. This is not supported for now."
376 | )
377 |
378 | if episode_length == 0:
379 | raise ValueError(
380 | "You must add one or several frames with `add_frame` before calling `add_episode`."
381 | )
382 |
383 | task_index = self.meta.get_task_index(task)
384 |
385 | if not set(episode_buffer.keys()) == set(self.features):
386 | raise ValueError()
387 |
388 | for key, ft in self.features.items():
389 | if key == "index":
390 | episode_buffer[key] = np.arange(
391 | self.meta.total_frames, self.meta.total_frames + episode_length
392 | )
393 | elif key == "episode_index":
394 | episode_buffer[key] = np.full((episode_length,), episode_index)
395 | elif key == "task_index":
396 | episode_buffer[key] = np.full((episode_length,), task_index)
397 | elif ft["dtype"] in ["image", "video"]:
398 | continue
399 | elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
400 | episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
401 | elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
402 | episode_buffer[key] = np.stack(episode_buffer[key])
403 | else:
404 | raise ValueError(key)
405 |
406 | self._wait_image_writer()
407 | self._save_episode_table(episode_buffer, episode_index)
408 |
409 | self.meta.save_episode(episode_index, episode_length, task, task_index)
410 | for key in self.meta.video_keys:
411 | video_path = self.root / self.meta.get_video_file_path(episode_index, key)
412 | episode_buffer[key] = video_path
413 | video_path.parent.mkdir(parents=True, exist_ok=True)
414 | shutil.copyfile(videos[key], video_path)
415 | if not episode_data: # Reset the buffer
416 | self.episode_buffer = self.create_episode_buffer()
417 | self.consolidated = False
418 |
419 | def consolidate(
420 | self, run_compute_stats: bool = True, keep_image_files: bool = False
421 | ) -> None:
422 | self.hf_dataset = self.load_hf_dataset()
423 | self.episode_data_index = get_episode_data_index(
424 | self.meta.episodes, self.episodes
425 | )
426 | check_timestamps_sync(
427 | self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s
428 | )
429 | if len(self.meta.video_keys) > 0:
430 | self.meta.write_video_info()
431 |
432 | if not keep_image_files:
433 | img_dir = self.root / "images"
434 | if img_dir.is_dir():
435 | shutil.rmtree(self.root / "images")
436 | video_files = list(self.root.rglob("*.mp4"))
437 | assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
438 |
439 | parquet_files = list(self.root.rglob("*.parquet"))
440 | assert len(parquet_files) == self.num_episodes
441 |
442 | if run_compute_stats:
443 | self.stop_image_writer()
444 | self.meta.stats = compute_stats(self)
445 | serialized_stats = serialize_dict(self.meta.stats)
446 | write_json(serialized_stats, self.root / STATS_PATH)
447 | self.consolidated = True
448 | else:
449 | logging.warning(
450 | "Skipping computation of the dataset statistics, dataset is not fully consolidated."
451 | )
452 |
453 | def add_frame(self, frame: dict) -> None:
454 | """
455 | This function only adds the frame to the episode_buffer. Apart from images — which are written in a
456 | temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
457 | then needs to be called.
458 | """
459 | # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
460 | # check the dtype and shape matches, etc.
461 |
462 | if self.episode_buffer is None:
463 | self.episode_buffer = self.create_episode_buffer()
464 |
465 | frame_index = self.episode_buffer["size"]
466 | timestamp = (
467 | frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
468 | )
469 | self.episode_buffer["frame_index"].append(frame_index)
470 | self.episode_buffer["timestamp"].append(timestamp)
471 |
472 | for key in frame:
473 | if key not in self.features:
474 | raise ValueError(key)
475 | item = (
476 | frame[key].numpy()
477 | if isinstance(frame[key], torch.Tensor)
478 | else frame[key]
479 | )
480 | self.episode_buffer[key].append(item)
481 |
482 | self.episode_buffer["size"] += 1
483 |
484 |
485 | def load_depths(root_dir: str, camera_name: str):
486 | cam_path = Path(root_dir)
487 | all_imgs = sorted(list(cam_path.glob(f"{camera_name}*")))
488 | return [np.array(Image.open(f)).astype(np.float32) / 1000 for f in all_imgs]
489 |
490 |
491 | def load_local_dataset(episode_id: int, src_path: str, task_id: int) -> list | None:
492 | """Load local dataset and return a dict with observations and actions"""
493 |
494 | ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
495 | depth_imgs = load_depths(ob_dir / "depth", HEAD_DEPTH)
496 | proprio_dir = Path(src_path) / f"proprio_stats/{task_id}/{episode_id}"
497 |
498 | with h5py.File(proprio_dir / "proprio_stats.h5") as f:
499 | state_joint = np.array(f["state/joint/position"])
500 | state_effector = np.array(f["state/effector/position"])
501 | state_head = np.array(f["state/head/position"])
502 | state_waist = np.array(f["state/waist/position"])
503 | action_joint = np.array(f["action/joint/position"])
504 | action_effector = np.array(f["action/effector/position"])
505 | action_head = np.array(f["action/head/position"])
506 | action_waist = np.array(f["action/waist/position"])
507 | action_velocity = np.array(f["action/robot/velocity"])
508 |
509 | states_value = np.hstack(
510 | [state_joint, state_effector, state_head, state_waist]
511 | ).astype(np.float32)
512 | assert (
513 | action_joint.shape[0] == action_effector.shape[0]
514 | ), f"shape of action_joint:{action_joint.shape};shape of action_effector:{action_effector.shape}"
515 | action_value = np.hstack(
516 | [action_joint, action_effector, action_head, action_waist, action_velocity]
517 | ).astype(np.float32)
518 |
519 | assert len(depth_imgs) == len(
520 | states_value
521 | ), f"Number of images and states are not equal"
522 | assert len(depth_imgs) == len(
523 | action_value
524 | ), f"Number of images and actions are not equal"
525 | frames = [
526 | {
527 | "observation.images.cam_top_depth": depth_imgs[i],
528 | "observation.state": states_value[i],
529 | "action": action_value[i],
530 | }
531 | for i in range(len(depth_imgs))
532 | ]
533 |
534 | v_path = ob_dir / "videos"
535 | videos = {
536 | "observation.images.top_head": v_path / HEAD_COLOR,
537 | "observation.images.hand_left": v_path / HAND_LEFT_COLOR,
538 | "observation.images.hand_right": v_path / HAND_RIGHT_COLOR,
539 | "observation.images.head_center_fisheye": v_path / HEAD_CENTER_FISHEYE_COLOR,
540 | "observation.images.head_left_fisheye": v_path / HEAD_LEFT_FISHEYE_COLOR,
541 | "observation.images.head_right_fisheye": v_path / HEAD_RIGHT_FISHEYE_COLOR,
542 | "observation.images.back_left_fisheye": v_path / BACK_LEFT_FISHEYE_COLOR,
543 | "observation.images.back_right_fisheye": v_path / BACK_RIGHT_FISHEYE_COLOR,
544 | }
545 | return frames, videos
546 |
547 |
548 | def get_task_instruction(task_json_path: str) -> dict:
549 | """Get task language instruction"""
550 | with open(task_json_path, "r") as f:
551 | task_info = json.load(f)
552 | task_name = task_info[0]["task_name"]
553 | task_init_scene = task_info[0]["init_scene_text"]
554 | task_instruction = f"{task_name}.{task_init_scene}"
555 | print(f"Get Task Instruction <{task_instruction}>")
556 | return task_instruction
557 |
558 |
559 | def main(
560 | src_path: str,
561 | tgt_path: str,
562 | task_id: int,
563 | repo_id: str,
564 | task_info_json: str,
565 | debug: bool = False,
566 | chunk_size: int = 10 # Add chunk size parameter
567 | ):
568 | task_name = get_task_instruction(task_info_json)
569 |
570 | dataset = AgiBotDataset.create(
571 | repo_id=repo_id,
572 | root=f"{tgt_path}/{repo_id}",
573 | fps=30,
574 | robot_type="a2d",
575 | features=FEATURES,
576 | )
577 |
578 | all_subdir = sorted(
579 | [
580 | f.as_posix()
581 | for f in Path(src_path).glob(f"observations/{task_id}/*")
582 | if f.is_dir()
583 | ]
584 | )
585 |
586 | if debug:
587 | all_subdir = all_subdir[:2]
588 |
589 | # Get all episode id
590 | all_subdir_eids = [int(Path(path).name) for path in all_subdir]
591 | all_subdir_episode_desc = [task_name] * len(all_subdir_eids)
592 |
593 | # Process in chunks to reduce memory usage
594 | for chunk_start in tqdm(range(0, len(all_subdir_eids), chunk_size), desc="Processing chunks"):
595 | chunk_end = min(chunk_start + chunk_size, len(all_subdir_eids))
596 | chunk_eids = all_subdir_eids[chunk_start:chunk_end]
597 | chunk_descs = all_subdir_episode_desc[chunk_start:chunk_end]
598 |
599 | # Process only this chunk
600 | if debug:
601 | raw_datasets_chunk = [
602 | load_local_dataset(subdir, src_path=src_path, task_id=task_id)
603 | for subdir in tqdm(chunk_eids, desc="Loading chunk data")
604 | ]
605 | else:
606 | raw_datasets_chunk = process_map(
607 | partial(load_local_dataset, src_path=src_path, task_id=task_id),
608 | chunk_eids,
609 | max_workers=os.cpu_count() // 2,
610 | desc=f"Loading chunk {chunk_start//chunk_size + 1}/{(len(all_subdir_eids) + chunk_size - 1)//chunk_size}",
611 | )
612 |
613 | # Filter out None results
614 | valid_datasets = [(ds, desc) for ds, desc in zip(raw_datasets_chunk, chunk_descs) if ds is not None]
615 |
616 | # Process each dataset in the chunk
617 | for raw_dataset, episode_desc in tqdm(valid_datasets, desc="Processing episodes in chunk"):
618 | for raw_dataset_sub in tqdm(
619 | raw_dataset[0], desc="Processing frames", leave=False
620 | ):
621 | dataset.add_frame(raw_dataset_sub)
622 | dataset.save_episode(task=episode_desc, videos=raw_dataset[1])
623 |
624 | # Clear memory after each chunk
625 | raw_datasets_chunk = None
626 | valid_datasets = None
627 | gc.collect()
628 |
629 | # Only consolidate at the end
630 | dataset.consolidate()
631 |
632 |
633 | if __name__ == "__main__":
634 | parser = argparse.ArgumentParser()
635 | parser.add_argument(
636 | "--src_path",
637 | type=str,
638 | required=True,
639 | )
640 | parser.add_argument(
641 | "--task_id",
642 | type=str,
643 | required=True,
644 | )
645 | parser.add_argument(
646 | "--tgt_path",
647 | type=str,
648 | required=True,
649 | )
650 | parser.add_argument(
651 | "--debug",
652 | action="store_true",
653 | )
654 | parser.add_argument(
655 | "--chunk_size",
656 | type=int,
657 | default=10,
658 | help="Number of episodes to process at once",
659 | )
660 | args = parser.parse_args()
661 |
662 | task_id = args.task_id
663 | json_file = f"{args.src_path}/task_info/task_{args.task_id}.json"
664 | dataset_base = f"agibotworld/task_{args.task_id}"
665 |
666 | assert Path(json_file).exists, f"Cannot find {json_file}."
667 | main(args.src_path, args.tgt_path, task_id, dataset_base, json_file, args.debug, args.chunk_size)
668 |
--------------------------------------------------------------------------------
/scripts/visualize_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is adapted from the Hugging Face 🤗 LeRobot project:
3 | https://github.com/huggingface/lerobot
4 |
5 | Original file:
6 | https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/visualize_dataset.py
7 |
8 | The original script was developed as part of the LeRobot project for dataset visualization.
9 | This version adds support for depth map visualization.
10 | """
11 |
12 | import argparse
13 | import gc
14 | import logging
15 | import time
16 | from pathlib import Path
17 | from typing import Iterator
18 |
19 | import numpy as np
20 | import rerun as rr
21 | import torch
22 | import torch.utils.data
23 | import tqdm
24 | import matplotlib.pyplot as plt
25 |
26 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
27 |
28 |
29 | class EpisodeSampler(torch.utils.data.Sampler):
30 | def __init__(self, dataset: LeRobotDataset, episode_index: int):
31 | from_idx = dataset.episode_data_index["from"][episode_index].item()
32 | to_idx = dataset.episode_data_index["to"][episode_index].item()
33 | self.frame_ids = range(from_idx, to_idx)
34 |
35 | def __iter__(self) -> Iterator:
36 | return iter(self.frame_ids)
37 |
38 | def __len__(self) -> int:
39 | return len(self.frame_ids)
40 |
41 |
42 | def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
43 | assert chw_float32_torch.dtype == torch.float32
44 | assert chw_float32_torch.ndim == 3
45 | c, h, w = chw_float32_torch.shape
46 | assert c < h and c < w, f"Expect channel first images, but instead {chw_float32_torch.shape}"
47 |
48 | if c == 1:
49 | # If depth image, clip and normalize the depth map just for visualization
50 | min_depth = 0.4
51 | max_depth = 3
52 | clipped_depth = torch.clamp(chw_float32_torch, min=min_depth, max=max_depth)
53 | normalized_depth = (clipped_depth-min_depth) / (max_depth-min_depth)
54 | depth_image = np.sqrt(normalized_depth.squeeze().cpu().numpy())
55 |
56 | colormap = plt.get_cmap('jet')
57 | colored_depth_image = colormap(depth_image)
58 | hwc_uint8_numpy = (colored_depth_image[:, :, :3] * 255).astype(np.uint8)
59 | else:
60 | # If RGB image
61 | hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
62 |
63 | return hwc_uint8_numpy
64 |
65 |
66 | def visualize_dataset(
67 | dataset: LeRobotDataset,
68 | episode_index: int,
69 | batch_size: int = 32,
70 | num_workers: int = 0,
71 | mode: str = "local",
72 | web_port: int = 9090,
73 | ws_port: int = 9087,
74 | save: bool = False,
75 | output_dir: Path | None = None,
76 | ) -> Path | None:
77 | if save:
78 | assert (
79 | output_dir is not None
80 | ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
81 |
82 | repo_id = dataset.repo_id
83 |
84 | logging.info("Loading dataloader")
85 | episode_sampler = EpisodeSampler(dataset, episode_index)
86 | dataloader = torch.utils.data.DataLoader(
87 | dataset,
88 | num_workers=num_workers,
89 | batch_size=batch_size,
90 | sampler=episode_sampler,
91 | )
92 |
93 | logging.info("Starting Rerun")
94 |
95 | if mode not in ["local", "distant"]:
96 | raise ValueError(mode)
97 |
98 | spawn_local_viewer = mode == "local" and not save
99 | rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
100 |
101 | # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
102 | # when iterating on a dataloader with `num_workers` > 0
103 | # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
104 | gc.collect()
105 |
106 | if mode == "distant":
107 | rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
108 |
109 | logging.info("Logging to Rerun")
110 |
111 | for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
112 | # iterate over the batch
113 | for i in range(len(batch["index"])):
114 | rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
115 | rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
116 |
117 | # display each camera image
118 | for key in dataset.meta.camera_keys:
119 | # TODO(rcadene): add `.compress()`? is it lossless?
120 | rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
121 |
122 | # display each dimension of action space (e.g. actuators command)
123 | if "action" in batch:
124 | for dim_idx, val in enumerate(batch["action"][i]):
125 | rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
126 |
127 | # display each dimension of observed state space (e.g. agent position in joint space)
128 | if "observation.state" in batch:
129 | for dim_idx, val in enumerate(batch["observation.state"][i]):
130 | rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
131 |
132 | if mode == "local" and save:
133 | # save .rrd locally
134 | output_dir = Path(output_dir)
135 | output_dir.mkdir(parents=True, exist_ok=True)
136 | repo_id_str = repo_id.replace("/", "_")
137 | rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
138 | rr.save(rrd_path)
139 | return rrd_path
140 |
141 | elif mode == "distant":
142 | # stop the process from exiting since it is serving the websocket connection
143 | try:
144 | while True:
145 | time.sleep(1)
146 | except KeyboardInterrupt:
147 | print("Ctrl-C received. Exiting.")
148 |
149 |
150 | def main():
151 | parser = argparse.ArgumentParser()
152 |
153 | parser.add_argument(
154 | "--task-id",
155 | type=int,
156 | default=None,
157 | help="Index of the AgiBot World task.",
158 | )
159 | parser.add_argument(
160 | "--episode-index",
161 | type=int,
162 | nargs="*",
163 | default=None,
164 | help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
165 | )
166 | parser.add_argument(
167 | "--dataset-path",
168 | type=Path,
169 | default=None,
170 | help="Root directory for the converted LeRobot dataset stored locally.",
171 | )
172 | parser.add_argument(
173 | "--output-dir",
174 | type=Path,
175 | default=None,
176 | help="Directory path to write a .rrd file when `--save 1` is set.",
177 | )
178 | parser.add_argument(
179 | "--batch-size",
180 | type=int,
181 | default=32,
182 | help="Batch size loaded by DataLoader.",
183 | )
184 | parser.add_argument(
185 | "--num-workers",
186 | type=int,
187 | default=4,
188 | help="Number of processes of Dataloader for loading the data.",
189 | )
190 | parser.add_argument(
191 | "--mode",
192 | type=str,
193 | default="local",
194 | help=(
195 | "Mode of viewing between 'local' or 'distant'. "
196 | "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
197 | "'distant' creates a server on the distant machine where the data is stored. "
198 | "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
199 | ),
200 | )
201 | parser.add_argument(
202 | "--web-port",
203 | type=int,
204 | default=9090,
205 | help="Web port for rerun.io when `--mode distant` is set.",
206 | )
207 | parser.add_argument(
208 | "--ws-port",
209 | type=int,
210 | default=9087,
211 | help="Web socket port for rerun.io when `--mode distant` is set.",
212 | )
213 | parser.add_argument(
214 | "--save",
215 | type=int,
216 | default=0,
217 | help=(
218 | "Save a .rrd file in the directory provided by `--output-dir`. "
219 | "It also deactivates the spawning of a viewer. "
220 | "Visualize the data by running `rerun path/to/file.rrd` on your local machine."
221 | ),
222 | )
223 |
224 | args = parser.parse_args()
225 | kwargs = vars(args)
226 | repo_id = f"agibotworld/task_{kwargs.pop('task_id')}"
227 | root = f"{kwargs.pop('dataset_path')}/{repo_id}"
228 |
229 | logging.info("Loading dataset")
230 | dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
231 |
232 | visualize_dataset(dataset, **vars(args))
233 |
234 | if __name__ == "__main__":
235 | main()
--------------------------------------------------------------------------------