├── utils ├── __init__.py ├── bridge_test.py ├── normalizers.py ├── server.py └── dataset.py ├── .python-version ├── conversion_scripts ├── __init__.py ├── conversion_utils.py └── new_modeling_pi0.py ├── pi0 ├── __init__.py ├── utils.py ├── paligemma_with_expert.py ├── modeling_pi0.py └── modeling_pi0fast.py ├── .gitignore ├── 1_e2e_inference.py ├── pyproject.toml ├── 2_test_pi0_on_libero.py ├── so100_client.py ├── 3_test_pi0fast_on_libero.py ├── train.py ├── so100_train.py ├── so100_eval.py ├── so100_train_fast.py ├── so100_eval_fast.py ├── README.md ├── convert_pi0_to_hf_lerobot.py └── convert_pi0fast_to_hf_lerobot.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /conversion_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pi0/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_pi0 import PI0Policy 2 | from .modeling_pi0fast import PI0FASTPolicy 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.png 3 | *.mp4 4 | *.txt 5 | *.jpg 6 | lightning_logs/ 7 | assets/ 8 | *.out 9 | bridgedata.py 10 | *.yaml 11 | *.json 12 | 13 | __pycache__/ -------------------------------------------------------------------------------- /1_e2e_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pi0 import PI0FASTPolicy, PI0Policy 4 | 5 | PATH_TO_PI_MODEL = ( 6 | "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_libero_pytorch" 7 | ) 8 | PATH_TO_PI_FAST_MODEL = ( 9 | "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_fast_libero_pytorch" 10 | ) 11 | model_type = "pi0fast" # or "pi0fast" 12 | 13 | # load model 14 | if model_type == "pi0": 15 | policy = PI0Policy.from_pretrained(PATH_TO_PI_MODEL) 16 | else: 17 | policy = PI0FASTPolicy.from_pretrained(PATH_TO_PI_FAST_MODEL) 18 | 19 | # create pseudo observation 20 | # check the comment in `PI0Policy.select_action` for the expected observation format 21 | # let's assume we have the following observation 22 | device = policy.config.device 23 | observation = { 24 | "image": { 25 | "base_0_rgb": torch.randint( 26 | 0, 256, (1, 3, 224, 224), dtype=torch.uint8, device=device 27 | ), 28 | # "left_wrist_0_rgb": ..., Suppose we don't have this view 29 | # "right_wrist_0_rgb": ..., Suppose we don't have this view 30 | }, 31 | "state": torch.randn(1, 8, device=device) * 0.2, 32 | "prompt": ["do something"], 33 | } 34 | 35 | # select action 36 | # let's assume the `action_dim` is 7 37 | action = policy.select_action(observation)[0, :, :7] 38 | print(action) 39 | -------------------------------------------------------------------------------- /utils/bridge_test.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import tensorflow_datasets as tfds 6 | import tqdm 7 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 8 | 9 | path = Path("/mnt/20T/datasets/bridgev2/tf") 10 | 11 | 12 | dataset = LeRobotDataset.create( 13 | repo_id="ZibinDong/bridgedatav2_train", 14 | root="/mnt/20T/datasets/bridgev2/lerobot/ZibinDong/bridgedatav2_train", 15 | robot_type="WidowX250", 16 | fps=5, 17 | features={ 18 | "image": { 19 | "dtype": "video", 20 | "shape": (256, 256, 3), 21 | "names": ["height", "width", "channel"], 22 | }, 23 | "state": { 24 | "dtype": "float32", 25 | "shape": (7,), 26 | "names": ["state"], 27 | }, 28 | "action": { 29 | "dtype": "float32", 30 | "shape": (7,), 31 | "names": ["action"], 32 | }, 33 | "traj_idx": { 34 | "dtype": "int64", 35 | "shape": (1,), 36 | "names": ["traj_idx"], 37 | }, 38 | }, 39 | image_writer_threads=10, 40 | image_writer_processes=5, 41 | ) 42 | 43 | builder = tfds.builder_from_directory(path) 44 | ds = builder.as_dataset(split="train") 45 | 46 | 47 | for idx, episode in tqdm.tqdm(enumerate(tfds.as_numpy(ds))): 48 | images = [[], [], [], []] 49 | actions, states = [], [] 50 | has_image = [False, False, False, False] 51 | 52 | for i, step in enumerate(episode["steps"]): 53 | if i == 0: 54 | language_instruction = step.get("language_instruction", b"").decode("utf-8") 55 | for img_idx in range(4): 56 | if step["observation"][f"image_{img_idx}"].mean() > 0: 57 | has_image[img_idx] = True 58 | if not any(has_image): 59 | break 60 | 61 | for img_idx in range(4): 62 | if has_image[img_idx]: 63 | images[img_idx].append(step["observation"][f"image_{img_idx}"]) 64 | 65 | actions.append(step["action"]) 66 | states.append(step["observation"]["state"]) 67 | 68 | if not any(has_image): 69 | continue 70 | 71 | for img_idx in range(4): 72 | if has_image[img_idx]: 73 | for image, action, state in zip(images[img_idx], actions, states): 74 | dataset.add_frame( 75 | { 76 | "image": image, 77 | "action": action, 78 | "state": state, 79 | "traj_idx": np.array([idx], dtype=np.int64), 80 | }, 81 | task=language_instruction, 82 | ) 83 | dataset.save_episode() 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "openpi-pytorch-env" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "torch==2.2.2", 9 | "torchvision==0.17.2", 10 | "torchaudio==2.2.2", 11 | "numpy<2.0", 12 | # "einops==0.4.1", 13 | "mujoco_py>=2.0", 14 | "gym>=0.13,<0.24", 15 | "dm_control>=1.0.3,<=1.0.20", 16 | "cython==3.0.0", 17 | "numba<0.60.0", 18 | "matplotlib==3.5.3", 19 | "six", 20 | "hydra-core==1.2.0", 21 | "zarr<2.17", 22 | # "wandb==0.13.1", 23 | "dill", 24 | "av", 25 | "pygame", 26 | "pymunk", 27 | "shapely<2.0.0", 28 | "scikit-image<0.23.0", 29 | "opencv-python==4.6.0.66", 30 | "imagecodecs", 31 | "mujoco<=3.1.6", 32 | "easydict==1.9", 33 | "transformers==4.48.0", 34 | "robomimic==0.2.0", 35 | "thop==0.1.1-2209072238", 36 | "robosuite==1.4.0", 37 | "bddl==1.0.1", 38 | "future==0.18.2", 39 | "cloudpickle==2.1.0", 40 | # Delete the [gym==0.25.2] requirements in "LIBERO" 41 | 42 | "datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency 43 | "diffusers>=0.27.2", 44 | "huggingface-hub[hf-transfer,cli]>=0.34.2", 45 | 46 | # Core dependencies 47 | "cmake>=3.29.0.1", 48 | "einops>=0.8.0", 49 | "opencv-python-headless>=4.9.0", 50 | "av>=14.2.0", 51 | "jsonlines>=4.0.0", 52 | "packaging>=24.2", 53 | "pynput>=1.7.7", 54 | "pyserial>=3.5", 55 | "wandb>=0.20.0", 56 | 57 | # "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency 58 | # "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency 59 | # "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency 60 | "torchcodec", 61 | 62 | "draccus==0.10.0", # TODO: Remove == 63 | "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency 64 | "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency 65 | 66 | # Support dependencies 67 | "deepdiff>=7.0.1,<9.0.0", 68 | "flask>=3.0.3,<4.0.0", 69 | "imageio[ffmpeg]>=2.34.0,<3.0.0", 70 | "termcolor>=2.4.0,<4.0.0", 71 | ] 72 | 73 | [tool.uv.sources] 74 | torch = { index = "pytorch-cu121" } 75 | torchvision = { index = "pytorch-cu121" } 76 | torchaudio = { index = "pytorch-cu121" } 77 | 78 | [[tool.uv.index]] 79 | name = "pytorch-cu121" 80 | url = "https://download.pytorch.org/whl/cu121" 81 | explicit = true 82 | 83 | # 1. Creating a new environment by Conda/Mamba 84 | # 2. Installing `mamba install ffmpeg=7.1.1 uv ipython` 85 | # 3. Using `uv lock` to solve the environment 86 | # 4. Using `uv export --format requirements-txt > requirements.txt ` to generate the requirements 87 | # 5. Using `uv pip install --system -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match` to install it -------------------------------------------------------------------------------- /conversion_scripts/conversion_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from transformers import GemmaConfig, PaliGemmaConfig 16 | 17 | 18 | def get_paligemma_config(precision: str): 19 | config = { 20 | "image_token_index": None, 21 | "pad_token_id": 0, 22 | "bos_token_id": 2, 23 | "eos_token_id": 1, 24 | } 25 | 26 | # image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} 27 | 28 | image_size = 224 # image_sizes[variant] 29 | patch_size = 14 30 | num_image_tokens = (image_size**2) // (patch_size**2) 31 | 32 | config["image_token_index"] = 257152 33 | text_config = { 34 | "vocab_size": 257152, 35 | "num_hidden_layers": 18, 36 | "num_key_value_heads": 1, 37 | "head_dim": 256, 38 | "torch_dtype": precision, 39 | "hidden_size": 2048, 40 | "hidden_activation": "gelu_pytorch_tanh", 41 | "num_attention_heads": 8, 42 | "intermediate_size": 16384, 43 | "is_encoder_decoder": False, 44 | } 45 | vision_config = { 46 | "torch_dtype": precision, 47 | "image_size": image_size, 48 | "patch_size": patch_size, 49 | "num_image_tokens": num_image_tokens, 50 | "hidden_size": 1152, 51 | "intermediate_size": 4304, 52 | "num_hidden_layers": 27, 53 | "num_attention_heads": 16, 54 | "projector_hidden_act": "gelu_fast", 55 | "vision_use_head": False, 56 | } 57 | final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) 58 | return final_config 59 | 60 | 61 | def get_gemma_config(precision: str): 62 | config = { 63 | "image_token_index": None, 64 | "pad_token_id": 0, 65 | "bos_token_id": 2, 66 | "eos_token_id": 1, 67 | } 68 | 69 | config["image_token_index"] = 257152 70 | text_config = { 71 | "vocab_size": 257152, 72 | "num_hidden_layers": 18, 73 | "num_key_value_heads": 1, 74 | "head_dim": 256, 75 | "torch_dtype": precision, 76 | "hidden_size": 1024, 77 | "hidden_activation": "gelu_pytorch_tanh", 78 | "num_attention_heads": 8, 79 | "intermediate_size": 4096, 80 | "is_encoder_decoder": False, 81 | } 82 | final_config = GemmaConfig() 83 | final_config.update(text_config) 84 | return final_config -------------------------------------------------------------------------------- /2_test_pi0_on_libero.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import gym 5 | import imageio 6 | import numpy as np 7 | import robosuite.utils.transform_utils as T 8 | import torch 9 | from cleandiffuser.env import libero # noqa: F401 10 | from termcolor import cprint 11 | 12 | from pi0 import PI0Policy 13 | 14 | PATH_TO_PI_MODEL = ( 15 | "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_libero_pytorch" 16 | ) 17 | PATH_TO_JAX_PI_MODEL = "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_libero" 18 | 19 | # load model 20 | cprint("Loading PI0 model...", "green") 21 | policy = PI0Policy.from_pretrained(PATH_TO_PI_MODEL) 22 | 23 | # load normalization stats 24 | device = policy.config.device 25 | norm_stats_path = ( 26 | Path(PATH_TO_JAX_PI_MODEL) / "assets/physical-intelligence/libero/norm_stats.json" 27 | ) 28 | with open(norm_stats_path) as f: 29 | norm_stats = json.load(f) 30 | state_mean = np.array(norm_stats["norm_stats"]["state"]["mean"][:8], dtype=np.float32) 31 | state_std = np.array(norm_stats["norm_stats"]["state"]["std"][:8], dtype=np.float32) 32 | action_mean = np.array( 33 | norm_stats["norm_stats"]["actions"]["mean"][:7], dtype=np.float32 34 | ) 35 | action_std = np.array(norm_stats["norm_stats"]["actions"]["std"][:7], dtype=np.float32) 36 | 37 | # create environment 38 | # ** Change `env_name` and `task_id` to test different environments and tasks ** 39 | cprint("Creating Libero environment...", "green") 40 | env = gym.make( 41 | "libero-goal-v0", # from ["libero-goal-v0", "libero-object-v0", "libero-spatial-v0", "libero-10-v0", "libero-90-v0"], 42 | task_id=2, # task id from 0 to 9 43 | image_size=224, # image size (height, width) 44 | camera_names=["agentview", "robot0_eye_in_hand"], # camera names 45 | seed=0, # random seed 46 | ) 47 | 48 | # reset environment 49 | o = env.reset() 50 | # important: do some `dummy` steps because the simulator drops object at the beginning 51 | dummy_action = np.array([0, 0, 0, 0, 0, 0, -1]) 52 | for _ in range(20): 53 | o, r, d, i = env.step(dummy_action) 54 | 55 | frames = [] 56 | cprint("Starting demo...", "green") 57 | while not d: 58 | unnorm_state = np.concatenate( 59 | [ 60 | o["robot0_eef_pos"], 61 | T.quat2axisangle(o["robot0_eef_quat"]), 62 | o["robot0_gripper_qpos"], 63 | ], 64 | dtype=np.float32, 65 | ) 66 | state = (unnorm_state - state_mean) / (state_std + 1e-6) 67 | base_0_rgb = o["agentview_image"][:, :, ::-1].copy() 68 | left_wrist_0_rgb = o["robot0_eye_in_hand_image"][:, :, ::-1].copy() 69 | 70 | observation = { 71 | "image": { 72 | "base_0_rgb": torch.from_numpy(base_0_rgb).to(device)[None], 73 | "left_wrist_0_rgb": torch.from_numpy(left_wrist_0_rgb).to(device)[None], 74 | }, 75 | "state": torch.from_numpy(state).to(device)[None], 76 | "prompt": [env.task_description], 77 | } 78 | action = policy.select_action(observation)[0, :, :7] 79 | action = action.cpu().numpy() 80 | action = action * (action_std + 1e-6) + action_mean 81 | action[:, :6] += unnorm_state[None, :6] 82 | for i in range(50): 83 | o, r, d, _ = env.step(action[i, :7]) 84 | frames.append(o["agentview_image"][:, :, ::-1].transpose(1, 2, 0).copy()) 85 | if d: 86 | break 87 | 88 | # save video 89 | writer = imageio.get_writer("pi0_libero_demo.mp4", fps=30) 90 | for frame in frames: 91 | writer.append_data(frame) 92 | writer.close() 93 | -------------------------------------------------------------------------------- /so100_client.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import gym 5 | import numpy as np 6 | from lerobot.common.cameras.opencv import OpenCVCameraConfig 7 | from lerobot.common.robots.so100_follower import SO100Follower, SO100FollowerConfig 8 | from lerobot.common.utils.robot_utils import busy_wait 9 | 10 | from utils.server import PolicyClient 11 | 12 | 13 | class SO100Env(gym.Env): 14 | def __init__( 15 | self, 16 | robot_id: str, 17 | port: str, 18 | cameras: dict[str, dict], 19 | ): 20 | super().__init__() 21 | 22 | cameras_config = {} 23 | for k, v in cameras.items(): 24 | cameras_config[k] = OpenCVCameraConfig(**v) 25 | 26 | robot_config = SO100FollowerConfig( 27 | port=port, 28 | id=robot_id, 29 | cameras=cameras_config 30 | ) 31 | self.robot = SO100Follower(robot_config) 32 | self.camera_names = list(cameras.keys()) 33 | 34 | self.init_pos = None 35 | 36 | def get_observation(self): 37 | raw_observation = self.robot.get_observation() 38 | state = np.array([ 39 | raw_observation['shoulder_pan.pos'], 40 | raw_observation['shoulder_lift.pos'], 41 | raw_observation['elbow_flex.pos'], 42 | raw_observation['wrist_flex.pos'], 43 | raw_observation['wrist_roll.pos'], 44 | raw_observation['gripper.pos'], 45 | ]) 46 | observation = {'state': state} 47 | for camera_name in self.camera_names: 48 | _image = cv2.resize(raw_observation[camera_name], (224, 224)) 49 | observation[camera_name] = _image 50 | return observation 51 | 52 | def reset(self): 53 | if self.init_pos is None: 54 | self.robot.connect() 55 | init_observation = self.get_observation() 56 | self.init_pos = init_observation['state'] 57 | else: 58 | self.robot.send_action(self.init_pos) 59 | init_observation = self.get_observation() 60 | return init_observation 61 | 62 | def step(self, action): 63 | self.robot.send_action(action) 64 | return self.get_observation(), 0, False, {} 65 | 66 | def close(self): 67 | self.robot.disconnect() 68 | 69 | 70 | client = PolicyClient(server_host='localhost', server_port=12346) 71 | 72 | env = SO100Env( 73 | robot_id="so100_follower", 74 | port="/dev/ttyACM0", 75 | cameras={ 76 | "base": {"fps": 25, "width": 640, "height": 480, "index_or_path": 0}, 77 | "wrist": {"fps": 25, "width": 640, "height": 480, "index_or_path": 2}, 78 | } 79 | ) 80 | 81 | o = env.reset() 82 | o['prompt'] = ["grab the screwdriver and put it to the right"] 83 | busy_wait(1) 84 | 85 | for _ in range(10): 86 | o = env.get_observation() 87 | o['prompt'] = ["grab the screwdriver and put it to the front"] 88 | action = client.get_action(o)[0] 89 | 90 | names = [ 91 | "shoulder_pan.pos", 92 | "shoulder_lift.pos", 93 | "elbow_flex.pos", 94 | "wrist_flex.pos", 95 | "wrist_roll.pos", 96 | "gripper.pos" 97 | ] 98 | 99 | fps = 30 100 | start_episode_t = time.perf_counter() 101 | for i in range(0, 50): 102 | start_loop_t = time.perf_counter() 103 | 104 | act = {} 105 | for j, name in enumerate(names): 106 | act[name] = action[i, j] 107 | 108 | o, _, _, _ = env.step(act) 109 | 110 | if fps is not None: 111 | dt_s = time.perf_counter() - start_loop_t 112 | busy_wait(1 / fps - dt_s) 113 | 114 | dt_s = time.perf_counter() - start_loop_t 115 | 116 | env.close() 117 | 118 | -------------------------------------------------------------------------------- /3_test_pi0fast_on_libero.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import gym 5 | import imageio 6 | import numpy as np 7 | import robosuite.utils.transform_utils as T 8 | import torch 9 | from cleandiffuser.env import libero # noqa: F401 10 | from termcolor import cprint 11 | 12 | from pi0 import PI0FASTPolicy 13 | 14 | PATH_TO_PI_MODEL = ( 15 | "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_fast_libero_pytorch" 16 | ) 17 | PATH_TO_JAX_PI_MODEL = ( 18 | "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_fast_libero" 19 | ) 20 | 21 | # load model 22 | cprint("Loading PI0 fast model...", "green") 23 | policy = PI0FASTPolicy.from_pretrained(PATH_TO_PI_MODEL) 24 | policy.model.action_dim = 7 25 | 26 | # load normalization stats 27 | device = policy.config.device 28 | norm_stats_path = ( 29 | Path(PATH_TO_JAX_PI_MODEL) / "assets/physical-intelligence/libero/norm_stats.json" 30 | ) 31 | with open(norm_stats_path) as f: 32 | norm_stats = json.load(f) 33 | state_mean = np.array(norm_stats["norm_stats"]["state"]["mean"][:8], dtype=np.float32) 34 | state_std = np.array(norm_stats["norm_stats"]["state"]["std"][:8], dtype=np.float32) 35 | action_mean = np.array( 36 | norm_stats["norm_stats"]["actions"]["mean"][:7], dtype=np.float32 37 | ) 38 | action_std = np.array(norm_stats["norm_stats"]["actions"]["std"][:7], dtype=np.float32) 39 | 40 | # create environment 41 | # ** Change `env_name` and `task_id` to test different environments and tasks ** 42 | cprint("Creating Libero environment...", "green") 43 | env = gym.make( 44 | "libero-goal-v0", # from ["libero-goal-v0", "libero-object-v0", "libero-spatial-v0", "libero-10-v0", "libero-90-v0"], 45 | task_id=0, # task id from 0 to 9 46 | image_size=224, # image size (height, width) 47 | camera_names=["agentview", "robot0_eye_in_hand"], # camera names 48 | seed=0, # random seed 49 | max_episode_steps=300, 50 | ) 51 | 52 | # reset environment 53 | o = env.reset() 54 | # important: do some `dummy` steps because the simulator drops object at the beginning 55 | dummy_action = np.array([0, 0, 0, 0, 0, 0, -1]) 56 | for _ in range(10): 57 | o, r, d, i = env.step(dummy_action) 58 | 59 | frames = [] 60 | cprint("Starting demo...", "green") 61 | while not d: 62 | unnorm_state = np.concatenate( 63 | [ 64 | o["robot0_eef_pos"], 65 | T.quat2axisangle(o["robot0_eef_quat"]), 66 | o["robot0_gripper_qpos"], 67 | ], 68 | dtype=np.float32, 69 | ) 70 | state = (unnorm_state - state_mean) / (state_std + 1e-6) 71 | base_0_rgb = o["agentview_image"][:, :, ::-1].copy() 72 | left_wrist_0_rgb = o["robot0_eye_in_hand_image"][:, :, ::-1].copy() 73 | 74 | observation = { 75 | "image": { 76 | "base_0_rgb": torch.from_numpy(base_0_rgb).to(device)[None], 77 | "left_wrist_0_rgb": torch.from_numpy(left_wrist_0_rgb).to(device)[None], 78 | # "right_wrist_0_rgb": torch.from_numpy(np.zeros_like(left_wrist_0_rgb)).to( 79 | # "cuda:1" 80 | # )[None], 81 | }, 82 | "state": torch.from_numpy(state).to(device)[None], 83 | "prompt": [env.language], 84 | } 85 | # action = policy.select_action(observation)[0, :, :7] 86 | action = policy.select_action(observation)[0] 87 | action = action.cpu().numpy() 88 | action = action * (action_std + 1e-6) + action_mean 89 | action[:, :6] += unnorm_state[None, :6] 90 | for i in range(5): 91 | o, r, d, _ = env.step(action[i]) 92 | frames.append(o["agentview_image"][:, :, ::-1].transpose(1, 2, 0).copy()) 93 | if d: 94 | break 95 | 96 | # save video 97 | writer = imageio.get_writer("pi0fast_libero_demo.mp4", fps=30) 98 | for frame in frames: 99 | writer.append_data(frame) 100 | writer.close() 101 | -------------------------------------------------------------------------------- /utils/normalizers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def dict_apply(func, d): 8 | """ 9 | Apply a function to all values in a dictionary recursively. 10 | If the value is a dictionary, it will apply the function to its values. 11 | """ 12 | for key, value in d.items(): 13 | if isinstance(value, dict): 14 | dict_apply(func, value) 15 | else: 16 | d[key] = func(value) 17 | return d 18 | 19 | 20 | class Normalizer: 21 | def __init__( 22 | self, 23 | norm_stats: Dict[str, Dict[str, np.ndarray]], 24 | norm_type: Dict[str, str] | None = None, 25 | ): 26 | self.norm_stats = dict_apply(lambda x: x.astype(np.float32), norm_stats) 27 | self.norm_type = norm_type or {} 28 | 29 | def normalize(self, data: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]: 30 | normalized_data = {} 31 | for key, value in data.items(): 32 | if key in self.norm_stats: 33 | norm_type = self.norm_type.get(key, "identity") 34 | if norm_type == "meanstd": 35 | mean = self.norm_stats[key]["mean"] 36 | std = self.norm_stats[key]["std"] 37 | normalized_value = (value - mean) / (std + 1e-6) 38 | elif norm_type == "std": 39 | std = self.norm_stats[key]["std"] 40 | normalized_value = value / (std + 1e-6) 41 | elif norm_type == "minmax": 42 | min_val = self.norm_stats[key]["min"] 43 | max_val = self.norm_stats[key]["max"] 44 | normalized_value = (value - min_val) / ( 45 | max_val - min_val + 1e-6 46 | ) * 2 - 1 47 | elif norm_type == "identity": 48 | normalized_value = value 49 | else: 50 | raise ValueError( 51 | f"Unknown normalization type: {norm_type}. Supported types are 'meanstd', 'minmax', and 'identity'." 52 | ) 53 | normalized_data[key] = normalized_value 54 | else: 55 | # If the key is not in norm_stats, we assume no normalization is needed 56 | normalized_data[key] = value 57 | return normalized_data 58 | 59 | def unnormalize(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 60 | unnormalized_data = {} 61 | for key, value in data.items(): 62 | if key in self.norm_stats: 63 | norm_type = self.norm_type.get(key, "identity") 64 | if norm_type == "meanstd": 65 | mean = self.norm_stats[key]["mean"] 66 | std = self.norm_stats[key]["std"] 67 | unnormalized_value = value * (std + 1e-6) + mean 68 | elif norm_type == "std": 69 | std = self.norm_stats[key]["std"] 70 | unnormalized_value = value * (std + 1e-6) 71 | elif norm_type == "minmax": 72 | min_val = self.norm_stats[key]["min"] 73 | max_val = self.norm_stats[key]["max"] 74 | unnormalized_value = (value + 1) / 2 * ( 75 | max_val - min_val + 1e-6 76 | ) + min_val 77 | elif norm_type == "identity": 78 | unnormalized_value = value 79 | else: 80 | raise ValueError( 81 | f"Unknown normalization type: {norm_type}. Supported types are 'meanstd', 'minmax', and 'identity'." 82 | ) 83 | unnormalized_data[key] = unnormalized_value 84 | else: 85 | # If the key is not in norm_stats, we assume no unnormalization is needed 86 | unnormalized_data[key] = value 87 | return unnormalized_data 88 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # ! In development, not ready for use. 2 | 3 | import pytorch_lightning as L 4 | import torch 5 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 6 | from lerobot.configs.policies import PreTrainedConfig 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from torch.utils.data import DataLoader, Dataset 9 | from torchvision.transforms.v2 import ColorJitter, Compose, RandomCrop, Resize 10 | 11 | from pi0.new_modeling_pi0 import PI0Policy 12 | from utils.normalizers import Normalizer 13 | from utils.schedulers import CosineDecaySchedule 14 | 15 | 16 | def to_device_dtype(d, device, dtype): 17 | for key, value in d.items(): 18 | if isinstance(value, dict): 19 | to_device_dtype(value, device, dtype) 20 | elif isinstance(value, torch.Tensor): 21 | if key not in ["action_is_pad"]: 22 | d[key] = value.to(device=device, dtype=dtype) 23 | else: 24 | d[key] = value.to(device=device) 25 | else: 26 | pass 27 | return d 28 | 29 | 30 | class PI0BridgeDataset(Dataset): 31 | def __init__( 32 | self, 33 | repo_id="ZibinDong/bridgedatav2_val", 34 | root="/mnt/20T/datasets/bridgev2/lerobot/ZibinDong/bridgedatav2_val", 35 | ): 36 | image_transforms = Compose( 37 | [RandomCrop(243), Resize(224), ColorJitter(0.3, 0.4, 0.5)] 38 | ) 39 | delta_timestamps = { 40 | "image": [0], 41 | "state": [0], 42 | "action": [i / 5 for i in range(50)], 43 | } 44 | self.dataset = LeRobotDataset( 45 | repo_id=repo_id, 46 | root=root, 47 | image_transforms=image_transforms, 48 | delta_timestamps=delta_timestamps, 49 | ) 50 | 51 | self.normalizer = Normalizer( 52 | norm_stats=self.dataset.meta.stats, 53 | norm_type={"image": "identity", "state": "meanstd", "action": "meanstd"}, 54 | ) 55 | 56 | def __len__(self): 57 | return len(self.dataset) 58 | 59 | def __getitem__(self, idx): 60 | item = self.dataset[idx] 61 | normalized_item = self.normalizer.normalize(item) 62 | image = (normalized_item["image"] * 255).to(torch.uint8) 63 | return { 64 | "image": {"base_0_rgb": image}, 65 | "state": normalized_item["state"][0], 66 | "action": normalized_item["action"], 67 | "action_is_pad": normalized_item["action_is_pad"], 68 | "prompt": item["task"], 69 | } 70 | 71 | 72 | class LightningTrainingWrapper(L.LightningModule): 73 | def __init__(self, policy: PI0Policy): 74 | super().__init__() 75 | self.policy = policy 76 | 77 | self.lr_scheduler = CosineDecaySchedule( 78 | warmup_steps=1000, 79 | peak_lr=5e-5, 80 | decay_steps=1_000_000, 81 | decay_lr=5e-5, 82 | ) 83 | 84 | def forward(self, batch): 85 | return self.policy(batch)[0] 86 | 87 | def training_step(self, batch, batch_idx): 88 | loss = self.policy(batch)[0] 89 | self.log("train_loss", loss) 90 | return loss 91 | 92 | def configure_optimizers(self): 93 | optimizer = torch.optim.AdamW( 94 | self.policy.get_optim_params(), lr=5e-5, weight_decay=1e-2 95 | ) 96 | scheduler = LambdaLR(optimizer, lr_lambda=lambda step: self.lr_scheduler(step)) 97 | return { 98 | "optimizer": optimizer, 99 | "lr_scheduler": { 100 | "scheduler": scheduler, 101 | "interval": "step", 102 | "frequency": 1, 103 | }, 104 | } 105 | 106 | 107 | config = PreTrainedConfig.from_pretrained( 108 | "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch" 109 | ) 110 | config.device = "cpu" 111 | config.freeze_vision_encoder = True 112 | config.train_expert_only = True 113 | config.train_state_proj = True 114 | policy = PI0Policy(config) 115 | training_policy = LightningTrainingWrapper(policy) 116 | 117 | dataset = PI0BridgeDataset() 118 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2) 119 | 120 | trainer = L.Trainer( 121 | devices=[0, 1, 2, 3], 122 | strategy="ddp_find_unused_parameters_true", 123 | max_epochs=1, 124 | enable_progress_bar=True, 125 | gradient_clip_val=1.0, 126 | precision="bf16-true", 127 | accumulate_grad_batches=2, 128 | ) 129 | 130 | trainer.fit(training_policy, dataloader) 131 | -------------------------------------------------------------------------------- /utils/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import json 4 | import pickle 5 | from typing import Any, Dict, List, Union 6 | 7 | import numpy as np 8 | import websockets 9 | 10 | 11 | class PolicyServer: 12 | def __init__(self, policy, host="0.0.0.0", port=12345): 13 | self.policy = policy 14 | self.host = host 15 | self.port = port 16 | 17 | def serialize_obs(self, obs: Dict[str, Union[np.ndarray, List[str]]]) -> str: 18 | serialized = {} 19 | for key, value in obs.items(): 20 | if isinstance(value, np.ndarray): 21 | serialized[key] = { 22 | "type": "numpy", 23 | "data": base64.b64encode(pickle.dumps(value)).decode("utf-8"), 24 | "dtype": str(value.dtype), 25 | "shape": value.shape, 26 | } 27 | elif isinstance(value, list): 28 | serialized[key] = {"type": "list", "data": value} 29 | else: 30 | raise ValueError(f"Unsupported data type for key {key}: {type(value)}") 31 | 32 | return json.dumps(serialized) 33 | 34 | def deserialize_obs(self, data: str) -> Dict[str, Union[np.ndarray, List[str]]]: 35 | serialized = json.loads(data) 36 | obs = {} 37 | 38 | for key, value in serialized.items(): 39 | if value["type"] == "numpy": 40 | array_data = pickle.loads(base64.b64decode(value["data"])) 41 | obs[key] = array_data 42 | elif value["type"] == "list": 43 | obs[key] = value["data"] 44 | else: 45 | raise ValueError(f"Unknown data type: {value['type']}") 46 | 47 | return obs 48 | 49 | def serialize_action(self, action: Any) -> str: 50 | if isinstance(action, np.ndarray): 51 | return json.dumps( 52 | { 53 | "type": "numpy", 54 | "data": base64.b64encode(pickle.dumps(action)).decode("utf-8"), 55 | } 56 | ) 57 | else: 58 | return json.dumps( 59 | { 60 | "type": "other", 61 | "data": base64.b64encode(pickle.dumps(action)).decode("utf-8"), 62 | } 63 | ) 64 | 65 | async def handle_client(self, websocket): 66 | print(f"Client connected from {websocket.remote_address}") 67 | try: 68 | async for message in websocket: 69 | try: 70 | obs = self.deserialize_obs(message) 71 | print(f"Received obs with keys: {list(obs.keys())}") 72 | 73 | action = self.policy(obs) 74 | print(f"Computed action: {type(action)}") 75 | 76 | action_data = self.serialize_action(action) 77 | await websocket.send(action_data) 78 | 79 | except Exception as e: 80 | error_msg = json.dumps({"error": str(e)}) 81 | await websocket.send(error_msg) 82 | print(f"Error processing request: {e}") 83 | 84 | except websockets.exceptions.ConnectionClosed: 85 | print("Client disconnected") 86 | except Exception as e: 87 | print(f"Error in handle_client: {e}") 88 | 89 | async def start_server(self): 90 | print(f"Starting Policy server on {self.host}:{self.port}") 91 | 92 | server = await websockets.serve( 93 | self.handle_client, self.host, self.port, ping_interval=20, ping_timeout=10 94 | ) 95 | 96 | print("Policy server is running and waiting for connections...") 97 | await server.wait_closed() 98 | 99 | def run(self): 100 | asyncio.run(self.start_server()) 101 | 102 | 103 | class PolicyClient: 104 | def __init__(self, server_host='localhost', server_port=12345): 105 | self.server_host = server_host 106 | self.server_port = server_port 107 | self.loop = None 108 | self.client = None 109 | self._setup_loop() 110 | 111 | def _setup_loop(self): 112 | import threading 113 | self.loop = asyncio.new_event_loop() 114 | self.thread = threading.Thread(target=self._run_loop, daemon=True) 115 | self.thread.start() 116 | import time 117 | time.sleep(0.1) 118 | 119 | def _run_loop(self): 120 | asyncio.set_event_loop(self.loop) 121 | self.loop.run_forever() 122 | 123 | def _run_async(self, coro): 124 | future = asyncio.run_coroutine_threadsafe(coro, self.loop) 125 | return future.result(timeout=30) 126 | 127 | def connect(self): 128 | if self.client is None: 129 | self.client = PolicyClient(self.server_host, self.server_port) 130 | return self._run_async(self.client.connect()) 131 | 132 | def disconnect(self): 133 | if self.client: 134 | self._run_async(self.client.disconnect()) 135 | 136 | def get_action(self, obs: Dict[str, Union[np.ndarray, List[str]]]) -> Any: 137 | if self.client is None: 138 | self.connect() 139 | return self._run_async(self.client.get_action_async(obs)) 140 | 141 | def __enter__(self): 142 | self.connect() 143 | return self 144 | 145 | def __exit__(self, exc_type, exc_val, exc_tb): 146 | self.disconnect() 147 | if self.loop and self.loop.is_running(): 148 | self.loop.call_soon_threadsafe(self.loop.stop) 149 | 150 | -------------------------------------------------------------------------------- /so100_train.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 4 | from lerobot.configs.policies import PreTrainedConfig 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | from torch.optim.lr_scheduler import LinearLR 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchvision.transforms.v2 import Resize 9 | 10 | from pi0.modeling_pi0 import PI0Policy 11 | from utils.normalizers import Normalizer 12 | 13 | 14 | def to_device_dtype(d, device, dtype): 15 | for key, value in d.items(): 16 | if isinstance(value, dict): 17 | to_device_dtype(value, device, dtype) 18 | elif isinstance(value, torch.Tensor): 19 | if key not in ["action_is_pad"]: 20 | d[key] = value.to(device=device, dtype=dtype) 21 | else: 22 | d[key] = value.to(device=device) 23 | else: 24 | pass 25 | return d 26 | 27 | 28 | class PI0SO100Dataset(Dataset): 29 | def __init__( 30 | self, 31 | repo_id="ZibinDong/so100_grab_screwdriver", 32 | ): 33 | image_transforms = Resize((224, 224)) 34 | 35 | # [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS. 36 | # The timestamps are set to 0 for the images and state, as we only use current obs. 37 | delta_timestamps = { 38 | "observation.images.base": [0], 39 | "observation.images.wrist": [0], 40 | "observation.state": [0], 41 | "action": [i / 30 for i in range(50)], 42 | } 43 | 44 | self.dataset = LeRobotDataset( 45 | repo_id=repo_id, 46 | image_transforms=image_transforms, 47 | delta_timestamps=delta_timestamps, 48 | ) 49 | self.normalizer = Normalizer( 50 | norm_stats=self.dataset.meta.stats, 51 | norm_type={ 52 | "observation.images.base": "identity", 53 | "observation.images.wrist": "identity", 54 | "observation.state": "meanstd", 55 | "action": "std", 56 | }, 57 | ) 58 | 59 | def __len__(self): 60 | return len(self.dataset) 61 | 62 | def __getitem__(self, idx): 63 | item = self.dataset[idx] 64 | # we use relative action, so we need to subtract the state from the action 65 | item["action"] = item["action"] - item["observation.state"] 66 | normalized_item = self.normalizer.normalize(item) 67 | base_image = (normalized_item["observation.images.base"] * 255).to(torch.uint8) 68 | wrist_image = (normalized_item["observation.images.wrist"] * 255).to( 69 | torch.uint8 70 | ) 71 | return { 72 | "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": wrist_image}, 73 | "state": normalized_item["observation.state"][0], 74 | "action": normalized_item["action"], 75 | "action_is_pad": normalized_item["action_is_pad"], 76 | "prompt": item["task"], 77 | } 78 | 79 | 80 | class LightningTrainingWrapper(L.LightningModule): 81 | def __init__(self, config, ckpt_path): 82 | super().__init__() 83 | # load model in `configure_model` to accelerate model loading 84 | self.policy = None 85 | self.config = config 86 | self.ckpt_path = ckpt_path 87 | 88 | def configure_model(self): 89 | if self.policy is None: 90 | self.policy = PI0Policy.from_pretrained(self.ckpt_path, config=self.config) 91 | 92 | def forward(self, batch): 93 | return self.policy(batch)[0] 94 | 95 | def training_step(self, batch, batch_idx): 96 | loss = self.policy(batch)[0] 97 | self.log("train_loss", loss, prog_bar=True) 98 | return loss 99 | 100 | def configure_optimizers(self): 101 | optimizer = torch.optim.AdamW( 102 | self.policy.get_optim_params(), lr=5e-5, weight_decay=1e-2, eps=1e-6 103 | ) 104 | scheduler = LinearLR( 105 | optimizer, 106 | start_factor=0.01, 107 | end_factor=1.0, 108 | total_iters=100, 109 | ) 110 | return { 111 | "optimizer": optimizer, 112 | "lr_scheduler": { 113 | "scheduler": scheduler, 114 | "interval": "step", 115 | "frequency": 1, 116 | }, 117 | } 118 | 119 | 120 | dataset = PI0SO100Dataset("ZibinDong/so100_grab_screwdriver") 121 | dataloader = DataLoader( 122 | dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True 123 | ) 124 | 125 | callback = ModelCheckpoint( 126 | dirpath="/mnt/20T/dzb/pi0_so100_checkpoints", # where you want to save the checkpoints 127 | filename="{epoch}-{step}", 128 | save_top_k=-1, # save all checkpoints 129 | every_n_epochs=4, # save every 4 epochs 130 | ) 131 | 132 | trainer = L.Trainer( 133 | accelerator="cuda", 134 | devices=4, 135 | strategy="ddp_find_unused_parameters_true", 136 | max_epochs=50, 137 | enable_progress_bar=True, 138 | gradient_clip_val=1.0, 139 | precision="bf16-mixed", 140 | accumulate_grad_batches=4, 141 | callbacks=[callback], 142 | ) 143 | 144 | with trainer.init_module(): 145 | ckpt_path = "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch" 146 | config = PreTrainedConfig.from_pretrained(ckpt_path) 147 | config.device = "cpu" 148 | config.freeze_vision_encoder = True 149 | config.train_expert_only = True 150 | config.train_state_proj = True 151 | training_policy = LightningTrainingWrapper(config, ckpt_path) 152 | 153 | 154 | trainer.fit(training_policy, dataloader) 155 | -------------------------------------------------------------------------------- /so100_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytorch_lightning as L 3 | import torch 4 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 5 | from lerobot.configs.policies import PreTrainedConfig 6 | from termcolor import cprint 7 | from torch.utils.data import Dataset 8 | from torchvision.transforms.v2 import Resize 9 | 10 | from pi0.modeling_pi0 import PI0Policy 11 | from utils.normalizers import Normalizer 12 | from utils.server import PolicyServer 13 | 14 | 15 | def to_device_dtype(d, device, dtype): 16 | for key, value in d.items(): 17 | if isinstance(value, dict): 18 | to_device_dtype(value, device, dtype) 19 | elif isinstance(value, torch.Tensor): 20 | if key not in ["action_is_pad"]: 21 | d[key] = value.to(device=device, dtype=dtype) 22 | else: 23 | d[key] = value.to(device=device) 24 | else: 25 | pass 26 | return d 27 | 28 | 29 | class LightningTrainingWrapper(L.LightningModule): 30 | def __init__(self, config, ckpt_path): 31 | super().__init__() 32 | # load model in `configure_model` to accelerate model loading 33 | self.policy = None 34 | self.config = config 35 | self.ckpt_path = ckpt_path 36 | 37 | def configure_model(self): 38 | if self.policy is None: 39 | self.policy = PI0Policy.from_pretrained(self.ckpt_path, config=self.config) 40 | 41 | def forward(self, batch): 42 | return self.policy(batch)[0] 43 | 44 | def training_step(self, batch, batch_idx): 45 | loss = self.policy(batch)[0] 46 | self.log("train_loss", loss, prog_bar=True) 47 | return loss 48 | 49 | 50 | class SO100Policy: 51 | def __init__( 52 | self, 53 | ckpt_path: str, 54 | pi0_ckpt_path: str, 55 | repo_id: str = None, 56 | device: str = "cuda:0", 57 | dtype: torch.dtype = torch.bfloat16, 58 | ): 59 | self.device = device 60 | self.dtype = dtype 61 | 62 | # load policy 63 | cprint("Loading SO100 Policy...", "yellow") 64 | config = PreTrainedConfig.from_pretrained(pi0_ckpt_path) 65 | training_policy = LightningTrainingWrapper(config, pi0_ckpt_path) 66 | training_policy.load_state_dict( 67 | torch.load(ckpt_path, map_location="cpu")["state_dict"] 68 | ) 69 | self.policy = policy.to(device=device, dtype=dtype).eval() 70 | cprint("SO100 Policy loaded successfully!", "green") 71 | 72 | cprint("Prepareing norm stats...", "yellow") 73 | dataset = LeRobotDataset(repo_id=repo_id) 74 | self.normalizer = Normalizer( 75 | norm_stats=dataset.meta.stats, 76 | norm_type={ 77 | "observation.images.base": "identity", 78 | "observation.images.wrist": "identity", 79 | "observation.state": "meanstd", 80 | "action": "std", 81 | }, 82 | ) 83 | cprint("Norm stats prepared successfully!", "green") 84 | 85 | self.resize = Resize((224, 224)) 86 | 87 | cprint("Ready to use SO100 Policy!", "green") 88 | 89 | @torch.no_grad() 90 | def act(self, obs: np.ndarray): 91 | """ 92 | obs: { 93 | "base": uint8 (H, W, C), 94 | "wrist": uint8 (H, W, C), 95 | "state": float32 (state_dim,), 96 | "prompt": str 97 | } 98 | """ 99 | obs = self.normalizer.normalize( 100 | { 101 | "observation.images.base": obs["base"], 102 | "observation.images.wrist": obs["wrist"], 103 | "observation.state": obs["state"], 104 | "prompt": obs["prompt"], 105 | } 106 | ) 107 | 108 | base_image = torch.tensor( 109 | obs["observation.images.base"], dtype=torch.uint8, device=self.device 110 | ) 111 | wrist_image = torch.tensor( 112 | obs["observation.images.wrist"], dtype=torch.uint8, device=self.device 113 | ) 114 | base_image = base_image.permute(2, 0, 1)[None] 115 | wrist_image = wrist_image.permute(2, 0, 1)[None] 116 | base_image = self.resize(base_image) 117 | wrist_image = self.resize(wrist_image) 118 | state = torch.tensor( 119 | obs["observation.state"], dtype=self.dtype, device=self.device 120 | )[None] 121 | prompt = obs["prompt"] 122 | action = self.policy.select_action( 123 | { 124 | "image": { 125 | "base_0_rgb": base_image, 126 | "left_wrist_0_rgb": wrist_image, 127 | }, 128 | "state": state, 129 | "prompt": prompt, 130 | } 131 | ) 132 | action = action[:, :, :6] 133 | action = action.float().cpu().numpy() 134 | state = state.float().cpu().numpy() 135 | state_action = self.normalizer.unnormalize( 136 | {"observation.state": state, "action": action} 137 | ) 138 | state = state_action["observation.state"] 139 | action = state_action["action"] 140 | action = action + state 141 | return action 142 | 143 | def __call__(self, obs: np.ndarray): 144 | return self.act(obs) 145 | 146 | 147 | if __name__ == "__main__": 148 | policy = SO100Policy( 149 | ckpt_path="/mnt/20T/dzb/pi0_so100_checkpoints/epoch=39-step=29760.ckpt", 150 | pi0_ckpt_path="/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch", 151 | repo_ids="ZibinDong/so100_grab_screwdriver", 152 | device="cuda:0", 153 | dtype=torch.bfloat16, 154 | ) 155 | server = PolicyServer(policy, host="0.0.0.0", port=12346) 156 | server.run() 157 | -------------------------------------------------------------------------------- /so100_train_fast.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset 4 | from lerobot.configs.policies import PreTrainedConfig 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | from peft import LoraConfig, TaskType, get_peft_model 7 | from torch.optim.lr_scheduler import LinearLR 8 | from torch.utils.data import DataLoader, Dataset 9 | from torchvision.transforms.v2 import Resize 10 | 11 | from pi0 import PI0FASTPolicy 12 | from utils.normalizers import Normalizer 13 | 14 | 15 | def to_device_dtype(d, device, dtype): 16 | for key, value in d.items(): 17 | if isinstance(value, dict): 18 | to_device_dtype(value, device, dtype) 19 | elif isinstance(value, torch.Tensor): 20 | if key not in ["action_is_pad"]: 21 | d[key] = value.to(device=device, dtype=dtype) 22 | else: 23 | d[key] = value.to(device=device) 24 | else: 25 | pass 26 | return d 27 | 28 | 29 | class PI0SO100Dataset(Dataset): 30 | def __init__(self, repo_ids): 31 | image_transforms = Resize((224, 224)) 32 | 33 | # [i / 25 for i in range(50)] represents action chunks in 50 steps at 25 FPS. 34 | # The timestamps are set to 0 for the images and state, as we only use current obs. 35 | delta_timestamps = { 36 | "observation.images.base": [0], 37 | "observation.images.wrist": [0], 38 | "observation.state": [0], 39 | "action": [i / 25 for i in range(50)], 40 | } 41 | self.dataset = MultiLeRobotDataset( 42 | repo_ids=repo_ids, 43 | image_transforms=image_transforms, 44 | delta_timestamps=delta_timestamps, 45 | ) 46 | self.normalizer = Normalizer( 47 | norm_stats=self.dataset.stats, 48 | norm_type={ 49 | "observation.images.base": "identity", 50 | "observation.images.wrist": "identity", 51 | "observation.state": "minmax", 52 | "action": "minmax", 53 | }, 54 | ) 55 | 56 | def __len__(self): 57 | return len(self.dataset) 58 | 59 | def __getitem__(self, idx): 60 | item = self.dataset[idx] 61 | normalized_item = self.normalizer.normalize(item) 62 | base_image = (normalized_item["observation.images.base"] * 255).to(torch.uint8) 63 | wrist_image = (normalized_item["observation.images.wrist"] * 255).to( 64 | torch.uint8 65 | ) 66 | return { 67 | "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": wrist_image}, 68 | "state": normalized_item["observation.state"][0], 69 | "action": normalized_item["action"], 70 | "action_is_pad": normalized_item["action_is_pad"], 71 | "prompt": item["task"], 72 | } 73 | 74 | 75 | class LightningTrainingWrapper(L.LightningModule): 76 | def __init__(self, config, ckpt_path): 77 | super().__init__() 78 | # load model in `configure_model` to accelerate model loading 79 | self.policy = None 80 | self.config = config 81 | self.ckpt_path = ckpt_path 82 | 83 | def configure_model(self): 84 | if self.policy is None: 85 | policy = PI0FASTPolicy.from_pretrained(self.ckpt_path, config=self.config) 86 | # add lora to pi0_paligemma model 87 | model = policy.model.pi0_paligemma 88 | peft_config = LoraConfig( 89 | r=16, 90 | lora_alpha=32, 91 | task_type=TaskType.CAUSAL_LM, 92 | target_modules="all-linear", 93 | ) 94 | policy.model.pi0_paligemma = get_peft_model(model, peft_config) 95 | self.policy = policy 96 | 97 | def forward(self, batch): 98 | return self.policy(batch)[0] 99 | 100 | def training_step(self, batch, batch_idx): 101 | loss, loss_dict = self.policy(batch) 102 | self.log("train_loss", loss, prog_bar=True) 103 | self.log("acc", loss_dict["acc"], prog_bar=True) 104 | return loss 105 | 106 | def configure_optimizers(self): 107 | optimizer = torch.optim.AdamW( 108 | self.policy.get_optim_params(), lr=5e-5, weight_decay=1e-2, eps=1e-6 109 | ) 110 | scheduler = LinearLR( 111 | optimizer, 112 | start_factor=0.01, 113 | end_factor=1.0, 114 | total_iters=100, 115 | ) 116 | return { 117 | "optimizer": optimizer, 118 | "lr_scheduler": { 119 | "scheduler": scheduler, 120 | "interval": "step", 121 | "frequency": 1, 122 | }, 123 | } 124 | 125 | 126 | dataset = PI0SO100Dataset( 127 | [f"ZibinDong/so100_play_screwdriver_0{i + 1}" for i in range(6)] 128 | ) 129 | dataloader = DataLoader( 130 | dataset, batch_size=2, shuffle=True, num_workers=2, persistent_workers=True 131 | ) 132 | 133 | callback = ModelCheckpoint( 134 | dirpath="/mnt/20T/dzb/pi0fast_so100_checkpoints", # where you want to save the checkpoints 135 | filename="{epoch}-{step}", 136 | save_top_k=-1, # save all checkpoints 137 | every_n_epochs=4, # save all checkpoints 138 | ) 139 | 140 | trainer = L.Trainer( 141 | accelerator="cuda", 142 | devices=4, 143 | strategy="ddp", 144 | max_epochs=50, 145 | enable_progress_bar=True, 146 | gradient_clip_val=1.0, 147 | precision="bf16-mixed", 148 | accumulate_grad_batches=4, 149 | callbacks=[callback], 150 | ) 151 | 152 | with trainer.init_module(): 153 | ckpt_path = ( 154 | "/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch" 155 | ) 156 | config = PreTrainedConfig.from_pretrained(ckpt_path) 157 | config.max_action_dim = 6 # set action dimension to 6 158 | config.chunk_size = 50 # set action chunk length to 50 159 | training_policy = LightningTrainingWrapper(config, ckpt_path) 160 | 161 | trainer.fit(training_policy, dataloader) 162 | -------------------------------------------------------------------------------- /so100_eval_fast.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import numpy as np 3 | import torch 4 | from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset 5 | from lerobot.configs.policies import PreTrainedConfig 6 | from peft import LoraConfig, TaskType, get_peft_model 7 | from termcolor import cprint 8 | from torchvision.transforms.v2 import Resize 9 | 10 | from pi0 import PI0FASTPolicy 11 | from utils.normalizers import Normalizer 12 | from utils.server import PolicyServer 13 | 14 | 15 | def to_device_dtype(d, device, dtype): 16 | for key, value in d.items(): 17 | if isinstance(value, dict): 18 | to_device_dtype(value, device, dtype) 19 | elif isinstance(value, torch.Tensor): 20 | if key not in ["action_is_pad"]: 21 | d[key] = value.to(device=device, dtype=dtype) 22 | else: 23 | d[key] = value.to(device=device) 24 | else: 25 | pass 26 | return d 27 | 28 | 29 | class LightningTrainingWrapper(L.LightningModule): 30 | def __init__(self, config, ckpt_path): 31 | super().__init__() 32 | # load model in `configure_model` to accelerate model loading 33 | self.policy = None 34 | self.config = config 35 | self.ckpt_path = ckpt_path 36 | 37 | def configure_model(self): 38 | if self.policy is None: 39 | policy = PI0FASTPolicy.from_pretrained(self.ckpt_path, config=self.config) 40 | # add lora to pi0_paligemma model 41 | model = policy.model.pi0_paligemma 42 | peft_config = LoraConfig( 43 | r=16, 44 | lora_alpha=32, 45 | task_type=TaskType.CAUSAL_LM, 46 | target_modules="all-linear", 47 | ) 48 | policy.model.pi0_paligemma = get_peft_model(model, peft_config) 49 | self.policy = policy 50 | 51 | def forward(self, batch): 52 | return self.policy(batch)[0] 53 | 54 | def training_step(self, batch, batch_idx): 55 | loss, loss_dict = self.policy(batch) 56 | self.log("train_loss", loss, prog_bar=True) 57 | self.log("acc", loss_dict["acc"], prog_bar=True) 58 | return loss 59 | 60 | 61 | class SO100Policy: 62 | def __init__( 63 | self, 64 | ckpt_path: str, 65 | pi0fast_ckpt_path: str, 66 | repo_ids: list[str] = None, 67 | device: str = "cuda:0", 68 | dtype: torch.dtype = torch.bfloat16, 69 | ): 70 | self.device = device 71 | self.dtype = dtype 72 | 73 | # load policy 74 | cprint("Loading SO100 Policy...", "yellow") 75 | config = PreTrainedConfig.from_pretrained(pi0fast_ckpt_path) 76 | config.max_action_dim = 6 # set action dimension to 6 77 | config.chunk_size = 50 # set action chunk length to 50 78 | training_policy = LightningTrainingWrapper(config, pi0fast_ckpt_path) 79 | training_policy.configure_model() 80 | training_policy.load_state_dict( 81 | torch.load(ckpt_path, map_location=device)["state_dict"] 82 | ) 83 | self.policy = training_policy.policy.to(device=device, dtype=dtype).eval() 84 | cprint("SO100 Policy loaded successfully!", "green") 85 | 86 | cprint("Prepareing norm stats...", "yellow") 87 | dataset = MultiLeRobotDataset(repo_ids) 88 | self.normalizer = Normalizer( 89 | norm_stats=dataset.stats, 90 | norm_type={ 91 | "observation.images.base": "identity", 92 | "observation.images.wrist": "identity", 93 | "observation.state": "minmax", 94 | "action": "minmax", 95 | }, 96 | ) 97 | print(self.normalizer.norm_stats) 98 | cprint("Norm stats prepared successfully!", "green") 99 | 100 | self.resize = Resize((224, 224)) 101 | 102 | cprint("Ready to use SO100 Policy!", "green") 103 | 104 | @torch.no_grad() 105 | def act(self, obs: np.ndarray): 106 | """ 107 | obs: { 108 | "base": uint8 (H, W, C), 109 | "wrist": uint8 (H, W, C), 110 | "state": float32 (state_dim,), 111 | "prompt": str 112 | } 113 | """ 114 | obs = self.normalizer.normalize( 115 | { 116 | "observation.images.base": obs["base"], 117 | "observation.images.wrist": obs["wrist"], 118 | "observation.state": obs["state"], 119 | "prompt": obs["prompt"], 120 | } 121 | ) 122 | 123 | base_image = torch.tensor( 124 | obs["observation.images.base"], dtype=torch.uint8, device=self.device 125 | ) 126 | wrist_image = torch.tensor( 127 | obs["observation.images.wrist"], dtype=torch.uint8, device=self.device 128 | ) 129 | base_image = base_image.permute(2, 0, 1)[None] 130 | wrist_image = wrist_image.permute(2, 0, 1)[None] 131 | base_image = self.resize(base_image) 132 | wrist_image = self.resize(wrist_image) 133 | state = torch.tensor( 134 | obs["observation.state"], dtype=self.dtype, device=self.device 135 | )[None] 136 | prompt = obs["prompt"] 137 | action = self.policy.select_action( 138 | { 139 | "image": { 140 | "base_0_rgb": base_image, 141 | "left_wrist_0_rgb": wrist_image, 142 | }, 143 | "state": state, 144 | "prompt": prompt, 145 | } 146 | ) 147 | action = action[:, :, :6] 148 | action = action.float().cpu().numpy() 149 | action = self.normalizer.unnormalize({"action": action})["action"] 150 | return action 151 | 152 | def __call__(self, obs: np.ndarray): 153 | return self.act(obs) 154 | 155 | 156 | if __name__ == "__main__": 157 | policy = SO100Policy( 158 | ckpt_path="/mnt/20T/dzb/pi0fast_so100_checkpoints/epoch=39-step=29760.ckpt", 159 | pi0fast_ckpt_path="/home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_fast_base_pytorch", 160 | repo_ids=[f"ZibinDong/so100_play_screwdriver_0{i + 1}" for i in range(6)], 161 | device="cuda:0", 162 | dtype=torch.bfloat16, 163 | ) 164 | server = PolicyServer(policy, host="0.0.0.0", port=12346) 165 | server.run() 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simplified LeRobot's Pytorch PI0 & PI0-fast Implementation 2 | 3 | [The LeRobot team](https://github.com/huggingface/lerobot/tree/main) has made a substantial contribution to the community through their diligent efforts in converting the PI0 and PI0-fast VLA models to PyTorch. This was an impressive undertaking. However, the original release included only limited usage instructions and examples, making it challenging for users to get the models running correctly by simply following the provided guidance. 4 | 5 | This repository addresses those issues by introducing numerous fixes and removing redundant code and functionalities. Furthermore, it now includes comprehensive usage documentation, enabling users to seamlessly deploy official [OpenPI](https://github.com/Physical-Intelligence/openpi/tree/main?tab=readme-ov-file) checkpoints and fine-tune their own models with ease. 6 | 7 | ## 1. Installation 8 | 9 | If you only need to use the VLA models, you'll just need to install [LeRobot](https://github.com/huggingface/lerobot/tree/main) and [PyTorch](https://pytorch.org/). If you plan to run Libero's test scripts (not necessary for VLA), you'll also need to install [CleanDiffuser's Libero support](https://github.com/CleanDiffuserTeam/CleanDiffuser/tree/lightning/cleandiffuser/env/libero). 10 | 11 | --- 12 | 13 | ### 1.1 You need to create TWO environment. 14 | - The first one is for downloading JAX model and converting it to pytorch one 15 | - The second one is for traning, evaluating the model using pytorch 16 | 17 | Now, let's begin the installation. As it's very complex, you need to do it step by step. 18 | 19 | ### 1.2 Installing the first envrionment 20 | 21 | **1) Create a virtual envrionment** 22 | 23 | You can use conda or mamba to create an environment 24 | ``` 25 | mamba create -n pi0_jax python=3.11 26 | conda activate pi0_jax 27 | ``` 28 | 29 | Than, install the uv on it. 30 | ``` 31 | mamba install uv 32 | ``` 33 | 34 | **2) Install the Openpi** 35 | 36 | Firstly, download it 37 | ``` 38 | git clone --recurse-submodules https://github.com/Physical-Intelligence/openpi.git 39 | ``` 40 | 41 | Than, use `uv` to install the package 42 | ``` 43 | GIT_LFS_SKIP_SMUDGE=1 uv pip install . --system 44 | ``` 45 | 46 | Now, the first environment is successfully installed. 47 | 48 | 49 | ### 1.3 Installing the second envrionment 50 | 51 | **1) Create a virtual envrionment** 52 | 53 | You can use conda or mamba to create an environment 54 | ``` 55 | mamba create -n pi0_torch python=3.10 56 | conda activate pi0_torch 57 | ``` 58 | Than, install the `uv`, `ipython` and `ffmpeg` on it. 59 | ``` 60 | mamba install ffmpeg=7.1.1 uv ipython 61 | ``` 62 | 63 | **2) Install the basis package** 64 | 65 | Firstly, use `uv lock` to solve the environment. 66 | 67 | Secondly, you need to generate the `requirements.txt`. 68 | ``` 69 | uv export --format requirements-txt > requirements.txt 70 | ``` 71 | 72 | Finally, install all the package 73 | ``` 74 | uv pip install --system -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match 75 | ``` 76 | 77 | > This environment contains all the basic package used by `CleanDiffuser`, `lerobot`, `LIBERO` and `openpi_pytorch`. 78 | 79 | 80 | **3) Install `Cleandiffuser`** 81 | 82 | Firstly, download it by `git`. Be careful, the `lightning` branch is the only branch we can used. 83 | 84 | ``` 85 | git clone -b lightning git@github.com:CleanDiffuserTeam/CleanDiffuser.git 86 | ``` 87 | 88 | Then, go into the `Cleandiffuser` folder 89 | ``` 90 | cd Cleandiffuser 91 | ``` 92 | 93 | Then, edit the `pyproject.toml`, delete all the dependences in the file 94 | ``` 95 | dependencies = [ 96 | 97 | ] 98 | ``` 99 | 100 | > We have already installed it in the Step2, delete it and make sure the `pip` will not change the version of basis package. 101 | 102 | Now, install it 103 | ``` 104 | pip install . 105 | ``` 106 | 107 | 108 | **4) Install `lerobot`** 109 | 110 | Firstly, download it 111 | 112 | ``` 113 | git clone https://github.com/huggingface/lerobot.git 114 | ``` 115 | 116 | Then, go into the folder and edit the `pyproject.toml`, and delete all the dependences. 117 | 118 | Then, install it 119 | ``` 120 | pip install . 121 | ``` 122 | 123 | 124 | **5) Install `LIBERO`** 125 | 126 | Firstly, download it 127 | 128 | ``` 129 | git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git 130 | ``` 131 | 132 | Then, go into the folder and edit the `requirements.txt`, and delete all the lines. 133 | 134 | Then, install it 135 | ``` 136 | pip install . 137 | ``` 138 | 139 | Now, all the packages have been successfully installed. 140 | 141 | 142 | 143 | 144 | 145 | ## 2. Usage 146 | 147 | ### 2.1. Converting OpenPI Checkpoints 148 | 149 | > In this step, you need to use `pi0_jax` environment 150 | 151 | You can directly use the checkpoints LeRobot has uploaded to [HuggingFace](https://huggingface.co/lerobot/pi0): 152 | 153 | ```python 154 | from pi0 import Pi0Policy 155 | policy = Pi0Policy.from_pretrained("lerobot/pi0") 156 | ``` 157 | 158 | LeRobot has only uploaded the `pi0_base` model. However, OpenPI provides a [**list of checkpoints**](https://github.com/Physical-Intelligence/openpi?tab=readme-ov-file#model-checkpoints) for inference or fine-tuning, so I highly recommend using the conversion script to **flexibly obtain various OpenPI checkpoints**. 159 | 160 | First, you'll need to install [OpenPI](https://github.com/Physical-Intelligence/openpi/tree/main?tab=readme-ov-file) and download an official JAX checkpoint. Let's take `pi0_libero` as an example: 161 | 162 | ```python 163 | from openpi.shared import download 164 | checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_libero", anon=True) 165 | ``` 166 | 167 | This will store the downloaded checkpoint in `"/home/username/.cache/openpi/openpi-assets/checkpoints/pi0_libero"` if you're using Ubuntu. Then, you can run the conversion script by simply providing the JAX checkpoint path and the desired PyTorch checkpoint path: 168 | 169 | ```bash 170 | python convert_pi0_to_hf_lerobot.py \ 171 | --checkpoint_dir /home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_libero/params \ 172 | --output_path /home/dzb/.cache/openpi/openpi-assets/checkpoints/pi0_libero_pytorch 173 | ``` 174 | 175 | **Note:** After completing this step, **do not delete the JAX checkpoint**. This folder contains crucial `norm_stats` parameters, which are essential if you plan to use the model for inference. 176 | 177 | ### 2.2. Try Inference Code 178 | 179 | > In this step, you need to use `pi0_torch` environment 180 | 181 | Please see `1_e2e_inference.py`. -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import h5py 6 | import libero 7 | import numpy as np 8 | import tqdm 9 | from lerobot.common.datasets.lerobot_dataset import LeRobotDataset 10 | from libero.libero import benchmark, get_libero_path 11 | from libero.libero.envs import OffScreenRenderEnv 12 | 13 | # --- Config --- 14 | 15 | argparser = argparse.ArgumentParser() 16 | argparser.add_argument("--image_size", type=int, default=224) 17 | argparser.add_argument("--benchmark", type=str, default="libero_goal") 18 | args = argparser.parse_args() 19 | 20 | IMAGE_SIZE = args.image_size 21 | LIBERO_PATH = Path(os.path.dirname(libero.libero.__file__)).parents[0] 22 | DATASET_PATH = LIBERO_PATH / "datasets" 23 | BENCHMARKS = [args.benchmark] 24 | 25 | # benchmark for suite 26 | benchmark_dict = benchmark.get_benchmark_dict() 27 | 28 | # Total number of tasks 29 | num_tasks = 0 30 | for bm in BENCHMARKS: 31 | benchmark_path = DATASET_PATH / bm 32 | num_tasks += len(list(benchmark_path.glob("*.hdf5"))) 33 | 34 | tasks_stored = 0 35 | for bm in BENCHMARKS: 36 | print(f"############################# {bm} #############################") 37 | benchmark_path = DATASET_PATH / bm 38 | 39 | # Init env benchmark suite 40 | task_suite = benchmark_dict[bm]() 41 | 42 | # Init lerobot dataset 43 | dataset = LeRobotDataset.create( 44 | repo_id=f"ZibinDong/{bm}", 45 | robot_type="panda", 46 | fps=20, 47 | features={ 48 | "base_0_rgb": { 49 | "dtype": "video", 50 | "shape": (224, 224, 3), 51 | "names": ["height", "width", "channel"], 52 | }, 53 | "right_wrist_0_rgb": { 54 | "dtype": "video", 55 | "shape": (224, 224, 3), 56 | "names": ["height", "width", "channel"], 57 | }, 58 | "state": { 59 | "dtype": "float32", 60 | "shape": (9,), 61 | "names": ["state"], 62 | }, 63 | "actions": { 64 | "dtype": "float32", 65 | "shape": (7,), 66 | "names": ["actions"], 67 | }, 68 | }, 69 | image_writer_threads=10, 70 | image_writer_processes=5, 71 | ) 72 | 73 | for task_file in benchmark_path.glob("*.hdf5"): 74 | print(f"Processing {tasks_stored + 1}/{num_tasks}: {task_file}") 75 | data = h5py.File(task_file, "r")["data"] 76 | 77 | # Init env 78 | task_name = str(task_file).split("/")[-1][:-10] 79 | # get task id from list of task names 80 | task_id = task_suite.get_task_names().index(task_name) 81 | # create environment 82 | task = task_suite.get_task(task_id) 83 | task_name = task.name 84 | task_bddl_file = os.path.join( 85 | get_libero_path("bddl_files"), task.problem_folder, task.bddl_file 86 | ) 87 | env_args = { 88 | "bddl_file_name": task_bddl_file, 89 | "camera_heights": IMAGE_SIZE, 90 | "camera_widths": IMAGE_SIZE, 91 | } 92 | env = OffScreenRenderEnv(**env_args) 93 | 94 | obs = env.reset() 95 | 96 | states = [] 97 | actions = [] 98 | rewards = [] 99 | episode_ends = [] 100 | 101 | for demo in tqdm.tqdm(data.keys()): 102 | print(f"Processing demo {demo}") 103 | demo_data = data[demo] 104 | 105 | colors, colors_ego = [], [] 106 | joint_states, eef_states, gripper_states = [], [], [] 107 | 108 | for i in range(len(demo_data["states"])): 109 | obs = env.regenerate_obs_from_state(demo_data["states"][i]) 110 | 111 | # get RGBD 112 | color = obs["agentview_image"][::-1] 113 | color_ego = obs["robot0_eye_in_hand_image"][::-1] 114 | eef_state = np.concatenate( 115 | [obs["robot0_eef_pos"], obs["robot0_eef_quat"]] 116 | ) 117 | gripper_state = obs["robot0_gripper_qpos"] 118 | 119 | dataset.add_frame( 120 | { 121 | "base_0_rgb": color, 122 | "right_wrist_0_rgb": color_ego, 123 | "state": np.concatenate([gripper_state, eef_state]).astype(np.float32), 124 | "actions": demo_data["actions"][i].astype(np.float32), 125 | }, 126 | task=env.language_instruction, 127 | ) 128 | dataset.save_episode() 129 | 130 | print(f"{env.language_instruction}: Finish!") 131 | tasks_stored += 1 132 | 133 | # dataset = LeRobotDataset("ZibinDong/libero_goal") 134 | # dataset.push_to_hub(tags="libero", private=True) 135 | 136 | """ 137 | - data 138 | - demo_0 139 | - actions float64(n, 7) 140 | - dones uint8(n,) 141 | - rewards uint8(n,) 142 | - robot_states float64(n, 9) # (gripper_states(2), ee_pos(3), ee_quad(4)) 143 | - states float64(n, 79) 144 | - obs 145 | - agentview_rgb uint8(n, 128, 128, 3) 146 | - eye_in_hand_rgb uint8(n, 128, 128, 3) 147 | - ee_ori float64(n, 3) 148 | - ee_pos float64(n, 3) 149 | - ee_states float64(n, 6) 150 | - gripper_states float64(n, 2) 151 | - joint_states float64(n, 7) 152 | - demo_1 153 | - ... 154 | """ 155 | 156 | 157 | # REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub 158 | # RAW_DATASET_NAMES = [ 159 | # "libero_10_no_noops", 160 | # "libero_goal_no_noops", 161 | # "libero_object_no_noops", 162 | # "libero_spatial_no_noops", 163 | # ] # For simplicity we will combine multiple Libero datasets into one training dataset 164 | 165 | 166 | # def main(data_dir: str, *, push_to_hub: bool = False): 167 | # # Clean up any existing dataset in the output directory 168 | # output_path = LEROBOT_HOME / REPO_NAME 169 | # if output_path.exists(): 170 | # shutil.rmtree(output_path) 171 | 172 | # # Create LeRobot dataset, define features to store 173 | # # OpenPi assumes that proprio is stored in `state` and actions in `action` 174 | # # LeRobot assumes that dtype of image data is `image` 175 | # dataset = LeRobotDataset.create( 176 | # repo_id=REPO_NAME, 177 | # robot_type="panda", 178 | # fps=10, 179 | # features={ 180 | # "image": { 181 | # "dtype": "image", 182 | # "shape": (256, 256, 3), 183 | # "names": ["height", "width", "channel"], 184 | # }, 185 | # "wrist_image": { 186 | # "dtype": "image", 187 | # "shape": (256, 256, 3), 188 | # "names": ["height", "width", "channel"], 189 | # }, 190 | # "state": { 191 | # "dtype": "float32", 192 | # "shape": (8,), 193 | # "names": ["state"], 194 | # }, 195 | # "actions": { 196 | # "dtype": "float32", 197 | # "shape": (7,), 198 | # "names": ["actions"], 199 | # }, 200 | # }, 201 | # image_writer_threads=10, 202 | # image_writer_processes=5, 203 | # ) 204 | 205 | # # Loop over raw Libero datasets and write episodes to the LeRobot dataset 206 | # # You can modify this for your own data format 207 | # for raw_dataset_name in RAW_DATASET_NAMES: 208 | # raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train") 209 | # for episode in raw_dataset: 210 | # for step in episode["steps"].as_numpy_iterator(): 211 | # dataset.add_frame( 212 | # { 213 | # "image": step["observation"]["image"], 214 | # "wrist_image": step["observation"]["wrist_image"], 215 | # "state": step["observation"]["state"], 216 | # "actions": step["action"], 217 | # } 218 | # ) 219 | # dataset.save_episode(task=step["language_instruction"].decode()) 220 | 221 | # # Consolidate the dataset, skip computing stats since we will do that later 222 | # dataset.consolidate(run_compute_stats=False) 223 | 224 | # # Optionally push to the Hugging Face Hub 225 | # if push_to_hub: 226 | # dataset.push_to_hub( 227 | # tags=["libero", "panda", "rlds"], 228 | # private=False, 229 | # push_videos=True, 230 | # license="apache-2.0", 231 | # ) 232 | 233 | 234 | # if __name__ == "__main__": 235 | # tyro.cli(main) 236 | -------------------------------------------------------------------------------- /pi0/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import einops 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | # from xformers.ops import memory_efficient_attention 11 | 12 | 13 | def find_next_divisible_by_8_numpy(n: np.ndarray) -> np.ndarray: 14 | """ 15 | Finds the smallest integers greater than each element in a NumPy array 'n' 16 | that are divisible by 8. Assumes non-negative integers. 17 | 18 | Args: 19 | n: A NumPy array of integers. 20 | 21 | Returns: 22 | A NumPy array containing the smallest integers greater than each input element 23 | that are divisible by 8. 24 | """ 25 | remainder = n % 8 26 | # Calculate the amount to add: 0 if already divisible, otherwise 8 - remainder 27 | # np.where is efficient for conditional operations on arrays 28 | amount_to_add = np.where(remainder == 0, 8, 8 - remainder) 29 | return n + amount_to_add 30 | 31 | 32 | def create_sinusoidal_pos_embedding( 33 | time: torch.tensor, 34 | dimension: int, 35 | min_period: float, 36 | max_period: float, 37 | device="cpu", 38 | ) -> Tensor: 39 | """Computes sine-cosine positional embedding vectors for scalar positions.""" 40 | if dimension % 2 != 0: 41 | raise ValueError(f"dimension ({dimension}) must be divisible by 2") 42 | 43 | if time.ndim != 1: 44 | raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") 45 | 46 | fraction = torch.linspace( 47 | 0.0, 1.0, dimension // 2, dtype=torch.float32, device=device 48 | ) 49 | period = min_period * (max_period / min_period) ** fraction 50 | 51 | # Compute the outer product 52 | scaling_factor = 1.0 / period * 2 * math.pi 53 | sin_input = scaling_factor[None, :] * time[:, None] 54 | pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) 55 | return pos_emb 56 | 57 | 58 | def sample_beta(alpha, beta, bsize, device): 59 | gamma1 = torch.rand((bsize,), device=device).pow(1 / alpha) 60 | gamma2 = torch.rand((bsize,), device=device).pow(1 / beta) 61 | return gamma1 / (gamma1 + gamma2) 62 | 63 | 64 | def make_att_2d_masks(pad_masks, att_masks): 65 | """Copied from big_vision. 66 | 67 | Tokens can attend to valid inputs tokens which have a cumulative mask_ar 68 | smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to 69 | setup several types of attention, for example: 70 | 71 | [[1 1 1 1 1 1]]: pure causal attention. 72 | 73 | [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between 74 | themselves and the last 3 tokens have a causal attention. The first 75 | entry could also be a 1 without changing behaviour. 76 | 77 | [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a 78 | block can attend all previous blocks and all tokens on the same block. 79 | 80 | Args: 81 | input_mask: bool[B, N] true if its part of the input, false if padding. 82 | mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on 83 | it and 0 where it shares the same attention mask as the previous token. 84 | """ 85 | if att_masks.ndim != 2: 86 | raise ValueError(att_masks.ndim) 87 | if pad_masks.ndim != 2: 88 | raise ValueError(pad_masks.ndim) 89 | 90 | cumsum = torch.cumsum(att_masks, dim=1) 91 | att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] 92 | pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] 93 | att_2d_masks = att_2d_masks & pad_2d_masks 94 | return att_2d_masks 95 | 96 | 97 | def resize_with_pad(img, width, height, pad_value=-1): 98 | # assume no-op when width height fits already 99 | if img.ndim != 4: 100 | raise ValueError(f"(b,c,h,w) expected, but {img.shape}") 101 | 102 | cur_height, cur_width = img.shape[2:] 103 | 104 | ratio = max(cur_width / width, cur_height / height) 105 | resized_height = int(cur_height / ratio) 106 | resized_width = int(cur_width / ratio) 107 | resized_img = F.interpolate( 108 | img, size=(resized_height, resized_width), mode="bilinear", align_corners=False 109 | ) 110 | 111 | pad_height = max(0, int(height - resized_height)) 112 | pad_width = max(0, int(width - resized_width)) 113 | 114 | # pad on left and top of image 115 | padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) 116 | return padded_img 117 | 118 | 119 | def eager_attention_forward( 120 | query_states: torch.Tensor, 121 | key_states: torch.Tensor, 122 | value_states: torch.Tensor, 123 | attention_mask: torch.Tensor, 124 | ): 125 | """ 126 | Performs eager attention, optimized with torch.einsum. 127 | 128 | Args: 129 | query_states: Query tensor of shape [batch_size, seq_len, num_attention_heads, head_dim]. 130 | key_states: Key tensor of shape [batch_size, seq_len, num_key_value_heads, head_dim]. 131 | value_states: Value tensor of shape [batch_size, seq_len, num_key_value_heads, head_dim]. 132 | attention_mask: Attention mask tensor, typically [batch_size, 1, seq_len, seq_len] or [batch_size, seq_len, seq_len]. 133 | 134 | Returns: 135 | Output tensor of shape [batch_size, seq_len, num_attention_heads * head_dim]. 136 | """ 137 | bsize, seq_len, num_att_heads, head_dim = query_states.shape 138 | num_key_value_heads = key_states.shape[2] 139 | num_key_value_groups = num_att_heads // num_key_value_heads 140 | 141 | key_states = einops.repeat( 142 | key_states, "b l h d -> b l (h g) d", g=num_key_value_groups 143 | ) 144 | value_states = einops.repeat( 145 | value_states, "b l h d -> b l (h g) d", g=num_key_value_groups 146 | ) 147 | 148 | query_states_permuted = torch.einsum("blhd->bhld", query_states) 149 | key_states_permuted = torch.einsum("blhd->bhld", key_states) 150 | 151 | att_weights = torch.einsum( 152 | "bhqd,bhkd->bhqk", query_states_permuted, key_states_permuted 153 | ) 154 | att_weights *= head_dim**-0.5 155 | 156 | big_neg = -2.3819763e38 157 | masked_att_weights = torch.where( 158 | attention_mask[:, None, :, :], att_weights, big_neg 159 | ) 160 | 161 | probs = nn.functional.softmax(masked_att_weights, dim=-1) 162 | probs = probs.to(dtype=value_states.dtype) 163 | 164 | value_states_permuted = torch.einsum("blhd->bhld", value_states) # [B, H, L_v, D] 165 | att_output = torch.einsum( 166 | "bhqk,bhkv->bhqv", probs, value_states_permuted 167 | ) # [B, H, L_q, D] 168 | att_output = torch.einsum("bhld->blhd", att_output) # [B, L, H, D] 169 | att_output = att_output.reshape(bsize, seq_len, num_att_heads * head_dim) 170 | 171 | return att_output 172 | 173 | 174 | # def xformer_attention_forward(query_states, key_states, value_states, attention_mask): 175 | # bsize, seq_len, num_att_heads, head_dim = query_states.shape 176 | # num_key_value_heads = key_states.shape[2] 177 | # num_key_value_groups = num_att_heads // num_key_value_heads 178 | 179 | # query_states = einops.rearrange( 180 | # query_states, "b l (h g) d -> b l h g d", g=num_key_value_groups 181 | # ) 182 | # key_states = einops.repeat( 183 | # key_states, "b l h d -> b l h g d", g=num_key_value_groups 184 | # ) 185 | # value_states = einops.repeat( 186 | # value_states, "b l h d -> b l h g d", g=num_key_value_groups 187 | # ) 188 | # aligned_attention_mask = torch.zeros( 189 | # (bsize, seq_len, find_next_divisible_by_8_numpy(seq_len).item()), 190 | # dtype=query_states.dtype, 191 | # device=attention_mask.device, 192 | # ) 193 | # big_neg = -2.3819763e38 194 | # aligned_attention_mask[:, :, :seq_len] = ~attention_mask * big_neg 195 | # aligned_attention_mask = einops.repeat( 196 | # aligned_attention_mask, 197 | # "b l s -> b h g l s", 198 | # h=num_key_value_heads, 199 | # g=num_key_value_groups, 200 | # )[:, :, :, :, :seq_len] 201 | 202 | # att_output = memory_efficient_attention( 203 | # query=query_states, 204 | # key=key_states, 205 | # value=value_states, 206 | # attn_bias=aligned_attention_mask, 207 | # ) 208 | # att_output = att_output.reshape(bsize, seq_len, -1) 209 | 210 | # return att_output 211 | 212 | 213 | @torch.jit.script 214 | def apply_rope( 215 | x: torch.Tensor, 216 | positions: torch.Tensor, 217 | max_wavelength: float = 10_000.0, 218 | dtype: torch.dtype = torch.float32, 219 | ) -> torch.Tensor: 220 | """Applies RoPE positions [B, L] to x [B, L, H, D].""" 221 | original_dtype = x.dtype 222 | d = x.shape[-1] 223 | d_half = d // 2 224 | device = x.device 225 | 226 | # Cast input to compute_dtype for all internal operations 227 | x_casted = x.to(dtype) 228 | positions_casted = positions.to(dtype) 229 | 230 | freq_exponents = (2.0 / d) * torch.arange(d_half, dtype=dtype, device=device) 231 | timescale = max_wavelength**freq_exponents 232 | radians = torch.einsum("bl,h->blh", positions_casted, 1.0 / timescale) 233 | 234 | radians = radians[..., None, :] # [B, L, 1, D_half] 235 | 236 | sin = torch.sin(radians) 237 | cos = torch.cos(radians) 238 | 239 | x1, x2 = x_casted.split(d_half, dim=-1) 240 | 241 | res = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) 242 | 243 | return res.to(original_dtype) 244 | -------------------------------------------------------------------------------- /pi0/paligemma_with_expert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Union 16 | 17 | import torch 18 | from pytest import Cache 19 | from transformers import ( 20 | AutoConfig, 21 | GemmaForCausalLM, 22 | PaliGemmaForConditionalGeneration, 23 | PretrainedConfig, 24 | PreTrainedModel, 25 | ) 26 | from transformers.models.auto import CONFIG_MAPPING 27 | 28 | from .utils import apply_rope, eager_attention_forward 29 | 30 | 31 | class PaliGemmaWithExpertConfig(PretrainedConfig): 32 | model_type = "PaliGemmaWithExpertModel" 33 | sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig} 34 | 35 | def __init__( 36 | self, 37 | paligemma_config: dict | None = None, 38 | gemma_expert_config: dict | None = None, 39 | freeze_vision_encoder: bool = True, 40 | train_expert_only: bool = True, 41 | attention_implementation: str = "eager", 42 | **kwargs, 43 | ): 44 | self.freeze_vision_encoder = freeze_vision_encoder 45 | self.train_expert_only = train_expert_only 46 | self.attention_implementation = attention_implementation 47 | 48 | if paligemma_config is None: 49 | # Default config from Pi0 50 | self.paligemma_config = CONFIG_MAPPING["paligemma"]( 51 | transformers_version="4.48.1", 52 | _vocab_size=257152, 53 | bos_token_id=2, 54 | eos_token_id=1, 55 | hidden_size=2048, 56 | image_token_index=257152, 57 | model_type="paligemma", 58 | pad_token_id=0, 59 | projection_dim=2048, 60 | text_config={ 61 | "hidden_activation": "gelu_pytorch_tanh", 62 | "hidden_size": 2048, 63 | "intermediate_size": 16384, 64 | "model_type": "gemma", 65 | "num_attention_heads": 8, 66 | "num_hidden_layers": 18, 67 | "num_image_tokens": 256, 68 | "num_key_value_heads": 1, 69 | "torch_dtype": "float32", 70 | "vocab_size": 257152, 71 | }, 72 | vision_config={ 73 | "hidden_size": 1152, 74 | "intermediate_size": 4304, 75 | "model_type": "siglip_vision_model", 76 | "num_attention_heads": 16, 77 | "num_hidden_layers": 27, 78 | "num_image_tokens": 256, 79 | "patch_size": 14, 80 | "projection_dim": 2048, 81 | "projector_hidden_act": "gelu_fast", 82 | "torch_dtype": "float32", 83 | "vision_use_head": False, 84 | }, 85 | ) 86 | elif isinstance(self.paligemma_config, dict): 87 | # Override Pi0 default config for PaliGemma 88 | if "model_type" not in gemma_expert_config: 89 | paligemma_config["model_type"] = "paligemma" 90 | 91 | cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] 92 | self.paligemma_config = cfg_cls(**paligemma_config) 93 | 94 | if gemma_expert_config is None: 95 | # Default config from Pi0 96 | self.gemma_expert_config = CONFIG_MAPPING["gemma"]( 97 | attention_bias=False, 98 | attention_dropout=0.0, 99 | bos_token_id=2, 100 | eos_token_id=1, 101 | head_dim=256, 102 | hidden_act="gelu_pytorch_tanh", 103 | hidden_activation="gelu_pytorch_tanh", 104 | hidden_size=1024, 105 | initializer_range=0.02, 106 | intermediate_size=4096, 107 | max_position_embeddings=8192, 108 | model_type="gemma", 109 | num_attention_heads=8, 110 | num_hidden_layers=18, 111 | num_key_value_heads=1, 112 | pad_token_id=0, 113 | rms_norm_eps=1e-06, 114 | rope_theta=10000.0, 115 | torch_dtype="float32", 116 | transformers_version="4.48.1", 117 | use_cache=True, 118 | vocab_size=257152, 119 | ) 120 | elif isinstance(self.gemma_expert_config, dict): 121 | # Override Pi0 default config for Gemma Expert 122 | if "model_type" not in gemma_expert_config: 123 | gemma_expert_config["model_type"] = "gemma" 124 | 125 | cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] 126 | self.gemma_expert_config = cfg_cls(**gemma_expert_config) 127 | 128 | super().__init__(**kwargs) 129 | 130 | def __post_init__(self): 131 | super().__post_init__() 132 | if self.train_expert_only and not self.freeze_vision_encoder: 133 | raise ValueError( 134 | "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." 135 | ) 136 | 137 | if self.attention_implementation not in ["eager", "fa2", "flex"]: 138 | raise ValueError( 139 | f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." 140 | ) 141 | 142 | 143 | class PaliGemmaWithExpertModel(PreTrainedModel): 144 | config_class = PaliGemmaWithExpertConfig 145 | 146 | def __init__(self, config: PaliGemmaWithExpertConfig): 147 | super().__init__(config=config) 148 | self.config = config 149 | self.paligemma = PaliGemmaForConditionalGeneration( 150 | config=config.paligemma_config 151 | ) 152 | self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) 153 | # Remove unused embed_tokens 154 | self.gemma_expert.model.embed_tokens = None 155 | 156 | self.attention_interface = self.get_attention_interface() 157 | 158 | # self.to_bfloat16_like_physical_intelligence() 159 | self.set_requires_grad() 160 | 161 | def set_requires_grad(self): 162 | """sets the requires_grad attribute of the model parameters based on the configuration. 163 | If `freeze_vision_encoder` is True, the vision tower parameters are frozen. 164 | If `train_expert_only` is True, the entire PaliGemma model is frozen. 165 | """ 166 | if self.config.freeze_vision_encoder: 167 | self.paligemma.vision_tower.eval() 168 | for params in self.paligemma.vision_tower.parameters(): 169 | params.requires_grad = False 170 | 171 | if self.config.train_expert_only: 172 | self.paligemma.eval() 173 | for params in self.paligemma.parameters(): 174 | params.requires_grad = False 175 | 176 | def train(self, mode: bool = True): 177 | super().train(mode) 178 | if self.config.freeze_vision_encoder: 179 | self.paligemma.vision_tower.eval() 180 | if self.config.train_expert_only: 181 | self.paligemma.eval() 182 | 183 | def to_bfloat16_like_physical_intelligence(self): 184 | """casts the model to bfloat16. 185 | 186 | Modules not casted to bfloat16: 187 | - paligemma.language_model.model.embed_tokens.weight 188 | - paligemma.language_model.model.norm.weight 189 | - gemma_expert.model.norm.weight 190 | - gemma_expert.lm_head.weight 191 | """ 192 | self.paligemma = self.paligemma.to(dtype=torch.bfloat16) 193 | 194 | params_to_change_dtype = [ 195 | "language_model.model.layers", 196 | "gemma_expert.model.layers", 197 | "vision_tower", 198 | "multi_modal", 199 | ] 200 | for name, param in self.named_parameters(): 201 | if any(selector in name for selector in params_to_change_dtype): 202 | param.data = param.data.to(dtype=torch.bfloat16) 203 | 204 | def embed_image(self, image: torch.Tensor): 205 | return self.paligemma.get_image_features(image) 206 | 207 | def embed_language_tokens(self, tokens: torch.Tensor): 208 | return self.paligemma.language_model.model.embed_tokens(tokens) 209 | 210 | def handle_kv_cache( 211 | self, 212 | key_states: torch.Tensor, 213 | value_states: torch.Tensor, 214 | layer_idx: int, 215 | past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, 216 | use_cache: Optional[bool] = None, 217 | fill_kv_cache: Optional[bool] = None, 218 | ): 219 | if use_cache: 220 | if past_key_values is None: 221 | past_key_values = {} 222 | 223 | if fill_kv_cache: 224 | past_key_values[layer_idx] = { 225 | "key_states": key_states, 226 | "value_states": value_states, 227 | } 228 | else: 229 | key_states = torch.cat( 230 | [past_key_values[layer_idx]["key_states"], key_states], dim=1 231 | ) 232 | value_states = torch.cat( 233 | [past_key_values[layer_idx]["value_states"], value_states], 234 | dim=1, 235 | ) 236 | return key_states, value_states, past_key_values 237 | 238 | def forward( 239 | self, 240 | attention_mask: Optional[torch.Tensor] = None, 241 | position_ids: Optional[torch.LongTensor] = None, 242 | past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, 243 | inputs_embeds: List[torch.FloatTensor] = None, 244 | use_cache: Optional[bool] = None, 245 | fill_kv_cache: Optional[bool] = None, 246 | ): 247 | """ 248 | Args: 249 | attention_mask (Optional[torch.Tensor], optional): 250 | Attention mask with shape (b, seq_len, seq_len). Defaults to None. 251 | position_ids (Optional[torch.LongTensor], optional): 252 | Position indices for applying RoPE. Defaults to None. 253 | past_key_values (Optional[Union[List[torch.FloatTensor], Cache]], optional): 254 | Optional kv cache. Defaults to None. 255 | inputs_embeds (List[torch.FloatTensor], optional): 256 | Input embeddings. Defaults to None. 257 | use_cache (Optional[bool], optional): 258 | Whether to use kv cache. Defaults to None. 259 | fill_kv_cache (Optional[bool], optional): 260 | Whether to return kv tensors in this forward pass as cache. Defaults to None. 261 | 262 | Returns: 263 | outputs_embeds (torch.Tensor): Output embeddings. 264 | past_key_values (Optional[Union[List[torch.FloatTensor], Cache]]): 265 | Optional kv cache. 266 | """ 267 | models = [self.paligemma.language_model.model, self.gemma_expert.model] 268 | 269 | # RMSNorm 270 | num_layers = self.paligemma.config.text_config.num_hidden_layers 271 | for layer_idx in range(num_layers): 272 | query_states = [] 273 | key_states = [] 274 | value_states = [] 275 | for i, hidden_states in enumerate(inputs_embeds): 276 | if hidden_states is None: 277 | continue 278 | 279 | layer = models[i].layers[layer_idx] 280 | hidden_states = layer.input_layernorm(hidden_states) 281 | hidden_shape = (*hidden_states.shape[:-1], -1, layer.self_attn.head_dim) 282 | 283 | query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) 284 | key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) 285 | value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) 286 | 287 | query_states.append(query_state) 288 | key_states.append(key_state) 289 | value_states.append(value_state) 290 | 291 | # B,L,H,D with L sequence length, H number of heads, D head dim 292 | # concatenate on the number of embeddings/tokens 293 | query_states = torch.cat(query_states, dim=1) 294 | key_states = torch.cat(key_states, dim=1) 295 | value_states = torch.cat(value_states, dim=1) 296 | 297 | query_states = apply_rope(query_states, position_ids) 298 | key_states = apply_rope(key_states, position_ids) 299 | 300 | key_states, value_states, past_key_values = self.handle_kv_cache( 301 | key_states, 302 | value_states, 303 | layer_idx, 304 | past_key_values=past_key_values, 305 | use_cache=use_cache, 306 | fill_kv_cache=fill_kv_cache, 307 | ) 308 | 309 | att_output = self.attention_interface( 310 | query_states, key_states, value_states, attention_mask 311 | ) 312 | 313 | # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) 314 | outputs_embeds = [] 315 | start = 0 316 | for i, hidden_states in enumerate(inputs_embeds): 317 | layer = models[i].layers[layer_idx] 318 | 319 | if hidden_states is not None: 320 | end = start + hidden_states.shape[1] 321 | 322 | if att_output.dtype != layer.self_attn.o_proj.weight.dtype: 323 | att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) 324 | out_emb = layer.self_attn.o_proj(att_output[:, start:end]) 325 | 326 | # first residual 327 | out_emb += hidden_states 328 | after_first_residual = out_emb.clone() 329 | 330 | out_emb = layer.post_attention_layernorm(out_emb) 331 | out_emb = layer.mlp(out_emb) 332 | 333 | # second residual 334 | out_emb += after_first_residual 335 | outputs_embeds.append(out_emb) 336 | 337 | start = end 338 | else: 339 | outputs_embeds.append(None) 340 | 341 | inputs_embeds = outputs_embeds 342 | 343 | # final norm 344 | outputs_embeds = [] 345 | for i, hidden_states in enumerate(inputs_embeds): 346 | if hidden_states is not None: 347 | out_emb = models[i].norm(hidden_states) 348 | outputs_embeds.append(out_emb) 349 | else: 350 | outputs_embeds.append(None) 351 | 352 | return outputs_embeds, past_key_values 353 | 354 | def get_attention_interface(self): 355 | if self.config.attention_implementation == "fa2": 356 | raise NotImplementedError("FA2 is not implemented (yet)") 357 | elif self.config.attention_implementation == "flex": 358 | # attention_interface = flex_attention_forward 359 | raise NotImplementedError("Flex attention is not implemented (yet)") 360 | elif self.config.attention_implementation == "eager": 361 | attention_interface = eager_attention_forward 362 | elif self.config.attention_implementation == "xformer": 363 | # attention_interface = xformer_attention_forward 364 | raise NotImplementedError("Xformer attention is not implemented (yet)") 365 | else: 366 | raise ValueError( 367 | f"Invalid attention implementation: {self.config.attention_implementation}. " 368 | "Expected one of ['fa2', 'flex', 'eager', 'xformer']." 369 | ) 370 | return attention_interface 371 | -------------------------------------------------------------------------------- /pi0/modeling_pi0.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | # In pi0_jax envrionment 6 | from lerobot.common.policies.pi0.configuration_pi0 import PI0Config 7 | from lerobot.common.policies.pretrained import PreTrainedPolicy 8 | # In pi0_torch envrionment 9 | # from lerobot.policies.pi0.configuration_pi0 import PI0Config 10 | # from lerobot.policies.pretrained import PreTrainedPolicy 11 | from torch import Tensor, nn 12 | from transformers import AutoTokenizer 13 | 14 | from .paligemma_with_expert import PaliGemmaWithExpertConfig, PaliGemmaWithExpertModel 15 | from .utils import ( 16 | create_sinusoidal_pos_embedding, 17 | make_att_2d_masks, 18 | resize_with_pad, 19 | sample_beta, 20 | ) 21 | 22 | IMAGE_KEYS = ( 23 | "base_0_rgb", 24 | "left_wrist_0_rgb", 25 | "right_wrist_0_rgb", 26 | ) 27 | 28 | 29 | class PI0Policy(PreTrainedPolicy): 30 | config_class = PI0Config 31 | name = "torch_pi0" 32 | 33 | def __init__( 34 | self, 35 | config: PI0Config, 36 | tokenizer_path: str = "google/paligemma-3b-pt-224", 37 | ): 38 | """ 39 | Args: 40 | config: Policy configuration class instance or None, in which case the default instantiation of 41 | the configuration class is used. 42 | """ 43 | 44 | super().__init__(config) 45 | self.config = config 46 | self.language_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 47 | self.model = PI0FlowMatching(config) 48 | self.reset() 49 | 50 | def reset(self): 51 | return None 52 | 53 | def get_optim_params(self) -> dict: 54 | return self.parameters() 55 | 56 | @torch.no_grad 57 | def select_action( 58 | self, observation: dict[str, Tensor], noise: Tensor | None = None 59 | ): 60 | """ 61 | Observation: { 62 | "image": { 63 | "base_0_rgb": (*b, c, h, w), # uint8 [0, 255] 64 | ... 65 | }, 66 | "state": float32 [*b, s], 67 | "prompt": List[str], 68 | 69 | "lang_tokens": float32 [*b, l], 70 | "lang_masks": float32 [*b, l], 71 | } 72 | either provide `prompt` or (`lang_tokens`, `lang_masks`). 73 | """ 74 | self.eval() 75 | 76 | images, img_masks = self.prepare_images(observation) 77 | state = self.prepare_state(observation) 78 | lang_tokens, lang_masks = self.prepare_language(observation) 79 | actions = self.model.sample_actions( 80 | images, img_masks, lang_tokens, lang_masks, state, noise=noise 81 | ) 82 | return actions 83 | 84 | def forward( 85 | self, batch: dict[str, Tensor], noise=None, time=None 86 | ) -> tuple[Tensor, dict[str, Tensor]]: 87 | """Do a full training forward pass to compute the loss 88 | 89 | batch: { 90 | "image": { 91 | "base_0_rgb": (*b, c, h, w), # uint8 [0, 255] 92 | ... 93 | }, 94 | "state": float32 [*b, s], 95 | "lang_tokens": float32 [*b, l], 96 | "lang_masks": float32 [*b, l], 97 | "action": float32 [*b, ha, da] 98 | } 99 | """ 100 | images, img_masks = self.prepare_images(batch) 101 | state = self.prepare_state(batch) 102 | lang_tokens, lang_masks = self.prepare_language(batch) 103 | actions, action_dim = self.prepare_action(batch) 104 | noise = batch.get("noise", None) 105 | time = batch.get("time", None) 106 | 107 | loss_dict = {} 108 | losses = self.model.forward( 109 | images, img_masks, lang_tokens, lang_masks, state, actions, noise, time 110 | ) 111 | 112 | actions_is_pad = batch.get("action_is_pad", None) 113 | if actions_is_pad is not None: 114 | in_episode_bound = ~actions_is_pad 115 | losses = losses * in_episode_bound.unsqueeze(-1) 116 | loss_dict["losses_after_in_ep_bound"] = losses.clone() 117 | 118 | # Remove padding 119 | losses = losses[:, :, :action_dim] 120 | loss_dict["losses"] = losses.clone() 121 | 122 | # For backward pass 123 | loss = losses.mean() 124 | # For logging 125 | loss_dict["l2_loss"] = loss.item() 126 | 127 | return loss, loss_dict 128 | 129 | def prepare_images(self, observation: dict[str, Tensor]): 130 | """Normalize, resize, and pad images and stack them into a tensor. 131 | 132 | Args: 133 | observation (dict[str, Tensor]) 134 | 135 | Returns: 136 | images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0] 137 | img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing 138 | """ 139 | dtype = observation["state"].dtype 140 | bsize = observation["state"].shape[0] 141 | images, img_masks = [], [] 142 | present_img_keys = [key for key in IMAGE_KEYS if key in observation["image"]] 143 | missing_img_keys = [key for key in IMAGE_KEYS if key not in present_img_keys] 144 | 145 | for key in present_img_keys: 146 | # resize, pad, and normalize 147 | img = observation["image"][key] 148 | img = img.to(dtype) / 127.5 - 1.0 149 | img = resize_with_pad( 150 | img, *self.config.resize_imgs_with_padding, pad_value=-1.0 151 | ) 152 | images.append(img) 153 | img_masks.append(torch.ones((bsize,), dtype=torch.bool, device=img.device)) 154 | 155 | for key in missing_img_keys: 156 | # zero padding 157 | img = torch.full_like(img, fill_value=-1.0) 158 | images.append(img) 159 | img_masks.append(torch.zeros((bsize,), dtype=torch.bool, device=img.device)) 160 | 161 | images = torch.stack(images, dim=1) # (*b, n, c, h, w) 162 | img_masks = torch.stack(img_masks, dim=1) # (*b, n) 163 | 164 | return images, img_masks 165 | 166 | def prepare_state(self, observation: dict[str, Tensor]): 167 | """Pad the state to the maximum state dimension. 168 | 169 | Args: 170 | observation (dict[str, Tensor]) 171 | 172 | Returns: 173 | state (torch.Tensor): (*b, max_state_dim) padded state tensor 174 | """ 175 | state = observation["state"] 176 | state = F.pad(state, (0, self.config.max_state_dim - state.shape[1])) 177 | return state 178 | 179 | def prepare_action(self, observation: dict[str, Tensor]): 180 | """Pad the action to the maximum action dimension. 181 | 182 | Args: 183 | observation (dict[str, Tensor]) 184 | 185 | Returns: 186 | action (torch.Tensor): (*b, n, max_action_dim) padded action tensor 187 | action_dim (int): the actual dimension of the action before padding 188 | """ 189 | action = observation["action"] 190 | action_dim = action.shape[-1] 191 | action = F.pad(action, (0, self.config.max_action_dim - action_dim)) 192 | return action, action_dim 193 | 194 | def prepare_language(self, observation: dict[str, Tensor]): 195 | """If `prompt` is provided, modify it to PaliGemma format and tokenize it. 196 | If `lang_tokens` and `lang_masks` are provided, use them directly. 197 | 198 | PaliGemma expects prefix prompts to be formatted as: 199 | .... prompt , where uses `\\n`. 200 | So here we format the prompt to start with `` and end with `\\n`. 201 | Later, we will concatenate the images and language tokens into a single sequence. 202 | 203 | Args: 204 | observation (dict[str, Tensor]) 205 | 206 | Returns: 207 | lang_tokens (torch.Tensor): (*b, l) language tokens 208 | lang_masks (torch.Tensor): (*b, l) masks for language tokens, True if token is present, False if missing 209 | """ 210 | lang_tokens = observation.get("lang_tokens", None) 211 | lang_masks = observation.get("lang_masks", None) 212 | prompt = observation.get("prompt", None) 213 | 214 | # either provide `prompt` or (`lang_tokens`, `lang_masks`) 215 | if prompt is None and (lang_tokens is None or lang_masks is None): 216 | raise ValueError( 217 | "Either 'prompt' or ('lang_tokens', 'lang_masks') must be provided in the observation." 218 | ) 219 | 220 | device = observation["state"].device 221 | if prompt is not None and (lang_tokens is None or lang_masks is None): 222 | prompt = [p if p.startswith("") else f"{p}" for p in prompt] 223 | prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt] 224 | tokenized_prompt = self.language_tokenizer.__call__( 225 | prompt, 226 | padding="max_length", 227 | padding_side="right", 228 | max_length=self.config.tokenizer_max_length, 229 | return_tensors="pt", 230 | ) 231 | lang_tokens = tokenized_prompt["input_ids"].to(device=device) 232 | lang_masks = tokenized_prompt["attention_mask"].to( 233 | device=device, dtype=torch.bool 234 | ) 235 | else: 236 | lang_tokens = observation["lang_tokens"].to(device=device) 237 | lang_masks = observation["lang_masks"].to(device=device, dtype=torch.bool) 238 | 239 | return lang_tokens, lang_masks 240 | 241 | 242 | class PI0FlowMatching(nn.Module): 243 | """ 244 | π0: A Vision-Language-Action Flow Model for General Robot Control 245 | 246 | [Paper](https://www.physicalintelligence.company/download/pi0.pdf) 247 | [Jax code](https://github.com/Physical-Intelligence/openpi) 248 | 249 | Designed by Physical Intelligence. Ported from Jax by Hugging Face. 250 | ┌──────────────────────────────┐ 251 | │ actions │ 252 | │ ▲ │ 253 | │ ┌┴─────┐ │ 254 | │ kv cache │Gemma │ │ 255 | │ ┌──────────►│Expert│ │ 256 | │ │ │ │ │ 257 | │ ┌┴────────┐ │x 10 │ │ 258 | │ │ │ └▲──▲──┘ │ 259 | │ │PaliGemma│ │ │ │ 260 | │ │ │ │ robot state │ 261 | │ │ │ noise │ 262 | │ └▲──▲─────┘ │ 263 | │ │ │ │ 264 | │ │ image(s) │ 265 | │ language tokens │ 266 | └──────────────────────────────┘ 267 | 268 | """ 269 | 270 | def __init__(self, config): 271 | super().__init__() 272 | self.config = config 273 | 274 | # paligemma with action expert 275 | paligemma_with_export_config = PaliGemmaWithExpertConfig( 276 | freeze_vision_encoder=self.config.freeze_vision_encoder, 277 | train_expert_only=self.config.train_expert_only, 278 | attention_implementation=self.config.attention_implementation, 279 | ) 280 | self.paligemma_with_expert = PaliGemmaWithExpertModel( 281 | paligemma_with_export_config 282 | ) 283 | 284 | # projection layers 285 | self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) 286 | self.action_in_proj = nn.Linear( 287 | self.config.max_action_dim, self.config.proj_width 288 | ) 289 | self.action_out_proj = nn.Linear( 290 | self.config.proj_width, self.config.max_action_dim 291 | ) 292 | 293 | self.action_time_mlp_in = nn.Linear( 294 | self.config.proj_width * 2, self.config.proj_width 295 | ) 296 | self.action_time_mlp_out = nn.Linear( 297 | self.config.proj_width, self.config.proj_width 298 | ) 299 | 300 | self.set_requires_grad() 301 | 302 | def set_requires_grad(self): 303 | for params in self.state_proj.parameters(): 304 | params.requires_grad = self.config.train_state_proj 305 | 306 | def sample_time(self, bsize, device): 307 | time_beta = sample_beta(1.5, 1.0, bsize, device) 308 | time = time_beta * 0.999 + 0.001 309 | return time.to(dtype=torch.float32, device=device) 310 | 311 | def embed_prefix( 312 | self, images, img_masks, lang_tokens, lang_masks 313 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 314 | """Embed images with SigLIP and language tokens with embedding layer to prepare 315 | for PaliGemma transformer processing. 316 | 317 | Args: 318 | images (torch.Tensor): float (*b, n, c, h, w) images in range [-1.0, 1.0] 319 | img_masks (torch.Tensor): bool (*b, n) masks for images 320 | lang_tokens (torch.Tensor): int (*b, l) language tokens 321 | lang_masks (torch.Tensor): bool (*b, l) masks for language tokens 322 | """ 323 | bsize = images.shape[0] 324 | device = images.device 325 | dtype = images.dtype 326 | 327 | # embed image 328 | images = einops.rearrange(images, "b n c h w -> (b n) c h w") 329 | img_emb = self.paligemma_with_expert.embed_image(images) 330 | num_patch = img_emb.shape[1] 331 | img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize) 332 | img_emb = img_emb.to(dtype=dtype) * (img_emb.shape[-1] ** 0.5) 333 | num_img_embs = img_emb.shape[1] 334 | img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch) 335 | 336 | # embed language 337 | lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) 338 | num_lang_embs = lang_emb.shape[1] 339 | lang_emb = lang_emb.to(dtype=dtype) * np.sqrt(lang_emb.shape[-1]) 340 | 341 | # assemble embeddings 342 | embs = torch.cat([img_emb, lang_emb], dim=1) 343 | pad_masks = torch.cat([img_masks, lang_masks], dim=1) 344 | 345 | # PaliGemma uses bidirectional attention for prefix tokens, 346 | # so we set 1D `att_masks` to zeros. 347 | # (see `make_att_2d_masks` to understand why zeros means bidirection) 348 | att_masks = torch.zeros( 349 | (bsize, num_img_embs + num_lang_embs), device=device, dtype=torch.bool 350 | ) 351 | return embs, pad_masks, att_masks 352 | 353 | def embed_suffix(self, state, noisy_actions, timestep): 354 | """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing. 355 | 356 | Args: 357 | state (torch.Tensor): float32 (*b, s) robot state 358 | noisy_actions (torch.Tensor): float32 (*b, n, m) noisy actions 359 | timestep (torch.Tensor): float32 (*b,) timestep in [0, 1] range 360 | """ 361 | bsize = state.shape[0] 362 | device = state.device 363 | dtype = state.dtype 364 | 365 | # embed state 366 | state_emb = self.state_proj(state) 367 | 368 | # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] 369 | time_emb = create_sinusoidal_pos_embedding( 370 | timestep, 371 | self.config.proj_width, 372 | min_period=4e-3, 373 | max_period=4.0, 374 | device=device, 375 | ) 376 | time_emb = time_emb.type(dtype=dtype) 377 | 378 | # Fuse timestep + action information using an MLP 379 | action_emb = self.action_in_proj(noisy_actions) 380 | time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1]) 381 | action_time_emb = torch.cat([action_emb, time_emb], dim=-1) 382 | 383 | action_time_emb = self.action_time_mlp_in(action_time_emb) 384 | action_time_emb = F.silu(action_time_emb) # swish == silu 385 | action_time_emb = self.action_time_mlp_out(action_time_emb) 386 | action_time_dim = action_time_emb.shape[1] 387 | 388 | # Add to input tokens 389 | embs = torch.cat([state_emb[:, None], action_time_emb], dim=1) 390 | pad_masks = torch.ones( 391 | (bsize, action_time_dim + 1), device=device, dtype=torch.bool 392 | ) 393 | 394 | # Set attention masks for suffix tokens so that prefix tokens cannot attend to suffix tokens. 395 | # And state token cannot attend action tokens. 396 | # Action tokens use a bidirectional attention. 397 | att_masks = torch.zeros( 398 | (bsize, action_time_dim + 1), device=device, dtype=torch.bool 399 | ) 400 | att_masks[:, :2] = True 401 | 402 | return embs, pad_masks, att_masks 403 | 404 | def forward( 405 | self, 406 | images, 407 | img_masks, 408 | lang_tokens, 409 | lang_masks, 410 | state, 411 | actions, 412 | noise=None, 413 | time=None, 414 | ) -> Tensor: 415 | bsize = state.shape[0] 416 | dtype = state.dtype 417 | device = state.device 418 | 419 | """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" 420 | if noise is None: 421 | actions_shape = ( 422 | bsize, 423 | self.config.n_action_steps, 424 | self.config.max_action_dim, 425 | ) 426 | noise = torch.randn(actions_shape, device=device, dtype=dtype) 427 | 428 | if time is None: 429 | time = self.sample_time(bsize, device).to(dtype) 430 | 431 | time_expanded = time[:, None, None] 432 | x_t = time_expanded * noise + (1 - time_expanded) * actions 433 | u_t = noise - actions 434 | 435 | prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( 436 | images, img_masks, lang_tokens, lang_masks 437 | ) 438 | suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( 439 | state, x_t, time 440 | ) 441 | 442 | pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) 443 | att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) 444 | 445 | att_2d_masks = make_att_2d_masks(pad_masks, att_masks) 446 | position_ids = torch.cumsum(pad_masks, dim=1) - 1 447 | 448 | (_, suffix_out), _ = self.paligemma_with_expert.forward( 449 | attention_mask=att_2d_masks, 450 | position_ids=position_ids, 451 | past_key_values=None, 452 | inputs_embeds=[prefix_embs, suffix_embs], 453 | use_cache=False, 454 | fill_kv_cache=False, 455 | ) 456 | suffix_out = suffix_out[:, -self.config.n_action_steps :] 457 | v_t = self.action_out_proj(suffix_out) 458 | losses = F.mse_loss(u_t, v_t, reduction="none") 459 | return losses 460 | 461 | def sample_actions( 462 | self, images, img_masks, lang_tokens, lang_masks, state, noise=None 463 | ) -> Tensor: 464 | """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" 465 | bsize = state.shape[0] 466 | device = state.device 467 | dtype = state.dtype 468 | 469 | if noise is None: 470 | actions_shape = ( 471 | bsize, 472 | self.config.n_action_steps, 473 | self.config.max_action_dim, 474 | ) 475 | noise = torch.randn(actions_shape, device=device, dtype=dtype) 476 | 477 | prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( 478 | images, img_masks, lang_tokens, lang_masks 479 | ) 480 | prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) 481 | prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 482 | 483 | # Compute image and language key value cache 484 | _, past_key_values = self.paligemma_with_expert.forward( 485 | attention_mask=prefix_att_2d_masks, 486 | position_ids=prefix_position_ids, 487 | past_key_values=None, 488 | inputs_embeds=[prefix_embs, None], 489 | use_cache=self.config.use_cache, 490 | fill_kv_cache=True, 491 | ) 492 | 493 | dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device) 494 | x_t = noise 495 | time = torch.tensor(1.0, dtype=dtype, device=device) 496 | while time >= -dt / 2: 497 | expanded_time = time.expand(bsize) 498 | 499 | v_t = self.predict_velocity( 500 | state, prefix_pad_masks, past_key_values, x_t, expanded_time 501 | ) 502 | 503 | # Euler step 504 | x_t += dt * v_t 505 | time += dt 506 | 507 | return x_t 508 | 509 | def predict_velocity(self, state, prefix_pad_masks, past_key_values, x_t, timestep): 510 | """predict velocity at time t using the suffix model.""" 511 | suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( 512 | state, x_t, timestep 513 | ) 514 | 515 | suffix_len = suffix_pad_masks.shape[1] 516 | batch_size = prefix_pad_masks.shape[0] 517 | prefix_len = prefix_pad_masks.shape[1] 518 | prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( 519 | batch_size, suffix_len, prefix_len 520 | ) 521 | 522 | suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) 523 | 524 | full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) 525 | 526 | prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] 527 | position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 528 | 529 | outputs_embeds, _ = self.paligemma_with_expert.forward( 530 | attention_mask=full_att_2d_masks, 531 | position_ids=position_ids, 532 | past_key_values=past_key_values, 533 | inputs_embeds=[None, suffix_embs], 534 | use_cache=self.config.use_cache, 535 | fill_kv_cache=False, 536 | ) 537 | suffix_out = outputs_embeds[1] 538 | suffix_out = suffix_out[:, -self.config.n_action_steps :] 539 | v_t = self.action_out_proj(suffix_out) 540 | return v_t 541 | -------------------------------------------------------------------------------- /conversion_scripts/new_modeling_pi0.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import einops 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from lerobot.common.policies.pretrained import PreTrainedPolicy 8 | from torch import Tensor, nn 9 | from transformers import AutoTokenizer 10 | 11 | from .configuration_pi0 import TorchPI0Config 12 | from .paligemma_with_expert import PaliGemmaWithExpertConfig, PaliGemmaWithExpertModel 13 | from .utils import ( 14 | create_sinusoidal_pos_embedding, 15 | make_att_2d_masks, 16 | resize_with_pad, 17 | sample_beta, 18 | ) 19 | 20 | IMAGE_KEYS = ( 21 | "base_0_rgb", 22 | "left_wrist_0_rgb", 23 | "right_wrist_0_rgb", 24 | ) 25 | 26 | 27 | class PI0Policy(PreTrainedPolicy): 28 | """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot.""" 29 | 30 | config_class = TorchPI0Config 31 | name = "torch_pi0" 32 | 33 | def __init__( 34 | self, 35 | config: TorchPI0Config, 36 | tokenizer_path: str = "/home/dzb/pretrained/paligemma3b", 37 | ): 38 | """ 39 | Args: 40 | config: Policy configuration class instance or None, in which case the default instantiation of 41 | the configuration class is used. 42 | """ 43 | 44 | super().__init__(config) 45 | self.config = config 46 | self.language_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 47 | self.model = PI0FlowMatching(config) 48 | self.reset() 49 | 50 | def reset(self): 51 | """This should be called whenever the environment is reset.""" 52 | # self._action_queue = deque([], maxlen=self.config.n_action_steps) 53 | pass 54 | 55 | def get_optim_params(self) -> dict: 56 | return self.parameters() 57 | 58 | @torch.no_grad 59 | def select_action( 60 | self, observation: dict[str, Tensor], noise: Tensor | None = None 61 | ): 62 | """ 63 | Observation: { 64 | "image": { 65 | "base_0_rgb": (*b, c, h, w), # uint8 [0, 255] 66 | ... 67 | }, 68 | "state": float32 [*b, s], 69 | "prompt": List[str], 70 | 71 | "lang_tokens": float32 [*b, l], 72 | "lang_masks": float32 [*b, l], 73 | } 74 | """ 75 | self.eval() 76 | 77 | images, img_masks = self.prepare_images(observation) 78 | state = self.prepare_state(observation) 79 | lang_tokens, lang_masks = self.prepare_language(observation) 80 | actions = self.model.sample_actions( 81 | images, img_masks, lang_tokens, lang_masks, state, noise=noise 82 | ) 83 | return actions 84 | 85 | def forward( 86 | self, batch: dict[str, Tensor], noise=None, time=None 87 | ) -> tuple[Tensor, dict[str, Tensor]]: 88 | """Do a full training forward pass to compute the loss 89 | 90 | batch: { 91 | "image": { 92 | "base_0_rgb": (*b, c, h, w), # uint8 [0, 255] 93 | ... 94 | }, 95 | "state": float32 [*b, s], 96 | "lang_tokens": float32 [*b, l], 97 | "lang_masks": float32 [*b, l], 98 | "action": float32 [*b, ha, da] 99 | } 100 | """ 101 | images, img_masks = self.prepare_images(batch) 102 | state = self.prepare_state(batch) 103 | lang_tokens, lang_masks = self.prepare_language(batch) 104 | actions, action_dim = self.prepare_action(batch) 105 | noise = batch.get("noise", None) 106 | time = batch.get("time", None) 107 | 108 | loss_dict = {} 109 | losses = self.model.forward( 110 | images, img_masks, lang_tokens, lang_masks, state, actions, noise, time 111 | ) 112 | 113 | actions_is_pad = batch.get("action_is_pad", None) 114 | if actions_is_pad is not None: 115 | in_episode_bound = ~actions_is_pad 116 | losses = losses * in_episode_bound.unsqueeze(-1) 117 | loss_dict["losses_after_in_ep_bound"] = losses.clone() 118 | 119 | # Remove padding 120 | losses = losses[:, :, :action_dim] 121 | loss_dict["losses"] = losses.clone() 122 | 123 | # For backward pass 124 | loss = losses.mean() 125 | # For logging 126 | loss_dict["l2_loss"] = loss.item() 127 | 128 | return loss, loss_dict 129 | 130 | def prepare_images(self, observation: dict[str, Tensor]): 131 | """Normalize, resize, and pad images and stack them into a tensor. 132 | 133 | Args: 134 | observation (dict[str, Tensor]) 135 | 136 | Returns: 137 | images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0] 138 | img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing 139 | """ 140 | dtype = observation["state"].dtype 141 | bsize = observation["state"].shape[0] 142 | images, img_masks = [], [] 143 | present_img_keys = [key for key in IMAGE_KEYS if key in observation["image"]] 144 | missing_img_keys = [key for key in IMAGE_KEYS if key not in present_img_keys] 145 | 146 | for key in present_img_keys: 147 | # resize, pad, and normalize 148 | img = observation["image"][key] 149 | img = img.to(dtype) / 127.5 - 1.0 150 | img = resize_with_pad( 151 | img, *self.config.resize_imgs_with_padding, pad_value=-1.0 152 | ) 153 | images.append(img) 154 | img_masks.append(torch.ones((bsize,), dtype=torch.bool, device=img.device)) 155 | 156 | for key in missing_img_keys: 157 | # zero padding 158 | img = torch.full_like(img, fill_value=-1.0) 159 | images.append(img) 160 | img_masks.append(torch.zeros((bsize,), dtype=torch.bool, device=img.device)) 161 | 162 | images = torch.stack(images, dim=1) # (*b, n, c, h, w) 163 | img_masks = torch.stack(img_masks, dim=1) # (*b, n) 164 | 165 | return images, img_masks 166 | 167 | def prepare_state(self, observation: dict[str, Tensor]): 168 | """Pad the state to the maximum state dimension. 169 | 170 | Args: 171 | observation (dict[str, Tensor]) 172 | 173 | Returns: 174 | state (torch.Tensor): (*b, max_state_dim) padded state tensor 175 | """ 176 | state = observation["state"] 177 | state = F.pad(state, (0, self.config.max_state_dim - state.shape[1])) 178 | return state 179 | 180 | def prepare_action(self, observation: dict[str, Tensor]): 181 | """Pad the action to the maximum action dimension. 182 | 183 | Args: 184 | observation (dict[str, Tensor]) 185 | 186 | Returns: 187 | action (torch.Tensor): (*b, n, max_action_dim) padded action tensor 188 | action_dim (int): the actual dimension of the action before padding 189 | """ 190 | action = observation["action"] 191 | action_dim = action.shape[-1] 192 | action = F.pad(action, (0, self.config.max_action_dim - action_dim)) 193 | return action, action_dim 194 | 195 | def prepare_language(self, observation: dict[str, Tensor]): 196 | """If `prompt` is provided, modify it to PaliGemma format and tokenize it. 197 | If `lang_tokens` and `lang_masks` are provided, use them directly. 198 | 199 | PaliGemma expects prefix prompts to be formatted as: 200 | .... prompt , where uses `\\n`. 201 | So here we format the prompt to start with `` and end with `\\n`. 202 | Later, we will concatenate the images and language tokens into a single sequence. 203 | 204 | Args: 205 | observation (dict[str, Tensor]) 206 | 207 | Returns: 208 | lang_tokens (torch.Tensor): (*b, l) language tokens 209 | lang_masks (torch.Tensor): (*b, l) masks for language tokens, True if token is present, False if missing 210 | """ 211 | lang_tokens = observation.get("lang_tokens", None) 212 | lang_masks = observation.get("lang_masks", None) 213 | prompt = observation.get("prompt", None) 214 | 215 | # either provide `prompt` or (`lang_tokens`, `lang_masks`) 216 | if prompt is None and (lang_tokens is None or lang_masks is None): 217 | raise ValueError( 218 | "Either 'prompt' or ('lang_tokens', 'lang_masks') must be provided in the observation." 219 | ) 220 | 221 | device = observation["state"].device 222 | if prompt is not None and (lang_tokens is None or lang_masks is None): 223 | prompt = [p if p.startswith("") else f"{p}" for p in prompt] 224 | prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt] 225 | tokenized_prompt = self.language_tokenizer.__call__( 226 | prompt, 227 | padding="max_length", 228 | padding_side="right", 229 | max_length=self.config.tokenizer_max_length, 230 | return_tensors="pt", 231 | ) 232 | lang_tokens = tokenized_prompt["input_ids"].to(device=device) 233 | lang_masks = tokenized_prompt["attention_mask"].to( 234 | device=device, dtype=torch.bool 235 | ) 236 | else: 237 | lang_tokens = observation["lang_tokens"].to(device=device) 238 | lang_masks = observation["lang_masks"].to(device=device, dtype=torch.bool) 239 | 240 | return lang_tokens, lang_masks 241 | 242 | 243 | class PI0FlowMatching(nn.Module): 244 | """ 245 | π0: A Vision-Language-Action Flow Model for General Robot Control 246 | 247 | [Paper](https://www.physicalintelligence.company/download/pi0.pdf) 248 | [Jax code](https://github.com/Physical-Intelligence/openpi) 249 | 250 | Designed by Physical Intelligence. Ported from Jax by Hugging Face. 251 | ┌──────────────────────────────┐ 252 | │ actions │ 253 | │ ▲ │ 254 | │ ┌┴─────┐ │ 255 | │ kv cache │Gemma │ │ 256 | │ ┌──────────►│Expert│ │ 257 | │ │ │ │ │ 258 | │ ┌┴────────┐ │x 10 │ │ 259 | │ │ │ └▲──▲──┘ │ 260 | │ │PaliGemma│ │ │ │ 261 | │ │ │ │ robot state │ 262 | │ │ │ noise │ 263 | │ └▲──▲─────┘ │ 264 | │ │ │ │ 265 | │ │ image(s) │ 266 | │ language tokens │ 267 | └──────────────────────────────┘ 268 | 269 | """ 270 | 271 | def __init__(self, config): 272 | super().__init__() 273 | self.config = config 274 | 275 | # paligemma with action expert 276 | paligemma_with_export_config = PaliGemmaWithExpertConfig( 277 | freeze_vision_encoder=self.config.freeze_vision_encoder, 278 | train_expert_only=self.config.train_expert_only, 279 | attention_implementation=self.config.attention_implementation, 280 | ) 281 | self.paligemma_with_expert = PaliGemmaWithExpertModel( 282 | paligemma_with_export_config 283 | ) 284 | 285 | # projection layers 286 | self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) 287 | self.action_in_proj = nn.Linear( 288 | self.config.max_action_dim, self.config.proj_width 289 | ) 290 | self.action_out_proj = nn.Linear( 291 | self.config.proj_width, self.config.max_action_dim 292 | ) 293 | 294 | self.action_time_mlp_in = nn.Linear( 295 | self.config.proj_width * 2, self.config.proj_width 296 | ) 297 | self.action_time_mlp_out = nn.Linear( 298 | self.config.proj_width, self.config.proj_width 299 | ) 300 | 301 | self.set_requires_grad() 302 | 303 | def set_requires_grad(self): 304 | for params in self.state_proj.parameters(): 305 | params.requires_grad = self.config.train_state_proj 306 | 307 | def sample_time(self, bsize, device): 308 | time_beta = sample_beta(1.5, 1.0, bsize, device) 309 | time = time_beta * 0.999 + 0.001 310 | return time.to(dtype=torch.float32, device=device) 311 | 312 | def embed_prefix( 313 | self, images, img_masks, lang_tokens, lang_masks 314 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 315 | """Embed images with SigLIP and language tokens with embedding layer to prepare 316 | for PaliGemma transformer processing. 317 | 318 | Args: 319 | images (torch.Tensor): float (*b, n, c, h, w) images in range [-1.0, 1.0] 320 | img_masks (torch.Tensor): bool (*b, n) masks for images 321 | lang_tokens (torch.Tensor): int (*b, l) language tokens 322 | lang_masks (torch.Tensor): bool (*b, l) masks for language tokens 323 | """ 324 | bsize = images.shape[0] 325 | device = images.device 326 | dtype = images.dtype 327 | 328 | # embed image 329 | images = einops.rearrange(images, "b n c h w -> (b n) c h w") 330 | img_emb = self.paligemma_with_expert.embed_image(images) 331 | num_patch = img_emb.shape[1] 332 | img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize) 333 | img_emb = img_emb.to(dtype=dtype) * (img_emb.shape[-1] ** 0.5) 334 | num_img_embs = img_emb.shape[1] 335 | img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch) 336 | 337 | # embed language 338 | lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) 339 | num_lang_embs = lang_emb.shape[1] 340 | lang_emb = lang_emb.to(dtype=dtype) * np.sqrt(lang_emb.shape[-1]) 341 | 342 | # assemble embeddings 343 | embs = torch.cat([img_emb, lang_emb], dim=1) 344 | pad_masks = torch.cat([img_masks, lang_masks], dim=1) 345 | 346 | # PaliGemma uses bidirectional attention for prefix tokens, 347 | # so we set 1D `att_masks` to zeros. 348 | # (see `make_att_2d_masks` to understand why zeros means bidirection) 349 | att_masks = torch.zeros( 350 | (bsize, num_img_embs + num_lang_embs), device=device, dtype=torch.bool 351 | ) 352 | return embs, pad_masks, att_masks 353 | 354 | def embed_suffix(self, state, noisy_actions, timestep): 355 | """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing. 356 | 357 | Args: 358 | state (torch.Tensor): float32 (*b, s) robot state 359 | noisy_actions (torch.Tensor): float32 (*b, n, m) noisy actions 360 | timestep (torch.Tensor): float32 (*b,) timestep in [0, 1] range 361 | """ 362 | bsize = state.shape[0] 363 | device = state.device 364 | dtype = state.dtype 365 | 366 | # embed state 367 | state_emb = self.state_proj(state) 368 | 369 | # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] 370 | time_emb = create_sinusoidal_pos_embedding( 371 | timestep, 372 | self.config.proj_width, 373 | min_period=4e-3, 374 | max_period=4.0, 375 | device=device, 376 | ) 377 | time_emb = time_emb.type(dtype=dtype) 378 | 379 | # Fuse timestep + action information using an MLP 380 | action_emb = self.action_in_proj(noisy_actions) 381 | time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1]) 382 | action_time_emb = torch.cat([action_emb, time_emb], dim=-1) 383 | 384 | action_time_emb = self.action_time_mlp_in(action_time_emb) 385 | action_time_emb = F.silu(action_time_emb) # swish == silu 386 | action_time_emb = self.action_time_mlp_out(action_time_emb) 387 | action_time_dim = action_time_emb.shape[1] 388 | 389 | # Add to input tokens 390 | embs = torch.cat([state_emb[:, None], action_time_emb], dim=1) 391 | pad_masks = torch.ones( 392 | (bsize, action_time_dim + 1), device=device, dtype=torch.bool 393 | ) 394 | 395 | # Set attention masks for suffix tokens so that prefix tokens cannot attend to suffix tokens. 396 | # And state token cannot attend action tokens. 397 | # Action tokens use a bidirectional attention. 398 | att_masks = torch.zeros( 399 | (bsize, action_time_dim + 1), device=device, dtype=torch.bool 400 | ) 401 | att_masks[:, :2] = True 402 | 403 | return embs, pad_masks, att_masks 404 | 405 | def forward( 406 | self, 407 | images, 408 | img_masks, 409 | lang_tokens, 410 | lang_masks, 411 | state, 412 | actions, 413 | noise=None, 414 | time=None, 415 | ) -> Tensor: 416 | bsize = state.shape[0] 417 | dtype = state.dtype 418 | device = state.device 419 | 420 | """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" 421 | if noise is None: 422 | actions_shape = ( 423 | bsize, 424 | self.config.n_action_steps, 425 | self.config.max_action_dim, 426 | ) 427 | noise = torch.randn(actions_shape, device=device, dtype=dtype) 428 | 429 | if time is None: 430 | time = self.sample_time(bsize, device).to(dtype) 431 | 432 | time_expanded = time[:, None, None] 433 | x_t = time_expanded * noise + (1 - time_expanded) * actions 434 | u_t = noise - actions 435 | 436 | prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( 437 | images, img_masks, lang_tokens, lang_masks 438 | ) 439 | suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( 440 | state, x_t, time 441 | ) 442 | 443 | pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) 444 | att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) 445 | 446 | att_2d_masks = make_att_2d_masks(pad_masks, att_masks) 447 | position_ids = torch.cumsum(pad_masks, dim=1) - 1 448 | 449 | (_, suffix_out), _ = self.paligemma_with_expert.forward( 450 | attention_mask=att_2d_masks, 451 | position_ids=position_ids, 452 | past_key_values=None, 453 | inputs_embeds=[prefix_embs, suffix_embs], 454 | use_cache=False, 455 | fill_kv_cache=False, 456 | ) 457 | suffix_out = suffix_out[:, -self.config.n_action_steps :] 458 | # Original openpi code, upcast attention output 459 | suffix_out = suffix_out.to(dtype=torch.float32) 460 | v_t = self.action_out_proj(suffix_out) 461 | 462 | losses = F.mse_loss(u_t, v_t, reduction="none") 463 | return losses 464 | 465 | def sample_actions( 466 | self, images, img_masks, lang_tokens, lang_masks, state, noise=None 467 | ) -> Tensor: 468 | """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" 469 | bsize = state.shape[0] 470 | device = state.device 471 | dtype = state.dtype 472 | 473 | if noise is None: 474 | actions_shape = ( 475 | bsize, 476 | self.config.n_action_steps, 477 | self.config.max_action_dim, 478 | ) 479 | noise = torch.randn(actions_shape, device=device, dtype=dtype) 480 | 481 | prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( 482 | images, img_masks, lang_tokens, lang_masks 483 | ) 484 | prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) 485 | prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 486 | 487 | # Compute image and language key value cache 488 | _, past_key_values = self.paligemma_with_expert.forward( 489 | attention_mask=prefix_att_2d_masks, 490 | position_ids=prefix_position_ids, 491 | past_key_values=None, 492 | inputs_embeds=[prefix_embs, None], 493 | use_cache=self.config.use_cache, 494 | fill_kv_cache=True, 495 | ) 496 | 497 | dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device) 498 | x_t = noise 499 | time = torch.tensor(1.0, dtype=dtype, device=device) 500 | while time >= -dt / 2: 501 | expanded_time = time.expand(bsize) 502 | 503 | v_t = self.predict_velocity( 504 | state, # (*b, state_dim) 505 | prefix_pad_masks, # (*b, l) 506 | past_key_values, 507 | x_t, # (*b, ha, da) 508 | expanded_time, # (*b,) 509 | ) 510 | 511 | # Euler step 512 | x_t += dt * v_t 513 | time += dt 514 | 515 | return x_t 516 | 517 | def predict_velocity( 518 | self, 519 | state, 520 | prefix_pad_masks, 521 | past_key_values, 522 | x_t, 523 | timestep, 524 | ): 525 | """predict velocity at time t using the suffix model.""" 526 | suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( 527 | state, x_t, timestep 528 | ) 529 | 530 | suffix_len = suffix_pad_masks.shape[1] 531 | batch_size = prefix_pad_masks.shape[0] 532 | prefix_len = prefix_pad_masks.shape[1] 533 | prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( 534 | batch_size, suffix_len, prefix_len 535 | ) 536 | 537 | suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) 538 | 539 | full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) 540 | 541 | prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] 542 | position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 543 | 544 | outputs_embeds, _ = self.paligemma_with_expert.forward( 545 | attention_mask=full_att_2d_masks, 546 | position_ids=position_ids, 547 | past_key_values=past_key_values, 548 | inputs_embeds=[None, suffix_embs], 549 | use_cache=self.config.use_cache, 550 | fill_kv_cache=False, 551 | ) 552 | suffix_out = outputs_embeds[1] 553 | suffix_out = suffix_out[:, -self.config.n_action_steps :] 554 | suffix_out = suffix_out.to(dtype=torch.float32) 555 | v_t = self.action_out_proj(suffix_out) 556 | return v_t 557 | -------------------------------------------------------------------------------- /convert_pi0_to_hf_lerobot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Convert pi0 parameters from Jax to Pytorch 17 | 18 | Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment 19 | and install the required libraries. 20 | 21 | ```bash 22 | cd ~/code/openpi 23 | source .venv/bin/activate 24 | ``` 25 | 26 | Example downloading parameters: 27 | ```bash 28 | python 29 | >>> import openpi.shared.download as download 30 | >>> path='s3://openpi-assets/checkpoints/pi0_base/params' 31 | >>> download.maybe_download(path) 32 | ``` 33 | 34 | Converting pi0_base: 35 | ```python 36 | python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ 37 | --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \ 38 | --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch 39 | ``` 40 | 41 | ```python 42 | python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ 43 | --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \ 44 | --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch 45 | ``` 46 | """ 47 | 48 | import argparse 49 | import pathlib 50 | 51 | import jax 52 | import numpy as np 53 | import orbax.checkpoint as ocp 54 | import torch 55 | from jax.sharding import SingleDeviceSharding 56 | from lerobot.common.policies.pi0.configuration_pi0 import PI0Config 57 | 58 | from conversion_scripts.conversion_utils import ( 59 | get_gemma_config, 60 | get_paligemma_config, 61 | ) 62 | from pi0.modeling_pi0 import PI0Policy 63 | 64 | PRECISIONS = { 65 | "bfloat16": torch.bfloat16, 66 | "float32": torch.float32, 67 | "float16": torch.float16, 68 | } 69 | 70 | 71 | def slice_paligemma_state_dict(state_dict, config): 72 | suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" 73 | 74 | # fmt: off 75 | # patch embeddings 76 | state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose( 77 | 3, 2, 0, 1 78 | ) 79 | state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}") 80 | # positional embeddings 81 | state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape( 82 | -1, config.vision_config.hidden_size 83 | ) 84 | 85 | # extract vision layers to be sliced at index 0. There are 27 layers in the base model. 86 | encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") 87 | encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") 88 | encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") 89 | encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") 90 | 91 | encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") 92 | encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") 93 | encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") 94 | encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") 95 | 96 | encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}") 97 | encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}") 98 | encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}") 99 | encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}") 100 | encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}") 101 | encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}") 102 | encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}") 103 | encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}") 104 | 105 | for i in range(config.vision_config.num_hidden_layers): 106 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() 107 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] 108 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() 109 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] 110 | 111 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() 112 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] 113 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() 114 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] 115 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 116 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 117 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 118 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 119 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 120 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 121 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 122 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 123 | 124 | state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose() 125 | state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}") 126 | 127 | # multimodal projector 128 | 129 | state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose() 130 | state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}") 131 | 132 | # text decoder (gemma) 133 | embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}") 134 | state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector 135 | 136 | # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. 137 | 138 | llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") 139 | llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") 140 | llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") 141 | 142 | llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") 143 | llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") 144 | 145 | llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") 146 | llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") 147 | 148 | for i in range(config.text_config.num_hidden_layers): 149 | # llm_attention_q_einsum[i].shape = (8, 2048, 256) 150 | q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) 151 | 152 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped 153 | 154 | # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) 155 | k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() 156 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped 157 | # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) 158 | v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() 159 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped 160 | 161 | # output projection. 162 | 163 | # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) 164 | o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) 165 | 166 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped 167 | # mlp layers 168 | gate_proj_weight = llm_mlp_gating_einsum[i, 0] 169 | state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() 170 | up_proj_weight = llm_mlp_gating_einsum[i, 1] 171 | state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() 172 | state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() 173 | state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] 174 | state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] 175 | 176 | state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}") 177 | state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied. 178 | 179 | # fmt: on 180 | expert_dict = {} 181 | final_state_dict = {} 182 | for key, value in state_dict.items(): 183 | if key not in [ 184 | f"llm/final_norm_1/scale{suffix}", 185 | f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", 186 | f"llm/layers/attn/kv_einsum_1/w{suffix}", 187 | f"llm/layers/attn/q_einsum_1/w{suffix}", 188 | f"llm/layers/mlp_1/gating_einsum{suffix}", 189 | f"llm/layers/mlp_1/linear{suffix}", 190 | f"llm/layers/pre_attention_norm_1/scale{suffix}", 191 | f"llm/layers/pre_ffw_norm_1/scale{suffix}", 192 | ]: 193 | final_state_dict[key] = torch.from_numpy(value) 194 | else: 195 | expert_dict[key] = value 196 | 197 | return final_state_dict, expert_dict 198 | 199 | 200 | def slice_gemma_state_dict(state_dict, config, num_expert=1): 201 | # fmt: off 202 | # text decoder (gemma) 203 | # no embedding vector, the expert just has the decoder layers 204 | 205 | embedding_vector = torch.zeros([config.vocab_size, config.hidden_size]) 206 | state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector 207 | 208 | # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. 209 | 210 | suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" 211 | 212 | llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") 213 | llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") 214 | llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") 215 | 216 | llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") 217 | llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") 218 | 219 | llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") 220 | llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") 221 | 222 | for i in range(config.num_hidden_layers): 223 | q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size) 224 | 225 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped 226 | 227 | k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() 228 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped 229 | v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() 230 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped 231 | 232 | # output projection. 233 | 234 | # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024) 235 | o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0) 236 | 237 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped 238 | # mlp layers 239 | gate_proj_weight = llm_mlp_gating_einsum[i, 0] 240 | state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() 241 | up_proj_weight = llm_mlp_gating_einsum[i, 1] 242 | state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() 243 | state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() 244 | state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] 245 | state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] 246 | 247 | state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}") 248 | state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here) 249 | 250 | # fmt: on 251 | final_state_dict = {} 252 | for key, value in state_dict.items(): 253 | if not isinstance(value, torch.Tensor): 254 | final_state_dict[key] = torch.from_numpy(value) 255 | else: 256 | final_state_dict[key] = value 257 | return final_state_dict 258 | 259 | 260 | def flatten_for_memory(tree, parent_key=""): 261 | out = {} 262 | for k, v in tree.items(): 263 | new_key = f"{parent_key}/{k}" if parent_key else k 264 | if isinstance(v, dict): 265 | out.update(flatten_for_memory(v, new_key)) 266 | else: 267 | out[new_key] = np.array(v) # Ensure conversion to np.array for consistency 268 | return out 269 | 270 | 271 | def flatten_for_npz(tree, parent_key=""): 272 | out = {} 273 | for k, v in tree.items(): 274 | new_key = f"{parent_key}/{k}" if parent_key else k 275 | if isinstance(v, dict): 276 | out.update(flatten_for_npz(v, new_key)) 277 | else: 278 | # bf16/f32 here? 279 | out[new_key] = np.array(v) 280 | return out 281 | 282 | 283 | def slice_initial_orbax_checkpoint(checkpoint_dir: str): 284 | params_path = pathlib.Path(checkpoint_dir).resolve() 285 | checkpointer = ocp.PyTreeCheckpointer() 286 | 287 | metadata = checkpointer.metadata(params_path) 288 | print("Metadata keys:", list(metadata.keys())) 289 | 290 | params_name = "params" 291 | 292 | item = {params_name: metadata[params_name]} 293 | device = jax.local_devices()[0] # Use the first local device 294 | sharding = SingleDeviceSharding(device) 295 | restored = checkpointer.restore( 296 | params_path, 297 | ocp.args.PyTreeRestore( 298 | item=item, 299 | restore_args=jax.tree_util.tree_map( 300 | lambda _: ocp.ArrayRestoreArgs( 301 | restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it 302 | sharding=sharding, 303 | ), 304 | item, 305 | ), 306 | transforms={}, 307 | ), 308 | ) 309 | params = restored[params_name] 310 | 311 | # get params for PaliGemma 312 | pali_params = params["PaliGemma"] 313 | del params["PaliGemma"] 314 | pali_params_flat = flatten_for_npz(pali_params) 315 | return {"paligemma_params": pali_params_flat, "projection_params": params} 316 | 317 | 318 | def update_keys_with_prefix(d: dict, prefix: str) -> dict: 319 | """Update dictionary keys by adding a prefix.""" 320 | return {f"{prefix}{key}": value for key, value in d.items()} 321 | 322 | 323 | def convert_pi0_checkpoint( 324 | checkpoint_dir: str, 325 | precision: str, 326 | tokenizer_id: str, 327 | output_path: str, 328 | torch_device: str, 329 | ): 330 | # Break down orbax ckpts - they are in OCDBT 331 | initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) 332 | # process projection params 333 | keys = [ 334 | "state_proj", 335 | "action_in_proj", 336 | "action_out_proj", 337 | "action_time_mlp_in", 338 | "action_time_mlp_out", 339 | ] 340 | 341 | projection_params = {} 342 | for key in keys: 343 | kernel_params = initial_params["projection_params"][key]["kernel"] 344 | bias_params = initial_params["projection_params"][key]["bias"] 345 | if isinstance(kernel_params, dict): 346 | weight = kernel_params["value"] 347 | bias = bias_params["value"] 348 | else: 349 | weight = kernel_params 350 | bias = bias_params 351 | projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T 352 | projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias)) 353 | 354 | # Process PaliGemma weights 355 | paligemma_config = get_paligemma_config(precision) 356 | paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict( 357 | initial_params["paligemma_params"], paligemma_config 358 | ) 359 | 360 | # Process Gemma weights (at this stage they are unused) 361 | gemma_config = get_gemma_config(precision) 362 | gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config) 363 | 364 | # Instantiate model from configs 365 | 366 | if "pi0_aloha_sim" in checkpoint_dir: 367 | pi0_config = PI0Config( 368 | empty_cameras=2, 369 | adapt_to_pi_aloha=True, 370 | use_delta_joint_actions_aloha=False, 371 | ) 372 | elif "pi0_aloha_towel" in checkpoint_dir: 373 | pi0_config = PI0Config( 374 | adapt_to_pi_aloha=True, 375 | use_delta_joint_actions_aloha=True, 376 | ) 377 | elif "pi0_base" in checkpoint_dir: 378 | pi0_config = PI0Config( 379 | empty_cameras=0, 380 | adapt_to_pi_aloha=False, 381 | use_delta_joint_actions_aloha=False, 382 | ) 383 | elif "pi0_libero" in checkpoint_dir: 384 | pi0_config = PI0Config( 385 | empty_cameras=0, 386 | adapt_to_pi_aloha=False, 387 | use_delta_joint_actions_aloha=False, 388 | ) 389 | else: 390 | raise ValueError() 391 | 392 | pi0_config.device = torch_device 393 | 394 | # gemma_config=gemma_config, paligemma_config=paligemma_config) 395 | pi0_model = PI0Policy(pi0_config) 396 | 397 | paligemma_params = update_keys_with_prefix( 398 | paligemma_params, "model.paligemma_with_expert." 399 | ) 400 | gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") 401 | projection_params = update_keys_with_prefix(projection_params, "model.") 402 | 403 | # load state dict 404 | torch_dtype = PRECISIONS[precision] 405 | pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params}) 406 | pi0_model = pi0_model.to(torch_dtype) 407 | # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) 408 | 409 | pi0_model.save_pretrained(output_path, safe_serialization=True) 410 | # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype) 411 | 412 | # assert that model loads properly 413 | del pi0_model 414 | PI0Policy.from_pretrained(output_path) 415 | 416 | 417 | if __name__ == "__main__": 418 | parser = argparse.ArgumentParser() 419 | parser.add_argument( 420 | "--checkpoint_dir", 421 | default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params", 422 | type=str, 423 | help="Path to the ocdbt checkpoint", 424 | ) 425 | 426 | parser.add_argument( 427 | "--precision", 428 | choices=["float32", "bfloat16", "float16"], 429 | default="bfloat16", 430 | type=str, 431 | help="Precision identifier for model conversion - should match the base checkpoint precision.", 432 | ) 433 | # tokenizer is identical to paligemma, it appears 434 | 435 | parser.add_argument( 436 | "--tokenizer_hub_id", 437 | default="/home/dzb/pretrained/paligemma3b", 438 | type=str, 439 | help="Hub path to the tokenizer to save", 440 | ) 441 | 442 | parser.add_argument( 443 | "--output_path", 444 | required=True, 445 | type=str, 446 | help="Path to save converted weights to", 447 | ) 448 | 449 | parser.add_argument( 450 | "--torch_device", 451 | default="cuda:0", 452 | type=str, 453 | help="Torch device to use for conversion", 454 | ) 455 | 456 | args = parser.parse_args() 457 | convert_pi0_checkpoint( 458 | checkpoint_dir=args.checkpoint_dir, 459 | precision=args.precision, 460 | tokenizer_id=args.tokenizer_hub_id, 461 | output_path=args.output_path, 462 | torch_device=args.torch_device, 463 | ) 464 | -------------------------------------------------------------------------------- /convert_pi0fast_to_hf_lerobot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Convert pi0 parameters from Jax to Pytorch 17 | 18 | Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment 19 | and install the required libraries. 20 | 21 | ```bash 22 | cd ~/code/openpi 23 | source .venv/bin/activate 24 | ``` 25 | 26 | Example downloading parameters: 27 | ```bash 28 | python 29 | >>> import openpi.shared.download as download 30 | >>> path='s3://openpi-assets/checkpoints/pi0_base/params' 31 | >>> download.maybe_download(path) 32 | ``` 33 | 34 | Converting pi0_base: 35 | ```python 36 | python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ 37 | --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \ 38 | --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch 39 | ``` 40 | 41 | ```python 42 | python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ 43 | --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \ 44 | --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch 45 | ``` 46 | """ 47 | 48 | import argparse 49 | import pathlib 50 | 51 | import jax 52 | import numpy as np 53 | import orbax.checkpoint as ocp 54 | import torch 55 | from jax.sharding import SingleDeviceSharding 56 | from lerobot.common.policies.pi0.configuration_pi0 import PI0Config 57 | from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig 58 | 59 | from pi0.modeling_pi0 import PI0Policy 60 | from pi0.modeling_pi0fast import PI0FASTPolicy 61 | 62 | from conversion_scripts.conversion_utils import get_gemma_config, get_paligemma_config 63 | 64 | PRECISIONS = { 65 | "bfloat16": torch.bfloat16, 66 | "float32": torch.float32, 67 | "float16": torch.float16, 68 | } 69 | 70 | 71 | def slice_paligemma_state_dict(state_dict, config): 72 | suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" 73 | 74 | # fmt: off 75 | # patch embeddings 76 | state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose( 77 | 3, 2, 0, 1 78 | ) 79 | state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}") 80 | # positional embeddings 81 | state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape( 82 | -1, config.vision_config.hidden_size 83 | ) 84 | 85 | # extract vision layers to be sliced at index 0. There are 27 layers in the base model. 86 | encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") 87 | encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") 88 | encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") 89 | encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") 90 | 91 | encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") 92 | encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") 93 | encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") 94 | encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") 95 | 96 | encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}") 97 | encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}") 98 | encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}") 99 | encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}") 100 | encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}") 101 | encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}") 102 | encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}") 103 | encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}") 104 | 105 | for i in range(config.vision_config.num_hidden_layers): 106 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() 107 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] 108 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() 109 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] 110 | 111 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() 112 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] 113 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() 114 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] 115 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 116 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 117 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 118 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 119 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 120 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 121 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() 122 | state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) 123 | 124 | state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose() 125 | state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}") 126 | 127 | # multimodal projector 128 | 129 | state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose() 130 | state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}") 131 | 132 | # text decoder (gemma) 133 | embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}") 134 | state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector 135 | 136 | # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. 137 | 138 | llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") 139 | llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") 140 | llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") 141 | 142 | llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") 143 | llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") 144 | # TODO verify correctness of layer norm loading 145 | 146 | llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") 147 | llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") 148 | 149 | for i in range(config.text_config.num_hidden_layers): 150 | # llm_attention_q_einsum[i].shape = (8, 2048, 256) 151 | q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) 152 | 153 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped 154 | 155 | # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) 156 | k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() 157 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped 158 | # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) 159 | v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() 160 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped 161 | 162 | # output projection. 163 | 164 | # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) 165 | o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) 166 | 167 | state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped 168 | # mlp layers 169 | gate_proj_weight = llm_mlp_gating_einsum[i, 0] 170 | state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() 171 | up_proj_weight = llm_mlp_gating_einsum[i, 1] 172 | state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() 173 | state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() 174 | state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] 175 | state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] 176 | 177 | state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}") 178 | state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied. 179 | 180 | # fmt: on 181 | expert_dict = {} 182 | final_state_dict = {} 183 | for key, value in state_dict.items(): 184 | if key not in [ 185 | f"llm/final_norm_1/scale{suffix}", 186 | f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", 187 | f"llm/layers/attn/kv_einsum_1/w{suffix}", 188 | f"llm/layers/attn/q_einsum_1/w{suffix}", 189 | f"llm/layers/mlp_1/gating_einsum{suffix}", 190 | f"llm/layers/mlp_1/linear{suffix}", 191 | f"llm/layers/pre_attention_norm_1/scale{suffix}", 192 | f"llm/layers/pre_ffw_norm_1/scale{suffix}", 193 | ]: 194 | final_state_dict[key] = torch.from_numpy(value) 195 | else: 196 | expert_dict[key] = value 197 | 198 | return final_state_dict, expert_dict 199 | 200 | 201 | def slice_gemma_state_dict(state_dict, config, num_expert=1): 202 | # fmt: off 203 | # text decoder (gemma) 204 | # no embedding vector, the expert just has the decoder layers 205 | 206 | embedding_vector = torch.zeros([config.vocab_size, config.hidden_size]) 207 | state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector 208 | 209 | # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. 210 | 211 | suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" 212 | 213 | llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") 214 | llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") 215 | llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") 216 | 217 | llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") 218 | llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") 219 | # TODO verify correctness of layer norm loading 220 | 221 | llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") 222 | llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") 223 | 224 | for i in range(config.num_hidden_layers): 225 | q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size) 226 | 227 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped 228 | 229 | k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() 230 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped 231 | v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() 232 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped 233 | 234 | # output projection. 235 | 236 | # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024) 237 | o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0) 238 | 239 | state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped 240 | # mlp layers 241 | gate_proj_weight = llm_mlp_gating_einsum[i, 0] 242 | state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() 243 | up_proj_weight = llm_mlp_gating_einsum[i, 1] 244 | state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() 245 | state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() 246 | state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] 247 | state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] 248 | 249 | state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}") 250 | state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here) 251 | 252 | # fmt: on 253 | final_state_dict = {} 254 | for key, value in state_dict.items(): 255 | if not isinstance(value, torch.Tensor): 256 | final_state_dict[key] = torch.from_numpy(value) 257 | else: 258 | final_state_dict[key] = value 259 | return final_state_dict 260 | 261 | 262 | def flatten_for_memory(tree, parent_key=""): 263 | out = {} 264 | for k, v in tree.items(): 265 | new_key = f"{parent_key}/{k}" if parent_key else k 266 | if isinstance(v, dict): 267 | out.update(flatten_for_memory(v, new_key)) 268 | else: 269 | out[new_key] = np.array(v) # Ensure conversion to np.array for consistency 270 | return out 271 | 272 | 273 | def flatten_for_npz(tree, parent_key=""): 274 | out = {} 275 | for k, v in tree.items(): 276 | new_key = f"{parent_key}/{k}" if parent_key else k 277 | if isinstance(v, dict): 278 | out.update(flatten_for_npz(v, new_key)) 279 | else: 280 | # bf16/f32 here? 281 | out[new_key] = np.array(v) 282 | return out 283 | 284 | 285 | def slice_initial_orbax_checkpoint(checkpoint_dir: str): 286 | params_path = pathlib.Path(checkpoint_dir).resolve() 287 | checkpointer = ocp.PyTreeCheckpointer() 288 | 289 | metadata = checkpointer.metadata(params_path) 290 | print("Metadata keys:", list(metadata.keys())) 291 | 292 | params_name = "params" 293 | 294 | item = {params_name: metadata[params_name]} 295 | device = jax.local_devices()[0] # Use the first local device 296 | sharding = SingleDeviceSharding(device) 297 | restored = checkpointer.restore( 298 | params_path, 299 | ocp.args.PyTreeRestore( 300 | item=item, 301 | restore_args=jax.tree_util.tree_map( 302 | lambda _: ocp.ArrayRestoreArgs( 303 | restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it 304 | sharding=sharding, 305 | ), 306 | item, 307 | ), 308 | transforms={}, 309 | ), 310 | ) 311 | params = restored[params_name] 312 | 313 | # get params for PaliGemma 314 | pali_params = params["PaliGemma"] 315 | del params["PaliGemma"] 316 | pali_params_flat = flatten_for_npz(pali_params) 317 | return {"paligemma_params": pali_params_flat, "projection_params": params} 318 | 319 | 320 | def update_keys_with_prefix(d: dict, prefix: str) -> dict: 321 | """Update dictionary keys by adding a prefix.""" 322 | return {f"{prefix}{key}": value for key, value in d.items()} 323 | 324 | 325 | def convert_pi0_checkpoint( 326 | checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str 327 | ): 328 | # Break down orbax ckpts - they are in OCDBT 329 | initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) 330 | # process projection params 331 | keys = [ 332 | "state_proj", 333 | "action_in_proj", 334 | "action_out_proj", 335 | "action_time_mlp_in", 336 | "action_time_mlp_out", 337 | ] 338 | 339 | # projection_params = {} 340 | # for key in keys: 341 | # kernel_params = initial_params["projection_params"][key]["kernel"] 342 | # bias_params = initial_params["projection_params"][key]["bias"] 343 | # if isinstance(kernel_params, dict): 344 | # weight = kernel_params["value"] 345 | # bias = bias_params["value"] 346 | # else: 347 | # weight = kernel_params 348 | # bias = bias_params 349 | # projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T 350 | # projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias)) 351 | 352 | # Process PaliGemma weights 353 | paligemma_config = get_paligemma_config(precision) 354 | paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict( 355 | initial_params["paligemma_params"], paligemma_config 356 | ) 357 | 358 | # # Process Gemma weights (at this stage they are unused) 359 | # gemma_config = get_gemma_config(precision) 360 | # gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config) 361 | 362 | # Instantiate model from configs 363 | 364 | if "pi0_aloha_sim" in checkpoint_dir: 365 | pi0_config = PI0Config( 366 | empty_cameras=2, 367 | adapt_to_pi_aloha=True, 368 | use_delta_joint_actions_aloha=False, 369 | ) 370 | elif "pi0_aloha_towel" in checkpoint_dir: 371 | pi0_config = PI0Config( 372 | adapt_to_pi_aloha=True, 373 | use_delta_joint_actions_aloha=True, 374 | ) 375 | elif "pi0_base" in checkpoint_dir: 376 | pi0_config = PI0Config( 377 | empty_cameras=0, 378 | adapt_to_pi_aloha=False, 379 | use_delta_joint_actions_aloha=False, 380 | ) 381 | elif "pi0_libero" in checkpoint_dir: 382 | pi0_config = PI0Config( 383 | empty_cameras=0, 384 | adapt_to_pi_aloha=False, 385 | use_delta_joint_actions_aloha=False, 386 | ) 387 | elif "pi0_fast_libero" in checkpoint_dir: 388 | pi0_config = PI0FASTConfig( 389 | adapt_to_pi_aloha=False, 390 | ) 391 | elif "pi0_fast_base" in checkpoint_dir: 392 | pi0_config = PI0FASTConfig( 393 | adapt_to_pi_aloha=False, 394 | ) 395 | else: 396 | raise ValueError() 397 | 398 | pi0_config.device = "cuda:1" 399 | 400 | # gemma_config=gemma_config, paligemma_config=paligemma_config) 401 | pi0_model = PI0FASTPolicy(pi0_config) 402 | 403 | paligemma_params = update_keys_with_prefix(paligemma_params, "model.pi0_") 404 | # gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") 405 | # projection_params = update_keys_with_prefix(projection_params, "model.") 406 | 407 | # load state dict 408 | torch_dtype = PRECISIONS[precision] 409 | print(paligemma_params.keys()) 410 | pi0_model.load_state_dict({**paligemma_params}) 411 | pi0_model = pi0_model.to(torch_dtype) 412 | # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) 413 | 414 | pi0_model.save_pretrained(output_path, safe_serialization=True) 415 | # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype) 416 | 417 | # assert that model loads properly 418 | del pi0_model 419 | PI0FASTPolicy.from_pretrained(output_path) 420 | 421 | 422 | if __name__ == "__main__": 423 | parser = argparse.ArgumentParser() 424 | parser.add_argument( 425 | "--checkpoint_dir", 426 | default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params", 427 | type=str, 428 | help="Path to the ocdbt checkpoint", 429 | ) 430 | 431 | parser.add_argument( 432 | "--precision", 433 | choices=["float32", "bfloat16", "float16"], 434 | default="bfloat16", 435 | type=str, 436 | help="Precision identifier for model conversion - should match the base checkpoint precision.", 437 | ) 438 | # tokenizer is identical to paligemma, it appears 439 | 440 | parser.add_argument( 441 | "--tokenizer_hub_id", 442 | default="/home/dzb/pretrained/paligemma3b", 443 | type=str, 444 | help="Hub path to the tokenizer to save", 445 | ) 446 | 447 | parser.add_argument( 448 | "--output_path", 449 | required=True, 450 | type=str, 451 | help="Path to save converted weights to", 452 | ) 453 | 454 | args = parser.parse_args() 455 | convert_pi0_checkpoint( 456 | checkpoint_dir=args.checkpoint_dir, 457 | precision=args.precision, 458 | tokenizer_id=args.tokenizer_hub_id, 459 | output_path=args.output_path, 460 | ) 461 | -------------------------------------------------------------------------------- /pi0/modeling_pi0fast.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import einops 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F # noqa: N812 7 | from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig 8 | from lerobot.common.policies.pretrained import PreTrainedPolicy 9 | from PIL import Image 10 | from scipy.fft import idct 11 | from termcolor import cprint 12 | from torch import Tensor, nn 13 | from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration 14 | from transformers.cache_utils import HybridCache, StaticCache 15 | from transformers.models.auto import CONFIG_MAPPING 16 | 17 | IMAGE_KEYS = ( 18 | "base_0_rgb", 19 | "left_wrist_0_rgb", 20 | "right_wrist_0_rgb", 21 | ) 22 | 23 | PRECISION = { 24 | "float16": torch.float16, 25 | "float32": torch.float32, 26 | "bfloat16": torch.bfloat16, 27 | } 28 | 29 | 30 | def get_paligemma_config(torch_dtype, **kwargs): 31 | return CONFIG_MAPPING["paligemma"]( 32 | transformers_version="4.48.1", 33 | _vocab_size=257152, 34 | bos_token_id=2, 35 | eos_token_id=1, 36 | hidden_size=2048, 37 | image_token_index=257152, 38 | model_type="paligemma", 39 | pad_token_id=0, 40 | projection_dim=2048, 41 | text_config={ 42 | "hidden_activation": "gelu_pytorch_tanh", 43 | "hidden_size": 2048, 44 | "intermediate_size": 16384, 45 | "model_type": "gemma", 46 | "num_attention_heads": 8, 47 | "num_hidden_layers": 18, 48 | "num_image_tokens": 256, 49 | "num_key_value_heads": 1, 50 | "torch_dtype": torch_dtype, 51 | "vocab_size": 257152, 52 | "_attn_implementation": "eager", 53 | }, 54 | vision_config={ 55 | "hidden_size": 1152, 56 | "intermediate_size": 4304, 57 | "model_type": "siglip_vision_model", 58 | "num_attention_heads": 16, 59 | "num_hidden_layers": 27, 60 | "num_image_tokens": 256, 61 | "patch_size": 14, 62 | "projection_dim": 2048, 63 | "projector_hidden_act": "gelu_pytorch_tanh", 64 | "torch_dtype": torch_dtype, 65 | "vision_use_head": False, 66 | }, 67 | **kwargs, 68 | ) 69 | 70 | 71 | class PI0FASTPolicy(PreTrainedPolicy): 72 | config_class = PI0FASTConfig 73 | name = "torch_pi0fast" 74 | 75 | def __init__( 76 | self, 77 | config: PI0FASTConfig, 78 | tokenizer_path: str = "google/paligemma-3b-pt-224", 79 | ): 80 | """ 81 | Args: 82 | config: Policy configuration class instance or None, in which case the default instantiation of 83 | the configuration class is used. 84 | dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected 85 | that they will be passed with a call to `load_state_dict` before the policy is used. 86 | """ 87 | 88 | super().__init__(config) 89 | self.config = config 90 | self.language_tokenizer = AutoProcessor.from_pretrained(tokenizer_path) 91 | self.model = PI0FAST(config) 92 | self.reset() 93 | 94 | def reset(self): 95 | return None 96 | 97 | def get_optim_params(self) -> dict: 98 | return self.parameters() 99 | 100 | @torch.no_grad 101 | def select_action(self, observation: dict[str, Tensor]) -> Tensor: 102 | """ 103 | Observation: { 104 | "image": { 105 | "base_0_rgb": (*b, c, h, w), # uint8 [0, 255] 106 | ... 107 | }, 108 | "state": float32 [*b, s], 109 | "prompt": List[str], 110 | 111 | "lang_tokens": float32 [*b, l], 112 | "lang_masks": float32 [*b, l], 113 | } 114 | """ 115 | self.eval() 116 | 117 | actions = self.model.generate_actions(observation) 118 | return actions 119 | 120 | def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: 121 | loss_dict = self.model.forward(batch) 122 | return loss_dict["loss"], loss_dict 123 | 124 | 125 | def block_causal_update_causal_mask( 126 | attention_mask, 127 | token_type_ids=None, 128 | past_key_values=None, 129 | cache_position=None, 130 | input_tensor=None, 131 | attn_implementation: str = "eager", 132 | dtype: torch.dtype = "float32", 133 | ): 134 | """ 135 | Update the causal mask during training and generation. It can be customized to different attention masks. 136 | """ 137 | if attn_implementation == "flash_attention_2": 138 | if attention_mask is not None and 0.0 in attention_mask: 139 | return attention_mask 140 | return None 141 | using_static_cache = isinstance(past_key_values, StaticCache) 142 | min_dtype = torch.finfo(dtype).min 143 | 144 | if input_tensor is None: 145 | input_tensor = attention_mask 146 | 147 | inputs_lead_dim, sequence_length = input_tensor.shape[:2] 148 | 149 | if using_static_cache or isinstance(past_key_values, HybridCache): 150 | target_length = past_key_values.get_max_cache_shape() 151 | else: 152 | target_length = ( 153 | attention_mask.shape[-1] 154 | if isinstance(attention_mask, torch.Tensor) 155 | else cache_position[0] + sequence_length + 1 156 | ) 157 | 158 | # Handle precomputed attention masks 159 | if attention_mask is not None and attention_mask.dim() == 4: 160 | return attention_mask 161 | 162 | # Causal mask initialization 163 | causal_mask = torch.full( 164 | (sequence_length, target_length), 165 | fill_value=min_dtype, 166 | dtype=dtype, 167 | device=cache_position.device, 168 | ) 169 | 170 | # Standard causal masking (triu ensures tokens can only attend to past) 171 | if sequence_length != 1: 172 | causal_mask = torch.triu(causal_mask, diagonal=1) 173 | 174 | # Apply block causal mask 175 | if token_type_ids is not None: 176 | token_type_ids = token_type_ids.to(causal_mask.device).bool() 177 | cumsum = torch.cumsum(token_type_ids, dim=1) 178 | block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None] 179 | 180 | # Combine causal_mask with block-wise attention mask 181 | causal_mask = torch.where(block_causal_mask, 0.0, causal_mask) 182 | causal_mask = causal_mask[:, None, :, :] 183 | else: 184 | # Apply past cache position constraint 185 | causal_mask *= torch.arange( 186 | target_length, device=cache_position.device 187 | ) > cache_position.reshape(-1, 1) 188 | causal_mask = causal_mask[None, None, :, :].expand( 189 | inputs_lead_dim, 1, -1, -1 190 | ) 191 | else: 192 | # Apply past cache position constraint 193 | causal_mask *= torch.arange( 194 | target_length, device=cache_position.device 195 | ) > cache_position.reshape(-1, 1) 196 | causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) 197 | 198 | if attention_mask is not None: 199 | causal_mask = ( 200 | causal_mask.clone() 201 | ) # Copy to contiguous memory for in-place edits 202 | mask_length = attention_mask.shape[-1] 203 | 204 | # Apply padding mask 205 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ 206 | :, None, None, : 207 | ].to(causal_mask.device) 208 | padding_mask = padding_mask == 0 209 | causal_mask[:, :, :, :mask_length] = causal_mask[ 210 | :, :, :, :mask_length 211 | ].masked_fill(padding_mask, min_dtype) 212 | 213 | return causal_mask 214 | 215 | 216 | def prepare_inputs_for_generation( 217 | # self, 218 | input_ids, 219 | past_key_values=None, 220 | inputs_embeds=None, 221 | cache_position=None, 222 | position_ids=None, 223 | pixel_values=None, 224 | attention_mask=None, 225 | token_type_ids=None, 226 | use_cache=True, 227 | num_logits_to_keep=None, 228 | labels=None, 229 | self=None, 230 | **kwargs, 231 | ): 232 | # create block causal attention 233 | if cache_position[0] > 0 and input_ids.shape[1] > 0: 234 | input_tensor = input_ids[:, -1:] 235 | new_positions = ( 236 | torch.ones( 237 | (position_ids.shape[0], input_ids.shape[1]), 238 | dtype=position_ids.dtype, 239 | device=position_ids.device, 240 | ).cumsum(-1) 241 | + position_ids[:, -1:] 242 | ) 243 | position_ids = torch.cat([position_ids, new_positions], dim=-1) 244 | else: 245 | input_tensor = inputs_embeds 246 | attention_mask = block_causal_update_causal_mask( 247 | attention_mask=attention_mask, 248 | past_key_values=past_key_values, 249 | cache_position=cache_position, 250 | input_tensor=input_tensor, 251 | token_type_ids=token_type_ids, 252 | dtype=self.dtype, 253 | attn_implementation=self.config.text_config._attn_implementation, 254 | ) 255 | # Overwritten -- custom `position_ids` and `pixel_values` handling 256 | model_inputs = self.language_model.prepare_inputs_for_generation( 257 | input_ids, 258 | past_key_values=past_key_values, 259 | inputs_embeds=inputs_embeds, 260 | attention_mask=attention_mask, 261 | position_ids=position_ids, 262 | cache_position=cache_position, 263 | use_cache=use_cache, 264 | num_logits_to_keep=num_logits_to_keep, 265 | token_type_ids=token_type_ids, 266 | **kwargs, 267 | ) 268 | 269 | # Position_ids in Paligemma are 1-indexed 270 | if model_inputs.get("position_ids") is not None: 271 | model_inputs["position_ids"] += 1 272 | # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore 273 | # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always 274 | if cache_position[0] == 0: 275 | model_inputs["pixel_values"] = pixel_values 276 | is_training = token_type_ids is not None and labels is not None 277 | if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): 278 | input_tensor = inputs_embeds if inputs_embeds is not None else input_ids 279 | causal_mask = self._update_causal_mask( 280 | attention_mask, 281 | token_type_ids, 282 | past_key_values, 283 | cache_position, 284 | input_tensor, 285 | is_training, 286 | ) 287 | model_inputs["attention_mask"] = causal_mask 288 | 289 | return model_inputs 290 | 291 | 292 | class PI0FAST(nn.Module): 293 | def __init__(self, config: PI0FASTConfig): 294 | super().__init__() 295 | self.config = config 296 | 297 | fast_tokenizer_path = "physical-intelligence/fast" 298 | pi0_paligemma_path = "google/paligemma-3b-pt-224" 299 | self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path) 300 | self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path) 301 | self.fast_tokenizer = AutoProcessor.from_pretrained( 302 | fast_tokenizer_path, trust_remote_code=True 303 | ) 304 | self.fast_skip_tokens = self.config.fast_skip_tokens 305 | self.max_input_seq_len = self.config.max_input_seq_len 306 | self.action_horizon = self.config.chunk_size 307 | self.action_dim = self.config.max_action_dim 308 | precision = config.precision 309 | torch_precision = PRECISION.get(precision, torch.float32) 310 | 311 | self.pad_token_id = ( 312 | self.paligemma_tokenizer.pad_token_id 313 | if hasattr(self.paligemma_tokenizer, "pad_token_id") 314 | else self.paligemma_tokenizer.eos_token_id 315 | ) 316 | 317 | paligemma_config = get_paligemma_config(torch_dtype=torch_precision) 318 | self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) 319 | 320 | self.pi0_paligemma.prepare_inputs_for_generation = partial( 321 | prepare_inputs_for_generation, self=self.pi0_paligemma 322 | ) 323 | # change important stuff in bf16 324 | params_to_change_dtype = [ 325 | "language_model", 326 | "vision_tower", 327 | "multi_modal", 328 | ] 329 | for name, param in self.pi0_paligemma.named_parameters(): 330 | if any(selector in name for selector in params_to_change_dtype): 331 | param.data = param.data.to(dtype=torch_precision) 332 | self.set_requires_grad() 333 | self.image_keys = self.config.image_features.keys() 334 | self.ignore_index = self.pi0_paligemma.config.ignore_index 335 | self.padding_side = self.config.padding_side 336 | 337 | def set_requires_grad(self): 338 | if self.config.freeze_vision_encoder: 339 | self.pi0_paligemma.vision_tower.eval() 340 | for params in self.pi0_paligemma.vision_tower.parameters(): 341 | params.requires_grad = False 342 | # To avoid unused params issue with distributed training 343 | if self.config.freeze_lm_head: 344 | for name, params in self.pi0_paligemma.named_parameters(): 345 | if "embed_tokens" in name: # lm heads and embedding layer are tied 346 | params.requires_grad = False 347 | 348 | def embed_tokens(self, tokens: torch.Tensor): 349 | return self.pi0_paligemma.language_model.model.embed_tokens(tokens) 350 | 351 | def prepare_inputs_for_generation(self, *args, **kwargs): 352 | return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs) 353 | 354 | def prepare_images(self, observation: dict[str, Tensor]): 355 | """Use zeros for unpresented views without padding masks.""" 356 | dtype = observation["state"].dtype 357 | bsize = observation["state"].shape[0] 358 | images, img_masks = [], [] 359 | for key in IMAGE_KEYS: 360 | if key in observation["image"]: 361 | # resize, pad, and normalize 362 | img = observation["image"][key] 363 | img = img.to(dtype) / 127.5 - 1.0 364 | img = resize_with_pad( 365 | img, *self.config.resize_imgs_with_padding, pad_value=-1.0 366 | ) 367 | images.append(img) 368 | img_masks.append( 369 | torch.ones((bsize,), dtype=torch.bool, device=img.device) 370 | ) 371 | else: 372 | img = torch.full_like(img, fill_value=-1.0) 373 | images.append(img) 374 | img_masks.append( 375 | torch.ones((bsize,), dtype=torch.bool, device=img.device) 376 | ) 377 | images = torch.stack(images, dim=1) # (*b, n, c, h, w) 378 | img_masks = torch.stack(img_masks, dim=1) # (*b, n) 379 | return images, img_masks 380 | 381 | def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: 382 | out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens 383 | return out 384 | 385 | def fast_tokenizer_wrapper(self, actions_norm): 386 | actions_norm = actions_norm.to(torch.float32) 387 | batch_tokens = self.fast_tokenizer(actions_norm) 388 | fast_out = self.processor.tokenizer.pad( 389 | {"input_ids": batch_tokens}, return_tensors="pt" 390 | ) 391 | return fast_out 392 | 393 | def create_token_type_ids( 394 | self, padded_mask: torch.Tensor, prefix_len: int 395 | ) -> torch.Tensor: 396 | token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) 397 | # Compute cumulative sum mask 398 | cumsum_mask = (padded_mask != 0).cumsum(dim=1) 399 | # Suffix block (everything after prefix_len) 400 | suffix_mask = cumsum_mask > prefix_len 401 | token_type_ids = suffix_mask 402 | return token_type_ids 403 | 404 | def create_input_tokens(self, state, lang_text, actions=None): 405 | bsize = state.shape[0] 406 | device = state.device 407 | 408 | # Note that `state` is expected to be normalized to [-1, 1] range. 409 | bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] 410 | discretized = torch.bucketize(state, bins) - 1 411 | 412 | prefix_texts = [] 413 | for txt, disc in zip(lang_text, discretized, strict=False): 414 | cleaned = txt.lower().strip().replace("_", " ") 415 | state_str = " ".join(str(val.item()) for val in disc) 416 | prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") 417 | 418 | # tokenizer automatically adds token 419 | prefix_out = self.paligemma_tokenizer( 420 | prefix_texts, 421 | add_special_tokens=True, 422 | return_tensors="pt", 423 | padding="longest", 424 | truncation=False, 425 | ) 426 | prefix_ids = prefix_out["input_ids"].to(device) 427 | prefix_mask = prefix_out["attention_mask"].to(device) 428 | prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() 429 | 430 | if actions is not None: 431 | # see https://github.com/Physical-Intelligence/openpi/blob/0992224b1cf89d0fe282ac381596c1048b766adc/src/openpi/models/tokenizer.py#L39 432 | # JAX OpenPI does not: 433 | # 1. normalize action before passing to fast tokenizer 434 | # 2. pad action to 32dim 435 | # 3. replace action token with 0 to pad tokens 436 | # And it does: 437 | # 1. add "|" after action tokens 438 | 439 | fast_out = self.fast_tokenizer_wrapper(actions.cpu()) 440 | act_ids = fast_out["input_ids"] 441 | act_mask = fast_out["attention_mask"].to(device) 442 | act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) 443 | act_ids[torch.where(1 - act_mask)] = self.paligemma_tokenizer.pad_token_id 444 | bos = self.paligemma_tokenizer( 445 | "Action: ", add_special_tokens=False, return_tensors="pt" 446 | ) 447 | eos = self.paligemma_tokenizer( 448 | "|", add_special_tokens=False, return_tensors="pt" 449 | ) 450 | bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) 451 | bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) 452 | eos_token = eos["input_ids"].expand(act_ids.shape[0], -1).to(device) 453 | eos_mask = eos["attention_mask"].expand(act_ids.shape[0], -1).to(device) 454 | act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) 455 | act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) 456 | act_mask = act_mask.to(device) 457 | else: 458 | act_ids = torch.empty(bsize, 0, dtype=torch.long, device=device) 459 | act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) 460 | 461 | final_ids = torch.cat([prefix_ids, act_ids], dim=1) 462 | final_mask = torch.cat([prefix_mask, act_mask], dim=1) 463 | batch_inputs = { 464 | "input_ids": final_ids.tolist(), 465 | "attention_mask": final_mask.tolist(), 466 | } 467 | 468 | # Use tokenizer pad function 469 | padded_output = self.paligemma_tokenizer.pad( 470 | batch_inputs, padding="longest", max_length=180, return_tensors="pt" 471 | ) 472 | padded_mask = padded_output["attention_mask"] 473 | 474 | # define tensor of padding lengths 475 | att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens 476 | 477 | token_type_ids = self.create_token_type_ids( 478 | padded_mask=padded_mask, prefix_len=prefix_lens 479 | ) 480 | 481 | padded_output["padded_mask"] = padded_output.pop("attention_mask") 482 | padded_output["attention_mask"] = att_mask 483 | # loss is computed not on prefix, and not on padding 484 | padded_output["loss_mask"] = att_mask & padded_output["padded_mask"] 485 | padded_output["token_type_ids"] = token_type_ids 486 | return padded_output 487 | 488 | def shift_padding_side( 489 | self, 490 | tokens: torch.Tensor, 491 | ar_mask: torch.Tensor, 492 | padding_mask: torch.Tensor, 493 | loss_mask: torch.Tensor, 494 | targets: torch.Tensor, 495 | token_type_ids: torch.Tensor, 496 | padding_side: str = "right", 497 | ) -> tuple[torch.Tensor]: 498 | if padding_side not in ["right", "left"]: 499 | return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids 500 | 501 | new_tokens = torch.empty_like(tokens) 502 | new_ar_masks = torch.empty_like(ar_mask) 503 | new_padding_mask = torch.empty_like(padding_mask) 504 | new_loss_mask = torch.empty_like(loss_mask) 505 | new_targets = torch.empty_like(targets) 506 | new_token_type_ids = torch.empty_like(token_type_ids) 507 | batch_size = tokens.shape[0] 508 | for i in range(batch_size): 509 | padding_indices = torch.where(padding_mask[i] == 0)[0] 510 | non_padding_indices = torch.where(padding_mask[i] == 1)[0] 511 | if padding_side == "left": 512 | new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) 513 | else: 514 | new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) 515 | new_tokens[i] = tokens[i].index_select(0, new_indices) 516 | new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) 517 | new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) 518 | new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) 519 | new_targets[i] = targets[i].index_select(0, new_indices) 520 | new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) 521 | 522 | return ( 523 | new_tokens, 524 | new_ar_masks, 525 | new_padding_mask, 526 | new_loss_mask, 527 | new_targets, 528 | new_token_type_ids, 529 | ) 530 | 531 | def forward(self, batch: dict[str, Tensor]): 532 | device = batch["state"].device 533 | images, img_masks = self.prepare_images(batch) 534 | padded_outs = self.create_input_tokens( 535 | state=batch["state"], 536 | lang_text=batch["prompt"], 537 | actions=batch["action"], 538 | ) 539 | embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( 540 | images, 541 | img_masks, 542 | padded_outs["input_ids"], 543 | padded_outs["padded_mask"], 544 | padded_outs["attention_mask"], 545 | padded_outs["loss_mask"], 546 | padded_outs["token_type_ids"], 547 | padding_side=self.padding_side, 548 | ) 549 | position_ids = torch.cumsum(pad_masks, dim=1) - 1 550 | token_type_ids = token_type_ids.to(dtype=torch.int64) 551 | past_seen_tokens = 0 552 | cache_position = torch.arange( 553 | past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device 554 | ) 555 | pad_masks = block_causal_update_causal_mask( 556 | attention_mask=pad_masks, 557 | past_key_values=None, 558 | cache_position=cache_position, 559 | input_tensor=embs, 560 | token_type_ids=token_type_ids, 561 | dtype=self.pi0_paligemma.dtype, 562 | attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation, 563 | ) 564 | outputs = self.pi0_paligemma.forward( 565 | input_ids=None, 566 | token_type_ids=None, 567 | attention_mask=pad_masks, 568 | position_ids=position_ids, 569 | past_key_values=None, 570 | inputs_embeds=embs, 571 | use_cache=False, 572 | labels=None, 573 | ) 574 | 575 | logits = outputs.logits 576 | 577 | loss_fct = nn.CrossEntropyLoss(reduction="none") 578 | 579 | # Shift left for next-step prediction 580 | logits = logits[:, 588:-1, :] 581 | targets = targets[:, 588 + 1 :].to(device) # Shift targets 582 | loss_mask = loss_mask[:, 588 + 1 :].to(device) # Ensure correct shape 583 | 584 | # Compute per-token loss 585 | token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) 586 | 587 | # Apply loss mask 588 | token_loss = token_loss * loss_mask.reshape(-1) 589 | 590 | # Compute final loss 591 | loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) 592 | 593 | # accuracy 594 | with torch.no_grad(): 595 | acc = (logits.argmax(-1) == targets)[loss_mask].float().mean() 596 | 597 | # Return loss dictionary 598 | loss_dict = {"ce_loss": loss.item(), "loss": loss, "acc": acc.item()} 599 | return loss_dict 600 | 601 | def decode_actions_with_fast( 602 | self, 603 | tokens: list[list[int]], 604 | *, 605 | time_horizon: int | None = None, 606 | action_dim: int | None = None, 607 | relaxed_decoding: bool = True, 608 | ) -> np.array: 609 | """ 610 | Adapt original decoding in FAST to always return actions instead of zeros. 611 | """ 612 | self.time_horizon = ( 613 | time_horizon 614 | or self.fast_tokenizer.time_horizon 615 | or self.fast_tokenizer.called_time_horizon 616 | ) 617 | self.action_dim = ( 618 | action_dim 619 | or self.fast_tokenizer.action_dim 620 | or self.fast_tokenizer.called_action_dim 621 | ) 622 | 623 | # Cache the time horizon and action dimension for the next call 624 | self.called_time_horizon = self.time_horizon 625 | self.called_action_dim = self.action_dim 626 | 627 | assert self.time_horizon is not None and self.action_dim is not None, ( 628 | "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." 629 | ) 630 | 631 | decoded_actions = [] 632 | for token in tokens: 633 | try: 634 | decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) 635 | decoded_dct_coeff = ( 636 | np.array(list(map(ord, decoded_tokens))) 637 | + self.fast_tokenizer.min_token 638 | ) 639 | if relaxed_decoding: 640 | # Expected sequence length 641 | expected_seq_len = self.time_horizon * self.action_dim 642 | diff = expected_seq_len - decoded_dct_coeff.shape[0] 643 | # Apply truncation if too long 644 | if diff < 0: 645 | decoded_dct_coeff = decoded_dct_coeff[ 646 | :expected_seq_len 647 | ] # Truncate on the right 648 | cprint( 649 | f"Relaxed decoding: expected sequence length {expected_seq_len}, got {decoded_dct_coeff.shape[0]}. ", 650 | "yellow", 651 | ) 652 | # Apply padding if too short 653 | elif diff > 0: 654 | decoded_dct_coeff = np.pad( 655 | decoded_dct_coeff, 656 | (0, diff), 657 | mode="constant", 658 | constant_values=0, 659 | ) 660 | cprint( 661 | f"Relaxed decoding: expected sequence length {expected_seq_len}, got {decoded_dct_coeff.shape[0]}. ", 662 | "yellow", 663 | ) 664 | 665 | decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) 666 | assert decoded_dct_coeff.shape == ( 667 | self.time_horizon, 668 | self.action_dim, 669 | ), ( 670 | f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" 671 | ) 672 | except Exception as e: 673 | print(f"Error decoding tokens: {e}") 674 | print(f"Tokens: {token}") 675 | decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) 676 | decoded_actions.append( 677 | idct( 678 | decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho" 679 | ) 680 | ) 681 | return np.stack(decoded_actions) 682 | 683 | def extract_actions( 684 | self, tokens: torch.Tensor, action_horizon: int, action_dim: int 685 | ) -> torch.Tensor: 686 | """ 687 | Extracts actions from predicted output tokens using the FAST model. 688 | 689 | Args: 690 | tokens (torch.Tensor): The input tensor of tokenized outputs. 691 | action_horizon (int): The number of timesteps for actions. 692 | action_dim (int): The dimensionality of each action. 693 | 694 | Returns: 695 | torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim). 696 | """ 697 | # Decode predicted output tokens 698 | decoded_tokens = self.paligemma_tokenizer.batch_decode( 699 | tokens, skip_special_tokens=True 700 | ) 701 | cleaned_tokens = [ 702 | tokens_sequence.replace("Action:", "") 703 | .replace(":", "") 704 | .strip() 705 | .split("|")[0] 706 | .strip() 707 | for tokens_sequence in decoded_tokens 708 | ] 709 | raw_action_tokens = [ 710 | self.processor.tokenizer.encode( 711 | sample_tokens, return_tensors="pt", padding=False 712 | ) 713 | for sample_tokens in cleaned_tokens 714 | ] # something like this should be robust #looks good 715 | action_tokens = [ 716 | self._act_tokens_to_paligemma_tokens(raw_action_token) 717 | for raw_action_token in raw_action_tokens 718 | ] 719 | # returns the tensor of decoded actions per sample in a list 720 | decoded_actions = [ 721 | torch.tensor( 722 | self.decode_actions_with_fast( 723 | tok.tolist(), 724 | time_horizon=action_horizon, 725 | action_dim=action_dim, 726 | relaxed_decoding=self.config.relaxed_action_decoding, 727 | ), 728 | device=tokens.device, 729 | ).squeeze(0) 730 | for tok in action_tokens 731 | ] 732 | 733 | return torch.stack( 734 | decoded_actions, 735 | dim=0, 736 | ) 737 | 738 | def generate_actions(self, batch: dict[str, Tensor]): 739 | # normalze, resize, pad, and stack images 740 | images, img_masks = self.prepare_images(batch) 741 | 742 | # create input tokens from state and prompt 743 | padded_outs = self.create_input_tokens( 744 | state=batch["state"], lang_text=batch["prompt"], actions=None 745 | ) 746 | 747 | # embed inputs 748 | tokens = padded_outs["input_ids"] 749 | pad_mask = padded_outs["padded_mask"] 750 | ar_mask = padded_outs["attention_mask"] 751 | loss_mask = padded_outs["loss_mask"] 752 | token_type_ids = padded_outs["token_type_ids"] 753 | embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = ( 754 | self.embed_inputs( 755 | images, 756 | img_masks, 757 | tokens, 758 | pad_mask, 759 | ar_mask, 760 | loss_mask, 761 | token_type_ids, 762 | padding_side="left", 763 | ) 764 | ) 765 | 766 | # generate actions 767 | token_type_ids = token_type_ids.to(dtype=torch.int64) 768 | prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1 769 | output_tokens = self.pi0_paligemma.generate( 770 | input_ids=None, 771 | attention_mask=pad_masks, 772 | position_ids=prefix_position_ids, 773 | past_key_values=None, 774 | inputs_embeds=embs, 775 | use_cache=self.config.use_cache, 776 | max_new_tokens=self.config.max_decoding_steps, 777 | do_sample=False, 778 | num_beams=1, 779 | token_type_ids=token_type_ids, 780 | ) 781 | 782 | # decode actions from output tokens 783 | actions = self.extract_actions( 784 | output_tokens, self.action_horizon, self.action_dim 785 | ) 786 | return actions 787 | 788 | def embed_image(self, image: torch.Tensor): 789 | # Handle different transformers versions 790 | if hasattr(self.pi0_paligemma, "get_image_features"): 791 | return self.pi0_paligemma.get_image_features(image) 792 | else: 793 | return self.pi0_paligemma.model.get_image_features(image) 794 | 795 | def embed_inputs( 796 | self, 797 | images, 798 | img_masks, 799 | tokens, 800 | pad_mask, 801 | ar_mask, 802 | loss_mask, 803 | token_type_ids, 804 | padding_side: str = "right", 805 | ): 806 | bsize = images.shape[0] 807 | device = images.device 808 | 809 | # embed image 810 | images = einops.rearrange(images, "b n c h w -> (b n) c h w") 811 | img_emb = self.embed_image(images) 812 | img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize) 813 | num_img_embs = img_emb.shape[1] 814 | img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_img_embs // 3) 815 | img_tgt_tokens = ( 816 | torch.ones_like(img_masks, dtype=torch.long) * self.pad_token_id 817 | ) 818 | img_loss_mask = torch.zeros_like(img_masks, dtype=torch.long) 819 | 820 | # embed language and state 821 | tokens_emb = self.embed_tokens(tokens.to(device)) 822 | num_tokens_embs = tokens_emb.shape[1] 823 | 824 | embs = torch.cat([img_emb, tokens_emb], dim=1) 825 | pad_masks = torch.empty( 826 | (bsize, num_img_embs + num_tokens_embs), device=device, dtype=torch.bool 827 | ) 828 | att_masks = torch.zeros( 829 | (bsize, num_img_embs + num_tokens_embs), device=device, dtype=torch.bool 830 | ) 831 | loss_masks = torch.empty( 832 | (bsize, num_img_embs + num_tokens_embs), device=device, dtype=torch.bool 833 | ) 834 | 835 | pad_masks[:, :num_img_embs] = img_masks 836 | pad_masks[:, num_img_embs:] = pad_mask 837 | att_masks[:, num_img_embs:] = ar_mask 838 | loss_masks[:, :num_img_embs] = img_loss_mask 839 | loss_masks[:, num_img_embs:] = loss_mask 840 | 841 | targets = torch.cat([img_tgt_tokens.to(device), tokens.to(device)], dim=1) 842 | token_type_ids = torch.cat( 843 | [img_loss_mask.to(device), token_type_ids.to(device)], dim=1 844 | ) 845 | 846 | # Shift pad tokens to the left (.generate()) or right (.train()) 847 | embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = ( 848 | self.shift_padding_side( 849 | embs, 850 | att_masks, 851 | pad_masks, 852 | loss_masks, 853 | targets, 854 | token_type_ids, 855 | padding_side=padding_side, 856 | ) 857 | ) 858 | 859 | targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets) 860 | return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids 861 | 862 | 863 | def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): 864 | # assume no-op when width height fits already 865 | if img.ndim != 4: 866 | raise ValueError(f"(b,c,h,w) expected, but {img.shape}") 867 | 868 | cur_height, cur_width = img.shape[2:] 869 | 870 | ratio = max(cur_width / width, cur_height / height) 871 | resized_height = int(cur_height / ratio) 872 | resized_width = int(cur_width / ratio) 873 | 874 | if interpolate_like_pi: 875 | img = (img * 255.0).to(dtype=torch.uint8) 876 | img = img.permute(0, 2, 3, 1) 877 | original_device = img.device 878 | img = img.to(device="cpu").numpy() 879 | imgs = [] 880 | for sub_img in img: 881 | sub_img = Image.fromarray(sub_img) 882 | resized_img = sub_img.resize((resized_width, resized_height), resample=2) 883 | resized_img = torch.from_numpy(np.array(resized_img)) 884 | imgs.append(resized_img) 885 | img = torch.stack(imgs, dim=0) 886 | img = img.permute(0, 3, 1, 2) 887 | resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0 888 | else: 889 | resized_img = F.interpolate( 890 | img, 891 | size=(resized_height, resized_width), 892 | mode="bilinear", 893 | align_corners=False, 894 | ) 895 | 896 | pad_height = max(0, int(height - resized_height)) 897 | pad_width = max(0, int(width - resized_width)) 898 | 899 | # pad on left and top of image 900 | padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) 901 | return padded_img 902 | --------------------------------------------------------------------------------