├── SPA_agent ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-39.pyc │ ├── base_llm.cpython-39.pyc │ ├── agent_proxy.cpython-39.pyc │ ├── ctx_manager.cpython-39.pyc │ ├── es_manager.cpython-39.pyc │ ├── agent_proxy.cpython-311.pyc │ ├── ctx_manager.cpython-311.pyc │ ├── generate_sft_data.cpython-39.pyc │ └── generate_sft_data_mask.cpython-39.pyc ├── base_llm.py ├── generate_sft_data.py ├── agent_proxy.py ├── es_manager.py └── ctx_manager.py ├── config ├── evaluation │ └── llama-30b.yaml ├── _2_sokoban.yaml ├── _5_metamathqa.yaml ├── _3_frozen_lake.yaml ├── _4_countdown.yaml ├── _10_sudoku.yaml ├── _6_webshop.yaml ├── stream.yaml ├── _1_bandit.yaml ├── evaluate_api_llm.yaml ├── base.yaml ├── base-lora.yaml ├── ppo_trainer.yaml └── envs.yaml ├── Internalizing_World_Models_via_Self_Play_Finetuning_for_Agentic_RL.pdf ├── .gitignore ├── sft ├── __pycache__ │ ├── spa_sft_dataset.cpython-39.pyc │ └── spa_sft_trainer.cpython-39.pyc ├── config │ └── sft_trainer.yaml ├── finetune_ft.sh ├── filter_sft_by_tag.py ├── spa_sft_dataset.py └── spa_sft_trainer.py ├── train_ppo_sfted.sh ├── run_baseline.sh ├── run_spa.sh └── README.md /SPA_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/evaluation/llama-30b.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Internalizing_World_Models_via_Self_Play_Finetuning_for_Agentic_RL.pdf: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | sftckpt 2 | sftdata 3 | outputs 4 | log 5 | data 6 | wandb 7 | sft_data* 8 | sftdata* 9 | __pycache__ -------------------------------------------------------------------------------- /config/_2_sokoban.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | trainer: 5 | experiment_name: sokoban-main 6 | 7 | 8 | -------------------------------------------------------------------------------- /SPA_agent/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/base_llm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/base_llm.cpython-39.pyc -------------------------------------------------------------------------------- /sft/__pycache__/spa_sft_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/sft/__pycache__/spa_sft_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /sft/__pycache__/spa_sft_trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/sft/__pycache__/spa_sft_trainer.cpython-39.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/agent_proxy.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/agent_proxy.cpython-39.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/ctx_manager.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/ctx_manager.cpython-39.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/es_manager.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/es_manager.cpython-39.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/agent_proxy.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/agent_proxy.cpython-311.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/ctx_manager.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/ctx_manager.cpython-311.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/generate_sft_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/generate_sft_data.cpython-39.pyc -------------------------------------------------------------------------------- /SPA_agent/__pycache__/generate_sft_data_mask.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiqichen17/SPA/HEAD/SPA_agent/__pycache__/generate_sft_data_mask.cpython-39.pyc -------------------------------------------------------------------------------- /config/_5_metamathqa.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | es_manager: 5 | train: 6 | env_configs: 7 | tags: ["MetamathQA"] 8 | val: 9 | env_configs: 10 | tags: ["MetamathQA"] 11 | 12 | -------------------------------------------------------------------------------- /config/_3_frozen_lake.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | trainer: 5 | experiment_name: frozenlake-main 6 | 7 | 8 | es_manager: 9 | train: 10 | env_configs: 11 | tags: ["FrozenLake"] 12 | val: 13 | env_configs: 14 | tags: ["FrozenLake"] 15 | -------------------------------------------------------------------------------- /config/_4_countdown.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | trainer: 5 | experiment_name: countdown 6 | 7 | 8 | agent_proxy: 9 | max_turn: 1 10 | max_actions_per_turn: 1 11 | 12 | es_manager: 13 | train: 14 | env_configs: 15 | tags: ["Countdown"] 16 | val: 17 | env_configs: 18 | tags: ["Countdown"] 19 | -------------------------------------------------------------------------------- /config/_10_sudoku.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | trainer: 5 | experiment_name: sudoku-${oc.env:MODEL}-RENDER_MODE${oc.env:RENDER_MODE}-fixed-easy12 6 | 7 | agent_proxy: 8 | max_turn: 5 9 | max_actions_per_turn: 5 # how many actions can be output at most in a single turn 10 | 11 | es_manager: 12 | train: 13 | env_configs: 14 | tags: ["Sudoku"] 15 | val: 16 | env_configs: 17 | tags: ["Sudoku"] -------------------------------------------------------------------------------- /config/_6_webshop.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | micro_batch_size_per_gpu: 1 5 | ppo_mini_batch_size: 32 6 | model_path: Qwen/Qwen2.5-3B-Instruct 7 | 8 | trainer: 9 | experiment_name: webshop 10 | 11 | 12 | agent_proxy: 13 | max_turn: 9 14 | max_actions_per_turn: 1 15 | 16 | actor_rollout_ref: 17 | rollout: 18 | max_model_len: 15000 19 | max_num_batched_tokens: 15000 20 | 21 | es_manager: 22 | train: 23 | env_configs: 24 | tags: ["WebShop"] 25 | val: 26 | env_configs: 27 | tags: ["WebShop"] 28 | -------------------------------------------------------------------------------- /config/stream.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | trainer: 5 | experiment_name: sokoban-main 6 | 7 | 8 | es_manager: 9 | val: 10 | env_groups: 1 11 | group_size: 1 # should be set to 1 because when val temperature is set to 0 and group size > 1, there will be repetitive prompts which leads to same trajectory. 12 | env_configs: 13 | tags: ["SimpleSokoban"] 14 | n_groups: [1] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation 15 | -------------------------------------------------------------------------------- /config/_1_bandit.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | actor_rollout_ref: 5 | rollout: 6 | response_length: 500 7 | val_kwargs: 8 | do_sample: True 9 | temperature: 0.5 # enabling randomness in evaluation 10 | 11 | trainer: 12 | experiment_name: bandit-base 13 | 14 | agent_proxy: 15 | max_turn: 1 16 | max_actions_per_turn: 1 # how many actions can be output at most in a single turn 17 | 18 | es_manager: 19 | train: 20 | env_configs: 21 | tags: ["Bandit"] # BanditGeneralizationNoThink 22 | val: 23 | env_groups: 512 24 | env_configs: 25 | tags: ["Bandit", "BanditTest"] 26 | n_groups: [256, 256] 27 | -------------------------------------------------------------------------------- /config/evaluate_api_llm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base # this is a symbolic link to the verl/verl/trainer/config/ppo_trainer.yaml file 3 | 4 | model_config: 5 | model_name: gpt-4o # should be registered in model_info 6 | max_concurrency: 16 7 | 8 | model_info: 9 | Qwen2.5-7B-Instruct: 10 | provider_name: together 11 | model_name: Qwen/Qwen2.5-7B-Instruct-Turbo 12 | generation_kwargs: 13 | temperature: 0 14 | max_tokens: 512 15 | Qwen2.5-72B-Instruct: 16 | provider_name: together 17 | model_name: Qwen/Qwen2.5-72B-Instruct-Turbo 18 | generation_kwargs: 19 | temperature: 0 20 | max_tokens: 512 21 | claude-3.7: 22 | provider_name: anthropic 23 | model_name: claude-3-7-sonnet-20250219 24 | generation_kwargs: 25 | temperature: 0 26 | max_tokens: 512 # max_completion_tokens if o1-mini 27 | gpt-4o: 28 | provider_name: openai 29 | model_name: gpt-4o 30 | generation_kwargs: 31 | temperature: 0 32 | max_tokens: 512 # max_completion_tokens if o1-mini 33 | deepseek-r1: 34 | provider_name: deepseek 35 | model_name: deepseek-reasoner 36 | generation_kwargs: 37 | temperature: 0 38 | max_completion_tokens: 512 39 | deepseek-v3: 40 | provider_name: deepseek 41 | model_name: deepseek-chat 42 | generation_kwargs: 43 | temperature: 0 44 | max_completion_tokens: 512 45 | 46 | 47 | 48 | es_manager: 49 | val: 50 | env_groups: 256 51 | group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output 52 | env_configs: 53 | tags: ["SimpleSokoban"] 54 | n_groups: [256] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation 55 | 56 | 57 | -------------------------------------------------------------------------------- /sft/config/sft_trainer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_batch_size: 256 3 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu 4 | micro_batch_size_per_gpu: 4 # this is also val batch size 5 | train_files: ~/data/gsm8k/train.parquet 6 | val_files: ~/data/gsm8k/test.parquet 7 | # Single-turn settings 8 | prompt_key: question 9 | response_key: answer 10 | prompt_dict_keys: ['question'] 11 | response_dict_keys: ['answer'] 12 | # Multi-turn settings 13 | multiturn: 14 | enable: false # Set to true to use multi-turn dataset 15 | messages_key: messages # Key for messages list in multi-turn mode 16 | max_length: 1024 17 | truncation: left 18 | balance_dp_token: False 19 | chat_template: null 20 | custom_cls: 21 | path: null 22 | name: null 23 | model: 24 | partial_pretrain: ~/models/gemma-1.1-7b-it 25 | fsdp_config: 26 | wrap_policy: 27 | min_num_params: 0 28 | cpu_offload: False 29 | offload_params: False 30 | external_lib: null 31 | enable_gradient_checkpointing: False 32 | trust_remote_code: False 33 | lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) 34 | lora_alpha: 16 # LoRA scaling factor 35 | target_modules: all-linear # Target modules for LoRA adaptation 36 | use_liger: False 37 | optim: 38 | lr: 1e-5 39 | betas: [0.9, 0.95] 40 | weight_decay: 0.01 41 | warmup_steps_ratio: 0.1 42 | clip_grad: 1.0 43 | lr_scheduler: cosine 44 | ulysses_sequence_parallel_size: 1 45 | use_remove_padding: False 46 | trainer: 47 | default_local_dir: /tmp/sft_model 48 | default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here 49 | resume_path: null 50 | project_name: gsm8k-sft 51 | experiment_name: test 52 | total_epochs: 4 53 | total_training_steps: null 54 | logger: ['console'] 55 | seed: 1 56 | 57 | -------------------------------------------------------------------------------- /train_ppo_sfted.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONHASHSEED=10000 4 | 5 | # Script arguments 6 | CONFIG_NAME=$1 7 | ckpt=$2 8 | name=$3 9 | 10 | # Validate arguments 11 | if [ -z "$CONFIG_NAME" ] || [ -z "$ckpt" ] || [ -z "$name" ]; then 12 | echo "Usage: $0 " 13 | echo "Example: $0 _2_sokoban /path/to/checkpoint experiment_name" 14 | exit 1 15 | fi 16 | 17 | # Create log directory 18 | mkdir -p log 19 | 20 | # Training parameters 21 | MODEL="1.5B" 22 | MODE="base" 23 | REWARD=0.0 24 | COS=False 25 | RENDER_MODE="text_with_coordinates" 26 | 27 | # Export training variables 28 | export MODEL=$MODEL 29 | export MODE=$MODE 30 | export REWARD=$REWARD 31 | export COS=$COS 32 | export RENDER_MODE=$RENDER_MODE 33 | 34 | echo "Starting PPO training with:" 35 | echo " - Config: $CONFIG_NAME" 36 | echo " - Checkpoint: $ckpt" 37 | echo " - Experiment: $name" 38 | echo " - Model: $MODEL" 39 | echo " - Mode: $MODE" 40 | echo " - Render Mode: $RENDER_MODE" 41 | echo "WANDB_ENTITY: $WANDB_ENTITY" 42 | 43 | if [ -z "$BASE_DIR" ]; then 44 | echo "BASE_DIR is not set" 45 | BASE_DIR=/home/aiops/zhuty/ragen/checkpoints 46 | echo "BASE_DIR is not set, using default: $BASE_DIR" 47 | fi 48 | 49 | # Run training 50 | CUDA_VISIBLE_DEVICES='0,1,2,3' python ../train.py \ 51 | --config-path=SPA/config \ 52 | --config-name $CONFIG_NAME \ 53 | model_path=$ckpt \ 54 | system.CUDA_VISIBLE_DEVICES=\'0,1,2,3\' \ 55 | trainer.n_gpus_per_node=4 \ 56 | trainer.total_training_steps=1000 \ 57 | trainer.experiment_name=${name} \ 58 | trainer.save_freq=100 \ 59 | trainer.default_local_dir=$BASE_DIR/${CONFIG_NAME}/$ckpt \ 60 | actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ 61 | +actor_rollout_ref.rollout.tp_size_check=True \ 62 | +algorithm.bi_level_gae=False 63 | 64 | echo "Training completed!" 65 | -------------------------------------------------------------------------------- /sft/finetune_ft.sh: -------------------------------------------------------------------------------- 1 | # NOTE only tested with 1 GPU 2 | set -x 3 | 4 | 5 | env_type=$1 6 | nproc_per_node=$2 7 | save_path=$3 8 | data_path=$4 9 | size=$5 10 | 11 | shift 5 12 | 13 | 14 | if [ ! -d $save_path ]; then 15 | mkdir -p $save_path 16 | fi 17 | 18 | 19 | 20 | export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH" 21 | 22 | # Sanity check: ensure nproc_per_node does not exceed visible GPUs 23 | if [ -n "$CUDA_VISIBLE_DEVICES" ]; then 24 | IFS=',' read -ra gpu_array <<< "$CUDA_VISIBLE_DEVICES" 25 | num_visible_gpus=${#gpu_array[@]} 26 | else 27 | # Fallback to nvidia-smi if CUDA_VISIBLE_DEVICES is not set 28 | if command -v nvidia-smi >/dev/null 2>&1; then 29 | num_visible_gpus=$(nvidia-smi --list-gpus | wc -l) 30 | else 31 | num_visible_gpus=0 32 | fi 33 | fi 34 | 35 | if [ "$nproc_per_node" -gt "$num_visible_gpus" ]; then 36 | echo "Error: Requested nproc_per_node=$nproc_per_node but only $num_visible_gpus GPUs are visible (CUDA_VISIBLE_DEVICES='$CUDA_VISIBLE_DEVICES')." >&2 37 | echo "Please reduce nproc_per_node or expose more GPUs." >&2 38 | exit 1 39 | fi 40 | 41 | echo "Starting training..." 42 | echo "nproc_per_node: $nproc_per_node" 43 | 44 | torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ 45 | -m sft.spa_sft_trainer \ 46 | data.train_files=$data_path/wm_train.parquet \ 47 | data.val_files=$data_path/wm_val.parquet \ 48 | data.prompt_key=prompt \ 49 | data.response_key=response \ 50 | data.max_length=2048 \ 51 | optim.lr=1e-4 \ 52 | data.train_batch_size=16 \ 53 | data.micro_batch_size_per_gpu=1 \ 54 | model.partial_pretrain=Qwen/Qwen2.5-$size-Instruct \ 55 | trainer.default_local_dir=$save_path \ 56 | trainer.experiment_name=${env_type}-sft-qwen-2.5-$size-instuct \ 57 | trainer.logger=['console'] \ 58 | trainer.total_epochs=5 \ 59 | trainer.default_hdfs_dir=null \ 60 | +trainer.max_ckpt_to_keep=2 \ 61 | model.target_modules=all-linear \ 62 | trainer.project_name=spa-sft \ 63 | model.enable_gradient_checkpointing=False $@ \ 64 | 2>&1 | tee $save_path/train.log 65 | 66 | 67 | -------------------------------------------------------------------------------- /run_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ============================================================================= 4 | # SPA Training Script 5 | # ============================================================================= 6 | # This script sets up the environment and launches training for the SPA model. 7 | # It handles environment configuration, checkpoint discovery, and training execution. 8 | 9 | set -euo pipefail # Exit on error, undefined vars, pipe failures 10 | 11 | # ============================================================================= 12 | # Environment Configuration 13 | # ============================================================================= 14 | 15 | # Ray configuration 16 | export RAY_object_spilling_threshold=0.99 17 | export RAY_BACKEND_LOG_LEVEL=FATAL 18 | # Reproducibility 19 | export PYTHONHASHSEED=10000 20 | 21 | 22 | export JAVA_HOME=/home/aiops/zhuty/ragen-dev/jdk-21.0.6 23 | export PATH=$JAVA_HOME/bin:$PATH 24 | # ensure that the JAVA_HOME exists 25 | if [ ! -d "$JAVA_HOME" ]; then 26 | echo "JAVA_HOME: $JAVA_HOME does not exist" 27 | exit 1 28 | fi 29 | 30 | # ============================================================================= 31 | # Training Configuration 32 | # ============================================================================= 33 | 34 | # Training parameters 35 | MODE="add_worldmodel" 36 | MODEL="1.5B" 37 | RENDER_MODE="text_with_coordinates" 38 | BSZ_NUM=5 39 | CONFIG_NAME=$1 40 | # CONFIG_NAME="_2_sokoban" 41 | # CONFIG_NAME="_10_sudoku" 42 | # ensure that the CONFIG_NAME is set, and it is one of the following: _2_sokoban, _10_sudoku 43 | if [ -z "$CONFIG_NAME" ]; then 44 | echo "CONFIG_NAME is not set" 45 | exit 1 46 | fi 47 | if [ "$CONFIG_NAME" != "_2_sokoban" ] && [ "$CONFIG_NAME" != "_10_sudoku" ]; then 48 | echo "CONFIG_NAME is not one of the following: _2_sokoban, _10_sudoku" 49 | exit 1 50 | fi 51 | 52 | 53 | # Derived paths 54 | OUTPUT_DIR="./sftdata/${CONFIG_NAME}-${MODEL}-${RENDER_MODE}" 55 | CHECKPOINT_DIR="./sftckpt/checkpoints${CONFIG_NAME}-${MODEL}-${RENDER_MODE}-qwen/" 56 | 57 | # Export training variables 58 | export MODE="$MODE" 59 | export MODEL="$MODEL" 60 | export PENALTY_VALUE=0.0 61 | export RENDER_MODE="$RENDER_MODE" 62 | export BT_NUM="$BSZ_NUM" 63 | export CONFIG_NAME="$CONFIG_NAME" 64 | export OUTPUT_DIR="$OUTPUT_DIR" 65 | 66 | # ============================================================================= 67 | # Validation and Setup 68 | # ============================================================================= 69 | 70 | 71 | # ============================================================================= 72 | # Training Pipeline 73 | # ============================================================================= 74 | 75 | echo "Starting training pipeline..." 76 | 77 | 78 | 79 | # Step 3: PPO Training 80 | echo "Step 3: Starting PPO training..." 81 | if [ "$CONFIG_NAME" == "_2_sokoban" ]; then 82 | EXPERIMENT_NAME="sokoban-${MODEL}-RENDER_MODE${RENDER_MODE}-baseline" 83 | elif [ "$CONFIG_NAME" == "_3_frozen_lake" ]; then 84 | EXPERIMENT_NAME="frozen_lake-${MODEL}-RENDER_MODE${RENDER_MODE}-baseline" 85 | elif [ "$CONFIG_NAME" == "_10_sudoku" ]; then 86 | EXPERIMENT_NAME="sudoku-${MODEL}-RENDER_MODE${RENDER_MODE}-baseline" 87 | fi 88 | 89 | echo "Training configuration:" 90 | echo " - Config: $CONFIG_NAME" 91 | echo " - Experiment: $EXPERIMENT_NAME" 92 | echo " - Model: $MODEL" 93 | echo " - Render Mode: $RENDER_MODE" 94 | 95 | # Launch training 96 | bash ./train_ppo_sfted.sh "$CONFIG_NAME" "Qwen/Qwen2.5-1.5B-Instruct" "$EXPERIMENT_NAME" 97 | 98 | echo "Training pipeline completed successfully!" 99 | 100 | 101 | -------------------------------------------------------------------------------- /config/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ppo_trainer # this is a symbolic link to the verl/verl/trainer/config/ppo_trainer.yaml file 3 | - envs 4 | 5 | system: 6 | CUDA_VISIBLE_DEVICES: "0" 7 | 8 | seed: 9 | train: 10000 10 | val: 123 11 | 12 | micro_batch_size_per_gpu: 1 13 | # ppo_mini_batch_size: 512 14 | ppo_mini_batch_size: 64 15 | model_path: Qwen/Qwen2.5-1.5B-Instruct 16 | enable_response_mask: False 17 | 18 | actor_rollout_ref: 19 | model: 20 | path: ${model_path} 21 | actor: 22 | mode: base 23 | ppo_mini_batch_size: ${ppo_mini_batch_size} # by default, ppo_mini_batch_size = train_batch_size / 4 24 | micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 25 | ppo_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micrqo_batch_size_per_gpu 26 | use_ref: True 27 | entropy_coeff: 0.001 28 | use_kl_loss: False 29 | kl_loss_coef: 0.000 30 | kl_loss_type: kl 31 | clip_ratio_low: 0.2 32 | clip_ratio_high: 0.28 33 | optim: 34 | betas: [0.9, 0.999] 35 | ref: 36 | mode: base 37 | log_prob_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 38 | rollout: 39 | log_prob_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 40 | tensor_model_parallel_size: 1 41 | max_model_len: 3600 42 | prompt_length: 1 # useless. Just put it here 43 | response_length: 400 # single-turn response length 44 | gpu_memory_utilization: 0.6 45 | # max_num_batched_tokens: 8192 # set only when enable_chunked_prefill is true 46 | temperature: 1 47 | # rollout_filter_ratio: 0.25 48 | rollout_filter_ratio: 0.25 49 | rollout_filter_type: std # max_mean or std 50 | val_kwargs: 51 | do_sample: True 52 | temperature: 1 53 | 54 | critic: 55 | mode: base 56 | ppo_mini_batch_size: ${ppo_mini_batch_size} # by default, ppo_mini_batch_size = train_batch_size / 4 57 | ppo_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 58 | model: 59 | path: ${model_path} 60 | optim: 61 | betas: [0.9, 0.999] 62 | 63 | data: 64 | max_prompt_length: null 65 | max_response_length: null 66 | train_batch_size: null 67 | 68 | algorithm: 69 | gamma: 1.0 70 | lam: 1.0 71 | high_level_gamma: 0.95 72 | adv_estimator: gae 73 | kl_penalty: kl # how to estimate kl divergence 74 | kl_ctrl: 75 | type: fixed 76 | kl_coef: 0.000 77 | 78 | trainer: 79 | project_name: textgame-ppo 80 | 81 | experiment_name: test 82 | total_training_steps: 1000 83 | validation_steps: 2 # validation instances = validation_steps * val_env_groups * group_size 84 | val_before_train: True 85 | n_gpus_per_node: 2 86 | test_freq: 10 87 | generations_to_log_to_wandb: 88 | train: 128 89 | val: 20 90 | logger: [ 'console', 'wandb' ] 91 | 92 | 93 | agent_proxy: 94 | max_turn: 5 95 | action_sep: "||" 96 | max_actions_per_turn: 5 # how many actions can be output at most in a single turn 97 | use_turn_scores: False # important to GAE when applying token-level rewards to token-level advantages. If False, will take the sum of scores as the reward for the last turn. 98 | enable_think: True # False -> no think RL 99 | reward_normalization: 100 | grouping: "state" # state / batch / inductive 101 | method: "identity" # asym_clip / identity / mean_std 102 | 103 | 104 | es_manager: 105 | format_penalty: -0.1 106 | train: 107 | env_groups: 32 108 | # under the same group, the env config and env seed are ensured to be equal 109 | group_size: 8 110 | env_configs: 111 | tags: ["SimpleSokoban"] 112 | n_groups: [32] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation 113 | val: 114 | env_groups: 128 115 | group_size: 8 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output 116 | env_configs: 117 | tags: ["SimpleSokoban"] 118 | n_groups: [128] # [128] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation 119 | 120 | ctx_manager: 121 | generation: # go to vllm 122 | gen_config: 123 | response_length: ${actor_rollout_ref.rollout.response_length} 124 | temperature: ${actor_rollout_ref.rollout.temperature} 125 | top_p: ${actor_rollout_ref.rollout.top_p} 126 | top_k: ${actor_rollout_ref.rollout.top_k} 127 | kwargs: null 128 | mode: ${oc.env:MODE} # add_worldmodel, base 129 | -------------------------------------------------------------------------------- /run_spa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ============================================================================= 4 | # SPA Training Script 5 | # ============================================================================= 6 | # This script sets up the environment and launches training for the SPA model. 7 | # It handles environment configuration, checkpoint discovery, and training execution. 8 | 9 | set -euo pipefail # Exit on error, undefined vars, pipe failures 10 | 11 | # ============================================================================= 12 | # Environment Configuration 13 | # ============================================================================= 14 | 15 | # Ray configuration 16 | export RAY_object_spilling_threshold=0.99 17 | export RAY_BACKEND_LOG_LEVEL=FATAL 18 | # Reproducibility 19 | export PYTHONHASHSEED=10000 20 | 21 | 22 | export JAVA_HOME=/home/aiops/zhuty/ragen-dev/jdk-21.0.6 23 | export PATH=$JAVA_HOME/bin:$PATH 24 | # ensure that the JAVA_HOME exists 25 | if [ ! -d "$JAVA_HOME" ]; then 26 | echo "JAVA_HOME: $JAVA_HOME does not exist" 27 | exit 1 28 | fi 29 | 30 | # ============================================================================= 31 | # Training Configuration 32 | # ============================================================================= 33 | 34 | # Training parameters 35 | MODE="add_worldmodel" 36 | MODEL="1.5B" 37 | RENDER_MODE="text_with_coordinates" 38 | BSZ_NUM=5 39 | # BSZ_NUM=30 40 | CONFIG_NAME=$1 41 | CKPT=${2:-last} 42 | GENERATE_DATA=${3:-False} 43 | # CONFIG_NAME="_2_sokoban" 44 | # CONFIG_NAME="_10_sudoku" 45 | # ensure that the CONFIG_NAME is set, and it is one of the following: _2_sokoban, _10_sudoku, _3_frozen_lake 46 | if [ -z "$CONFIG_NAME" ]; then 47 | echo "CONFIG_NAME is not set" 48 | exit 1 49 | fi 50 | if [ "$CONFIG_NAME" != "_2_sokoban" ] && [ "$CONFIG_NAME" != "_10_sudoku" ] && [ "$CONFIG_NAME" != "_3_frozen_lake" ]; then 51 | echo "CONFIG_NAME is not one of the following: _2_sokoban, _10_sudoku, _3_frozen_lake" 52 | exit 1 53 | fi 54 | 55 | # Derived paths 56 | OUTPUT_DIR="./sftdata/${CONFIG_NAME}-${MODEL}-${RENDER_MODE}" 57 | CHECKPOINT_DIR="./sftckpt/checkpoints${CONFIG_NAME}-${MODEL}-${RENDER_MODE}-qwen/" 58 | mkdir -p "$OUTPUT_DIR" 59 | mkdir -p "$CHECKPOINT_DIR" 60 | 61 | # Export training variables 62 | export MODE="$MODE" 63 | export MODEL="$MODEL" 64 | export PENALTY_VALUE=0.0 65 | export RENDER_MODE="$RENDER_MODE" 66 | export BT_NUM="$BSZ_NUM" 67 | export CONFIG_NAME="$CONFIG_NAME" 68 | export OUTPUT_DIR="$OUTPUT_DIR" 69 | export CKPT="$CKPT" 70 | export GENERATE_DATA="$GENERATE_DATA" 71 | 72 | # ============================================================================= 73 | # Validation and Setup 74 | # ============================================================================= 75 | 76 | 77 | # ============================================================================= 78 | # Training Pipeline 79 | # ============================================================================= 80 | 81 | echo "Starting training pipeline..." 82 | 83 | if [ "$GENERATE_DATA" == "True" ]; then 84 | # Step 1: Generate SFT data (commented out - uncomment if needed) 85 | echo "Step 1: Generating SFT data..." 86 | python -m SPA_agent.generate_sft_data --config-name "$CONFIG_NAME" 87 | # exit 0 # exit if you need to filter 88 | 89 | # Step 2: Fine-tuning (commented out - uncomment if needed) 90 | echo "Step 2: Fine-tuning..." 91 | bash sft/finetune_ft.sh "$CONFIG_NAME" 4 "$CHECKPOINT_DIR" "$OUTPUT_DIR" "$MODEL" 92 | fi 93 | 94 | # Validate checkpoint directory exists 95 | if [[ ! -d "$CHECKPOINT_DIR" ]]; then 96 | echo "Error: Checkpoint directory '$CHECKPOINT_DIR' does not exist!" 97 | echo "Please ensure the SFT checkpoint has been created first." 98 | exit 1 99 | fi 100 | 101 | # Create output directory if it doesn't exist 102 | mkdir -p "$OUTPUT_DIR" 103 | 104 | # ============================================================================= 105 | # Checkpoint Discovery 106 | # ============================================================================= 107 | 108 | echo "Searching for latest checkpoint in: $CHECKPOINT_DIR" 109 | 110 | # Find the latest checkpoint 111 | LATEST_CKPT=$(ls -d "${CHECKPOINT_DIR%/}"/*/ 2>/dev/null | sort -V | tail -n 1) 112 | if [[ -z "$LATEST_CKPT" ]]; then 113 | echo "Error: No checkpoints found in '$CHECKPOINT_DIR'" 114 | exit 1 115 | fi 116 | 117 | LATEST_CKPT=${LATEST_CKPT%/} 118 | echo "Latest checkpoint found: $LATEST_CKPT" 119 | 120 | if [ -z "$CKPT" ] || [ "$CKPT" == "last" ]; then 121 | CKPT=$LATEST_CKPT 122 | else 123 | CKPT=$CHECKPOINT_DIR/global_step_$CKPT 124 | fi 125 | 126 | 127 | # Validate checkpoint directory 128 | if [[ ! -d "$CKPT" ]]; then 129 | echo "Error: Latest checkpoint directory '$CKPT' is not accessible" 130 | exit 1 131 | fi 132 | 133 | 134 | 135 | 136 | # Step 3: PPO Training 137 | echo "Step 3: Starting PPO training..." 138 | EXPERIMENT_NAME="${CONFIG_NAME}-${MODEL}-RENDER_MODE${RENDER_MODE}-spa-${CKPT}" 139 | 140 | echo "Training configuration:" 141 | echo " - Config: $CONFIG_NAME" 142 | echo " - Checkpoint: $CKPT" 143 | echo " - Experiment: $EXPERIMENT_NAME" 144 | echo " - Model: $MODEL" 145 | echo " - Render Mode: $RENDER_MODE" 146 | 147 | # Launch training 148 | bash ./train_ppo_sfted.sh "$CONFIG_NAME" "$CKPT" "$EXPERIMENT_NAME" 149 | 150 | echo "Training pipeline completed successfully!" 151 | 152 | 153 | -------------------------------------------------------------------------------- /config/base-lora.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ppo_trainer # this is a symbolic link to the verl/verl/trainer/config/ppo_trainer.yaml file 3 | - envs 4 | 5 | system: 6 | CUDA_VISIBLE_DEVICES: "0" 7 | 8 | seed: 9 | train: 10000 10 | val: 123 11 | 12 | micro_batch_size_per_gpu: 1 13 | ppo_mini_batch_size: 32 14 | model_path: Qwen/Qwen2.5-0.5B-Instruct 15 | enable_response_mask: True # Enabling response mask could improve stability of rollout/old_log_prob, as P(st|history) are no longer calculated in loss here. See https://docs.google.com/document/d/1bg7obeiKTExuHHBl5uOiSpec5uLDZ2Tgvxy6li5pHX4/edit?usp=sharing for more details. 16 | grpo_advantage_length_weight: False # if you do not enable this and critic/advantage_estimator is GRPO, and the critic/advantages/mean is too low, then you can try enabling this to encourage reasoning and forbid collapse 17 | 18 | lora: 19 | rank: 64 20 | alpha: 64 21 | target_modules: all-linear 22 | 23 | actor_rollout_ref: 24 | model: 25 | path: ${model_path} 26 | lora_rank: ${lora.rank} 27 | lora_alpha: ${lora.alpha} 28 | target_modules: ${lora.target_modules} 29 | actor: 30 | ppo_mini_batch_size: ${ppo_mini_batch_size} # by default, ppo_mini_batch_size = train_batch_size / 4 31 | micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 32 | ppo_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 33 | use_ref: True 34 | entropy_coeff: 0.001 35 | use_kl_loss: False 36 | kl_loss_coef: 0.000 37 | kl_loss_type: kl 38 | clip_ratio_low: 0.2 39 | clip_ratio_high: 0.28 40 | grpo_advantage_length_weight: ${grpo_advantage_length_weight} 41 | optim: 42 | betas: [0.9, 0.999] 43 | lr: 1e-5 44 | ref: 45 | log_prob_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 46 | rollout: 47 | log_prob_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 48 | tensor_model_parallel_size: 1 49 | max_model_len: 3600 50 | prompt_length: 1 # useless. Just put it here 51 | response_length: 400 # single-turn response length 52 | gpu_memory_utilization: 0.5 53 | max_num_batched_tokens: 8192 # set only when enable_chunked_prefill is true 54 | temperature: 1 55 | rollout_filter_ratio: 0.25 56 | rollout_filter_type: std # max_mean or std 57 | enforce_eager: True # for small models, set both enforce_eager and free_cache_engine to False to make rollout faster 58 | free_cache_engine: True 59 | val_kwargs: 60 | do_sample: True 61 | temperature: 0.5 62 | tp_size_check: true 63 | 64 | critic: 65 | ppo_mini_batch_size: ${ppo_mini_batch_size} # by default, ppo_mini_batch_size = train_batch_size / 4 66 | ppo_micro_batch_size_per_gpu: ${micro_batch_size_per_gpu} # following micro_batch_size_per_gpu 67 | model: 68 | path: ${model_path} 69 | lora_rank: ${lora.rank} 70 | lora_alpha: ${lora.alpha} 71 | target_modules: ${lora.target_modules} 72 | optim: 73 | betas: [0.9, 0.999] 74 | lr: 1e-4 75 | 76 | data: 77 | max_prompt_length: null 78 | max_response_length: null 79 | train_batch_size: null 80 | 81 | algorithm: 82 | gamma: 1.0 83 | lam: 1.0 84 | high_level_gamma: 0.95 85 | adv_estimator: gae 86 | bi_level_gae: False 87 | kl_penalty: kl # how to estimate kl divergence 88 | kl_ctrl: 89 | type: fixed 90 | kl_coef: 0.000 91 | 92 | trainer: 93 | project_name: ragen_latest 94 | experiment_name: test 95 | total_training_steps: 200 96 | validation_steps: 1 # validation instances = validation_steps * val_env_groups * group_size 97 | val_before_train: True 98 | n_gpus_per_node: 1 99 | test_freq: 10 100 | generations_to_log_to_wandb: 101 | train: 128 # TODO: will be implemented 102 | val: 20 103 | logger: [ 'console', 'wandb' ] 104 | 105 | agent_proxy: 106 | max_turn: 5 107 | action_sep: "||" 108 | max_actions_per_turn: 5 # how many actions can be output at most in a single turn 109 | use_turn_scores: False # important to GAE when applying token-level rewards to token-level advantages. If False, will take the sum of scores as the reward for the last turn. 110 | enable_think: True # False -> no think RL 111 | reward_normalization: 112 | grouping: "state" # state / batch / inductive 113 | method: "identity" # asym_clip / identity / mean_std 114 | 115 | es_manager: 116 | format_penalty: -0.1 117 | train: 118 | env_groups: 8 119 | # under the same group, the env config and env seed are ensured to be equal 120 | group_size: 16 121 | env_configs: 122 | tags: ["SimpleSokoban"] 123 | n_groups: [8] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation 124 | val: 125 | env_groups: 256 126 | group_size: 1 # should be set to 1 because when val temperature is set to 0 and group size > 1, there will be repetitive prompts which leads to same trajectory. 127 | env_configs: 128 | tags: ["SimpleSokoban"] 129 | n_groups: [256] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation 130 | 131 | ctx_manager: 132 | generation: # go to vllm 133 | gen_config: 134 | response_length: ${actor_rollout_ref.rollout.response_length} 135 | temperature: ${actor_rollout_ref.rollout.temperature} 136 | top_p: ${actor_rollout_ref.rollout.top_p} 137 | top_k: ${actor_rollout_ref.rollout.top_k} 138 | kwargs: null 139 | -------------------------------------------------------------------------------- /sft/filter_sft_by_tag.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Filter SFT parquet splits by minimal structural rules: 4 | 5 | Keep only rows whose response follows this structure (case-insensitive tags): 6 | - Contains ' ... ' 7 | - Inside think: at least one ' ... ' and at least one 8 | ' ... ', with the first observation before the first prediction 9 | 10 | Optional length filter and de-duplication are still supported. 11 | 12 | Usage: 13 | python filter_sft_by_tags.py \ 14 | --in-train /projects/b1222/shiqi/Ragen-dev-test/ragen/sft/data/sudoku_coords_0920_1756/wm_train.parquet \ 15 | --in-val /projects/b1222/shiqi/Ragen-dev-test/ragen/sft/data/sudoku_coords_0920_1756/wm_val.parquet \ 16 | --out-dir /projects/b1222/shiqi/Ragen-dev-test/ragen/sft/data/sudoku_filtered \ 17 | --min-len 1 --max-len 100000 --no-dedup # all optional 18 | Requires: pandas, pyarrow 19 | """ 20 | 21 | import argparse 22 | import os 23 | import re 24 | from typing import Optional, Tuple 25 | 26 | import pandas as pd 27 | 28 | 29 | def load_parquet(path: str) -> pd.DataFrame: 30 | if not os.path.isfile(path): 31 | raise FileNotFoundError(f"File not found: {path}") 32 | try: 33 | return pd.read_parquet(path) 34 | except Exception as e: 35 | raise RuntimeError(f"Failed to read parquet {path}: {e}") 36 | 37 | 38 | _DIR_PATTERN = re.compile(r"^\s*(?:Up|Down|Left|Right)(?:\s*\|\|\s*(?:Up|Down|Left|Right)){0,19}\s*$", 39 | flags=re.IGNORECASE) 40 | 41 | 42 | def _extract_tag(text: str, tag: str) -> Tuple[Optional[re.Match], Optional[str]]: 43 | m = re.search(fr"<{tag}>(.*?)", text, flags=re.IGNORECASE | re.DOTALL) 44 | return m, (m.group(1) if m else None) 45 | 46 | 47 | def is_good_response(resp: str) -> bool: 48 | """Apply basic structural checks to a response string.""" 49 | if not isinstance(resp, str) or resp.strip() == "": 50 | return False 51 | 52 | # Must have complete structure: ...... 53 | # And think must appear before answer 54 | m_think, think_body = _extract_tag(resp, "think") 55 | m_answer, answer_body = _extract_tag(resp, "answer") 56 | 57 | if not m_think or not m_answer: 58 | return False 59 | 60 | # Check that think appears before answer 61 | if m_think.start() >= m_answer.start(): 62 | return False 63 | 64 | # Inside think: need at least one observation and one prediction 65 | # And first observation must appear before first prediction 66 | if not think_body: 67 | return False 68 | 69 | obs_iter = list(re.finditer(r"(.*?)", think_body, flags=re.IGNORECASE | re.DOTALL)) 70 | pred_iter = list(re.finditer(r"(.*?)", think_body, flags=re.IGNORECASE | re.DOTALL)) 71 | 72 | # Require both observation and prediction, and observation must appear first 73 | if len(obs_iter) == 0 or len(pred_iter) == 0: 74 | return False 75 | if obs_iter[0].start() > pred_iter[0].start(): 76 | return False 77 | 78 | return True 79 | 80 | 81 | def filter_df(df: pd.DataFrame, min_len: int, max_len: int, dedup: bool) -> pd.DataFrame: 82 | # Basic column presence 83 | needed = ["prompt", "response"] 84 | for col in needed: 85 | if col not in df.columns: 86 | raise KeyError(f"Input dataframe missing required column: {col}") 87 | 88 | # Drop NA 89 | df = df.dropna(subset=["prompt", "response"]).copy() 90 | 91 | # Ensure strings 92 | df["response"] = df["response"].astype(str) 93 | 94 | # Strict structural filter 95 | mask_struct = df["response"].apply(is_good_response) 96 | df = df[mask_struct] 97 | 98 | # Length filter 99 | if min_len is not None or max_len is not None: 100 | min_len = 0 if min_len is None else min_len 101 | max_len = 10**12 if max_len is None else max_len 102 | lens = df["response"].str.len() 103 | df = df[(lens >= min_len) & (lens <= max_len)] 104 | 105 | # Dedup on (prompt, response) 106 | if dedup: 107 | df = df.drop_duplicates(subset=["prompt", "response"]) 108 | 109 | return df.reset_index(drop=True) 110 | 111 | 112 | def main(): 113 | ap = argparse.ArgumentParser() 114 | ap.add_argument("--in-train", required=True, help="Input train parquet path") 115 | ap.add_argument("--in-val", required=True, help="Input val parquet path") 116 | ap.add_argument("--out-dir", required=True, help="Output directory for filtered parquet files") 117 | ap.add_argument("--min-len", type=int, default=None, help="Min response length filter (chars)") 118 | ap.add_argument("--max-len", type=int, default=None, help="Max response length filter (chars)") 119 | ap.add_argument("--no-dedup", action="store_true", help="Disable deduplication on (prompt, response)") 120 | args = ap.parse_args() 121 | 122 | os.makedirs(args.out_dir, exist_ok=True) 123 | 124 | print(f"Loading train: {args.in_train}") 125 | train = load_parquet(args.in_train) 126 | print(f"Loading val : {args.in_val}") 127 | val = load_parquet(args.in_val) 128 | 129 | print(f"Train size before: {len(train)}; Val size before: {len(val)}") 130 | train_f = filter_df(train, args.min_len, args.max_len, dedup=(not args.no_dedup)) 131 | val_f = filter_df(val, args.min_len, args.max_len, dedup=(not args.no_dedup)) 132 | 133 | # # only keep 1/7 of the train and valid 134 | # train_f = train_f.sample(frac=1/7) 135 | # val_f = val_f.sample(frac=1/7) 136 | 137 | print(f"Train size after : {len(train_f)}; Val size after : {len(val_f)}") 138 | 139 | out_train = os.path.join(args.out_dir, "wm_train.parquet") 140 | out_val = os.path.join(args.out_dir, "wm_val.parquet") 141 | 142 | train_f.to_parquet(out_train, index=False) 143 | val_f.to_parquet(out_val, index=False) 144 | 145 | print(f"Saved filtered train: {out_train}") 146 | print(f"Saved filtered val : {out_val}") 147 | 148 | 149 | if __name__ == "__main__": 150 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPA: Self-Play with World Model for LLM Agents 2 | 3 |

