├── wall_x ├── __init__.py ├── data │ ├── __init__.py │ └── config.py ├── fusions │ ├── __init__.py │ └── backend.py ├── model │ ├── __init__.py │ └── qwen2_5_based │ │ ├── __init__.py │ │ └── configuration_qwen2_5_vl.py ├── trainer │ └── __init__.py ├── utils │ ├── __init__.py │ └── constant.py └── serving │ ├── policy │ ├── __init__.py │ ├── wall_x_policy.py │ └── utils.py │ ├── __init__.py │ ├── websocket_policy_server.py │ ├── launch_serving.py │ ├── README.md │ └── client.py ├── pyproject.toml ├── .gitattributes ├── assets ├── QRcode_community.jpg └── cot_example_frame.png ├── .gitmodules ├── csrc ├── rot_pos.h ├── window_index.h ├── dual_asym_grouped_gemm.h ├── rope.h ├── rope_index.h ├── ops.cu ├── permute.h ├── README.md ├── window_index.cu ├── rot_pos.cu └── dual_asym_grouped_gemm.cu ├── workspace ├── lerobot_example │ ├── evaluation │ │ └── lerobot_openloop.png │ ├── run.sh │ ├── qwen25_config.json │ ├── libero │ │ └── config_qact_libero_from_vlm.yml │ ├── config_qact_from_vlm.yml │ └── config_qact.yml └── README.md ├── requirements.txt ├── CONTRIBUTING.md ├── .github ├── workflows │ └── lint.yml └── ISSUE_TEMPLATE │ └── bug_report.md ├── .pre-commit-config.yaml ├── scripts ├── merge_tokenizer.py ├── compute_norm_stats.py ├── fake_inference.py ├── vqa_inference.py ├── draw_openloop_plot.py ├── normalize.py └── merge_sharded_weights.py ├── setup.py ├── train_qact.py ├── .gitignore └── README.md /wall_x/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wall_x/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wall_x/fusions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wall_x/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wall_x/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wall_x/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | per-file-ignores = { "__init__.py" = ["F401", "E402"] } 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | workspace/lerobot_example/evaluation/lerobot_openloop.png filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /assets/QRcode_community.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-Square-Robot/wall-x/HEAD/assets/QRcode_community.jpg -------------------------------------------------------------------------------- /assets/cot_example_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-Square-Robot/wall-x/HEAD/assets/cot_example_frame.png -------------------------------------------------------------------------------- /wall_x/serving/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .wall_x_policy import WallXPolicy 2 | 3 | __all__ = ["WallXPolicy"] 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/cutlass"] 2 | path = 3rdparty/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /wall_x/serving/__init__.py: -------------------------------------------------------------------------------- 1 | from .websocket_policy_server import WebsocketPolicyServer, BasePolicy 2 | 3 | __all__ = ["WebsocketPolicyServer", "BasePolicy"] 4 | -------------------------------------------------------------------------------- /csrc/rot_pos.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor fused_rot_pos_emb_cuda( 4 | torch::Tensor inv_freq, 5 | torch::Tensor grid_thw, 6 | int spatial_merge_size); 7 | -------------------------------------------------------------------------------- /workspace/lerobot_example/evaluation/lerobot_openloop.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:aae8646e566b79b64a669956d6d2c779d72809010c802f75b2623ee371444b47 3 | size 985979 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.6.0 2 | torchvision==0.21.0 3 | torchaudio==2.6.0 4 | transformers==4.49.0 5 | accelerate==1.10.1 6 | peft==0.17.1 7 | scipy==1.15.3 8 | torchdiffeq==0.2.5 9 | qwen_vl_utils==0.0.11 10 | -------------------------------------------------------------------------------- /wall_x/model/qwen2_5_based/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_qwen2_5_vl_act import Qwen2_5_VLMoEModel, Qwen2_5_VLMoEForAction 2 | from .configuration_qwen2_5_vl import Qwen2_5_VLConfig 3 | 4 | __all__ = [ 5 | "Qwen2_5_VLMoEModel", 6 | "Qwen2_5_VLMoEForAction", 7 | "Qwen2_5_VLConfig", 8 | ] 9 | -------------------------------------------------------------------------------- /csrc/window_index.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::tuple get_window_index_cuda( 5 | torch::Tensor grid_thw, 6 | int spatial_merge_size, 7 | int vit_merger_window_size, 8 | int patch_size, 9 | int spatial_merge_unit); 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Wall-x 2 | 3 | ## Submit a Pull Request 4 | 5 | Before opening a pull request, please make sure your code passes the lint checks. 6 | 7 | ```bash 8 | # Install pre-commit hooks (run once) 9 | pre-commit install 10 | ``` 11 | 12 | Or 13 | 14 | ```bash 15 | # Manually run all checks 16 | pre-commit run --all-files 17 | ``` 18 | -------------------------------------------------------------------------------- /csrc/dual_asym_grouped_gemm.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void AsymmetricDualExpertGemm( 4 | torch::Tensor input_expert0, 5 | torch::Tensor input_expert1, 6 | torch::Tensor weight_expert0, 7 | torch::Tensor weight_expert1, 8 | torch::Tensor output_expert0, 9 | torch::Tensor output_expert1, 10 | bool trans_a, bool trans_b); 11 | -------------------------------------------------------------------------------- /csrc/rope.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void launch_multimodal_rope_forward( 4 | torch::Tensor q, torch::Tensor k, torch::Tensor cos, torch::Tensor sin, 5 | torch::Tensor q_out, torch::Tensor k_out, 6 | std::vector mrope_section_doubled 7 | ); 8 | 9 | void launch_multimodal_rope_backward( 10 | torch::Tensor grad_q_out, torch::Tensor grad_k_out, 11 | torch::Tensor q, torch::Tensor k, torch::Tensor cos, torch::Tensor sin, 12 | torch::Tensor grad_q, torch::Tensor grad_k, 13 | std::vector mrope_section_doubled 14 | ); 15 | -------------------------------------------------------------------------------- /csrc/rope_index.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | std::tuple get_rope_index( 4 | const torch::optional &input_ids, 5 | const torch::optional &image_grid_thw, 6 | const torch::optional &video_grid_thw, 7 | const torch::optional &second_per_grid_ts, 8 | const torch::optional &attention_mask, 9 | int spatial_merge_size, 10 | int image_token_id, 11 | int video_token_id, 12 | int vision_start_token_id, 13 | float tokens_per_second); 14 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Pre-commit 2 | 3 | on: 4 | push: 5 | branches: [ main, master, develop ] 6 | pull_request: 7 | branches: [ main, master, develop ] 8 | 9 | jobs: 10 | pre-commit: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.11' 21 | 22 | - name: Install pre-commit 23 | run: pip install pre-commit 24 | 25 | - name: Run pre-commit 26 | run: pre-commit run --all-files 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | --- 11 | name: 'Bug Report (English)' 12 | about: Report a bug encountered while using or reproducing the Wall-X model 13 | title: '[Bug] ' 14 | labels: 'bug, needs-triage' 15 | --- 16 | 17 | **Describe the bug** 18 | A clear and concise description of what the bug is. 19 | 20 | **To Reproduce** 21 | Steps to reproduce the behavior: 22 | 1. Go to '...' 23 | 2. Run command '....' 24 | 3. See error 25 | 26 | **Expected behavior** 27 | A clear and concise description of what you expected to happen. 28 | 29 | **Logs & Screenshots** 30 | If applicable, add the complete error message (traceback) and screenshots to help explain your problem. 31 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: ".git" 4 | 5 | repos: 6 | - repo: https://github.com/astral-sh/ruff-pre-commit 7 | rev: v0.2.2 8 | hooks: 9 | - id: ruff 10 | args: [ --fix, --exit-non-zero-on-fix ] 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 24.2.0 14 | hooks: 15 | - id: black 16 | 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: v4.5.0 19 | hooks: 20 | - id: check-added-large-files 21 | - id: check-ast 22 | - id: check-case-conflict 23 | - id: check-merge-conflict 24 | - id: check-toml 25 | - id: check-yaml 26 | - id: end-of-file-fixer 27 | - id: trailing-whitespace 28 | -------------------------------------------------------------------------------- /workspace/lerobot_example/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 3 | NUM_GPUS=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l) 4 | 5 | # print current time 6 | echo "[current time: $(date +'%Y-%m-%d %H:%M:%S')]" 7 | 8 | code_dir="/path/to/wall-x" 9 | config_path="/path/to/wall-x/workspace/lerobot_example" 10 | 11 | # Use a fixed port instead of a random one 12 | export PORT=$((21000 + $RANDOM % 30000)) 13 | 14 | MASTER_PORT=10239 # use 5 digits ports 15 | 16 | export LAUNCHER="accelerate launch --num_processes=$NUM_GPUS --main_process_port=$PORT" 17 | 18 | export SCRIPT="${code_dir}/train_qact.py" 19 | export SCRIPT_ARGS="--config ${config_path}/config_qact.yml --seed $MASTER_PORT" 20 | 21 | echo "Running command: $LAUNCHER $SCRIPT $SCRIPT_ARGS" 22 | 23 | $LAUNCHER $SCRIPT $SCRIPT_ARGS 24 | -------------------------------------------------------------------------------- /csrc/ops.cu: -------------------------------------------------------------------------------- 1 | #include "dual_asym_grouped_gemm.h" 2 | #include "permute.h" 3 | #include "rope.h" 4 | #include "rope_index.h" 5 | #include "rot_pos.h" 6 | #include "window_index.h" 7 | 8 | #include 9 | 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("asym_dual_gmm", &AsymmetricDualExpertGemm, "Asymmetric Dual Expert Grouped GEMM."); 13 | m.def("permute", &moe_permute_topK_op, "Token permutation kernel"); 14 | m.def("unpermute", &moe_recover_topK_op, "Token un-permutation kernel"); 15 | m.def("unpermute_bwd", &moe_recover_topK_bwd_op, "Token un-permutation backward kernel"); 16 | m.def("rope", &launch_multimodal_rope_forward, "Multimodal RoPE forward kernel"); 17 | m.def("rope_bwd", &launch_multimodal_rope_backward, "Multimodal RoPE backward kernel"); 18 | m.def("rope_index", &get_rope_index, "Get RoPE index kernel"); 19 | m.def("rot_pos_emb", &fused_rot_pos_emb_cuda, "Fused Rotary Position Embedding kernel"); 20 | m.def("get_window_index", &get_window_index_cuda, "Get window index kernel"); 21 | } 22 | -------------------------------------------------------------------------------- /csrc/permute.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | using torch::Tensor; 12 | 13 | std::tuple> moe_permute_topK_op( 14 | Tensor input, 15 | Tensor indices, 16 | int64_t num_out_tokens, 17 | std::vector workspace, 18 | int64_t max_expanded_token_num); 19 | 20 | torch::Tensor moe_recover_topK_op( 21 | torch::Tensor input, 22 | torch::Tensor row_id_map, 23 | torch::Tensor prob_opt, 24 | int64_t num_tokens, 25 | int64_t num_topK); 26 | 27 | std::tuple moe_recover_topK_bwd_op( 28 | Tensor input_bwd, 29 | Tensor input_fwd, 30 | Tensor row_id_map, 31 | Tensor prob); 32 | -------------------------------------------------------------------------------- /scripts/merge_tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor 2 | import os 3 | 4 | processor_path = "/path/to/Qwen2.5-VL-3B-Instruct" 5 | action_tokenizer_path = "/path/to/fast" 6 | use_fast_tokenizer = True 7 | 8 | processor = AutoProcessor.from_pretrained(processor_path, use_fast=True) 9 | processor.tokenizer.padding_side = "left" 10 | 11 | action_tokenizer = AutoProcessor.from_pretrained( 12 | action_tokenizer_path, trust_remote_code=True 13 | ) 14 | 15 | new_tokens = ["<|propri|>", "<|action|>"] 16 | new_tokens += [f"<|action_token_{i}|>" for i in range(action_tokenizer.vocab_size)] 17 | num_added_tokens = processor.tokenizer.add_tokens(new_tokens) 18 | 19 | begin_idx_token = "<|action_token_0|>" 20 | token_id = processor.tokenizer.convert_tokens_to_ids(begin_idx_token) 21 | processor.tokenizer.init_kwargs["action_token_start_index"] = token_id 22 | processor.tokenizer.init_kwargs["action_token_vocab_size"] = action_tokenizer.vocab_size 23 | 24 | new_tokenizer_dir = "/path/to/new_tokenizer" 25 | os.makedirs(new_tokenizer_dir, exist_ok=True) 26 | processor.save_pretrained(new_tokenizer_dir) 27 | -------------------------------------------------------------------------------- /csrc/README.md: -------------------------------------------------------------------------------- 1 | # Fusion Operators (CSRC) 2 | 3 | High-performance CUDA kernels for accelerating model training, with specialized support for multimodal and MoE architectures. 4 | 5 | ## Operators 6 | 7 | ### Asymmetric Dual Expert GEMM 8 | - `asym_dual_gmm`: Simultaneous matrix multiplication for two experts 9 | - Supports all transpose combinations (NN, TN, NT, TT) 10 | 11 | ### Token Permutation 12 | - `permute`: Token permutation for MoE routing 13 | - `unpermute`: Token recovery after expert computation 14 | - `unpermute_bwd`: Backward pass for token recovery 15 | 16 | ### Multimodal RoPE 17 | - `rope`: Rotary Position Embedding forward pass 18 | - `rope_bwd`: RoPE backward pass 19 | - `rope_index`: Generates position indices for multimodal RoPE 20 | - `rot_pos_emb`: Fused rotary position embedding computation 21 | 22 | ### Vision Transformer Optimization 23 | - `get_window_index`: Window attention index generation 24 | 25 | 26 | ## Acknowledgments 27 | 28 | The `permute` and `unpermute` operators are adapted from [fanshiqing/grouped_gemm](https://github.com/fanshiqing/grouped_gemm). Thanks for their open-source contributions. 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from pathlib import Path 4 | from setuptools import setup, find_packages 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | cwd = Path(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | nvcc_flags = [ 10 | "-std=c++17", # NOTE: CUTLASS requires c++17 11 | "-DENABLE_BF16", # Enable BF16 for cuda_version >= 11 12 | ] 13 | 14 | env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 15 | 16 | if env_arch_list: 17 | # Let PyTorch builder to choose device to target for. 18 | device_capability = "" 19 | else: 20 | device_capability = torch.cuda.get_device_capability() 21 | device_capability = f"{device_capability[0]}{device_capability[1]}" 22 | 23 | if device_capability: 24 | nvcc_flags.extend( 25 | [ 26 | f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", 27 | f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", 28 | ] 29 | ) 30 | 31 | ext_modules = [ 32 | CUDAExtension( 33 | "wallx_csrc", 34 | [ 35 | "csrc/ops.cu", 36 | "csrc/dual_asym_grouped_gemm.cu", 37 | "csrc/permute.cu", 38 | "csrc/rope.cu", 39 | "csrc/rope_index.cu", 40 | "csrc/rot_pos.cu", 41 | "csrc/window_index.cu", 42 | ], 43 | include_dirs=[f"{cwd}/3rdparty/cutlass/include/", f"{cwd}/csrc"], 44 | extra_compile_args={ 45 | "cxx": ["-fopenmp", "-fPIC", "-Wno-strict-aliasing"], 46 | "nvcc": nvcc_flags, 47 | }, 48 | ) 49 | ] 50 | 51 | setup( 52 | name="wall_x", 53 | version="1.0.1", 54 | author="X2Robot Team", 55 | classifiers=[ 56 | "Programming Language :: Python :: 3", 57 | "License :: OSI Approved :: BSD License", 58 | "Operating System :: Unix", 59 | ], 60 | packages=find_packages(), 61 | ext_modules=ext_modules, 62 | cmdclass={"build_ext": BuildExtension}, 63 | ) 64 | -------------------------------------------------------------------------------- /scripts/compute_norm_stats.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import tqdm 4 | from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata 5 | from wall_x.data.load_lerobot_dataset import KEY_MAPPINGS 6 | import normalize 7 | import numpy as np 8 | import argparse 9 | 10 | 11 | def load_config(config_path): 12 | """Load configuration from YAML file.""" 13 | with open(config_path, "r") as f: 14 | config = yaml.load(f, Loader=yaml.FullLoader) 15 | 16 | config["data"]["model_type"] = config.get("model_type") 17 | 18 | return config 19 | 20 | 21 | def load_lerobot_dataset(repo_id, root, action_horizon, args): 22 | dataset_meta = LeRobotDatasetMetadata(repo_id) 23 | dataset = LeRobotDataset( 24 | repo_id, 25 | root=root, 26 | delta_timestamps={ 27 | key: [t / dataset_meta.fps for t in range(action_horizon)] 28 | for key in [KEY_MAPPINGS[repo_id]["action"]] 29 | }, 30 | video_backend="pyav", 31 | ) 32 | num_batches = len(dataset) // args.batch_size 33 | generator = torch.Generator() 34 | generator.manual_seed(args.seed) 35 | data_loader = torch.utils.data.DataLoader( 36 | dataset, 37 | batch_size=args.batch_size, 38 | shuffle=False, 39 | drop_last=True, 40 | generator=generator, 41 | num_workers=args.num_workers, 42 | persistent_workers=True if args.num_workers > 0 else False, 43 | ) 44 | return data_loader, num_batches 45 | 46 | 47 | if __name__ == "__main__": 48 | # set args 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--batch_size", type=int, default=256) 51 | parser.add_argument("--num_workers", type=int, default=2) 52 | parser.add_argument("--seed", type=int, default=0) 53 | args = parser.parse_args() 54 | 55 | # Configs 56 | path = "/path/to/config.yml" 57 | output_path = "/path/to/output" 58 | config = load_config(path) 59 | lerobot_config = config["data"]["lerobot_config"] 60 | repo_id = lerobot_config.get("repo_id", None) 61 | root = lerobot_config.get("root", None) 62 | assert repo_id is not None, "repo id is required" 63 | action_horizon = config["data"].get("action_horizon", 32) 64 | 65 | data_loader, num_batches = load_lerobot_dataset(repo_id, root, action_horizon, args) 66 | 67 | keys = ["state", "action"] 68 | stats = {key: normalize.RunningStats() for key in keys} 69 | for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"): 70 | for key in keys: 71 | stats[key].update(np.asarray(batch[KEY_MAPPINGS[repo_id][key]])) 72 | norm_stats = { 73 | KEY_MAPPINGS[repo_id][key]: stats.get_statistics() 74 | for key, stats in stats.items() 75 | } 76 | 77 | output_path = output_path + "/" + repo_id 78 | print(f"Writing stats to: {output_path}") 79 | normalize.save(output_path, norm_stats) 80 | -------------------------------------------------------------------------------- /workspace/lerobot_example/qwen25_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "Qwen2_5_VLForConditionalGeneration" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 151643, 7 | "eos_token_id": 151645, 8 | "vision_start_token_id": 151652, 9 | "vision_end_token_id": 151653, 10 | "vision_token_id": 151654, 11 | "image_token_id": 151655, 12 | "video_token_id": 151656, 13 | "hidden_act": "silu", 14 | "hidden_size": 2048, 15 | "initializer_range": 0.02, 16 | "intermediate_size": 11008, 17 | "max_position_embeddings": 128000, 18 | "max_window_layers": 70, 19 | "model_type": "qwen2_5_vl", 20 | "num_attention_heads": 16, 21 | "num_hidden_layers": 36, 22 | "num_key_value_heads": 2, 23 | "rms_norm_eps": 1e-06, 24 | "rope_theta": 1000000.0, 25 | "sliding_window": 32768, 26 | "tie_word_embeddings": true, 27 | "torch_dtype": "bfloat16", 28 | "transformers_version": "4.41.2", 29 | "_attn_implementation": "flash_attention_2", 30 | "use_cache": true, 31 | "use_sliding_window": false, 32 | "vision_config": { 33 | "depth": 32, 34 | "hidden_act": "silu", 35 | "hidden_size": 1280, 36 | "intermediate_size": 3420, 37 | "num_heads": 16, 38 | "in_chans": 3, 39 | "out_hidden_size": 2048, 40 | "patch_size": 14, 41 | "spatial_merge_size": 2, 42 | "spatial_patch_size": 14, 43 | "window_size": 112, 44 | "fullatt_block_indexes": [ 45 | 7, 46 | 15, 47 | 23, 48 | 31 49 | ], 50 | "tokens_per_second": 2, 51 | "temporal_patch_size": 2 52 | }, 53 | "rope_scaling": { 54 | "type": "mrope", 55 | "mrope_section": [ 56 | 16, 57 | 24, 58 | 24 59 | ] 60 | }, 61 | "vocab_size": 151936, 62 | "num_experts": 2, 63 | "experts":[ 64 | { 65 | "hidden_size": 2048, 66 | "intermediate_size": 11008, 67 | "hidden_act": "silu" 68 | }, 69 | { 70 | "hidden_size": 2048, 71 | "intermediate_size": 2048, 72 | "hidden_act": "silu" 73 | } 74 | ], 75 | "dof_config": { 76 | "follow_left_ee_cartesian_pos": 3, 77 | "follow_left_ee_rotation": 3, 78 | "follow_left_gripper": 1, 79 | "follow_right_ee_cartesian_pos": 3, 80 | "follow_right_ee_rotation": 3, 81 | "follow_right_gripper": 1, 82 | "head_actions": 2, 83 | "height": 1, 84 | "car_pose": 3 85 | }, 86 | "agent_pos_config": { 87 | "follow_left_ee_cartesian_pos": 3, 88 | "follow_left_ee_rotation": 3, 89 | "follow_left_gripper": 1, 90 | "follow_right_ee_cartesian_pos": 3, 91 | "follow_right_ee_rotation": 3, 92 | "follow_right_gripper": 1, 93 | "head_actions": 2, 94 | "height": 1, 95 | "car_pose": 3 96 | }, 97 | "noise_scheduler": { 98 | "beta_alpha": 1.5, 99 | "beta_beta": 1.0, 100 | "s": 0.999, 101 | "num_inference_timesteps": 5 102 | }, 103 | "dim_inputs": [2048,2048], 104 | "attention_moe": false, 105 | "mlp_moe": true 106 | } 107 | -------------------------------------------------------------------------------- /scripts/fake_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wall_x.model.qwen2_5_based.modeling_qwen2_5_vl_act import Qwen2_5_VLMoEForAction 3 | 4 | model_path = "/path/to/model" 5 | model = Qwen2_5_VLMoEForAction.from_pretrained(model_path) 6 | model.eval() 7 | 8 | # Gen Fake data 9 | batch_size = 1 10 | seq_length = 50 11 | 12 | torch.manual_seed(0) 13 | fake_input_ids = torch.randint( 14 | 0, len(model.processor.tokenizer), (batch_size, seq_length), dtype=torch.long 15 | ) 16 | fake_attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 17 | fake_moe_token_types = torch.zeros((batch_size, seq_length), dtype=torch.long) 18 | fake_position_ids = ( 19 | torch.arange(seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) 20 | ) 21 | fake_proprioception = torch.randn((batch_size, 1, 20), dtype=torch.float32) 22 | fake_agent_pos_mask = torch.ones((batch_size, 1, 20), dtype=torch.float32) 23 | fake_dof_mask = torch.ones((batch_size, 32, 20), dtype=torch.float32) 24 | fake_dataset_names = ["x2_normal"] 25 | 26 | 27 | device = "cuda" 28 | 29 | model = model.to(device) 30 | model = model.bfloat16() 31 | 32 | fake_input_ids = fake_input_ids.to(device) 33 | fake_attention_mask = fake_attention_mask.to(device) 34 | fake_moe_token_types = fake_moe_token_types.to(device) 35 | fake_position_ids = fake_position_ids.to(device) 36 | fake_proprioception = fake_proprioception.to(device).bfloat16() 37 | fake_agent_pos_mask = fake_agent_pos_mask.to(device).bfloat16() 38 | fake_dof_mask = fake_dof_mask.to(device).bfloat16() 39 | 40 | try: 41 | with torch.no_grad(): 42 | outputs = model( 43 | input_ids=fake_input_ids, 44 | attention_mask=fake_attention_mask, 45 | moe_token_types=fake_moe_token_types, 46 | position_ids=fake_position_ids, 47 | proprioception=fake_proprioception, 48 | agent_pos_mask=fake_agent_pos_mask, 49 | dof_mask=fake_dof_mask, 50 | dataset_names=fake_dataset_names, 51 | mode="validate", 52 | ) 53 | 54 | print("✅ Fake inference test successful!") 55 | print(f"Output logits shape: {outputs.logits.shape}") 56 | print(f"Output logits dtype: {outputs.logits.dtype}") 57 | print(f"Output logits device: {outputs.logits.device}") 58 | 59 | # Check if output is reasonable 60 | if outputs.logits.shape == (batch_size, seq_length, model.config.vocab_size): 61 | print("✅ Output shape correct") 62 | else: 63 | print("❌ Output shape incorrect") 64 | 65 | if not torch.isnan(outputs.logits).any(): 66 | print("✅ Output contains no NaN values") 67 | else: 68 | print("❌ Output contains NaN values") 69 | 70 | if not torch.isinf(outputs.logits).any(): 71 | print("✅ Output contains no infinity values") 72 | else: 73 | print("❌ Output contains infinity values") 74 | 75 | print("Output logits statistics:") 76 | print(f" Min value: {outputs.logits.min().item():.4f}") 77 | print(f" Max value: {outputs.logits.max().item():.4f}") 78 | print(f" Mean: {outputs.logits.mean().item():.4f}") 79 | print(f" Standard deviation: {outputs.logits.std().item():.4f}") 80 | 81 | except Exception as e: 82 | print(f"❌ Fake inference test failed: {e}") 83 | import traceback 84 | 85 | traceback.print_exc() 86 | -------------------------------------------------------------------------------- /scripts/vqa_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from transformers import AutoProcessor 4 | import yaml 5 | 6 | from wall_x.model.qwen2_5_based.modeling_qwen2_5_vl_act import Qwen2_5_VLMoEForAction 7 | 8 | 9 | class VQAWrapper(object): 10 | def __init__(self, model_path: str, train_config: dict): 11 | self.device = self._setup_device() 12 | self.processor = self._load_processor(model_path) 13 | self.model = self._load_model(model_path, train_config) 14 | 15 | def _setup_device(self) -> str: 16 | if torch.cuda.is_available(): 17 | return "cuda" 18 | else: 19 | return "cpu" 20 | 21 | def _load_processor(self, model_path: str) -> AutoProcessor: 22 | return AutoProcessor.from_pretrained(model_path, trust_remote_code=True) 23 | 24 | def _load_model( 25 | self, model_path: str, train_config: dict 26 | ) -> Qwen2_5_VLMoEForAction: 27 | model = Qwen2_5_VLMoEForAction.from_pretrained( 28 | model_path, train_config=train_config 29 | ) 30 | if self.device == "cuda": 31 | model = model.to(self.device, dtype=torch.bfloat16) 32 | else: 33 | model.to(self.device) 34 | model.eval() 35 | return model 36 | 37 | def generate(self, image: Image.Image, text: str, **kwargs) -> str: 38 | messages = [ 39 | { 40 | "role": "user", 41 | "content": [{"type": "image"}, {"type": "text", "text": text}], 42 | } 43 | ] 44 | text_prompt = self.processor.apply_chat_template( 45 | messages, tokenize=False, add_generation_prompt=True 46 | ) 47 | inputs = self.processor(text=[text_prompt], images=[image], return_tensors="pt") 48 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 49 | 50 | generation_params = { 51 | "max_new_tokens": 1024, # default value, can be overridden by kwargs 52 | "do_sample": False, 53 | "eos_token_id": self.processor.tokenizer.eos_token_id, 54 | "pad_token_id": self.processor.tokenizer.pad_token_id, 55 | **kwargs, 56 | } 57 | 58 | with torch.no_grad(): 59 | generated_ids = self.model.generate(**inputs, **generation_params) 60 | 61 | generated_ids = [ 62 | output_ids[len(input_ids) :] 63 | for input_ids, output_ids in zip(inputs["input_ids"], generated_ids) 64 | ] 65 | response = self.processor.batch_decode( 66 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False 67 | )[0] 68 | return response 69 | 70 | 71 | if __name__ == "__main__": 72 | MODEL_PATH_FOR_MODULE_TEST = "/path/to/model" 73 | train_config_path = "/path/to/config.yaml" 74 | with open(train_config_path, "r") as f: 75 | train_config = yaml.load(f, Loader=yaml.FullLoader) 76 | wrapper = VQAWrapper( 77 | model_path=MODEL_PATH_FOR_MODULE_TEST, train_config=train_config 78 | ) 79 | 80 | try: 81 | test_question = "To move the red block in the plate with same color, what should you do next? Think step by step." 82 | 83 | # Local Image 84 | img = Image.open("/path/to/wall-x/assets/cot_example_frame.png").convert("RGB") 85 | # Internet Image 86 | # import requests 87 | # test_image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" 88 | # img = Image.open(requests.get(test_image_url, stream=True).raw).convert("RGB") 89 | 90 | answer = wrapper.generate(img, test_question) 91 | 92 | print("model answer:", answer) 93 | except Exception as e: 94 | print(f"model answer fail: {e}") 95 | -------------------------------------------------------------------------------- /wall_x/data/config.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | from dataclasses import dataclass, field 3 | from qwen_vl_utils.vision_process import MIN_PIXELS, MAX_PIXELS, IMAGE_FACTOR 4 | 5 | 6 | # Tactile sensor file mapping for data processing 7 | TACTILE_FILE_MAPPING = { 8 | "tactile_data_left": "left_tactile", 9 | "tactile_data_right": "right_tactile", 10 | } 11 | 12 | # Supported action datasets 13 | ACTION_DATASET_NAMES = [ 14 | "x2_normal", 15 | "agibotworld_alpha", 16 | "droid", 17 | "fractal", 18 | "bridge_data_v2", 19 | "DobbE", 20 | "RH20T", 21 | "UMI-biarm", 22 | "austin_buds", 23 | "austin_sailor", 24 | "austin_sirius", 25 | "bc_z", 26 | "berkeley_autolab_ur5", 27 | "berkeley_cable_routing", 28 | "berkeley_fanuc_manipulation", 29 | "dlr_edan_shared_control", 30 | "fmb", 31 | "furniture_bench", 32 | "jaco_play", 33 | "nyu_rot", 34 | "stanford_hydra", 35 | "stanford_kuka_multimodal", 36 | "taco_play", 37 | "utaustin_mutex", 38 | "viola", 39 | "physical-intelligence/libero", 40 | "lerobot/aloha_mobile_cabinet", 41 | ] 42 | 43 | # Supported multimodal datasets 44 | MULTIMODAL_DATASET_NAMES = [ 45 | "x2_multimodal_from_action", 46 | "x2_multimodal", 47 | "x2_subtask_generation", 48 | "multimodal_CapsFusion", 49 | "multimodal_Robo2VLM", 50 | "multimodal_RoboPoint", 51 | "multimodal_EQA", 52 | "multimodal_Cambrian", 53 | "multimodal_pixmo", 54 | "multimodal_VQAv2", 55 | "multimodal_COCO", 56 | ] 57 | 58 | 59 | @dataclass 60 | class X2RDataProcessingConfig: 61 | """Configuration class for X2R data processing pipeline. 62 | 63 | This class contains all the necessary parameters for processing robotic data 64 | including camera mappings, tactile sensor configurations, action predictions, 65 | and various processing options. 66 | """ 67 | 68 | # Action prediction configuration 69 | predict_action_keys: List[str] = field(default_factory=list) 70 | obs_action_keys: List[str] = field(default_factory=list) 71 | 72 | # Image resolution settings for different views 73 | resolution: Dict[str, int] = field( 74 | default_factory=lambda: { 75 | "face_view": -1, 76 | "left_wrist_view": 128, 77 | "right_wrist_view": 128, 78 | } 79 | ) 80 | 81 | # Dataset splitting 82 | train_test_split: float = 0.9 83 | split_seed: int = 42 84 | 85 | # Instruction handling 86 | priority_order: Optional[Dict[str, float]] = None 87 | 88 | # Vision model parameters 89 | model_type: str = "qwen2_5" 90 | max_pixels: int = MAX_PIXELS 91 | min_pixels: int = MIN_PIXELS 92 | image_factor: int = IMAGE_FACTOR 93 | 94 | generate_subtask_ratio: float = 0.0 95 | 96 | def __post_init__(self): 97 | """Post-initialization validation and setup.""" 98 | # Validate train/test split 99 | if not 0 < self.train_test_split < 1: 100 | raise ValueError( 101 | f"train_test_split must be between 0 and 1, got {self.train_test_split}" 102 | ) 103 | 104 | def as_dict(self) -> Dict: 105 | """Convert configuration to dictionary format. 106 | 107 | Returns: 108 | Dict: Configuration as dictionary 109 | """ 110 | return self.__dict__ 111 | 112 | def update(self, **kwargs) -> "X2RDataProcessingConfig": 113 | """Update configuration parameters. 114 | 115 | Args: 116 | **kwargs: Key-value pairs to update 117 | 118 | Returns: 119 | X2RDataProcessingConfig: Updated configuration instance 120 | """ 121 | for key, value in kwargs.items(): 122 | if hasattr(self, key): 123 | setattr(self, key, value) 124 | else: 125 | raise ValueError(f"Unknown configuration parameter: {key}") 126 | return self 127 | -------------------------------------------------------------------------------- /workspace/lerobot_example/libero/config_qact_libero_from_vlm.yml: -------------------------------------------------------------------------------- 1 | # Training Configuration for Wall-X Robotic Multi-Modal Learning 2 | # This configuration supports multi-modal learning with vision, language, and action data 3 | 4 | # Model and paths configuration 5 | log_name: "opensource_training" 6 | log_project: "libero" 7 | model_type: qwen2_5 8 | use_fast_tokenizer: true 9 | pretrained_wallx_path: "/path/to/qwen/" 10 | action_tokenizer_path: "/path/to/fast/" 11 | qwen_vl_act_config_path: "/path/to/qwen25_config.json" 12 | 13 | save_path: "/path/to/save" 14 | # Torch Profile 15 | profile: False 16 | profile_save_path: /path/to/profile/ 17 | profile_wait_iters: 10 18 | profile_warmup_iters: 5 19 | profile_active_iters: 2 20 | 21 | # Training hyperparameters 22 | num_warmup_steps: 100 23 | num_training_steps: 64000000 24 | learning_rate: 0.00005 25 | min_lr: 0.00005 26 | num_epoch: 100 27 | gradient_accumulation_steps: 1 28 | batch_size_per_gpu: 8 29 | padding_side: left 30 | epoch_save_interval: 1 31 | 32 | # Robot configuration - Define degrees of freedom for each component 33 | dof_config: 34 | follow_left_ee_cartesian_pos: 3 # Left end-effector Cartesian position 35 | follow_left_ee_rotation: 3 # Left end-effector rotation 36 | follow_left_gripper: 1 # Left gripper control 37 | follow_right_ee_cartesian_pos: 3 # Right end-effector Cartesian position 38 | follow_right_ee_rotation: 3 # Right end-effector rotation 39 | follow_right_gripper: 1 # Right gripper control 40 | head_actions: 2 # Head/camera movement 41 | height: 1 # Mobile base height control 42 | car_pose: 3 # Mobile base pose (x, y, theta) 43 | 44 | # Agent proprioception configuration (typically matches DOF config) 45 | agent_pos_config: 46 | follow_left_ee_cartesian_pos: 3 47 | follow_left_ee_rotation: 3 48 | follow_left_gripper: 1 49 | follow_right_ee_cartesian_pos: 3 50 | follow_right_ee_rotation: 3 51 | follow_right_gripper: 1 52 | head_actions: 2 53 | height: 1 54 | car_pose: 3 55 | 56 | norm_stats_path: "wall-x/workspace/lerobot_example/libero/libero_norm_stats.json" 57 | 58 | enable_customized_robot_config: true 59 | customized_robot_config: 60 | name: "physical-intelligence/libero" 61 | customized_dof_config: 62 | "panda_action_eef_with_gripper": 7 63 | 64 | customized_agent_pos_config: 65 | "panda_state_eef_with_gripper": 8 66 | 67 | # Checkpoint resuming configuration 68 | # resume: 69 | # ckpt: "/path/to/ckpt" 70 | # load_ckpt_only: false 71 | 72 | # Data configuration 73 | data: 74 | use_lerobot: true 75 | 76 | # LeRobot dataset configuration 77 | lerobot_config: 78 | repo_id: "physical-intelligence/libero" 79 | root: null 80 | episodes: null 81 | image_transforms: null 82 | delta_timestamps: null 83 | tolerance_s: 1e-4 84 | revision: null 85 | force_cache_sync: false 86 | download_videos: true 87 | video_backend: null 88 | 89 | action_horizon: 32 90 | train_test_split: 0.95 91 | 92 | # Action keys for observation and prediction 93 | obs_action_keys: 94 | - follow_left_ee_cartesian_pos 95 | - follow_left_ee_rotation 96 | - follow_left_gripper 97 | - follow_right_ee_cartesian_pos 98 | - follow_right_ee_rotation 99 | - follow_right_gripper 100 | - head_actions 101 | - height 102 | - car_pose 103 | 104 | predict_action_keys: 105 | - follow_left_ee_cartesian_pos 106 | - follow_left_ee_rotation 107 | - follow_left_gripper 108 | - follow_right_ee_cartesian_pos 109 | - follow_right_ee_rotation 110 | - follow_right_gripper 111 | - head_actions 112 | - height 113 | - car_pose 114 | 115 | # Image resolution configuration for different camera views 116 | resolution: 117 | face_view: 256 118 | left_wrist_view: 256 119 | right_wrist_view: 256 120 | move1_view: 256 121 | move2_view: 256 122 | top_view: 256 123 | wall_view: 256 124 | multi_modal: 256 125 | -------------------------------------------------------------------------------- /train_qact.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import yaml 5 | import wandb 6 | import accelerate 7 | from argparse import ArgumentParser 8 | from accelerate import ( 9 | Accelerator, 10 | DistributedDataParallelKwargs, 11 | DataLoaderConfiguration, 12 | ) 13 | 14 | from wall_x.trainer.qwen_vl_act_trainer import QwenVlAct_Trainer 15 | 16 | 17 | def setup_environment(): 18 | """Set up environment variables for training.""" 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | 22 | def load_config(config_path): 23 | """Load configuration from YAML file.""" 24 | with open(config_path, "r") as f: 25 | config = yaml.load(f, Loader=yaml.FullLoader) 26 | 27 | # Set model_type in data config if not already set 28 | config["data"]["model_type"] = config.get("model_type") 29 | 30 | return config 31 | 32 | 33 | def setup_accelerator(config): 34 | """Initialize and configure the accelerator for distributed training.""" 35 | print( 36 | f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] Preparing accelerator" 37 | ) 38 | 39 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 40 | accelerator_dataloader_config = DataLoaderConfiguration(dispatch_batches=False) 41 | 42 | if config.get("FSDP2", False): 43 | # Use Fully Sharded Data Parallel (FSDP) version 2 44 | fsdp_plugin = accelerate.utils.dataclasses.FullyShardedDataParallelPlugin( 45 | fsdp_version=2, reshard_after_forward=True 46 | ) 47 | print("[INFO] Using FSDP version 2 for distributed training") 48 | else: 49 | fsdp_plugin = None 50 | 51 | if config.get("torch_compile", False): 52 | # Use Torch Dynamo for compilation 53 | dynamo_plugin = accelerate.utils.TorchDynamoPlugin( 54 | backend="inductor", 55 | mode="default", 56 | fullgraph=False, 57 | dynamic=False, 58 | ) 59 | print("[INFO] Using Torch Dynamo for compilation") 60 | else: 61 | dynamo_plugin = None 62 | 63 | accelerator = Accelerator( 64 | kwargs_handlers=[ddp_kwargs], 65 | mixed_precision="bf16", 66 | fsdp_plugin=fsdp_plugin, 67 | dynamo_plugin=dynamo_plugin, 68 | dataloader_config=accelerator_dataloader_config, 69 | gradient_accumulation_steps=config.get("gradient_accumulation_steps", 1), 70 | ) 71 | 72 | print( 73 | f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] Accelerator initialization complete" 74 | ) 75 | 76 | return accelerator 77 | 78 | 79 | def setup_logging(config, accelerator): 80 | """Set up logging with wandb for the main process.""" 81 | if not accelerator.is_main_process: 82 | return None 83 | 84 | # Create save directory if it doesn't exist 85 | save_path = config["save_path"] 86 | if not os.path.exists(save_path): 87 | print(f"Save path {save_path} does not exist, creating directory.") 88 | os.makedirs(save_path, exist_ok=True) 89 | 90 | print("Configuration:") 91 | print("=" * 50) 92 | print(json.dumps(config, indent=2, ensure_ascii=False)) 93 | print("=" * 50) 94 | 95 | # Initialize wandb logger 96 | logger = wandb.init( 97 | project=config["log_project"], 98 | name=config["log_name"], 99 | save_code=False, 100 | force=False, 101 | ) 102 | 103 | return logger 104 | 105 | 106 | def main(args): 107 | """Main training function.""" 108 | setup_environment() 109 | 110 | # Load configuration 111 | config = load_config(args.config) 112 | 113 | # Set up accelerator 114 | accelerator = setup_accelerator(config) 115 | 116 | # Set up logging 117 | logger = setup_logging(config, accelerator) 118 | 119 | # Initialize trainer 120 | trainer = QwenVlAct_Trainer( 121 | config=config, 122 | logger=logger, 123 | accelerator=accelerator, 124 | seed=args.seed, 125 | data_config_path=args.config, 126 | ) 127 | 128 | # Start training 129 | trainer.fit() 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = ArgumentParser(description="Training script for Wall-X model") 134 | parser.add_argument( 135 | "--config", type=str, required=True, help="Path to configuration YAML file" 136 | ) 137 | parser.add_argument( 138 | "--seed", type=int, default=42, help="Random seed for reproducibility" 139 | ) 140 | 141 | args = parser.parse_args() 142 | main(args) 143 | -------------------------------------------------------------------------------- /scripts/draw_openloop_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import argparse 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | from wall_x.model.qwen2_5_based.modeling_qwen2_5_vl_act import Qwen2_5_VLMoEForAction 8 | from wall_x.data.load_lerobot_dataset import load_test_dataset, get_data_configs 9 | 10 | 11 | def load_config(config_path): 12 | """Load configuration from YAML file.""" 13 | with open(config_path, "r") as f: 14 | config = yaml.load(f, Loader=yaml.FullLoader) 15 | 16 | config["data"]["model_type"] = config.get("model_type") 17 | 18 | return config 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--pred_horizon", type=int, default=32) 24 | parser.add_argument("--origin_action_dim", type=int, default=7) 25 | args = parser.parse_args() 26 | 27 | origin_action_dim = args.origin_action_dim 28 | pred_horizon = args.pred_horizon 29 | 30 | # get train config 31 | model_path = "/path/to/model" 32 | action_tokenizer_path = "/path/to/action/tokenizer" 33 | save_dir = "/path/to/save/dir" 34 | path = "/path/to/train/config" 35 | config = load_config(path) 36 | 37 | # load model with customized robot config 38 | model = Qwen2_5_VLMoEForAction.from_pretrained( 39 | model_path, train_config=config, action_tokenizer_path=action_tokenizer_path 40 | ) 41 | model.eval() 42 | model = model.to("cuda") 43 | model = model.bfloat16() 44 | 45 | # get test dataloader 46 | dataload_config = get_data_configs(config["data"]) 47 | lerobot_config = dataload_config.get("lerobot_config", {}) 48 | dataset = load_test_dataset(config, lerobot_config, seed=42) 49 | dataloader = dataset.get_dataloader() 50 | 51 | total_frames = len(dataloader) 52 | 53 | predict_mode = "fast" if config.get("use_fast_tokenizer", False) else "diffusion" 54 | action_dim = 20 if predict_mode == "diffusion" else origin_action_dim 55 | gt_traj = torch.zeros((total_frames, origin_action_dim)) 56 | pred_traj = torch.zeros((total_frames, origin_action_dim)) 57 | 58 | # use tqdm to show the progress 59 | for idx, batch in tqdm( 60 | enumerate(dataloader), total=total_frames, desc="predicting" 61 | ): 62 | if idx % pred_horizon == 0 and idx + pred_horizon < total_frames: 63 | batch = batch.to("cuda") 64 | with torch.no_grad(): 65 | outputs = model( 66 | **batch, 67 | action_dim=action_dim, 68 | pred_horizon=pred_horizon, 69 | mode="predict", 70 | predict_mode=predict_mode, 71 | ) 72 | pred_traj[idx : idx + pred_horizon] = ( 73 | outputs["predict_action"][:, :, :origin_action_dim] 74 | .detach() 75 | .cpu() 76 | .squeeze(0) 77 | ) 78 | 79 | # Denormalize ground truth actions 80 | gt_action_chunk = batch["action_chunk"][:, :, :origin_action_dim] 81 | dof_mask = batch["dof_mask"].to(gt_action_chunk.dtype) 82 | denormalized_gt = ( 83 | model.action_preprocessor.normalizer_action.unnormalize_data( 84 | gt_action_chunk, 85 | [lerobot_config.get("repo_id", "physical-intelligence/libero")], 86 | dof_mask, 87 | ).squeeze(0) 88 | ) 89 | gt_traj[idx : idx + pred_horizon] = denormalized_gt.detach().cpu() 90 | 91 | gt_traj_np = gt_traj.numpy() 92 | pred_traj_np = pred_traj.numpy() 93 | 94 | timesteps = gt_traj.shape[0] 95 | 96 | fig, axs = plt.subplots( 97 | origin_action_dim, 1, figsize=(15, 5 * origin_action_dim), sharex=True 98 | ) 99 | fig.suptitle("Action Comparison for lerobot", fontsize=16) 100 | 101 | for i in range(origin_action_dim): 102 | axs[i].plot(range(timesteps), gt_traj_np[:, i], label="Ground Truth") 103 | axs[i].plot(range(timesteps), pred_traj_np[:, i], label="Prediction") 104 | axs[i].set_ylabel(f"Action Dim {i+1}") 105 | axs[i].legend() 106 | axs[i].grid(True) 107 | 108 | axs[-1].set_xlabel("Timestep") 109 | plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 110 | os.makedirs(save_dir, exist_ok=True) 111 | save_path = os.path.join(save_dir, "lerobot_comparison.png") 112 | plt.savefig(save_path) 113 | print(f"Saved plot to {save_path}") 114 | plt.close() 115 | -------------------------------------------------------------------------------- /wall_x/serving/websocket_policy_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import http 3 | import logging 4 | import time 5 | import traceback 6 | from typing import Any, Dict, Optional 7 | 8 | try: 9 | import msgpack 10 | import msgpack_numpy as m 11 | 12 | m.patch() 13 | except ImportError: 14 | logging.warning( 15 | "msgpack-numpy not installed. Install with: pip install msgpack-numpy" 16 | ) 17 | msgpack = None 18 | 19 | import websockets.asyncio.server as _server 20 | import websockets.frames 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class BasePolicy: 26 | """Base class for policies that can be served.""" 27 | 28 | def infer(self, obs: Dict) -> Dict: 29 | """Infer actions from observations.""" 30 | raise NotImplementedError 31 | 32 | def reset(self) -> None: 33 | """Reset the policy to its initial state.""" 34 | pass 35 | 36 | @property 37 | def metadata(self) -> Dict[str, Any]: 38 | """Return metadata about the policy.""" 39 | return {} 40 | 41 | 42 | class WebsocketPolicyServer: 43 | """Serves a policy using the websocket protocol. 44 | 45 | Implements a websocket server that: 46 | 1. Sends policy metadata on connection 47 | 2. Receives observations 48 | 3. Returns predicted actions 49 | 4. Tracks timing information 50 | """ 51 | 52 | def __init__( 53 | self, 54 | policy: BasePolicy, 55 | host: str = "0.0.0.0", 56 | port: int = 8000, 57 | metadata: Optional[Dict] = None, 58 | ) -> None: 59 | self._policy = policy 60 | self._host = host 61 | self._port = port 62 | self._metadata = metadata or {} 63 | logging.getLogger("websockets.server").setLevel(logging.INFO) 64 | 65 | def serve_forever(self) -> None: 66 | asyncio.run(self.run()) 67 | 68 | async def run(self): 69 | async with _server.serve( 70 | self._handler, 71 | self._host, 72 | self._port, 73 | compression=None, 74 | max_size=None, 75 | ping_interval=None, # Disable automatic ping for long-running inference 76 | ping_timeout=None, # Disable ping timeout 77 | process_request=_health_check, 78 | ) as server: 79 | logger.info(f"Server started on {self._host}:{self._port}") 80 | await server.serve_forever() 81 | 82 | async def _handler(self, websocket: _server.ServerConnection): 83 | logger.info(f"Connection from {websocket.remote_address} opened") 84 | 85 | if msgpack is None: 86 | await websocket.close( 87 | code=websockets.frames.CloseCode.INTERNAL_ERROR, 88 | reason="msgpack-numpy not installed on server", 89 | ) 90 | return 91 | 92 | # Send metadata to client 93 | await websocket.send(msgpack.packb(self._metadata)) 94 | 95 | prev_total_time = None 96 | while True: 97 | try: 98 | start_time = time.monotonic() 99 | obs = msgpack.unpackb(await websocket.recv()) 100 | 101 | infer_time = time.monotonic() 102 | action = self._policy.infer(obs) 103 | infer_time = time.monotonic() - infer_time 104 | 105 | action["server_timing"] = { 106 | "infer_ms": infer_time * 1000, 107 | } 108 | if prev_total_time is not None: 109 | action["server_timing"]["prev_total_ms"] = prev_total_time * 1000 110 | 111 | await websocket.send(msgpack.packb(action)) 112 | prev_total_time = time.monotonic() - start_time 113 | 114 | except websockets.ConnectionClosed: 115 | logger.info(f"Connection from {websocket.remote_address} closed") 116 | break 117 | except Exception as e: 118 | logger.error(f"Error handling request: {e}") 119 | await websocket.send(traceback.format_exc()) 120 | await websocket.close( 121 | code=websockets.frames.CloseCode.INTERNAL_ERROR, 122 | reason="Internal server error. Traceback included in previous frame.", 123 | ) 124 | raise 125 | 126 | 127 | def _health_check( 128 | connection: _server.ServerConnection, request: _server.Request 129 | ) -> Optional[_server.Response]: 130 | if request.path == "/healthz": 131 | return connection.respond(http.HTTPStatus.OK, "OK\n") 132 | return None 133 | -------------------------------------------------------------------------------- /workspace/lerobot_example/config_qact_from_vlm.yml: -------------------------------------------------------------------------------- 1 | # Train from Qwen-2.5-VL 2 | 3 | # Model and paths configuration 4 | log_name: "robotic_training" 5 | log_project: "vla_training" 6 | model_type: qwen2_5 7 | pretrained_wallx_path: "/path/to/wallx_model/" # Must set 8 | save_path: "/path/to/workspace/" # Must set 9 | use_fast_tokenizer: True # True: train FAST, False: train Flow 10 | action_tokenizer_path: "/path/to/fast/" # Must set if use_fast_tokenizer is true 11 | qwen_vl_act_config_path: "wall-x/workspace/lerobot_example/qwen25_config.json" 12 | 13 | 14 | # Torch Profile 15 | profile: False 16 | profile_save_path: /path/to/profile/ 17 | profile_wait_iters: 10 18 | profile_warmup_iters: 5 19 | profile_active_iters: 2 20 | 21 | # Training hyperparameters 22 | num_warmup_steps: 100 23 | num_training_steps: 64000000 24 | learning_rate: 0.00009 25 | min_lr: 0.00005 26 | num_epoch: 100 27 | gradient_accumulation_steps: 1 28 | batch_size_per_gpu: 8 29 | padding_side: left 30 | epoch_save_interval: 10 31 | 32 | # Training optimization settings 33 | FSDP2: True 34 | torch_compile: False 35 | 36 | # Robot configuration - Define degrees of freedom for each component 37 | dof_config: 38 | follow_left_ee_cartesian_pos: 3 # Left end-effector Cartesian position 39 | follow_left_ee_rotation: 3 # Left end-effector rotation 40 | follow_left_gripper: 1 # Left gripper control 41 | follow_right_ee_cartesian_pos: 3 # Right end-effector Cartesian position 42 | follow_right_ee_rotation: 3 # Right end-effector rotation 43 | follow_right_gripper: 1 # Right gripper control 44 | head_actions: 2 # Head/camera movement 45 | height: 1 # Mobile base height control 46 | car_pose: 3 # Mobile base pose (x, y, theta) 47 | 48 | # Agent proprioception configuration (typically matches DOF config) 49 | agent_pos_config: 50 | follow_left_ee_cartesian_pos: 3 51 | follow_left_ee_rotation: 3 52 | follow_left_gripper: 1 53 | follow_right_ee_cartesian_pos: 3 54 | follow_right_ee_rotation: 3 55 | follow_right_gripper: 1 56 | head_actions: 2 57 | height: 1 58 | car_pose: 3 59 | 60 | # # Checkpoint resuming configuration 61 | # resume: 62 | # ckpt: "/path/to/resume_model/" 63 | # load_ckpt_only: true 64 | 65 | norm_stats_path: "/path/to/norm_stats.json" 66 | 67 | enable_customized_robot_config: true 68 | customized_robot_config: 69 | name: "physical-intelligence/libero" 70 | customized_dof_config: 71 | "action_left_shoulder" : 1 72 | "action_left_elbow" : 1 73 | "action_left_forearm_roll" : 1 74 | "action_left_wrist_angle" : 1 75 | "action_left_wrist_rotate" : 1 76 | "action_left_gripper" : 1 77 | "action_right_waist" : 1 78 | "action_right_shoulder" : 1 79 | "action_right_elbow" : 1 80 | "action_right_forearm_roll" : 1 81 | "action_right_wrist_angle" : 1 82 | "action_right_wrist_rotate" : 1 83 | "action_right_gripper" : 1 84 | 85 | customized_agent_pos_config: 86 | "state_left_shoulder" : 1 87 | "state_left_elbow" : 1 88 | "state_left_forearm_roll" : 1 89 | "state_left_wrist_angle" : 1 90 | "state_left_wrist_rotate" : 1 91 | "state_left_gripper" : 1 92 | "state_right_waist" : 1 93 | "state_right_shoulder" : 1 94 | "state_right_elbow" : 1 95 | "state_right_forearm_roll" : 1 96 | "state_right_wrist_angle" : 1 97 | "state_right_wrist_rotate" : 1 98 | "state_right_gripper" : 1 99 | 100 | # Data configuration 101 | data: 102 | use_lerobot: true 103 | 104 | # LeRobot dataset configuration 105 | lerobot_config: 106 | repo_id: "lerobot/aloha_mobile_cabinet" 107 | root: null 108 | episodes: null 109 | image_transforms: null 110 | delta_timestamps: null 111 | tolerance_s: 1e-4 112 | revision: null 113 | force_cache_sync: false 114 | download_videos: true 115 | video_backend: null 116 | 117 | action_horizon: 32 118 | train_test_split: 0.95 119 | 120 | # Action keys for observation and prediction 121 | obs_action_keys: 122 | - follow_left_ee_cartesian_pos 123 | - follow_left_ee_rotation 124 | - follow_left_gripper 125 | - follow_right_ee_cartesian_pos 126 | - follow_right_ee_rotation 127 | - follow_right_gripper 128 | - head_actions 129 | - height 130 | - car_pose 131 | 132 | predict_action_keys: 133 | - follow_left_ee_cartesian_pos 134 | - follow_left_ee_rotation 135 | - follow_left_gripper 136 | - follow_right_ee_cartesian_pos 137 | - follow_right_ee_rotation 138 | - follow_right_gripper 139 | - head_actions 140 | - height 141 | - car_pose 142 | 143 | # Image resolution configuration for different camera views 144 | resolution: 145 | face_view: 256 146 | left_wrist_view: 256 147 | right_wrist_view: 256 148 | move1_view: 256 149 | move2_view: 256 150 | top_view: 256 151 | wall_view: 256 152 | multi_modal: 256 153 | -------------------------------------------------------------------------------- /workspace/lerobot_example/config_qact.yml: -------------------------------------------------------------------------------- 1 | # Training Configuration for Wall-X Robotic Multi-Modal Learning 2 | # This configuration supports multi-modal learning with vision, language, and action data 3 | 4 | # Model and paths configuration 5 | log_name: "robotic_training" 6 | log_project: "vla_training" 7 | model_type: wall-oss 8 | pretrained_wallx_path: "/path/to/wallx_model/" # Must set 9 | save_path: "/path/to/workspace/" # Must set 10 | use_fast_tokenizer: False # True: train FAST, False: train Flow 11 | action_tokenizer_path: "/path/to/fast/" # Must set if use_fast_tokenizer is true 12 | 13 | # Torch Profile 14 | profile: False 15 | profile_save_path: /path/to/profile/ 16 | profile_wait_iters: 10 17 | profile_warmup_iters: 5 18 | profile_active_iters: 2 19 | 20 | # Training hyperparameters 21 | num_warmup_steps: 100 22 | num_training_steps: 64000000 23 | learning_rate: 0.00005 24 | min_lr: 0.00005 25 | num_epoch: 100 26 | gradient_accumulation_steps: 32 27 | batch_size_per_gpu: 8 28 | padding_side: left 29 | epoch_save_interval: 10 30 | 31 | # Training optimization settings 32 | FSDP2: True 33 | torch_compile: False 34 | 35 | # Robot configuration - Define degrees of freedom for each component 36 | dof_config: 37 | follow_left_ee_cartesian_pos: 3 # Left end-effector Cartesian position 38 | follow_left_ee_rotation: 3 # Left end-effector rotation 39 | follow_left_gripper: 1 # Left gripper control 40 | follow_right_ee_cartesian_pos: 3 # Right end-effector Cartesian position 41 | follow_right_ee_rotation: 3 # Right end-effector rotation 42 | follow_right_gripper: 1 # Right gripper control 43 | head_actions: 2 # Head/camera movement 44 | height: 1 # Mobile base height control 45 | car_pose: 3 # Mobile base pose (x, y, theta) 46 | 47 | # Agent proprioception configuration (typically matches DOF config) 48 | agent_pos_config: 49 | follow_left_ee_cartesian_pos: 3 50 | follow_left_ee_rotation: 3 51 | follow_left_gripper: 1 52 | follow_right_ee_cartesian_pos: 3 53 | follow_right_ee_rotation: 3 54 | follow_right_gripper: 1 55 | head_actions: 2 56 | height: 1 57 | car_pose: 3 58 | 59 | # # Checkpoint resuming configuration 60 | # resume: 61 | # ckpt: "/path/to/resume_model/" 62 | # load_ckpt_only: true 63 | 64 | norm_stats_path: "/path/to/norm_stats.json" 65 | 66 | enable_customized_robot_config: true 67 | customized_robot_config: 68 | name: "lerobot/aloha_mobile_cabinet" 69 | customized_dof_config: 70 | "action_left_shoulder" : 1 71 | "action_left_elbow" : 1 72 | "action_left_forearm_roll" : 1 73 | "action_left_wrist_angle" : 1 74 | "action_left_wrist_rotate" : 1 75 | "action_left_gripper" : 1 76 | "action_right_waist" : 1 77 | "action_right_shoulder" : 1 78 | "action_right_elbow" : 1 79 | "action_right_forearm_roll" : 1 80 | "action_right_wrist_angle" : 1 81 | "action_right_wrist_rotate" : 1 82 | "action_right_gripper" : 1 83 | 84 | customized_agent_pos_config: 85 | "state_left_shoulder" : 1 86 | "state_left_elbow" : 1 87 | "state_left_forearm_roll" : 1 88 | "state_left_wrist_angle" : 1 89 | "state_left_wrist_rotate" : 1 90 | "state_left_gripper" : 1 91 | "state_right_waist" : 1 92 | "state_right_shoulder" : 1 93 | "state_right_elbow" : 1 94 | "state_right_forearm_roll" : 1 95 | "state_right_wrist_angle" : 1 96 | "state_right_wrist_rotate" : 1 97 | "state_right_gripper" : 1 98 | 99 | # Data configuration 100 | data: 101 | use_lerobot: true 102 | 103 | # LeRobot dataset configuration 104 | lerobot_config: 105 | repo_id: "lerobot/aloha_mobile_cabinet" 106 | root: null 107 | episodes: null 108 | image_transforms: null 109 | delta_timestamps: null 110 | tolerance_s: 1e-4 111 | revision: null 112 | force_cache_sync: false 113 | download_videos: true 114 | video_backend: null 115 | 116 | action_horizon: 32 117 | train_test_split: 0.95 118 | 119 | # Action keys for observation and prediction 120 | obs_action_keys: 121 | - follow_left_ee_cartesian_pos 122 | - follow_left_ee_rotation 123 | - follow_left_gripper 124 | - follow_right_ee_cartesian_pos 125 | - follow_right_ee_rotation 126 | - follow_right_gripper 127 | - head_actions 128 | - height 129 | - car_pose 130 | 131 | predict_action_keys: 132 | - follow_left_ee_cartesian_pos 133 | - follow_left_ee_rotation 134 | - follow_left_gripper 135 | - follow_right_ee_cartesian_pos 136 | - follow_right_ee_rotation 137 | - follow_right_gripper 138 | - head_actions 139 | - height 140 | - car_pose 141 | 142 | # Image resolution configuration for different camera views 143 | resolution: 144 | face_view: 256 145 | left_wrist_view: 256 146 | right_wrist_view: 256 147 | move1_view: 256 148 | move2_view: 256 149 | top_view: 256 150 | wall_view: 256 151 | multi_modal: 256 152 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # OPERATING SYSTEM FILES 3 | # ============================================================================= 4 | 5 | # macOS 6 | .DS_Store 7 | .DS_Store? 8 | ._* 9 | .Spotlight-V100 10 | .Trashes 11 | ehthumbs.db 12 | Thumbs.db 13 | 14 | # Windows 15 | Thumbs.db 16 | ehthumbs.db 17 | Desktop.ini 18 | $RECYCLE.BIN/ 19 | *.cab 20 | *.msi 21 | *.msix 22 | *.msm 23 | *.msp 24 | 25 | # Linux 26 | *~ 27 | 28 | # ============================================================================= 29 | # PYTHON 30 | # ============================================================================= 31 | 32 | # Byte-compiled / optimized / DLL files 33 | __pycache__/ 34 | *.py[cod] 35 | *$py.class 36 | 37 | # C extensions 38 | *.so 39 | 40 | # Distribution / packaging 41 | .Python 42 | build/ 43 | develop-eggs/ 44 | dist/ 45 | downloads/ 46 | eggs/ 47 | .eggs/ 48 | lib/ 49 | lib64/ 50 | parts/ 51 | sdist/ 52 | var/ 53 | wheels/ 54 | pip-wheel-metadata/ 55 | share/python-wheels/ 56 | *.egg-info/ 57 | .installed.cfg 58 | *.egg 59 | MANIFEST 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .nox/ 75 | .coverage 76 | .coverage.* 77 | .cache 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | *.py,cover 82 | .hypothesis/ 83 | .pytest_cache/ 84 | 85 | # Translations 86 | *.mo 87 | *.pot 88 | 89 | # Django stuff: 90 | local_settings.py 91 | db.sqlite3 92 | db.sqlite3-journal 93 | 94 | # Flask stuff: 95 | instance/ 96 | .webassets-cache 97 | 98 | # Scrapy stuff: 99 | .scrapy 100 | 101 | # Sphinx documentation 102 | docs/_build/ 103 | 104 | # PyBuilder 105 | target/ 106 | 107 | # Jupyter Notebook 108 | .ipynb_checkpoints 109 | *.ipynb 110 | 111 | # IPython 112 | profile_default/ 113 | ipython_config.py 114 | 115 | # pyenv 116 | .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # ============================================================================= 154 | # DEVELOPMENT ENVIRONMENT 155 | # ============================================================================= 156 | 157 | # VS Code 158 | .vscode/ 159 | *.code-workspace 160 | 161 | # PyCharm 162 | .idea/ 163 | 164 | # Vim 165 | *.swp 166 | *.swo 167 | *~ 168 | 169 | # Emacs 170 | *~ 171 | \#*\# 172 | /.emacs.desktop 173 | /.emacs.desktop.lock 174 | *.elc 175 | auto-save-list 176 | tramp 177 | .\#* 178 | 179 | # ============================================================================= 180 | # PROJECT SPECIFIC 181 | # ============================================================================= 182 | 183 | # Model checkpoints and weights 184 | *.pth 185 | *.pt 186 | *.ckpt 187 | ckpt/ 188 | checkpoints/ 189 | 190 | # Output directories 191 | outputs/ 192 | results/ 193 | logs/ 194 | bin/ 195 | 196 | # Media files 197 | *.jpg 198 | *.jpeg 199 | *.gif 200 | *.mp4 201 | *.avi 202 | *.mov 203 | *.wav 204 | *.mp3 205 | 206 | # Configuration files with sensitive data 207 | fuse.cfg 208 | ray/auth.json 209 | 210 | # AI/ML related 211 | *.ai 212 | 213 | # ============================================================================= 214 | # LOGGING AND MONITORING 215 | # ============================================================================= 216 | 217 | # General logs 218 | *.log 219 | *.err 220 | *.out 221 | 222 | # Weights & Biases 223 | wandb/ 224 | _wandb/ 225 | 226 | # TensorBoard 227 | runs/ 228 | tensorboard/ 229 | 230 | # MLflow 231 | mlruns/ 232 | 233 | # ============================================================================= 234 | # TEMPORARY AND CACHE FILES 235 | # ============================================================================= 236 | 237 | # Temporary files 238 | *.tmp 239 | *.temp 240 | *.bak 241 | *.backup 242 | 243 | # Cache directories 244 | .cache/ 245 | cache/ 246 | .ruff_cache/ 247 | 248 | # ============================================================================= 249 | # FONTS AND ASSETS (if not part of the project) 250 | # ============================================================================= 251 | 252 | # Font files (uncomment if fonts should not be tracked) 253 | # *.ttf 254 | # *.otf 255 | # *.woff 256 | # *.woff2 257 | 258 | # ============================================================================= 259 | # PERSONAL/LOCAL FILES 260 | # ============================================================================= 261 | 262 | # Personal notes and documentation 263 | code_update_info.md 264 | TODO.md 265 | NOTES.md 266 | 267 | # Local configuration 268 | .env 269 | .env.local 270 | .env.*.local -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wall-X 2 | 3 |
4 | 5 | 6 | 7 | Hugging Face 8 | 9 | 10 | Project Page 11 | 12 | 13 | 14 |
15 | Python 3.10 16 | PyTorch 17 | FlashAttention 18 | LeRobot 19 | CUDA 20 | Ubuntu 22.04 21 | 22 |
23 | 24 | ## Building General-Purpose Robots Based on Embodied Foundation Model 25 | We are building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. 26 | 27 | By creating a direct feedback loop between the model's decisions and the body's lived experience, we enable the emergence of a truly generalizable intelligence—one that understands not just how the world works, but how to act effectively within it. 28 | 29 | ## Repository 30 | This repository provides the training and inference code that supports our WALL series open-source embodied foundation models. It includes end-to-end pipelines for data preparation (LeRobot), model configuration, flow-matching and FAST action branches, and evaluation utilities for real and simulated robots. 31 | 32 | ## News 33 | - We introduce [**WALL-OSS: Igniting VLMs toward the Embodied Space**](https://x2robot.com/en/research/68bc2cde8497d7f238dde690), an end-to-end embodied foundation model that leverages large-scale multimodal pretraining to achieve (1) embodiment-aware vision–language understanding, (2) strong language–action association, and (3) robust manipulation capability. 34 | 35 | ## Models 36 | - WALL-OSS-FLOW: https://huggingface.co/x-square-robot/wall-oss-flow 37 | - WALL-OSS-FAST: https://huggingface.co/x-square-robot/wall-oss-fast 38 | 39 | ## Environment Setup 40 | 41 | Create and activate conda environment: 42 | ```bash 43 | conda create --name wallx python=3.10 44 | conda activate wallx 45 | ``` 46 | 47 | Install requirements: 48 | ```bash 49 | pip install -r requirements.txt 50 | MAX_JOBS=4 pip install flash-attn==2.7.4.post1 --no-build-isolation 51 | ``` 52 | 53 | Install lerobot: 54 | ```bash 55 | git clone https://github.com/huggingface/lerobot.git 56 | git checkout c66cd401767e60baece16e1cf68da2824227e076 57 | cd lerobot 58 | pip install -e . 59 | ``` 60 | 61 | Install wall_x: 62 | ```bash 63 | git submodule update --init --recursive 64 | MAX_JOBS=4 pip install --no-build-isolation --verbose -e . 65 | ``` 66 | 67 | ## Training 68 | 69 | ### Finetune on LeRobot Datasets 70 | 71 | Before training, please refer to `workspace/README.md` for detailed configuration instructions including: 72 | 73 | Training script path configuration 74 | 75 | - GPU setup 76 | - Model and data paths 77 | - Robot DOF configuration 78 | - Training hyperparameters 79 | 80 | Download the Flow/FAST pretrained model and run: 81 | ```bash 82 | bash ./workspace/lerobot_example/run.sh 83 | ``` 84 | 85 | ## Inference 86 | 87 | ### Basic Action Inference 88 | 89 | For model inference, please refer to: 90 | 91 | ```bash 92 | python ./scripts/fake_inference.py 93 | ``` 94 | 95 | This script demonstrates how to: 96 | - Load the Wall-OSS model using `Qwen2_5_VLMoEForAction.from_pretrained()` 97 | - Prepare input data including proprioceptive information, attention masks, and dataset specifications 98 | - Run inference in validation mode with proper data types (bfloat16) 99 | - Validate model outputs and check for numerical stability 100 | 101 | ### Open-Loop Evaluation 102 | 103 | To generate an open-loop comparison plot, please follow: 104 | 105 | ```bash 106 | python ./scripts/draw_openloop_plot.py 107 | ``` 108 | 109 | ### VQA Inference and Chain-of-Thought Testing 110 | 111 | To run VQA inference and test the model's Chain-of-Thought (COT) reasoning capabilities, please follow: 112 | 113 | ```bash 114 | python ./scripts/vqa_inference.py 115 | ``` 116 | 117 | This script can be used to test the model's COT reasoning abilities for embodied tasks. Below is an example of COT testing: 118 | 119 | **Input Image:** 120 | 121 | ![COT Example Frame](assets/cot_example_frame.png) 122 | 123 | **Input Text:** 124 | ``` 125 | To move the red block in the plate with same color, what should you do next? Think step by step. 126 | ``` 127 | 128 | **Model Output (COT Reasoning):** 129 | ``` 130 | To move the red block in the plate with the same color, you should first locate the red block. It is currently positioned on the table, not in the plate. Then, you should carefully grasp the red block using your fingers. Next, you should use your hand to lift the red block from the table and place it into the plate that is also red in color. Ensure that the red block is securely placed in the plate without slipping or falling. 131 | ``` 132 | 133 | ## Join Our Community 134 | - Scan the QR code on WeChat to join the discussion group, where you can engage in in-depth exchanges with community developers and the official team. 135 | QR Code 136 | 137 | ## 📚 Cite Us 138 | 139 | If you find WALL-OSS models useful, please cite: 140 | 141 | ```bibtex 142 | @article{zhai2025igniting, 143 | title = {Igniting VLMs Toward the Embodied Space}, 144 | author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach}, 145 | journal = {arXiv preprint arXiv:2509.11766}, 146 | year = {2025} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /scripts/normalize.py: -------------------------------------------------------------------------------- 1 | # This file is copied from openpi 2 | import json 3 | import pathlib 4 | 5 | import numpy as np 6 | import numpydantic 7 | import pydantic 8 | 9 | 10 | @pydantic.dataclasses.dataclass 11 | class NormStats: 12 | mean: numpydantic.NDArray 13 | std: numpydantic.NDArray 14 | q01: numpydantic.NDArray | None = None # 1st quantile 15 | q99: numpydantic.NDArray | None = None # 99th quantile 16 | 17 | 18 | class RunningStats: 19 | """Compute running statistics of a batch of vectors.""" 20 | 21 | def __init__(self): 22 | self._count = 0 23 | self._mean = None 24 | self._mean_of_squares = None 25 | self._min = None 26 | self._max = None 27 | self._histograms = None 28 | self._bin_edges = None 29 | self._num_quantile_bins = 5000 # for computing quantiles on the fly 30 | 31 | def update(self, batch: np.ndarray) -> None: 32 | """ 33 | Update the running statistics with a batch of vectors. 34 | 35 | Args: 36 | vectors (np.ndarray): An array where all dimensions except the last are batch dimensions. 37 | """ 38 | batch = batch.reshape(-1, batch.shape[-1]) 39 | num_elements, vector_length = batch.shape 40 | if self._count == 0: 41 | self._mean = np.mean(batch, axis=0) 42 | self._mean_of_squares = np.mean(batch**2, axis=0) 43 | self._min = np.min(batch, axis=0) 44 | self._max = np.max(batch, axis=0) 45 | self._histograms = [ 46 | np.zeros(self._num_quantile_bins) for _ in range(vector_length) 47 | ] 48 | self._bin_edges = [ 49 | np.linspace( 50 | self._min[i] - 1e-10, 51 | self._max[i] + 1e-10, 52 | self._num_quantile_bins + 1, 53 | ) 54 | for i in range(vector_length) 55 | ] 56 | else: 57 | if vector_length != self._mean.size: 58 | raise ValueError( 59 | "The length of new vectors does not match the initialized vector length." 60 | ) 61 | new_max = np.max(batch, axis=0) 62 | new_min = np.min(batch, axis=0) 63 | max_changed = np.any(new_max > self._max) 64 | min_changed = np.any(new_min < self._min) 65 | self._max = np.maximum(self._max, new_max) 66 | self._min = np.minimum(self._min, new_min) 67 | 68 | if max_changed or min_changed: 69 | self._adjust_histograms() 70 | 71 | self._count += num_elements 72 | 73 | batch_mean = np.mean(batch, axis=0) 74 | batch_mean_of_squares = np.mean(batch**2, axis=0) 75 | 76 | # Update running mean and mean of squares. 77 | self._mean += (batch_mean - self._mean) * (num_elements / self._count) 78 | self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * ( 79 | num_elements / self._count 80 | ) 81 | 82 | self._update_histograms(batch) 83 | 84 | def get_statistics(self) -> NormStats: 85 | """ 86 | Compute and return the statistics of the vectors processed so far. 87 | 88 | Returns: 89 | dict: A dictionary containing the computed statistics. 90 | """ 91 | if self._count < 2: 92 | raise ValueError("Cannot compute statistics for less than 2 vectors.") 93 | 94 | variance = self._mean_of_squares - self._mean**2 95 | stddev = np.sqrt(np.maximum(0, variance)) 96 | q01, q99 = self._compute_quantiles([0.01, 0.99]) 97 | return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99) 98 | 99 | def _adjust_histograms(self): 100 | """Adjust histograms when min or max changes.""" 101 | for i in range(len(self._histograms)): 102 | old_edges = self._bin_edges[i] 103 | new_edges = np.linspace( 104 | self._min[i], self._max[i], self._num_quantile_bins + 1 105 | ) 106 | 107 | # Redistribute the existing histogram counts to the new bins 108 | new_hist, _ = np.histogram( 109 | old_edges[:-1], bins=new_edges, weights=self._histograms[i] 110 | ) 111 | 112 | self._histograms[i] = new_hist 113 | self._bin_edges[i] = new_edges 114 | 115 | def _update_histograms(self, batch: np.ndarray) -> None: 116 | """Update histograms with new vectors.""" 117 | for i in range(batch.shape[1]): 118 | hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) 119 | self._histograms[i] += hist 120 | 121 | def _compute_quantiles(self, quantiles): 122 | """Compute quantiles based on histograms.""" 123 | results = [] 124 | for q in quantiles: 125 | target_count = q * self._count 126 | q_values = [] 127 | for hist, edges in zip(self._histograms, self._bin_edges, strict=True): 128 | cumsum = np.cumsum(hist) 129 | idx = np.searchsorted(cumsum, target_count) 130 | q_values.append(edges[idx]) 131 | results.append(np.array(q_values)) 132 | return results 133 | 134 | 135 | class _NormStatsDict(pydantic.BaseModel): 136 | norm_stats: dict[str, NormStats] 137 | 138 | 139 | def serialize_json(norm_stats: dict[str, NormStats]) -> str: 140 | """Serialize the running statistics to a JSON string.""" 141 | return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2) 142 | 143 | 144 | def deserialize_json(data: str) -> dict[str, NormStats]: 145 | """Deserialize the running statistics from a JSON string.""" 146 | return _NormStatsDict(**json.loads(data)).norm_stats 147 | 148 | 149 | def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None: 150 | """Save the normalization stats to a directory.""" 151 | path = pathlib.Path(directory) / "norm_stats.json" 152 | path.parent.mkdir(parents=True, exist_ok=True) 153 | path.write_text(serialize_json(norm_stats)) 154 | 155 | 156 | def load(directory: pathlib.Path | str) -> dict[str, NormStats]: 157 | """Load the normalization stats from a directory.""" 158 | path = pathlib.Path(directory) / "norm_stats.json" 159 | if not path.exists(): 160 | raise FileNotFoundError(f"Norm stats file not found at: {path}") 161 | return deserialize_json(path.read_text()) 162 | -------------------------------------------------------------------------------- /wall_x/serving/policy/wall_x_policy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Any, List 3 | import torch 4 | import numpy as np 5 | from transformers import AutoProcessor 6 | 7 | from wall_x.serving.websocket_policy_server import BasePolicy 8 | from wall_x.model.qwen2_5_based.modeling_qwen2_5_vl_act import Qwen2_5_VLMoEForAction 9 | from wall_x.serving.policy.utils import prepare_batch 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class WallXPolicy(BasePolicy): 15 | """Policy wrapper for Wall-X model that implements the BasePolicy interface.""" 16 | 17 | def __init__( 18 | self, 19 | model_path: str, 20 | train_config: dict, 21 | action_tokenizer_path: str, 22 | action_dim: int, 23 | agent_pos_dim: int, 24 | pred_horizon: int, 25 | camera_key: List[str], 26 | device: str = "cuda", 27 | dtype: str = "bfloat16", 28 | predict_mode: str = "fast", 29 | default_prompt: str | None = None, 30 | min_pixels: int = 4 * 28 * 28, 31 | max_pixels: int = 16384 * 28 * 28, 32 | image_factor: int = 28, 33 | max_length: int = 2048, 34 | ): 35 | """Initialize the Wall-X policy. 36 | 37 | Args: 38 | model_path: Path to the pretrained model checkpoint 39 | action_tokenizer_path: Path to the action tokenizer 40 | action_dim: Dimension of action space 41 | pred_horizon: Prediction horizon for actions 42 | device: Device to run model on ('cuda' or 'cpu') 43 | dtype: Data type for model ('bfloat16', 'float16', or 'float32') 44 | predict_mode: Prediction mode ('fast' or 'slow') 45 | default_prompt: Default text prompt for the model 46 | min_pixels: Minimum pixels for image resizing 47 | max_pixels: Maximum pixels for image resizing 48 | image_factor: Factor for smart resize 49 | max_length: Maximum sequence length for text 50 | """ 51 | logger.info(f"Loading Wall-X model from {model_path}") 52 | 53 | self.model = Qwen2_5_VLMoEForAction.from_pretrained( 54 | model_path, 55 | train_config=train_config, 56 | action_tokenizer_path=action_tokenizer_path, 57 | ) 58 | self.model.eval() 59 | self.model = self.model.to(device) 60 | 61 | self.model = self.model.bfloat16() 62 | 63 | # hard code the action dim to 20 for align to wall-x configuration 64 | self.fixed_action_dim = 20 65 | 66 | self.action_dim = action_dim 67 | self.agent_pos_dim = agent_pos_dim 68 | self.pred_horizon = pred_horizon 69 | self.device = device 70 | self.predict_mode = predict_mode 71 | self.default_prompt = default_prompt 72 | self.camera_key = camera_key 73 | 74 | # Image preprocessing config 75 | self.min_pixels = min_pixels 76 | self.max_pixels = max_pixels 77 | self.image_factor = image_factor 78 | self.max_length = max_length 79 | 80 | # Load processor 81 | logger.info("Loading processor and tokenizer...") 82 | self.processor = AutoProcessor.from_pretrained(model_path, use_fast=True) 83 | self.processor.tokenizer.padding_side = "left" 84 | 85 | # Action buffer for multi-step predictions 86 | self.action_buffer = [] 87 | self.buffer_index = 0 88 | 89 | logger.info( 90 | f"Model loaded successfully. Device: {device}, Action dim: {action_dim}, Horizon: {pred_horizon}" 91 | ) 92 | 93 | @property 94 | def metadata(self) -> Dict[str, Any]: 95 | """Return metadata about the policy.""" 96 | return { 97 | "action_dim": self.action_dim, 98 | "pred_horizon": self.pred_horizon, 99 | "device": self.device, 100 | "predict_mode": self.predict_mode, 101 | } 102 | 103 | def reset(self) -> None: 104 | """Reset the policy state.""" 105 | self.action_buffer = [] 106 | self.buffer_index = 0 107 | logger.debug("Policy reset") 108 | 109 | def infer(self, obs: Dict) -> Dict: 110 | """Infer action from observation. 111 | 112 | Args: 113 | obs: Dictionary containing: 114 | - 'image': Image observation (numpy array or PIL Image) 115 | - 'prompt': Optional text prompt 116 | - 'state': Optional robot state 117 | - Other modality-specific observations 118 | 119 | Returns: 120 | Dictionary containing: 121 | - 'action': Predicted action (numpy array) 122 | - Additional metadata 123 | """ 124 | try: 125 | # Need to predict new actions 126 | input_batch = prepare_batch( 127 | obs, 128 | self.processor, 129 | self.camera_key, 130 | self.agent_pos_dim, 131 | self.action_dim, 132 | self.pred_horizon, 133 | self.fixed_action_dim, 134 | self.max_length, 135 | self.image_factor, 136 | self.min_pixels, 137 | self.max_pixels, 138 | self.predict_mode, 139 | self.device, 140 | ) 141 | 142 | with torch.no_grad(): 143 | outputs = self.model( 144 | **input_batch, 145 | action_dim=( 146 | self.action_dim 147 | if self.predict_mode == "fast" 148 | else self.fixed_action_dim 149 | ), 150 | pred_horizon=self.pred_horizon, 151 | mode="predict", 152 | predict_mode=self.predict_mode, 153 | ) 154 | 155 | if outputs["predict_action"] is None: 156 | predicted_actions = np.zeros( 157 | [1, self.pred_horizon, self.action_dim] 158 | ).astype(np.float32) 159 | 160 | predicted_actions = ( 161 | outputs["predict_action"][:, :, : self.action_dim] 162 | .detach() 163 | .cpu() 164 | .to(torch.float32) 165 | .numpy() 166 | ) 167 | 168 | print(predicted_actions.shape) 169 | return {"action": predicted_actions} 170 | 171 | except Exception as e: 172 | logger.error(f"Error during inference: {e}") 173 | raise 174 | -------------------------------------------------------------------------------- /wall_x/serving/launch_serving.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Server script for Wall-X model. 4 | 5 | This script serves a Wall-X model using a websocket server, allowing 6 | clients to connect and get action predictions from observations. 7 | 8 | Based on the OpenPI serve_policy.py script structure. 9 | """ 10 | 11 | import dataclasses 12 | from dataclasses import field 13 | import enum 14 | import logging 15 | import socket 16 | import sys 17 | import yaml 18 | from pathlib import Path 19 | from typing import List 20 | 21 | import tyro 22 | 23 | from wall_x.serving.policy.wall_x_policy import WallXPolicy 24 | from wall_x.serving.websocket_policy_server import WebsocketPolicyServer 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class EnvMode(enum.Enum): 30 | """Supported environments/datasets.""" 31 | 32 | LIBERO = "libero" 33 | ALOHA = "aloha" 34 | 35 | 36 | @dataclasses.dataclass 37 | class ModelConfig: 38 | """Configuration for loading a Wall-X model.""" 39 | 40 | # Path to the pretrained model checkpoint 41 | model_path: str 42 | # Path to the action tokenizer 43 | action_tokenizer_path: str 44 | # Path to train config yaml 45 | train_config_path: str 46 | # Action dimension for the environment 47 | action_dim: int = 7 48 | # State dimension for the environment 49 | state_dim: int = 8 50 | # Prediction horizon (number of future actions to predict) 51 | pred_horizon: int = 32 52 | # Device to run model on 53 | device: str = "cuda" 54 | # Model dtype (bfloat16, float16, float32) 55 | dtype: str = "bfloat16" 56 | # Prediction mode (fast or slow) 57 | predict_mode: str = "fast" 58 | # Camera key for the environment 59 | camera_key: List[str] = field( 60 | default_factory=lambda: ["front_view", "left_wrist_view", "right_wrist_view"] 61 | ) 62 | 63 | 64 | @dataclasses.dataclass 65 | class Args: 66 | """Arguments for the serve_wall_x script.""" 67 | 68 | # Environment mode (used for default configurations) 69 | env: EnvMode = EnvMode.LIBERO 70 | 71 | # Model configuration. If not provided, uses default config for the environment 72 | model_config: ModelConfig | None = None 73 | 74 | # Default text prompt to use if not provided in observation 75 | default_prompt: str | None = None 76 | 77 | # Port to serve the policy on 78 | port: int = 8000 79 | 80 | # Host to bind the server to 81 | host: str = "0.0.0.0" 82 | 83 | # Enable debug logging 84 | debug: bool = False 85 | 86 | 87 | # Default model configurations for each environment 88 | DEFAULT_CONFIGS: dict[EnvMode, ModelConfig] = { 89 | EnvMode.LIBERO: ModelConfig( 90 | model_path="/path/to/model", 91 | action_tokenizer_path="/path/to/action_tokenizer", 92 | train_config_path="/path/to/train_config", 93 | state_dim=8, 94 | action_dim=7, 95 | pred_horizon=32, 96 | device="cuda", 97 | dtype="bfloat16", 98 | predict_mode="fast", 99 | camera_key=["front_view", "left_wrist_view"], 100 | ), 101 | EnvMode.ALOHA: ModelConfig( 102 | model_path="/path/to/model", 103 | action_tokenizer_path="/path/to/action_tokenizer", 104 | train_config_path="/path/to/train_config", 105 | state_dim=14, 106 | action_dim=14, 107 | pred_horizon=32, 108 | device="cuda", 109 | dtype="bfloat16", 110 | predict_mode="fast", 111 | camera_key=["face_view", "left_wrist_view", "right_wrist_view"], 112 | ), 113 | } 114 | 115 | 116 | def get_model_config(args: Args) -> ModelConfig: 117 | """Get model configuration from args or defaults.""" 118 | if args.model_config is not None: 119 | return args.model_config 120 | 121 | if config := DEFAULT_CONFIGS.get(args.env): 122 | logger.info(f"Using default configuration for {args.env.value}") 123 | return config 124 | 125 | raise ValueError( 126 | f"No default configuration for {args.env.value}. " 127 | f"Please provide --model-config with model_path and action_tokenizer_path." 128 | ) 129 | 130 | 131 | def create_policy(args: Args) -> WallXPolicy: 132 | """Create a Wall-X policy from the given arguments.""" 133 | config = get_model_config(args) 134 | logger.info(f"Creating Wall-X policy with config: {config}") 135 | 136 | # Validate paths 137 | if not Path(config.model_path).exists(): 138 | logger.warning(f"Model path does not exist: {config.model_path}") 139 | 140 | if not Path(config.action_tokenizer_path).exists(): 141 | logger.warning( 142 | f"Action tokenizer path does not exist: {config.action_tokenizer_path}" 143 | ) 144 | 145 | with open(config.train_config_path, "r") as f: 146 | train_config = yaml.load(f, Loader=yaml.FullLoader) 147 | 148 | policy = WallXPolicy( 149 | model_path=config.model_path, 150 | train_config=train_config, 151 | action_tokenizer_path=config.action_tokenizer_path, 152 | action_dim=config.action_dim, 153 | agent_pos_dim=config.state_dim, 154 | pred_horizon=config.pred_horizon, 155 | device=config.device, 156 | dtype=config.dtype, 157 | predict_mode=config.predict_mode, 158 | default_prompt=args.default_prompt, 159 | camera_key=config.camera_key, 160 | ) 161 | 162 | return policy 163 | 164 | 165 | def main(args: Args) -> None: 166 | """Main function to start the Wall-X model server.""" 167 | log_level = logging.DEBUG if args.debug else logging.INFO 168 | logging.basicConfig( 169 | level=log_level, 170 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 171 | ) 172 | 173 | logger.info("Starting Wall-X model server") 174 | logger.info(f"Environment: {args.env.value}") 175 | logger.info(f"Port: {args.port}") 176 | logger.info(f"Host: {args.host}") 177 | 178 | # Create policy 179 | try: 180 | policy = create_policy(args) 181 | except Exception as e: 182 | logger.error(f"Failed to create policy: {e}") 183 | sys.exit(1) 184 | 185 | # Get policy metadata 186 | policy_metadata = policy.metadata 187 | policy_metadata["env"] = args.env.value 188 | 189 | # Get network info 190 | hostname = socket.gethostname() 191 | try: 192 | local_ip = socket.gethostbyname(hostname) 193 | except Exception: 194 | local_ip = "unknown" 195 | 196 | logger.info(f"Server hostname: {hostname}") 197 | logger.info(f"Server IP: {local_ip}") 198 | logger.info(f"Server will be available at: ws://{args.host}:{args.port}") 199 | logger.info(f"Health check endpoint: http://{args.host}:{args.port}/healthz") 200 | 201 | # Create and start server 202 | server = WebsocketPolicyServer( 203 | policy=policy, 204 | host=args.host, 205 | port=args.port, 206 | metadata=policy_metadata, 207 | ) 208 | 209 | logger.info("Starting server...") 210 | try: 211 | server.serve_forever() 212 | except KeyboardInterrupt: 213 | logger.info("Server stopped by user") 214 | except Exception as e: 215 | logger.error(f"Server error: {e}") 216 | sys.exit(1) 217 | 218 | 219 | if __name__ == "__main__": 220 | main(tyro.cli(Args)) 221 | -------------------------------------------------------------------------------- /workspace/README.md: -------------------------------------------------------------------------------- 1 | # Training Guide 2 | 3 | This document explains the key configuration parameters and memory requirements for Wall-X training. 4 | 5 | ## Quick Start Checklist 6 | 7 | ### 🚀 **Step 1: Prepare Model** 8 | Choose one of our pretrained models: 9 | - **WALL-OSS-FLOW**: https://huggingface.co/x-square-robot/wall-oss-flow 10 | - **WALL-OSS-FAST**: https://huggingface.co/x-square-robot/wall-oss-fast 11 | Or from Qwen-2.5-VL 12 | - Download https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct, settings refer to `config_qact_from_vlm.yml` 13 | 14 | ### ⚙️ **Step 2: Configure Environment** 15 | - Update `run.sh`: Set `code_dir` and `config_path` to your actual paths 16 | - Set `CUDA_VISIBLE_DEVICES` for your available GPUs 17 | 18 | ### 📝 **Step 3: Update Configuration Files** 19 | - Replace all `/path/to/` placeholders in `config_qact.yml` with actual paths 20 | - Configure robot settings: `dof_config` and `agent_pos_config` 21 | - Set dataset: Choose appropriate `repo_id` 22 | - Adjust `batch_size_per_gpu` based on your GPU memory 23 | 24 | ### ▶️ **Step 4: Start Training** 25 | ```bash 26 | bash ./workspace/lerobot_example/run.sh 27 | ``` 28 | 29 | ## Enable FAST tokenizer 30 | To fine-tune using the FAST tokenizer, please download the repository and update the `action_tokenizer_path`. Make sure to set `use_fast_tokenizer` to `true` and q01 and q99 to normalize the dataset, refer to `wall-x/scripts/compute_norm_stats.py`: 31 | ```bash 32 | git clone https://huggingface.co/physical-intelligence/fast 33 | ``` 34 | 35 | ## Required Paths (Must Modify) 36 | ```yaml 37 | pretrained_wallx_path: "/path/to/wallx_model/" # Path to pretrained wallx model 38 | save_path: "/path/to/workspace/" # Path to save training outputs 39 | use_fast_tokenizer: False # True: train FAST, False: train Flow 40 | action_tokenizer_path: "/path/to/fast/" # Must set if use_fast_tokenizer is True 41 | norm_stats_path: "/path/to/stats/" # Must set for normalize dataset 42 | ``` 43 | ## Customize your robot configuration 44 | Ensure that the sum of the configuration dimensions corresponds to the values specified in norm_stats.json, and that each key is unique. The maximum dimensionality is set to 20, consistent with our robot configuration. 45 | ```yaml 46 | customized_dof_config: 47 | "action_eef": 6 48 | "action_gripper": 1 49 | 50 | customized_agent_pos_config: 51 | "state_eef_with_gripper": 7 52 | ``` 53 | 54 | ## Using Lerobot Dataset 55 | - Each dataset employs distinct keys; please specify the corresponding key mappings as described in `wall-x/wall_x/data/utils.py`. 56 | ```python 57 | "lerobot/aloha_mobile_cabinet": { 58 | "camera": { 59 | "observation.images.cam_high": "face_view", 60 | "observation.images.cam_left_wrist": "left_wrist_view", 61 | "observation.images.cam_right_wrist": "right_wrist_view", 62 | }, 63 | "state": "observation.state", 64 | "action": "action", 65 | } 66 | ``` 67 | 68 | ## Compute stats 69 | ```bash 70 | python wall-x/scripts/compute_norm_stats.py 71 | ``` 72 | 73 | ## Configuration Explain 74 | - `agent_pos_config` corresponds to `obs_action_keys` and subsequently to state, while `dof_config` corresponds to `predict_action_keys` and subsequently to action. Note that the state and action may not necessarily share the same set of DoF. 75 | 76 | ## Training Parameters (Commonly Modified) 77 | 78 | ### Learning Rate Settings 79 | - `learning_rate`: Initial learning rate (default: 0.00009) 80 | - `min_lr`: Minimum learning rate for scheduler (default: 0.00005) 81 | - `num_warmup_steps`: Number of warmup steps (default: 100) 82 | 83 | ### Batch Size and Memory 84 | - `batch_size_per_gpu`: Batch size per GPU - adjust based on GPU memory 85 | - `gradient_accumulation_steps`: Gradient accumulation steps 86 | - `num_training_steps`: Total training steps 87 | - `num_epoch`: Number of training epochs 88 | 89 | ### Training Optimization Settings 90 | - `FSDP2`: Enable FSDP2 for distributed training (default: True) - **Recommended for multi-GPU** 91 | - `torch_compile`: Enable PyTorch compilation optimization (default: False) 92 | 93 | **⚠️ Important Note on torch_compile:** 94 | - **Benefits**: Enabling `torch_compile` can significantly improve training efficiency 95 | - **Requirements**: Requires that the data input shape is always consistent throughout training 96 | - **Caution**: If you don't have sufficient understanding of torch compile, please **DO NOT** enable it as it may cause unexpected issues with dynamic input shapes 97 | 98 | ## Robot Configuration (Modify for Your Robot) 99 | 100 | ### DOF Configuration 101 | Modify `dof_config` to match your robot's action space: 102 | - Add/remove action keys based on your robot's capabilities 103 | - Ensure DOF numbers match your robot's action dimensions 104 | 105 | ### Agent Position Configuration 106 | Keep `agent_pos_config` consistent with `dof_config`. 107 | 108 | ### Action Keys 109 | - `obs_action_keys`: Actions used as observation context 110 | - `predict_action_keys`: Actions to predict/control 111 | 112 | ## Data Configuration 113 | 114 | ### Dataset 115 | - `repo_id`: LeRobot dataset identifier 116 | - `train_test_split`: Training/validation split ratio (default: 0.95) 117 | - `action_horizon`: Number of future actions to predict (default: 32) 118 | 119 | ### Image Settings 120 | - `resolution`: Image resolution for different camera views 121 | - `download_videos`: Whether to download video files (true/false) 122 | 123 | ## Resume Training (Optional) 124 | - `resume.ckpt`: Path to checkpoint for resuming training 125 | - `resume.load_ckpt_only`: Only load model weights, not optimizer state 126 | 127 | ## Merge checkpoint 128 | - If FSDP SHARDED_STATE_DICT is used, please run command below to merge checkpoint into a single safetensors 129 | ```bash 130 | # refer to accelerate/commands/merge.py 131 | accelerate merge-weights /path/to/sharded_tensors /path/to/model.safetensors 132 | # copy the saved processor files 133 | cp /path/to/saved_processor_dir/* /path/to/model.safetensors 134 | 135 | # In earlier versions of PyTorch, errors may occur. You can use our provided script to address this issue; refer to wall-x/scripts/merge_sharded_weights.py for details. 136 | ``` 137 | 138 | ## Memory Usage 139 | 140 | Below are the memory consumption benchmarks for different training configurations using the `lerobot/aloha_mobile_cabinet` dataset: 141 | 142 | | Dataset | Batch Size | FSDP2 | Torch Compile | Num GPUs | Max Allocated Memory | 143 | |---------|------------|--------|---------------|----------|---------------------| 144 | | lerobot/aloha_mobile_cabinet | 1 | ❌ | ❌ | 1 | 40.11G | 145 | | lerobot/aloha_mobile_cabinet | 1 | ❌ | ❌ | 8 | 48.02G | 146 | | lerobot/aloha_mobile_cabinet | 1 | ✅ | ❌ | 2 | 43.70G | 147 | | lerobot/aloha_mobile_cabinet | 1 | ✅ | ❌ | 8 | 24.96G | 148 | | lerobot/aloha_mobile_cabinet | 1 | ✅ | ✅ | 8 | 24.21G | 149 | 150 | 151 | **Hardware Recommendations:** 152 | 153 | - For single GPU training: Ensure at least 48GB VRAM (e.g., RTX 6000 Ada, A6000) 154 | - For multi-GPU training: Enable FSDP2 for optimal memory distribution 155 | 156 | ## Reproduce 157 | 158 | Openloop plot `wall-x/workspace/lerobot_example/evaluation/lerobot_openloop.png` 159 | 160 | To reproduce the results, use the config file wall-x/workspace/lerobot_example/config_qact_from_vlm.yml with a global batch size of 128, adjusted via `gradient_accumulation_steps` and numbers of gpu. 161 | -------------------------------------------------------------------------------- /scripts/merge_sharded_weights.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Custom script to merge FSDP sharded checkpoints with compatibility handling. 4 | Works around the StorageMeta compatibility issue between PyTorch versions. 5 | """ 6 | 7 | import os 8 | import sys 9 | import torch 10 | from pathlib import Path 11 | from typing import Dict 12 | from safetensors.torch import save_file 13 | 14 | 15 | def patch_metadata_loader(): 16 | """Patch the metadata loader to handle missing StorageMeta class.""" 17 | import torch.distributed.checkpoint.metadata as metadata_module 18 | 19 | # Create a dummy StorageMeta class if it doesn't exist 20 | if not hasattr(metadata_module, "StorageMeta"): 21 | print("[INFO] Creating StorageMeta compatibility shim") 22 | 23 | class StorageMeta: 24 | """Compatibility shim for old StorageMeta class.""" 25 | 26 | def __init__(self, *args, **kwargs): 27 | # Store all args as attributes 28 | self.args = args 29 | self.kwargs = kwargs 30 | 31 | # Inject the class into the module 32 | metadata_module.StorageMeta = StorageMeta 33 | 34 | # Also make it available for unpickling 35 | sys.modules["torch.distributed.checkpoint.metadata"].StorageMeta = StorageMeta 36 | 37 | 38 | def load_sharded_checkpoint(checkpoint_dir: str) -> Dict[str, torch.Tensor]: 39 | """ 40 | Load a sharded FSDP checkpoint by manually reading all shard files. 41 | 42 | Args: 43 | checkpoint_dir: Path to directory containing .distcp files 44 | 45 | Returns: 46 | Dictionary of merged model state 47 | """ 48 | import torch.distributed.checkpoint as dist_cp 49 | import torch.distributed.checkpoint.format_utils as dist_cp_format_utils 50 | 51 | print(f"[INFO] Loading checkpoint from {checkpoint_dir}") 52 | 53 | # Apply the compatibility patch 54 | patch_metadata_loader() 55 | 56 | # Try to load using the standard approach 57 | try: 58 | state_dict = {} 59 | storage_reader = dist_cp.FileSystemReader(checkpoint_dir) 60 | 61 | dist_cp_format_utils._load_state_dict( 62 | state_dict, 63 | storage_reader=storage_reader, 64 | planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(), 65 | no_dist=True, 66 | ) 67 | 68 | print(f"[INFO] Successfully loaded state dict with {len(state_dict)} keys") 69 | return state_dict 70 | 71 | except AttributeError as e: 72 | if "StorageMeta" in str(e): 73 | print(f"[ERROR] StorageMeta compatibility issue: {e}") 74 | print("[INFO] Attempting alternative loading method...") 75 | return load_checkpoint_alternative(checkpoint_dir) 76 | else: 77 | raise 78 | 79 | 80 | def load_checkpoint_alternative(checkpoint_dir: str) -> Dict[str, torch.Tensor]: 81 | """ 82 | Alternative method to load checkpoint by directly reading shard files. 83 | 84 | Args: 85 | checkpoint_dir: Path to directory containing .distcp files 86 | 87 | Returns: 88 | Dictionary of merged model state 89 | """ 90 | checkpoint_path = Path(checkpoint_dir) 91 | 92 | # Find all shard files 93 | shard_files = sorted(checkpoint_path.glob("*.distcp")) 94 | 95 | if not shard_files: 96 | raise FileNotFoundError(f"No .distcp files found in {checkpoint_dir}") 97 | 98 | print(f"[INFO] Found {len(shard_files)} shard files") 99 | 100 | # Load all shards 101 | merged_state = {} 102 | 103 | for shard_file in shard_files: 104 | print(f"[INFO] Loading shard: {shard_file.name}") 105 | try: 106 | shard_data = torch.load(shard_file, map_location="cpu") 107 | 108 | # Merge the shard into the state dict 109 | if isinstance(shard_data, dict): 110 | for key, value in shard_data.items(): 111 | if isinstance(value, torch.Tensor): 112 | if key in merged_state: 113 | # Handle duplicates - concatenate or overwrite based on shape 114 | print(f"[WARNING] Duplicate key found: {key}") 115 | merged_state[key] = value 116 | elif isinstance(value, dict): 117 | # Nested dict structure 118 | for subkey, subvalue in value.items(): 119 | full_key = f"{key}.{subkey}" if key else subkey 120 | if isinstance(subvalue, torch.Tensor): 121 | merged_state[full_key] = subvalue 122 | 123 | except Exception as e: 124 | print(f"[WARNING] Failed to load shard {shard_file.name}: {e}") 125 | continue 126 | 127 | if not merged_state: 128 | raise RuntimeError("Failed to load any checkpoint data from shards") 129 | 130 | print(f"[INFO] Loaded {len(merged_state)} tensors from shards") 131 | return merged_state 132 | 133 | 134 | def save_merged_checkpoint( 135 | state_dict: Dict[str, torch.Tensor], 136 | output_path: str, 137 | safe_serialization: bool = True, 138 | ): 139 | """ 140 | Save the merged checkpoint to disk. 141 | 142 | Args: 143 | state_dict: Model state dictionary 144 | output_path: Directory to save the merged checkpoint 145 | safe_serialization: If True, save as .safetensors, else as .bin 146 | """ 147 | output_dir = Path(output_path) 148 | output_dir.mkdir(parents=True, exist_ok=True) 149 | 150 | # Handle nested state dict structure (e.g., {model: {...}}) 151 | if len(state_dict.keys()) == 1 and all( 152 | isinstance(v, dict) for v in state_dict.values() 153 | ): 154 | print("[INFO] Unwrapping nested state dict") 155 | state_dict = state_dict[list(state_dict.keys())[0]] 156 | 157 | # Prepare tensors for saving 158 | save_dict = {} 159 | for key, value in state_dict.items(): 160 | if isinstance(value, torch.Tensor): 161 | # Convert to CPU and contiguous 162 | save_dict[key] = value.cpu().contiguous() 163 | else: 164 | print(f"[WARNING] Skipping non-tensor key: {key} (type: {type(value)})") 165 | 166 | if safe_serialization: 167 | output_file = output_dir / "model.safetensors" 168 | print(f"[INFO] Saving merged checkpoint to {output_file}") 169 | save_file(save_dict, output_file) 170 | else: 171 | output_file = output_dir / "pytorch_model.bin" 172 | print(f"[INFO] Saving merged checkpoint to {output_file}") 173 | torch.save(save_dict, output_file) 174 | 175 | print("[SUCCESS] Checkpoint saved successfully!") 176 | print(f"[INFO] Saved {len(save_dict)} tensors") 177 | 178 | # Print size info 179 | total_params = sum(v.numel() for v in save_dict.values()) 180 | total_size_gb = sum(v.numel() * v.element_size() for v in save_dict.values()) / ( 181 | 1024**3 182 | ) 183 | print(f"[INFO] Total parameters: {total_params:,}") 184 | print(f"[INFO] Total size: {total_size_gb:.2f} GB") 185 | 186 | return output_file 187 | 188 | 189 | def main(): 190 | import argparse 191 | 192 | parser = argparse.ArgumentParser( 193 | description="Merge FSDP sharded checkpoints with compatibility handling" 194 | ) 195 | parser.add_argument( 196 | "checkpoint_dir", 197 | type=str, 198 | help="Directory containing sharded FSDP checkpoint files (*.distcp)", 199 | ) 200 | parser.add_argument( 201 | "output_path", type=str, help="Output directory for merged checkpoint" 202 | ) 203 | parser.add_argument( 204 | "--unsafe-serialization", 205 | action="store_true", 206 | help="Save as .bin instead of .safetensors", 207 | ) 208 | 209 | args = parser.parse_args() 210 | 211 | # Validate input 212 | if not os.path.exists(args.checkpoint_dir): 213 | print(f"[ERROR] Checkpoint directory not found: {args.checkpoint_dir}") 214 | sys.exit(1) 215 | 216 | try: 217 | # Load the sharded checkpoint 218 | state_dict = load_sharded_checkpoint(args.checkpoint_dir) 219 | 220 | # Save the merged checkpoint 221 | safe_serialization = not args.unsafe_serialization 222 | output_file = save_merged_checkpoint( 223 | state_dict, args.output_path, safe_serialization 224 | ) 225 | 226 | print("\n[COMPLETE] Checkpoint merging successful!") 227 | print(f"[COMPLETE] Output: {output_file}") 228 | 229 | except Exception as e: 230 | print(f"\n[ERROR] Failed to merge checkpoint: {e}") 231 | import traceback 232 | 233 | traceback.print_exc() 234 | sys.exit(1) 235 | 236 | 237 | if __name__ == "__main__": 238 | main() 239 | -------------------------------------------------------------------------------- /wall_x/serving/policy/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import logging 3 | import numpy as np 4 | from wall_x.data.utils import preprocesser_call 5 | from qwen_vl_utils.vision_process import smart_resize 6 | import torch 7 | from PIL import Image 8 | from transformers import BatchFeature 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def prepare_batch( 14 | obs: Dict, 15 | processor, 16 | camera_key: List[str], 17 | agent_pos_dim, 18 | action_dim, 19 | pred_horizon, 20 | fixed_action_dim, 21 | max_length, 22 | image_factor: int, 23 | min_pixels: int, 24 | max_pixels: int, 25 | predict_mode: str = "fast", 26 | device: str = "cuda", 27 | ) -> BatchFeature: 28 | """Prepare observation into model input format. 29 | 30 | Args: 31 | obs: Dictionary containing: 32 | - 'camera_key_0' : image 0 33 | - 'camera_key_1' : image 1 34 | ... 35 | - 'prompt': Text prompt 36 | - 'state': Robot state/proprioception 37 | - 'dataset_names': Dataset names 38 | 39 | Returns: 40 | BatchFeature object ready for model input 41 | """ 42 | # Handle images - can be single image, list of images, or dict of images 43 | images = [] 44 | images = [obs[key] for key in camera_key] 45 | # Convert numpy arrays to PIL Images 46 | processed_images = [] 47 | for img in images: 48 | if isinstance(img, np.ndarray): 49 | # Debug: Log the shape and dtype 50 | logger.debug(f"Image shape: {img.shape}, dtype: {img.dtype}") 51 | 52 | # Handle unexpected dimensions - squeeze if needed 53 | if img.ndim > 3: 54 | logger.warning( 55 | f"Image has {img.ndim} dimensions, squeezing extra dimensions" 56 | ) 57 | img = np.squeeze(img) 58 | 59 | # Verify shape is valid for PIL 60 | if img.ndim == 2: 61 | # Grayscale image 62 | pass 63 | elif img.ndim == 3: 64 | # Check if channel dimension is first or last 65 | if img.shape[0] == 3 or img.shape[0] == 1: 66 | # Channels first, transpose to channels last 67 | img = np.transpose(img, (1, 2, 0)) 68 | elif img.shape[2] == 3 or img.shape[2] == 1: 69 | # Already channels last 70 | pass 71 | else: 72 | raise ValueError( 73 | f"Unexpected image shape: {img.shape}. Expected (H, W, C) or (C, H, W)" 74 | ) 75 | else: 76 | raise ValueError( 77 | f"Invalid image dimensions: {img.ndim}. Expected 2 or 3 dimensions, got shape {img.shape}" 78 | ) 79 | 80 | # Convert to PIL Image 81 | if img.dtype == np.uint8: 82 | img = Image.fromarray(img) 83 | else: 84 | img = Image.fromarray((img * 255).astype(np.uint8)) 85 | processed_images.append(img) 86 | 87 | # Apply smart resize to images 88 | resized_images = process_images( 89 | processed_images, image_factor, min_pixels, max_pixels 90 | ) 91 | 92 | # Handle text prompt - format with vision tokens 93 | instruction = obs["prompt"] 94 | formatted_text = format_text_with_vision_tokens( 95 | instruction, camera_key, predict_mode, pred_horizon 96 | ) 97 | 98 | # Use processor to prepare inputs 99 | inputs = preprocesser_call( 100 | processor=processor, 101 | text=[formatted_text], 102 | images=[resized_images], 103 | videos=None, 104 | padding=True, 105 | truncation=True, 106 | return_tensors="pt", 107 | max_length=max_length, 108 | ) 109 | 110 | action_token_id = processor.tokenizer.convert_tokens_to_ids("<|action|>") 111 | moe_token_types = inputs.input_ids == action_token_id 112 | inputs["moe_token_types"] = moe_token_types 113 | 114 | # Handle robot state/proprioception if available 115 | if "state" in obs: 116 | state = obs["state"] 117 | if isinstance(state, np.ndarray): 118 | state = torch.from_numpy(state).float() 119 | elif not isinstance(state, torch.Tensor): 120 | state = torch.tensor(state, dtype=torch.float32) 121 | 122 | # Add batch dimension if needed 123 | if state.dim() == 1: 124 | state = state.unsqueeze(0) 125 | if state.dim() == 2: 126 | state = state.unsqueeze(1) # [batch, 1, state_dim] 127 | 128 | # Pad to 20 dimensions if needed (same as training) 129 | if state.shape[-1] < 20: 130 | padding = torch.zeros(state.shape[0], state.shape[1], 20 - state.shape[-1]) 131 | state = torch.cat([state, padding], dim=-1) 132 | 133 | # Create mask for valid dimensions 134 | agent_pos_mask = torch.ones_like(state) 135 | if state.shape[-1] > agent_pos_dim: 136 | agent_pos_mask[:, :, agent_pos_dim:] = 0 137 | 138 | inputs["proprioception"] = state 139 | inputs["agent_pos_mask"] = agent_pos_mask 140 | 141 | # Add dataset name (required by model) 142 | inputs["dataset_names"] = obs["dataset_names"] 143 | 144 | # Move all tensors to device 145 | for key in inputs: 146 | if isinstance(inputs[key], torch.Tensor): 147 | inputs[key] = inputs[key].to(device) 148 | 149 | dof_mask = torch.ones([state.shape[0], pred_horizon, fixed_action_dim]) 150 | dof_mask[:, :, action_dim:] = 0 151 | 152 | inputs["dof_mask"] = dof_mask 153 | 154 | # Convert to BatchFeature to maintain consistency with training pipeline 155 | return BatchFeature(data=dict(inputs)).to(device) 156 | 157 | 158 | def process_images( 159 | images: List[Image.Image], image_factor: int, min_pixels: int, max_pixels: int 160 | ) -> List[Image.Image]: 161 | """Process images with smart resize following the data loading pattern. 162 | 163 | Args: 164 | images: List of PIL Images 165 | 166 | Returns: 167 | List of resized PIL Images 168 | """ 169 | resized_images = [] 170 | for img_pil in images: 171 | current_width, current_height = img_pil.size 172 | 173 | # Apply smart scaling (Qwen logic) 174 | resized_height, resized_width = smart_resize( 175 | current_height, 176 | current_width, 177 | factor=image_factor, 178 | min_pixels=min_pixels, 179 | max_pixels=max_pixels, 180 | ) 181 | 182 | resized_img = img_pil.resize((resized_width, resized_height)) 183 | resized_images.append(resized_img) 184 | 185 | return resized_images 186 | 187 | 188 | def format_text_with_vision_tokens( 189 | instruction: str, 190 | camera_key: List[str], 191 | predict_mode: str = "fast", 192 | pred_horizon: int = 32, 193 | ) -> str: 194 | """Format text prompt with vision tokens for the model. 195 | 196 | Args: 197 | instruction: Task instruction text 198 | camera_key: List of camera names 199 | 200 | Returns: 201 | Formatted text with special tokens 202 | """ 203 | # Special tokens for formatting 204 | role_start_symbol = "<|im_start|>" 205 | role_end_symbol = "<|im_end|>" 206 | vision_start_symbol = "<|vision_start|>" 207 | vision_end_symbol = "<|vision_end|>" 208 | image_pad_symbol = "<|image_pad|>" 209 | propri_symbol = "<|propri|>" 210 | action_symbol = "<|action|>" 211 | # action_fast_symbol = "<|action_fast|>" 212 | 213 | # Camera name mapping 214 | camera_name_mapping = { 215 | "front_view": "front view", 216 | "face_view": "front view", 217 | "left_wrist_view": "left wrist view", 218 | "right_wrist_view": "right wrist view", 219 | "top_view": "top view", 220 | "wall_view": "wall view", 221 | } 222 | 223 | # System prologue 224 | prologue = ( 225 | f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n" 226 | ) 227 | 228 | # User request with observation 229 | user_request = f"{role_start_symbol}user\nObservation:" 230 | if camera_key: 231 | for cam_name in camera_key: 232 | view_name = camera_name_mapping.get(cam_name, cam_name) 233 | user_request += f" {view_name}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}" 234 | user_request += "\nInstruction:" 235 | 236 | text_prompt = ( 237 | f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n" 238 | ) 239 | user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" 240 | assistant_output = f"{role_start_symbol}assistant\n" 241 | if predict_mode == "diffusion": 242 | assistant_output += f"{action_symbol * pred_horizon}" 243 | complete_text = prologue + user_message + assistant_output 244 | 245 | return complete_text 246 | -------------------------------------------------------------------------------- /wall_x/serving/README.md: -------------------------------------------------------------------------------- 1 | # Wall-X Model Serving 2 | 3 | This directory contains scripts for serving Wall-X models via a websocket server, allowing remote clients to connect and get action predictions from observations. 4 | 5 | ## Overview 6 | 7 | The serving infrastructure consists of three main components: 8 | 9 | 1. **WebsocketPolicyServer** (`wall_x/serving/websocket_policy_server.py`): Generic websocket server that can serve any policy implementing the `BasePolicy` interface 10 | 2. **WallXPolicy** (`wall_x/serving/policy/wall_x_policy.py`): Policy wrapper that adapts the Wall-X model to the `BasePolicy` interface 11 | 3. **launch_serving.py**: Main script for starting the server with various configurations 12 | 13 | ## Quick Start 14 | 15 | ### Basic Usage 16 | 17 | Serve a model with default LIBERO configuration: 18 | 19 | ```bash 20 | cd /x2robot_v2/vincent/workspace/opensource 21 | python -m wall_x.serving.launch_serving \ 22 | --env libero \ 23 | --model-config.model-path /path/to/libero_model_stuff \ 24 | --model-config.action-tokenizer-path /path/to/fast/ \ 25 | --model-config.train-config-path /path/to/config.yml 26 | ``` 27 | 28 | ### Specify Environment 29 | 30 | Serve with a specific environment preset: 31 | 32 | ```bash 33 | # LIBERO (single arm, 7 DOF) 34 | python -m wall_x.serving.launch_serving --env libero 35 | 36 | # ALOHA (dual arm, 14 DOF) 37 | python -m wall_x.serving.launch_serving --env aloha 38 | ``` 39 | 40 | ### Custom Configuration 41 | 42 | Serve with custom model paths and settings: 43 | 44 | ```bash 45 | python -m wall_x.serving.launch_serving \ 46 | --model-config.model-path /path/to/model \ 47 | --model-config.action-tokenizer-path /path/to/tokenizer \ 48 | --model-config.train-config-path /path/to/train_config.yml \ 49 | --model-config.action-dim 7 \ 50 | --model-config.state-dim 8 \ 51 | --model-config.pred-horizon 32 \ 52 | --model-config.camera-key front_view left_wrist_view \ 53 | --port 8000 54 | ``` 55 | 56 | ## Command Line Arguments 57 | 58 | ### Basic Arguments 59 | 60 | - `--env {libero,aloha}`: Environment mode (default: libero) 61 | - `--port PORT`: Port to serve on (default: 8000) 62 | - `--host HOST`: Host to bind to (default: 0.0.0.0) 63 | - `--default-prompt TEXT`: Default text prompt if not provided in observation 64 | - `--debug`: Enable debug logging 65 | 66 | ### Model Configuration 67 | 68 | All model configuration arguments use the `--model-config.` prefix: 69 | 70 | - `--model-config.model-path PATH`: Path to pretrained model checkpoint (required) 71 | - `--model-config.action-tokenizer-path PATH`: Path to action tokenizer (required) 72 | - `--model-config.train-config-path PATH`: Path to train config YAML file (required) 73 | - `--model-config.action-dim INT`: Action space dimension (default: 7) 74 | - `--model-config.state-dim INT`: Robot state dimension (default: 8) 75 | - `--model-config.pred-horizon INT`: Prediction horizon (default: 32) 76 | - `--model-config.device {cuda,cpu}`: Device to run on (default: cuda) 77 | - `--model-config.dtype {bfloat16,float16,float32}`: Model dtype (default: bfloat16) 78 | - `--model-config.predict-mode {fast,diffusion}`: Prediction mode (default: fast) 79 | - `--model-config.camera-key KEY1 KEY2 ...`: Camera keys for observation images 80 | 81 | ### Camera Keys 82 | 83 | The `camera-key` parameter specifies which camera views are expected in the observation dictionary. This is **critical** for proper operation: 84 | 85 | - Keys must match between server configuration and client observations 86 | - Order matters: keys are processed in the order specified 87 | - Common keys: `front_view`, `left_wrist_view`, `right_wrist_view`, `face_view` 88 | 89 | Example: 90 | ```bash 91 | --model-config.camera-key front_view left_wrist_view 92 | ``` 93 | 94 | Client must send observations with matching keys: 95 | ```python 96 | obs = { 97 | "front_view": image1, # Must match camera-key[0] 98 | "left_wrist_view": image2, # Must match camera-key[1] 99 | "prompt": "task description", 100 | "state": robot_state, 101 | } 102 | ``` 103 | 104 | ## Default Configurations 105 | 106 | ### LIBERO (Single Arm) 107 | 108 | ```python 109 | ModelConfig( 110 | model_path="/path/to/model", 111 | action_tokenizer_path="/path/to/action_tokenizer", 112 | train_config_path="/path/to/train_config", 113 | state_dim=8, 114 | action_dim=7, 115 | pred_horizon=32, 116 | device="cuda", 117 | dtype="bfloat16", 118 | predict_mode="fast", 119 | camera_key=["front_view", "left_wrist_view"], 120 | ) 121 | ``` 122 | 123 | ### ALOHA (Dual Arm) 124 | 125 | ```python 126 | ModelConfig( 127 | model_path="/path/to/model", 128 | action_tokenizer_path="/path/to/action_tokenizer", 129 | train_config_path="/path/to/train_config", 130 | state_dim=14, 131 | action_dim=14, 132 | pred_horizon=32, 133 | device="cuda", 134 | dtype="bfloat16", 135 | predict_mode="fast", 136 | camera_key=["face_view", "left_wrist_view", "right_wrist_view"], 137 | ) 138 | ``` 139 | 140 | ## Server Protocol 141 | 142 | ### Connection Flow 143 | 144 | 1. Client connects to `ws://host:port` 145 | 2. Server sends metadata JSON with policy information 146 | 3. Client sends observation (msgpack-encoded) 147 | 4. Server responds with action prediction (msgpack-encoded) 148 | 5. Repeat steps 3-4 for each inference 149 | 150 | ### Observation Format 151 | 152 | Observations must be a dictionary with camera keys matching server configuration: 153 | 154 | ```python 155 | obs = { 156 | # Image observations - keys must match server's camera_key configuration 157 | "front_view": np.ndarray, # (H, W, 3) uint8 or float 158 | "left_wrist_view": np.ndarray, # (H, W, 3) uint8 or float 159 | 160 | # Required fields 161 | "prompt": str, # Task description 162 | "dataset_names": List[str], # Dataset/robot name, e.g., ["physical-intelligence/libero"] 163 | "state": np.ndarray, # Robot proprioception state (state_dim,) 164 | } 165 | ``` 166 | 167 | **Important**: The image keys (`front_view`, `left_wrist_view`, etc.) must exactly match the `camera_key` parameter configured on the server. 168 | 169 | ### Action Response Format 170 | 171 | Actions are returned as a dictionary: 172 | 173 | ```python 174 | { 175 | "action": np.ndarray, # Predicted action [pred_horizon, action_dim] 176 | "server_timing": { 177 | "infer_ms": float, # Inference time in milliseconds 178 | "prev_total_ms": float, # Total time for previous request 179 | } 180 | } 181 | ``` 182 | 183 | ### Server Metadata 184 | 185 | When connecting, the server sends metadata: 186 | 187 | ```python 188 | { 189 | "action_dim": int, # Action space dimension 190 | "pred_horizon": int, # Number of future actions predicted 191 | "device": str, # Device model runs on 192 | "predict_mode": str, # Prediction mode (fast/diffusion) 193 | "env": str, # Environment name 194 | } 195 | ``` 196 | 197 | ### Health Check 198 | 199 | HTTP health check endpoint available at: 200 | ``` 201 | http://host:port/healthz 202 | ``` 203 | 204 | Returns `200 OK` if the server is running. 205 | 206 | ## Client Example 207 | 208 | ### Synchronous Python Client 209 | 210 | For synchronous usage, see `wall_x/serving/client.py`: 211 | 212 | ```python 213 | from wall_x.serving.client import WallXClient 214 | 215 | # Create and connect 216 | client = WallXClient(uri="ws://localhost:8000") 217 | client.connect_sync() 218 | 219 | # Prepare observation 220 | obs = { 221 | "front_view": image1, 222 | "left_wrist_view": image2, 223 | "prompt": "task description", 224 | "state": robot_state, 225 | "dataset_names": ["physical-intelligence/libero"], 226 | } 227 | 228 | # Get prediction 229 | response = client.predict_sync(obs) 230 | action = response["action"] 231 | 232 | # Close connection 233 | client.close_sync() 234 | ``` 235 | 236 | ## Architecture 237 | 238 | ### WebsocketPolicyServer 239 | 240 | Generic websocket server that: 241 | - Handles websocket connections with msgpack serialization 242 | - Tracks inference timing and performance metrics 243 | - Provides health check endpoint 244 | - Handles errors gracefully with proper logging 245 | - Supports concurrent client connections 246 | 247 | ### WallXPolicy 248 | 249 | Policy wrapper that: 250 | - Loads and manages the Wall-X model from pretrained checkpoint 251 | - Processes multi-camera observations 252 | - Handles image preprocessing (smart resize, normalization) 253 | - Manages device placement and dtype conversion 254 | - Provides policy metadata to clients 255 | - Supports both fast tokenizer and diffusion prediction modes 256 | 257 | ### Image Processing Pipeline 258 | 259 | 1. **Camera Key Matching**: Extracts images from observation dict using configured camera keys 260 | 2. **Format Conversion**: Converts numpy arrays to PIL Images 261 | 3. **Smart Resize**: Applies Qwen's smart resize algorithm based on min/max pixels 262 | 4. **Vision Token Formatting**: Inserts vision tokens in text prompt 263 | 5. **Batch Preparation**: Creates model-ready BatchFeature input 264 | -------------------------------------------------------------------------------- /csrc/window_index.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | __global__ void compute_metadata( 8 | const int *grid_thw, // [num_grids, 3] 9 | int *grid_info, // [num_grids, 6]: [grid_elements, grid_windows, llm_h, llm_w, num_windows_h, num_windows_w] 10 | int *global_totals, // [total_elements, total_windows] 11 | int num_grids, 12 | int spatial_merge_size, 13 | int vit_merger_window_size) 14 | { 15 | int grid_idx = blockIdx.x * blockDim.x + threadIdx.x; 16 | if (grid_idx >= num_grids) 17 | return; 18 | 19 | int grid_t = grid_thw[grid_idx * 3 + 0]; 20 | int grid_h = grid_thw[grid_idx * 3 + 1]; 21 | int grid_w = grid_thw[grid_idx * 3 + 2]; 22 | 23 | int llm_h = grid_h / spatial_merge_size; 24 | int llm_w = grid_w / spatial_merge_size; 25 | 26 | int pad_h = (vit_merger_window_size - llm_h % vit_merger_window_size) % vit_merger_window_size; 27 | int pad_w = (vit_merger_window_size - llm_w % vit_merger_window_size) % vit_merger_window_size; 28 | 29 | int num_windows_h = (llm_h + pad_h) / vit_merger_window_size; 30 | int num_windows_w = (llm_w + pad_w) / vit_merger_window_size; 31 | 32 | int grid_elements = grid_t * llm_h * llm_w; 33 | int grid_windows = grid_t * num_windows_h * num_windows_w; 34 | 35 | grid_info[grid_idx * 6 + 0] = grid_elements; 36 | grid_info[grid_idx * 6 + 1] = grid_windows; 37 | grid_info[grid_idx * 6 + 2] = llm_h; 38 | grid_info[grid_idx * 6 + 3] = llm_w; 39 | grid_info[grid_idx * 6 + 4] = num_windows_h; 40 | grid_info[grid_idx * 6 + 5] = num_windows_w; 41 | 42 | atomicAdd(&global_totals[0], grid_elements); 43 | atomicAdd(&global_totals[1], grid_windows); 44 | } 45 | 46 | __global__ void compute_window_counts( 47 | const int *grid_thw, 48 | const int *grid_info, 49 | int *window_counts, 50 | int vit_merger_window_size, 51 | int spatial_merge_unit, 52 | int num_grids) 53 | { 54 | int grid_idx = blockIdx.y; 55 | int t_idx = blockIdx.x; 56 | 57 | if (grid_idx >= num_grids) 58 | return; 59 | 60 | int grid_t = grid_thw[grid_idx * 3 + 0]; 61 | if (t_idx >= grid_t) 62 | return; 63 | 64 | int llm_h = grid_info[grid_idx * 6 + 2]; 65 | int llm_w = grid_info[grid_idx * 6 + 3]; 66 | int num_windows_h = grid_info[grid_idx * 6 + 4]; 67 | int num_windows_w = grid_info[grid_idx * 6 + 5]; 68 | 69 | int window_base = 0; 70 | for (int g = 0; g < grid_idx; g++) 71 | { 72 | window_base += grid_info[g * 6 + 1]; 73 | } 74 | 75 | int t_window_base = window_base + t_idx * num_windows_h * num_windows_w; 76 | 77 | int thread_id = threadIdx.x; 78 | int warp_id = thread_id / 32; 79 | int lane_id = thread_id % 32; 80 | 81 | int windows_per_t = num_windows_h * num_windows_w; 82 | int warps_per_block = blockDim.x / 32; 83 | 84 | if (lane_id == 0) 85 | { 86 | for (int window_idx = warp_id; window_idx < windows_per_t; window_idx += warps_per_block) 87 | { 88 | int win_h = window_idx / num_windows_w; 89 | int win_w = window_idx % num_windows_w; 90 | 91 | int start_h = win_h * vit_merger_window_size; 92 | int start_w = win_w * vit_merger_window_size; 93 | 94 | int valid_h = min(vit_merger_window_size, llm_h - start_h); 95 | int valid_w = min(vit_merger_window_size, llm_w - start_w); 96 | 97 | int valid_count = (valid_h > 0 && valid_w > 0) ? valid_h * valid_w : 0; 98 | 99 | window_counts[t_window_base + window_idx] = valid_count; 100 | } 101 | } 102 | } 103 | 104 | __global__ void compute_cu_window_seqlens( 105 | const int *window_counts, 106 | int *cu_window_seqlens, // [total_windows + 1] 107 | int total_windows, 108 | int spatial_merge_unit) 109 | { 110 | int tid = blockIdx.x * blockDim.x + threadIdx.x; 111 | 112 | if (tid == 0) 113 | { 114 | cu_window_seqlens[0] = 0; 115 | } 116 | 117 | if (tid < total_windows) 118 | { 119 | cu_window_seqlens[tid + 1] = window_counts[tid] * spatial_merge_unit; 120 | } 121 | 122 | __syncthreads(); 123 | 124 | if (tid == 0) 125 | { 126 | for (int i = 1; i <= total_windows; i++) 127 | { 128 | cu_window_seqlens[i] += cu_window_seqlens[i - 1]; 129 | } 130 | } 131 | } 132 | 133 | __global__ void generate_window_indices( 134 | const int *grid_thw, 135 | const int *grid_info, 136 | const int *cu_window_seqlens, 137 | int *window_indices, 138 | int vit_merger_window_size, 139 | int spatial_merge_unit, 140 | int num_grids) 141 | { 142 | int grid_idx = blockIdx.y; 143 | int t_idx = blockIdx.x; 144 | 145 | if (grid_idx >= num_grids) 146 | return; 147 | 148 | int grid_t = grid_thw[grid_idx * 3 + 0]; 149 | if (t_idx >= grid_t) 150 | return; 151 | 152 | int llm_h = grid_info[grid_idx * 6 + 2]; 153 | int llm_w = grid_info[grid_idx * 6 + 3]; 154 | int num_windows_h = grid_info[grid_idx * 6 + 4]; 155 | int num_windows_w = grid_info[grid_idx * 6 + 5]; 156 | 157 | int element_base = 0; 158 | for (int g = 0; g < grid_idx; g++) 159 | { 160 | element_base += grid_info[g * 6 + 0]; 161 | } 162 | int t_element_base = element_base + t_idx * llm_h * llm_w; 163 | 164 | int window_base = 0; 165 | for (int g = 0; g < grid_idx; g++) 166 | { 167 | window_base += grid_info[g * 6 + 1]; 168 | } 169 | int t_window_base = window_base + t_idx * num_windows_h * num_windows_w; 170 | 171 | int thread_id = threadIdx.x; 172 | int warp_id = thread_id / 32; 173 | int lane_id = thread_id % 32; 174 | 175 | int windows_per_t = num_windows_h * num_windows_w; 176 | int warps_per_block = blockDim.x / 32; 177 | 178 | for (int window_idx = warp_id; window_idx < windows_per_t; window_idx += warps_per_block) 179 | { 180 | int win_h = window_idx / num_windows_w; 181 | int win_w = window_idx % num_windows_w; 182 | 183 | int global_window_idx = t_window_base + window_idx; 184 | int output_offset = cu_window_seqlens[global_window_idx] / spatial_merge_unit; 185 | 186 | int start_h = win_h * vit_merger_window_size; 187 | int start_w = win_w * vit_merger_window_size; 188 | 189 | int valid_h = min(vit_merger_window_size, llm_h - start_h); 190 | int valid_w = min(vit_merger_window_size, llm_w - start_w); 191 | 192 | for (int elem_idx = lane_id; elem_idx < valid_h * valid_w; elem_idx += 32) 193 | { 194 | int local_h = elem_idx / valid_w; 195 | int local_w = elem_idx % valid_w; 196 | 197 | int abs_h = start_h + local_h; 198 | int abs_w = start_w + local_w; 199 | 200 | int value = t_element_base + abs_h * llm_w + abs_w; 201 | 202 | int base_offset = output_offset + elem_idx; 203 | window_indices[base_offset] = value; 204 | } 205 | } 206 | } 207 | 208 | std::tuple get_window_index_cuda( 209 | torch::Tensor grid_thw, 210 | int spatial_merge_size, 211 | int vit_merger_window_size, 212 | int patch_size, 213 | int spatial_merge_unit) 214 | { 215 | TORCH_CHECK(grid_thw.is_cuda(), "grid_thw must be a CUDA tensor"); 216 | TORCH_CHECK(grid_thw.dim() == 2 && grid_thw.size(1) == 3); 217 | TORCH_CHECK(grid_thw.dtype() == torch::kInt32); 218 | 219 | int num_grids = grid_thw.size(0); 220 | if (num_grids == 0) 221 | { 222 | return std::make_tuple( 223 | torch::empty({0}, grid_thw.options()), 224 | torch::zeros({1}, grid_thw.options())); 225 | } 226 | 227 | const int *d_grid_thw = grid_thw.data_ptr(); 228 | auto options = grid_thw.options(); 229 | 230 | auto grid_thw_cpu = grid_thw.cpu(); 231 | int max_grid_t = 0; 232 | for (int i = 0; i < num_grids; i++) 233 | { 234 | max_grid_t = std::max(max_grid_t, grid_thw_cpu[i][0].item()); 235 | } 236 | 237 | auto grid_info_tensor = torch::empty({num_grids, 6}, options); 238 | auto global_totals_tensor = torch::zeros({2}, options); 239 | 240 | int *d_grid_info = grid_info_tensor.data_ptr(); 241 | int *d_global_totals = global_totals_tensor.data_ptr(); 242 | 243 | int threads1 = 256; 244 | int blocks1 = (num_grids + threads1 - 1) / threads1; 245 | compute_metadata<<>>( 246 | d_grid_thw, d_grid_info, d_global_totals, 247 | num_grids, spatial_merge_size, vit_merger_window_size); 248 | 249 | auto totals_cpu = global_totals_tensor.cpu(); 250 | int total_elements = totals_cpu[0].item(); 251 | int total_windows = totals_cpu[1].item(); 252 | 253 | if (total_elements == 0 || total_windows == 0) 254 | { 255 | return std::make_tuple( 256 | torch::empty({0}, options), 257 | torch::zeros({1}, options)); 258 | } 259 | 260 | torch::Tensor window_indices = torch::empty({total_elements}, options); 261 | torch::Tensor cu_window_seqlens = torch::empty({total_windows + 1}, options); 262 | 263 | int *d_window_indices = window_indices.data_ptr(); 264 | int *d_cu_window_seqlens = cu_window_seqlens.data_ptr(); 265 | 266 | auto window_counts_tensor = torch::empty({total_windows}, options); 267 | int *d_window_counts = window_counts_tensor.data_ptr(); 268 | 269 | dim3 blocks2(max_grid_t, num_grids); 270 | dim3 threads2(256); 271 | 272 | compute_window_counts<<>>( 273 | d_grid_thw, d_grid_info, d_window_counts, 274 | vit_merger_window_size, spatial_merge_unit, num_grids); 275 | 276 | int threads4 = 256; 277 | int blocks4 = (total_windows + threads4 - 1) / threads4; 278 | compute_cu_window_seqlens<<>>( 279 | d_window_counts, d_cu_window_seqlens, total_windows, spatial_merge_unit); 280 | 281 | generate_window_indices<<>>( 282 | d_grid_thw, d_grid_info, d_cu_window_seqlens, d_window_indices, 283 | vit_merger_window_size, spatial_merge_unit, num_grids); 284 | 285 | return std::make_tuple(window_indices, cu_window_seqlens); 286 | } 287 | -------------------------------------------------------------------------------- /wall_x/model/qwen2_5_based/configuration_qwen2_5_vl.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.modeling_rope_utils import rope_config_validation 3 | 4 | 5 | class Qwen2_5_VLVisionConfig(PretrainedConfig): 6 | model_type = "qwen2_5_vl" 7 | base_config_key = "vision_config" 8 | 9 | def __init__( 10 | self, 11 | depth=32, 12 | hidden_size=3584, 13 | hidden_act="silu", 14 | intermediate_size=3420, 15 | num_heads=16, 16 | in_channels=3, 17 | patch_size=14, 18 | spatial_merge_size=2, 19 | temporal_patch_size=2, 20 | tokens_per_second=4, 21 | window_size=112, 22 | out_hidden_size=3584, 23 | fullatt_block_indexes=[7, 15, 23, 31], 24 | **kwargs, 25 | ): 26 | super().__init__(**kwargs) 27 | 28 | self.depth = depth 29 | self.hidden_size = hidden_size 30 | self.hidden_act = hidden_act 31 | self.intermediate_size = intermediate_size 32 | self.num_heads = num_heads 33 | self.in_channels = in_channels 34 | self.patch_size = patch_size 35 | self.spatial_merge_size = spatial_merge_size 36 | self.temporal_patch_size = temporal_patch_size 37 | self.tokens_per_second = tokens_per_second 38 | self.window_size = window_size 39 | self.fullatt_block_indexes = fullatt_block_indexes 40 | self.out_hidden_size = out_hidden_size 41 | 42 | 43 | class Qwen2_5_VLConfig(PretrainedConfig): 44 | r""" 45 | This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a 46 | Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration 47 | with the defaults will yield a similar configuration to that of 48 | Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). 49 | 50 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 51 | documentation from [`PretrainedConfig`] for more information. 52 | 53 | 54 | Args: 55 | vocab_size (`int`, *optional*, defaults to 152064): 56 | Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the 57 | `inputs_ids` passed when calling [`Qwen2_5_VLModel`] 58 | hidden_size (`int`, *optional*, defaults to 8192): 59 | Dimension of the hidden representations. 60 | intermediate_size (`int`, *optional*, defaults to 29568): 61 | Dimension of the MLP representations. 62 | num_hidden_layers (`int`, *optional*, defaults to 80): 63 | Number of hidden layers in the Transformer encoder. 64 | num_attention_heads (`int`, *optional*, defaults to 64): 65 | Number of attention heads for each attention layer in the Transformer encoder. 66 | num_key_value_heads (`int`, *optional*, defaults to 8): 67 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 68 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 69 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When 70 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 71 | by meanpooling all the original heads within that group. For more details checkout [this 72 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. 73 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 74 | The non-linear activation function (function or string) in the decoder. 75 | max_position_embeddings (`int`, *optional*, defaults to 32768): 76 | The maximum sequence length that this model might ever be used with. 77 | initializer_range (`float`, *optional*, defaults to 0.02): 78 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 79 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 80 | The epsilon used by the rms normalization layers. 81 | use_cache (`bool`, *optional*, defaults to `True`): 82 | Whether or not the model should return the last key/values attentions (not used by all models). Only 83 | relevant if `config.is_decoder=True`. 84 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 85 | Whether the model's input and output word embeddings should be tied. 86 | rope_theta (`float`, *optional*, defaults to 1000000.0): 87 | The base period of the RoPE embeddings. 88 | use_sliding_window (`bool`, *optional*, defaults to `False`): 89 | Whether to use sliding window attention. 90 | sliding_window (`int`, *optional*, defaults to 4096): 91 | Sliding window attention (SWA) window size. If not specified, will default to `4096`. 92 | max_window_layers (`int`, *optional*, defaults to 80): 93 | The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. 94 | attention_dropout (`float`, *optional*, defaults to 0.0): 95 | The dropout ratio for the attention probabilities. 96 | vision_config (`Dict`, *optional*): 97 | The config for the visual encoder initialization. 98 | rope_scaling (`Dict`, *optional*): 99 | Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type 100 | and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value 101 | accordingly. 102 | Expected contents: 103 | `rope_type` (`str`): 104 | The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 105 | 'llama3'], with 'default' being the original RoPE implementation. 106 | `factor` (`float`, *optional*): 107 | Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In 108 | most scaling types, a `factor` of x will enable the model to handle sequences of length x * 109 | original maximum pre-trained length. 110 | `original_max_position_embeddings` (`int`, *optional*): 111 | Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during 112 | pretraining. 113 | `attention_factor` (`float`, *optional*): 114 | Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention 115 | computation. If unspecified, it defaults to value recommended by the implementation, using the 116 | `factor` field to infer the suggested value. 117 | `beta_fast` (`float`, *optional*): 118 | Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear 119 | ramp function. If unspecified, it defaults to 32. 120 | `beta_slow` (`float`, *optional*): 121 | Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear 122 | ramp function. If unspecified, it defaults to 1. 123 | `short_factor` (`List[float]`, *optional*): 124 | Only used with 'longrope'. The scaling factor to be applied to short contexts (< 125 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 126 | size divided by the number of attention heads divided by 2 127 | `long_factor` (`List[float]`, *optional*): 128 | Only used with 'longrope'. The scaling factor to be applied to long contexts (< 129 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 130 | size divided by the number of attention heads divided by 2 131 | `low_freq_factor` (`float`, *optional*): 132 | Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE 133 | `high_freq_factor` (`float`, *optional*): 134 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE 135 | 136 | ```python 137 | >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig 138 | 139 | >>> # Initializing a Qwen2_5_VL style configuration 140 | >>> configuration = Qwen2_5_VLConfig() 141 | 142 | >>> # Initializing a model from the Qwen2-VL-7B style configuration 143 | >>> model = Qwen2_5_VLForConditionalGeneration(configuration) 144 | 145 | >>> # Accessing the model configuration 146 | >>> configuration = model.config 147 | ```""" 148 | 149 | model_type = "qwen2_5_vl" 150 | sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} 151 | keys_to_ignore_at_inference = ["past_key_values"] 152 | # Default tensor parallel plan for base model `Qwen2_5_VL` 153 | base_model_tp_plan = { 154 | "layers.*.self_attn.q_proj": "colwise", 155 | "layers.*.self_attn.k_proj": "colwise", 156 | "layers.*.self_attn.v_proj": "colwise", 157 | "layers.*.self_attn.o_proj": "rowwise", 158 | "layers.*.mlp.gate_proj": "colwise", 159 | "layers.*.mlp.up_proj": "colwise", 160 | "layers.*.mlp.down_proj": "rowwise", 161 | } 162 | base_model_pp_plan = { 163 | "embed_tokens": (["input_ids"], ["inputs_embeds"]), 164 | "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), 165 | "norm": (["hidden_states"], ["hidden_states"]), 166 | } 167 | 168 | def __init__( 169 | self, 170 | vocab_size=152064, 171 | hidden_size=8192, 172 | intermediate_size=29568, 173 | num_hidden_layers=80, 174 | num_attention_heads=64, 175 | num_key_value_heads=8, 176 | hidden_act="silu", 177 | max_position_embeddings=32768, 178 | initializer_range=0.02, 179 | rms_norm_eps=1e-05, 180 | use_cache=True, 181 | tie_word_embeddings=False, 182 | rope_theta=1000000.0, 183 | use_sliding_window=False, 184 | sliding_window=4096, 185 | max_window_layers=80, 186 | attention_dropout=0.0, 187 | vision_config=None, 188 | rope_scaling=None, 189 | num_experts=4, 190 | experts=None, 191 | dof_config=None, 192 | noise_scheduler=None, 193 | dim_inputs=(1536, 1536), 194 | attention_moe=False, 195 | mlp_moe=False, 196 | **kwargs, 197 | ): 198 | if isinstance(vision_config, dict): 199 | self.vision_config = self.sub_configs["vision_config"](**vision_config) 200 | elif vision_config is None: 201 | self.vision_config = self.sub_configs["vision_config"]() 202 | 203 | self.vocab_size = vocab_size 204 | self.max_position_embeddings = max_position_embeddings 205 | self.hidden_size = hidden_size 206 | self.intermediate_size = intermediate_size 207 | self.num_hidden_layers = num_hidden_layers 208 | self.num_attention_heads = num_attention_heads 209 | self.use_sliding_window = use_sliding_window 210 | self.sliding_window = sliding_window 211 | self.max_window_layers = max_window_layers 212 | 213 | # for backward compatibility 214 | if num_key_value_heads is None: 215 | num_key_value_heads = num_attention_heads 216 | 217 | self.num_key_value_heads = num_key_value_heads 218 | self.hidden_act = hidden_act 219 | self.initializer_range = initializer_range 220 | self.rms_norm_eps = rms_norm_eps 221 | self.use_cache = use_cache 222 | self.rope_theta = rope_theta 223 | self.attention_dropout = attention_dropout 224 | self.rope_scaling = rope_scaling 225 | 226 | self.num_experts = num_experts 227 | self.experts = experts 228 | self.dof_config = dof_config 229 | self.noise_scheduler = noise_scheduler 230 | self.dim_inputs = tuple(dim_inputs) 231 | self.attention_moe = attention_moe 232 | self.mlp_moe = mlp_moe 233 | 234 | # Validate the correctness of rotary position embeddings parameters 235 | # BC: if there is a 'type' field, move it to 'rope_type'. 236 | # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations 237 | # one can set it to "linear"/"dynamic" etc. to have scaled RoPE 238 | # TODO: @raushan update config in the hub 239 | if self.rope_scaling is not None and "type" in self.rope_scaling: 240 | if self.rope_scaling["type"] == "mrope": 241 | self.rope_scaling["type"] = "default" 242 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 243 | rope_config_validation(self, ignore_keys={"mrope_section"}) 244 | 245 | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) 246 | 247 | 248 | __all__ = ["Qwen2_5_VLConfig"] 249 | -------------------------------------------------------------------------------- /wall_x/serving/client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Example client for Wall-X model server with sync support. 4 | 5 | This script demonstrates how to connect to a Wall-X server and request 6 | action predictions from observations in both sync and async contexts. 7 | """ 8 | 9 | import asyncio 10 | import logging 11 | from typing import Dict, List 12 | import numpy as np 13 | import threading 14 | import yaml 15 | import torch 16 | import matplotlib.pyplot as plt 17 | import os 18 | 19 | from wall_x.data.utils import update_action_statistics 20 | from wall_x.utils.constant import action_statistic_dof 21 | from wall_x.model.action_head import Normalizer 22 | 23 | try: 24 | import msgpack 25 | import msgpack_numpy as m 26 | 27 | m.patch() 28 | except ImportError: 29 | print("Please install msgpack-numpy: pip install msgpack-numpy") 30 | exit(1) 31 | 32 | try: 33 | import websockets 34 | except ImportError: 35 | print("Please install websockets: pip install websockets") 36 | exit(1) 37 | 38 | logging.basicConfig(level=logging.INFO) 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | class WallXClient: 43 | """Client for connecting to Wall-X model server.""" 44 | 45 | def __init__( 46 | self, 47 | config_path: str, 48 | uri: str = "ws://localhost:8000", 49 | norm_stats_path: str = "x2_norm_stats.json", 50 | ): 51 | """Initialize client. 52 | 53 | Args: 54 | config_path: Path to train config file 55 | uri: WebSocket URI of the server (e.g., ws://localhost:8000) 56 | norm_stats_path: Path to normalization stats file 57 | """ 58 | self.uri = uri 59 | self.websocket = None 60 | self.metadata = None 61 | self._loop = None 62 | self._thread = None 63 | self.norm_stats_path = norm_stats_path 64 | 65 | with open(config_path, "r") as f: 66 | self.train_config = yaml.load(f, Loader=yaml.FullLoader) 67 | 68 | self.init_normalizer(self.train_config) 69 | 70 | async def connect(self): 71 | """Connect to the server and receive metadata.""" 72 | logger.info(f"Connecting to {self.uri}...") 73 | self.websocket = await websockets.connect( 74 | self.uri, 75 | ping_interval=None, 76 | ping_timeout=None, 77 | max_size=None, 78 | ) 79 | 80 | self.metadata = msgpack.unpackb(await self.websocket.recv()) 81 | logger.info(f"Connected! Server metadata: {self.metadata}") 82 | 83 | async def predict(self, obs: Dict) -> Dict: 84 | """Get action prediction from observation. 85 | 86 | Args: 87 | obs: Observation dictionary containing: 88 | - 'image': Image array (H, W, C) 89 | - 'prompt': Optional text prompt 90 | - 'state': Optional robot state 91 | 92 | Returns: 93 | Dictionary with: 94 | - 'action': Predicted action array 95 | - 'server_timing': Timing information 96 | """ 97 | if self.websocket is None: 98 | raise RuntimeError("Not connected. Call connect() first.") 99 | 100 | await self.websocket.send(msgpack.packb(obs)) 101 | response = msgpack.unpackb(await self.websocket.recv()) 102 | return response 103 | 104 | async def close(self): 105 | """Close the connection.""" 106 | if self.websocket: 107 | await self.websocket.close() 108 | logger.info("Connection closed") 109 | 110 | async def reset(self): 111 | """Reset the policy (if supported).""" 112 | pass 113 | 114 | # ============ Synchronous methods (using independent thread event loop) ============ 115 | 116 | def _start_background_loop(self): 117 | """Start event loop in background thread.""" 118 | self._loop = asyncio.new_event_loop() 119 | asyncio.set_event_loop(self._loop) 120 | self._loop.run_forever() 121 | 122 | def _ensure_loop(self): 123 | """Ensure background event loop is running.""" 124 | if self._loop is None or not self._loop.is_running(): 125 | self._thread = threading.Thread( 126 | target=self._start_background_loop, daemon=True 127 | ) 128 | self._thread.start() 129 | # Wait for loop to start 130 | import time 131 | 132 | while self._loop is None: 133 | time.sleep(0.01) 134 | 135 | def _run_async(self, coro): 136 | """Run coroutine in background event loop.""" 137 | self._ensure_loop() 138 | future = asyncio.run_coroutine_threadsafe(coro, self._loop) 139 | return future.result() 140 | 141 | def connect_sync(self): 142 | """Synchronously connect to server.""" 143 | return self._run_async(self.connect()) 144 | 145 | def norm_state( 146 | self, 147 | state: np.ndarray, 148 | dataset_names: List[str], 149 | state_mask: torch.Tensor = None, 150 | ) -> np.ndarray: 151 | """Normalize state.""" 152 | return self.normalizer_propri.normalize_data(state, dataset_names, state_mask) 153 | 154 | def predict_sync(self, obs: Dict) -> Dict: 155 | """Synchronous prediction method. 156 | 157 | Args: 158 | obs: Observation dictionary 159 | 160 | Returns: 161 | Prediction result dictionary 162 | """ 163 | return self._run_async(self.predict(obs)) 164 | 165 | def close_sync(self): 166 | """Synchronously close connection.""" 167 | result = self._run_async(self.close()) 168 | # Stop event loop 169 | if self._loop: 170 | self._loop.call_soon_threadsafe(self._loop.stop) 171 | return result 172 | 173 | def init_normalizer(self, train_config): 174 | # Define default configurations 175 | dof_config = {"biarm_eed_with_base": 20} 176 | 177 | agent_pos_config = {"biarm_eed_with_base": 20} 178 | 179 | update_action_statistics( 180 | action_statistic_dof=action_statistic_dof, 181 | norm_stats_path=self.norm_stats_path, 182 | repo_id="x2", 183 | dof_config=dof_config, 184 | agent_pos_config=agent_pos_config, 185 | ) 186 | 187 | self.normalizer_action = Normalizer(action_statistic_dof, dof_config) 188 | self.normalizer_propri = Normalizer(action_statistic_dof, agent_pos_config) 189 | 190 | print("Normalizer initialized") 191 | 192 | 193 | def prepare_batch_sync(data, normalizer_action, normalizer_propri, dataset_names): 194 | """Synchronous version of prepare_batch.""" 195 | image = (data["image"].permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() 196 | wrist_image = ( 197 | (data["wrist_image"].permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() 198 | ) 199 | prompt = data["task"] 200 | 201 | state = data["state"].to("cuda") 202 | if state.dim() == 1: 203 | state = state.unsqueeze(0) 204 | 205 | state_mask = torch.ones([1, 32, 20]).to("cuda") 206 | state_mask[:, :, 8:] = 0 207 | 208 | state = normalizer_propri.normalize_data(state, dataset_names, state_mask) 209 | state = state.cpu().numpy().astype(np.float32) 210 | 211 | obs = { 212 | "front_view": image, 213 | "left_wrist_view": wrist_image, 214 | "prompt": prompt, 215 | "state": state, 216 | "dataset_names": dataset_names, 217 | } 218 | return obs 219 | 220 | 221 | def init_serving_sample_dataset(train_config): 222 | from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata 223 | 224 | repo_id = train_config["data"]["lerobot_config"]["repo_id"] 225 | 226 | meta_info = LeRobotDatasetMetadata(repo_id) 227 | dataset_fps = meta_info.fps 228 | delta_timestamps = { 229 | "actions": [t / dataset_fps for t in range(32)], 230 | } 231 | dataset = LeRobotDataset( 232 | repo_id, 233 | episodes=[0], 234 | delta_timestamps=delta_timestamps, 235 | video_backend="pyav", 236 | ) 237 | 238 | return dataset, repo_id 239 | 240 | 241 | # ============ Synchronous version of main function ============ 242 | 243 | 244 | def main_sync(args): 245 | """Synchronous version of main function.""" 246 | 247 | # Create client and connect 248 | client = WallXClient( 249 | args.config_path, uri=args.uri, norm_stats_path=args.norm_stats_path 250 | ) 251 | client.connect_sync() 252 | 253 | dataset, repo_id = init_serving_sample_dataset(client.train_config) 254 | 255 | total_frames = len(dataset) 256 | gt_traj = np.zeros((total_frames, args.action_dim)) 257 | pred_traj = np.zeros((total_frames, args.action_dim)) 258 | import torch 259 | 260 | dof_mask = torch.ones([1, 32, 20]).to("cuda") 261 | dof_mask[:, :, args.action_dim :] = 0 262 | 263 | # Synchronous processing 264 | for idx, data in enumerate(dataset): 265 | if idx % args.pred_horizon == 0 and idx + args.pred_horizon < total_frames: 266 | print(f"Processing frame {idx}") 267 | obs = prepare_batch_sync( 268 | data, 269 | client.normalizer_action, 270 | client.normalizer_propri, 271 | dataset_names=[repo_id], 272 | ) 273 | response = client.predict_sync(obs) 274 | pred_action = response["action"] 275 | pred_traj[idx : idx + args.pred_horizon] = pred_action 276 | gt_traj[idx : idx + args.pred_horizon] = data["actions"] 277 | 278 | # Draw plot 279 | timesteps = gt_traj.shape[0] 280 | fig, axs = plt.subplots( 281 | args.action_dim, 1, figsize=(15, 5 * args.action_dim), sharex=True 282 | ) 283 | fig.suptitle("Action Comparison for lerobot", fontsize=16) 284 | 285 | for i in range(args.action_dim): 286 | axs[i].plot(range(timesteps), gt_traj[:, i], label="Ground Truth") 287 | axs[i].plot(range(timesteps), pred_traj[:, i], label="Prediction") 288 | axs[i].set_ylabel(f"Action Dim {i+1}") 289 | axs[i].legend() 290 | axs[i].grid(True) 291 | 292 | axs[-1].set_xlabel("Timestep") 293 | plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 294 | os.makedirs(args.save_dir, exist_ok=True) 295 | save_path = os.path.join(args.save_dir, "lerobot_comparison_serving.png") 296 | plt.savefig(save_path) 297 | print(f"Saved plot to {save_path}") 298 | plt.close() 299 | 300 | # Close connection 301 | client.close_sync() 302 | 303 | 304 | # ============ Asynchronous version of main function (keep original functionality) ============ 305 | 306 | 307 | async def main(args): 308 | client = WallXClient( 309 | args.config_path, uri=args.uri, norm_stats_path=args.norm_stats_path 310 | ) 311 | await client.connect() 312 | dataset, repo_id = init_serving_sample_dataset(client.train_config) 313 | 314 | total_frames = len(dataset) 315 | gt_traj = np.zeros((total_frames, args.action_dim)) 316 | pred_traj = np.zeros((total_frames, args.action_dim)) 317 | 318 | for idx, data in enumerate(dataset): 319 | if idx % args.pred_horizon == 0 and idx + args.pred_horizon < total_frames: 320 | print(f"Processing frame {idx}") 321 | obs = prepare_batch_sync( 322 | data, 323 | client.normalizer_action, 324 | client.normalizer_propri, 325 | dataset_names=[repo_id], 326 | ) 327 | response = await client.predict(obs) 328 | pred_action = response["action"] 329 | print(pred_action.shape) 330 | pred_traj[idx : idx + args.pred_horizon] = pred_action 331 | gt_traj[idx : idx + args.pred_horizon] = data["actions"] 332 | 333 | timesteps = gt_traj.shape[0] 334 | 335 | fig, axs = plt.subplots( 336 | args.action_dim, 1, figsize=(15, 5 * args.action_dim), sharex=True 337 | ) 338 | fig.suptitle("Action Comparison for lerobot", fontsize=16) 339 | 340 | for i in range(args.action_dim): 341 | axs[i].plot(range(timesteps), gt_traj[:, i], label="Ground Truth") 342 | axs[i].plot(range(timesteps), pred_traj[:, i], label="Prediction") 343 | axs[i].set_ylabel(f"Action Dim {i+1}") 344 | axs[i].legend() 345 | axs[i].grid(True) 346 | 347 | axs[-1].set_xlabel("Timestep") 348 | plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 349 | os.makedirs(args.save_dir, exist_ok=True) 350 | save_path = os.path.join(args.save_dir, "lerobot_comparison_serving.png") 351 | plt.savefig(save_path) 352 | print(f"Saved plot to {save_path}") 353 | plt.close() 354 | 355 | 356 | if __name__ == "__main__": 357 | """Asynchronous version of main function.""" 358 | import argparse 359 | 360 | parser = argparse.ArgumentParser(description="Wall-X client examples") 361 | parser.add_argument( 362 | "--example", 363 | choices=["single", "multiple", "benchmark"], 364 | default="single", 365 | help="Example to run", 366 | ) 367 | parser.add_argument( 368 | "--uri", 369 | default="ws://localhost:8000", 370 | help="Server URI", 371 | ) 372 | parser.add_argument( 373 | "--pred_horizon", type=int, default=32, help="Prediction horizon" 374 | ) 375 | parser.add_argument("--action_dim", type=int, default=7, help="Action dimension") 376 | parser.add_argument( 377 | "--config_path", 378 | default="config_from_qwen_libero.yml", 379 | help="Train config path", 380 | ) 381 | parser.add_argument( 382 | "--save_dir", 383 | default="libero", 384 | help="Save directory", 385 | ) 386 | parser.add_argument( 387 | "--norm_stats_path", 388 | default="x2_norm_stats.json", 389 | help="Normalization stats path", 390 | ) 391 | args = parser.parse_args() 392 | 393 | # Synchronous mode 394 | main_sync(args) 395 | 396 | # Asynchronous mode 397 | # asyncio.run(main(args)) 398 | -------------------------------------------------------------------------------- /wall_x/utils/constant.py: -------------------------------------------------------------------------------- 1 | action_statistic_dof = { 2 | "x2_normal": { 3 | # water flowers 4 | "follow_left_arm_joint_cur": { 5 | "min": [-3.7121], 6 | "delta": [7.6008], 7 | }, 8 | "follow_right_arm_joint_cur": { 9 | "min": [-3.6176], 10 | "delta": [8.5015], 11 | }, 12 | "follow_left_ee_cartesian_pos": { 13 | "min": [-0.036, -0.3241, -0.1245], 14 | "delta": [0.4389, 0.557, 0.479], 15 | }, 16 | "follow_left_ee_rotation": { 17 | "min": [-1.2373, -0.1929, -1.5182], 18 | "delta": [2.2009, 1.5669, 2.0936], 19 | }, 20 | "follow_left_gripper": {"min": [-0.1196], "delta": [4.5226]}, 21 | "follow_right_ee_cartesian_pos": { 22 | "min": [-0.0326, -0.2273, -0.1377], 23 | "delta": [0.4574, 0.5704, 0.4743], 24 | }, 25 | "follow_right_ee_rotation": { 26 | "min": [-1.2201, -0.2611, -0.7427], 27 | "delta": [2.6623, 1.6622, 2.4186], 28 | }, 29 | "follow_right_gripper": {"min": [-0.1208], "delta": [4.5261]}, 30 | "height": {"min": [-0.0001], "delta": [0.5051]}, 31 | "head_actions": {"min": [-1.5000, -1.4167], "delta": [2.5000, 1.8879]}, 32 | "base_velocity": { 33 | "min": [-0.0359, -0.084, -0.0162], 34 | "delta": [0.1539, 0.1848, 0.0322], 35 | }, 36 | }, 37 | "DobbE": { 38 | "follow_right_ee_cartesian_pos": { 39 | "min": [-0.6107, -0.3272, -0.4282], 40 | "delta": [1.2629, 1.5297, 0.8349], 41 | }, 42 | "follow_right_ee_rotation": { 43 | "min": [-1.7378, -1.4597, -1.8712], 44 | "delta": [2.7031, 2.8182, 3.5921], 45 | }, 46 | "follow_right_gripper": {"min": [0.0], "delta": [0.9983]}, 47 | }, 48 | "RH20T": { 49 | "follow_right_ee_cartesian_pos": { 50 | "min": [0.3646, -0.2722, 0.0066], 51 | "delta": [0.3813, 0.5973, 0.3277], 52 | }, 53 | "follow_right_ee_rotation": { 54 | "min": [-1.8716, -0.4398, -3.1414], 55 | "delta": [3.4145, 1.0225, 6.2828], 56 | }, 57 | "follow_right_gripper": {"min": [0.0], "delta": [95.0]}, 58 | }, 59 | "agibotworld_alpha": { 60 | "follow_left_ee_cartesian_pos": { 61 | "min": [0.4954, 0.0166, 0.1729], 62 | "delta": [0.3336, 0.5123, 0.9189], 63 | }, 64 | "follow_left_ee_rotation": { 65 | "min": [-3.1064, -1.2629, -3.1238], 66 | "delta": [6.2127, 2.5923, 6.2496], 67 | }, 68 | "follow_left_gripper": {"min": [34.6222], "delta": [86.1921]}, 69 | "follow_right_ee_cartesian_pos": { 70 | "min": [0.4615, -0.5975, 0.1638], 71 | "delta": [0.3823, 0.5577, 0.8873], 72 | }, 73 | "follow_right_ee_rotation": { 74 | "min": [-3.0891, -1.0739, -2.5091], 75 | "delta": [6.1707, 2.3074, 3.8533], 76 | }, 77 | "follow_right_gripper": {"min": [34.6222], "delta": [85.7635]}, 78 | "height": {"min": [0.0], "delta": [0.4535]}, 79 | "head_actions": {"min": [-0.1746, 0.0523], "delta": [0.2444, 0.4713]}, 80 | }, 81 | "austin_buds": { 82 | "follow_right_ee_cartesian_pos": { 83 | "min": [0.3496, -0.2855, 0.0105], 84 | "delta": [0.3748, 0.492, 0.3116], 85 | }, 86 | "follow_right_ee_rotation": { 87 | "min": [-3.1405, -0.151, -0.0737], 88 | "delta": [6.2813, 0.3218, 0.1536], 89 | }, 90 | "follow_right_gripper": {"min": [0.0076], "delta": [0.0724]}, 91 | }, 92 | "austin_sailor": { 93 | "follow_right_ee_cartesian_pos": { 94 | "min": [0.387, -0.3165, 0.0244], 95 | "delta": [0.2999, 0.5252, 0.2308], 96 | }, 97 | "follow_right_ee_rotation": { 98 | "min": [-3.1402, -0.1618, -1.5918], 99 | "delta": [6.2804, 0.337, 2.9478], 100 | }, 101 | "follow_right_gripper": {"min": [0.0005], "delta": [0.0773]}, 102 | }, 103 | "austin_sirius": { 104 | "follow_right_ee_cartesian_pos": { 105 | "min": [0.0, -0.1182, 0.0], 106 | "delta": [0.5329, 0.3812, 0.2723], 107 | }, 108 | "follow_right_ee_rotation": { 109 | "min": [-3.1407, -0.1243, -1.7434], 110 | "delta": [6.2823, 0.1975, 1.8073], 111 | }, 112 | "follow_right_gripper": {"min": [0.0334], "delta": [0.046]}, 113 | }, 114 | "bc_z": { 115 | "follow_right_ee_cartesian_pos": { 116 | "min": [-0.3883, -0.1116, 0.6113], 117 | "delta": [0.7199, 0.4288, 0.3709], 118 | }, 119 | "follow_right_ee_rotation": { 120 | "min": [-1.056, -1.0587, -2.6295], 121 | "delta": [1.9142, 1.9455, 4.8064], 122 | }, 123 | "follow_right_gripper": {"min": [0.2], "delta": [0.8]}, 124 | }, 125 | "berkeley_autolab_ur5": { 126 | "follow_right_ee_cartesian_pos": { 127 | "min": [0.3018, -0.2129, -0.1888], 128 | "delta": [0.3121, 0.52, 0.3107], 129 | }, 130 | "follow_right_ee_rotation": { 131 | "min": [-3.1396, -0.2278, 1.1413], 132 | "delta": [6.279, 0.454, 0.9841], 133 | }, 134 | "follow_right_gripper": {"min": [0.0], "delta": [1.0]}, 135 | }, 136 | "berkeley_cable_routing": { 137 | "follow_right_ee_cartesian_pos": { 138 | "min": [0.4617, -0.28, 0.03], 139 | "delta": [0.1838, 0.5665, 0.1272], 140 | }, 141 | "follow_right_ee_rotation": { 142 | "min": [-3.1413, -0.0299, -0.7665], 143 | "delta": [6.2826, 0.0692, 3.322], 144 | }, 145 | }, 146 | "berkeley_fanuc_manipulation": { 147 | "follow_right_ee_cartesian_pos": { 148 | "min": [0.3718, -0.4072, 0.0184], 149 | "delta": [0.3483, 0.7201, 0.5229], 150 | }, 151 | "follow_right_ee_rotation": { 152 | "min": [-3.1399, -1.0166, -1.6988], 153 | "delta": [6.2802, 1.4498, 3.2074], 154 | }, 155 | "follow_right_gripper": {"min": [0.0], "delta": [1.0]}, 156 | }, 157 | "bridge_data_v2": { 158 | "follow_right_ee_cartesian_pos": { 159 | "min": [0.1498, -0.2178, -0.0901], 160 | "delta": [0.3012, 0.469, 0.298], 161 | }, 162 | "follow_right_ee_rotation": { 163 | "min": [-0.3279, -0.6105, -1.0578], 164 | "delta": [0.7378, 1.0353, 2.2552], 165 | }, 166 | "follow_right_gripper": {"min": [0.0692], "delta": [0.9426]}, 167 | }, 168 | "dlr_edan_shared_control": { 169 | "follow_right_ee_cartesian_pos": { 170 | "min": [-0.8387, 0.1473, -0.3934], 171 | "delta": [0.6579, 0.6025, 1.1566], 172 | }, 173 | "follow_right_ee_rotation": { 174 | "min": [-3.1217, -1.5197, -2.2516], 175 | "delta": [6.2505, 1.5594, 4.2831], 176 | }, 177 | "follow_right_gripper": {"min": [0.0], "delta": [1.0]}, 178 | }, 179 | "droid": { 180 | "follow_right_ee_cartesian_pos": { 181 | "min": [0.2667, -0.4396, -0.0472], 182 | "delta": [0.5159, 0.8806, 0.8331], 183 | }, 184 | "follow_right_ee_rotation": { 185 | "min": [-3.1374, -1.216, -2.1741], 186 | "delta": [6.2749, 2.1075, 4.2259], 187 | }, 188 | "follow_right_gripper": {"min": [0.0], "delta": [0.9912]}, 189 | }, 190 | "fmb": { 191 | "follow_right_ee_cartesian_pos": { 192 | "min": [0.3554, -0.2844, 0.0354], 193 | "delta": [0.336, 0.4961, 0.2943], 194 | }, 195 | "follow_right_ee_rotation": { 196 | "min": [-3.1404, -0.9302, -0.0599], 197 | "delta": [6.2807, 1.724, 1.8284], 198 | }, 199 | "follow_right_gripper": {"min": [0.0], "delta": [1.0]}, 200 | }, 201 | "fractal": { 202 | "follow_right_ee_cartesian_pos": { 203 | "min": [0.3242, -0.2836, 0.1405], 204 | "delta": [0.5518, 0.4963, 0.9328], 205 | }, 206 | "follow_right_ee_rotation": { 207 | "min": [-3.1308, -0.2421, -2.9685], 208 | "delta": [6.2609, 1.7343, 5.819], 209 | }, 210 | "follow_right_gripper": {"min": [0.0], "delta": [1.0]}, 211 | }, 212 | "furniture_bench": { 213 | "follow_right_ee_cartesian_pos": { 214 | "min": [0.3691, -0.181, 0.0058], 215 | "delta": [0.2962, 0.3582, 0.1775], 216 | }, 217 | "follow_right_ee_rotation": { 218 | "min": [-3.1394, -0.6121, -1.9958], 219 | "delta": [6.2786, 1.6114, 3.7748], 220 | }, 221 | "follow_right_gripper": {"min": [0.0035], "delta": [0.0762]}, 222 | }, 223 | "jaco_play": { 224 | "follow_right_ee_cartesian_pos": { 225 | "min": [-0.3787, -0.6294, 0.1682], 226 | "delta": [0.5898, 0.3587, 0.2183], 227 | }, 228 | "follow_right_ee_rotation": { 229 | "min": [0.9792, -0.0668, -0.0498], 230 | "delta": [0.0175, 0.1277, 0.0686], 231 | }, 232 | "follow_right_gripper": {"min": [0.0791], "delta": [0.1033]}, 233 | }, 234 | "nyu_rot": { 235 | "follow_right_ee_cartesian_pos": { 236 | "min": [0.25, -1.0, -0.2], 237 | "delta": [0.75, 2.0, 1.2], 238 | }, 239 | "follow_right_ee_rotation": { 240 | "min": [-3.1416, -3.1416, 6.2831], 241 | "delta": [9.4248, 4.1416, 0.0], 242 | }, 243 | "follow_right_gripper": {"min": [0.0], "delta": [1.0]}, 244 | }, 245 | "stanford_hydra": { 246 | "follow_right_ee_cartesian_pos": { 247 | "min": [0.2068, -0.274, 0.1317], 248 | "delta": [0.4929, 0.4981, 0.4588], 249 | }, 250 | "follow_right_ee_rotation": { 251 | "min": [-3.1321, -0.7496, -3.0269], 252 | "delta": [6.2658, 1.5176, 5.8261], 253 | }, 254 | "follow_right_gripper": {"min": [0.0], "delta": [0.0811]}, 255 | }, 256 | "stanford_kuka_multimodal": { 257 | "follow_right_ee_cartesian_pos": { 258 | "min": [0.4781, -0.0659, 0.3424], 259 | "delta": [0.0868, 0.0864, 0.1863], 260 | }, 261 | "follow_right_ee_rotation": { 262 | "min": [-3.136, -0.0521, -3.1413], 263 | "delta": [6.2727, 0.1109, 6.2825], 264 | }, 265 | "follow_right_gripper": {"min": [-0.4713], "delta": [0.9485]}, 266 | }, 267 | "taco_play": { 268 | "follow_right_ee_cartesian_pos": { 269 | "min": [0.1375, -0.4291, 0.2052], 270 | "delta": [0.5327, 1.0237, 0.3913], 271 | }, 272 | "follow_right_ee_rotation": { 273 | "min": [-3.1391, -0.6946, -1.2808], 274 | "delta": [6.2784, 0.8196, 3.0856], 275 | }, 276 | "follow_right_gripper": {"min": [0.0001], "delta": [0.0806]}, 277 | }, 278 | "utaustin_mutex": { 279 | "follow_right_ee_cartesian_pos": { 280 | "min": [0.3213, -0.4734, 0.0141], 281 | "delta": [0.2108, 0.8471, 0.5644], 282 | }, 283 | "follow_right_ee_rotation": { 284 | "min": [-3.1404, -0.2202, -1.5489], 285 | "delta": [6.2805, 0.582, 1.9282], 286 | }, 287 | "follow_right_gripper": {"min": [0.0019], "delta": [0.0738]}, 288 | }, 289 | "viola": { 290 | "follow_right_ee_cartesian_pos": { 291 | "min": [0.4011, -0.2521, 0.0103], 292 | "delta": [0.2444, 0.4305, 0.4355], 293 | }, 294 | "follow_right_ee_rotation": { 295 | "min": [-3.1403, -0.2737, -1.8626], 296 | "delta": [6.2804, 0.4901, 2.0618], 297 | }, 298 | "follow_right_gripper": {"min": [0.0002], "delta": [0.0773]}, 299 | }, 300 | "kuka": { 301 | "follow_right_ee_cartesian_pos": { 302 | "min": [0.3914, -0.4901, 0.0175], 303 | "delta": [0.3339, 0.8357, 0.9064], 304 | }, 305 | "follow_right_ee_rotation": { 306 | "min": [-3.1416, -0.9903, -3.1416], 307 | "delta": [6.2832, 2.2421, 6.2832], 308 | }, 309 | "follow_right_gripper": { 310 | "min": [0.0000], 311 | "delta": [1.0000], 312 | }, 313 | }, 314 | "UMI-biarm": { 315 | "follow_left_ee_cartesian_pos": { 316 | "min": [-0.2917, -0.4926, 0.0063], 317 | "delta": [0.9028, 0.8168, 0.3473], 318 | }, 319 | "follow_left_ee_rotation": { 320 | "min": [-2.5309, -1.5706, -1.3309], 321 | "delta": [0.8758, 2.3315, 1.7076], 322 | }, 323 | "follow_left_gripper": { 324 | "min": [0.0029], 325 | "delta": [0.0812], 326 | }, 327 | "follow_right_ee_cartesian_pos": { 328 | "min": [-0.0023, -0.5191, -0.0358], 329 | "delta": [0.7474, 0.8668, 0.351], 330 | }, 331 | "follow_right_ee_rotation": { 332 | "min": [-2.4945, -2.0149, -0.8088], 333 | "delta": [1.0941, 3.2628, 2.0018], 334 | }, 335 | "follow_right_gripper": { 336 | "min": [0.0019], 337 | "delta": [0.0814], 338 | }, 339 | }, 340 | "agibotworld_beta": { 341 | "follow_left_ee_cartesian_pos": { 342 | "min": [0.4954, 0.0166, 0.1729], 343 | "delta": [0.3336, 0.5123, 0.9189], 344 | }, 345 | "follow_left_ee_rotation": { 346 | "min": [-3.1064, -1.2629, -3.1238], 347 | "delta": [6.2127, 2.5923, 6.2496], 348 | }, 349 | "follow_left_gripper": {"min": [34.6222], "delta": [86.1921]}, 350 | "follow_right_ee_cartesian_pos": { 351 | "min": [0.4615, -0.5975, 0.1638], 352 | "delta": [0.3823, 0.5577, 0.8873], 353 | }, 354 | "follow_right_ee_rotation": { 355 | "min": [-3.0891, -1.0739, -2.5091], 356 | "delta": [6.1707, 2.3074, 3.8533], 357 | }, 358 | "follow_right_gripper": {"min": [34.6222], "delta": [85.7635]}, 359 | "height": {"min": [0.0], "delta": [0.4535]}, 360 | "head_actions": {"min": [-0.1746, 0.0523], "delta": [0.2444, 0.4713]}, 361 | }, 362 | } 363 | -------------------------------------------------------------------------------- /csrc/rot_pos.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | // CUDA kernel for fused rotary position embedding computation - int32 version 7 | __global__ void fused_rot_pos_emb_kernel_int32( 8 | const float *__restrict__ inv_freq, // [dim/2] - precomputed inverse frequencies 9 | const int32_t *__restrict__ grid_thw, // [num_grids, 3] - (t, h, w) for each grid 10 | float *__restrict__ output, // [total_tokens, dim] - output rotary embeddings 11 | const int32_t *__restrict__ cumsum_tokens, // [num_grids+1] - cumulative sum of tokens per grid 12 | const int dim_half, // dim/2 (size of inv_freq) 13 | const int spatial_merge_size, // spatial merge size 14 | const int num_grids // number of grids 15 | ) 16 | { 17 | const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; 18 | const int32_t total_tokens = cumsum_tokens[num_grids]; 19 | 20 | if (tid >= total_tokens * dim_half) 21 | return; 22 | 23 | const int32_t token_idx = tid / dim_half; 24 | const int freq_idx = tid % dim_half; 25 | 26 | // Find which grid this token belongs to 27 | int grid_idx = 0; 28 | int32_t local_token_idx = token_idx; 29 | for (int g = 0; g < num_grids; g++) 30 | { 31 | if (token_idx < cumsum_tokens[g + 1]) 32 | { 33 | grid_idx = g; 34 | local_token_idx = token_idx - cumsum_tokens[g]; 35 | break; 36 | } 37 | } 38 | 39 | // Get grid dimensions 40 | const int32_t h = grid_thw[grid_idx * 3 + 1]; 41 | const int32_t w = grid_thw[grid_idx * 3 + 2]; 42 | 43 | // Calculate spatial dimensions after merging 44 | const int32_t h_merged = h / spatial_merge_size; 45 | const int32_t w_merged = w / spatial_merge_size; 46 | const int32_t spatial_tokens = h_merged * w_merged * spatial_merge_size * spatial_merge_size; 47 | 48 | // Get spatial index 49 | const int32_t spatial_idx = local_token_idx % spatial_tokens; 50 | 51 | // Decompose spatial index to get merged block and position within block 52 | const int32_t tokens_per_block = spatial_merge_size * spatial_merge_size; 53 | const int32_t block_idx = spatial_idx / tokens_per_block; 54 | const int32_t within_block_idx = spatial_idx % tokens_per_block; 55 | 56 | // Get block coordinates in merged grid 57 | const int32_t block_h = block_idx / w_merged; 58 | const int32_t block_w = block_idx % w_merged; 59 | 60 | // Get position within block 61 | const int32_t within_h = within_block_idx / spatial_merge_size; 62 | const int32_t within_w = within_block_idx % spatial_merge_size; 63 | 64 | // Calculate actual h and w positions 65 | const int32_t h_pos = block_h * spatial_merge_size + within_h; 66 | const int32_t w_pos = block_w * spatial_merge_size + within_w; 67 | 68 | // Compute rotary embedding 69 | float freq_val = inv_freq[freq_idx]; 70 | 71 | // Output has shape [total_tokens, dim] where dim = 2 * dim_half 72 | int32_t out_idx = token_idx * dim_half * 2 + freq_idx; 73 | output[out_idx] = h_pos * freq_val; // h_pos frequencies 74 | output[out_idx + dim_half] = w_pos * freq_val; // w_pos frequencies 75 | } 76 | 77 | // CUDA kernel for fused rotary position embedding computation - int64 version 78 | __global__ void fused_rot_pos_emb_kernel_int64( 79 | const float *__restrict__ inv_freq, // [dim/2] - precomputed inverse frequencies 80 | const int64_t *__restrict__ grid_thw, // [num_grids, 3] - (t, h, w) for each grid 81 | float *__restrict__ output, // [total_tokens, dim] - output rotary embeddings 82 | const int64_t *__restrict__ cumsum_tokens, // [num_grids+1] - cumulative sum of tokens per grid 83 | const int dim_half, // dim/2 (size of inv_freq) 84 | const int spatial_merge_size, // spatial merge size 85 | const int num_grids // number of grids 86 | ) 87 | { 88 | const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; 89 | const int64_t total_tokens = cumsum_tokens[num_grids]; 90 | 91 | if (tid >= total_tokens * dim_half) 92 | return; 93 | 94 | const int64_t token_idx = tid / dim_half; 95 | const int freq_idx = tid % dim_half; 96 | 97 | // Find which grid this token belongs to 98 | int grid_idx = 0; 99 | int64_t local_token_idx = token_idx; 100 | for (int g = 0; g < num_grids; g++) 101 | { 102 | if (token_idx < cumsum_tokens[g + 1]) 103 | { 104 | grid_idx = g; 105 | local_token_idx = token_idx - cumsum_tokens[g]; 106 | break; 107 | } 108 | } 109 | 110 | // Get grid dimensions 111 | const int64_t h = grid_thw[grid_idx * 3 + 1]; 112 | const int64_t w = grid_thw[grid_idx * 3 + 2]; 113 | 114 | // Calculate spatial dimensions after merging 115 | const int64_t h_merged = h / spatial_merge_size; 116 | const int64_t w_merged = w / spatial_merge_size; 117 | const int64_t spatial_tokens = h_merged * w_merged * spatial_merge_size * spatial_merge_size; 118 | 119 | // Get spatial index 120 | const int64_t spatial_idx = local_token_idx % spatial_tokens; 121 | 122 | // Decompose spatial index to get merged block and position within block 123 | const int64_t tokens_per_block = spatial_merge_size * spatial_merge_size; 124 | const int64_t block_idx = spatial_idx / tokens_per_block; 125 | const int64_t within_block_idx = spatial_idx % tokens_per_block; 126 | 127 | // Get block coordinates in merged grid 128 | const int64_t block_h = block_idx / w_merged; 129 | const int64_t block_w = block_idx % w_merged; 130 | 131 | // Get position within block 132 | const int64_t within_h = within_block_idx / spatial_merge_size; 133 | const int64_t within_w = within_block_idx % spatial_merge_size; 134 | 135 | // Calculate actual h and w positions 136 | const int64_t h_pos = block_h * spatial_merge_size + within_h; 137 | const int64_t w_pos = block_w * spatial_merge_size + within_w; 138 | 139 | // Compute rotary embedding 140 | float freq_val = inv_freq[freq_idx]; 141 | 142 | // Output has shape [total_tokens, dim] where dim = 2 * dim_half 143 | int64_t out_idx = token_idx * dim_half * 2 + freq_idx; 144 | output[out_idx] = h_pos * freq_val; // h_pos frequencies 145 | output[out_idx + dim_half] = w_pos * freq_val; // w_pos frequencies 146 | } 147 | 148 | // Parallel computation of token counts per grid - int32 version 149 | __global__ void compute_token_counts_kernel_int32( 150 | const int32_t *__restrict__ grid_thw, 151 | int32_t *__restrict__ token_counts, 152 | const int spatial_merge_size, 153 | const int num_grids) 154 | { 155 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 156 | if (idx >= num_grids) 157 | return; 158 | 159 | int32_t t = grid_thw[idx * 3 + 0]; 160 | int32_t h = grid_thw[idx * 3 + 1]; 161 | int32_t w = grid_thw[idx * 3 + 2]; 162 | int32_t h_merged = h / spatial_merge_size; 163 | int32_t w_merged = w / spatial_merge_size; 164 | token_counts[idx] = t * h_merged * w_merged * spatial_merge_size * spatial_merge_size; 165 | } 166 | 167 | // Parallel computation of token counts per grid - int64 version 168 | __global__ void compute_token_counts_kernel_int64( 169 | const int64_t *__restrict__ grid_thw, 170 | int64_t *__restrict__ token_counts, 171 | const int spatial_merge_size, 172 | const int num_grids) 173 | { 174 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 175 | if (idx >= num_grids) 176 | return; 177 | 178 | int64_t t = grid_thw[idx * 3 + 0]; 179 | int64_t h = grid_thw[idx * 3 + 1]; 180 | int64_t w = grid_thw[idx * 3 + 2]; 181 | int64_t h_merged = h / spatial_merge_size; 182 | int64_t w_merged = w / spatial_merge_size; 183 | token_counts[idx] = t * h_merged * w_merged * spatial_merge_size * spatial_merge_size; 184 | } 185 | 186 | // Implementation for int32 187 | torch::Tensor fused_rot_pos_emb_cuda_int32( 188 | torch::Tensor inv_freq, // [dim/2] 189 | torch::Tensor grid_thw, // [num_grids, 3] 190 | int spatial_merge_size) 191 | { 192 | TORCH_CHECK(inv_freq.dim() == 1, "inv_freq must be 1-dimensional"); 193 | TORCH_CHECK(inv_freq.is_cuda(), "inv_freq must be a CUDA tensor"); 194 | TORCH_CHECK(inv_freq.scalar_type() == torch::kFloat32, "inv_freq must be float32"); 195 | 196 | TORCH_CHECK(grid_thw.dim() == 2, "grid_thw must be 2-dimensional"); 197 | TORCH_CHECK(grid_thw.size(1) == 3, "grid_thw must have shape [num_grids, 3]"); 198 | TORCH_CHECK(grid_thw.is_cuda(), "grid_thw must be a CUDA tensor"); 199 | TORCH_CHECK(grid_thw.scalar_type() == torch::kInt32, "grid_thw must be int32"); 200 | 201 | TORCH_CHECK(spatial_merge_size > 0, "spatial_merge_size must be positive"); 202 | 203 | const int dim_half = inv_freq.size(0); 204 | const int num_grids = grid_thw.size(0); 205 | 206 | auto token_counts = torch::zeros({num_grids}, torch::TensorOptions().dtype(torch::kInt32).device(grid_thw.device())); 207 | const int threads = 256; 208 | const int blocks = (num_grids + threads - 1) / threads; 209 | 210 | compute_token_counts_kernel_int32<<>>( 211 | grid_thw.data_ptr(), 212 | token_counts.data_ptr(), 213 | spatial_merge_size, 214 | num_grids); 215 | 216 | auto cumsum_tokens = torch::cat({torch::zeros({1}, torch::TensorOptions().dtype(torch::kInt32).device(grid_thw.device())), 217 | token_counts.cumsum(0).to(torch::kInt32)}, 218 | 0); 219 | 220 | cudaDeviceSynchronize(); 221 | 222 | int64_t total_tokens = cumsum_tokens[-1].item(); 223 | TORCH_CHECK(total_tokens > 0, "total_tokens must be positive"); 224 | 225 | auto output = torch::zeros({total_tokens, dim_half * 2}, 226 | torch::TensorOptions().dtype(torch::kFloat32).device(inv_freq.device())); 227 | 228 | const int threads_per_block = 256; 229 | const int64_t num_elements = total_tokens * dim_half; 230 | const int num_blocks = static_cast((num_elements + threads_per_block - 1) / threads_per_block); 231 | 232 | fused_rot_pos_emb_kernel_int32<<>>( 233 | inv_freq.data_ptr(), 234 | grid_thw.data_ptr(), 235 | output.data_ptr(), 236 | cumsum_tokens.data_ptr(), 237 | dim_half, 238 | spatial_merge_size, 239 | num_grids); 240 | 241 | cudaDeviceSynchronize(); 242 | 243 | TORCH_CHECK(output.scalar_type() == torch::kFloat32, "Output must be float32"); 244 | TORCH_CHECK(output.size(0) == total_tokens, "Output token count mismatch"); 245 | TORCH_CHECK(output.size(1) == dim_half * 2, "Output dimension mismatch"); 246 | 247 | return output; 248 | } 249 | 250 | // Implementation for int64 251 | torch::Tensor fused_rot_pos_emb_cuda_int64( 252 | torch::Tensor inv_freq, // [dim/2] 253 | torch::Tensor grid_thw, // [num_grids, 3] 254 | int spatial_merge_size) 255 | { 256 | TORCH_CHECK(inv_freq.dim() == 1, "inv_freq must be 1-dimensional"); 257 | TORCH_CHECK(inv_freq.is_cuda(), "inv_freq must be a CUDA tensor"); 258 | TORCH_CHECK(inv_freq.scalar_type() == torch::kFloat32, "inv_freq must be float32"); 259 | 260 | TORCH_CHECK(grid_thw.dim() == 2, "grid_thw must be 2-dimensional"); 261 | TORCH_CHECK(grid_thw.size(1) == 3, "grid_thw must have shape [num_grids, 3]"); 262 | TORCH_CHECK(grid_thw.is_cuda(), "grid_thw must be a CUDA tensor"); 263 | TORCH_CHECK(grid_thw.scalar_type() == torch::kInt64, "grid_thw must be int64"); 264 | 265 | TORCH_CHECK(spatial_merge_size > 0, "spatial_merge_size must be positive"); 266 | 267 | const int dim_half = inv_freq.size(0); 268 | const int num_grids = grid_thw.size(0); 269 | 270 | auto token_counts = torch::zeros({num_grids}, torch::TensorOptions().dtype(torch::kInt64).device(grid_thw.device())); 271 | const int threads = 256; 272 | const int blocks = (num_grids + threads - 1) / threads; 273 | 274 | compute_token_counts_kernel_int64<<>>( 275 | grid_thw.data_ptr(), 276 | token_counts.data_ptr(), 277 | spatial_merge_size, 278 | num_grids); 279 | 280 | auto cumsum_tokens = torch::cat({torch::zeros({1}, torch::TensorOptions().dtype(torch::kInt64).device(grid_thw.device())), 281 | token_counts.cumsum(0).to(torch::kInt64)}, 282 | 0); 283 | 284 | cudaDeviceSynchronize(); 285 | 286 | int64_t total_tokens = cumsum_tokens[-1].item(); 287 | TORCH_CHECK(total_tokens > 0, "total_tokens must be positive"); 288 | 289 | auto output = torch::zeros({total_tokens, dim_half * 2}, 290 | torch::TensorOptions().dtype(torch::kFloat32).device(inv_freq.device())); 291 | 292 | const int threads_per_block = 256; 293 | const int64_t num_elements = total_tokens * dim_half; 294 | const int num_blocks = static_cast((num_elements + threads_per_block - 1) / threads_per_block); 295 | 296 | fused_rot_pos_emb_kernel_int64<<>>( 297 | inv_freq.data_ptr(), 298 | grid_thw.data_ptr(), 299 | output.data_ptr(), 300 | cumsum_tokens.data_ptr(), 301 | dim_half, 302 | spatial_merge_size, 303 | num_grids); 304 | 305 | cudaDeviceSynchronize(); 306 | 307 | TORCH_CHECK(output.scalar_type() == torch::kFloat32, "Output must be float32"); 308 | TORCH_CHECK(output.size(0) == total_tokens, "Output token count mismatch"); 309 | TORCH_CHECK(output.size(1) == dim_half * 2, "Output dimension mismatch"); 310 | 311 | return output; 312 | } 313 | 314 | // Main function that dispatches based on grid_thw scalar type 315 | torch::Tensor fused_rot_pos_emb_cuda( 316 | torch::Tensor inv_freq, 317 | torch::Tensor grid_thw, 318 | int spatial_merge_size) 319 | { 320 | if (grid_thw.scalar_type() == torch::kInt32) 321 | { 322 | return fused_rot_pos_emb_cuda_int32(inv_freq, grid_thw, spatial_merge_size); 323 | } 324 | else if (grid_thw.scalar_type() == torch::kInt64) 325 | { 326 | return fused_rot_pos_emb_cuda_int64(inv_freq, grid_thw, spatial_merge_size); 327 | } 328 | else 329 | { 330 | TORCH_CHECK(false, "Unsupported grid_thw scalar type: ", grid_thw.scalar_type()); 331 | } 332 | } 333 | -------------------------------------------------------------------------------- /csrc/dual_asym_grouped_gemm.cu: -------------------------------------------------------------------------------- 1 | #include "dual_asym_grouped_gemm.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cutlass/bfloat16.h" 9 | #include "cutlass/complex.h" 10 | #include "cutlass/gemm/kernel/gemm_grouped.h" 11 | #include "cutlass/gemm/kernel/default_gemm_grouped.h" 12 | #include "cutlass/gemm/device/gemm_grouped.h" 13 | 14 | #define NUM_STREAM 4 15 | 16 | #define CUDA_CALL(code) \ 17 | do \ 18 | { \ 19 | cudaError_t status = code; \ 20 | std::string err = cudaGetErrorString(status); \ 21 | TORCH_CHECK(status == cudaSuccess, err); \ 22 | } while (0) 23 | 24 | #define CUBLAS_CALL(code) \ 25 | do \ 26 | { \ 27 | cublasStatus_t status = code; \ 28 | TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \ 29 | } while (0) 30 | 31 | #define GROUPED_GEMM_STRINGIFY_HELPER(x) #x 32 | #define GROUPED_GEMM_STRINGIFY(x) \ 33 | GROUPED_GEMM_STRINGIFY_HELPER(x) 34 | 35 | template 36 | torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) 37 | { 38 | size_t bytes = x.size() * sizeof(T); 39 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(device); 40 | torch::Tensor out = torch::empty(bytes, options); 41 | 42 | CUDA_CALL(cudaMemcpyAsync(out.data_ptr(), 43 | x.data(), bytes, 44 | cudaMemcpyHostToDevice, 45 | c10::cuda::getCurrentCUDAStream())); 46 | return out; 47 | } 48 | 49 | using DualExpertGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped< 50 | ::cutlass::bfloat16_t, 51 | ::cutlass::layout::RowMajor, 52 | ::cutlass::ComplexTransform::kNone, 53 | 8, 54 | ::cutlass::bfloat16_t, 55 | ::cutlass::layout::RowMajor, 56 | ::cutlass::ComplexTransform::kNone, 57 | 8, 58 | ::cutlass::bfloat16_t, 59 | ::cutlass::layout::RowMajor, 60 | float, 61 | ::cutlass::arch::OpClassTensorOp, 62 | ::cutlass::arch::Sm80, 63 | ::cutlass::gemm::GemmShape<128, 128, 32>, 64 | ::cutlass::gemm::GemmShape<64, 64, 32>, 65 | ::cutlass::gemm::GemmShape<16, 8, 16>, 66 | ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, 67 | ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 68 | 4>::GemmKernel; 69 | 70 | using DualExpertGemmKernelTN = typename cutlass::gemm::kernel::DefaultGemmGrouped< 71 | ::cutlass::bfloat16_t, 72 | ::cutlass::layout::ColumnMajor, 73 | ::cutlass::ComplexTransform::kNone, 74 | 8, 75 | ::cutlass::bfloat16_t, 76 | ::cutlass::layout::RowMajor, 77 | ::cutlass::ComplexTransform::kNone, 78 | 8, 79 | ::cutlass::bfloat16_t, 80 | ::cutlass::layout::RowMajor, 81 | float, 82 | ::cutlass::arch::OpClassTensorOp, 83 | ::cutlass::arch::Sm80, 84 | ::cutlass::gemm::GemmShape<128, 128, 32>, 85 | ::cutlass::gemm::GemmShape<64, 64, 32>, 86 | ::cutlass::gemm::GemmShape<16, 8, 16>, 87 | ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, 88 | ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 89 | 4>::GemmKernel; 90 | 91 | using DualExpertGemmKernelNT = typename cutlass::gemm::kernel::DefaultGemmGrouped< 92 | ::cutlass::bfloat16_t, 93 | ::cutlass::layout::RowMajor, 94 | ::cutlass::ComplexTransform::kNone, 95 | 8, 96 | ::cutlass::bfloat16_t, 97 | ::cutlass::layout::ColumnMajor, 98 | ::cutlass::ComplexTransform::kNone, 99 | 8, 100 | ::cutlass::bfloat16_t, 101 | ::cutlass::layout::RowMajor, 102 | float, 103 | ::cutlass::arch::OpClassTensorOp, 104 | ::cutlass::arch::Sm80, 105 | ::cutlass::gemm::GemmShape<128, 128, 32>, 106 | ::cutlass::gemm::GemmShape<64, 64, 32>, 107 | ::cutlass::gemm::GemmShape<16, 8, 16>, 108 | ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, 109 | ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 110 | 4>::GemmKernel; 111 | 112 | using DualExpertGemmKernelTT = typename cutlass::gemm::kernel::DefaultGemmGrouped< 113 | ::cutlass::bfloat16_t, 114 | ::cutlass::layout::ColumnMajor, 115 | ::cutlass::ComplexTransform::kNone, 116 | 8, 117 | ::cutlass::bfloat16_t, 118 | ::cutlass::layout::ColumnMajor, 119 | ::cutlass::ComplexTransform::kNone, 120 | 8, 121 | ::cutlass::bfloat16_t, 122 | ::cutlass::layout::RowMajor, 123 | float, 124 | ::cutlass::arch::OpClassTensorOp, 125 | ::cutlass::arch::Sm80, 126 | ::cutlass::gemm::GemmShape<128, 128, 32>, 127 | ::cutlass::gemm::GemmShape<64, 64, 32>, 128 | ::cutlass::gemm::GemmShape<16, 8, 16>, 129 | ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, 130 | ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 131 | 4>::GemmKernel; 132 | 133 | using DualExpertGemmNN = ::cutlass::gemm::device::GemmGrouped; 134 | using DualExpertGemmTN = ::cutlass::gemm::device::GemmGrouped; 135 | using DualExpertGemmNT = ::cutlass::gemm::device::GemmGrouped; 136 | using DualExpertGemmTT = ::cutlass::gemm::device::GemmGrouped; 137 | 138 | template 139 | typename Gemm::Arguments MakeAsymmetricArgumentsSeparated( 140 | torch::Tensor input_expert0, 141 | torch::Tensor input_expert1, 142 | torch::Tensor weight_expert0, 143 | torch::Tensor weight_expert1, 144 | torch::Tensor output_expert0, 145 | torch::Tensor output_expert1) 146 | { 147 | 148 | TORCH_CHECK(input_expert0.dim() == 2 && input_expert1.dim() == 2, 149 | "Input tensors must be 2D"); 150 | TORCH_CHECK(weight_expert0.dim() == 2 && weight_expert1.dim() == 2, 151 | "Weight tensors must be 2D"); 152 | TORCH_CHECK(output_expert0.dim() == 2 && output_expert1.dim() == 2, 153 | "Output tensors must be 2D"); 154 | 155 | using LayoutA = typename Gemm::LayoutA; 156 | using LayoutB = typename Gemm::LayoutB; 157 | using LayoutC = typename Gemm::LayoutC; 158 | 159 | bool a_is_column_major = std::is_same_v; 160 | bool b_is_column_major = std::is_same_v; 161 | 162 | int64_t m0, k0, n0; 163 | if (a_is_column_major) 164 | { 165 | m0 = input_expert0.size(1); 166 | k0 = input_expert0.size(0); 167 | } 168 | else 169 | { 170 | m0 = input_expert0.size(0); 171 | k0 = input_expert0.size(1); 172 | } 173 | 174 | if (b_is_column_major) 175 | { 176 | n0 = weight_expert0.size(0); 177 | TORCH_CHECK(weight_expert0.size(1) == k0, "Expert 0: k dimensions must match"); 178 | } 179 | else 180 | { 181 | n0 = weight_expert0.size(1); 182 | TORCH_CHECK(weight_expert0.size(0) == k0, "Expert 0: k dimensions must match"); 183 | } 184 | 185 | int64_t m1, k1, n1; 186 | if (a_is_column_major) 187 | { 188 | m1 = input_expert1.size(1); 189 | k1 = input_expert1.size(0); 190 | } 191 | else 192 | { 193 | m1 = input_expert1.size(0); 194 | k1 = input_expert1.size(1); 195 | } 196 | 197 | if (b_is_column_major) 198 | { 199 | n1 = weight_expert1.size(0); 200 | TORCH_CHECK(weight_expert1.size(1) == k1, "Expert 1: k dimensions must match"); 201 | } 202 | else 203 | { 204 | n1 = weight_expert1.size(1); 205 | TORCH_CHECK(weight_expert1.size(0) == k1, "Expert 1: k dimensions must match"); 206 | } 207 | 208 | std::vector problem_sizes_host(2); 209 | problem_sizes_host[0] = cutlass::gemm::GemmCoord(m0, n0, k0); 210 | problem_sizes_host[1] = cutlass::gemm::GemmCoord(m1, n1, k1); 211 | 212 | int64_t num_experts = 2; 213 | int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts); 214 | if (!threadblock_count) 215 | { 216 | TORCH_CHECK(false, "Dual Expert Grouped GEMM execution not possible with HW"); 217 | } 218 | 219 | std::vector lda_host(num_experts); 220 | std::vector ldb_host(num_experts); 221 | std::vector ldc_host(num_experts); 222 | 223 | using ElementA = typename Gemm::ElementA; 224 | using ElementB = typename Gemm::ElementB; 225 | using ElementC = typename Gemm::ElementC; 226 | 227 | std::vector ptr_a_host(num_experts); 228 | std::vector ptr_b_host(num_experts); 229 | std::vector ptr_c_host(num_experts); 230 | 231 | auto problem_0 = problem_sizes_host[0]; 232 | lda_host[0] = LayoutA::packed({problem_0.m(), problem_0.k()}).stride(0); 233 | ldb_host[0] = LayoutB::packed({problem_0.k(), problem_0.n()}).stride(0); 234 | ldc_host[0] = LayoutC::packed({problem_0.m(), problem_0.n()}).stride(0); 235 | 236 | ptr_a_host[0] = (ElementA *)input_expert0.data_ptr(); 237 | ptr_b_host[0] = (ElementB *)weight_expert0.data_ptr(); 238 | ptr_c_host[0] = (ElementC *)output_expert0.data_ptr(); 239 | 240 | auto problem_1 = problem_sizes_host[1]; 241 | lda_host[1] = LayoutA::packed({problem_1.m(), problem_1.k()}).stride(0); 242 | ldb_host[1] = LayoutB::packed({problem_1.k(), problem_1.n()}).stride(0); 243 | ldc_host[1] = LayoutC::packed({problem_1.m(), problem_1.n()}).stride(0); 244 | 245 | ptr_a_host[1] = (ElementA *)input_expert1.data_ptr(); 246 | ptr_b_host[1] = (ElementB *)weight_expert1.data_ptr(); 247 | ptr_c_host[1] = (ElementC *)output_expert1.data_ptr(); 248 | 249 | torch::Tensor lda = CopyToDevice(lda_host, input_expert0.device()); 250 | torch::Tensor ldb = CopyToDevice(ldb_host, input_expert0.device()); 251 | torch::Tensor ldc = CopyToDevice(ldc_host, input_expert0.device()); 252 | torch::Tensor ptr_a = CopyToDevice(ptr_a_host, input_expert0.device()); 253 | torch::Tensor ptr_b = CopyToDevice(ptr_b_host, input_expert0.device()); 254 | torch::Tensor ptr_c = CopyToDevice(ptr_c_host, input_expert0.device()); 255 | torch::Tensor problem_sizes = CopyToDevice(problem_sizes_host, input_expert0.device()); 256 | 257 | typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f); 258 | typename Gemm::Arguments arguments( 259 | (cutlass::gemm::GemmCoord *)problem_sizes.data_ptr(), 260 | (int)num_experts, 261 | (int)threadblock_count, 262 | epilogue_op, 263 | (ElementA **)ptr_a.data_ptr(), 264 | (ElementB **)ptr_b.data_ptr(), 265 | (ElementC **)ptr_c.data_ptr(), 266 | (ElementC **)ptr_c.data_ptr(), 267 | (int64_t *)lda.data_ptr(), 268 | (int64_t *)ldb.data_ptr(), 269 | (int64_t *)ldc.data_ptr(), 270 | (int64_t *)ldc.data_ptr(), 271 | (cutlass::gemm::GemmCoord *)problem_sizes_host.data()); 272 | 273 | return arguments; 274 | } 275 | 276 | template 277 | void executeDualExpertGemm( 278 | torch::Tensor input_expert0, 279 | torch::Tensor input_expert1, 280 | torch::Tensor weight_expert0, 281 | torch::Tensor weight_expert1, 282 | torch::Tensor output_expert0, 283 | torch::Tensor output_expert1) 284 | { 285 | 286 | Gemm gemm; 287 | 288 | auto arguments = MakeAsymmetricArgumentsSeparated( 289 | input_expert0, input_expert1, 290 | weight_expert0, weight_expert1, 291 | output_expert0, output_expert1); 292 | 293 | int64_t workspace_size = gemm.get_workspace_size(arguments); 294 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(input_expert0.device()); 295 | torch::Tensor workspace = torch::empty(workspace_size, options); 296 | 297 | if (gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) 298 | { 299 | TORCH_CHECK(false, "Failed to initialize CUTLASS Asymmetric Dual Expert GEMM"); 300 | } 301 | 302 | if (gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) 303 | { 304 | TORCH_CHECK(false, "Failed to run CUTLASS Asymmetric Dual Expert GEMM"); 305 | } 306 | } 307 | 308 | void AsymmetricDualExpertGemm( 309 | torch::Tensor input_expert0, 310 | torch::Tensor input_expert1, 311 | torch::Tensor weight_expert0, 312 | torch::Tensor weight_expert1, 313 | torch::Tensor output_expert0, 314 | torch::Tensor output_expert1, 315 | bool trans_a, bool trans_b) 316 | { 317 | 318 | TORCH_CHECK(input_expert0.device() == input_expert1.device() && 319 | input_expert0.device() == weight_expert0.device() && 320 | input_expert0.device() == weight_expert1.device() && 321 | input_expert0.device() == output_expert0.device() && 322 | input_expert0.device() == output_expert1.device(), 323 | "All tensors must be on the same device"); 324 | 325 | TORCH_CHECK(input_expert0.device().is_cuda(), 326 | "All tensors must be on CUDA device for CUTLASS GEMM"); 327 | 328 | TORCH_CHECK(input_expert0.dtype() == torch::kBFloat16, 329 | "All tensors must be BFloat16 for this kernel"); 330 | 331 | torch::Tensor input0_contiguous = input_expert0.contiguous(); 332 | torch::Tensor input1_contiguous = input_expert1.contiguous(); 333 | torch::Tensor weight0_contiguous = weight_expert0.contiguous(); 334 | torch::Tensor weight1_contiguous = weight_expert1.contiguous(); 335 | torch::Tensor output0_contiguous = output_expert0.contiguous(); 336 | torch::Tensor output1_contiguous = output_expert1.contiguous(); 337 | 338 | if (!trans_a && !trans_b) 339 | { 340 | using Gemm = DualExpertGemmNN; 341 | executeDualExpertGemm(input0_contiguous, input1_contiguous, 342 | weight0_contiguous, weight1_contiguous, 343 | output0_contiguous, output1_contiguous); 344 | } 345 | else if (trans_a && !trans_b) 346 | { 347 | using Gemm = DualExpertGemmTN; 348 | executeDualExpertGemm(input0_contiguous, input1_contiguous, 349 | weight0_contiguous, weight1_contiguous, 350 | output0_contiguous, output1_contiguous); 351 | } 352 | else if (!trans_a && trans_b) 353 | { 354 | using Gemm = DualExpertGemmNT; 355 | executeDualExpertGemm(input0_contiguous, input1_contiguous, 356 | weight0_contiguous, weight1_contiguous, 357 | output0_contiguous, output1_contiguous); 358 | } 359 | else 360 | { 361 | using Gemm = DualExpertGemmTT; 362 | executeDualExpertGemm(input0_contiguous, input1_contiguous, 363 | weight0_contiguous, weight1_contiguous, 364 | output0_contiguous, output1_contiguous); 365 | } 366 | } 367 | -------------------------------------------------------------------------------- /wall_x/fusions/backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | High-performance C++ backend interface for optimized matrix operations. 3 | 4 | This module provides Python bindings for custom CUDA kernels optimized for 5 | transformer and MoE (Mixture of Experts) operations, including: 6 | - Asymmetric dual expert operations 7 | - Token permutation/unpermutation for MoE routing 8 | - RoPE (Rotary Position Embedding) operations 9 | """ 10 | 11 | import torch 12 | from typing import Tuple, Optional 13 | import wallx_csrc as backend 14 | 15 | 16 | def _allocate_asymmetric_dual_outputs( 17 | input_expert0: torch.Tensor, 18 | input_expert1: torch.Tensor, 19 | weight_expert0: torch.Tensor, 20 | weight_expert1: torch.Tensor, 21 | ) -> Tuple[torch.Tensor, torch.Tensor]: 22 | """ 23 | Allocate output tensors for asymmetric dual expert GEMM operations. 24 | 25 | This function handles the case where two experts may have different output 26 | dimensions, which is common in heterogeneous MoE architectures. 27 | 28 | Args: 29 | input_expert0 (torch.Tensor): Expert 0 input tensor of shape [m0, k] 30 | input_expert1 (torch.Tensor): Expert 1 input tensor of shape [m1, k] 31 | weight_expert0 (torch.Tensor): Expert 0 weight tensor of shape [k, n0] 32 | weight_expert1 (torch.Tensor): Expert 1 weight tensor of shape [k, n1] 33 | 34 | Returns: 35 | Tuple[torch.Tensor, torch.Tensor]: Pre-allocated output tensors 36 | - output_expert0: Shape [m0, n0] 37 | - output_expert1: Shape [m1, n1] 38 | 39 | Raises: 40 | AssertionError: If tensor dimensions are incompatible 41 | """ 42 | # Validate input tensor dimensions 43 | assert input_expert0.ndim == 2, "Expected 2D tensor for input_expert0" 44 | assert input_expert1.ndim == 2, "Expected 2D tensor for input_expert1" 45 | assert weight_expert0.ndim == 2, "Expected 2D tensor for weight_expert0" 46 | assert weight_expert1.ndim == 2, "Expected 2D tensor for weight_expert1" 47 | 48 | # Verify dimension compatibility for matrix multiplication 49 | assert input_expert0.size(1) == weight_expert0.size( 50 | 0 51 | ), f"Input expert0 K dimension {input_expert0.size(1)} != weight expert0 K dimension {weight_expert0.size(0)}" 52 | assert input_expert1.size(1) == weight_expert1.size( 53 | 0 54 | ), f"Input expert1 K dimension {input_expert1.size(1)} != weight expert1 K dimension {weight_expert1.size(0)}" 55 | 56 | # Calculate output shapes: [m, k] × [k, n] = [m, n] 57 | m0, n0 = input_expert0.size(0), weight_expert0.size(1) 58 | m1, n1 = input_expert1.size(0), weight_expert1.size(1) 59 | 60 | # Allocate output tensors with matching device and dtype 61 | output_expert0 = torch.empty( 62 | m0, n0, device=input_expert0.device, dtype=input_expert0.dtype 63 | ) 64 | output_expert1 = torch.empty( 65 | m1, n1, device=input_expert1.device, dtype=input_expert1.dtype 66 | ) 67 | 68 | return output_expert0, output_expert1 69 | 70 | 71 | def asym_dual_gmm_separated( 72 | input_expert0: torch.Tensor, 73 | input_expert1: torch.Tensor, 74 | weight_expert0: torch.Tensor, 75 | weight_expert1: torch.Tensor, 76 | output_expert0: Optional[torch.Tensor] = None, 77 | output_expert1: Optional[torch.Tensor] = None, 78 | trans_a: bool = False, 79 | trans_b: bool = False, 80 | ) -> Tuple[torch.Tensor, torch.Tensor]: 81 | """ 82 | Asymmetric dual expert grouped GEMM with separated inputs and outputs. 83 | 84 | This is the recommended interface for maximum flexibility and performance when 85 | dealing with two experts that may have different intermediate dimensions. 86 | The operation is equivalent to: 87 | output_expert0 = input_expert0 @ weight_expert0 88 | output_expert1 = input_expert1 @ weight_expert1 89 | But optimized as a single fused kernel call. 90 | 91 | Args: 92 | input_expert0 (torch.Tensor): Expert 0 input tensor of shape [m0, k] 93 | input_expert1 (torch.Tensor): Expert 1 input tensor of shape [m1, k] 94 | weight_expert0 (torch.Tensor): Expert 0 weight tensor of shape [k, n0] 95 | weight_expert1 (torch.Tensor): Expert 1 weight tensor of shape [k, n1] 96 | Note: n0 can be different from n1 97 | output_expert0 (torch.Tensor, optional): Pre-allocated output for expert 0 [m0, n0] 98 | output_expert1 (torch.Tensor, optional): Pre-allocated output for expert 1 [m1, n1] 99 | trans_a (bool, optional): Whether to transpose input tensors. Defaults to False. 100 | trans_b (bool, optional): Whether to transpose weight tensors. Defaults to False. 101 | 102 | Returns: 103 | Tuple[torch.Tensor, torch.Tensor]: Output tensors (output_expert0, output_expert1) 104 | 105 | Example: 106 | >>> # Two experts with different output dimensions 107 | >>> input0 = torch.randn(512, 1024, device='cuda') # 512 tokens for expert 0 108 | >>> input1 = torch.randn(256, 1024, device='cuda') # 256 tokens for expert 1 109 | >>> weight0 = torch.randn(1024, 2048, device='cuda') # Expert 0: 1024->2048 110 | >>> weight1 = torch.randn(1024, 4096, device='cuda') # Expert 1: 1024->4096 111 | >>> out0, out1 = asym_dual_gmm_separated(input0, input1, weight0, weight1) 112 | """ 113 | # Allocate outputs if not provided 114 | if output_expert0 is None or output_expert1 is None: 115 | alloc_out0, alloc_out1 = _allocate_asymmetric_dual_outputs( 116 | input_expert0, input_expert1, weight_expert0, weight_expert1 117 | ) 118 | if output_expert0 is None: 119 | output_expert0 = alloc_out0 120 | if output_expert1 is None: 121 | output_expert1 = alloc_out1 122 | 123 | # Call optimized C++ backend kernel 124 | backend.asym_dual_gmm( 125 | input_expert0, 126 | input_expert1, 127 | weight_expert0, 128 | weight_expert1, 129 | output_expert0, 130 | output_expert1, 131 | trans_a, 132 | trans_b, 133 | ) 134 | 135 | return output_expert0, output_expert1 136 | 137 | 138 | def permute( 139 | input: torch.Tensor, 140 | indices: torch.Tensor, 141 | num_out_tokens: int, 142 | workspace: torch.Tensor, 143 | max_expanded_token_num: int, 144 | ) -> torch.Tensor: 145 | """ 146 | Permute input tokens according to expert assignment indices for MoE routing. 147 | 148 | This function reorders tokens based on their assigned experts to enable 149 | efficient grouped processing. Used in the forward pass of MoE layers. 150 | 151 | Args: 152 | input (torch.Tensor): Input tokens to permute 153 | indices (torch.Tensor): Expert assignment indices for each token 154 | num_out_tokens (int): Number of output tokens after expansion 155 | workspace (torch.Tensor): Temporary workspace tensor for intermediate computations 156 | max_expanded_token_num (int): Maximum number of tokens after top-k expansion 157 | 158 | Returns: 159 | torch.Tensor: Permuted tokens grouped by expert assignment 160 | 161 | Note: 162 | This is typically used with top-k expert selection where each token 163 | can be routed to multiple experts. 164 | """ 165 | return backend.permute( 166 | input, indices, num_out_tokens, workspace, max_expanded_token_num 167 | ) 168 | 169 | 170 | def unpermute( 171 | input: torch.Tensor, 172 | row_id_map: torch.Tensor, 173 | prob: torch.Tensor, 174 | max_tokens: int, 175 | num_topK: int, 176 | ) -> torch.Tensor: 177 | """ 178 | Unpermute expert outputs back to original token order with probability weighting. 179 | 180 | This function reverses the permutation applied in the forward pass and combines 181 | outputs from multiple experts using their routing probabilities. 182 | 183 | Args: 184 | input (torch.Tensor): Permuted expert outputs to unpermute 185 | row_id_map (torch.Tensor): Mapping from permuted positions to original positions 186 | prob (torch.Tensor): Expert routing probabilities for weighted combination 187 | max_tokens (int): Maximum number of tokens in the sequence 188 | num_topK (int): Number of top experts selected per token 189 | 190 | Returns: 191 | torch.Tensor: Unpermuted tokens in original order with expert outputs combined 192 | 193 | Note: 194 | The output combines multiple expert predictions for each token using 195 | the routing probabilities as weights. 196 | """ 197 | return backend.unpermute(input, row_id_map, prob, max_tokens, num_topK) 198 | 199 | 200 | def unpermute_bwd( 201 | input_bwd: torch.Tensor, 202 | input_fwd: torch.Tensor, 203 | row_id_map: torch.Tensor, 204 | prob: Optional[torch.Tensor], 205 | ) -> torch.Tensor: 206 | """ 207 | Backward pass for unpermute operation with gradient flow. 208 | 209 | This function handles the backward pass through the unpermute operation, 210 | ensuring proper gradient flow for training MoE models. 211 | 212 | Args: 213 | input_bwd (torch.Tensor): Backward gradients from the next layer 214 | input_fwd (torch.Tensor): Forward pass inputs (for gradient computation) 215 | row_id_map (torch.Tensor): Row mapping used in forward unpermute 216 | prob (torch.Tensor, optional): Expert probabilities. If None, uniform weights are used. 217 | 218 | Returns: 219 | torch.Tensor: Gradients with respect to the input of unpermute forward pass 220 | 221 | Note: 222 | If prob is None, uniform probabilities are assumed for gradient computation. 223 | """ 224 | # Handle case where probabilities are not provided 225 | if prob is None: 226 | prob = torch.ones( 227 | [input_bwd.size(0), 1], dtype=torch.float32, device=input_bwd.device 228 | ) 229 | 230 | return backend.unpermute_bwd(input_bwd, input_fwd, row_id_map, prob) 231 | 232 | 233 | def rope( 234 | q: torch.Tensor, 235 | k: torch.Tensor, 236 | cos: torch.Tensor, 237 | sin: torch.Tensor, 238 | q_out: torch.Tensor, 239 | k_out: torch.Tensor, 240 | mrope_section_doubled: bool, 241 | ) -> None: 242 | """ 243 | Apply RoPE (Rotary Position Embedding) to query and key tensors. 244 | 245 | Applies rotary position embeddings to query and key tensors using precomputed 246 | cosine and sine values. Supports both standard RoPE and multi-dimensional RoPE (mRoPE). 247 | 248 | Args: 249 | q (torch.Tensor): Query tensor to apply RoPE to 250 | k (torch.Tensor): Key tensor to apply RoPE to 251 | cos (torch.Tensor): Precomputed cosine values for rotation 252 | sin (torch.Tensor): Precomputed sine values for rotation 253 | q_out (torch.Tensor): Output tensor for rotated queries (in-place operation supported) 254 | k_out (torch.Tensor): Output tensor for rotated keys (in-place operation supported) 255 | mrope_section_doubled (bool): Whether using multi-dimensional RoPE with doubled sections 256 | 257 | Note: 258 | This function performs in-place operations if q_out and k_out point to the same 259 | memory as q and k respectively. The rotation is applied using the standard 260 | RoPE formulation with complex number rotation. 261 | """ 262 | return backend.rope(q, k, cos, sin, q_out, k_out, mrope_section_doubled) 263 | 264 | 265 | def rope_bwd( 266 | grad_q_out: torch.Tensor, 267 | grad_k_out: torch.Tensor, 268 | q: torch.Tensor, 269 | k: torch.Tensor, 270 | cos: torch.Tensor, 271 | sin: torch.Tensor, 272 | grad_q: torch.Tensor, 273 | grad_k: torch.Tensor, 274 | mrope_section_doubled: bool, 275 | ) -> None: 276 | """ 277 | Backward pass for RoPE operation with gradient computation. 278 | 279 | Computes gradients with respect to the input query and key tensors 280 | for the RoPE operation used in transformer attention mechanisms. 281 | 282 | Args: 283 | grad_q_out (torch.Tensor): Gradient with respect to output queries 284 | grad_k_out (torch.Tensor): Gradient with respect to output keys 285 | q (torch.Tensor): Original query tensor from forward pass 286 | k (torch.Tensor): Original key tensor from forward pass 287 | cos (torch.Tensor): Cosine values used in forward pass 288 | sin (torch.Tensor): Sine values used in forward pass 289 | grad_q (torch.Tensor): Output tensor for query gradients 290 | grad_k (torch.Tensor): Output tensor for key gradients 291 | mrope_section_doubled (bool): Whether using multi-dimensional RoPE configuration 292 | 293 | Note: 294 | This function computes the analytical gradient of the RoPE operation, 295 | which involves the inverse rotation compared to the forward pass. 296 | """ 297 | return backend.rope_bwd( 298 | grad_q_out, grad_k_out, q, k, cos, sin, grad_q, grad_k, mrope_section_doubled 299 | ) 300 | 301 | 302 | def get_rope_index( 303 | input_ids: torch.Tensor, 304 | image_grid_thw: Optional[torch.Tensor], 305 | video_grid_thw: Optional[torch.Tensor], 306 | second_per_grid_ts: Optional[torch.Tensor], 307 | attention_mask: Optional[torch.Tensor], 308 | spatial_merge_size: int, 309 | image_token_id: int, 310 | video_token_id: int, 311 | vision_start_token_id: int, 312 | tokens_per_second: float, 313 | ) -> Tuple[torch.Tensor, torch.Tensor]: 314 | """ 315 | Generate position indices for multimodal RoPE (Rotary Position Embedding). 316 | 317 | This function computes 3D position indices for text, image, and video tokens 318 | to enable proper spatial-temporal position encoding in multimodal transformers. 319 | 320 | Args: 321 | input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len] 322 | image_grid_thw (torch.Tensor, optional): Image grid specifications of shape [num_images, 3] (T, H, W) 323 | video_grid_thw (torch.Tensor, optional): Video grid specifications of shape [num_videos, 3] (T, H, W) 324 | second_per_grid_ts (torch.Tensor, optional): Temporal scaling per video grid of shape [num_videos] 325 | attention_mask (torch.Tensor, optional): Attention mask of shape [batch_size, seq_len] 326 | spatial_merge_size (int): Spatial dimension merge factor for patch grouping 327 | image_token_id (int): Token ID representing image patches 328 | video_token_id (int): Token ID representing video frames 329 | vision_start_token_id (int): Token ID marking vision sequence start 330 | tokens_per_second (float): Temporal scaling factor for video sequences 331 | 332 | Returns: 333 | Tuple[torch.Tensor, torch.Tensor]: A tuple containing: 334 | - position_ids: 3D position indices of shape [3, batch_size, seq_len] 335 | - mrope_deltas: Position deltas for multimodal RoPE of shape [batch_size, 1] 336 | 337 | Note: 338 | When both image_grid_thw and video_grid_thw are None, returns standard 339 | text-only position indices based on attention_mask or sequence order. 340 | """ 341 | return backend.rope_index( 342 | input_ids, 343 | image_grid_thw, 344 | video_grid_thw, 345 | second_per_grid_ts, 346 | attention_mask, 347 | spatial_merge_size, 348 | image_token_id, 349 | video_token_id, 350 | vision_start_token_id, 351 | tokens_per_second, 352 | ) 353 | 354 | 355 | def rot_pos_emb( 356 | inv_freq: torch.Tensor, 357 | grid_thw: torch.Tensor, 358 | spatial_merge_size: int, 359 | ) -> torch.Tensor: 360 | """ 361 | Compute fused rotary position embeddings for multimodal grids. 362 | 363 | This function efficiently computes rotary position embeddings for spatial-temporal 364 | grids using a fused CUDA kernel, supporting both int32 and int64 grid specifications. 365 | 366 | Args: 367 | inv_freq (torch.Tensor): Inverse frequencies for RoPE of shape [dim/2] 368 | Must be float32 dtype on CUDA device 369 | grid_thw (torch.Tensor): Grid specifications of shape [num_grids, 3] (T, H, W) 370 | Supports int32 or int64 dtype on CUDA device 371 | spatial_merge_size (int): Merge factor for spatial dimensions (must be positive) 372 | 373 | Returns: 374 | torch.Tensor: Computed rotary embeddings of shape [total_tokens, dim] 375 | where total_tokens is determined by grid layouts and spatial_merge_size 376 | 377 | Example: 378 | >>> inv_freq = torch.randn(64, device='cuda', dtype=torch.float32) # 128-dim model 379 | >>> grids = torch.tensor([[8, 14, 14], [16, 7, 7]], device='cuda', dtype=torch.int32) 380 | >>> embeddings = rot_pos_emb(inv_freq, grids, spatial_merge_size=2) 381 | >>> print(embeddings.shape) # [computed_tokens, 128] 382 | 383 | Note: 384 | The function automatically dispatches to int32 or int64 implementations 385 | based on the dtype of grid_thw. Output is always float32. 386 | """ 387 | return backend.rot_pos_emb(inv_freq, grid_thw, spatial_merge_size) 388 | 389 | 390 | def get_window_index( 391 | grid_thw: torch.Tensor, 392 | spatial_merge_size: int, 393 | vit_merger_window_size: int, 394 | patch_size: int, 395 | spatial_merge_unit: int, 396 | ) -> Tuple[torch.Tensor, torch.Tensor]: 397 | """ 398 | Generate window attention indices for Vision Transformer architectures. 399 | 400 | Computes window-based attention indices for hierarchical processing of vision 401 | tokens, enabling efficient sliding window attention patterns in ViT models. 402 | 403 | Args: 404 | grid_thw (torch.Tensor): Grid specifications of shape [num_grids, 3] (T, H, W) 405 | Must be int32 dtype on CUDA device 406 | spatial_merge_size (int): Spatial dimension merge factor 407 | vit_merger_window_size (int): Size of attention windows for ViT processing 408 | patch_size (int): Size of vision patches in pixels 409 | spatial_merge_unit (int): Unit size for spatial merging operations 410 | 411 | Returns: 412 | Tuple[torch.Tensor, torch.Tensor]: A tuple containing: 413 | - window_indices: Flattened window indices of shape [total_elements] 414 | - cu_window_seqlens: Cumulative window sequence lengths of shape [num_windows + 1] 415 | 416 | Example: 417 | >>> grids = torch.tensor([[1, 14, 14]], device='cuda', dtype=torch.int32) 418 | >>> indices, seqlens = get_window_index( 419 | ... grids, spatial_merge_size=2, vit_merger_window_size=7, 420 | ... patch_size=16, spatial_merge_unit=4 421 | ... ) 422 | 423 | Note: 424 | Returns empty tensors if input grid is empty or no valid windows can be formed. 425 | The cu_window_seqlens tensor enables efficient batched attention computation. 426 | """ 427 | return backend.get_window_index( 428 | grid_thw, 429 | spatial_merge_size, 430 | vit_merger_window_size, 431 | patch_size, 432 | spatial_merge_unit, 433 | ) 434 | --------------------------------------------------------------------------------