├── .github └── FUNDING.yml ├── .gitignore ├── AgibotWorld.ipynb ├── CONTRIBUTING.md ├── README.md ├── assets ├── Contact-rich_manipulation.gif ├── Long-horizon_planning.gif ├── Multi-robot_collaboration.gif └── dataset_visualization.gif └── scripts ├── convert_to_lerobot.py └── visualize_dataset.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 3 | patreon: # Replace with a single Patreon username 4 | open_collective: # Replace with a single Open Collective username 5 | ko_fi: # Replace with a single Ko-fi username 6 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 7 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 8 | liberapay: # Replace with a single Liberapay username 9 | issuehunt: # Replace with a single IssueHunt username 10 | otechie: # Replace with a single Otechie username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /AgibotWorld.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# AgiBot World Diffusion Policy Training Demo\n", 8 | "\n", 9 | "This notebook demonstrates how to use **AgiBotWorldDataset** to run an offline training workflow.\n", 10 | "Make sure you have installed all necessary packages before running.\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "# =============================================\n", 20 | "# 1. Imports and Parameter Settings\n", 21 | "# =============================================\n", 22 | "import torch\n", 23 | "import numpy as np\n", 24 | "\n", 25 | "from lerobot.common.datasets.lerobot_dataset import LeRobotDataset\n", 26 | "from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig\n", 27 | "from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy\n", 28 | "\n", 29 | "# Parameters\n", 30 | "FPS = 30\n", 31 | "TASK_ID = 352\n", 32 | "training_steps = 5000\n", 33 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 34 | "\n", 35 | "# Paths\n", 36 | "dataset_path = \"/path/to/your/AgiBotWorld/dataset\"\n", 37 | "output_path = \"/path/to/save/your/checkpoint\"" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# =============================================\n", 47 | "# 2. Dataset Setup\n", 48 | "# =============================================\n", 49 | "observation_idx = np.array([-1, 0])\n", 50 | "action_idx = np.arange(-1, 15)\n", 51 | "repo_id = f\"agibotworld/task_{TASK_ID}\"\n", 52 | "\n", 53 | "delta_timestamps = {\n", 54 | " \"observation.images.top_head\": (observation_idx / FPS).tolist(),\n", 55 | " \"observation.state\": (observation_idx / FPS).tolist(),\n", 56 | " \"action\": (action_idx / FPS).tolist(),\n", 57 | "}\n", 58 | "\n", 59 | "dataset = LeRobotDataset(\n", 60 | " repo_id=repo_id,\n", 61 | " root=f\"{dataset_path}/{repo_id}\",\n", 62 | " delta_timestamps=delta_timestamps,\n", 63 | " local_files_only=True\n", 64 | ")\n", 65 | "\n", 66 | "dataloader = torch.utils.data.DataLoader(\n", 67 | " dataset,\n", 68 | " num_workers=0,\n", 69 | " batch_size=64,\n", 70 | " shuffle=True,\n", 71 | " pin_memory=(device.type == \"cuda\"),\n", 72 | " drop_last=True,\n", 73 | ")" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "If you want to train one robot policy model to master multiple distinct skills, you can use ’MultiLeRobotDataset‘ to load datasets for various tasks into a unified training process." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "from pathlib import Path\n", 90 | "from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset\n", 91 | "repo_ids = [f\"agibotworld/{path.name}\" for path in Path(dataset_path).glob(\"agibotworld/task_*\")]\n", 92 | "multi_dataset = MultiLeRobotDataset(\n", 93 | " repo_ids=repo_ids,\n", 94 | " root=dataset_path,\n", 95 | " delta_timestamps=delta_timestamps,\n", 96 | " local_files_only=True\n", 97 | ")" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "Let's kick off a simple training with Diffusion Policy:" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# =============================================\n", 114 | "# 3. Policy Configuration and Initialization\n", 115 | "# =============================================\n", 116 | "cfg = DiffusionConfig()\n", 117 | "cfg.input_shapes = {\n", 118 | " \"observation.images.top_head\": [3, 480, 640],\n", 119 | " \"observation.state\": [20],\n", 120 | "}\n", 121 | "cfg.input_normalization_modes = {\n", 122 | " \"observation.images.top_head\": \"mean_std\",\n", 123 | " \"observation.state\": \"min_max\",\n", 124 | "}\n", 125 | "cfg.output_shapes = {\n", 126 | " \"action\": [22],\n", 127 | "}\n", 128 | "\n", 129 | "policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)\n", 130 | "#policy = DiffusionPolicy(cfg, dataset_stats=multi_dataset.stats)\n", 131 | "policy.train()\n", 132 | "policy.to(device)\n", 133 | "\n", 134 | "optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# =============================================\n", 144 | "# 4. Training Loop\n", 145 | "# =============================================\n", 146 | "step = 0\n", 147 | "done = False\n", 148 | "\n", 149 | "while not done:\n", 150 | " for batch in dataloader:\n", 151 | " batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}\n", 152 | " output_dict = policy.forward(batch)\n", 153 | " loss = output_dict[\"loss\"]\n", 154 | " \n", 155 | " loss.backward()\n", 156 | " optimizer.step()\n", 157 | " optimizer.zero_grad()\n", 158 | " \n", 159 | " print(f\"Step {step}, Loss: {loss.item():.3f}\")\n", 160 | " step += 1\n", 161 | " \n", 162 | " if step >= training_steps:\n", 163 | " done = True\n", 164 | " break\n" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "# =============================================\n", 174 | "# 5. Save Policy Checkpoint\n", 175 | "# =============================================\n", 176 | "policy.save_pretrained(output_path)\n", 177 | "print(f\"Model saved to {output_path}\")\n" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Congrats! Now please feel free to explore the AgiBot World!" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "base", 191 | "language": "python", 192 | "name": "python3" 193 | }, 194 | "language_info": { 195 | "name": "python", 196 | "version": "3.10.14" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 2 201 | } 202 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # AgiBot-World Contributions 4 | 5 |
6 | 7 | `AgiBot-World` is a joint effort from multiple and diverse research teams. 8 | 9 | - Here we list all contributors towards constructing this platform from 2024 and onwards. 10 | - **Note that** names marked with `*` do not appear in the [technical report](https://arxiv.org/abs/2503.06669) and yet are much appreciated as the project evolves. 11 | - We encourage all future endeavors from the community! 12 | 13 | ### 🌟 Core Contributors 14 | > The whole span of project, including data collection, algorithm, experiment, and writing 15 | 16 | - *[Qingwen Bu](https://scholar.google.com/citations?user=-JCRysgAAAAJ&hl=zh-CN), [Guanghui Ren](https://scholar.google.com/citations?hl=zh-CN&user=oqN1dA8AAAAJ), [Chiming Liu](https://scholar.google.co.uk/citations?user=VuL0zQkAAAAJ&hl=en), [Chengen Xie](https://scholar.google.com/citations?hl=zh-CN&user=-Sk1x_gAAAAJ), [Modi Shi](https://github.com/ModiShi), [Xindong He](https://scholar.google.com/citations?view_op=list_works&hl=en&user=YAuiW5MAAAAJ), [Jianheng Song](https://github.com/JianJianHeng), [Yuxiang Lu](https://scholar.google.com/citations?hl=zh-CN&user=7m-TOp8AAAAJ), [Siyuan Feng](https://github.com/Eralien)* 17 | 18 | 19 | ### 🌏 Algorithm 20 | > Technical roadmap, model training and evaluation
21 | 22 | **Roadmap and Methodology**
23 | - *[Yao Mu](https://yaomarkmu.github.io/), [Li Chen](https://ilnehc.github.io/), [Yan Ding](https://yding25.com/), [Yixuan Pan](https://lzpyx.github.io/)\**
24 | 25 | **Pre-training**
26 | - *Yi Liu, Yuxin Jiang, Xiuqi Cui*
27 | 28 | **Post-training**
29 | - *Ziyu Xiong, Xu Huang, Dafeng Wei*
30 | 31 | **Deployment & Evaluation**
32 | - *Guo Xu, [Shu Jiang](https://scholar.google.com.hk/citations?user=oPZpk1oAAAAJ&hl=zh-CN)\*, Chengshi Shi\**
33 | 34 | ### 💫 Product & Ecosystem 35 | > System architecture design, project management, community engagement 36 | - *Chengyue Zhao, Shukai Yang, [Huijie Wang](https://faikit.github.io/), Yongjian Shen, Jialu Li, Jiaqi Zhao, Jianchao Zhu, Jiaqi Shan* 37 | 38 | ### 📖 Manuscript Preparation 39 | > Manuscript outline, writing and revising 40 | - *[Jisong Cai](https://scholar.google.com/citations?hl=zh-CN&user=dTrpq94AAAAJ), [Chonghao Sima](https://scholar.google.com/citations?user=dgYJ6esAAAAJ), [Shenyuan Gao](https://scholar.google.com/citations?user=hZtOnecAAAAJ)\** 41 | 42 | ### 🦾 Data Curation 43 | > Data collection, quality check 44 | - *Cheng Ruan, Jia Zeng, Lei Yang* 45 | 46 | ### 🛠️ Hardware & Software Development 47 | > Hardware design, embedded software development 48 | - *Yuehan Niu, Cheng Jing, Mingkang Shi, Chi Zhang, Qinglin Zhang, Cunbiao Yang, [Wenhao Wang](https://hao-starrr.github.io/), [Xuan Hu](https://github.com/huxuan)* 49 | 50 | ### 🚀 Project Co-lead and Advising 51 | > Research direction, project coordination, technical advising 52 | - *Maoqing Yao, Yu Qiao, Hongyang Li, [Jianlan Luo](https://scholar.google.co.uk/citations?user=SJoRNbYAAAAJ&hl=en&oi=ao), [Jiangmiao Pang](https://scholar.google.co.uk/citations?user=ssSfKpAAAAAJ&hl=en&oi=ao), [Bin Zhao](https://scholar.google.com/citations?user=DQB0hqwAAAAJ), [Junchi Yan](https://scholar.google.co.uk/citations?user=ga230VoAAAAJ&hl=en&oi=ao), Ping Luo* 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ![agibot_world](https://github.com/user-attachments/assets/df64b543-db82-41ee-adda-799970e8a198) 4 | 5 | Research Blog on March 10 | Technical Report 6 | 7 | [![Static Badge](https://img.shields.io/badge/Download-grey?style=plastic&logo=huggingface&logoColor=yellow)](https://huggingface.co/agibot-world) [![Static Badge](https://img.shields.io/badge/Project%20Page-blue?style=plastic)](https://agibot-world.com) [![License](https://img.shields.io/badge/License-CC_%20_BY--NC--SA_4.0-blue.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/) 8 | Document Badge 9 | 10 |
11 | 12 | AgiBot World Colosseo is a full-stack large-scale robot learning platform curated for advancing bimanual manipulation in scalable and intelligent embodied systems. It is accompanied by foundation models, benchmarks, and an ecosystem to democratize access to high-quality robot data for the academic community and the industry, paving the path towards the "ImageNet Moment" for Embodied AI. 13 | 14 | We have released: 15 | - **Task Catalog:** Reference sheet outlining the tasks in our dataset, including robot end-effector types, sample action-text descriptions and more 16 | - **AgiBot World Beta:** Our complete dataset featuring 1,003,672 trajectories (~43.8T) 17 | - **AgiBot World Alpha:** A curated subset of AgiBot World Beta, containing 92,214 trajectories (~8.5T) 18 | 19 | ## News📰 20 | 21 | - **`[2025/03/10]`** 📄 Research Blog and Technical Report released. 22 | - **`[2025/03/01]`** Agibot World Beta released. 23 | - **`[2025/01/03]`** Agibot World Alpha Sample Dataset released. 24 | - **`[2024/12/30]`** 🤖 Agibot World Alpha released. 25 | 26 | ## TODO List 📅 27 | 28 | - [x] **AgiBot World Alpha** 29 | - [x] **AgiBot World Beta** (expected Q1 2025) 30 | - [x] ~1,000,000 trajectories of high-quality robot data 31 | - [ ] **AgiBot World Foundation Model: GO-1** (expected Q2 2025) 32 | - [ ] Training & inference code 33 | - [ ] Pretrained model checkpoint 34 | - [ ] **AgiBot World Colosseo** (expected 2025) 35 | - [ ] A comprehensive platform with toolkits including teleoperation, training and inference. 36 | - [ ] **2025 AgiBot World Challenge** (expected 2025) 37 | 38 | ## Key Features 🔑 39 | 40 | - **1 million+** trajectories from 100 robots. 41 | - **100+ 1:1 replicated real-life scenarios** across 5 target domains. 42 | - **Cutting-edge hardware:** visual tactile sensors / 6-DoF Dexterous hand / mobile dual-arm robots 43 | - **Wide-spectrum versatile challenging tasks** 44 | 45 |
46 | 47 | 48 | 52 | 56 | 60 | 61 |
49 | Contact-rich Manipulation 50 |

Contact-rich Manipulation

51 |
53 | Long-horizon Planning 54 |

Long-horizon Planning

55 |
57 | Multi-robot Collaboration 58 |

Multi-robot Collaboration

59 |
62 |
63 | 64 | 65 | 66 | ## Table of Contents 67 | 68 | 1. [Key Features](#keyfeatures) 69 | 2. [At a Quick Glance](#quickglance) 70 | 3. [Getting Started](#installation) 71 | - [Installation](#training) 72 | - [How to Get Started with Our AgiBot World Data](#preaparedata) 73 | - [Visualize Datasets](#visualizedatasets) 74 | - [Policy Learning Quickstart](#training) 75 | 4. [TODO List](#todolist) 76 | 5. [License and Citation](#liscenseandcitation) 77 | 78 | ## At a Quick Glance⬇️ 79 | 80 | Follow the steps below to quickly explore and get an overview of AgiBot World with our [sample dataset](https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha/blob/main/sample_dataset.tar) (~7GB). 81 | 82 | ```bash 83 | # Installation 84 | conda create -n agibotworld python=3.10 -y 85 | conda activate agibotworld 86 | pip install git+https://github.com/huggingface/lerobot@59e275743499c5811a9f651a8947e8f881c4058c 87 | pip install matplotlib 88 | git clone https://github.com/OpenDriveLab/AgiBot-World.git 89 | cd AgiBot-World 90 | 91 | # Download the sample dataset (~7GB) from Hugging Face. Replace with your Hugging Face Access Token. You can generate an access token by following the instructions in the Hugging Face documentation from https://huggingface.co/docs/hub/security-tokens 92 | mkdir data 93 | cd data 94 | curl -L -o sample_dataset.tar -H "Authorization: Bearer " https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha/resolve/main/sample_dataset.tar 95 | tar -xvf sample_dataset.tar 96 | 97 | # Convert the sample dataset to LeRobot dataset format and visualize 98 | cd .. 99 | python scripts/convert_to_lerobot.py --src_path ./data/sample_dataset --task_id 390 --tgt_path ./data/sample_lerobot 100 | python scripts/visualize_dataset.py --task-id 390 --dataset-path ./data/sample_lerobot 101 | ``` 102 | 103 | ## Getting started 🔥 104 | 105 | #### Installation 106 | 107 | Download our source code: 108 | ```bash 109 | git clone https://github.com/OpenDriveLab/AgiBot-World.git 110 | cd AgiBot-World 111 | ``` 112 | 113 | Our project is built upon the [lerobot library](https://github.com/huggingface/lerobot) (**dataset `v2.0`, commit 59e2757**), 114 | install lerobot through 115 | ```bash 116 | pip install git+https://github.com/huggingface/lerobot@59e275743499c5811a9f651a8947e8f881c4058c 117 | ``` 118 | 119 | #### How to Get Started with Our AgiBot World Data 120 | 121 | - [OPTION 1] Download data from our [OpenDataLab](https://opendatalab.com/OpenDriveLab/AgiBot-World) page. 122 | 123 | ```bash 124 | pip install openxlab # install CLI 125 | openxlab dataset get --dataset-repo OpenDriveLab/AgiBot-World # dataset download 126 | ``` 127 | 128 | - [OPTION 2] Download data from our [HuggingFace](https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha) page. 129 | 130 | ```bash 131 | huggingface-cli download --resume-download --repo-type dataset agibot-world/AgiBotWorld-Alpha --local-dir ./AgiBotWorld-Alpha 132 | ``` 133 | 134 | Convert the data to **LeRobot Dataset** format. 135 | 136 | ```bash 137 | python scripts/convert_to_lerobot.py --src_path /path/to/agibotworld/alpha --task_id 390 --tgt_path /path/to/save/lerobot 138 | ``` 139 | 140 | #### Visualize Datasets 141 | 142 | We adapt and extend the dataset visualization script from [LeRobot Project](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/visualize_dataset.py) 143 | 144 | ```bash 145 | python scripts/visualize_dataset.py --task-id 390 --dataset-path /path/to/lerobot/format/dataset 146 | ``` 147 | 148 | It will open `rerun.io` and display the camera streams, robot states and actions, like this: 149 |
150 | 151 |
152 | 153 | #### Policy Training Quickstart 154 | 155 | Leveraging the simplicity of [LeRobot Dataset](https://github.com/huggingface/lerobot), we provide a user-friendly [Jupyter Notebook](https://github.com/OpenDriveLab/AgiBot-World/blob/main/AgibotWorld.ipynb) for training diffusion policy on AgiBot World Dataset. 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | ## License and Citation📄 165 | 166 | All the data and code within this repo are under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). 167 | 168 | - Please consider citing our work if it helps your research. 169 | - For the full authorship and detailed contributions, please refer to [contributions](CONTRIBUTING.md). 170 | - In alphabetical order by surname: 171 | ```BibTeX 172 | @article{bu2025agibot, 173 | title={Agibot world colosseo: A large-scale manipulation platform for scalable and intelligent embodied systems}, 174 | author={Bu, Qingwen and Cai, Jisong and Chen, Li and Cui, Xiuqi and Ding, Yan and Feng, Siyuan and Gao, Shenyuan and He, Xindong and Huang, Xu and Jiang, Shu and others}, 175 | journal={arXiv preprint arXiv:2503.06669}, 176 | year={2025} 177 | } 178 | ``` 179 | -------------------------------------------------------------------------------- /assets/Contact-rich_manipulation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/Contact-rich_manipulation.gif -------------------------------------------------------------------------------- /assets/Long-horizon_planning.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/Long-horizon_planning.gif -------------------------------------------------------------------------------- /assets/Multi-robot_collaboration.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/Multi-robot_collaboration.gif -------------------------------------------------------------------------------- /assets/dataset_visualization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/AgiBot-World/0d4b63b081c4a066f10e16901f418218f45e68bf/assets/dataset_visualization.gif -------------------------------------------------------------------------------- /scripts/convert_to_lerobot.py: -------------------------------------------------------------------------------- 1 | """ 2 | This project is built upon the open-source project 🤗 LeRobot: https://github.com/huggingface/lerobot 3 | 4 | We are grateful to the LeRobot team for their outstanding work and their contributions to the community. 5 | 6 | If you find this project useful, please also consider supporting and exploring LeRobot. 7 | """ 8 | 9 | import os 10 | import json 11 | import shutil 12 | import logging 13 | import argparse 14 | import gc 15 | from pathlib import Path 16 | from typing import Callable 17 | from functools import partial 18 | from math import ceil 19 | from copy import deepcopy 20 | 21 | import h5py 22 | import torch 23 | import einops 24 | import numpy as np 25 | from PIL import Image 26 | from tqdm import tqdm 27 | from pprint import pformat 28 | from tqdm.contrib.concurrent import process_map 29 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 30 | from lerobot.common.datasets.utils import ( 31 | STATS_PATH, 32 | check_timestamps_sync, 33 | get_episode_data_index, 34 | serialize_dict, 35 | write_json, 36 | ) 37 | 38 | HEAD_COLOR = "head_color.mp4" 39 | HAND_LEFT_COLOR = "hand_left_color.mp4" 40 | HAND_RIGHT_COLOR = "hand_right_color.mp4" 41 | HEAD_CENTER_FISHEYE_COLOR = "head_center_fisheye_color.mp4" 42 | HEAD_LEFT_FISHEYE_COLOR = "head_left_fisheye_color.mp4" 43 | HEAD_RIGHT_FISHEYE_COLOR = "head_right_fisheye_color.mp4" 44 | BACK_LEFT_FISHEYE_COLOR = "back_left_fisheye_color.mp4" 45 | BACK_RIGHT_FISHEYE_COLOR = "back_right_fisheye_color.mp4" 46 | HEAD_DEPTH = "head_depth" 47 | 48 | DEFAULT_IMAGE_PATH = ( 49 | "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.jpg" 50 | ) 51 | 52 | FEATURES = { 53 | "observation.images.top_head": { 54 | "dtype": "video", 55 | "shape": [480, 640, 3], 56 | "names": ["height", "width", "channel"], 57 | "video_info": { 58 | "video.fps": 30.0, 59 | "video.codec": "av1", 60 | "video.pix_fmt": "yuv420p", 61 | "video.is_depth_map": False, 62 | "has_audio": False, 63 | }, 64 | }, 65 | "observation.images.cam_top_depth": { 66 | "dtype": "image", 67 | "shape": [480, 640, 1], 68 | "names": ["height", "width", "channel"], 69 | }, 70 | "observation.images.hand_left": { 71 | "dtype": "video", 72 | "shape": [480, 640, 3], 73 | "names": ["height", "width", "channel"], 74 | "video_info": { 75 | "video.fps": 30.0, 76 | "video.codec": "av1", 77 | "video.pix_fmt": "yuv420p", 78 | "video.is_depth_map": False, 79 | "has_audio": False, 80 | }, 81 | }, 82 | "observation.images.hand_right": { 83 | "dtype": "video", 84 | "shape": [480, 640, 3], 85 | "names": ["height", "width", "channel"], 86 | "video_info": { 87 | "video.fps": 30.0, 88 | "video.codec": "av1", 89 | "video.pix_fmt": "yuv420p", 90 | "video.is_depth_map": False, 91 | "has_audio": False, 92 | }, 93 | }, 94 | "observation.images.head_center_fisheye": { 95 | "dtype": "video", 96 | "shape": [748, 960, 3], 97 | "names": ["height", "width", "channel"], 98 | "video_info": { 99 | "video.fps": 30.0, 100 | "video.codec": "av1", 101 | "video.pix_fmt": "yuv420p", 102 | "video.is_depth_map": False, 103 | "has_audio": False, 104 | }, 105 | }, 106 | "observation.images.head_left_fisheye": { 107 | "dtype": "video", 108 | "shape": [748, 960, 3], 109 | "names": ["height", "width", "channel"], 110 | "video_info": { 111 | "video.fps": 30.0, 112 | "video.codec": "av1", 113 | "video.pix_fmt": "yuv420p", 114 | "video.is_depth_map": False, 115 | "has_audio": False, 116 | }, 117 | }, 118 | "observation.images.head_right_fisheye": { 119 | "dtype": "video", 120 | "shape": [748, 960, 3], 121 | "names": ["height", "width", "channel"], 122 | "video_info": { 123 | "video.fps": 30.0, 124 | "video.codec": "av1", 125 | "video.pix_fmt": "yuv420p", 126 | "video.is_depth_map": False, 127 | "has_audio": False, 128 | }, 129 | }, 130 | "observation.images.back_left_fisheye": { 131 | "dtype": "video", 132 | "shape": [748, 960, 3], 133 | "names": ["height", "width", "channel"], 134 | "video_info": { 135 | "video.fps": 30.0, 136 | "video.codec": "av1", 137 | "video.pix_fmt": "yuv420p", 138 | "video.is_depth_map": False, 139 | "has_audio": False, 140 | }, 141 | }, 142 | "observation.images.back_right_fisheye": { 143 | "dtype": "video", 144 | "shape": [748, 960, 3], 145 | "names": ["height", "width", "channel"], 146 | "video_info": { 147 | "video.fps": 30.0, 148 | "video.codec": "av1", 149 | "video.pix_fmt": "yuv420p", 150 | "video.is_depth_map": False, 151 | "has_audio": False, 152 | }, 153 | }, 154 | "observation.state": { 155 | "dtype": "float32", 156 | "shape": [20], 157 | }, 158 | "action": { 159 | "dtype": "float32", 160 | "shape": [22], 161 | }, 162 | "episode_index": { 163 | "dtype": "int64", 164 | "shape": [1], 165 | "names": None, 166 | }, 167 | "frame_index": { 168 | "dtype": "int64", 169 | "shape": [1], 170 | "names": None, 171 | }, 172 | "index": { 173 | "dtype": "int64", 174 | "shape": [1], 175 | "names": None, 176 | }, 177 | "task_index": { 178 | "dtype": "int64", 179 | "shape": [1], 180 | "names": None, 181 | }, 182 | } 183 | 184 | 185 | def get_stats_einops_patterns(dataset, num_workers=0): 186 | """These einops patterns will be used to aggregate batches and compute statistics. 187 | 188 | Note: We assume the images are in channel first format 189 | """ 190 | 191 | dataloader = torch.utils.data.DataLoader( 192 | dataset, 193 | num_workers=num_workers, 194 | batch_size=2, 195 | shuffle=False, 196 | ) 197 | batch = next(iter(dataloader)) 198 | 199 | stats_patterns = {} 200 | 201 | for key in dataset.features: 202 | # sanity check that tensors are not float64 203 | assert batch[key].dtype != torch.float64 204 | 205 | # if isinstance(feats_type, (VideoFrame, Image)): 206 | if key in dataset.meta.camera_keys: 207 | # sanity check that images are channel first 208 | _, c, h, w = batch[key].shape 209 | assert ( 210 | c < h and c < w 211 | ), f"expect channel first images, but instead {batch[key].shape}" 212 | assert ( 213 | batch[key].dtype == torch.float32 214 | ), f"expect torch.float32, but instead {batch[key].dtype=}" 215 | # assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" 216 | # assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" 217 | stats_patterns[key] = "b c h w -> c 1 1" 218 | elif batch[key].ndim == 2: 219 | stats_patterns[key] = "b c -> c " 220 | elif batch[key].ndim == 1: 221 | stats_patterns[key] = "b -> 1" 222 | else: 223 | raise ValueError(f"{key}, {batch[key].shape}") 224 | 225 | return stats_patterns 226 | 227 | 228 | def compute_stats(dataset, batch_size=8, num_workers=4, max_num_samples=None): 229 | """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" 230 | if max_num_samples is None: 231 | max_num_samples = len(dataset) 232 | 233 | # for more info on why we need to set the same number of workers, see `load_from_videos` 234 | stats_patterns = get_stats_einops_patterns(dataset, num_workers) 235 | 236 | # mean and std will be computed incrementally while max and min will track the running value. 237 | mean, std, max, min = {}, {}, {}, {} 238 | for key in stats_patterns: 239 | mean[key] = torch.tensor(0.0).float() 240 | std[key] = torch.tensor(0.0).float() 241 | max[key] = torch.tensor(-float("inf")).float() 242 | min[key] = torch.tensor(float("inf")).float() 243 | 244 | def create_seeded_dataloader(dataset, batch_size, seed): 245 | generator = torch.Generator() 246 | generator.manual_seed(seed) 247 | dataloader = torch.utils.data.DataLoader( 248 | dataset, 249 | num_workers=num_workers, 250 | batch_size=batch_size, 251 | shuffle=True, 252 | drop_last=False, 253 | generator=generator, 254 | ) 255 | return dataloader 256 | 257 | # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get 258 | # surprises when rerunning the sampler. 259 | first_batch = None 260 | running_item_count = 0 # for online mean computation 261 | dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) 262 | for i, batch in enumerate( 263 | tqdm( 264 | dataloader, 265 | total=ceil(max_num_samples / batch_size), 266 | desc="Compute mean, min, max", 267 | ) 268 | ): 269 | this_batch_size = len(batch["index"]) 270 | running_item_count += this_batch_size 271 | if first_batch is None: 272 | first_batch = deepcopy(batch) 273 | for key, pattern in stats_patterns.items(): 274 | batch[key] = batch[key].float() 275 | # Numerically stable update step for mean computation. 276 | batch_mean = einops.reduce(batch[key], pattern, "mean") 277 | # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents 278 | # the update step, N is the running item count, B is this batch size, x̄ is the running mean, 279 | # and x is the current batch mean. Some rearrangement is then required to avoid risking 280 | # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields 281 | # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ 282 | mean[key] = ( 283 | mean[key] 284 | + this_batch_size * (batch_mean - mean[key]) / running_item_count 285 | ) 286 | max[key] = torch.maximum( 287 | max[key], einops.reduce(batch[key], pattern, "max") 288 | ) 289 | min[key] = torch.minimum( 290 | min[key], einops.reduce(batch[key], pattern, "min") 291 | ) 292 | 293 | if i == ceil(max_num_samples / batch_size) - 1: 294 | break 295 | 296 | first_batch_ = None 297 | running_item_count = 0 # for online std computation 298 | dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) 299 | for i, batch in enumerate( 300 | tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") 301 | ): 302 | this_batch_size = len(batch["index"]) 303 | running_item_count += this_batch_size 304 | # Sanity check to make sure the batches are still in the same order as before. 305 | if first_batch_ is None: 306 | first_batch_ = deepcopy(batch) 307 | for key in stats_patterns: 308 | assert torch.equal(first_batch_[key], first_batch[key]) 309 | for key, pattern in stats_patterns.items(): 310 | batch[key] = batch[key].float() 311 | # Numerically stable update step for mean computation (where the mean is over squared 312 | # residuals).See notes in the mean computation loop above. 313 | batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") 314 | std[key] = ( 315 | std[key] + this_batch_size * (batch_std - std[key]) / running_item_count 316 | ) 317 | 318 | if i == ceil(max_num_samples / batch_size) - 1: 319 | break 320 | 321 | for key in stats_patterns: 322 | std[key] = torch.sqrt(std[key]) 323 | 324 | stats = {} 325 | for key in stats_patterns: 326 | stats[key] = { 327 | "mean": mean[key], 328 | "std": std[key], 329 | "max": max[key], 330 | "min": min[key], 331 | } 332 | return stats 333 | 334 | 335 | class AgiBotDataset(LeRobotDataset): 336 | def __init__( 337 | self, 338 | repo_id: str, 339 | root: str | Path | None = None, 340 | episodes: list[int] | None = None, 341 | image_transforms: Callable | None = None, 342 | delta_timestamps: dict[list[float]] | None = None, 343 | tolerance_s: float = 1e-4, 344 | download_videos: bool = True, 345 | local_files_only: bool = False, 346 | video_backend: str | None = None, 347 | ): 348 | super().__init__( 349 | repo_id=repo_id, 350 | root=root, 351 | episodes=episodes, 352 | image_transforms=image_transforms, 353 | delta_timestamps=delta_timestamps, 354 | tolerance_s=tolerance_s, 355 | download_videos=download_videos, 356 | local_files_only=local_files_only, 357 | video_backend=video_backend, 358 | ) 359 | 360 | def save_episode( 361 | self, task: str, episode_data: dict | None = None, videos: dict | None = None 362 | ) -> None: 363 | """ 364 | We rewrite this method to copy mp4 videos to the target position 365 | """ 366 | if not episode_data: 367 | episode_buffer = self.episode_buffer 368 | 369 | episode_length = episode_buffer.pop("size") 370 | episode_index = episode_buffer["episode_index"] 371 | if episode_index != self.meta.total_episodes: 372 | # TODO(aliberts): Add option to use existing episode_index 373 | raise NotImplementedError( 374 | "You might have manually provided the episode_buffer with an episode_index that doesn't " 375 | "match the total number of episodes in the dataset. This is not supported for now." 376 | ) 377 | 378 | if episode_length == 0: 379 | raise ValueError( 380 | "You must add one or several frames with `add_frame` before calling `add_episode`." 381 | ) 382 | 383 | task_index = self.meta.get_task_index(task) 384 | 385 | if not set(episode_buffer.keys()) == set(self.features): 386 | raise ValueError() 387 | 388 | for key, ft in self.features.items(): 389 | if key == "index": 390 | episode_buffer[key] = np.arange( 391 | self.meta.total_frames, self.meta.total_frames + episode_length 392 | ) 393 | elif key == "episode_index": 394 | episode_buffer[key] = np.full((episode_length,), episode_index) 395 | elif key == "task_index": 396 | episode_buffer[key] = np.full((episode_length,), task_index) 397 | elif ft["dtype"] in ["image", "video"]: 398 | continue 399 | elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: 400 | episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"]) 401 | elif len(ft["shape"]) == 1 and ft["shape"][0] > 1: 402 | episode_buffer[key] = np.stack(episode_buffer[key]) 403 | else: 404 | raise ValueError(key) 405 | 406 | self._wait_image_writer() 407 | self._save_episode_table(episode_buffer, episode_index) 408 | 409 | self.meta.save_episode(episode_index, episode_length, task, task_index) 410 | for key in self.meta.video_keys: 411 | video_path = self.root / self.meta.get_video_file_path(episode_index, key) 412 | episode_buffer[key] = video_path 413 | video_path.parent.mkdir(parents=True, exist_ok=True) 414 | shutil.copyfile(videos[key], video_path) 415 | if not episode_data: # Reset the buffer 416 | self.episode_buffer = self.create_episode_buffer() 417 | self.consolidated = False 418 | 419 | def consolidate( 420 | self, run_compute_stats: bool = True, keep_image_files: bool = False 421 | ) -> None: 422 | self.hf_dataset = self.load_hf_dataset() 423 | self.episode_data_index = get_episode_data_index( 424 | self.meta.episodes, self.episodes 425 | ) 426 | check_timestamps_sync( 427 | self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s 428 | ) 429 | if len(self.meta.video_keys) > 0: 430 | self.meta.write_video_info() 431 | 432 | if not keep_image_files: 433 | img_dir = self.root / "images" 434 | if img_dir.is_dir(): 435 | shutil.rmtree(self.root / "images") 436 | video_files = list(self.root.rglob("*.mp4")) 437 | assert len(video_files) == self.num_episodes * len(self.meta.video_keys) 438 | 439 | parquet_files = list(self.root.rglob("*.parquet")) 440 | assert len(parquet_files) == self.num_episodes 441 | 442 | if run_compute_stats: 443 | self.stop_image_writer() 444 | self.meta.stats = compute_stats(self) 445 | serialized_stats = serialize_dict(self.meta.stats) 446 | write_json(serialized_stats, self.root / STATS_PATH) 447 | self.consolidated = True 448 | else: 449 | logging.warning( 450 | "Skipping computation of the dataset statistics, dataset is not fully consolidated." 451 | ) 452 | 453 | def add_frame(self, frame: dict) -> None: 454 | """ 455 | This function only adds the frame to the episode_buffer. Apart from images — which are written in a 456 | temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method 457 | then needs to be called. 458 | """ 459 | # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, 460 | # check the dtype and shape matches, etc. 461 | 462 | if self.episode_buffer is None: 463 | self.episode_buffer = self.create_episode_buffer() 464 | 465 | frame_index = self.episode_buffer["size"] 466 | timestamp = ( 467 | frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps 468 | ) 469 | self.episode_buffer["frame_index"].append(frame_index) 470 | self.episode_buffer["timestamp"].append(timestamp) 471 | 472 | for key in frame: 473 | if key not in self.features: 474 | raise ValueError(key) 475 | item = ( 476 | frame[key].numpy() 477 | if isinstance(frame[key], torch.Tensor) 478 | else frame[key] 479 | ) 480 | self.episode_buffer[key].append(item) 481 | 482 | self.episode_buffer["size"] += 1 483 | 484 | 485 | def load_depths(root_dir: str, camera_name: str): 486 | cam_path = Path(root_dir) 487 | all_imgs = sorted(list(cam_path.glob(f"{camera_name}*"))) 488 | return [np.array(Image.open(f)).astype(np.float32) / 1000 for f in all_imgs] 489 | 490 | 491 | def load_local_dataset(episode_id: int, src_path: str, task_id: int) -> list | None: 492 | """Load local dataset and return a dict with observations and actions""" 493 | 494 | ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}" 495 | depth_imgs = load_depths(ob_dir / "depth", HEAD_DEPTH) 496 | proprio_dir = Path(src_path) / f"proprio_stats/{task_id}/{episode_id}" 497 | 498 | with h5py.File(proprio_dir / "proprio_stats.h5") as f: 499 | state_joint = np.array(f["state/joint/position"]) 500 | state_effector = np.array(f["state/effector/position"]) 501 | state_head = np.array(f["state/head/position"]) 502 | state_waist = np.array(f["state/waist/position"]) 503 | action_joint = np.array(f["action/joint/position"]) 504 | action_effector = np.array(f["action/effector/position"]) 505 | action_head = np.array(f["action/head/position"]) 506 | action_waist = np.array(f["action/waist/position"]) 507 | action_velocity = np.array(f["action/robot/velocity"]) 508 | 509 | states_value = np.hstack( 510 | [state_joint, state_effector, state_head, state_waist] 511 | ).astype(np.float32) 512 | assert ( 513 | action_joint.shape[0] == action_effector.shape[0] 514 | ), f"shape of action_joint:{action_joint.shape};shape of action_effector:{action_effector.shape}" 515 | action_value = np.hstack( 516 | [action_joint, action_effector, action_head, action_waist, action_velocity] 517 | ).astype(np.float32) 518 | 519 | assert len(depth_imgs) == len( 520 | states_value 521 | ), f"Number of images and states are not equal" 522 | assert len(depth_imgs) == len( 523 | action_value 524 | ), f"Number of images and actions are not equal" 525 | frames = [ 526 | { 527 | "observation.images.cam_top_depth": depth_imgs[i], 528 | "observation.state": states_value[i], 529 | "action": action_value[i], 530 | } 531 | for i in range(len(depth_imgs)) 532 | ] 533 | 534 | v_path = ob_dir / "videos" 535 | videos = { 536 | "observation.images.top_head": v_path / HEAD_COLOR, 537 | "observation.images.hand_left": v_path / HAND_LEFT_COLOR, 538 | "observation.images.hand_right": v_path / HAND_RIGHT_COLOR, 539 | "observation.images.head_center_fisheye": v_path / HEAD_CENTER_FISHEYE_COLOR, 540 | "observation.images.head_left_fisheye": v_path / HEAD_LEFT_FISHEYE_COLOR, 541 | "observation.images.head_right_fisheye": v_path / HEAD_RIGHT_FISHEYE_COLOR, 542 | "observation.images.back_left_fisheye": v_path / BACK_LEFT_FISHEYE_COLOR, 543 | "observation.images.back_right_fisheye": v_path / BACK_RIGHT_FISHEYE_COLOR, 544 | } 545 | return frames, videos 546 | 547 | 548 | def get_task_instruction(task_json_path: str) -> dict: 549 | """Get task language instruction""" 550 | with open(task_json_path, "r") as f: 551 | task_info = json.load(f) 552 | task_name = task_info[0]["task_name"] 553 | task_init_scene = task_info[0]["init_scene_text"] 554 | task_instruction = f"{task_name}.{task_init_scene}" 555 | print(f"Get Task Instruction <{task_instruction}>") 556 | return task_instruction 557 | 558 | 559 | def main( 560 | src_path: str, 561 | tgt_path: str, 562 | task_id: int, 563 | repo_id: str, 564 | task_info_json: str, 565 | debug: bool = False, 566 | chunk_size: int = 10 # Add chunk size parameter 567 | ): 568 | task_name = get_task_instruction(task_info_json) 569 | 570 | dataset = AgiBotDataset.create( 571 | repo_id=repo_id, 572 | root=f"{tgt_path}/{repo_id}", 573 | fps=30, 574 | robot_type="a2d", 575 | features=FEATURES, 576 | ) 577 | 578 | all_subdir = sorted( 579 | [ 580 | f.as_posix() 581 | for f in Path(src_path).glob(f"observations/{task_id}/*") 582 | if f.is_dir() 583 | ] 584 | ) 585 | 586 | if debug: 587 | all_subdir = all_subdir[:2] 588 | 589 | # Get all episode id 590 | all_subdir_eids = [int(Path(path).name) for path in all_subdir] 591 | all_subdir_episode_desc = [task_name] * len(all_subdir_eids) 592 | 593 | # Process in chunks to reduce memory usage 594 | for chunk_start in tqdm(range(0, len(all_subdir_eids), chunk_size), desc="Processing chunks"): 595 | chunk_end = min(chunk_start + chunk_size, len(all_subdir_eids)) 596 | chunk_eids = all_subdir_eids[chunk_start:chunk_end] 597 | chunk_descs = all_subdir_episode_desc[chunk_start:chunk_end] 598 | 599 | # Process only this chunk 600 | if debug: 601 | raw_datasets_chunk = [ 602 | load_local_dataset(subdir, src_path=src_path, task_id=task_id) 603 | for subdir in tqdm(chunk_eids, desc="Loading chunk data") 604 | ] 605 | else: 606 | raw_datasets_chunk = process_map( 607 | partial(load_local_dataset, src_path=src_path, task_id=task_id), 608 | chunk_eids, 609 | max_workers=os.cpu_count() // 2, 610 | desc=f"Loading chunk {chunk_start//chunk_size + 1}/{(len(all_subdir_eids) + chunk_size - 1)//chunk_size}", 611 | ) 612 | 613 | # Filter out None results 614 | valid_datasets = [(ds, desc) for ds, desc in zip(raw_datasets_chunk, chunk_descs) if ds is not None] 615 | 616 | # Process each dataset in the chunk 617 | for raw_dataset, episode_desc in tqdm(valid_datasets, desc="Processing episodes in chunk"): 618 | for raw_dataset_sub in tqdm( 619 | raw_dataset[0], desc="Processing frames", leave=False 620 | ): 621 | dataset.add_frame(raw_dataset_sub) 622 | dataset.save_episode(task=episode_desc, videos=raw_dataset[1]) 623 | 624 | # Clear memory after each chunk 625 | raw_datasets_chunk = None 626 | valid_datasets = None 627 | gc.collect() 628 | 629 | # Only consolidate at the end 630 | dataset.consolidate() 631 | 632 | 633 | if __name__ == "__main__": 634 | parser = argparse.ArgumentParser() 635 | parser.add_argument( 636 | "--src_path", 637 | type=str, 638 | required=True, 639 | ) 640 | parser.add_argument( 641 | "--task_id", 642 | type=str, 643 | required=True, 644 | ) 645 | parser.add_argument( 646 | "--tgt_path", 647 | type=str, 648 | required=True, 649 | ) 650 | parser.add_argument( 651 | "--debug", 652 | action="store_true", 653 | ) 654 | parser.add_argument( 655 | "--chunk_size", 656 | type=int, 657 | default=10, 658 | help="Number of episodes to process at once", 659 | ) 660 | args = parser.parse_args() 661 | 662 | task_id = args.task_id 663 | json_file = f"{args.src_path}/task_info/task_{args.task_id}.json" 664 | dataset_base = f"agibotworld/task_{args.task_id}" 665 | 666 | assert Path(json_file).exists, f"Cannot find {json_file}." 667 | main(args.src_path, args.tgt_path, task_id, dataset_base, json_file, args.debug, args.chunk_size) 668 | -------------------------------------------------------------------------------- /scripts/visualize_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is adapted from the Hugging Face 🤗 LeRobot project: 3 | https://github.com/huggingface/lerobot 4 | 5 | Original file: 6 | https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/visualize_dataset.py 7 | 8 | The original script was developed as part of the LeRobot project for dataset visualization. 9 | This version adds support for depth map visualization. 10 | """ 11 | 12 | import argparse 13 | import gc 14 | import logging 15 | import time 16 | from pathlib import Path 17 | from typing import Iterator 18 | 19 | import numpy as np 20 | import rerun as rr 21 | import torch 22 | import torch.utils.data 23 | import tqdm 24 | import matplotlib.pyplot as plt 25 | 26 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 27 | 28 | 29 | class EpisodeSampler(torch.utils.data.Sampler): 30 | def __init__(self, dataset: LeRobotDataset, episode_index: int): 31 | from_idx = dataset.episode_data_index["from"][episode_index].item() 32 | to_idx = dataset.episode_data_index["to"][episode_index].item() 33 | self.frame_ids = range(from_idx, to_idx) 34 | 35 | def __iter__(self) -> Iterator: 36 | return iter(self.frame_ids) 37 | 38 | def __len__(self) -> int: 39 | return len(self.frame_ids) 40 | 41 | 42 | def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: 43 | assert chw_float32_torch.dtype == torch.float32 44 | assert chw_float32_torch.ndim == 3 45 | c, h, w = chw_float32_torch.shape 46 | assert c < h and c < w, f"Expect channel first images, but instead {chw_float32_torch.shape}" 47 | 48 | if c == 1: 49 | # If depth image, clip and normalize the depth map just for visualization 50 | min_depth = 0.4 51 | max_depth = 3 52 | clipped_depth = torch.clamp(chw_float32_torch, min=min_depth, max=max_depth) 53 | normalized_depth = (clipped_depth-min_depth) / (max_depth-min_depth) 54 | depth_image = np.sqrt(normalized_depth.squeeze().cpu().numpy()) 55 | 56 | colormap = plt.get_cmap('jet') 57 | colored_depth_image = colormap(depth_image) 58 | hwc_uint8_numpy = (colored_depth_image[:, :, :3] * 255).astype(np.uint8) 59 | else: 60 | # If RGB image 61 | hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() 62 | 63 | return hwc_uint8_numpy 64 | 65 | 66 | def visualize_dataset( 67 | dataset: LeRobotDataset, 68 | episode_index: int, 69 | batch_size: int = 32, 70 | num_workers: int = 0, 71 | mode: str = "local", 72 | web_port: int = 9090, 73 | ws_port: int = 9087, 74 | save: bool = False, 75 | output_dir: Path | None = None, 76 | ) -> Path | None: 77 | if save: 78 | assert ( 79 | output_dir is not None 80 | ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." 81 | 82 | repo_id = dataset.repo_id 83 | 84 | logging.info("Loading dataloader") 85 | episode_sampler = EpisodeSampler(dataset, episode_index) 86 | dataloader = torch.utils.data.DataLoader( 87 | dataset, 88 | num_workers=num_workers, 89 | batch_size=batch_size, 90 | sampler=episode_sampler, 91 | ) 92 | 93 | logging.info("Starting Rerun") 94 | 95 | if mode not in ["local", "distant"]: 96 | raise ValueError(mode) 97 | 98 | spawn_local_viewer = mode == "local" and not save 99 | rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) 100 | 101 | # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush 102 | # when iterating on a dataloader with `num_workers` > 0 103 | # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix 104 | gc.collect() 105 | 106 | if mode == "distant": 107 | rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) 108 | 109 | logging.info("Logging to Rerun") 110 | 111 | for batch in tqdm.tqdm(dataloader, total=len(dataloader)): 112 | # iterate over the batch 113 | for i in range(len(batch["index"])): 114 | rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) 115 | rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) 116 | 117 | # display each camera image 118 | for key in dataset.meta.camera_keys: 119 | # TODO(rcadene): add `.compress()`? is it lossless? 120 | rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) 121 | 122 | # display each dimension of action space (e.g. actuators command) 123 | if "action" in batch: 124 | for dim_idx, val in enumerate(batch["action"][i]): 125 | rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) 126 | 127 | # display each dimension of observed state space (e.g. agent position in joint space) 128 | if "observation.state" in batch: 129 | for dim_idx, val in enumerate(batch["observation.state"][i]): 130 | rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) 131 | 132 | if mode == "local" and save: 133 | # save .rrd locally 134 | output_dir = Path(output_dir) 135 | output_dir.mkdir(parents=True, exist_ok=True) 136 | repo_id_str = repo_id.replace("/", "_") 137 | rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" 138 | rr.save(rrd_path) 139 | return rrd_path 140 | 141 | elif mode == "distant": 142 | # stop the process from exiting since it is serving the websocket connection 143 | try: 144 | while True: 145 | time.sleep(1) 146 | except KeyboardInterrupt: 147 | print("Ctrl-C received. Exiting.") 148 | 149 | 150 | def main(): 151 | parser = argparse.ArgumentParser() 152 | 153 | parser.add_argument( 154 | "--task-id", 155 | type=int, 156 | default=None, 157 | help="Index of the AgiBot World task.", 158 | ) 159 | parser.add_argument( 160 | "--episode-index", 161 | type=int, 162 | nargs="*", 163 | default=None, 164 | help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", 165 | ) 166 | parser.add_argument( 167 | "--dataset-path", 168 | type=Path, 169 | default=None, 170 | help="Root directory for the converted LeRobot dataset stored locally.", 171 | ) 172 | parser.add_argument( 173 | "--output-dir", 174 | type=Path, 175 | default=None, 176 | help="Directory path to write a .rrd file when `--save 1` is set.", 177 | ) 178 | parser.add_argument( 179 | "--batch-size", 180 | type=int, 181 | default=32, 182 | help="Batch size loaded by DataLoader.", 183 | ) 184 | parser.add_argument( 185 | "--num-workers", 186 | type=int, 187 | default=4, 188 | help="Number of processes of Dataloader for loading the data.", 189 | ) 190 | parser.add_argument( 191 | "--mode", 192 | type=str, 193 | default="local", 194 | help=( 195 | "Mode of viewing between 'local' or 'distant'. " 196 | "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " 197 | "'distant' creates a server on the distant machine where the data is stored. " 198 | "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." 199 | ), 200 | ) 201 | parser.add_argument( 202 | "--web-port", 203 | type=int, 204 | default=9090, 205 | help="Web port for rerun.io when `--mode distant` is set.", 206 | ) 207 | parser.add_argument( 208 | "--ws-port", 209 | type=int, 210 | default=9087, 211 | help="Web socket port for rerun.io when `--mode distant` is set.", 212 | ) 213 | parser.add_argument( 214 | "--save", 215 | type=int, 216 | default=0, 217 | help=( 218 | "Save a .rrd file in the directory provided by `--output-dir`. " 219 | "It also deactivates the spawning of a viewer. " 220 | "Visualize the data by running `rerun path/to/file.rrd` on your local machine." 221 | ), 222 | ) 223 | 224 | args = parser.parse_args() 225 | kwargs = vars(args) 226 | repo_id = f"agibotworld/task_{kwargs.pop('task_id')}" 227 | root = f"{kwargs.pop('dataset_path')}/{repo_id}" 228 | 229 | logging.info("Loading dataset") 230 | dataset = LeRobotDataset(repo_id, root=root, local_files_only=True) 231 | 232 | visualize_dataset(dataset, **vars(args)) 233 | 234 | if __name__ == "__main__": 235 | main() --------------------------------------------------------------------------------