4 | 5 | 6 | 7 | 8 |

9 | 10 | SPA (Self-Play Agent) is a reinforcement learning recipe for training Large Language Model (LLM) agents in **out-of-distribution (OOD) environments**. By equipping agents with an **internal world model** through self-play supervised finetuning (SFT), SPA enables better grounding, broader exploration, and more reliable generalization. 11 | 12 | --- 13 | 14 | ## Overview 15 | 16 | LLM agents often struggle when deployed in environments that differ from their pre-training distribution. Standard reinforcement learning tends to overfit to narrow solution paths, improving **Pass@1** slightly but causing **Pass@k** to degrade. This reflects brittle exploration and weak generalization. 17 | 18 | SPA addresses this by introducing a **world model** with two key components: 19 | 20 | * **State Representation**: structured abstractions (e.g., symbolic coordinates in Sokoban) that lower perplexity and make spatial relations explicit. 21 | * **Transition Modeling**: predicting next states during self-play, enabling the agent to internalize environment dynamics before policy optimization. 22 | 23 | This initialization makes subsequent PPO training more stable and effective. 24 | 25 | --- 26 | 27 | ## Key Results 28 | 29 | SPA significantly improves performance across challenging environments: 30 | 31 | * **Sokoban**: Pass@1 success rate from **25.6% → 59.8%** 32 | * **FrozenLake**: Pass@1 success rate from **22.1% → 70.9%** 33 | * **Sudoku**: Pass@1 success rate from **0.0% → 59.6%** 34 | 35 | These improvements are consistent across different LLM families, including **Qwen** and **LLaMA** models. 36 | 37 | --- 38 | 39 | ## Framework 40 | 41 | SPA training consists of three stages: 42 | 43 | 1. **Data Generation**: Collect self-play trajectories with `` and `` states. 44 | 2. **Supervised Finetuning (SFT)**: Train the agent to predict next states and actions. 45 | 3. **PPO Optimization**: Reinforce policies initialized with the learned world model. 46 | 47 | This exploration-before-exploitation process allows agents to first **learn environment rules**, then optimize for rewards. 48 | 49 | --- 50 | 51 | ## Repository Setup 52 | 53 | Clone **RAGEN** and place SPA inside: 54 | 55 | ```bash 56 | git clone git@github.com:RAGEN-AI/RAGEN.git 57 | cd RAGEN 58 | git clone git@github.com:shiqichen17/SPA.git 59 | ``` 60 | 61 | --- 62 | 63 | ## Environment Setup 64 | 65 | From the RAGEN root directory: 66 | 67 | ```bash 68 | bash scripts/setup_ragen.sh 69 | pip uninstall -y torch torchvision torchaudio && pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 70 | pip uninstall -y vllm flash-attn flash_attn 71 | pip install vllm==0.8.5.post1 72 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl 73 | python -c "import torch; import flash_attn; import vllm; print('✅ All modules loaded successfully.')" 74 | ``` 75 | 76 | > **Note**: Use the versions above exactly to avoid runtime errors. 77 | 78 | --- 79 | 80 | ## Quick Start 81 | 82 | From the SPA directory: 83 | 84 | ```bash 85 | cd SPA 86 | bash run_spa.sh [CKPT] [GENERATE_DATA] 87 | ``` 88 | 89 | **Arguments:** 90 | 91 | * `CONFIG_NAME` (required): Environment config - `_2_sokoban`, `_10_sudoku`, or `_3_frozen_lake` 92 | * `CKPT` (optional, default: `last`): Checkpoint to use (`last` for latest, or step number like `1000`) 93 | * `GENERATE_DATA` (optional, default: `False`): Set to `True` to run full pipeline, `False` for PPO only 94 | 95 | **Examples:** 96 | 97 | ```bash 98 | # Full pipeline (generate data → SFT → PPO) 99 | bash run_spa.sh _2_sokoban last True 100 | 101 | # PPO training only with existing checkpoint 102 | bash run_spa.sh _2_sokoban last False 103 | 104 | # Use specific checkpoint step 105 | bash run_spa.sh _10_sudoku 2000 False 106 | ``` 107 | 108 | This script runs the **full pipeline** (when `GENERATE_DATA=True`): 109 | 110 | * Generate self-play training data 111 | * Perform SFT world-model training 112 | * Run PPO policy optimization 113 | 114 | --- 115 | 116 | ## Pretrained Models and Datasets 117 | 118 | We provide pretrained models and training datasets for all three environments on Hugging Face: 119 | 120 | | Environment | 📊 SFT Training Data | 🤖 Model (after self-play finetuning) | 121 | |------------|------------|------------| 122 | | **Sokoban** | [SPA-sokoban-data](https://huggingface.co/datasets/tyzhu/SPA-sokoban-data) | [SPA-sokoban-qwen2.5-1.5b-instruct](https://huggingface.co/tyzhu/SPA-sokoban-qwen2.5-1.5b-instruct) | 123 | | **FrozenLake** | [SPA-frozenlake-data](https://huggingface.co/datasets/tyzhu/SPA-frozenlake-data) | [SPA-frozenlake-qwen2.5-1.5b-instruct](https://huggingface.co/tyzhu/SPA-frozenlake-qwen2.5-1.5b-instruct) | 124 | | **Sudoku** | [SPA-sudoku-data](https://huggingface.co/datasets/tyzhu/SPA-sudoku-data) | [SPA-sudoku-qwen2.5-1.5b-instruct](https://huggingface.co/tyzhu/SPA-sudoku-qwen2.5-1.5b-instruct) | 125 | 126 | These resources allow you to: 127 | - **Use the pretrained models** directly for inference or further finetuning 128 | - **Reproduce the SFT stage** using the provided training data 129 | - **Skip data generation** and start from the SFT or PPO stages 130 | 131 | > **Note**: The FrozenLake and Sudoku datasets include trajectory filtering to remove trajectories not following the format, while the Sokoban dataset contains unfiltered raw trajectories from self-play data generation. 132 | 133 | --- 134 | 135 | ## Supported Environments 136 | 137 | SPA supports a variety of environments integrated through RAGEN: 138 | 139 | * **Sokoban** (grid-based spatial puzzles) 140 | * **FrozenLake** (navigation under stochastic transitions) 141 | * **Sudoku** (4×4 logical puzzles) 142 | 143 | --- 144 | 145 | ## Example World Model Trace 146 | 147 | For Sokoban, SPA generates structured reasoning traces: 148 | 149 | ``` 150 | 151 | 152 | ###### 153 | #___O# 154 | #__X_# 155 | ###P_# 156 | ###__# 157 | ###### 158 | Player (P) at (3,3); box (X) at (2,3); goal at (1,4). 159 | 160 | 161 | ###### 162 | #___O# 163 | #____# 164 | ###X_# 165 | ###P_# 166 | ###### 167 | 168 | 169 | Up 170 | ``` 171 | 172 | This explicit **observation → prediction → action** format grounds decision-making in environment dynamics. 173 | 174 | --- 175 | 176 | ## Configuration 177 | 178 | Key configuration files are located in `config/`: 179 | 180 | * `base.yaml`: core training settings 181 | * `_2_sokoban.yaml`, `_3_frozen_lake.yaml`, etc.: environment-specific configs 182 | * `envs.yaml`: environment registry 183 | 184 | Important parameters: 185 | 186 | * `model_path`: base model (e.g., `Qwen/Qwen2.5-1.5B-Instruct`) 187 | * `trainer.total_training_steps`: PPO steps 188 | * `agent_proxy.max_turn`: max turns per episode 189 | * `es_manager.train.env_groups`: number of environment groups 190 | 191 | --- 192 | 193 | ## Citation 194 | 195 | If you use SPA in your work, please cite: 196 | 197 | ```bibtex 198 | @misc{chen2025spa, 199 | title={Internalizing World Models via Self-Play Finetuning for Agentic RL}, 200 | author={Shiqi Chen and Tongyao Zhu and Zian Wang and Jinghan Zhang and Kangrui Wang and Siyang Gao and Teng Xiao and Yee Whye Teh and Junxian He and Manling Li}, 201 | year={2025}, 202 | eprint={2510.15047}, 203 | archivePrefix={arXiv}, 204 | primaryClass={cs.LG}, 205 | url={https://arxiv.org/abs/2510.15047}, 206 | } 207 | 208 | ``` 209 | 210 | --- 211 | 212 | ## License 213 | 214 | This project is licensed under the Apache 2.0 License. See the LICENSE file for details. 215 | 216 | --- 217 | 218 | ## Acknowledgments 219 | 220 | SPA is built on top of the [RAGEN](https://github.com/RAGEN-AI/RAGEN) framework, extending it with explicit world-model pretraining for improved RL scalability. 221 | -------------------------------------------------------------------------------- /config/ppo_trainer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | tokenizer: null 3 | train_files: ~/data/rlhf/gsm8k/train.parquet 4 | val_files: ~/data/rlhf/gsm8k/test.parquet 5 | prompt_key: prompt 6 | reward_fn_key: data_source 7 | max_prompt_length: 512 8 | max_response_length: 512 9 | train_batch_size: 1024 10 | val_batch_size: null 11 | return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs 12 | return_raw_chat: False 13 | shuffle: True 14 | filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. 15 | filter_overlong_prompts_workers: 1 16 | truncation: error 17 | image_key: images 18 | video_key: videos 19 | custom_cls: 20 | path: null 21 | name: null 22 | 23 | actor_rollout_ref: 24 | hybrid_engine: True 25 | model: 26 | path: ~/models/deepseek-llm-7b-chat 27 | external_lib: null 28 | override_config: { } 29 | enable_gradient_checkpointing: True 30 | use_remove_padding: False 31 | use_liger: False 32 | actor: 33 | strategy: fsdp # [fsdp, fsdp2], This is for backward-compatibility 34 | ppo_mini_batch_size: 256 35 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu 36 | ppo_micro_batch_size_per_gpu: null 37 | use_dynamic_bsz: False 38 | ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} 39 | grad_clip: 1.0 40 | # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) 41 | clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified 42 | clip_ratio_low: 0.2 43 | clip_ratio_high: 0.2 44 | clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 45 | loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" 46 | entropy_coeff: 0 47 | use_kl_loss: False # True for GRPO 48 | use_torch_compile: True # False to disable torch compile 49 | kl_loss_coef: 0.001 # for grpo 50 | kl_loss_type: low_var_kl # for grpo 51 | ppo_epochs: 1 52 | shuffle: False 53 | ulysses_sequence_parallel_size: 1 # sp size 54 | checkpoint: 55 | contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space 56 | optim: 57 | lr: 1e-6 58 | lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. 59 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime 60 | min_lr_ratio: null # only useful for warmup with cosine 61 | warmup_style: constant # select from constant/cosine 62 | total_training_steps: -1 # must be override by program 63 | weight_decay: 0.01 64 | fsdp_config: 65 | wrap_policy: 66 | # transformer_layer_cls_to_wrap: None 67 | min_num_params: 0 68 | param_offload: False 69 | optimizer_offload: False 70 | offload_policy: False # only for fsdp2, offload param\grad\optimizer during train 71 | reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] 72 | fsdp_size: -1 73 | ref: 74 | strategy: fsdp 75 | fsdp_config: 76 | param_offload: False 77 | reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] 78 | wrap_policy: 79 | # transformer_layer_cls_to_wrap: None 80 | min_num_params: 0 81 | use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} 82 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu 83 | log_prob_micro_batch_size_per_gpu: null 84 | log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} 85 | log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} 86 | ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size 87 | rollout: 88 | name: vllm 89 | mode: sync # sync: LLM, async: AsyncLLM 90 | chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler 91 | temperature: 1.0 92 | top_k: -1 # 0 for hf rollout, -1 for vllm rollout 93 | top_p: 1 94 | use_fire_sampling: False # https://arxiv.org/abs/2410.21236 95 | prompt_length: ${data.max_prompt_length} # not use for opensource 96 | response_length: ${data.max_response_length} 97 | # for vllm rollout 98 | dtype: bfloat16 # should align with FSDP 99 | gpu_memory_utilization: 0.5 100 | ignore_eos: False 101 | enforce_eager: True 102 | free_cache_engine: True 103 | load_format: dummy_dtensor 104 | tensor_model_parallel_size: 2 105 | max_num_batched_tokens: 8192 106 | max_model_len: null 107 | max_num_seqs: 1024 108 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu 109 | log_prob_micro_batch_size_per_gpu: null 110 | log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} 111 | log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} 112 | disable_log_stats: True 113 | enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. 114 | # for hf rollout 115 | do_sample: True 116 | # number of responses (i.e. num sample times) 117 | n: 1 # > 1 for grpo 118 | engine_kwargs: # inference engine parameters 119 | swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB 120 | val_kwargs: 121 | # sampling parameters for validation 122 | top_k: -1 # 0 for hf rollout, -1 for vllm rollout 123 | top_p: 1.0 124 | temperature: 0 125 | n: 1 126 | do_sample: False # default eager for validation 127 | multi_turn: 128 | enable: False # should set rollout.name to sglang_async if True 129 | max_turns: null # null for no limit (default max_length // 3) 130 | tool_config_path: null # null for no tool 131 | format: chatml # chatml, more formats will be supported in the future 132 | 133 | critic: 134 | rollout_n: ${actor_rollout_ref.rollout.n} 135 | strategy: fsdp # [fsdp, fsdp2] 136 | optim: 137 | lr: 1e-5 138 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime 139 | min_lr_ratio: null # only useful for warmup with cosine 140 | warmup_style: constant # select from constant/cosine 141 | total_training_steps: -1 # must be override by program 142 | weight_decay: 0.01 143 | model: 144 | path: ~/models/deepseek-llm-7b-chat 145 | tokenizer_path: ${actor_rollout_ref.model.path} 146 | override_config: { } 147 | external_lib: ${actor_rollout_ref.model.external_lib} 148 | enable_gradient_checkpointing: True 149 | use_remove_padding: False 150 | fsdp_config: 151 | param_offload: False 152 | optimizer_offload: False 153 | offload_policy: False # only for fsdp2, offload param\grad\optimizer during train 154 | reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] 155 | wrap_policy: 156 | # transformer_layer_cls_to_wrap: None 157 | min_num_params: 0 158 | fsdp_size: -1 159 | ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} 160 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu 161 | ppo_micro_batch_size_per_gpu: null 162 | forward_micro_batch_size: ${critic.ppo_micro_batch_size} 163 | forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} 164 | use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} 165 | ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 166 | forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} 167 | ulysses_sequence_parallel_size: 1 # sp size 168 | ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} 169 | shuffle: ${actor_rollout_ref.actor.shuffle} 170 | grad_clip: 1.0 171 | cliprange_value: 0.5 172 | checkpoint: 173 | contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space 174 | 175 | reward_model: 176 | enable: False 177 | strategy: fsdp 178 | model: 179 | input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical 180 | path: ~/models/FsfairX-LLaMA3-RM-v0.1 181 | external_lib: ${actor_rollout_ref.model.external_lib} 182 | use_remove_padding: False 183 | fsdp_config: 184 | wrap_policy: 185 | min_num_params: 0 186 | param_offload: False 187 | reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] 188 | fsdp_size: -1 189 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu 190 | micro_batch_size_per_gpu: null # set a number 191 | max_length: null 192 | ulysses_sequence_parallel_size: 1 # sp size 193 | use_dynamic_bsz: ${critic.use_dynamic_bsz} 194 | forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} 195 | reward_manager: naive 196 | launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob 197 | 198 | custom_reward_function: 199 | path: null 200 | name: compute_score 201 | 202 | algorithm: 203 | gamma: 1.0 204 | lam: 1.0 205 | adv_estimator: gae 206 | norm_adv_by_std_in_grpo: True 207 | use_kl_in_reward: False 208 | kl_penalty: kl # how to estimate kl divergence 209 | kl_ctrl: 210 | type: fixed 211 | kl_coef: 0.001 212 | horizon: 10000 213 | target_kl: 0.1 214 | 215 | trainer: 216 | balance_batch: True 217 | total_epochs: 30 218 | total_training_steps: null 219 | project_name: verl_examples 220 | experiment_name: gsm8k 221 | logger: [ 'console', 'wandb' ] 222 | log_val_generations: 0 223 | rollout_data_dir: null # directory for logging the rollout data, no dump if null 224 | validation_data_dir: null # directory for logging the validation data, no dump if null 225 | nnodes: 1 226 | n_gpus_per_node: 8 227 | save_freq: -1 228 | # auto: find the last ckpt to resume. If can't find, start from scratch 229 | resume_mode: auto # or disable or resume_path if resume_from_path is set 230 | resume_from_path: null 231 | val_before_train: True 232 | test_freq: -1 233 | critic_warmup: 0 234 | default_hdfs_dir: null 235 | del_local_ckpt_after_load: False 236 | default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} 237 | max_actor_ckpt_to_keep: 1 238 | max_critic_ckpt_to_keep: 1 239 | # The timeout for ray worker group to wait for the register center to be ready 240 | ray_wait_register_center_timeout: 300 241 | 242 | ray_init: 243 | num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. -------------------------------------------------------------------------------- /sft/spa_sft_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | SFT dataset 16 | - We assume user pass a single parquet file. 17 | - We load all the data into the memory. 18 | Each parquet file contains 19 | """ 20 | import pdb 21 | 22 | from typing import List, Union 23 | 24 | import pandas as pd 25 | 26 | import torch 27 | from torch.utils.data import Dataset 28 | from transformers import AutoTokenizer, PreTrainedTokenizer 29 | 30 | from verl.utils.fs import copy_local_path_from_hdfs 31 | from verl.utils.model import compute_position_id_with_mask 32 | from verl.utils import hf_tokenizer 33 | 34 | import numpy as np 35 | 36 | from typing import Union 37 | import numpy as np 38 | import copy 39 | 40 | def apply_chat_template( 41 | tokenizer, 42 | messages: Union[np.ndarray, str], 43 | response: str, 44 | with_thinking=True, 45 | **kwargs 46 | ): 47 | """ 48 | Apply a chat template to given messages. 49 | ================================ 50 | Args: 51 | - tokenizer: The tokenizer to use. 52 | - messages (Union[np.ndarray, str]): The messages to apply the template to. 53 | - str: The messages are already formatted. 54 | - np.ndarray: The messages are not formatted. 55 | - with_thinking: Whether the assistant needs to output think tags. 56 | - e.g., <|im_start|>assistant\n --> <|im_start|>assistant\n, assistant needs to output tags. 57 | - assistant message is added at the beginning 58 | - kwargs: Additional keyword arguments to pass to the tokenizer.apply_chat_template method. 59 | ================================ 60 | Returns: 61 | The formatted messages (str). 62 | """ 63 | if isinstance(messages, str): 64 | return messages 65 | # pdb.set_trace() 66 | messages=np.array(messages) 67 | # print(len(messages)) 68 | # pdb.set_trace() 69 | assert isinstance(messages, np.ndarray), "The messages must be a numpy array." 70 | messages = copy.deepcopy(messages.tolist()) 71 | assert messages[-1]['role'] == 'user', "The last message must be a user message." 72 | if not with_thinking: 73 | prompt_chat_str = tokenizer.apply_chat_template(messages, **kwargs) 74 | if with_thinking: 75 | for msg in messages: 76 | if msg['role'] == 'assistant': 77 | msg['content'] = f"{msg['content']}" 78 | prompt_chat_str = tokenizer.apply_chat_template(messages, **kwargs) 79 | prompt_chat_str = f"{prompt_chat_str}" 80 | response_chat_str = f"{response}" 81 | return prompt_chat_str, response_chat_str 82 | class SFTDataset(Dataset): 83 | """ 84 | This is an in-memory SFTDataset 85 | """ 86 | 87 | def __init__(self, 88 | parquet_files: Union[str, List[str]], 89 | tokenizer, 90 | prompt_key='prompt', 91 | prompt_dict_keys=None, 92 | response_key='response', 93 | response_dict_keys=None, 94 | max_length=1024, 95 | truncation='left'): 96 | assert truncation in ['error', 'left', 'right'] 97 | self.truncation = truncation 98 | 99 | if not isinstance(parquet_files, List): 100 | parquet_files = [parquet_files] 101 | 102 | self.parquet_files = parquet_files 103 | if isinstance(tokenizer, str): 104 | tokenizer = hf_tokenizer(tokenizer) 105 | self.tokenizer: PreTrainedTokenizer = tokenizer 106 | 107 | ########################## Original Codes ########################## 108 | # self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key] 109 | # self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key] 110 | ########################## Original Codes ########################## 111 | assert isinstance(prompt_key, str) 112 | assert isinstance(response_key, str) 113 | self.prompt_key = prompt_key 114 | self.response_key = response_key 115 | 116 | self.prompt_dict_keys = [] if not prompt_dict_keys else prompt_dict_keys 117 | self.response_dict_keys = [] if not response_dict_keys else response_dict_keys 118 | 119 | self.max_length = max_length 120 | 121 | self._download() 122 | self._read_files_and_tokenize() 123 | 124 | def _download(self): 125 | for i, parquet_file in enumerate(self.parquet_files): 126 | self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) 127 | 128 | def _read_files_and_tokenize(self): 129 | 130 | def series_to_item(ls): 131 | import pandas, numpy 132 | while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: 133 | ls = ls[0] 134 | return ls 135 | 136 | dataframes = [] 137 | for parquet_file in self.parquet_files: 138 | # read parquet files and cache 139 | dataframe = pd.read_parquet(parquet_file) 140 | dataframes.append(dataframe) 141 | self.dataframe = pd.concat(dataframes) 142 | self.prompts = self.dataframe[self.prompt_key] 143 | self.prompt_dict_keys=[] 144 | self.response_dict_keys=[] 145 | for key in self.prompt_dict_keys: 146 | # type(x): pandas.core.series.Series 147 | # type(x[0]): numpy.ndarray 148 | # type(x[0][0]): dict 149 | try: 150 | self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1) 151 | except Exception: 152 | print(f'self.prompts={self.prompts}') 153 | raise 154 | self.prompts = self.prompts.tolist() 155 | self.responses = self.dataframe[self.response_key] 156 | for key in self.response_dict_keys: 157 | try: 158 | self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1) 159 | except Exception: 160 | print(f'self.responses={self.responses}') 161 | raise 162 | self.responses = self.responses.tolist() 163 | 164 | def __len__(self): 165 | return len(self.prompts) 166 | 167 | def __getitem__(self, item): 168 | tokenizer = self.tokenizer 169 | 170 | prompt = self.prompts[item] 171 | response = self.responses[item] 172 | 173 | # Preprocess text to replace observation and prediction tags with special tokens 174 | def preprocess_text_for_tokenization(text): 175 | """Preprocess text to replace observation and prediction tags with special tokens before tokenization.""" 176 | import re 177 | 178 | # Replace observation tags and their content with special token 179 | text = re.sub(r'.*?', '', text, flags=re.DOTALL) 180 | 181 | # Replace prediction tags and their content with special token 182 | text = re.sub(r'.*?', '', text, flags=re.DOTALL) 183 | 184 | return text 185 | 186 | # Apply preprocessing to prompt and response 187 | if isinstance(prompt, str): 188 | prompt = preprocess_text_for_tokenization(prompt) 189 | elif isinstance(prompt, np.ndarray): 190 | # If prompt is an array of messages, process each message 191 | prompt = copy.deepcopy(prompt.tolist()) 192 | for msg in prompt: 193 | if isinstance(msg, dict) and 'content' in msg: 194 | msg['content'] = preprocess_text_for_tokenization(msg['content']) 195 | 196 | response = preprocess_text_for_tokenization(response) 197 | 198 | ########################## Original Codes ########################## 199 | # # apply chat template 200 | # prompt_chat = [{'role': 'user', 'content': prompt}] 201 | 202 | # # string 203 | # prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False) 204 | # response_chat_str = response + tokenizer.eos_token 205 | ########################## Original Codes ########################## 206 | prompt_chat_str, response_chat_str = apply_chat_template(tokenizer, prompt, response, with_thinking=True, add_generation_prompt=True, tokenize=False) 207 | response_chat_str = response_chat_str + tokenizer.eos_token 208 | 209 | # tokenize 210 | prompt_ids_output = tokenizer(prompt_chat_str, return_tensors='pt', add_special_tokens=False) 211 | prompt_ids = prompt_ids_output['input_ids'][0] 212 | prompt_attention_mask = prompt_ids_output['attention_mask'][0] 213 | 214 | response_ids_output = tokenizer(response_chat_str, return_tensors='pt', add_special_tokens=False) 215 | response_ids = response_ids_output['input_ids'][0] 216 | response_attention_mask = response_ids_output['attention_mask'][0] 217 | 218 | prompt_length = prompt_ids.shape[0] 219 | response_length = response_ids.shape[0] 220 | 221 | input_ids = torch.cat((prompt_ids, response_ids), dim=-1) 222 | attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) 223 | 224 | # padding to max length 225 | sequence_length = input_ids.shape[0] 226 | if sequence_length < self.max_length: 227 | padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), 228 | dtype=input_ids.dtype) * self.tokenizer.pad_token_id 229 | padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) 230 | 231 | input_ids = torch.cat((input_ids, padded_input_ids)) 232 | attention_mask = torch.cat((attention_mask, padded_attention_mask)) 233 | elif sequence_length > self.max_length: 234 | if self.truncation == 'left': 235 | # actually, left truncation may not be reasonable 236 | input_ids = input_ids[-self.max_length:] 237 | attention_mask = attention_mask[-self.max_length:] 238 | elif self.truncation == 'right': 239 | input_ids = input_ids[:self.max_length] 240 | attention_mask = attention_mask[:self.max_length] 241 | elif self.truncation == 'error': 242 | raise NotImplementedError(f'{sequence_length=} is larger than {self.max_length=}') 243 | else: 244 | raise NotImplementedError(f'Unknown truncation method {self.truncation}') 245 | 246 | position_ids = compute_position_id_with_mask(attention_mask) 247 | 248 | loss_mask = attention_mask.clone() 249 | if prompt_length > 1: 250 | # mask out prompt for SFT. 251 | loss_mask[:min(prompt_length, loss_mask.size(0)) - 1] = 0 252 | # mask out the last token in response 253 | loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0 254 | 255 | return { 256 | 'input_ids': input_ids, 257 | 'attention_mask': attention_mask, 258 | 'position_ids': position_ids, 259 | 'loss_mask': loss_mask 260 | } -------------------------------------------------------------------------------- /SPA_agent/base_llm.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from abc import ABC, abstractmethod 3 | from typing import List, Dict, Optional, Union, Any, Tuple 4 | import os 5 | import asyncio 6 | import time 7 | 8 | from anthropic import AsyncAnthropic 9 | from openai import AsyncOpenAI 10 | from together import AsyncTogether 11 | 12 | @dataclass 13 | class LLMResponse: 14 | """Unified response format across all LLM providers""" 15 | content: str 16 | model_name: str 17 | 18 | class LLMProvider(ABC): 19 | """Abstract base class for LLM providers""" 20 | 21 | @abstractmethod 22 | async def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: 23 | """Generate a response from the LLM""" 24 | pass 25 | 26 | class OpenAIProvider(LLMProvider): 27 | """OpenAI API provider implementation""" 28 | 29 | def __init__(self, model_name: str = "gpt-4o", api_key: Optional[str] = None): 30 | self.model_name = model_name 31 | self.api_key = api_key or os.environ.get("OPENAI_API_KEY") 32 | if not self.api_key: 33 | raise ValueError("OpenAI API key not provided and not found in environment variables") 34 | 35 | self.client = AsyncOpenAI(api_key=self.api_key) 36 | 37 | async def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: 38 | if "o1-mini" in self.model_name: 39 | if messages[0]["role"] == "system": 40 | messages = messages[1:] 41 | 42 | response = await self.client.chat.completions.create( 43 | model=self.model_name, 44 | messages=messages, 45 | **kwargs 46 | ) 47 | if response.choices[0].finish_reason in ['length', 'content_filter']: 48 | raise ValueError("Content filtered or length exceeded") 49 | return LLMResponse( 50 | content=response.choices[0].message.content, 51 | model_name=response.model 52 | ) 53 | 54 | class DeepSeekProvider(LLMProvider): 55 | """DeepSeek API provider implementation""" 56 | 57 | def __init__(self, model_name: str = "deepseek-reasoner", api_key: Optional[str] = None): 58 | self.model_name = model_name 59 | self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY") 60 | if not self.api_key: 61 | raise ValueError("DeepSeek API key not provided and not found in environment variables") 62 | 63 | self.client = AsyncOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com") 64 | 65 | async def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: 66 | if "o1-mini" in self.model_name: 67 | if messages[0]["role"] == "system": 68 | messages = messages[1:] 69 | 70 | response = await self.client.chat.completions.create( 71 | model=self.model_name, 72 | messages=messages, 73 | **kwargs 74 | ) 75 | if response.choices[0].finish_reason in ['length', 'content_filter']: 76 | raise ValueError("Content filtered or length exceeded") 77 | return LLMResponse( 78 | content=response.choices[0].message.content, 79 | model_name=response.model 80 | ) 81 | 82 | class AnthropicProvider(LLMProvider): 83 | """Anthropic Claude API provider implementation 84 | Refer to https://github.com/anthropics/anthropic-sdk-python 85 | """ 86 | 87 | def __init__(self, model_name: str = "claude-3.5-sonnet-20240620", api_key: Optional[str] = None): 88 | self.model_name = model_name 89 | self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") 90 | if not self.api_key: 91 | raise ValueError("Anthropic API key not provided and not found in environment variables") 92 | 93 | self.client = AsyncAnthropic(api_key=self.api_key) 94 | 95 | async def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: 96 | # Extract system message if present 97 | system_content = "" 98 | chat_messages = [] 99 | 100 | for msg in messages: 101 | if msg["role"] == "system": 102 | system_content = msg["content"] 103 | else: 104 | # Map to Anthropic's format 105 | chat_messages.append({ 106 | "role": "assistant" if msg["role"] == "assistant" else "user", 107 | "content": msg["content"] 108 | }) 109 | 110 | response = await self.client.messages.create( 111 | model=self.model_name, 112 | system=system_content, 113 | messages=chat_messages, 114 | **kwargs 115 | ) 116 | if response.stop_reason == "max_tokens": 117 | raise ValueError("Max tokens exceeded") 118 | return LLMResponse( 119 | content=response.content[0].text, 120 | model_name=response.model 121 | ) 122 | 123 | class TogetherProvider(LLMProvider): 124 | """Together AI API provider implementation""" 125 | 126 | def __init__(self, model_name: str = "meta-llama/Llama-3-70b-chat-hf", api_key: Optional[str] = None): 127 | self.model_name = model_name 128 | self.api_key = api_key or os.environ.get("TOGETHER_API_KEY") 129 | if not self.api_key: 130 | raise ValueError("Together API key not provided and not found in environment variables") 131 | 132 | self.client = AsyncTogether(api_key=self.api_key) 133 | 134 | async def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: 135 | response = await self.client.chat.completions.create( 136 | model=self.model_name, 137 | messages=messages, 138 | **kwargs 139 | ) 140 | return LLMResponse( 141 | content=response.choices[0].message.content, 142 | model_name=response.model 143 | ) 144 | 145 | class ConcurrentLLM: 146 | """Unified concurrent interface for multiple LLM providers""" 147 | 148 | def __init__(self, provider: Union[str, LLMProvider], model_name: Optional[str] = None, 149 | api_key: Optional[str] = None, max_concurrency: int = 4): 150 | """ 151 | Initialize the concurrent LLM client. 152 | 153 | Args: 154 | provider: Either a provider instance or a string ('openai', 'anthropic', 'together') 155 | model_name: Model name (if provider is a string) 156 | api_key: API key (if provider is a string) 157 | max_concurrency: Maximum number of concurrent requests 158 | """ 159 | if isinstance(provider, LLMProvider): 160 | self.provider = provider 161 | else: 162 | if provider.lower() == "openai": 163 | self.provider = OpenAIProvider(model_name or "gpt-4o", api_key) 164 | elif provider.lower() == "deepseek": 165 | self.provider = DeepSeekProvider(model_name or "deepseek-reasoner", api_key) 166 | elif provider.lower() == "anthropic": 167 | self.provider = AnthropicProvider(model_name or "claude-3-7-sonnet-20250219", api_key) 168 | elif provider.lower() == "together": 169 | self.provider = TogetherProvider(model_name or "meta-llama/Llama-3-70b-chat-hf", api_key) 170 | else: 171 | raise ValueError(f"Unknown provider: {provider}") 172 | 173 | # Store max_concurrency but don't create the semaphore yet 174 | self.max_concurrency = max_concurrency 175 | self._semaphore = None 176 | 177 | @property 178 | def semaphore(self): 179 | """ 180 | Lazy initialization of the semaphore. 181 | This ensures the semaphore is created in the event loop where it's used. 182 | """ 183 | if self._semaphore is None: 184 | self._semaphore = asyncio.Semaphore(self.max_concurrency) 185 | return self._semaphore 186 | 187 | async def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse: 188 | """Generate a response with concurrency control""" 189 | async with self.semaphore: 190 | return await self.provider.generate(messages, **kwargs) 191 | 192 | def run_batch(self, 193 | messages_list: List[List[Dict[str, str]]], 194 | **kwargs) -> Tuple[List[Dict[str, Any]], List[List[Dict[str, str]]]]: 195 | """Process batches with retries in separate event loops, using id() to track messages""" 196 | 197 | results = [None] * len(messages_list) 198 | position_map = {id(messages): i for i, messages in enumerate(messages_list)} 199 | 200 | # Queue to store unfinished or failed tasks 201 | current_batch = messages_list.copy() 202 | max_retries = kwargs.get("max_retries", 100) 203 | retry_count = 0 204 | 205 | while current_batch and retry_count < max_retries: 206 | async def process_batch(): 207 | self._semaphore = None # Reset semaphore for this event loop 208 | batch_results = [] 209 | failures = [] 210 | 211 | tasks_with_messages = [(msg, asyncio.create_task(self.generate(msg, **kwargs))) 212 | for msg in current_batch] 213 | for messages, task in tasks_with_messages: 214 | try: 215 | response = await task 216 | position = position_map[id(messages)] 217 | batch_results.append((position, { 218 | "messages": messages, 219 | "response": response.content, 220 | "model": response.model_name, 221 | "success": True 222 | })) 223 | except Exception as e: 224 | print(f'[DEBUG] error: {e}') 225 | failures.append(messages) 226 | 227 | return batch_results, failures 228 | 229 | # Run in fresh event loop 230 | batch_results, next_batch = asyncio.run(process_batch()) 231 | 232 | # Update results with successful responses 233 | for position, result in batch_results: 234 | results[position] = result 235 | 236 | # Update for next iteration 237 | if next_batch: 238 | retry_count += 1 239 | # Update position map for failed messages 240 | position_map = {id(messages): position_map[id(messages)] 241 | for messages in next_batch} 242 | 243 | current_batch = next_batch 244 | time.sleep(5) 245 | print(f'[DEBUG] {len(next_batch)} failed messages, retry_count: {retry_count}') 246 | else: 247 | break 248 | 249 | return results, next_batch 250 | 251 | 252 | 253 | if __name__ == "__main__": 254 | # llm = ConcurrentLLM(provider="openai", model_name="gpt-4o") 255 | # llm = ConcurrentLLM(provider="anthropic", model_name="claude-3-5-sonnet-20240620") 256 | llm = ConcurrentLLM(provider="together", model_name="Qwen/Qwen2.5-7B-Instruct-Turbo") 257 | messages = [ 258 | [{"role": "user", "content": "what is 2+2?"}], 259 | [{"role": "user", "content": "what is 2+3?"}], 260 | [{"role": "user", "content": "what is 2+4?"}], 261 | [{"role": "user", "content": "what is 2+5?"}], 262 | [{"role": "user", "content": "what is 2+6?"}], 263 | [{"role": "user", "content": "what is 2+7?"}], 264 | [{"role": "user", "content": "what is 2+8?"}], 265 | [{"role": "user", "content": "what is 2+9?"}], 266 | [{"role": "user", "content": "what is 2+10?"}], 267 | [{"role": "user", "content": "what is 2+11?"}], 268 | [{"role": "user", "content": "what is 2+12?"}], 269 | [{"role": "user", "content": "what is 2+13?"}], 270 | [{"role": "user", "content": "what is 2+14?"}], 271 | [{"role": "user", "content": "what is 2+15?"}], 272 | [{"role": "user", "content": "what is 2+16?"}], 273 | [{"role": "user", "content": "what is 2+17?"}], 274 | ] 275 | response = llm.run_batch(messages, max_tokens=100) 276 | print(f"final response: {response}") 277 | -------------------------------------------------------------------------------- /SPA_agent/generate_sft_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for generating trajectories with worldmodel mode for SFT training. 3 | """ 4 | 5 | import os 6 | import pdb 7 | import json 8 | import numpy as np 9 | import time 10 | import logging 11 | import sys 12 | import re 13 | from datetime import datetime 14 | import hydra 15 | from omegaconf import OmegaConf 16 | from transformers import AutoTokenizer 17 | from SPA_agent.agent_proxy import LLMAgentProxy, VllmWrapperWg 18 | from verl import DataProto 19 | import pdb 20 | 21 | def init_logging(to_file_only=False, log_dir="log"): 22 | """Set up logging: redirect stdout/stderr to file and optionally keep console output.""" 23 | os.makedirs(log_dir, exist_ok=True) 24 | log_file = os.path.join(log_dir, f"debug_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log") 25 | 26 | # Clear existing handlers 27 | for handler in logging.root.handlers[:]: 28 | logging.root.removeHandler(handler) 29 | 30 | handlers = [logging.FileHandler(log_file, mode='w', encoding='utf-8')] 31 | if not to_file_only: 32 | handlers.append(logging.StreamHandler(sys.__stdout__)) 33 | 34 | logging.basicConfig( 35 | level=logging.DEBUG, 36 | format="%(asctime)s - %(levelname)s - %(message)s", 37 | handlers=handlers 38 | ) 39 | 40 | # Redirect print and errors to logging 41 | class StreamToLogger: 42 | def __init__(self, level): self.level = level 43 | def write(self, message): 44 | message = message.strip() 45 | if message: self.level(message) 46 | def flush(self): pass 47 | 48 | sys.stdout = StreamToLogger(logging.info) 49 | sys.stderr = StreamToLogger(logging.error) 50 | 51 | 52 | def convert_to_sft_format_add_worldmodel(trajectories, tokenizer): 53 | """Convert trajectories to SFT format by replacing predicted states with real states. 54 | 55 | Each turn has the structure: 56 | <|imstart|> xxxxx real_state <|imend|> 57 | """ 58 | sft_data = [] 59 | index = 0 60 | for traj in trajectories: 61 | # Get the messages and states 62 | messages = traj['messages_list'] 63 | real_states = traj['real_states'] 64 | 65 | if real_states is None: 66 | continue 67 | 68 | # Process conversation 69 | input_ids = [] 70 | labels = [] 71 | 72 | # Add system message 73 | system_message = messages[0] 74 | system_tokens = tokenizer.encode('<|im_start|>system\n' + system_message['content'] + '\n<|im_end|>', add_special_tokens=False) 75 | input_ids.extend(system_tokens) 76 | labels.extend([-100] * len(system_tokens)) 77 | 78 | # Process each turn (user + assistant) 79 | for turn_idx in range(len(messages)//2): 80 | user_idx = 1 + turn_idx * 2 81 | assistant_idx = user_idx + 1 82 | cur_state_idx = turn_idx 83 | next_state_idx = turn_idx + 1 84 | 85 | # Add user message 86 | if user_idx < len(messages): 87 | user_tokens = tokenizer.encode('<|im_start|>user\n' + messages[user_idx]['content'] + '\n<|im_end|>', add_special_tokens=False) 88 | input_ids.extend(user_tokens) 89 | labels.extend([-100] * len(user_tokens)) 90 | 91 | # Add assistant message with real state 92 | if assistant_idx < len(messages): 93 | assistant_message = messages[assistant_idx]['content'] 94 | 95 | # Get real states for replacement 96 | cur_state = "" 97 | if cur_state_idx < len(real_states) and real_states[cur_state_idx] is not None: 98 | cur_state = real_states[cur_state_idx] 99 | 100 | next_state = "" 101 | if next_state_idx < len(real_states) and real_states[next_state_idx] is not None: 102 | next_state = real_states[next_state_idx] 103 | 104 | # Replace predicted state with real state using regex 105 | def replace_states_in_message(msg, obs, pred): 106 | 107 | pattern = r'()(.*?)()' 108 | msg = re.sub(pattern, f'\\1{obs}\\3', msg, flags=re.DOTALL) 109 | pattern = r'()(.*?)()' 110 | msg = re.sub(pattern, f'\\1{pred}\\3', msg, flags=re.DOTALL) 111 | # import pdb; pdb.set_trace() 112 | return msg 113 | 114 | assistant_message = replace_states_in_message(assistant_message, cur_state, next_state) 115 | 116 | # Update the assistant message in the original messages list 117 | messages[assistant_idx]['content'] = assistant_message 118 | 119 | # Add the complete message 120 | assistant_tokens = tokenizer.encode('<|im_start|>assistant\n' + assistant_message + '\n<|im_end|>', add_special_tokens=False) 121 | input_ids.extend(assistant_tokens) 122 | 123 | # Create labels with -100 for non-state positions 124 | turn_labels = [-100] * len(assistant_tokens) 125 | 126 | # Find state positions 127 | state_start_tokens = tokenizer.encode('', add_special_tokens=False) 128 | state_end_tokens = tokenizer.encode('', add_special_tokens=False) 129 | 130 | # Find indices of state tags 131 | start_indices = [i for i in range(len(assistant_tokens)) if assistant_tokens[i:i+len(state_start_tokens)] == state_start_tokens] 132 | end_indices = [i for i in range(len(assistant_tokens)) if assistant_tokens[i:i+len(state_end_tokens)] == state_end_tokens] 133 | 134 | if start_indices and end_indices: 135 | start_idx = start_indices[0] 136 | end_idx = end_indices[0] + len(state_end_tokens) 137 | # Only keep values between state tags 138 | turn_labels[start_idx:end_idx] = assistant_tokens[start_idx:end_idx] 139 | 140 | labels.extend(turn_labels) 141 | index += 1 142 | # Add to SFT data 143 | 144 | 145 | sft_data.append({ 146 | 'id':index, 147 | 'messages_list': messages, # Use the modified messages list 148 | 149 | }) 150 | 151 | return sft_data 152 | 153 | @hydra.main(version_base=None, config_path="../config", config_name="base") 154 | def main(config): 155 | """Generate trajectories with worldmodel mode and save them for SFT training.""" 156 | # Initialize logging 157 | init_logging(to_file_only=False) 158 | import argparse 159 | 160 | # Set environment variables 161 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" 162 | os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", str(config.system.CUDA_VISIBLE_DEVICES)) 163 | 164 | model_path = config.model_path 165 | print(f"Loading model from {model_path}") 166 | config.actor_rollout_ref.model.path = model_path 167 | print(f"Loading tokenizer from {config.actor_rollout_ref.model.path}") 168 | tokenizer = AutoTokenizer.from_pretrained(config.actor_rollout_ref.model.path) 169 | actor_wg = VllmWrapperWg(config, tokenizer) 170 | proxy = LLMAgentProxy(config, actor_wg, tokenizer) 171 | 172 | # Create output directory 173 | output_dir = os.getenv("OUTPUT_DIR") 174 | # Ensure OUTPUT_DIR environment variable is provided; otherwise raise an error 175 | if output_dir is None or str(output_dir).strip() == "": 176 | raise RuntimeError("Environment variable 'OUTPUT_DIR' is not set. " 177 | "Please export OUTPUT_DIR before running this script.") 178 | os.makedirs(output_dir, exist_ok=True) 179 | 180 | # Generate multiple batches of trajectories 181 | num_batches = os.getenv("BT_NUM") # You can adjust this number 182 | all_trajectories = [] 183 | 184 | total_start_time = time.time() 185 | 186 | for batch_idx in range(int(num_batches)): 187 | batch_start_time = time.time() 188 | print(f"Generating batch {batch_idx + 1}/{num_batches}") 189 | 190 | # Generate trajectories 191 | meta_info = { 192 | 'eos_token_id': tokenizer.eos_token_id, 193 | 'pad_token_id': tokenizer.pad_token_id, 194 | 'recompute_log_prob': False, 195 | 'do_sample': config.actor_rollout_ref.rollout.val_kwargs.do_sample, 196 | 'validate': True, 197 | } 198 | test_gen_batch = DataProto(batch=None, non_tensor_batch=None, meta_info=meta_info) 199 | 200 | # Get trajectories with both predicted and real states 201 | test_batch, rollout_states, all_states = proxy.rollout(test_gen_batch, val=False) 202 | 203 | # Process and store trajectories 204 | for i in range(len(test_batch.batch['responses'])): 205 | success_flag = ( 206 | rollout_states[i].get('metrics', {}).get('SimpleSokoban/success', 0.0) == 1.0 207 | or rollout_states[i].get('metrics', {}).get('FrozenLake/success', 0.0) == 1.0 208 | ) 209 | # If neither environment reports success, skip this trajectory 210 | # if not success_flag: 211 | # continue 212 | trajectory = { 213 | 'messages_list': test_batch.non_tensor_batch['messages_list'][i], 214 | 'real_states': all_states[i], 215 | 'rewards': test_batch.batch['rm_scores'][i].cpu().tolist(), 216 | } 217 | all_trajectories.append(trajectory) 218 | 219 | 220 | 221 | batch_time = time.time() - batch_start_time 222 | print(f"Batch {batch_idx + 1} completed in {batch_time:.2f} seconds") 223 | print(f"Generated {len(test_batch.batch['responses'])} trajectories in this batch") 224 | 225 | total_time = time.time() - total_start_time 226 | print(f"\nTotal generation time: {total_time:.2f} seconds") 227 | 228 | 229 | # Convert to SFT format 230 | print("\nConverting trajectories to SFT format...") 231 | if config.ctx_manager.mode == 'add_worldmodel': 232 | sft_data = convert_to_sft_format_add_worldmodel(all_trajectories, tokenizer) 233 | 234 | # Save trajectories 235 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 236 | 237 | # Save raw trajectories 238 | raw_output_file = os.path.join(output_dir, f"raw_trajectories_{timestamp}.json") 239 | 240 | with open(raw_output_file, 'w') as f: 241 | json.dump(sft_data, f, indent=2) 242 | print(f"Saved raw trajectories to {raw_output_file}") 243 | 244 | # 直接转 DataFrame 并划分 train/val 245 | import pandas as pd 246 | from sklearn.model_selection import train_test_split 247 | rows = [] 248 | DEFAULT_DATA_SOURCE = "sokoban" 249 | DEFAULT_ABILITY = "bfs" 250 | DEFAULT_REWARD_MODEL = "{'ground_truth': {'numbers': array([0, 0]), 'target': 0}, 'style': 'rule'}" 251 | DEFAULT_EXTRA_INFO = "{'index': 100016, 'split': 'train'}" 252 | 253 | 254 | for sample in sft_data: 255 | messages = sample["messages_list"] 256 | for i, msg in enumerate(messages): 257 | if msg["role"] == "assistant": 258 | prompt_list = np.array(messages[:i]) 259 | print(prompt_list.dtype, prompt_list.shape) 260 | rows.append({ 261 | 'data_source': DEFAULT_DATA_SOURCE, 262 | # 'prompt': json.dumps(prompt_list, ensure_ascii=False), 263 | 'prompt': prompt_list, 264 | 'response': msg['content'], 265 | 'ability': DEFAULT_ABILITY, 266 | 'reward_model': DEFAULT_REWARD_MODEL, 267 | 'extra_info': DEFAULT_EXTRA_INFO, 268 | }) 269 | df = pd.DataFrame(rows) 270 | train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True) 271 | train_csv = os.path.join(output_dir, 'wm_train.csv') 272 | val_csv = os.path.join(output_dir, 'wm_val.csv') 273 | train_parquet = os.path.join(output_dir, 'wm_train.parquet') 274 | val_parquet = os.path.join(output_dir, 'wm_val.parquet') 275 | train_df.to_csv(train_csv, index=False) 276 | val_df.to_csv(val_csv, index=False) 277 | train_df.to_parquet(train_parquet, index=False) 278 | val_df.to_parquet(val_parquet, index=False) 279 | print(f"Train: {len(train_df)}, Val: {len(val_df)}") 280 | print(f"Saved to {train_csv}, {val_csv}, {train_parquet}, {val_parquet}") 281 | print(f"\nGenerated {len(all_trajectories)} total trajectories") 282 | print(f"Converted to {len(sft_data)} SFT training examples") 283 | 284 | if __name__ == "__main__": 285 | main() -------------------------------------------------------------------------------- /SPA_agent/agent_proxy.py: -------------------------------------------------------------------------------- 1 | from .ctx_manager import ContextManager 2 | from .es_manager import EnvStateManager 3 | from vllm import LLM, SamplingParams 4 | from verl.single_controller.ray.base import RayWorkerGroup 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | from verl import DataProto 7 | import hydra 8 | import os 9 | from typing import List, Dict 10 | from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto 11 | from .base_llm import ConcurrentLLM 12 | import pdb 13 | import sys 14 | import os 15 | import logging 16 | import random 17 | import re 18 | from datetime import datetime 19 | 20 | 21 | class VllmWrapperWg: # Thi is a developing class for eval and test 22 | def __init__(self, config, tokenizer): 23 | self.config = config 24 | self.tokenizer = tokenizer 25 | model_name = config.actor_rollout_ref.model.path 26 | ro_config = config.actor_rollout_ref.rollout 27 | 28 | print(f"[DEBUG] loading model from: {model_name}") 29 | print(f"[DEBUG] transformers cache: {os.environ.get('TRANSFORMERS_CACHE')}") 30 | print(f"[DEBUG] vllm cache dir: {os.environ.get('VLLM_CACHE_DIR')}") 31 | 32 | self.llm = LLM( 33 | model_name, 34 | enable_sleep_mode=True, 35 | tensor_parallel_size=ro_config.tensor_model_parallel_size, 36 | dtype=ro_config.dtype, 37 | enforce_eager=ro_config.enforce_eager, 38 | gpu_memory_utilization=ro_config.gpu_memory_utilization, 39 | disable_custom_all_reduce=True, 40 | disable_mm_preprocessor_cache=True, 41 | skip_tokenizer_init=False, 42 | max_model_len=ro_config.max_model_len, 43 | disable_log_stats=ro_config.disable_log_stats, 44 | max_num_batched_tokens=ro_config.max_num_batched_tokens, 45 | enable_chunked_prefill=ro_config.enable_chunked_prefill, 46 | enable_prefix_caching=True, 47 | ) 48 | print("LLM initialized") 49 | self.sampling_params = SamplingParams( 50 | max_tokens=ro_config.response_length, 51 | temperature=ro_config.val_kwargs.temperature, 52 | top_p=ro_config.val_kwargs.top_p, 53 | top_k=ro_config.val_kwargs.top_k, 54 | # min_p=0.1, 55 | ) 56 | 57 | def generate_sequences(self, lm_inputs: DataProto): 58 | """ 59 | Convert the input ids to text, and then generate the sequences. Finally create a dataproto. 60 | This aligns with the verl Worker Group interface. 61 | """ 62 | # NOTE: free_cache_engine is not used in the vllm wrapper. Only used in the verl vllm. 63 | # cache_action = lm_inputs.meta_info.get('cache_action', None) 64 | 65 | input_ids = lm_inputs.batch['input_ids'] 66 | input_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False) 67 | input_texts = [i.replace("<|endoftext|>", "") for i in input_texts] 68 | 69 | outputs = self.llm.generate(input_texts, sampling_params=self.sampling_params) 70 | texts = [output.outputs[0].text for output in outputs] 71 | lm_outputs = DataProto() 72 | lm_outputs.non_tensor_batch = { 73 | 'response_texts': texts, 74 | 'env_ids': lm_inputs.non_tensor_batch['env_ids'], 75 | 'group_ids': lm_inputs.non_tensor_batch['group_ids'] 76 | } # this is a bit hard-coded to bypass the __init__ check in DataProto 77 | lm_outputs.meta_info = lm_inputs.meta_info 78 | 79 | return lm_outputs 80 | 81 | class ApiCallingWrapperWg: 82 | """Wrapper class for API-based LLM calls that fits into the VERL framework""" 83 | 84 | def __init__(self, config, tokenizer): 85 | self.config = config 86 | self.tokenizer = tokenizer 87 | model_info = config.model_info[config.model_config.model_name] 88 | self.llm_kwargs = model_info.generation_kwargs 89 | 90 | 91 | self.llm = ConcurrentLLM( 92 | provider=model_info.provider_name, 93 | model_name=model_info.model_name, 94 | max_concurrency=config.model_config.max_concurrency 95 | ) 96 | 97 | print(f'API-based LLM ({model_info.provider_name} - {model_info.model_name}) initialized') 98 | 99 | 100 | def generate_sequences(self, lm_inputs: DataProto) -> DataProto: 101 | """ 102 | Convert the input ids to text, make API calls to generate responses, 103 | and create a DataProto with the results. 104 | """ 105 | 106 | messages_list = lm_inputs.non_tensor_batch['messages_list'].tolist() 107 | results, failed_messages = self.llm.run_batch( 108 | messages_list=messages_list, 109 | **self.llm_kwargs 110 | ) 111 | assert not failed_messages, f"Failed to generate responses for the following messages: {failed_messages}" 112 | 113 | texts = [result["response"] for result in results] 114 | print(f'[DEBUG] texts: {texts}') 115 | lm_outputs = DataProto() 116 | lm_outputs.non_tensor_batch = { 117 | 'response_texts': texts, 118 | 'env_ids': lm_inputs.non_tensor_batch['env_ids'], 119 | 'group_ids': lm_inputs.non_tensor_batch['group_ids'] 120 | } # this is a bit hard-coded to bypass the __init__ check in DataProto 121 | lm_outputs.meta_info = lm_inputs.meta_info 122 | 123 | return lm_outputs 124 | 125 | class RandomActionWrapperWg: 126 | """Wrapper class for random action generation that fits into the VERL framework""" 127 | 128 | def __init__(self, config, tokenizer): 129 | self.config = config 130 | self.tokenizer = tokenizer 131 | print(f'Random Action Wrapper initialized') 132 | 133 | def generate_random_action(self): 134 | """Generate three random actions from up, down, left, right, separated by ||.""" 135 | actions = ["up", "down", "left", "right"] 136 | three_actions = [random.choice(actions) for _ in range(3)] 137 | return " || ".join(three_actions) 138 | 139 | def generate_sequences(self, lm_inputs: DataProto) -> DataProto: 140 | """ 141 | Generate random action responses following the strict format: 142 | ... I will push the box to the target. ...ACTION 143 | """ 144 | messages_list = lm_inputs.non_tensor_batch['messages_list'].tolist() 145 | 146 | # Generate random actions for each message 147 | response_texts = [] 148 | for messages in messages_list: 149 | # Find the last assistant message to modify 150 | assistant_message = None 151 | for msg in reversed(messages): 152 | if msg.get("role") == "assistant": 153 | assistant_message = msg 154 | break 155 | 156 | if assistant_message: 157 | # Generate random action 158 | random_action = self.generate_random_action() 159 | 160 | # Create the response following the strict format 161 | # Extract observation and prediction from the original message 162 | content = assistant_message.get("content", "") 163 | 164 | # Extract observation 165 | obs_match = re.search(r'(.*?)', content, re.DOTALL) 166 | observation = obs_match.group(1) if obs_match else "" 167 | 168 | # Extract prediction 169 | pred_match = re.search(r'(.*?)', content, re.DOTALL) 170 | prediction = pred_match.group(1) if pred_match else "" 171 | 172 | # Create new response with random action 173 | response = f" {observation} I will push the box to the target. {prediction}{random_action}" 174 | response_texts.append(response) 175 | else: 176 | # Fallback if no assistant message found 177 | random_action = self.generate_random_action() 178 | response = f" I will push the box to the target. {random_action}" 179 | response_texts.append(response) 180 | 181 | lm_outputs = DataProto() 182 | lm_outputs.non_tensor_batch = { 183 | 'response_texts': response_texts, 184 | 'env_ids': lm_inputs.non_tensor_batch['env_ids'], 185 | 'group_ids': lm_inputs.non_tensor_batch['group_ids'] 186 | } 187 | lm_outputs.meta_info = lm_inputs.meta_info 188 | 189 | return lm_outputs 190 | 191 | class LLMAgentProxy: 192 | """ 193 | The proxy means the llm agent is trying to generate some rollout **at this time**, **at this model state**, **at this env state from the env config** 194 | """ 195 | def __init__(self, config, actor_rollout_wg, tokenizer): 196 | self.config = config 197 | self.train_ctx_manager = ContextManager(config, tokenizer, mode="train") 198 | self.train_es_manager = EnvStateManager(config, mode="train") 199 | self.val_ctx_manager = ContextManager(config, tokenizer, mode="val") 200 | self.val_es_manager = EnvStateManager(config, mode="val") 201 | self.actor_wg = actor_rollout_wg 202 | self.tokenizer = tokenizer 203 | # -------- prepare a validation trajectory file path (once per proxy) -------- 204 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 205 | # model name (strip path separators) 206 | model_path = getattr(self.config, "model_path", None) 207 | if model_path is None and hasattr(self.config, "actor_rollout_ref"): 208 | model_path = getattr(self.config.actor_rollout_ref.model, "path", "unknown_model") 209 | model_name = os.path.basename(str(model_path)).replace("/", "_") 210 | # ctx manager mode (e.g. base / add_worldmodel ...) 211 | mode_name = getattr(self.config.ctx_manager, "mode", "unknown_mode") 212 | file_name = f"valtraj_{model_name}_{mode_name}_{timestamp}.txt" 213 | out_dir = "val_trajectories" 214 | os.makedirs(out_dir, exist_ok=True) 215 | self.valtraj_file_path = os.path.join(out_dir, file_name) 216 | 217 | def generate_sequences(self, lm_inputs: DataProto): 218 | # TODO: add kv cache both for the vllm wrapper here and for verl vllm. 219 | if isinstance(self.actor_wg, RayWorkerGroup): 220 | padded_lm_inputs, pad_size = pad_dataproto_to_divisor(lm_inputs, self.actor_wg.world_size) 221 | 222 | padded_lm_outputs = self.actor_wg.generate_sequences(padded_lm_inputs) 223 | # pdb.set_trace() 224 | lm_outputs = unpad_dataproto(padded_lm_outputs, pad_size=pad_size) 225 | lm_outputs.meta_info = lm_inputs.meta_info 226 | lm_outputs.non_tensor_batch = lm_inputs.non_tensor_batch 227 | elif isinstance(self.actor_wg, VllmWrapperWg) or isinstance(self.actor_wg, ApiCallingWrapperWg) or isinstance(self.actor_wg, RandomActionWrapperWg): 228 | lm_outputs = self.actor_wg.generate_sequences(lm_inputs) 229 | else: 230 | raise ValueError(f"Unsupported actor worker type: {type(self.actor_wg)}") 231 | 232 | return lm_outputs 233 | 234 | def rollout(self, dataproto: DataProto, val=False): 235 | es_manager = self.val_es_manager if val else self.train_es_manager 236 | ctx_manager = self.val_ctx_manager if val else self.train_ctx_manager 237 | env_outputs = es_manager.reset() 238 | 239 | for i in range(self.config.agent_proxy.max_turn): 240 | lm_inputs: DataProto = ctx_manager.get_lm_inputs(env_outputs, prepare_for_update=False) 241 | lm_inputs.meta_info = dataproto.meta_info # TODO: setup vllm early stop when max length is reached. make sure this can be done 242 | lm_outputs: DataProto = self.generate_sequences(lm_inputs) #len: 128 243 | env_inputs: List[Dict] = ctx_manager.get_env_inputs(lm_outputs) 244 | env_outputs: List[Dict] = es_manager.step(env_inputs) 245 | 246 | if len(env_outputs) == 0: # all finished 247 | break 248 | 249 | if self.config.ctx_manager.mode=='add_worldmodel': 250 | rollout_states,all_states = es_manager.get_rollout_states_add_worldmodel() 251 | 252 | 253 | elif self.config.ctx_manager.mode=='base': 254 | rollout_states = es_manager.get_rollout_states_base() 255 | 256 | rollouts = ctx_manager.formulate_rollouts(rollout_states) 257 | # self.tokenizer.batch_decode(rollouts.batch['input_ids'], skip_special_tokens=False) # see all the trajectories 258 | messages_list = rollouts.non_tensor_batch["messages_list"] 259 | 260 | for i, (messages, rollout_state) in enumerate(zip(messages_list, rollout_states)): 261 | if i >= 10: 262 | break 263 | print(f"Env {i}:") 264 | print("Messages:") 265 | for msg in messages: 266 | print(msg) 267 | # print("Success:", rollout_state["metrics"].get("success", None)) 268 | print("-" * 40) 269 | 270 | # Get success status for each environment 271 | # pdb.set_trace() 272 | # success_list = [rollout_state["metrics"].get("FrozenLake/success", False) for rollout_state in rollout_states] 273 | # success_list = [rollout_state["metrics"].get("SimpleSokoban/success", False) for rollout_state in rollout_states] 274 | success_list = [rollout_state["metrics"].get("FrozenLake/success", rollout_state["metrics"].get("SimpleSokoban/success", rollout_state["metrics"].get("O3Sokoban/success", False))) for rollout_state in rollout_states] 275 | 276 | # print("Success status for all environments:", success_list) 277 | # Calculate and print success rate 278 | success_rate = sum(success_list) / len(success_list) 279 | print(f"Overall success rate: {success_rate:.3f}") 280 | ''' 281 | if val: 282 | # Start Generation Here 283 | # Save messages and corresponding rollout metrics into a JSON file 284 | try: 285 | # Dynamically import to avoid polluting global namespace 286 | import os, json 287 | # Record (best-effort) the current training step so that it can be 288 | # referenced later when dumping validation trajectories. 289 | 290 | # -------- use persistent save path -------- 291 | file_path = self.valtraj_file_path 292 | train_step = 0 293 | # -------- dump trajectories -------- 294 | 295 | with open(file_path, "a", encoding="utf-8") as fp: 296 | train_step=os.environ.get("STEP", 0) 297 | fp.write(f" Current training step: {train_step}\n") 298 | # pdb.set_trace() 299 | for i, (messages, rollout_state) in enumerate(zip(messages_list, rollout_states)): 300 | fp.write(f"Env {i}\n") 301 | fp.write("Messages:\n") 302 | for msg in messages: 303 | # msg can be dict / str 304 | if isinstance(msg, (dict, list)): 305 | fp.write(json.dumps(msg, ensure_ascii=False) + "\n") 306 | else: 307 | fp.write(str(msg) + "\n") 308 | fp.write("Metrics:\n") 309 | fp.write(json.dumps(rollout_state.get("metrics", {}), ensure_ascii=False) + "\n") 310 | fp.write("-" * 60 + "\n") 311 | print(f"[Validation] Trajectories saved to {file_path}") 312 | except Exception as e: 313 | print(f"[Validation] Failed to save trajectories: {e}") 314 | ''' 315 | if self.config.ctx_manager.mode=='add_worldmodel': 316 | return rollouts,rollout_states,all_states 317 | elif self.config.ctx_manager.mode=='base': 318 | return rollouts,rollout_states 319 | 320 | def init_logging(to_file_only=False, log_dir="log"): 321 | """Set up logging: redirect stdout/stderr to file and optionally keep console output.""" 322 | os.makedirs(log_dir, exist_ok=True) 323 | log_file = os.path.join(log_dir, f"debug_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log") 324 | 325 | # Clear existing handlers (important if libraries already added theirs) 326 | for handler in logging.root.handlers[:]: 327 | logging.root.removeHandler(handler) 328 | 329 | handlers = [logging.FileHandler(log_file, mode='w', encoding='utf-8')] 330 | if not to_file_only: 331 | handlers.append(logging.StreamHandler(sys.__stdout__)) 332 | 333 | logging.basicConfig( 334 | level=logging.DEBUG, 335 | format="%(asctime)s - %(levelname)s - %(message)s", 336 | handlers=handlers 337 | ) 338 | 339 | # Redirect print and errors to logging 340 | class StreamToLogger: 341 | def __init__(self, level): self.level = level 342 | def write(self, message): 343 | message = message.strip() 344 | if message: self.level(message) 345 | def flush(self): pass 346 | 347 | sys.stdout = StreamToLogger(logging.info) 348 | sys.stderr = StreamToLogger(logging.error) 349 | @hydra.main(version_base=None, config_path="../../config", config_name="base") 350 | def main(config): 351 | 352 | init_logging(to_file_only=False) # log to file + keep terminal output 353 | # detect config name from python -m ragen.llm_agent.agent_proxy --config_name frozen_lake 354 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" 355 | os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", str(config.system.CUDA_VISIBLE_DEVICES)) 356 | tokenizer = AutoTokenizer.from_pretrained(config.actor_rollout_ref.model.path) 357 | actor_wg = VllmWrapperWg(config, tokenizer) 358 | proxy = LLMAgentProxy(config, actor_wg, tokenizer) 359 | import time 360 | start_time = time.time() 361 | rollouts = proxy.rollout(DataProto(batch=None, non_tensor_batch=None, meta_info={'eos_token_id': 151645, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample':config.actor_rollout_ref.rollout.do_sample, 'validate': True}), val=True) 362 | end_time = time.time() 363 | print(f'rollout time: {end_time - start_time} seconds') 364 | # print rollout rewards from the rm_scores 365 | rm_scores = rollouts.batch["rm_scores"] 366 | metrics = rollouts.meta_info["metrics"] 367 | avg_reward = rm_scores.sum(-1).mean().item() 368 | print(f'rollout rewards: {avg_reward}') 369 | print(f'metrics:') 370 | for k, v in metrics.items(): 371 | print(f'{k}: {v}') 372 | 373 | # @hydra.main(version_base=None, config_path="../../config", config_name="evaluate_api_llm") 374 | # def main(config): 375 | # # detect config name from python -m ragen.llm_agent.agent_proxy --config_name frozen_lake 376 | # tokenizer = AutoTokenizer.from_pretrained(config.actor_rollout_ref.model.path) 377 | # actor_wg = ApiCallingWrapperWg(config, tokenizer) 378 | # proxy = LLMAgentProxy(config, actor_wg, tokenizer) 379 | # import time 380 | # start_time = time.time() 381 | # rollouts = proxy.rollout(DataProto(batch=None, non_tensor_batch=None, meta_info={'eos_token_id': 151645, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample': False, 'validate': True}), val=True) 382 | # print(f'[DEBUG] rollouts: {rollouts}') 383 | # end_time = time.time() 384 | # print(f'rollout time: {end_time - start_time} seconds') 385 | # # print rollout rewards from the rm_scores 386 | # rm_scores = rollouts.batch["rm_scores"] 387 | # metrics = rollouts.meta_info["metrics"] 388 | # avg_reward = rm_scores.sum(-1).mean().item() 389 | # print(f'rollout rewards: {avg_reward}') 390 | # print(f'metrics:') 391 | # for k, v in metrics.items(): 392 | # print(f'{k}: {v}') 393 | 394 | 395 | 396 | if __name__ == "__main__": 397 | main() -------------------------------------------------------------------------------- /config/envs.yaml: -------------------------------------------------------------------------------- 1 | custom_envs: 2 | SimpleSokoban: 3 | env_type: sokoban 4 | max_actions_per_traj: 10 # used in environment state manager to control the actual max actions executed per trajectory 5 | 6 | env_instruction: "You are solving the Sokoban puzzle. You are the player and you need to push all boxes to targets You are provided with a symbol grid. When you are exactly next to a box, you can push it by moving in the same direction. You cannot push a box through a wall, and you cannot pull a box. The answer should be a sequence of actions, like Right || Right || Up." 7 | env_instruction_add_worldmodel: "You are solving the Sokoban puzzle. You are the player and you need to push all boxes to targets You are provided with a symbol grid and the zero-indexed coordinates of the player, each box, and each target. Coordinates range from the top-left corner (0, 0) to the bottom-right corner (5, 5). When you are exactly next to a box, you can push it by moving in the same direction. You cannot push a box through a wall, and you cannot pull a box. The answer should be a sequence of actions, like Right || Right || Up. A sample full output is as follows: ######\n#_####\n#_P###\n#_X#_#\n#__O_#\n######\nPlayer (P) is at (2,2); box (X) is at (3,2); target (O) is at (4,3). 1 Down – I push box to (4,2). 2 Left – I step to (3,1). 3 Down – I stand left of box, ready to push it Right onto target. ######\n#_####\n#__###\n#__#_#\n#PXO_#\n###### Down || Left || Down " 8 | 9 | max_tokens: 100 # used to curate llm prompt "max words", not used for rollout 10 | env_config: # keys should be a subset of SokobanConfig 11 | dim_x: 6 12 | dim_y: 6 13 | num_boxes: 1 14 | max_steps: 100 15 | 16 | Sokoban2Boxes: 17 | env_type: sokoban 18 | max_actions_per_traj: 10 19 | env_instruction: "You are solving the Sokoban puzzle. Move all boxes to targets. Example answer format: I should go left first, then move the box to the up side.Left || Up" 20 | max_tokens: 30 21 | env_config: # keys should be a subset of SokobanConfig 22 | search_depth: 30 23 | dim_x: 6 24 | dim_y: 6 25 | num_boxes: 2 26 | max_steps: 100 27 | 28 | SokobanDifferentGridVocab: 29 | env_type: sokoban 30 | max_actions_per_traj: 10 31 | env_instruction: "You are solving the Sokoban puzzle. Move all boxes to targets. Example answer format: I should go left first, then move the box to the up side.Left || Up" 32 | max_tokens: 30 33 | env_config: # keys should be a subset of SokobanConfig 34 | search_depth: 30 35 | dim_x: 6 36 | dim_y: 6 37 | num_boxes: 1 38 | max_steps: 100 39 | grid_lookup: 40 | 0: "W" 41 | 1: "." 42 | 2: "G" 43 | 3: "C" 44 | 4: "B" 45 | 5: "A" 46 | 6: "@" 47 | grid_vocab: 48 | "W": "wall" 49 | ".": "empty" 50 | "G": "target" 51 | "C": "box on target" 52 | "B": "box" 53 | "A": "player" 54 | "@": "player on target" 55 | 56 | SokobanDifferentActionVocab: 57 | env_type: sokoban 58 | max_actions_per_traj: 10 59 | env_instruction: "You are solving the Sokoban puzzle. Move all boxes to targets. Example answer format: I should go left first, then move the box to the up side.Left || Up" 60 | max_tokens: 30 61 | env_config: # keys should be a subset of SokobanConfig 62 | search_depth: 30 63 | dim_x: 6 64 | dim_y: 6 65 | num_boxes: 1 66 | max_steps: 100 67 | action_lookup: 68 | 1: "Left" 69 | 2: "Right" 70 | 3: "Up" 71 | 4: "Down" 72 | 73 | 74 | 75 | VisualSimpleSokoban: 76 | env_type: sokoban 77 | max_actions_per_traj: 10 78 | env_instruction: "You are solving the Visual Simple Sokoban puzzle. Move all boxes to targets. Example answer format: I should go left first, then move the box to the up side.Left || Up" 79 | max_tokens: 30 80 | env_config: # keys should be a subset of SokobanConfig 81 | dim_x: 6 82 | dim_y: 6 83 | num_boxes: 1 84 | max_steps: 100 85 | render_mode: "rgb_array" 86 | 87 | HarderSokoban: 88 | env_type: sokoban 89 | max_actions_per_traj: 10 90 | env_instruction: "You are solving the Harder Sokoban puzzle. Move all boxes to targets. Example answer format: I should go left first, then move the box to the up side.Left || Up" 91 | max_tokens: 30 92 | env_config: 93 | dim_x: 10 94 | dim_y: 10 95 | num_boxes: 2 96 | max_steps: 100 97 | 98 | 99 | Countdown: 100 | env_type: countdown 101 | max_actions_per_traj: 1 102 | env_instruction: "You are solving the Countdown puzzle. You should use the num list to create an equation that equals the target. Example answer format: To find an equation using [3, 5, 2] to get 4. Let's check 2 + 5 = 7, 7 - 3 = 4. So the answer is 2 + 5 - 3 = 4. 2 + 5 - 3" 103 | max_tokens: 30 104 | env_config: null 105 | 106 | Bandit: 107 | env_type: bandit 108 | max_actions_per_traj: 1 109 | env_instruction: "" 110 | max_tokens: 30 111 | env_config: 112 | lo_arm_name: "Phoenix" 113 | hi_arm_name: "Dragon" 114 | 115 | BanditTest: 116 | env_type: bandit 117 | max_actions_per_traj: 1 118 | env_instruction: "" 119 | max_tokens: 30 120 | env_config: 121 | lo_arm_name: "Trader" 122 | hi_arm_name: "Librarian" 123 | 124 | FrozenLake: 125 | env_type: frozen_lake 126 | max_actions_per_traj: 10 127 | env_instruction: "You are solving the FrozenLake puzzle. Forbid the whole and go to the target. You may move to the unintended direction due to the slippery ice. Example answer format: To forbid the hole and go to the target, I should go left then go up.Left || Up" 128 | env_instruction_add_worldmodel: "You are solving the FrozenLake puzzle. You are the player and you need to Forbid the hole and go to the target. You are provided with a symbol grid and the zero-indexed coordinates of the player, each hole, and each target. Coordinates range from the top-left corner (0, 0) to the bottom-right corner (3, 3). When you are exactly next to goal, you can move in the same direction. You cannot move to a hole cause you will fall into the hole. The answer should be a sequence of actions, like Right || Right || Up. A sample full output is as follows: _O__\nO___\nG___\n__P_ Player (3,2); holes at (0,1) and (1,0); goal at (2,0). 1 Up - move to safe ice (2,2). 2 Left - slide to (2,1), adjacent to goal. 3 Left - reach goal (2,0); player now on G. _O__\nO___\n√___\n____ Up || Left || Left " 129 | max_tokens: 100 130 | env_config: null 131 | 132 | MetamathQA: 133 | env_type: metamathqa 134 | max_actions_per_traj: 1 135 | env_instruction: "You are solving Math problems. " 136 | max_tokens: 30 137 | env_config: null 138 | 139 | Sudoku: 140 | env_type: sudoku 141 | max_actions_per_traj: 50 142 | env_instruction: "You are solving 4x4 Sudoku. Fill empty cells with digits 1–4. Use a 1-indexed grid (rows/cols start at 1). A move is exactly: row,col,value (three integers). In one turn you may output multiple moves, separated by ||. Only propose moves that keep the row, column, and 2x2 subgrid valid. Always output EXACTLY: [brief reasoning] [r,c,v || r,c,v ...] No extra text outside the two tags. Keep the response under 50 words. Example: Row 1 has one empty cell → place 1. Column 2 then needs 2. 1,3,1 || 3,2,2" 143 | env_instruction_add_worldmodel: "You are solving 4x4 Sudoku. Fill empty cells with digits 1–4. Use a 1-indexed grid (rows/cols start at 1). A move is exactly: row,col,value (three integers). In one turn you may output multiple moves, separated by ||. Only propose moves that keep the row, column, and 2x2 subgrid valid. Always output EXACTLY: [brief reasoning] [r,c,v || r,c,v ...] No extra text outside the two tags. Keep the response under 100 words. An example output: | . . 1 4 | 1 4 . 3 | 4 2 . . | . 1 4 2 \nEmpty positions to be filled are at (1,1), (1,2), (2,3), (3,3), (3,4), (4,1) | 2 3 1 4 | 1 4 2 3 | 4 2 3 1 | . 1 4 2\n Empty positions to be filled are at (4,1) 1,1,2 || 1,2,3 || 2,3,2 || 3,3,3 || 3,4,1 ." 144 | 145 | max_tokens: 100 146 | env_config: null 147 | 148 | WebShop: 149 | env_type: webshop 150 | max_actions_per_traj: 9 151 | # env_instruction: "You are browsing an online shop. Based on the instruction, find the product that close to the production description. You need to read the website and decide what action to take next until buying a product. Available actions depends on the page: in the search page you can search keywords, in the search result page you can click an item url or click[next >] to navigate to next page, in the product page you can click[description] or click[features] to see the details, click[blue] or click[x-large] to choose size and colors, click[buy now] when you decided to buy the product, click[back to search] to return to search page. You should only choose action from the available actions list. Example process: I need a gingko light and 20x20 pillow cover that is hand painted. First search[gingko light 20x20 pillow cover hand painted], answer format: search[blanket with fleece throw]. Valid answer is search[] or click[]." 152 | env_instruction: "You are browsing an online shop. Based on the instruction, buy a product that close to the production description. You need to search, read the search results, pick a product, choose the size and color and buy. You should only choose action from the available actions list provided later. Example process: I need a gingko light and 20x20 pillow cover that is hand painted. First search[gingko light 20x20 pillow cover hand painted], answer format: search[blanket with fleece throw]. Valid answer is search[] or click[]." 153 | # env_instruction: > 154 | # You are browsing an online shop. Based on the instruction, find a product 155 | # that closely matches the production description. You need to iteratively take 156 | # actions(search or click) in the browser and buy the chosen product. 157 | # Example: 158 | # WebShop [SEP] Instruction: [SEP] Find me machine wash men's dress shirts with cotton spandex, classic fit, short sleeve with color: deep atlantic, and size: large tall, and price lower than 60.00 dollars [SEP] Search 159 | # Available actions: ['search[]'] 160 | # Search for the item: men's dress shirts with cotton spandex, classic fit, short sleeve. Do not search for color, size or price, because they will be on the search result or product pagesearch[men's dress shirts with cotton spandex, classic fit, short sleeve] 161 | # Instruction: [SEP] Find me machine wash men's dress shirts with cotton spandex, classic fit, short sleeve with color: deep atlantic, and size: large tall, and price lower than 60.00 dollars [SEP] Back to Search [SEP] Page 1 (Total results: 1) [SEP] Next > [SEP] B09M63B87V [SEP] YALFJV Women Long Sleeve Crew Neck Side Button T Shirts Tunic Dress Loose Asymmetric Hem Tunic Pullover to Wear with Leggings [SEP] $10.71 to $18.34 162 | # The product on this page is for women. None of the products are close to the description. Click next page to see more products.click[next >] 163 | # Instruction: [SEP] Find me machine wash men's dress shirts with cotton spandex, classic fit, short sleeve with color: deep atlantic, and size: large tall, and price lower than 60.00 dollars [SEP] Back to Search [SEP] Page 2 (Total results: 2) [SEP] B07HRFSNL4 [SEP] Nautica Men's Solid Crew Neck Short-Sleeve Pocket T-Shirt [SEP] $16.05 to $40.98 [SEP] B07N7TDKXQ [SEP] SOCKS'NBULK Mens Cotton Crew Neck Short Sleeve T-Shirts Mix Colors Bulk [SEP] $80.79 to $172.8 164 | # Available actions: ['click[back to search]', 'click[< prev]', 'click[next >]', 'click[b07hrfsnl4]', 'click[b07n7tdkxq]'] 165 | # The B07HRFSNL4 products seems close to the description, and with the price range. Click it.click[b07hrfsnl4] 166 | # Instruction: [SEP] Find me machine wash men's dress shirts with cotton spandex, classic fit, short sleeve with color: deep atlantic, and size: large tall, and price lower than 60.00 dollars [SEP] Back to Search [SEP] < Prev [SEP] size [SEP] x-small [SEP] small [SEP] large tall [SEP] color [SEP] navy [SEP] deep atlantic [SEP] deep atlantic 167 | # Available actions: ['click[back to search]', 'click[< prev]', 'click[description]', 'click[features]', 'click[reviews]', 'click[buy now]', 'click[x-small]', 'click[small]', 'click[large tall]', 'click[navy]', 'click[deep atlantic]'] 168 | # I need to choose the right size. Click large tallclick[large tall] 169 | # Instruction: [SEP] Find me machine wash men's dress shirts with cotton spandex, classic fit, short sleeve with color: deep atlantic, and size: large tall, and price lower than 60.00 dollars [SEP] Back to Search [SEP] < Prev [SEP] size [SEP] x-small [SEP] small [SEP] large tall [SEP] color [SEP] navy [SEP] deep atlantic [SEP] deep atlantic 170 | # Available actions: ['click[back to search]', 'click[< prev]', 'click[description]', 'click[features]', 'click[reviews]', 'click[buy now]', 'click[x-small]', 'click[small]', 'click[large tall]', 'click[navy]', 'click[deep atlantic]'] 171 | # Now I need to choose the right color. Click deep atlanticclick[deep atlantic] 172 | # Instruction: [SEP] Find me machine wash men's dress shirts with cotton spandex, classic fit, short sleeve with color: deep atlantic, and size: large tall, and price lower than 60.00 dollars [SEP] Back to Search [SEP] < Prev [SEP] size [SEP] x-small [SEP] small [SEP] large tall [SEP] color [SEP] navy [SEP] deep atlantic [SEP] deep atlantic 173 | # Available actions: ['click[back to search]', 'click[< prev]', 'click[description]', 'click[features]', 'click[reviews]', 'click[buy now]', 'click[x-small]', 'click[small]', 'click[large tall]', 'click[navy]', 'click[deep atlantic]'] 174 | # Having selected the size and color, I am ready to buy. Click buy now.click[buy now] 175 | 176 | # You are browsing an online shop. Based on the instruction, find the product 177 | # that best matches the production description. You need to iteratively take 178 | # actions(search or click) in the browser and buy the chosen product. Example 179 | # process: 180 | # Instruction: Find me machine wash men's t-shirts with long sleeve with color: black, and size: xx-large big tall, and price lower than 50.00 dollars 181 | # WebShop [SEP] Instruction: [SEP] Find me machine wash men's t-shirts with long sleeve with color: black, and size: xx-large big tall, and price lower than 50.00 dollars [SEP] Search 182 | # Available actions: ['search[]'] 183 | # First search for the big catagory: machine wash men's t-shirts with long sleeve. Do not search for color, size or price, because they will be on the search result or product pagesearch[machine wash men's t-shirts with long sleeve] 184 | # Instruction: [SEP] Find me machine wash men's t-shirts with long sleeve with color: black, and size: xx-large big tall, and price lower than 50.00 dollars [SEP] Back to Search [SEP] Page 1 (Total results: 50) [SEP] Next > [SEP] B09QQP3356 [SEP] HAUKLIE Men's Sports Waffle Ribbed Polo Shirts Summer Short Sleeve Cotton Muscle Quarter-Zip Henley T-Shirt Tunics Tops [SEP] $10.99 [SEP] B09Q8RD8YN [SEP] Bungo Stray Anime Dogs Anime Character, Long Sleeve, Sweatshirt, Hoodie, T shirt [SEP] $19.99 [SEP] B09QGK5XHZ [SEP] WENKOMG1 Men's Long Sleeve Undershirt with Mask Turtleneck Hooded T-Shirt Solid Color Workout Tops Zipper Side Slit Shirts Slim Fit Sweatshirt Spring/Summer Tee Shirts(Gray,) [SEP] $8.39 [SEP] B09QQJJ3KM [SEP] One Lucky Teacher St Patrick Day Shamrock Tee Teachers Custom Personalized Unisex T-Shirts Long Sleeve Hoodie Sweatshirt Gifts [SEP] $100.0 [SEP] B09S3BN15C [SEP] Mens Linen Shirt,Men's Striped Shirts Casual Short Sleeve Button Down Shirts Regular Fit Hawaiian Shirts Beach Tees Tops [SEP] $3.78 to $11.38 [SEP] B09ND9DP7J [SEP] InterestPrint Gold Horse Pattern Men's 2-Piece Sleepwear Set, Long Sleeve Shirt with Pants Loungewear [SEP] $43.59 [SEP] B09T756KQ5 [SEP] WENKOMG1 Men's Crewneck Goth Tee Shirts Skull Print Tops Spring/Summer Long Sleeve Sports T-Shirt Baggy Y2K Soft Streetwear [SEP] $4.09 to $8.09 [SEP] B09Q67H373 [SEP] JSPOYOU 2022 Newly T-Shirt for Mens Funny 3D Graphics Pattern Crewneck Short Sleeve Tees Big and Tall Summer Casual Comfy Top [SEP] $1.99 to $8.99 [SEP] B09R9YCM6R [SEP] XXBR Summer T-shirts for Mens, Soldier Short Sleeve 3D Street Vintage Printed Shirt Slim Fit Muscle Casual Tee Tops [SEP] $8.98 to $11.99 [SEP] B09KLQLLT2 [SEP] Long Sleeve Superhero T Shirt Tank Top Mens Compression Shirt Men Workout Fitness Gym Shirt [SEP] $19.99 185 | # Available actions: ['click[back to search]', 'click[next >]', 'click[b09qqp3356]', 'click[b09q8rd8yn]', 'click[b09qgk5xhz]', 'click[b09qqjj3km]', 'click[b09s3bn15c]', 'click[b09nd9dp7j]', 'click[b09t756kq5]', 'click[b09q67h373]', 'click[b09r9ycm6r]', 'click[b09klqllt2]'] 186 | # Product b09klqllt2 is a Men's Long Sleeve Superhero T Shirt priced 19.99$, which satisfies my needclick[b09klqllt2] 187 | # Instruction: [SEP] Find me machine wash men's t-shirts with long sleeve with color: black, and size: xx-large big tall, and price lower than 50.00 dollars [SEP] Back to Search [SEP] < Prev [SEP] color [SEP] black | blue [SEP] blue [SEP] blue | red [SEP] bronze [SEP] brown [SEP] gold [SEP] green [SEP] red [SEP] red | blue [SEP] yellow [SEP] grey | red [SEP] size [SEP] small [SEP] medium [SEP] large [SEP] x-large [SEP] xx-large [SEP] Long Sleeve Superhero T Shirt Tank Top Mens Compression Shirt Men Workout Fitness Gym Shirt [SEP] Price: $19.99 [SEP] Rating: N.A. [SEP] Description [SEP] Features [SEP] Reviews [SEP] Buy Now 188 | # Available actions: ['click[back to search]', 'click[< prev]', 'click[description]', 'click[features]', 'click[reviews]', 'click[buy now]', 'click[black | blue]', 'click[blue]', 'click[blue | red]', 'click[bronze]', 'click[brown]', 'click[gold]', 'click[green]', 'click[red]', 'click[red | blue]', 'click[yellow]', 'click[grey | red]', 'click[small]', 'click[medium]', 'click[large]', 'click[x-large]', 'click[xx-large]'] 189 | # Produce description and price looks good, now select the right colorclick[black | blue] 190 | # Instruction: [SEP] Find me machine wash men's t-shirts with long sleeve with color: black, and size: xx-large big tall, and price lower than 50.00 dollars [SEP] Back to Search [SEP] < Prev [SEP] color [SEP] black | blue [SEP] blue [SEP] blue | red [SEP] bronze [SEP] brown [SEP] gold [SEP] green [SEP] red [SEP] red | blue [SEP] yellow [SEP] grey | red [SEP] size [SEP] small [SEP] medium [SEP] large [SEP] x-large [SEP] xx-large [SEP] Long Sleeve Superhero T Shirt Tank Top Mens Compression Shirt Men Workout Fitness Gym Shirt [SEP] Price: $19.99 [SEP] Rating: N.A. [SEP] Description [SEP] Features [SEP] Reviews [SEP] Buy Now 191 | # Available actions: ['click[back to search]', 'click[< prev]', 'click[description]', 'click[features]', 'click[reviews]', 'click[buy now]', 'click[black | blue]', 'click[blue]', 'click[blue | red]', 'click[bronze]', 'click[brown]', 'click[gold]', 'click[green]', 'click[red]', 'click[red | blue]', 'click[yellow]', 'click[grey | red]', 'click[small]', 'click[medium]', 'click[large]', 'click[x-large]', 'click[xx-large]'] 192 | # I already selected color, now click size xx-largeclick[xx-large] 193 | # Instruction: [SEP] Find me machine wash men's t-shirts with long sleeve with color: black, and size: xx-large big tall, and price lower than 50.00 dollars [SEP] Back to Search [SEP] < Prev [SEP] color [SEP] black | blue [SEP] blue [SEP] blue | red [SEP] bronze [SEP] brown [SEP] gold [SEP] green [SEP] red [SEP] red | blue [SEP] yellow [SEP] grey | red [SEP] size [SEP] small [SEP] medium [SEP] large [SEP] x-large [SEP] xx-large [SEP] Long Sleeve Superhero T Shirt Tank Top Mens Compression Shirt Men Workout Fitness Gym Shirt [SEP] Price: $19.99 [SEP] Rating: N.A. [SEP] Description [SEP] Features [SEP] Reviews [SEP] Buy Now 194 | # Available actions: ['click[back to search]', 'click[< prev]', 'click[description]', 'click[features]', 'click[reviews]', 'click[buy now]', 'click[black | blue]', 'click[blue]', 'click[blue | red]', 'click[bronze]', 'click[brown]', 'click[gold]', 'click[green]', 'click[red]', 'click[red | blue]', 'click[yellow]', 'click[grey | red]', 'click[small]', 'click[medium]', 'click[large]', 'click[x-large]', 'click[xx-large]'] 195 | # I already selected color and size, now click buy nowclick[buy now] 196 | max_tokens: 200 197 | # env_config: null 198 | env_config: 199 | dataset: small # choose from small or full 200 | # data_path: /RAGEN-Dev/external/webshop-minimal/webshop_minimal 201 | # file_path: /RAGEN-Dev/external/webshop-minimal/webshop_minimal/data/items_shuffle.json -------------------------------------------------------------------------------- /SPA_agent/es_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the environment state manager for the LLM agent. 3 | author: Pingyue Zhang 4 | date: 2025-03-30 5 | """ 6 | from dataclasses import dataclass, field 7 | from typing import Dict, List, Optional, Any, Union 8 | import PIL.Image 9 | import hydra 10 | import random 11 | import numpy as np 12 | import pdb 13 | 14 | import sys 15 | import os 16 | # Dynamically find RAGEN root directory 17 | current_dir = os.path.dirname(os.path.abspath(__file__)) 18 | ragen_root = os.path.abspath(os.path.join(current_dir, '..', '..')) 19 | sys.path.append(ragen_root) 20 | from ragen.env import REGISTERED_ENVS, REGISTERED_ENV_CONFIGS 21 | from ragen.utils import register_resolvers 22 | register_resolvers() 23 | 24 | @dataclass 25 | class EnvStatus: 26 | """Status of an environment""" 27 | truncated: bool = False # done but not success 28 | terminated: bool = False # done and success 29 | num_actions: int = 0 # current action step (single action) 30 | rewards: List[float] = field(default_factory=list) # rewards for each turn 31 | seed: Optional[int] = None # what seed is used to reset this environment 32 | 33 | 34 | 35 | class EnvStateManager: 36 | """Manager for the environment state 37 | The class is responsible for managing multiple (kinds of) environments 38 | 39 | """ 40 | def __init__(self, config, mode: str = "train"): 41 | self.sys_config = config 42 | self.mode = mode 43 | self.config = getattr(self.sys_config.es_manager, mode) 44 | self.env_groups = int(self.config.env_groups) 45 | self.group_size = self.config.group_size 46 | seed_cfg = getattr(self.sys_config, "seed", None) 47 | if seed_cfg is not None: 48 | self.base_seed = seed_cfg.get(mode, None) 49 | else: 50 | self.base_seed = None 51 | self.seed_counter = 0 52 | self._init_envs() 53 | self.rollout_cache = None 54 | 55 | def _init_envs(self): 56 | """Initialize the environments. train_envs and val_envs are lists of envs: 57 | Input: tags: ["SimpleSokoban", "HarderSokoban"]; n_groups: [1, 1]; group_size: 16 58 | Output: envs: List[Dict], each **entry** is a dict with keys: tag, group_id, env_id, env, env_config, status 59 | Example: [{"tag": "SimpleSokoban", "group_id": 0, "env_id": 0, "env": env, "config": env_config, "status": EnvStatus()}, 60 | ... 61 | {"tag": "SimpleSokoban", "group_id": 0, "env_id": 15 (group_size - 1), ...}, 62 | {"tag": "HarderSokoban", "group_id": 1, "env_id": 16, ...} 63 | ...] 64 | """ 65 | assert sum(self.config.env_configs.n_groups) == self.env_groups, f"Sum of n_groups must equal env_groups. Got sum({self.config.env_configs.n_groups}) != {self.env_groups}" 66 | assert len(self.config.env_configs.tags) == len(self.config.env_configs.n_groups), f"Number of tags must equal number of n_groups. Got {len(self.config.env_configs.tags)} != {len(self.config.env_configs.n_groups)}" 67 | self.envs = self._init_env_instances(self.config) 68 | 69 | def _init_env_instances(self, config): 70 | env_list = [] 71 | done_groups = 0 72 | for tag, n_group in zip(config.env_configs.tags, config.env_configs.n_groups): 73 | for env_id in range(done_groups * self.group_size, (done_groups + n_group) * self.group_size): 74 | cfg_template = self.sys_config.custom_envs[tag] 75 | env_class = cfg_template.env_type 76 | max_actions_per_traj = cfg_template.max_actions_per_traj 77 | if cfg_template.env_config is None: 78 | env_config = REGISTERED_ENV_CONFIGS[env_class]() 79 | else: 80 | env_config = REGISTERED_ENV_CONFIGS[env_class](**cfg_template.env_config) 81 | env_obj = REGISTERED_ENVS[env_class](env_config) 82 | entry = {'tag': tag, 'group_id': env_id // self.group_size, 'env_id': env_id, 83 | 'env': env_obj, 'config': env_config, 'status': EnvStatus(), 'max_actions_per_traj': max_actions_per_traj} 84 | env_list.append(entry) 85 | done_groups += n_group 86 | return env_list 87 | 88 | def reset(self, seed: Optional[int] = None): 89 | """ 90 | Reset the environments and get initial observation 91 | build up rollout cache like [{"env_id": int, "history": List[Dict], "group_id": int}, ...] 92 | """ 93 | def _expand_seed(seed: int): 94 | seeds = [[seed + i] * self.group_size for i in range(self.env_groups)] # [[seed, ..., seed], [seed+1, ..., seed+1], ...] 95 | return sum(seeds, []) 96 | 97 | envs = self.envs 98 | rollout_cache = [{"env_id": entry['env_id'], "history": [], "group_id": entry['group_id'], "tag": entry['tag'], "penalty": 0} for entry in envs] 99 | 100 | # reset all environments 101 | # if self.mode == "train": 102 | # seed = random.randint(0, 1000000) if seed is None else seed # get a random seed 103 | 104 | # else: 105 | # seed = 123 106 | # reset all environments 107 | if seed is None: 108 | if self.mode == "train": 109 | if self.base_seed is not None: 110 | seed = self.base_seed + self.seed_counter 111 | self.seed_counter += self.env_groups 112 | else: 113 | seed = random.randint(0, 1000000) 114 | else: 115 | seed = 123 if self.base_seed is None else self.base_seed 116 | else: 117 | if self.mode == "train" and self.base_seed is not None: 118 | self.seed_counter = seed - self.base_seed + 1 119 | # import pdb; pdb.set_trace() 120 | seeds = _expand_seed(seed) 121 | for seed, entry in zip(seeds, envs): 122 | entry['env'].reset(seed=seed) 123 | entry['status'] = EnvStatus(seed=seed) 124 | 125 | # update rollout cache 126 | for cache, env in zip(rollout_cache, envs): 127 | next_state = self._handle_mm_state(env['env'].render()) 128 | cache['history'] = self._update_cache_history(cache['history'], next_state=next_state, actions_left=env['max_actions_per_traj'], num_actions_info=None) 129 | 130 | self.rollout_cache = rollout_cache 131 | return rollout_cache 132 | 133 | def step(self, all_env_inputs: List[Dict]): 134 | """Step the environments. 135 | 1. extract valid actions from the action lookup table (if exists) and execute the actions, and update rollout cache 136 | 2. Since rollout does not need to act over done envs, whenever the environment is done, we only update rollout cache, but not output env_outputs. 137 | Input: 138 | all_env_inputs: List[Dict] 139 | {env_id: int, llm_response: str, actions: List[str]} 140 | NOTE: should use env_id as index for existing some already done envs 141 | env_outputs: List[Dict] 142 | {env_id: int, history: List[Dict][{state: str, actions: List[str], reward: float, info: Dict, llm_response: str, llm_raw_response: str, (Optional)images: List[PIL.Image.Image]}]} 143 | """ 144 | def _execute_actions(env, actions): 145 | acc_reward, turn_info, turn_done = 0, {}, False 146 | executed_actions = [] 147 | for action in actions: 148 | obs, reward, done, info = env.step(action) 149 | # print(f"action: {action}, obs: {obs}, reward: {reward}, done: {done}, info: {info}") 150 | acc_reward += reward 151 | turn_info.update(info) # NOTE: currently use last info for multi-action 152 | executed_actions.append(action) 153 | if done: 154 | turn_done = True 155 | break 156 | return acc_reward, turn_info, turn_done, executed_actions 157 | 158 | def _log_env_state(status, history, cur_obs, max_actions_per_traj, executed_actions, all_actions, acc_reward, turn_done, turn_info, env_input): 159 | obs = self._handle_mm_state(cur_obs) 160 | status.num_actions += len(executed_actions) 161 | status.rewards.append(acc_reward) # NOTE use turn-wise acc_reward 162 | actions_left = max_actions_per_traj - status.num_actions 163 | if turn_done: 164 | status.terminated = True # TODO check terminated definition in gymnasium 165 | status.truncated = not turn_info.get('success', False) 166 | history = self._update_cache_history(history, next_state=obs, actions_left=actions_left, num_actions_info={ 167 | 'actions': executed_actions, 'reward': acc_reward, 'info': turn_info, 168 | 'llm_response': env_input['llm_response'], 'llm_raw_response': env_input['llm_raw_response'] 169 | }) 170 | # filter out invalid actions 171 | # history = [content for content in history[:-1] if content['actions']] + [history[-1]] 172 | return status, history 173 | 174 | envs = self.envs 175 | env_outputs = [] 176 | 177 | for env_input in all_env_inputs: 178 | acc_reward, turn_info, turn_done = 0, {}, False 179 | entry = envs[env_input['env_id']] 180 | env_id, env = entry['env_id'], entry['env'] 181 | actions_left_before = entry['max_actions_per_traj'] - entry['status'].num_actions 182 | 183 | # execute actions in envs 184 | valid_actions = self._extract_map_valid_actions(entry, env_input['actions']) 185 | acc_reward, turn_info, turn_done, executed_actions = _execute_actions(env, valid_actions[:actions_left_before]) 186 | # print(f"acc_reward, turn_info, turn_done, executed_actions: {acc_reward}, {turn_info}, {turn_done}, {executed_actions}") 187 | if len(valid_actions) != len(env_input['actions']) or not valid_actions: 188 | self.rollout_cache[env_id]["penalty"] += self.sys_config.es_manager.format_penalty 189 | 190 | status, history = _log_env_state(entry['status'], self.rollout_cache[env_id]['history'], entry['env'].render(), entry['max_actions_per_traj'], executed_actions, valid_actions, acc_reward, turn_done, turn_info, env_input) 191 | entry['status'] = status 192 | if entry['status'].num_actions >= entry['max_actions_per_traj'] and not turn_done: 193 | entry['status'].truncated = True 194 | entry['status'].terminated = True 195 | turn_done = True 196 | self.rollout_cache[env_id]['history'] = history 197 | if not turn_done: # NOTE done environments are not sent for further llm generation (for efficiency) 198 | env_outputs.append(self.rollout_cache[env_id]) 199 | # print(all_env_inputs) 200 | # print(entry['status'].num_actions) 201 | # print([env["status"].num_actions for env in envs]) 202 | # calcuate terminated but not truncated 203 | # print([env["status"].terminated and (not env["status"].truncated) for env in envs]) 204 | # print('success_ratio: ', np.mean([env["status"].terminated and (not env["status"].truncated) for env in envs])) 205 | # Calculate pass@k metrics using reshape 206 | success = np.array([float(env["status"].terminated and (not env["status"].truncated)) for env in envs]) 207 | success = success.reshape(self.env_groups, self.group_size) 208 | 209 | 210 | # # Calculate and print pass rates 211 | # full_pass = np.mean(np.any(success, axis=1)) 212 | # print(f"Full pass rate: {full_pass:.3f}") 213 | 214 | # for k in [4, 8, 16, 32, 64, 128, 256, 512, 1024]: 215 | # sampled = np.array([np.random.choice(group, size=k, replace=False) for group in success]) 216 | # pass_k = np.mean(np.any(sampled, axis=1)) 217 | # print(f"Pass@{k} rate: {pass_k:.3f}") 218 | 219 | # # breakpoint() 220 | # pdb.set_trace() 221 | 222 | return env_outputs 223 | 224 | def get_rollout_states_add_worldmodel(self): 225 | """Get the final output for all environment""" 226 | envs = self.envs 227 | rollout_cache = self.rollout_cache 228 | all_states = [] 229 | max_len = 0 230 | # add metrics to rollout cache 231 | for entry, cache in zip(envs, rollout_cache): 232 | status = entry['status'] 233 | env_metric = { 234 | 'success': float(status.terminated and (not status.truncated)), 235 | 'num_actions': status.num_actions, 236 | } 237 | this_states = [] 238 | custom_metric = {} 239 | for turn in cache['history']: 240 | this_states.append(turn['state']) 241 | for k, v in turn.get('info', {}).items(): 242 | if k == 'success': 243 | continue 244 | if k not in custom_metric: 245 | custom_metric[k] = [] 246 | custom_metric[k].append(float(v)) 247 | for k, v in custom_metric.items(): 248 | env_metric[k] = np.sum(v) / (len(cache['history']) - 1) # NOTE: exclude the last observation 249 | 250 | cache['history'][-1]['metrics'] = custom_metric 251 | env_metric = {f"{entry['tag']}/{k}": v for k, v in env_metric.items()} 252 | cache['metrics'] = env_metric 253 | all_states.append(this_states) 254 | max_len = max(max_len, len(this_states)) 255 | 256 | # Pad all states to the same length 257 | padded_states = [] 258 | for states in all_states: 259 | if len(states) < max_len: 260 | # Pad with the last state 261 | states = states + [states[-1]] * (max_len - len(states)) 262 | padded_states.append(states) 263 | 264 | # Compute pass@k (group-level success) where a group is successful if any env in the group succeeds 265 | group_success = {} 266 | 267 | for entry in envs: 268 | gid = entry['group_id'] 269 | env_success = float(entry['status'].terminated and (not entry['status'].truncated)) 270 | if gid not in group_success: 271 | group_success[gid] = env_success 272 | 273 | else: 274 | group_success[gid] = max(group_success[gid], env_success) 275 | 276 | 277 | group_success_counts: Dict[int, float] = {} # successful envs count per group 278 | group_total_counts: Dict[int, int] = {} # total envs count per group 279 | 280 | # Aggregate successes and totals for every group 281 | for entry in envs: 282 | gid = entry['group_id'] 283 | success_flag = float(entry['status'].terminated and (not entry['status'].truncated)) 284 | group_success_counts[gid] = group_success_counts.get(gid, 0.0) + success_flag 285 | group_total_counts[gid] = group_total_counts.get(gid, 0) + 1 286 | 287 | # Calculate per-group mean success (success rate) 288 | group_success_mean = { 289 | gid: group_success_counts[gid] / group_total_counts[gid] 290 | for gid in group_total_counts 291 | } 292 | 293 | # Overall pass_mean_k is the average of per-group success rates 294 | mean_k = float(np.mean(list(group_success_mean.values()))) if group_success_mean else 0.0 295 | 296 | # Record per-env metrics for success_mean_k / success_meank 297 | for entry, cache in zip(envs, rollout_cache): 298 | tag = entry['tag'] 299 | gid = entry['group_id'] 300 | cache['metrics'][f"{tag}/success_mean_k"] = float(group_success_mean[gid]) 301 | # cache['metrics'][f"{tag}/success_meank"] = float(group_success_mean[gid]) 302 | 303 | 304 | # Override the stored success metric with group-level success 305 | for entry, cache in zip(envs, rollout_cache): 306 | tag = entry['tag'] 307 | cache['metrics'][f"{tag}/success_passk"] = float(group_success[entry['group_id']]) 308 | 309 | return rollout_cache, np.array(padded_states) 310 | 311 | def get_rollout_states_base(self): 312 | """Get the final output for all environment""" 313 | envs = self.envs 314 | rollout_cache = self.rollout_cache 315 | 316 | # add metrics to rollout cache 317 | for entry, cache in zip(envs, rollout_cache): 318 | status = entry['status'] 319 | env_metric = { 320 | 'success': float(status.terminated and (not status.truncated)), 321 | 'num_actions': status.num_actions, 322 | } 323 | custom_metric = {} 324 | for turn in cache['history']: 325 | for k, v in turn.get('info', {}).items(): 326 | # print(f"k: {k}, v: {v}") 327 | # pdb.set_trace() 328 | if k == 'success': 329 | continue 330 | if k not in custom_metric: 331 | custom_metric[k] = [] 332 | 333 | custom_metric[k].append(float(v)) 334 | for k, v in custom_metric.items(): 335 | env_metric[k] = np.sum(v) / (len(cache['history']) - 1) # NOTE: exclude the last observation 336 | 337 | cache['history'][-1]['metrics'] = custom_metric 338 | env_metric = {f"{entry['tag']}/{k}": v for k, v in env_metric.items()} 339 | cache['metrics'] = env_metric 340 | 341 | # Compute pass@k (group-level success) where a group is successful if any env in the group succeeds 342 | group_success = {} 343 | 344 | for entry in envs: 345 | gid = entry['group_id'] 346 | env_success = float(entry['status'].terminated and (not entry['status'].truncated)) 347 | if gid not in group_success: 348 | group_success[gid] = env_success 349 | 350 | else: 351 | group_success[gid] = max(group_success[gid], env_success) 352 | 353 | 354 | group_success_counts: Dict[int, float] = {} # successful envs count per group 355 | group_total_counts: Dict[int, int] = {} # total envs count per group 356 | 357 | # Aggregate successes and totals for every group 358 | for entry in envs: 359 | gid = entry['group_id'] 360 | success_flag = float(entry['status'].terminated and (not entry['status'].truncated)) 361 | group_success_counts[gid] = group_success_counts.get(gid, 0.0) + success_flag 362 | group_total_counts[gid] = group_total_counts.get(gid, 0) + 1 363 | 364 | # Calculate per-group mean success (success rate) 365 | group_success_mean = { 366 | gid: group_success_counts[gid] / group_total_counts[gid] 367 | for gid in group_total_counts 368 | } 369 | 370 | # Overall pass_mean_k is the average of per-group success rates 371 | mean_k = float(np.mean(list(group_success_mean.values()))) if group_success_mean else 0.0 372 | 373 | # Record per-env metrics for success_mean_k / success_meank 374 | for entry, cache in zip(envs, rollout_cache): 375 | tag = entry['tag'] 376 | gid = entry['group_id'] 377 | cache['metrics'][f"{tag}/success_mean_k"] = float(group_success_mean[gid]) 378 | # cache['metrics'][f"{tag}/success_meank"] = float(group_success_mean[gid]) 379 | 380 | 381 | # Override the stored success metric with group-level success 382 | for entry, cache in zip(envs, rollout_cache): 383 | tag = entry['tag'] 384 | cache['metrics'][f"{tag}/success_passk"] = float(group_success[entry['group_id']]) 385 | return rollout_cache 386 | 387 | 388 | def _update_cache_history(self, history: List[Dict], next_state, actions_left, num_actions_info: Optional[Dict] = None): 389 | """ 390 | Update last step info and append state to history 391 | """ 392 | if num_actions_info is not None: # update last step info 393 | assert len(history), "History should not be empty" 394 | history[-1].update(num_actions_info) 395 | 396 | entry = {} # append state to history 397 | if isinstance(next_state, str): # text state 398 | entry['state'] = next_state 399 | else: # multimodal state 400 | entry['state'] = "" * len(next_state) 401 | entry['images'] = next_state 402 | entry['actions_left'] = actions_left 403 | history.append(entry) 404 | return history 405 | 406 | def _extract_map_valid_actions(self, entry: Dict, actions: List[str]): 407 | """extract valid actions from the action lookup table (if exists)""" 408 | mapped_actions = [] 409 | action_lookup = getattr(entry['env'].config, 'action_lookup', None) 410 | if action_lookup is None: 411 | mapped_actions = actions 412 | else: # the envs have pre-defined action lookup 413 | rev_action_lookup = {v.lower(): k for k, v in action_lookup.items()} 414 | actions = [action.lower() for action in actions] 415 | mapped_actions = [rev_action_lookup[action] for action in actions if action in rev_action_lookup] 416 | def match_template(action: str, templates: dict) -> bool: 417 | action_tokens = action.split() 418 | action_tokens = [token for token in action_tokens if not token.isdigit()] # remove number in action 419 | for template in templates: 420 | template_tokens = template.split() 421 | 422 | if len(action_tokens) != len(template_tokens): 423 | continue 424 | 425 | match = True 426 | for atok, ttok in zip(action_tokens, template_tokens): 427 | if ttok.startswith('<') and ttok.endswith('>'): 428 | continue 429 | if atok != ttok: 430 | match = False 431 | break 432 | 433 | if match: 434 | return True 435 | 436 | return False 437 | 438 | if 'look' in rev_action_lookup.keys(): 439 | mapped_actions = [action for action in actions if match_template(action, rev_action_lookup)] 440 | return mapped_actions 441 | 442 | def _handle_mm_state(self, state: Union[str, np.ndarray, list[np.ndarray]]): 443 | """Handle the state from the environment 444 | """ 445 | if isinstance(state, str): # text state 446 | return state 447 | elif isinstance(state, np.ndarray): # when env state is a single image, convert it to a list to unify output format 448 | state = [state] 449 | results = [PIL.Image.fromarray(_state, mode='RGB') for _state in state] 450 | return results 451 | 452 | def render(self): 453 | rendered_list = [entry['env'].render() for entry in self.envs] 454 | return rendered_list 455 | 456 | def close(self): 457 | for entry in self.envs: 458 | entry['env'].close() 459 | 460 | 461 | 462 | 463 | @hydra.main(version_base=None, config_path="../../config", config_name="base") 464 | def main(config): 465 | """ 466 | Unit test for EnvStateManager 467 | """ 468 | es_manager = EnvStateManager(config, mode="train") 469 | print("Initializing environments...") 470 | es_manager.reset(seed=123) 471 | 472 | renders = es_manager.render() 473 | for i, render in enumerate(renders[:4]): # Show first 2 environments 474 | print(f"Environment {i}:\n{render}\n") 475 | 476 | print("\nRunning step for training environments...") 477 | all_env_inputs = [ 478 | { 479 | "env_id": 0, 480 | "llm_raw_response": "Go down", 481 | "llm_response": "Go down", 482 | "actions": ["down"] 483 | }, 484 | { 485 | "env_id": 3, 486 | "llm_raw_response": "Go down", 487 | "llm_response": "Go down", 488 | "actions": ["down"] 489 | } 490 | ] 491 | env_outputs = es_manager.step(all_env_inputs) 492 | print(f"Active environments after step: {len(env_outputs)}") 493 | print(f"env_outputs[:2]: {env_outputs[:2]}") 494 | 495 | renders = es_manager.render() 496 | for i, render in enumerate(renders[:4]): # Show first 2 environments 497 | print(f"Environment {i}:\n{render}\n") 498 | 499 | all_env_inputs = [ 500 | { 501 | "env_id": 0, 502 | "llm_raw_response": "Go left, go up", 503 | "llm_response": "Go left, go up", 504 | "actions": ["left", "up"] 505 | }, 506 | { 507 | "env_id": 3, 508 | "llm_raw_response": "Go up, go up", 509 | "llm_response": "Go up, go up", 510 | "actions": ["up", "up", "up", "up", "up"] 511 | } 512 | ] 513 | env_outputs = es_manager.step(all_env_inputs) 514 | print(f"Active environments after step: {len(env_outputs)}") 515 | print(f"env_outputs[:2]: {env_outputs[:2]}") 516 | 517 | renders = es_manager.render() 518 | for i, render in enumerate(renders[:4]): # Show first 2 environments 519 | print(f"Environment {i}:\n{render}\n") 520 | 521 | print("\nRendering final output...") 522 | final_outputs = es_manager.get_rollout_states() 523 | print(f"final outputs[:4]: {final_outputs[:4]}") 524 | 525 | print("\nClosing environments...") 526 | es_manager.close() 527 | print("Test completed successfully!") 528 | 529 | 530 | if __name__ == "__main__": 531 | main() 532 | -------------------------------------------------------------------------------- /SPA_agent/ctx_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the context manager for the LLM agent. 3 | author: Kangrui Wang, Zihan Wang 4 | date: 2025-03-30 5 | """ 6 | import torch 7 | import numpy as np 8 | import pdb 9 | from typing import List, Dict, Any, Optional, Union 10 | from dataclasses import dataclass 11 | import re 12 | from verl import DataProto 13 | from verl.utils.dataset.rl_dataset import collate_fn 14 | from transformers import AutoTokenizer 15 | import hydra 16 | import sys 17 | import os 18 | # Dynamically find RAGEN root directory 19 | current_dir = os.path.dirname(os.path.abspath(__file__)) 20 | ragen_root = os.path.abspath(os.path.join(current_dir, '..', '..')) 21 | sys.path.append(ragen_root) 22 | from ragen.utils import register_resolvers 23 | from ragen.env import REGISTERED_ENV_CONFIGS 24 | from tensordict import TensorDict 25 | 26 | from dataclasses import asdict 27 | register_resolvers() 28 | 29 | def get_masks_and_scores_add_worldmodel(input_ids: torch.Tensor, tokenizer: AutoTokenizer, all_scores: List[List[float]] = None, llm_input_texts= None,observations= None,use_turn_scores: bool = False): 30 | """ 31 | input_ids: shape (bsz, seq_len) 32 | Get loss mask that only learns between <|im_start|>assistant and <|im_end|>. Currently only supports qwen. 33 | NOTE: important! This assumes that the input_ids starts with system and then user & assistant in alternative ways 34 | """ 35 | special_token = tokenizer.encode("<|im_start|>")[0] 36 | turn_starts = torch.where(input_ids == special_token, 1, 0) 37 | turn_indicators = torch.cumsum(turn_starts, dim=-1) 38 | response_mask = (turn_indicators % 2 == 1) & (turn_indicators > 1) # only learns all assistant turns 39 | loss_mask = (turn_indicators > 1) # learns everything after system prompt 40 | reward_token = tokenizer.encode("<|im_end|>")[0] 41 | score_tensor = torch.zeros_like(input_ids, dtype=torch.float32) 42 | 43 | # Get state tokens 44 | start_state_tokens = [tokenizer.encode(">", add_special_tokens=False),tokenizer.encode("", add_special_tokens=False)] 45 | end_state_tokens = tokenizer.encode("", add_special_tokens=False) 46 | 47 | # Create state mask 48 | state_mask = torch.zeros_like(input_ids, dtype=torch.bool) 49 | for i in range(input_ids.shape[0]): 50 | # Find start positions by checking for continuous sequence of start_state_tokens 51 | start_positions = [] 52 | for j in range(input_ids.shape[1] - len(start_state_tokens[0]) + 1): 53 | if torch.all(input_ids[i, j:j+len(start_state_tokens[0])] == torch.tensor(start_state_tokens[0], device=input_ids.device)) \ 54 | or torch.all(input_ids[i, j:j+len(start_state_tokens[1])] == torch.tensor(start_state_tokens[1], device=input_ids.device)): 55 | start_positions.append(j) 56 | 57 | # Find end positions by checking for continuous sequence of end_state_tokens 58 | end_positions = [] 59 | for j in range(input_ids.shape[1] - len(end_state_tokens) + 1): 60 | if torch.all(input_ids[i, j:j+len(end_state_tokens)] == torch.tensor(end_state_tokens, device=input_ids.device)): 61 | end_positions.append(j) 62 | # pdb.set_trace() 63 | # For each turn, find the first valid pair of start and end tokens 64 | if len(start_positions) > 0 and len(end_positions) > 0: 65 | # Get the turn number for each position 66 | start_turns = turn_indicators[i, start_positions] 67 | end_turns = turn_indicators[i, end_positions] 68 | 69 | # For each turn, find the first valid pair 70 | for turn in range(1, turn_indicators[i].max().item() + 1): 71 | # Find start and end positions in this turn 72 | turn_start_positions = [pos for pos, t in zip(start_positions, start_turns) if t == turn] 73 | turn_end_positions = [pos for pos, t in zip(end_positions, end_turns) if t == turn] 74 | if turn_start_positions and turn_end_positions: 75 | # Take the first start token in this turn 76 | first_start = turn_start_positions[0] 77 | # Find the first end token that comes after the first start 78 | valid_ends = [end for end in turn_end_positions if end > first_start] 79 | if valid_ends: 80 | first_end = valid_ends[0] 81 | state_mask[i, first_start+len(start_state_tokens[0]):first_end] = True 82 | 83 | if use_turn_scores: 84 | for idx, scores in enumerate(list(zip(*all_scores))): 85 | scores = torch.tensor(scores, dtype=torch.float32, device=input_ids.device) 86 | turn_indicator = idx * 2 + 3 # 0: pad. 1: system. 2+2n: user. 3+2n: assistant 87 | reward_position = (input_ids == reward_token) & (turn_indicators == turn_indicator) 88 | score_tensor[reward_position] = scores 89 | scores_ = [sum(i) for i in all_scores] 90 | score_tensor[:, -1] = torch.tensor(scores_, dtype=torch.float32, device=input_ids.device) 91 | else: 92 | scores = [sum(i) for i in all_scores] 93 | score_tensor[:, -1] = torch.tensor(scores, dtype=torch.float32, device=input_ids.device) 94 | 95 | 96 | loss_mask = loss_mask[:, :-1] # remove the last token 97 | score_tensor = score_tensor[:, 1:] # remove the first token 98 | 99 | return loss_mask, score_tensor, response_mask, state_mask 100 | def get_masks_and_scores_base(input_ids: torch.Tensor, tokenizer: AutoTokenizer, all_scores: List[List[float]] = None, use_turn_scores: bool = False): 101 | """ 102 | input_ids: shape (bsz, seq_len) 103 | Get loss mask that only learns between <|im_start|>assistant and <|im_end|>. Currently only supports qwen. 104 | NOTE: important! This assumes that the input_ids starts with system and then user & assistant in alternative ways 105 | """ 106 | special_token = tokenizer.encode("<|im_start|>")[0] 107 | turn_starts = torch.where(input_ids == special_token, 1, 0) 108 | turn_indicators = torch.cumsum(turn_starts, dim=-1) 109 | response_mask = (turn_indicators % 2 == 1) & (turn_indicators > 1) # only learns all assistant turns 110 | loss_mask = (turn_indicators > 1) # learns everything after system prompt 111 | 112 | reward_token = tokenizer.encode("<|im_end|>")[0] 113 | score_tensor = torch.zeros_like(input_ids, dtype=torch.float32) 114 | 115 | if use_turn_scores: 116 | for idx, scores in enumerate(list(zip(*all_scores))): 117 | scores = torch.tensor(scores, dtype=torch.float32) 118 | turn_indicator = idx * 2 + 3 # 0: pad. 1: system. 2+2n: user. 3+2n: assistant 119 | reward_position = (input_ids == reward_token) & (turn_indicators == turn_indicator) 120 | score_tensor[reward_position] = scores 121 | scores_ = [sum(i) for i in all_scores] 122 | score_tensor[:, -1] = torch.tensor(scores_, dtype=torch.float32) 123 | else: 124 | scores = [sum(i) for i in all_scores] 125 | score_tensor[:, -1] = torch.tensor(scores, dtype=torch.float32) 126 | loss_mask = loss_mask[:, :-1] # remove the last token 127 | score_tensor = score_tensor[:, 1:] # remove the first token 128 | 129 | return loss_mask, score_tensor, response_mask 130 | 131 | 132 | 133 | 134 | class ContextManager: 135 | """ 136 | Manages the context for LLM interactions with environments. 137 | Translates between environment outputs and LLM inputs, and vice versa. 138 | """ 139 | 140 | def __init__(self, 141 | config, 142 | tokenizer, 143 | processor = None, 144 | mode: str = "train", 145 | ): 146 | """ 147 | Initialize the ContextManager. 148 | Processor is used to process the image data. 149 | """ 150 | self.config = config 151 | self.tokenizer = tokenizer 152 | self.processor = processor 153 | self.action_sep = self.config.agent_proxy.action_sep 154 | 155 | 156 | if self.config.ctx_manager.mode=='add_worldmodel': 157 | self.special_token_list = ["", "", "", "", "", "","", "", "<|im_start|>", "<|im_end|>"] 158 | elif self.config.ctx_manager.mode=='base': 159 | self.special_token_list = ["", "", "", "", "<|im_start|>", "<|im_end|>"] 160 | 161 | self.es_cfg = self.config.es_manager[mode] 162 | self.env_nums = { 163 | env_tag: n_group * self.es_cfg.group_size 164 | for n_group, env_tag in zip(self.es_cfg.env_configs.n_groups, self.es_cfg.env_configs.tags) 165 | } 166 | self._init_prefix_lookup() 167 | 168 | def _init_prefix_lookup(self): 169 | prefix_lookup = {} 170 | prefixes = {} 171 | env_config_lookup = {} 172 | env_config = {} 173 | for env_tag, env_config in self.config.custom_envs.items(): 174 | if env_tag not in self.es_cfg.env_configs.tags: 175 | continue 176 | env_config_new = asdict(REGISTERED_ENV_CONFIGS[env_config.env_type]()) 177 | for k,v in env_config.items(): 178 | env_config_new[k] = v 179 | if self.config.ctx_manager.mode=='add_worldmodel': 180 | if env_config_new.get("env_instruction_add_worldmodel", True): 181 | env_instruction = env_config_new.get("env_instruction_add_worldmodel", "") 182 | else: 183 | env_instruction = env_config_new.get("env_instruction", "") 184 | 185 | else: 186 | env_instruction = env_config_new.get("env_instruction", "") 187 | if env_config_new.get("grid_vocab", False): 188 | grid_vocab_str = "\nThe meaning of each symbol in the state is:\n" + ", ".join([f"{k}: {v}" for k, v in env_config_new["grid_vocab"].items()]) 189 | env_instruction += grid_vocab_str 190 | if env_config_new.get("action_lookup", False): 191 | action_lookup_str = "\nYour available actions are:\n" + ", ".join([f"{v}" for k, v in env_config_new["action_lookup"].items()]) 192 | action_lookup_str += f"\nYou can make up to {env_config_new['max_actions_per_traj']} actions, separated by the action separator \" " + self.action_sep + " \"\n" 193 | env_instruction += action_lookup_str 194 | prefixes[env_tag] = env_instruction 195 | env_config_lookup[env_tag] = {'max_tokens': env_config.get("max_tokens", self.config.actor_rollout_ref.rollout.response_length)} 196 | 197 | tags = self.es_cfg.env_configs.tags 198 | n_groups = self.es_cfg.env_configs.n_groups 199 | group_size = self.es_cfg.group_size 200 | 201 | cur_group = 0 202 | for env_tag, n_group in zip(tags, n_groups): 203 | env_instruction = prefixes[env_tag] 204 | start_idx = cur_group * group_size 205 | end_idx = (cur_group + n_group) * group_size 206 | for i in range(start_idx, end_idx): 207 | prefix_lookup[i] = env_instruction 208 | env_config_lookup[i] = env_config_lookup[env_tag] 209 | cur_group += n_group 210 | 211 | self.prefix_lookup = prefix_lookup 212 | self.env_config_lookup = env_config_lookup 213 | 214 | def _parse_response_base(self, response: str) -> List: 215 | pattern = r'(.*?)\s*(.*?)' if self.config.agent_proxy.enable_think else r'(.*?)' 216 | match = re.search(pattern, response, re.DOTALL) 217 | if not match: 218 | # think_content, action_content, actions = "", "", [] # do not remove this kind of invalid string 219 | llm_response, actions = response, [] 220 | else: 221 | if self.config.agent_proxy.enable_think: 222 | think_content, action_content = match.group(1), match.group(2) 223 | else: 224 | think_content, action_content = "", match.group(1) 225 | 226 | for special_token in self.special_token_list: 227 | action_content = action_content.replace(special_token, "").strip() 228 | think_content = think_content.replace(special_token, "").strip() 229 | 230 | actions = [action.strip() for action in action_content.split(self.action_sep) if action.strip()] 231 | max_actions = self.config.agent_proxy.max_actions_per_turn 232 | 233 | if len(actions) > max_actions: 234 | actions = actions[:max_actions] #Only the first MAX_ACTIONS actions are kept in the rollout. 235 | action_content = (" " + self.action_sep + " ").join(actions) 236 | 237 | llm_response = f"{think_content}{action_content}" if self.config.agent_proxy.enable_think else f"{action_content}" 238 | return llm_response, actions 239 | def _parse_response_add_worldmodel(self, response: str) -> List: 240 | 241 | pattern = r'\s*(.*?)(.*?)(.*?)\s*\s*(.*?)' if self.config.agent_proxy.enable_think else r'(.*?)' 242 | match = re.search(pattern, response, re.DOTALL) 243 | if not match: 244 | answer_pattern = r'(.*?)' 245 | answer_match = re.search(answer_pattern, response) 246 | if answer_match: 247 | llm_response, cur_state_content, next_state_content, actions = response, [],[],answer_match.group(1) 248 | else: 249 | llm_response, cur_state_content, next_state_content, actions = response, [],[],[] 250 | else: 251 | if self.config.agent_proxy.enable_think: 252 | cur_state_content, think_content, next_state_content, action_content = match.group(1), match.group(2),match.group(3),match.group(4) 253 | else: 254 | cur_state_content, next_state_content, action_content = "","", match.group(1) 255 | for special_token in self.special_token_list: 256 | action_content = action_content.replace(special_token, "").strip() 257 | think_content = think_content.replace(special_token, "").strip() 258 | cur_state_content = cur_state_content.replace(special_token, "").strip() 259 | next_state_content = next_state_content.replace(special_token, "").strip() 260 | 261 | actions = [action.strip() for action in action_content.split(self.action_sep) if action.strip()] 262 | max_actions = self.config.agent_proxy.max_actions_per_turn 263 | 264 | if len(actions) > max_actions: 265 | actions = actions[:max_actions] #Only the first MAX_ACTIONS actions are kept in the rollout. 266 | action_content = (" " + self.action_sep + " ").join(actions) 267 | llm_response = f" {cur_state_content} {think_content} {next_state_content}{action_content}" if self.config.agent_proxy.enable_think else f"{action_content}" 268 | return llm_response, actions, cur_state_content, next_state_content 269 | 270 | 271 | def _normalize_score_tensor(self, score_tensor: torch.Tensor, env_outputs: List[Dict]) -> torch.Tensor: 272 | """ 273 | Normalize the score tensor to be between 0 and 1. 274 | NOTE: only support score at the last token for now 275 | """ 276 | # assert self.config.agent_proxy.use_turn_scores == False, "Reward normalization is not supported for use_turn_scores == True" 277 | 278 | # score_tensor.shape : torch.Size([128, 2627]) 279 | rn_cfg = self.config.agent_proxy.reward_normalization 280 | grouping, method = rn_cfg.grouping, rn_cfg.method 281 | if grouping == "state": 282 | group_tags = [env_output["group_id"] for env_output in env_outputs] 283 | elif grouping == "inductive": 284 | group_tags = [env_output["tag"] for env_output in env_outputs] 285 | elif grouping == "batch": 286 | group_tags = [1] * len(env_outputs) 287 | else: 288 | raise ValueError(f"Invalid grouping: {grouping}") 289 | 290 | 291 | if method == "mean_std": 292 | norm_func = lambda x: (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-6) if x.std(dim=-1, keepdim=True).abs().max() > 1e-6 else torch.zeros_like(x) # stable to bf16 than x.std() 293 | elif method == "mean": 294 | norm_func = lambda x: (x - x.mean(dim=-1, keepdim=True)) 295 | elif method == "asym_clip": 296 | norm_func = lambda x: ((x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-6) if x.std(dim=-1, keepdim=True).abs().max() > 1e-6 else torch.zeros_like(x)).clamp(min=-1, max=3) 297 | elif method == "identity": 298 | norm_func = lambda x: x 299 | else: 300 | raise ValueError(f"Invalid normalization method: {method}") 301 | 302 | # apply groupwise normalization 303 | group2index = {} 304 | for i, env_tag in enumerate(group_tags): 305 | if env_tag not in group2index: 306 | group2index[env_tag] = [] 307 | group2index[env_tag].append(i) 308 | group2index = {k: torch.tensor(v) for k, v in group2index.items()} 309 | 310 | 311 | acc_scores = score_tensor[:, -1] 312 | normalized_acc_scores = acc_scores.clone() 313 | for group, index in group2index.items(): 314 | normalized_acc_scores[index] = norm_func(normalized_acc_scores[index]) 315 | 316 | # apply penalty 317 | penalty = torch.tensor([env_output["penalty"] for env_output in env_outputs], dtype=torch.float32) 318 | normalized_acc_scores = normalized_acc_scores + penalty 319 | 320 | score_tensor[:, -1] = normalized_acc_scores 321 | 322 | return score_tensor 323 | 324 | def get_lm_inputs(self, env_outputs: List[Dict], prepare_for_update: bool) -> DataProto: 325 | """ 326 | env_outputs - please see below example 327 | [ 328 | {"env_id": 1, "history": [{"state": "###\n#x_#", "llm_response": "Response 1", "reward": 0.5}, {"state": "###\n#x_#"}]}, 329 | {"env_id": 2, "history": [{"state": "###\n#x_#"}]}, 330 | ... 331 | ] 332 | prefix_lookup - from env_id to initial prompt 333 | """ 334 | llm_input_texts = [] 335 | messages_list = [] # for api calling 336 | for env_output in env_outputs: 337 | if 'state' in env_output['history'][-1] and prepare_for_update: 338 | env_output['history'] = env_output['history'][:-1] # when prepare for update, we do not add the state from the n+1 turn to the trajectory 339 | messages = [ 340 | {"role": "system", "content": f"You're a helpful assistant. "}, 341 | {"role": "user", "content": self.prefix_lookup[env_output["env_id"]]} 342 | ] 343 | 344 | for idx, content in enumerate(env_output["history"]): 345 | messages[-1]["content"] += f"\nTurn {idx + 1}:\n" 346 | if "state" in content: 347 | 348 | if self.config.ctx_manager.mode=='add_worldmodel': 349 | FORMAT_PROMPT = " [current state] [Your thoughts] [next state] [your answer] " if self.config.agent_proxy.enable_think else " [your answer] " 350 | elif self.config.ctx_manager.mode=='base': 351 | FORMAT_PROMPT = " [Your thoughts] [your answer] " if self.config.agent_proxy.enable_think else " [your answer] " 352 | LENGTH_PROMPT = f"Max response length: {self.env_config_lookup[env_output['env_id']]['max_tokens']} words (tokens)." 353 | messages[-1]["content"] += f"State:\n{content['state']}\nYou have {content['actions_left']} actions left. Always output: {FORMAT_PROMPT} with no extra text. Strictly follow this format. {LENGTH_PROMPT}\n" 354 | if "llm_response" in content: 355 | messages.append({"role": "assistant", "content": content["llm_response"]}) 356 | if "reward" in content and not (prepare_for_update and idx == len(env_output["history"]) - 1): 357 | # when prepare for update, we do not add the reward from the n+1 turn to the trajectory 358 | messages.append({"role": "user", "content": f"Reward:\n{content['reward']}\n"}) 359 | 360 | # NOTE: this assertion is important for loss mask computation 361 | assert all(msg["role"] == "assistant" for msg in messages[2::2]) 362 | text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=(not prepare_for_update), tokenize=False) 363 | if not prepare_for_update: 364 | if self.config.agent_proxy.enable_think: 365 | text += "" # force the LLM to think before answering 366 | else: 367 | text += "" # force the LLM to answer 368 | llm_input_texts.append(text) 369 | messages_list.append(messages) 370 | # pdb.set_trace() 371 | inputs = self.tokenizer(llm_input_texts, return_tensors="pt", padding=True, padding_side="left", truncation=False) # do not truncate here. Process later at TODO 372 | input_ids, attention_mask = inputs.input_ids, inputs.attention_mask 373 | position_ids = attention_mask.cumsum(dim=-1) 374 | if prepare_for_update: 375 | scores = [[i['reward'] for i in env_output['history']] for env_output in env_outputs] 376 | observations = [i['state'] for i in env_output['history']] 377 | # observations_id = self.tokenizer(observations, return_tensors="pt", padding=T, truncation=False) # do not use padding to ensure one-to-one alignment 378 | if self.config.ctx_manager.mode=='add_worldmodel': 379 | loss_mask, score_tensor, response_mask,state_mask = get_masks_and_scores_add_worldmodel(input_ids, self.tokenizer, scores, llm_input_texts,observations,use_turn_scores=self.config.agent_proxy.use_turn_scores) 380 | 381 | elif self.config.ctx_manager.mode=='base': 382 | loss_mask, score_tensor, response_mask = get_masks_and_scores_base(input_ids, self.tokenizer, scores,use_turn_scores=self.config.agent_proxy.use_turn_scores) 383 | normalized_score_tensor = self._normalize_score_tensor(score_tensor, env_outputs) 384 | response_length = response_mask.sum(dim=-1).float().mean().item() 385 | 386 | llm_inputs = DataProto() 387 | llm_inputs.batch = TensorDict({ 388 | "input_ids": input_ids, 389 | "attention_mask": attention_mask, 390 | "position_ids": position_ids, 391 | "responses": input_ids[:, 1:], # remove the first token 392 | }, batch_size=input_ids.shape[0]) 393 | 394 | if prepare_for_update: 395 | llm_inputs.batch["loss_mask"] = loss_mask # remove the first token 396 | llm_inputs.batch["rm_scores"] = normalized_score_tensor # remove the first token 397 | if self.config.ctx_manager.mode=='add_worldmodel': 398 | llm_inputs.batch["state_mask"] = state_mask 399 | 400 | 401 | llm_inputs.non_tensor_batch = { 402 | "env_ids": np.array([env_output["env_id"] for env_output in env_outputs], dtype=object), 403 | "group_ids": np.array([env_output["group_id"] for env_output in env_outputs], dtype=object), 404 | "messages_list": np.array(messages_list, dtype=object), 405 | } 406 | 407 | if self.config.ctx_manager.mode=='add_worldmodel': 408 | llm_inputs.non_tensor_batch["predicted_next_states"] = np.array([[None] * (len([m for m in messages if m["role"] == "assistant"])) for messages in messages_list], dtype=object) 409 | llm_inputs.non_tensor_batch["real_states"] = np.array([[None] * (len([m for m in messages if m["role"] == "assistant"])) for messages in messages_list], dtype=object) 410 | llm_inputs.non_tensor_batch["predicted_cur_states"] = np.array([[None] * (len([m for m in messages if m["role"] == "assistant"])) for messages in messages_list], dtype=object) 411 | # Extract next_state content from assistant messages 412 | for env_idx, messages in enumerate(messages_list): 413 | assistant_turn = 0 # Counter for assistant turns 414 | for message in messages: 415 | if message["role"] == "assistant": 416 | content = message["content"] 417 | # Extract content between and 418 | 419 | if self.config.ctx_manager.mode=='add_worldmodel': 420 | state_match = re.search(r'(.*?).*?(.*?)', content, re.DOTALL) 421 | if state_match: 422 | llm_inputs.non_tensor_batch["predicted_cur_states"][env_idx][assistant_turn] = state_match.group(1).strip() 423 | llm_inputs.non_tensor_batch["predicted_next_states"][env_idx][assistant_turn] = state_match.group(2).strip() 424 | 425 | 426 | assistant_turn += 1 427 | 428 | # Extract next_state content from User messages 429 | for env_idx, messages in enumerate(messages_list): 430 | user_turn = 0 # Counter for assistant turns 431 | for message in messages: 432 | if message["role"] == "user": 433 | content = message["content"] 434 | 435 | 436 | if prepare_for_update: 437 | metrics = {} 438 | for env_output in env_outputs: 439 | for key, value in env_output["metrics"].items(): 440 | if key not in metrics: 441 | metrics[key] = [] 442 | metrics[key].append(value) 443 | metrics = { 444 | key: np.sum(value) / self.env_nums[key.split("/")[0]] 445 | for key, value in metrics.items() 446 | } 447 | metrics["response_length"] = response_length 448 | llm_inputs.meta_info = {"metrics": metrics} 449 | return llm_inputs 450 | 451 | def get_env_inputs(self, lm_outputs: DataProto) -> List[Dict]: 452 | if lm_outputs.batch is not None and 'responses' in lm_outputs.batch.keys(): 453 | responses = self.tokenizer.batch_decode( 454 | lm_outputs.batch['responses'], 455 | skip_special_tokens=True 456 | ) 457 | else: # dataproto has textual responses 458 | responses = lm_outputs.non_tensor_batch['response_texts'] 459 | responses = ["" + response if self.config.agent_proxy.enable_think else "" + response for response in responses] # The LLM generation does not include tags. Add them back here. 460 | 461 | env_ids = lm_outputs.non_tensor_batch['env_ids'] 462 | env_inputs = [] 463 | 464 | if self.config.ctx_manager.mode=='add_worldmodel': 465 | for env_id, response in zip(env_ids, responses): 466 | llm_response, actions, cur_state_content, next_state_content = self._parse_response_add_worldmodel(response) 467 | env_inputs.append({ 468 | "env_id": env_id, 469 | "llm_raw_response": response, 470 | "llm_response": llm_response, 471 | "actions": actions, 472 | "llm_predict_cur_state": cur_state_content, 473 | "llm_predict_next_state": next_state_content, 474 | }) 475 | elif self.config.ctx_manager.mode=='base': 476 | for env_id, response in zip(env_ids, responses): 477 | llm_response, actions = self._parse_response_base(response) 478 | env_inputs.append({ 479 | "env_id": env_id, 480 | "llm_raw_response": response, 481 | "llm_response": llm_response, 482 | "actions": actions, 483 | }) 484 | 485 | return env_inputs 486 | 487 | def formulate_rollouts(self, env_outputs: List[Dict]) -> DataProto: 488 | llm_inputs = self.get_lm_inputs(env_outputs, prepare_for_update=True) 489 | return llm_inputs 490 | 491 | 492 | 493 | 494 | 495 | @hydra.main(version_base = None, config_path = "/ssddata/shiqi/RAGEN/SPA/config", config_name = "base") 496 | def main(config): 497 | import json 498 | tokenizer = AutoTokenizer.from_pretrained(config.actor_rollout_ref.model.path) 499 | ctx_manager = ContextManager(config=config, tokenizer=tokenizer) 500 | print("ctx_manager prefix", ctx_manager.prefix_lookup) 501 | batch_list = [ 502 | { 503 | "env_id": 0, 504 | "chat_response": " 123. say | hi ", 505 | }, 506 | { 507 | "env_id": 1, 508 | "chat_response": " 456. love ; you mlll nb lxxx ; you ", 509 | } 510 | ] 511 | ctx_manager.action_sep_lookup = { 512 | 0: "|", 513 | 1: ";" 514 | } 515 | for item in batch_list: 516 | item["responses"] = tokenizer.encode(item["chat_response"], return_tensors="pt",max_length=512, truncation=True,padding="max_length")[0] 517 | batch_dict = collate_fn(batch_list) 518 | batch = DataProto.from_single_dict(batch_dict) 519 | env_inputs = ctx_manager.get_env_inputs(batch) 520 | print(env_inputs) 521 | 522 | 523 | 524 | env_outputs = [ 525 | { 526 | "env_id": 1, 527 | "history": [ 528 | {"state": "###\n#x_#", "llm_response": "Response 1", "reward": 0.5}, 529 | {"state": "###\n#x_#", "llm_response": "Response 2", "reward": 0.8}, 530 | {"state": "###\n#x_#"} 531 | ], 532 | "group_id": 0 533 | }, 534 | { 535 | "env_id": 2, 536 | "history": [ 537 | {"state": "###\n#x_#", "llm_response": "Response 3", "reward": 0.3}, 538 | {"state": "###\n#x_#"} 539 | ], 540 | "group_id": 1 541 | } 542 | ] 543 | 544 | prefix_lookup = {1: "Initial prompt", 2: "Initial prompt 2"} 545 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") 546 | env_prompt = ctx_manager.get_lm_inputs(env_outputs, prepare_for_update=False) 547 | print(env_prompt) 548 | formulate_rollouts_rst= ctx_manager.formulate_rollouts(env_outputs) 549 | print(formulate_rollouts_rst) 550 | 551 | if __name__ == "__main__": 552 | main() 553 | -------------------------------------------------------------------------------- /sft/spa_sft_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A lightweight one-file FSDP SFT Trainer 16 | TODO(zhangchi.usc1992) 17 | - Add calculation of mfu 18 | - Add validation 19 | """ 20 | 21 | import os 22 | 23 | os.environ['NCCL_DEBUG'] = 'WARN' 24 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 25 | 26 | import logging 27 | import re 28 | import pdb 29 | from contextlib import nullcontext 30 | import torch 31 | import torch.distributed 32 | from torch import nn, optim 33 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload 34 | from tqdm import tqdm 35 | from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig 36 | from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup 37 | from tensordict import TensorDict 38 | from torch.utils.data import DataLoader, DistributedSampler 39 | from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis 40 | 41 | from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager 42 | from sft.spa_sft_dataset import SFTDataset 43 | 44 | 45 | from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset 46 | from verl.utils.fs import copy_to_local 47 | from verl.utils.tracking import Tracking 48 | from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group 49 | from torch.distributed.device_mesh import DeviceMesh 50 | 51 | import verl.utils.hdfs_io as hdfs_io 52 | from verl.utils.debug import log_gpu_memory_usage 53 | from peft import LoraConfig, TaskType, get_peft_model 54 | 55 | from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager 56 | from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad 57 | from verl import DataProto 58 | 59 | logger = logging.getLogger(__file__) 60 | logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) 61 | 62 | 63 | def extract_step(path): 64 | match = re.search(r'global_step_(\d+)', path) 65 | if match: 66 | return int(match.group(1)) 67 | return None 68 | 69 | 70 | def convert_to_regular_types(obj): 71 | """Convert Hydra configs and other special types to regular Python types.""" 72 | from omegaconf import ListConfig, DictConfig 73 | if isinstance(obj, (ListConfig, DictConfig)): 74 | return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) 75 | elif isinstance(obj, (list, tuple)): 76 | return [convert_to_regular_types(x) for x in obj] 77 | elif isinstance(obj, dict): 78 | return {k: convert_to_regular_types(v) for k, v in obj.items()} 79 | return obj 80 | 81 | 82 | class FSDPSFTTrainer(object): 83 | 84 | def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh): 85 | self.config = config 86 | self.device_mesh = device_mesh 87 | self.ulysses_device_mesh = ulysses_device_mesh 88 | self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) 89 | # build tokenizer first 90 | local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) 91 | # local_model_path = '/projects/b1222/model_hub/hub/step_200' 92 | from verl.utils import hf_tokenizer 93 | self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code) 94 | if self.config.data.chat_template is not None: 95 | raise ValueError('Apply Chat template from config is not supported yet.') 96 | 97 | # Setup special tokens for observation and prediction content 98 | self._setup_special_tokens() 99 | 100 | # normalize dp size 101 | self._normalize_config_bsz() 102 | 103 | # Set sequence parallel size 104 | self.config.ulysses_sequence_parallel_size = getattr(self.config, 'ulysses_sequence_parallel_size', 1) 105 | self.use_remove_padding = getattr(self.config, 'use_remove_padding', False) 106 | if self.device_mesh.get_rank() == 0: 107 | print(f'Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}') 108 | print(f'Using remove padding: {self.use_remove_padding}') 109 | 110 | self._build_dataloader() 111 | # build model 112 | self._build_model_optimizer() 113 | 114 | # TODO: add checkpoint manager 115 | if self.device_mesh.get_rank() == 0: 116 | print(self.config) 117 | 118 | def _normalize_config_bsz(self): 119 | dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) 120 | if self.device_mesh.get_rank() == 0: 121 | print(f'Normalize batch size by dp {dp_size}') 122 | 123 | assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" 124 | 125 | self.config.data.train_batch_size //= dp_size 126 | 127 | assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 128 | 129 | def _build_dataloader(self): 130 | config = self.config 131 | # build dataset 132 | from verl.utils.import_utils import load_extern_type 133 | 134 | # First check if a custom dataset class is specified 135 | if config.data.custom_cls.get("path", None): 136 | dataset_cls = load_extern_type(config.data.custom_cls.path, config.data.custom_cls.name) 137 | # Then check if multi-turn dataset should be used 138 | elif config.data.get('multiturn', {}).get('enable', False): 139 | dataset_cls = MultiTurnSFTDataset 140 | # Default to single-turn dataset 141 | else: 142 | dataset_cls = SFTDataset 143 | 144 | self.train_dataset = SFTDataset(parquet_files=config.data.train_files, 145 | tokenizer=self.tokenizer, 146 | prompt_key=config.data.prompt_key, 147 | prompt_dict_keys=config.data.get('prompt_dict_keys', None), 148 | response_key=config.data.response_key, 149 | response_dict_keys=config.data.get('response_dict_keys', None), 150 | max_length=config.data.max_length, 151 | truncation=config.data.truncation) 152 | self.val_dataset = SFTDataset(parquet_files=config.data.val_files, 153 | tokenizer=self.tokenizer, 154 | prompt_key=config.data.prompt_key, 155 | prompt_dict_keys=config.data.get('prompt_dict_keys', None), 156 | response_key=config.data.response_key, 157 | response_dict_keys=config.data.get('response_dict_keys', None), 158 | max_length=config.data.max_length, 159 | truncation=config.data.truncation) 160 | # pdb.set_trace() 161 | # build dataloader 162 | # Use data parallel rank and size instead of global rank and world size 163 | 164 | # If doing SP, we need to use the local rank and size 165 | if self.config.ulysses_sequence_parallel_size > 1: 166 | rank = self.ulysses_device_mesh.get_local_rank('dp') 167 | world_size = self.ulysses_device_mesh.size(0) 168 | if self.ulysses_device_mesh.get_rank() == 0: 169 | print(f'Using SP rank {rank} and size {world_size} for data distribution') 170 | print(f'Each SP rank gets different data, but the same data WITHIN the same rank') 171 | else: 172 | rank = self.device_mesh.get_rank() 173 | world_size = self.device_mesh.size() 174 | if self.device_mesh.get_rank() == 0: 175 | print(f'Using FSDP rank {rank} and size {world_size} for data distribution') 176 | 177 | self.train_sampler = DistributedSampler(self.train_dataset, 178 | shuffle=True, 179 | num_replicas=world_size, 180 | rank=rank, 181 | drop_last=True) 182 | self.train_dataloader = DataLoader(dataset=self.train_dataset, 183 | batch_size=config.data.train_batch_size, 184 | sampler=self.train_sampler, 185 | num_workers=8, 186 | pin_memory=True, 187 | drop_last=True) 188 | 189 | self.val_sampler = DistributedSampler(self.val_dataset, 190 | shuffle=False, 191 | num_replicas=world_size, 192 | rank=rank, 193 | drop_last=True) 194 | self.val_dataloader = DataLoader(dataset=self.val_dataset, 195 | batch_size=config.data.micro_batch_size_per_gpu, 196 | sampler=self.val_sampler, 197 | num_workers=8, 198 | pin_memory=True, 199 | drop_last=True) 200 | 201 | print("Validation dataset size:", len(self.val_dataset)) 202 | print("Validation files:", self.config.data.val_files) 203 | 204 | def _build_model_optimizer(self): 205 | # TODO (zhangchi.usc1992): 206 | # 1. support pretrain from random weights 207 | # 2. support init directly from sharded weights 208 | local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) 209 | 210 | if self.config.model.get('external_lib', None) is not None: 211 | # This is used to import external_lib into the huggingface systems 212 | import importlib 213 | importlib.import_module(self.config.model.external_lib) 214 | 215 | log_gpu_memory_usage('Before model allocation', logger=logger) 216 | 217 | trust_remote_code = self.config.model.trust_remote_code 218 | # load config first 219 | config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) 220 | if self.config.ulysses_sequence_parallel_size > 1: 221 | assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" 222 | 223 | # This may be very large 224 | init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, 225 | mesh=self.device_mesh) 226 | 227 | with init_context(): 228 | self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, 229 | config=config, 230 | torch_dtype=torch.float32, 231 | attn_implementation='flash_attention_2', 232 | trust_remote_code=trust_remote_code) 233 | 234 | # Resize token embeddings if special tokens were added 235 | if hasattr(self.tokenizer, 'obs_content_token_id'): 236 | original_vocab_size = self.model.config.vocab_size 237 | new_vocab_size = len(self.tokenizer) 238 | if new_vocab_size != original_vocab_size: 239 | self.model.resize_token_embeddings(new_vocab_size) 240 | print(f"Resized token embeddings from {original_vocab_size} to {new_vocab_size}") 241 | 242 | if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: 243 | from verl.models.transformers.monkey_patch import apply_monkey_patch 244 | apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) 245 | 246 | # Apply Liger kernel if use_liger is enabled 247 | if self.config.model.get('use_liger', False): 248 | from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance 249 | _apply_liger_kernel_to_instance(model=self.model) 250 | 251 | if self.config.model.get('lora_rank', 0) > 0: 252 | self.model.enable_input_require_grads() 253 | # Convert config to regular Python types before creating PEFT model 254 | lora_config = { 255 | 'task_type': TaskType.CAUSAL_LM, 256 | 'r': self.config.model.lora_rank, 257 | 'lora_alpha': self.config.model.lora_alpha, 258 | 'target_modules': convert_to_regular_types(self.config.model.target_modules), 259 | 'bias': "none" 260 | } 261 | self.model = get_peft_model(self.model, LoraConfig(**lora_config)) 262 | 263 | if self.config.model.enable_gradient_checkpointing: 264 | self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) 265 | 266 | log_gpu_memory_usage('After model allocation', logger=logger) 267 | 268 | mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, 269 | reduce_dtype=torch.float32, 270 | buffer_dtype=torch.float32) 271 | 272 | auto_wrap_policy = get_fsdp_wrap_policy(self.model, 273 | config=self.config.model.fsdp_config.wrap_policy, 274 | is_lora=self.config.model.get('lora_rank', 0) > 0) 275 | if self.device_mesh.get_rank() == 0: 276 | print(auto_wrap_policy) 277 | 278 | if not self.config.model.fsdp_config.cpu_offload: 279 | cpu_offload = None 280 | else: 281 | cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) 282 | 283 | self.fsdp_model = FSDP(module=self.model, 284 | auto_wrap_policy=auto_wrap_policy, 285 | param_init_fn=init_fn, 286 | sharding_strategy=ShardingStrategy.FULL_SHARD, 287 | mixed_precision=mixed_precision, 288 | device_mesh=self.device_mesh, 289 | sync_module_states=True, 290 | device_id=torch.cuda.current_device(), 291 | cpu_offload=cpu_offload, 292 | use_orig_params=False) 293 | 294 | log_gpu_memory_usage('After FSDP wrapping', logger=logger) 295 | 296 | self.optimizer = optim.AdamW(self.fsdp_model.parameters(), 297 | lr=self.config.optim.lr, 298 | betas=self.config.optim.betas, 299 | weight_decay=self.config.optim.weight_decay) 300 | 301 | log_gpu_memory_usage('After initialize optimizer', logger=logger) 302 | 303 | self.steps_per_epoch = len(self.train_dataloader) 304 | self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs 305 | 306 | if self.device_mesh.get_rank() == 0: 307 | print( 308 | f'Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}' 309 | ) 310 | 311 | num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) 312 | 313 | if not hasattr(self.config.optim, 'lr_scheduler') or self.config.optim.lr_scheduler == 'cosine': 314 | self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, 315 | num_warmup_steps=num_warmup_steps, 316 | num_training_steps=self.total_steps) 317 | elif self.config.optim.lr_scheduler == 'wsd': 318 | self.lr_scheduler = get_wsd_schedule_with_warmup(optimizer=self.optimizer, 319 | num_warmup_steps=num_warmup_steps, 320 | num_training_steps=self.total_steps) 321 | else: 322 | raise ValueError(f'Unknown lr scheduler: {self.config.optim.lr_scheduler}') 323 | 324 | def _create_observation_prediction_mask(self, input_ids, tokenizer): 325 | """Create a mask that excludes special tokens for observation and prediction content from loss computation.""" 326 | batch_size, seq_len = input_ids.shape 327 | mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=input_ids.device) 328 | 329 | # Get the special token IDs for observation and prediction content 330 | obs_content_token_id = getattr(tokenizer, 'obs_content_token_id', None) 331 | pred_content_token_id = getattr(tokenizer, 'pred_content_token_id', None) 332 | 333 | if obs_content_token_id is not None: 334 | # Mask out observation content tokens 335 | mask = mask & (input_ids != obs_content_token_id) 336 | 337 | if pred_content_token_id is not None: 338 | # Mask out prediction content tokens 339 | mask = mask & (input_ids != pred_content_token_id) 340 | 341 | return mask 342 | 343 | def _preprocess_text_for_tokenization(self, text): 344 | """Preprocess text to replace observation and prediction tags with special tokens before tokenization.""" 345 | import re 346 | 347 | # Replace observation tags and their content with special token 348 | text = re.sub(r'.*?', '', text, flags=re.DOTALL) 349 | 350 | # Replace prediction tags and their content with special token 351 | text = re.sub(r'.*?', '', text, flags=re.DOTALL) 352 | 353 | return text 354 | 355 | def _setup_special_tokens(self): 356 | """Setup special tokens for observation and prediction content if they don't exist.""" 357 | # Check if special tokens already exist 358 | if not hasattr(self.tokenizer, 'obs_content_token_id'): 359 | # Add special tokens to tokenizer 360 | special_tokens_dict = { 361 | 'additional_special_tokens': ['', ''] 362 | } 363 | self.tokenizer.add_special_tokens(special_tokens_dict) 364 | 365 | # Store the token IDs for easy access 366 | self.tokenizer.obs_content_token_id = self.tokenizer.convert_tokens_to_ids('') 367 | self.tokenizer.pred_content_token_id = self.tokenizer.convert_tokens_to_ids('') 368 | 369 | print(f"Added special tokens: obs_content_token_id={self.tokenizer.obs_content_token_id}, pred_content_token_id={self.tokenizer.pred_content_token_id}") 370 | 371 | def _compute_loss_and_backward(self, batch, do_backward=True): 372 | """Compute loss with optional sequence parallelism and remove padding features""" 373 | use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 374 | 375 | # Move inputs to GPU and prepare loss mask 376 | input_ids = batch['input_ids'].cuda() 377 | attention_mask = batch['attention_mask'].cuda() 378 | position_ids = batch['position_ids'].cuda() 379 | loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() 380 | loss_fct = nn.CrossEntropyLoss(reduction='none') 381 | 382 | # Create observation/prediction mask to exclude these tags from loss 383 | obs_pred_mask = self._create_observation_prediction_mask(input_ids, self.tokenizer) 384 | obs_pred_mask = obs_pred_mask[:, :-1].reshape(-1).cuda() # Remove last token and flatten 385 | 386 | # Context manager for sequence parallel if needed 387 | context = self.sharding_manager if use_sp else nullcontext() 388 | with context: 389 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 390 | if not use_sp: 391 | # Standard forward pass without sequence parallel 392 | labels = input_ids[:, 1:].contiguous() 393 | output = self.fsdp_model(input_ids=input_ids, 394 | attention_mask=attention_mask, 395 | position_ids=position_ids, 396 | use_cache=False) 397 | logits = output.logits 398 | 399 | shift_logits = logits[..., :-1, :].contiguous() 400 | shift_labels = labels.contiguous() 401 | # Flatten the tokens 402 | shift_logits = shift_logits.view(-1, self.model.config.vocab_size) 403 | shift_labels = shift_labels.view(-1) 404 | # Enable model parallelism 405 | shift_labels = shift_labels.to(shift_logits.device) 406 | loss = loss_fct(shift_logits, shift_labels) 407 | # Apply both loss_mask and obs_pred_mask 408 | combined_mask = loss_mask.to(loss.device) & obs_pred_mask.to(loss.device) 409 | loss = loss * combined_mask 410 | else: 411 | # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks 412 | # i.e., each GPU has <1 sequence, and each SP group has 1 sequence 413 | # 1. All SP ranks will receive the *SAME* batch 414 | # 2. Different SP groups will receive *DIFFERENT* batches 415 | # This is implemented by the DistributedSampler 416 | 417 | batch_size, seqlen = input_ids.shape 418 | # Remove padding 419 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), 420 | attention_mask) # input_ids_rmpad (total_nnz, ...) 421 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) 422 | 423 | # Unpad position_ids to align rotary 424 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), 425 | indices).transpose(0, 1) 426 | 427 | # Pad and slice inputs for sequence parallelism 428 | input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( 429 | input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) 430 | # For computing loss 431 | input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) 432 | input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( 433 | input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()) 434 | input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) 435 | 436 | # Forward pass 437 | output = self.fsdp_model( 438 | input_ids=input_ids_rmpad_sliced, 439 | attention_mask=None, # Not needed with flash attention varlen 440 | position_ids=position_ids_rmpad_padded, 441 | use_cache=False) 442 | 443 | # Compute loss locally then aggregate 444 | logits_rmpad = output.logits.squeeze(0) 445 | input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) 446 | loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) 447 | # Gather and unpad for sequence parallelism 448 | loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) 449 | 450 | # This is the loss collected from all ulysses ranks 451 | full_loss = pad_input(hidden_states=loss.unsqueeze(-1), 452 | indices=indices, 453 | batch=batch_size, 454 | seqlen=seqlen) 455 | full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss 456 | full_loss = full_loss.reshape(-1) 457 | loss_mask = loss_mask.to(full_loss.device) 458 | # Apply both loss_mask and obs_pred_mask 459 | combined_mask = loss_mask & obs_pred_mask.to(full_loss.device) 460 | loss = full_loss * combined_mask 461 | 462 | valid_token_this_rank = torch.sum(combined_mask) 463 | 464 | if self.config.data.balance_dp_token: 465 | torch.distributed.all_reduce(valid_token_this_rank) 466 | dp_size = self.ulysses_device_mesh.size('dp') if use_sp else torch.distributed.get_world_size() 467 | else: 468 | dp_size = 1 469 | 470 | loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size 471 | 472 | if do_backward: 473 | loss.backward() 474 | return loss 475 | 476 | def training_step(self, batch: TensorDict): 477 | self.fsdp_model.train() 478 | 479 | log_gpu_memory_usage('Before optimizer zero_grad', logger=logger) 480 | 481 | self.optimizer.zero_grad() 482 | 483 | log_gpu_memory_usage('After optimizer zero_grad', logger=logger) 484 | 485 | micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) 486 | n_micro_batches = len(micro_batches) 487 | step_loss = 0 488 | for micro_batch in micro_batches: 489 | loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches 490 | step_loss += loss.item() 491 | 492 | grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) 493 | 494 | log_gpu_memory_usage('Before optimizer step', logger=logger) 495 | 496 | # if grad_norm is not finite, skip the update 497 | if not torch.isfinite(grad_norm): 498 | print(f"WARN: grad_norm is not finite: {grad_norm}") 499 | self.optimizer.zero_grad() 500 | else: 501 | self.optimizer.step() 502 | 503 | log_gpu_memory_usage('After optimizer step', logger=logger) 504 | 505 | self.lr_scheduler.step() 506 | 507 | # reduce loss across dp ranks 508 | lr = self.lr_scheduler.get_last_lr()[0] 509 | 510 | log_gpu_memory_usage('After offload weights', logger=logger) 511 | 512 | step_loss = torch.tensor(step_loss).cuda() 513 | torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) 514 | return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} 515 | 516 | def validation_step(self, batch: TensorDict): 517 | self.fsdp_model.eval() 518 | with torch.no_grad(): 519 | loss = self._compute_loss_and_backward(batch, do_backward=False) 520 | torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) 521 | return loss 522 | 523 | def save_checkpoint(self, step): 524 | # save checkpoint 525 | from torch.distributed.fsdp import FullStateDictConfig, StateDictType 526 | cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 527 | with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg): 528 | state_dict = self.fsdp_model.state_dict() 529 | 530 | path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}') 531 | # save huggingface model 532 | if self.device_mesh.get_rank() == 0: 533 | os.makedirs(path, exist_ok=True) 534 | self.model.save_pretrained(path, state_dict=state_dict) 535 | self.tokenizer.save_pretrained(path) 536 | if self.config.trainer.default_hdfs_dir: 537 | hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) 538 | hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) 539 | torch.distributed.barrier() 540 | 541 | def fit(self): 542 | rank = self.device_mesh.get_rank() 543 | 544 | # TODO: add a unified tracking 545 | if rank == 0: 546 | tracking = Tracking(project_name=self.config.trainer.project_name, 547 | experiment_name=self.config.trainer.experiment_name, 548 | default_backend=self.config.trainer.logger) 549 | 550 | global_step = 0 551 | # compute the total training steps. 552 | # the total training steps in SFT is mainly for early exit 553 | total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs 554 | 555 | if self.config.trainer.total_training_steps is not None: 556 | total_training_steps = self.config.trainer.total_training_steps 557 | 558 | self.total_training_steps = total_training_steps 559 | print(f'Total training steps: {self.total_training_steps}') 560 | 561 | # TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow. 562 | 563 | for epoch in range(self.config.trainer.total_epochs): 564 | self.train_sampler.set_epoch(epoch=epoch) 565 | for data in tqdm(self.train_dataloader, 566 | total=self.steps_per_epoch, 567 | desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): 568 | global_step += 1 569 | data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() 570 | metric = self.training_step(data) 571 | if rank == 0: 572 | tracking.log(data=metric, step=global_step) 573 | 574 | # for early exit validation 575 | if global_step >= self.total_training_steps: 576 | # Perform final validation 577 | val_losses = [] 578 | for val_data in self.val_dataloader: 579 | val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() 580 | val_loss = self.validation_step(val_data) 581 | val_losses.append(val_loss) 582 | if rank == 0: 583 | avg_val_loss = torch.mean(torch.stack(val_losses)) 584 | metric = {'val/loss': avg_val_loss.detach().item()} 585 | tracking.log(data=metric, step=global_step) 586 | torch.distributed.barrier() 587 | 588 | # Save final checkpoint 589 | self.save_checkpoint(step=global_step) 590 | return 591 | 592 | # validation 593 | val_losses = [] 594 | print("Starting validation...") 595 | for data in self.val_dataloader: 596 | print("Processing validation batch...") 597 | data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() 598 | val_loss = self.validation_step(data) 599 | val_losses.append(val_loss) 600 | print("Validation completed") 601 | if rank == 0: 602 | val_loss = torch.mean(torch.stack(val_losses)) 603 | metric = {'val/loss': val_loss.detach().item()} 604 | tracking.log(data=metric, step=global_step) 605 | torch.distributed.barrier() 606 | 607 | # save checkpoint 608 | self.save_checkpoint(step=global_step) 609 | 610 | 611 | # from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer 612 | import hydra 613 | 614 | from torch.distributed.device_mesh import init_device_mesh 615 | 616 | from verl.utils.distributed import initialize_global_process_group 617 | 618 | 619 | @hydra.main(config_path='config', config_name='sft_trainer', version_base=None) 620 | def main(config): 621 | local_rank, rank, world_size = initialize_global_process_group() 622 | 623 | device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) 624 | dp_size = world_size // config.ulysses_sequence_parallel_size 625 | ulysses_device_mesh = init_device_mesh(device_type='cuda', 626 | mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), 627 | mesh_dim_names=('dp', 'sp')) 628 | trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) 629 | trainer.fit() 630 | 631 | 632 | if __name__ == '__main__': 633 | main() 634 | --------------------------------------------------------------------------------