├── openrlhf ├── __init__.py ├── cli │ ├── __init__.py │ ├── lora_combiner.py │ ├── serve_rm.py │ ├── interactive_chat.py │ └── train_prm.py ├── utils │ ├── deepspeed │ │ ├── __init__.py │ │ └── deepspeed_utils.py │ ├── __init__.py │ ├── logging_utils.py │ ├── remote_rm_utils.py │ ├── distributed_util.py │ ├── processor.py │ ├── utils.py │ └── distributed_sampler.py ├── trainer │ ├── ppo_utils │ │ ├── __init__.py │ │ ├── kl_controller.py │ │ └── replay_buffer.py │ ├── ray │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── vllm_worker_wrap.py │ │ ├── vllm_engine.py │ │ └── ppo_critic.py │ ├── __init__.py │ └── prm_trainer.py ├── datasets │ ├── __init__.py │ ├── utils.py │ ├── prompts_dataset.py │ ├── process_reward_dataset.py │ ├── unpaired_preference_dataset.py │ ├── reward_dataset.py │ └── sft_dataset.py └── models │ ├── __init__.py │ ├── ring_attn_utils.py │ ├── utils.py │ └── actor.py ├── version.txt ├── docs ├── logo.png ├── ray_architecture.png └── ppo_examples.md ├── CONTRIBUTING.md ├── examples └── scripts │ ├── reward_func.py │ ├── serve_remote_rm.sh │ ├── docker_run.sh │ ├── train_continue_pretrain_llama.sh │ ├── train_sft_llama.sh │ ├── train_sft_mixtral_lora.sh │ ├── train_kto_llama.sh │ ├── train_sft_llama_tensor_parallelism.sh │ ├── train_rm_llama.sh │ ├── train_knowledge_distillation.sh │ ├── train_prm_mistral.sh │ ├── train_dpo_llama.sh │ ├── nvidia_docker_install.sh │ ├── train_dpo_ring_llama.sh │ ├── train_conditional_llama.sh │ ├── train_ppo_llama_ray_70b.sh │ ├── train_reinforce_llama_ray.sh │ ├── train_ppo_llama_with_remote_rm.sh │ ├── train_grpo_llama_ray.sh │ ├── train_ppo_llama_with_reward_fn.sh │ ├── train_ppo_llama_ray_hybrid_engine.sh │ ├── train_ppo_llama_ray.sh │ ├── train_ppo_llama_ray_tensor_parallelism.sh │ ├── train_ppo_llama_ray_ring.sh │ ├── train_grpo_ray_hybrid_engine.sh │ ├── train_llama_slurm.sh │ ├── train_reinforce_llama_ray_hybrid_engine.sh │ ├── train_reinforce_baseline_llama_ray_hybrid_engine.sh │ ├── ckpt_ds_zero_to_universal.sh │ ├── train_iterative_dpo_llama.sh │ ├── train_rejection_sampling_llama.sh │ └── train_ppo_llama_ray_slurm.sh ├── requirements.txt ├── dockerfile ├── docker-entrypoint.sh └── Dockerfile ├── .pre-commit-config.yaml ├── .github └── workflows │ └── python-package.yml ├── pyproject.toml ├── setup.py └── .gitignore /openrlhf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.7.4.post2 -------------------------------------------------------------------------------- /openrlhf/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ItamarRocha/OpenRLHF_Lora_Rl/main/docs/logo.png -------------------------------------------------------------------------------- /docs/ray_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ItamarRocha/OpenRLHF_Lora_Rl/main/docs/ray_architecture.png -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed import DeepspeedStrategy 2 | 3 | __all__ = [ 4 | "DeepspeedStrategy", 5 | ] 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to OpenRLHF 2 | 3 | After cloning the repository, please install pre-commit hooks with: 4 | ``` 5 | pip install pre-commit 6 | pre-commit install 7 | ``` -------------------------------------------------------------------------------- /examples/scripts/reward_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def reward_func(queries, prompts, labels): 5 | # queries is prompts + responses 6 | # labels is answers 7 | print(queries) 8 | return torch.randn(len(queries)) 9 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .kl_controller import AdaptiveKLController, FixedKLController 2 | from .replay_buffer import NaiveReplayBuffer 3 | 4 | __all__ = [ 5 | "AdaptiveKLController", 6 | "FixedKLController", 7 | "NaiveReplayBuffer", 8 | ] 9 | -------------------------------------------------------------------------------- /examples/scripts/serve_remote_rm.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | python -m openrlhf.cli.serve_rm \ 4 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ 5 | --port 5000 \ 6 | --bf16 \ 7 | --flash_attn \ 8 | --normalize_reward \ 9 | --max_len 8192 \ 10 | --batch_size 16 -------------------------------------------------------------------------------- /openrlhf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import get_processor, reward_normalization 2 | from .utils import blending_datasets, get_strategy, get_tokenizer 3 | 4 | __all__ = [ 5 | "get_processor", 6 | "reward_normalization", 7 | "blending_datasets", 8 | "get_strategy", 9 | "get_tokenizer", 10 | ] 11 | -------------------------------------------------------------------------------- /examples/scripts/docker_run.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | PROJECT_PATH=$(cd $(dirname $0)/../../; pwd) 4 | IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" 5 | 6 | docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN \ 7 | -v $PROJECT_PATH:/openrlhf -v $HOME/.cache:/root/.cache -v $HOME/.bash_history2:/root/.bash_history \ 8 | $IMAGE_NAME bash -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets 4 | deepspeed==0.16.7 5 | einops 6 | flash-attn==2.7.4.post1 7 | isort 8 | jsonlines 9 | loralib 10 | optimum 11 | optree>=0.13.0 12 | packaging 13 | peft 14 | pynvml>=12.0.0 15 | ray[default]==2.43.0 16 | tensorboard 17 | torch 18 | torchmetrics 19 | tqdm 20 | transformers==4.51.3 21 | transformers_stream_generator 22 | wandb 23 | wheel 24 | -------------------------------------------------------------------------------- /openrlhf/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .process_reward_dataset import ProcessRewardDataset 2 | from .prompts_dataset import PromptDataset 3 | from .reward_dataset import RewardDataset 4 | from .sft_dataset import SFTDataset 5 | from .unpaired_preference_dataset import UnpairedPreferenceDataset 6 | 7 | __all__ = ["ProcessRewardDataset", "PromptDataset", "RewardDataset", "SFTDataset", "UnpairedPreferenceDataset"] 8 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/__init__.py: -------------------------------------------------------------------------------- 1 | from .launcher import DistributedTorchRayActor, PPORayActorGroup, ReferenceModelRayActor, RewardModelRayActor 2 | from .vllm_engine import batch_vllm_engine_call, create_vllm_engines 3 | 4 | __all__ = [ 5 | "DistributedTorchRayActor", 6 | "PPORayActorGroup", 7 | "ReferenceModelRayActor", 8 | "RewardModelRayActor", 9 | "create_vllm_engines", 10 | "batch_vllm_engine_call", 11 | ] 12 | -------------------------------------------------------------------------------- /openrlhf/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .dpo_trainer import DPOTrainer 2 | from .kd_trainer import KDTrainer 3 | from .kto_trainer import KTOTrainer 4 | from .ppo_trainer import PPOTrainer 5 | from .prm_trainer import ProcessRewardModelTrainer 6 | from .rm_trainer import RewardModelTrainer 7 | from .sft_trainer import SFTTrainer 8 | 9 | __all__ = [ 10 | "DPOTrainer", 11 | "KDTrainer", 12 | "KTOTrainer", 13 | "PPOTrainer", 14 | "ProcessRewardModelTrainer", 15 | "RewardModelTrainer", 16 | "SFTTrainer", 17 | ] 18 | -------------------------------------------------------------------------------- /openrlhf/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def zero_pad_sequences(sequences, side: str = "left", value=0): 6 | assert side in ("left", "right") 7 | max_len = max(seq.size(-1) for seq in sequences) 8 | padded_sequences = [] 9 | for seq in sequences: 10 | pad_len = max_len - seq.size(-1) 11 | padding = (pad_len, 0) if side == "left" else (0, pad_len) 12 | padded_sequences.append(F.pad(seq, padding, value=value)) 13 | return torch.stack(padded_sequences, dim=0) 14 | 15 | 16 | def exist_and_not_none(d, key): 17 | return key in d and not d[key] is None 18 | -------------------------------------------------------------------------------- /dockerfile/docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | USER=${LOCAL_USER:-"root"} 4 | 5 | if [[ "${USER}" != "root" ]]; then 6 | USER_ID=${LOCAL_USER_ID:-9001} 7 | echo ${USER} 8 | echo ${USER_ID} 9 | 10 | chown ${USER_ID} /home/${USER} 11 | useradd --shell /bin/bash -u ${USER_ID} -o -c "" -m ${USER} 12 | usermod -a -G root ${USER} 13 | adduser ${USER} sudo 14 | 15 | # user:password 16 | echo "${USER}:123" | chpasswd 17 | 18 | export HOME=/home/${USER} 19 | export PATH=/home/${USER}/.local/bin/:$PATH 20 | else 21 | export PATH=/root/.local/bin/:$PATH 22 | fi 23 | 24 | cd $HOME 25 | exec gosu ${USER} "$@" -------------------------------------------------------------------------------- /openrlhf/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import Actor 2 | from .loss import ( 3 | DPOLoss, 4 | GPTLMLoss, 5 | KDLoss, 6 | KTOLoss, 7 | LogExpLoss, 8 | PairWiseLoss, 9 | PolicyLoss, 10 | PRMLoss, 11 | SFTLoss, 12 | ValueLoss, 13 | VanillaKTOLoss, 14 | ) 15 | from .model import get_llm_for_sequence_regression 16 | 17 | __all__ = [ 18 | "Actor", 19 | "SFTLoss", 20 | "DPOLoss", 21 | "GPTLMLoss", 22 | "KDLoss", 23 | "KTOLoss", 24 | "LogExpLoss", 25 | "PairWiseLoss", 26 | "PolicyLoss", 27 | "PRMLoss", 28 | "ValueLoss", 29 | "VanillaKTOLoss", 30 | "get_llm_for_sequence_regression", 31 | ] 32 | -------------------------------------------------------------------------------- /dockerfile/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.07-py3 2 | 3 | WORKDIR /app 4 | 5 | RUN set -eux && \ 6 | apt-get update && \ 7 | apt-get install -y gosu && \ 8 | rm -rf /var/lib/apt/lists/* && \ 9 | gosu nobody true 10 | 11 | RUN apt-get update && apt-get -y install sudo 12 | RUN sudo su - 13 | 14 | RUN DEBIAN_FRONTEND=noninteractive apt install -y tzdata 15 | 16 | RUN apt-get -y install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev gdb 17 | RUN pip uninstall xgboost transformer_engine flash_attn pynvml opencv-python-headless -y 18 | RUN pip install vllm==0.8.3 19 | 20 | COPY docker-entrypoint.sh . 21 | RUN chmod a+x docker-entrypoint.sh 22 | 23 | ENTRYPOINT ["/app/docker-entrypoint.sh"] 24 | -------------------------------------------------------------------------------- /examples/scripts/train_continue_pretrain_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands <> ${JOBLOG} 26 | 27 | # load training commands 28 | source ./${training_script} slurm 29 | echo training_commands &>> ${JOBLOG} 30 | echo $training_commands &>> ${JOBLOG} 31 | 32 | # master addr and port 33 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 34 | export MASTER_PORT=9901 35 | 36 | srun --container-image="$IMAGE_NAME" \ 37 | --container-mounts="$PROJECT_PATH:/openrlhf,$HOME/.cache:/root/.cache" \ 38 | bash -c " cd /openrlhf; pip install . ; torchrun \ 39 | --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \ 40 | --master_addr $MASTER_ADDR --master_port $MASTER_PORT -m ${training_commands}" &>> ${JOBLOG} 41 | 42 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} -------------------------------------------------------------------------------- /examples/scripts/train_reinforce_llama_ray_hybrid_engine.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | ray job submit --address="http://127.0.0.1:8265" \ 4 | --runtime-env-json='{"working_dir": "/openrlhf"}' \ 5 | -- python3 -m openrlhf.cli.train_ppo_ray \ 6 | --ref_num_nodes 1 \ 7 | --ref_num_gpus_per_node 8 \ 8 | --reward_num_nodes 1 \ 9 | --reward_num_gpus_per_node 8 \ 10 | --actor_num_nodes 1 \ 11 | --actor_num_gpus_per_node 8 \ 12 | --vllm_num_engines 4 \ 13 | --vllm_tensor_parallel_size 2 \ 14 | --colocate_all_models \ 15 | --vllm_gpu_memory_utilization 0.6 \ 16 | --advantage_estimator reinforce \ 17 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 18 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ 19 | --save_path /openrlhf/examples/test_scripts/final/llama3-8b-rlhf \ 20 | --ckpt_path /openrlhf/examples/test_scripts/ckpt/llama3-8b-rlhf \ 21 | --save_hf_ckpt \ 22 | --micro_train_batch_size 4 \ 23 | --train_batch_size 128 \ 24 | --micro_rollout_batch_size 8 \ 25 | --rollout_batch_size 1024 \ 26 | --n_samples_per_prompt 1 \ 27 | --max_epochs 1 \ 28 | --prompt_max_len 1024 \ 29 | --max_samples 100000 \ 30 | --generate_max_len 1024 \ 31 | --zero_stage 3 \ 32 | --bf16 \ 33 | --actor_learning_rate 5e-7 \ 34 | --critic_learning_rate 9e-6 \ 35 | --init_kl_coef 1e-4 \ 36 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 37 | --input_key context_messages \ 38 | --apply_chat_template \ 39 | --normalize_reward \ 40 | --gradient_checkpointing \ 41 | --packing_samples \ 42 | --vllm_sync_backend nccl \ 43 | --enforce_eager \ 44 | --vllm_enable_sleep \ 45 | --deepspeed_enable_sleep 46 | 47 | # You could also try 48 | # --use_kl_loss \ 49 | # --kl_estimator k3 | k2 \ 50 | 51 | # also supports --advantage_estimator rloo | reinforce_baseline 52 | -------------------------------------------------------------------------------- /examples/scripts/train_reinforce_baseline_llama_ray_hybrid_engine.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | ray job submit --address="http://127.0.0.1:8265" \ 4 | --runtime-env-json='{"working_dir": "/openrlhf"}' \ 5 | -- python3 -m openrlhf.cli.train_ppo_ray \ 6 | --ref_num_nodes 1 \ 7 | --ref_num_gpus_per_node 8 \ 8 | --reward_num_nodes 1 \ 9 | --reward_num_gpus_per_node 8 \ 10 | --actor_num_nodes 1 \ 11 | --actor_num_gpus_per_node 8 \ 12 | --vllm_num_engines 4 \ 13 | --vllm_tensor_parallel_size 2 \ 14 | --colocate_all_models \ 15 | --vllm_gpu_memory_utilization 0.6 \ 16 | --advantage_estimator reinforce_baseline \ 17 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 18 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ 19 | --save_path /openrlhf/examples/test_scripts/final/llama3-8b-rlhf \ 20 | --ckpt_path /openrlhf/examples/test_scripts/ckpt/llama3-8b-rlhf \ 21 | --save_hf_ckpt \ 22 | --micro_train_batch_size 4 \ 23 | --train_batch_size 128 \ 24 | --micro_rollout_batch_size 8 \ 25 | --rollout_batch_size 128 \ 26 | --n_samples_per_prompt 8 \ 27 | --init_kl_coef 1e-3 \ 28 | --gamma 1.0 \ 29 | --use_kl_loss \ 30 | --kl_estimator k2 \ 31 | --max_epochs 1 \ 32 | --prompt_max_len 1024 \ 33 | --max_samples 100000 \ 34 | --generate_max_len 1024 \ 35 | --zero_stage 3 \ 36 | --bf16 \ 37 | --actor_learning_rate 5e-7 \ 38 | --critic_learning_rate 9e-6 \ 39 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 40 | --input_key context_messages \ 41 | --apply_chat_template \ 42 | --normalize_reward \ 43 | --gradient_checkpointing \ 44 | --packing_samples \ 45 | --vllm_sync_backend nccl \ 46 | --enforce_eager \ 47 | --vllm_enable_sleep \ 48 | --deepspeed_enable_sleep 49 | 50 | # You could also try 51 | # --use_kl_loss \ 52 | # --kl_estimator k3 | k2 \ 53 | 54 | # also supports --advantage_estimator rloo | reinforce_baseline 55 | -------------------------------------------------------------------------------- /openrlhf/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py 3 | """Logging configuration for vLLM.""" 4 | import logging 5 | import sys 6 | 7 | _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" 8 | _DATE_FORMAT = "%m-%d %H:%M:%S" 9 | 10 | 11 | class NewLineFormatter(logging.Formatter): 12 | """Adds logging prefix to newlines to align multi-line messages.""" 13 | 14 | def __init__(self, fmt, datefmt=None): 15 | logging.Formatter.__init__(self, fmt, datefmt) 16 | 17 | def format(self, record): 18 | msg = logging.Formatter.format(self, record) 19 | if record.message != "": 20 | parts = msg.split(record.message) 21 | msg = msg.replace("\n", "\r\n" + parts[0]) 22 | return msg 23 | 24 | 25 | _root_logger = logging.getLogger("openrlhf") 26 | _default_handler = None 27 | 28 | 29 | def _setup_logger(): 30 | _root_logger.setLevel(logging.DEBUG) 31 | global _default_handler 32 | if _default_handler is None: 33 | _default_handler = logging.StreamHandler(sys.stdout) 34 | _default_handler.flush = sys.stdout.flush # type: ignore 35 | _default_handler.setLevel(logging.INFO) 36 | _root_logger.addHandler(_default_handler) 37 | fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) 38 | _default_handler.setFormatter(fmt) 39 | # Setting this will avoid the message 40 | # being propagated to the parent logger. 41 | _root_logger.propagate = False 42 | 43 | 44 | # The logger is initialized when the module is imported. 45 | # This is thread-safe as the module is only imported once, 46 | # guaranteed by the Python GIL. 47 | _setup_logger() 48 | 49 | 50 | def init_logger(name: str): 51 | # Use the same settings as above for root logger 52 | logger = logging.getLogger(name) 53 | logger.setLevel(logging.DEBUG) 54 | logger.addHandler(_default_handler) 55 | logger.propagate = False 56 | return logger 57 | -------------------------------------------------------------------------------- /openrlhf/utils/remote_rm_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import ray 3 | import requests 4 | import torch 5 | 6 | from openrlhf.utils.logging_utils import init_logger 7 | 8 | logger = init_logger(__name__) 9 | 10 | 11 | def request_api_wrapper(url, data, score_key="rewards", try_max_times=5): 12 | """Synchronous request API wrapper""" 13 | headers = { 14 | "Content-Type": "application/json", 15 | } 16 | for _ in range(try_max_times): 17 | try: 18 | response = requests.post(url=url, json=data, headers=headers, timeout=180) 19 | response.raise_for_status() # Raise an HTTPError for bad responses 20 | response = response.json() 21 | assert score_key in response, f"{score_key} not in {response}" 22 | return response.get(score_key) 23 | except requests.RequestException as e: 24 | logger.info(f"Request error, please check: {e}") 25 | except Exception as e: 26 | logger.info(f"Unexpected error, please check: {e}") 27 | time.sleep(1) 28 | 29 | raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.") 30 | 31 | 32 | def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards"): 33 | """remote reward model API 34 | api_url: RM API, We assume that the API supports two modes: merging query + response and not merging 35 | queries: query+response with the template 36 | design is made optional. 37 | score_key: RM score key 38 | """ 39 | scores = request_api_wrapper(api_url, {"query": queries, "prompts": prompts, "labels": labels}, score_key) 40 | return torch.tensor(scores) 41 | 42 | 43 | @ray.remote 44 | def remote_rm_fn_ray(api_url, queries, prompts, labels, score_key="rewards"): 45 | return remote_rm_fn(api_url, queries, prompts, labels, score_key) 46 | 47 | 48 | if __name__ == "__main__": 49 | # test utils 50 | url = "http:xxx/get_rm_score" 51 | score = remote_rm_fn(url, ["example query"], ["example response"]) 52 | print(score) 53 | -------------------------------------------------------------------------------- /examples/scripts/ckpt_ds_zero_to_universal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Ensure at least one argument is provided. 4 | if [ "$#" -lt 1 ]; then 5 | echo "This script converts the latest DeepSpeed ZeRO checkpoint to a universal checkpoint." 6 | echo "Usage: $0 [additional arguments for deepspeed.checkpoint.ds_to_universal]" 7 | exit 1 8 | fi 9 | 10 | # Set CKPT_PATH to the first argument and shift it out so that "$@" contains the extra arguments. 11 | CKPT_PATH="$1" 12 | shift 13 | EXTRA_ARGS="$@" 14 | 15 | # Function to process a given directory. 16 | process_dir() { 17 | local path="$1" 18 | echo "Processing checkpoint: $path" 19 | 20 | # Check if the latest tag exists. 21 | if [ ! -f "$path/latest" ]; then 22 | echo "latest tag file not found in $path, ensure the directory contains a valid DeepSpeed ZeRO checkpoint." 23 | return 1 24 | fi 25 | 26 | # Read the latest tag. 27 | LATEST_TAG=$(cat "$path/latest") 28 | LATEST_UNI_TAG="${LATEST_TAG}_uni" 29 | 30 | # Write the universal tag. 31 | echo "$LATEST_UNI_TAG" > "$path/latest_universal" 32 | 33 | # Run the python command with any additional arguments. 34 | python -m deepspeed.checkpoint.ds_to_universal --inject_missing_state \ 35 | --input_folder "$path/$LATEST_TAG" \ 36 | --output_folder "$path/$LATEST_UNI_TAG" \ 37 | $EXTRA_ARGS 38 | } 39 | 40 | # Flag to check if at least one of the specific subdirectories exists. 41 | found_subdir=0 42 | 43 | ## For PPO, checkpoints for each model are stored under "_actor" and "_critic" separately. 44 | # Check for the subdirectory named exactly "_actor". 45 | if [ -d "$CKPT_PATH/_actor" ]; then 46 | process_dir "$CKPT_PATH/_actor" 47 | found_subdir=1 48 | fi 49 | 50 | # Check for the subdirectory named exactly "_critic". 51 | if [ -d "$CKPT_PATH/_critic" ]; then 52 | process_dir "$CKPT_PATH/_critic" 53 | found_subdir=1 54 | fi 55 | 56 | # If neither subdirectory exists, process the main CKPT_PATH. 57 | if [ "$found_subdir" -eq 0 ]; then 58 | process_dir "$CKPT_PATH" 59 | fi 60 | -------------------------------------------------------------------------------- /openrlhf/cli/lora_combiner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from peft import PeftModel 5 | from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer 6 | 7 | 8 | def apply_lora(model_name_or_path, lora_path, output_path, is_rm, bf16): 9 | print(f"Loading the base model from {model_name_or_path}") 10 | model_cls = AutoModelForCausalLM if not is_rm else AutoModelForSequenceClassification 11 | base = model_cls.from_pretrained( 12 | model_name_or_path, torch_dtype=torch.bfloat16 if bf16 else "auto", low_cpu_mem_usage=True 13 | ) 14 | base_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 15 | 16 | print(f"Loading the LoRA adapter from {lora_path}") 17 | # apply lora to transformer 18 | lora_model = PeftModel.from_pretrained( 19 | base, 20 | lora_path, 21 | torch_dtype=torch.bfloat16 if bf16 else "auto", 22 | ) 23 | 24 | print("Applying and merging the LoRA weights") 25 | lora_model.merge_and_unload() 26 | 27 | print(f"Saving the complete model to {output_path}") 28 | base.save_pretrained(output_path) 29 | base_tokenizer.save_pretrained(output_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser(description="Apply LoRA to a base model and save the combined model.") 34 | parser.add_argument("--model_path", type=str, required=True, help="Path to the base model directory.") 35 | parser.add_argument("--lora_path", type=str, required=True, help="Path to the LoRA adapter directory.") 36 | parser.add_argument("--output_path", type=str, required=True, help="Path to save the combined model.") 37 | parser.add_argument( 38 | "--is_rm", 39 | action="store_true", 40 | default=False, 41 | help="Whether to treat the model as a reward model (AutoModelForSequenceClassification)", 42 | ) 43 | parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") 44 | args = parser.parse_args() 45 | apply_lora(args.model_path, args.lora_path, args.output_path, args.is_rm, args.bf16) 46 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | build-and-publish: 10 | # do not run in forks 11 | if: ${{ github.repository_owner == 'OpenRLHF' && (github.event_name == 'release' || github.event_name == 'workflow_dispatch') }} 12 | name: build wheel and upload 13 | runs-on: ubuntu-22.04 14 | 15 | strategy: 16 | matrix: 17 | python-version: [3.10.14, 3.11.0] 18 | cuda-version: [12.1] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install CUDA ${{ matrix.cuda-version }} 29 | run: | 30 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin 31 | sudo mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 32 | sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub 33 | sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" 34 | sudo apt-get update 35 | sudo apt-get -y install cuda-${{ matrix.cuda-version }} 36 | 37 | - name: Set up CUDA environment variables 38 | run: | 39 | echo "/usr/local/cuda-${{ matrix.cuda-version }}/lib64" | sudo tee -a /etc/ld.so.conf.d/cuda.conf 40 | echo "export PATH=/usr/local/cuda-${{ matrix.cuda-version }}/bin:\$PATH" | sudo tee -a /etc/environment 41 | sudo ldconfig 42 | shell: bash 43 | 44 | - name: Install dependencies 45 | run: | 46 | python -m pip install --upgrade pip 47 | pip install setuptools wheel twine packaging 48 | 49 | - name: Build package 50 | run: | 51 | python setup.py bdist_wheel --dist-dir=dist 52 | 53 | - name: Publish package 54 | env: 55 | TWINE_USERNAME: __token__ 56 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 57 | run: | 58 | python -m twine upload dist/* -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "packaging", 4 | "setuptools >= 49.4.0", 5 | "wheel", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.isort] 10 | profile = "black" # black-compatible 11 | line_length = 119 # should match black parameters 12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style 13 | py_version = 310 # python 3.10 as a target version 14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] 15 | default_section = "THIRDPARTY" 16 | extend_skip = ["setup.py", "docs/source/conf.py"] 17 | 18 | 19 | [tool.black] 20 | line_length = 119 21 | 22 | [tool.ruff] 23 | line-length = 119 24 | 25 | [tool.pytest.ini_options] 26 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. 27 | # -vv will also display tests with durration = 0.00s 28 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest 29 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module 30 | # directories to ignore when discovering tests 31 | norecursedirs = [ 32 | "external", 33 | "examples", 34 | "docs", 35 | "scripts", 36 | "tools", 37 | "tutorials", 38 | "*.egg", 39 | ".*", 40 | "_darcs", 41 | "build", 42 | "CVS", 43 | "dist", 44 | "venv", 45 | "{arch}", 46 | ] 47 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests 48 | markers = [ 49 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')", 50 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')", 51 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')", 52 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')", 53 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')", 54 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups", 55 | "pleasefixme: marks tests that are broken and need fixing", 56 | ] 57 | -------------------------------------------------------------------------------- /openrlhf/datasets/prompts_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tqdm import tqdm 3 | 4 | 5 | def preprocess_data(data, input_template=None, input_key="input", label_key=None, apply_chat_template=None) -> str: 6 | if apply_chat_template: 7 | chat = data[input_key] 8 | if isinstance(chat, str): 9 | chat = [{"role": "user", "content": chat}] 10 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 11 | else: 12 | prompt = data[input_key] 13 | if input_template: 14 | prompt = input_template.format(prompt) 15 | 16 | # for Reinforced Fine-tuning 17 | label = "" if label_key is None else data[label_key] 18 | return prompt, label 19 | 20 | 21 | class PromptDataset(Dataset): 22 | """ 23 | Dataset for PPO model 24 | 25 | Args: 26 | dataset: dataset for PPO model 27 | tokenizer: tokenizer for PPO model 28 | max_length: max length of input 29 | """ 30 | 31 | def __init__( 32 | self, 33 | dataset, 34 | tokenizer, 35 | strategy, 36 | input_template=None, 37 | ) -> None: 38 | super().__init__() 39 | self.strategy = strategy 40 | self.tokenizer = tokenizer 41 | 42 | # chat_template 43 | self.input_template = input_template 44 | input_key = getattr(self.strategy.args, "input_key", None) 45 | label_key = getattr(self.strategy.args, "label_key", None) 46 | apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 47 | 48 | if apply_chat_template: 49 | apply_chat_template = self.tokenizer.apply_chat_template 50 | 51 | self.prompts = [] 52 | self.labels = [] 53 | self.datasources = [] 54 | for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()): 55 | prompt, label = preprocess_data(data, input_template, input_key, label_key, apply_chat_template) 56 | self.prompts.append(prompt) 57 | self.labels.append(label) 58 | self.datasources.append(data.get("datasource", "default")) 59 | 60 | def __len__(self): 61 | length = len(self.prompts) 62 | return length 63 | 64 | def __getitem__(self, idx): 65 | return self.datasources[idx], self.prompts[idx], self.labels[idx] 66 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import platform 4 | 5 | from datetime import datetime 6 | from setuptools import find_packages, setup 7 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 8 | 9 | _build_mode = os.getenv("OPENRLHF_BUILD_MODE", "") 10 | 11 | 12 | def _is_nightly(): 13 | return _build_mode.lower() == "nightly" 14 | 15 | 16 | def _fetch_requirements(path): 17 | with open(path, "r") as fd: 18 | return [r.strip() for r in fd.readlines()] 19 | 20 | 21 | def _fetch_readme(): 22 | with open("README.md", encoding="utf-8") as f: 23 | return f.read() 24 | 25 | 26 | def _fetch_version(): 27 | with open("version.txt", "r") as f: 28 | version = f.read().strip() 29 | 30 | if _is_nightly(): 31 | now = datetime.now() 32 | date_str = now.strftime("%Y%m%d") 33 | version += f".dev{date_str}" 34 | 35 | return version 36 | 37 | 38 | def _fetch_package_name(): 39 | return "openrlhf-nightly" if _is_nightly() else "openrlhf" 40 | 41 | 42 | # Custom wheel class to modify the wheel name 43 | class bdist_wheel(_bdist_wheel): 44 | def finalize_options(self): 45 | _bdist_wheel.finalize_options(self) 46 | self.root_is_pure = False 47 | 48 | def get_tag(self): 49 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 50 | abi_tag = f"{python_version}" 51 | 52 | if platform.system() == "Linux": 53 | platform_tag = "manylinux1_x86_64" 54 | else: 55 | platform_tag = platform.system().lower() 56 | 57 | return python_version, abi_tag, platform_tag 58 | 59 | 60 | # Setup configuration 61 | setup( 62 | author="OpenRLHF Team", 63 | name=_fetch_package_name(), 64 | version=_fetch_version(), 65 | packages=find_packages( 66 | exclude=( 67 | "data", 68 | "docs", 69 | "examples", 70 | ) 71 | ), 72 | description="A Ray-based High-performance RLHF framework.", 73 | long_description=_fetch_readme(), 74 | long_description_content_type="text/markdown", 75 | install_requires=_fetch_requirements("requirements.txt"), 76 | extras_require={ 77 | "vllm": ["vllm==0.8.3"], 78 | "vllm_latest": ["vllm>0.8.3"], 79 | "ring": ["ring_flash_attn"], 80 | "liger": ["liger_kernel"], 81 | }, 82 | python_requires=">=3.10", 83 | classifiers=[ 84 | "Programming Language :: Python :: 3.10", 85 | "Programming Language :: Python :: 3.11", 86 | "Environment :: GPU :: NVIDIA CUDA", 87 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 88 | "Topic :: System :: Distributed Computing", 89 | ], 90 | cmdclass={"bdist_wheel": bdist_wheel}, 91 | ) 92 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | # Address https://github.com/ray-project/ray/issues/51117 5 | # This function is used to get the bundle indices of a placement group 6 | # and ensure that the bundles placed on the same node are grouped together. 7 | def get_bundle_indices(placement_group, index, length): 8 | import ray 9 | 10 | pg_infos = ray.util.placement_group_table(placement_group) 11 | 12 | node_id_to_bundles = {} 13 | for bundle, node_id in pg_infos["bundles_to_node_id"].items(): 14 | node_id_to_bundles.setdefault(node_id, []).append(bundle) 15 | 16 | sorted_bundle_indices = sum(node_id_to_bundles.values(), []) 17 | return sorted_bundle_indices[index * length : (index + 1) * length] 18 | 19 | 20 | def ray_noset_visible_devices(env_vars=os.environ): 21 | # Refer to 22 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 23 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 24 | # https://github.com/ray-project/ray/blob/3b9e729f6a669ffd85190f901f5e262af79771b0/python/ray/_private/accelerators/amd_gpu.py#L114-L115 25 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 26 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 27 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 28 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 29 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 30 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ 31 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", 32 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", 33 | "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES", 34 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", 35 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", 36 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", 37 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", 38 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", 39 | ] 40 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) 41 | 42 | 43 | def get_physical_gpu_id(): 44 | import torch 45 | 46 | device = torch.cuda.current_device() 47 | props = torch.cuda.get_device_properties(device) 48 | return str(props.uuid) 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/.build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # IDE 133 | .idea/ 134 | .vscode/ 135 | 136 | # macos 137 | *.DS_Store 138 | #data/ 139 | 140 | docs/.build 141 | 142 | # pytorch checkpoint 143 | *.pt 144 | 145 | core 146 | */ckpt/* 147 | .vscode 148 | .nfs* 149 | *jianh* 150 | *test_scripts* 151 | */checkpoint/* -------------------------------------------------------------------------------- /examples/scripts/train_iterative_dpo_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 3 | 4 | checkSuccess() { 5 | if [[ $? != 0 ]]; then 6 | echo "FAILED $1" 7 | exit 1 8 | fi 9 | } 10 | 11 | mkdir -p ./checkpoint/llama-3-8b-iter-dpo 12 | GENERATE_OUTPUT=./checkpoint/llama-3-8b-iter-dpo/generate.jsonl 13 | RM_OUTPUT=./checkpoint/llama-3-8b-iter-dpo/rm.jsonl 14 | MODEL_OUTPUT_PATH=./checkpoint/llama-3-8b-iter-dpo/checkpoint 15 | ITER_LOG_PATH=null 16 | 17 | TRAINING_ITERS=5 18 | ROLLOUT_BATCH_SIZE=10240 19 | 20 | POLICY_MODEL_PATH=OpenRLHF/Llama-3-8b-sft-mixture 21 | REF_MODEL_PATH=$POLICY_MODEL_PATH 22 | 23 | iter=0 24 | if [ -f $ITER_LOG_PATH ]; then 25 | iter=$(cat $ITER_LOG_PATH) 26 | fi 27 | 28 | while (($iter < $TRAINING_ITERS)); do 29 | echo "Iter: $iter" 30 | # Use latest model if past first iteration 31 | if ((iter > 0)); then 32 | POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH 33 | fi 34 | 35 | read -r -d '' generate_commands <$ITER_LOG_PATH 98 | fi 99 | done -------------------------------------------------------------------------------- /examples/scripts/train_rejection_sampling_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 3 | 4 | checkSuccess() { 5 | if [[ $? != 0 ]]; then 6 | echo "FAILED $1" 7 | exit 1 8 | fi 9 | } 10 | 11 | mkdir -p ./checkpoint/llama-3-8b-rejection 12 | GENERATE_OUTPUT=./checkpoint/llama-3-8b-rejection/generate.jsonl 13 | RM_OUTPUT=./checkpoint/llama-3-8b-rejection/rm.jsonl 14 | ITER_LOG_PATH=./checkpoint/llama-3-8b-rejection/iter.log 15 | MODEL_OUTPUT_PATH=./checkpoint/llama-3-8b-rejection 16 | 17 | TRAINING_ITERS=10 18 | ROLLOUT_BATCH_SIZE=10240 19 | 20 | POLICY_MODEL_PATH=OpenRLHF/Llama-3-8b-sft-mixture 21 | 22 | iter=0 23 | if [ -f $ITER_LOG_PATH ]; then 24 | iter=$(cat $ITER_LOG_PATH) 25 | fi 26 | 27 | while (($iter < $TRAINING_ITERS)); do 28 | echo "Iter: $iter" 29 | # Use latest model if past first iteration 30 | if ((iter > 0)); then 31 | POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH 32 | fi 33 | 34 | read -r -d '' generate_commands <$ITER_LOG_PATH 98 | fi 99 | done -------------------------------------------------------------------------------- /openrlhf/utils/distributed_util.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Optional, Union 3 | 4 | import torch 5 | import torch.distributed 6 | from torch.distributed.distributed_c10d import ( 7 | Backend, 8 | PrefixStore, 9 | Store, 10 | _new_process_group_helper, 11 | _world, 12 | default_pg_timeout, 13 | rendezvous, 14 | ) 15 | 16 | 17 | def torch_dist_barrier_and_cuda_sync(): 18 | """Synchronize distributed training and CUDA operations. 19 | This function ensures that: 20 | 1. All distributed processes reach this point (barrier) 21 | 2. All CUDA operations are completed (synchronize) 22 | """ 23 | torch.distributed.barrier() 24 | torch.cuda.synchronize() 25 | 26 | 27 | # Copy from pytorch to allow creating multiple main groups. 28 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py 29 | def init_process_group( 30 | backend: Union[str, Backend] = None, 31 | init_method: Optional[str] = None, 32 | timeout: Optional[timedelta] = None, 33 | world_size: int = -1, 34 | rank: int = -1, 35 | store: Optional[Store] = None, 36 | group_name: str = None, 37 | pg_options: Optional[Any] = None, 38 | ): 39 | assert (store is None) or (init_method is None), "Cannot specify both init_method and store." 40 | 41 | if store is not None: 42 | assert world_size > 0, "world_size must be positive if using store" 43 | assert rank >= 0, "rank must be non-negative if using store" 44 | elif init_method is None: 45 | init_method = "env://" 46 | 47 | if backend: 48 | backend = Backend(backend) 49 | else: 50 | backend = Backend("undefined") 51 | 52 | if timeout is None: 53 | timeout = default_pg_timeout 54 | 55 | # backward compatible API 56 | if store is None: 57 | rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) 58 | store, rank, world_size = next(rendezvous_iterator) 59 | store.set_timeout(timeout) 60 | 61 | # Use a PrefixStore to avoid accidental overrides of keys used by 62 | # different systems (e.g. RPC) in case the store is multi-tenant. 63 | store = PrefixStore(group_name, store) 64 | 65 | # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 66 | # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 67 | # We need to determine the appropriate parameter name based on PyTorch version 68 | pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" 69 | pg, _ = _new_process_group_helper( 70 | world_size, 71 | rank, 72 | [], 73 | backend, 74 | store, 75 | group_name=group_name, 76 | **{pg_options_param_name: pg_options}, 77 | timeout=timeout, 78 | ) 79 | 80 | _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} 81 | 82 | return pg 83 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/vllm_worker_wrap.py: -------------------------------------------------------------------------------- 1 | class WorkerWrap: 2 | def init_process_group( 3 | self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", use_ray=False 4 | ): 5 | """Init torch process group for model weights update""" 6 | import torch 7 | from openrlhf.utils.distributed_util import init_process_group 8 | 9 | assert torch.distributed.is_initialized(), f"default torch process group must be initialized" 10 | assert group_name != "", f"group name must not be empty" 11 | 12 | rank = torch.distributed.get_rank() + rank_offset 13 | if use_ray: 14 | import ray.util.collective as collective 15 | 16 | collective.init_collective_group(world_size=world_size, rank=rank, backend=backend, group_name=group_name) 17 | self._model_update_group = group_name 18 | else: 19 | self._model_update_group = init_process_group( 20 | backend=backend, 21 | init_method=f"tcp://{master_address}:{master_port}", 22 | world_size=world_size, 23 | rank=rank, 24 | group_name=group_name, 25 | ) 26 | self._model_update_with_ray = use_ray 27 | print( 28 | f"init_process_group: master_address={master_address}, master_port={master_port}, ", 29 | f"rank={rank}, world_size={world_size}, group_name={group_name}", 30 | ) 31 | 32 | def update_weight(self, name, dtype, shape, empty_cache=False): 33 | import torch 34 | 35 | """Broadcast weight to all vllm workers from source rank 0 (actor model)""" 36 | if torch.distributed.get_rank() == 0: 37 | print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") 38 | 39 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" 40 | weight = torch.empty(shape, dtype=dtype, device="cuda") 41 | if self._model_update_with_ray: 42 | import ray.util.collective as collective 43 | 44 | collective.broadcast(weight, 0, group_name=self._model_update_group) 45 | else: 46 | torch.distributed.broadcast(weight, 0, group=self._model_update_group) 47 | 48 | self.model_runner.model.load_weights(weights=[(name, weight)]) 49 | 50 | del weight 51 | # TODO: should we empty cache if all weights have updated? 52 | # if empty_cache: 53 | # torch.cuda.empty_cache() 54 | 55 | def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None, empty_cache=False): 56 | import torch 57 | from openrlhf.trainer.ray.utils import get_physical_gpu_id 58 | 59 | if torch.distributed.get_rank() == 0: 60 | print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") 61 | 62 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" 63 | 64 | handle = ipc_handles[get_physical_gpu_id()] 65 | device_id = self.device.index 66 | func, args = handle 67 | list_args = list(args) 68 | # the key is to change device id to the current device id 69 | # in case two processes have different CUDA_VISIBLE_DEVICES 70 | list_args[6] = device_id 71 | weight = func(*list_args) 72 | self.model_runner.model.load_weights(weights=[(name, weight)]) 73 | torch.cuda.synchronize() 74 | -------------------------------------------------------------------------------- /openrlhf/utils/processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | def reward_normalization(objs): 6 | rewards = [float(obj["reward"]) for obj in objs] 7 | rewards = torch.tensor(rewards, dtype=torch.float64) 8 | rewards = (rewards - rewards.mean()) / rewards.std() 9 | for i, obj in enumerate(objs): 10 | obj["reward"] = rewards[i].item() 11 | 12 | 13 | # Conditional SFT 14 | # See https://arxiv.org/abs/2308.12050 15 | DEFAULT_REWARD_PROMPT = "{input} : {reward} " 16 | 17 | 18 | def conditional_sft_processor(args, objs): 19 | if "reward_template" not in args or args.reward_template is None: 20 | reward_template = DEFAULT_REWARD_PROMPT 21 | else: 22 | reward_template = args.reward_template 23 | assert "{input}" in reward_template 24 | assert "{reward}" in reward_template 25 | 26 | if args.normalize_reward: 27 | reward_normalization(objs) 28 | 29 | for obj in tqdm(objs, desc="Conditional SFT process..."): 30 | input = obj["input"] 31 | reward = "{:.2f}".format(float(obj["reward"])) 32 | input = reward_template.replace("{reward}", reward).replace("{input}", input) 33 | obj["input"] = input 34 | 35 | return objs 36 | 37 | 38 | # Rejection Sampling 39 | # See https://arxiv.org/abs/2307.09288 40 | def rejection_sampling_processor(args, objs): 41 | out = {} 42 | for obj in tqdm(objs, desc="Rejection Sampling process...."): 43 | input = obj["input"] 44 | output = obj["output"] 45 | reward = float(obj["reward"]) 46 | 47 | if input not in out: 48 | out[input] = {"output": output, "reward": reward} 49 | elif reward > out[input]["reward"]: 50 | out[input]["reward"] = reward 51 | out[input]["output"] = output 52 | 53 | return [{"input": k, "output": v["output"], "reward": v["reward"]} for k, v in out.items()] 54 | 55 | 56 | # Iterative DPO 57 | # See https://github.com/RLHFlow/Online-RLHF/blob/main/run_loop.sh 58 | def iterative_dpo_processor(args, objs): 59 | out = {} 60 | for obj in tqdm(objs, desc="Iterative DPO process...."): 61 | input = obj["input"] 62 | output = obj["output"] 63 | reward = float(obj["reward"]) 64 | 65 | if input not in out: 66 | out[input] = { 67 | "output": output, 68 | "chosen": output, 69 | "chosen_reward": reward, 70 | "rejected": output, 71 | "rejected_reward": reward, 72 | } 73 | elif reward > out[input]["chosen_reward"]: 74 | out[input]["chosen_reward"] = reward 75 | out[input]["chosen"] = output 76 | elif reward < out[input]["rejected_reward"]: 77 | out[input]["rejected_reward"] = reward 78 | out[input]["rejected"] = output 79 | 80 | return [ 81 | { 82 | "prompt": k, 83 | "chosen": v["chosen"], 84 | "chosen_reward": v["chosen_reward"], 85 | "rejected": v["rejected"], 86 | "rejected_reward": v["rejected_reward"], 87 | } 88 | for k, v in out.items() 89 | ] 90 | 91 | 92 | PROCESSORS = { 93 | "rs": rejection_sampling_processor, 94 | "csft": conditional_sft_processor, 95 | "iter_dpo": iterative_dpo_processor, 96 | } 97 | 98 | 99 | def get_processor(name): 100 | if name in PROCESSORS: 101 | return PROCESSORS[name] 102 | else: 103 | raise ValueError(f"Processor {name} does not exist.") 104 | -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_ray_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p { partition } 4 | #SBATCH -A { account } 5 | #SBATCH -J { jobname } 6 | #SBATCH -N 2 # 64x8x4 7 | #SBATCH -t {LIMIT_TIME} # wall time 8 | #SBATCH --ntasks-per-node=1 # tasks per node 9 | #SBATCH --exclusive # exclusive node access 10 | #SBATCH --mem=0 # all mem avail 11 | #SBATCH --mail-type=FAIL # only send email on failure 12 | #SBATCH --overcommit # needed for pytorch 13 | 14 | # project settings 15 | OPENRLHF_PATH= 16 | MOUNT="$OPENRLHF_PATH:/openrlhf,$HOME/.cache:/root/.cache" 17 | IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" 18 | RAY_VERSION=2.12.0 19 | 20 | JOBLOG="$(realpath .)/train_ppo_llama_ray-$SLURM_JOB_ID.log" 21 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} started ..." &>> ${JOBLOG} 22 | 23 | # launch ray daemon 24 | nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") # Getting the node names 25 | nodes_array=( $nodes ) 26 | node_1=${nodes_array[0]} 27 | ip=$node_1 28 | 29 | port=6379 30 | ip_head=$ip:$port 31 | export ip_head 32 | echo "IP Head: $ip_head" &>> ${JOBLOG} 33 | 34 | echo "STARTING HEAD at $node_1" &>> ${JOBLOG} 35 | srun --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 36 | "pip install ray[default]==$RAY_VERSION \ 37 | && /root/.local/bin/ray start --head --node-ip-address=$ip --port=$port --block" &>> ${JOBLOG} & 38 | sleep 10s 39 | 40 | worker_num=$((SLURM_JOB_NUM_NODES)) #number of nodes other than the head node 41 | for ((i = 1; i < worker_num; i++)); do 42 | node_i=${nodes_array[$i]} 43 | echo "STARTING WORKER $i at $node_i" &>> ${JOBLOG} 44 | srun --nodes=1 --ntasks=1 -w "$node_i" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 45 | "pip install ray[default]==$RAY_VERSION \ 46 | && /root/.local/bin/ray start --address $ip_head --block" &>> ${JOBLOG} & 47 | sleep 1s; 48 | done 49 | 50 | sleep 30s 51 | 52 | # ===== submit ray job ===== 53 | # Job start 54 | srun --overlap --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 55 | "pip install ray[default]==$RAY_VERSION \ 56 | && /root/.local/bin/ray job submit --address=http://localhost:8265 \ 57 | --runtime-env-json='{\"working_dir\": \"/openrlhf\", \"pip\": \"/openrlhf/requirements.txt\"}' \ 58 | -- python3 -m openrlhf.cli.train_ppo_ray \ 59 | --ref_num_nodes 1 \ 60 | --ref_num_gpus_per_node 4 \ 61 | --reward_num_nodes 1 \ 62 | --reward_num_gpus_per_node 4 \ 63 | --critic_num_nodes 1 \ 64 | --critic_num_gpus_per_node 4 \ 65 | --actor_num_nodes 1 \ 66 | --actor_num_gpus_per_node 4 \ 67 | --vllm_num_engines 4 \ 68 | --vllm_tensor_parallel_size 2 \ 69 | --colocate_critic_reward \ 70 | --colocate_actor_ref \ 71 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 72 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ 73 | --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ 74 | --micro_train_batch_size 8 \ 75 | --train_batch_size 128 \ 76 | --micro_rollout_batch_size 16 \ 77 | --rollout_batch_size 1024 \ 78 | --max_samples 100000 \ 79 | --max_epochs 1 \ 80 | --prompt_max_len 1024 \ 81 | --generate_max_len 1024 \ 82 | --zero_stage 3 \ 83 | --bf16 \ 84 | --actor_learning_rate 5e-7 \ 85 | --critic_learning_rate 9e-6 \ 86 | --init_kl_coef 0.01 \ 87 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 88 | --input_key context_messages \ 89 | --apply_chat_template \ 90 | --normalize_reward \ 91 | --adam_offload \ 92 | --flash_attn \ 93 | --packing_samples \ 94 | --vllm_sync_backend nccl \ 95 | --gradient_checkpointing \ 96 | --use_wandb {wandb_token}" &>> ${JOBLOG} 97 | 98 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} -------------------------------------------------------------------------------- /openrlhf/datasets/process_reward_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from openrlhf.utils.utils import convert_token_to_id 7 | from .utils import zero_pad_sequences 8 | 9 | 10 | class ProcessRewardDataset(Dataset): 11 | """ 12 | Dataset for process reward model 13 | 14 | Args: 15 | dataset: dataset for reward model 16 | self.tokenizer: self.tokenizer for reward model 17 | self.max_length: max length of input 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dataset, 23 | tokenizer: Callable, 24 | max_length: int, 25 | strategy, 26 | multiple_of=1, 27 | ) -> None: 28 | super().__init__() 29 | self.tokenizer = tokenizer 30 | self.strategy = strategy 31 | self.max_length = max_length 32 | self.multiple_of = multiple_of 33 | 34 | # chat_template 35 | self.input_key = getattr(self.strategy.args, "input_key", None) 36 | self.label_key = getattr(self.strategy.args, "label_key", None) 37 | self.placeholder_token = getattr(self.strategy.args, "placeholder_token", None) 38 | self.reward_tokens = getattr(self.strategy.args, "reward_tokens", None) 39 | 40 | self.placeholder_token_id = convert_token_to_id(self.placeholder_token, self.tokenizer) 41 | 42 | # Store the processed data in class attributes 43 | self.inputs = dataset[self.input_key] 44 | self.labels = dataset[self.label_key] 45 | 46 | def __len__(self): 47 | length = len(self.inputs) 48 | return length 49 | 50 | def __getitem__(self, idx): 51 | input_token = self.tokenizer( 52 | self.inputs[idx], 53 | max_length=self.max_length, 54 | padding=False, 55 | truncation=True, 56 | return_tensors="pt", 57 | add_special_tokens=False, 58 | ) 59 | 60 | input_ids = input_token["input_ids"] 61 | label_values = self.labels[idx] 62 | assert isinstance(label_values, list), "labels should be a list of strings or numbers" 63 | if isinstance(label_values[0], str): 64 | label_tokens = [] 65 | for label in label_values: 66 | assert ( 67 | self.reward_tokens is None or label in self.reward_tokens 68 | ), f"label should be in reward tokens {self.reward_tokens}, got {label}" 69 | label_tokens.append(convert_token_to_id(label, self.tokenizer)) 70 | 71 | # label_tokens is list of token id (for '+', '-', etc) 72 | label_tensor = torch.tensor(label_tokens, dtype=input_ids.dtype) 73 | else: 74 | # label_values is list of float numbers (for reward values) 75 | label_tensor = torch.tensor(label_values, dtype=torch.float) 76 | # Motivation: inputs_ids maybe truncated to self.max_length, where placeholder_tokens at the end may be removed. 77 | # We should also truncate the labels to match the length of input_ids 78 | # Step 1: Create a mask for placeholder token positions 79 | mask = input_ids == self.placeholder_token_id 80 | # Step 2: Ensure that label_tensor is truncated along the last dimension 81 | # Find the length of the last dimension of the mask 82 | num_placeholders = mask.sum(dim=-1) 83 | # Truncate label_tensor along the last dimension to match num_placeholders 84 | truncated_labels = label_tensor[..., : num_placeholders.max()] 85 | # Step 3: Update labels at placeholder token positions 86 | labels = torch.full_like(input_ids, -100) 87 | labels[mask] = truncated_labels 88 | 89 | return ( 90 | input_ids, 91 | input_token["attention_mask"], 92 | labels, 93 | ) 94 | 95 | def collate_fn(self, item_list): 96 | input_ids = [] 97 | input_masks = [] 98 | label_ids = [] 99 | for input_id, input_mask, label_id in item_list: 100 | input_ids.append(input_id) 101 | input_masks.append(input_mask) 102 | label_ids.append(label_id) 103 | 104 | padding_side = "right" 105 | input_ids = zero_pad_sequences(input_ids, side=padding_side, value=self.tokenizer.pad_token_id) 106 | input_masks = zero_pad_sequences(input_masks, side=padding_side) 107 | label_ids = zero_pad_sequences(label_ids, side=padding_side, value=self.tokenizer.pad_token_id) 108 | return input_ids, input_masks, label_ids 109 | -------------------------------------------------------------------------------- /openrlhf/cli/serve_rm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | 4 | import torch 5 | import uvicorn 6 | from fastapi import FastAPI, Request 7 | from fastapi.responses import JSONResponse 8 | 9 | from openrlhf.models import get_llm_for_sequence_regression 10 | from openrlhf.utils import get_tokenizer 11 | from openrlhf.utils.logging_utils import init_logger 12 | 13 | logger = init_logger(__name__) 14 | 15 | 16 | def strip_sequence(text, pad_token, eos_token): 17 | pad_token_escaped = re.escape(pad_token) 18 | eos_token_escaped = re.escape(eos_token) 19 | 20 | pattern = f"^({eos_token_escaped}|{pad_token_escaped})+" 21 | text = re.sub(pattern, "", text) 22 | 23 | pattern = f"({eos_token_escaped}|{pad_token_escaped})+$" 24 | text = re.sub(pattern, "", text) 25 | return text 26 | 27 | 28 | class RewardModelProxy: 29 | def __init__(self, args): 30 | self.reward_model = get_llm_for_sequence_regression( 31 | args.reward_pretrain, 32 | "reward", 33 | normalize_reward=args.normalize_reward, 34 | use_flash_attention_2=args.flash_attn, 35 | bf16=args.bf16, 36 | load_in_4bit=args.load_in_4bit, 37 | value_head_prefix=args.value_head_prefix, 38 | device_map="auto", 39 | ) 40 | self.reward_model.eval() 41 | 42 | self.tokenizer = get_tokenizer( 43 | args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer 44 | ) 45 | self.max_length = args.max_len 46 | self.batch_size = args.batch_size 47 | 48 | def get_reward(self, queries, prompts): 49 | if self.batch_size is None: 50 | batch_size = len(queries) 51 | else: 52 | batch_size = self.batch_size 53 | 54 | # remove pad_token 55 | for i in range(len(queries)): 56 | queries[i] = ( 57 | strip_sequence(queries[i], self.tokenizer.pad_token, self.tokenizer.eos_token) 58 | + self.tokenizer.eos_token 59 | ) 60 | logger.info(f"queries[0]: {queries[0]}") 61 | 62 | scores = [] 63 | # batch 64 | with torch.no_grad(): 65 | for i in range(0, len(queries), batch_size): 66 | inputs = self.tokenize_fn( 67 | queries[i : min(len(queries), i + batch_size)], device=self.reward_model.device 68 | ) 69 | r = self.reward_model(inputs["input_ids"], inputs["attention_mask"]) 70 | r = r.tolist() 71 | scores.extend(r) 72 | return scores 73 | 74 | def tokenize_fn(self, texts, device): 75 | batch = self.tokenizer( 76 | texts, 77 | return_tensors="pt", 78 | add_special_tokens=False, 79 | max_length=self.max_length, 80 | padding=True, 81 | truncation=True, 82 | ) 83 | return {k: v.to(device) for k, v in batch.items()} 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | # Reward Model 89 | parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") 90 | parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") 91 | parser.add_argument("--value_head_prefix", type=str, default="score") 92 | parser.add_argument("--max_len", type=int, default="2048") 93 | 94 | parser.add_argument("--port", type=int, default=5000, help="Port number for the server") 95 | parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server") 96 | 97 | # Performance 98 | parser.add_argument("--load_in_4bit", action="store_true", default=False) 99 | parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") 100 | parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") 101 | parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) 102 | parser.add_argument("--batch_size", type=int, default=None) 103 | 104 | # ModelScope parameters 105 | parser.add_argument("--use_ms", action="store_true", default=False) 106 | 107 | args = parser.parse_args() 108 | 109 | if args.use_ms: 110 | from modelscope.utils.hf_util import patch_hub 111 | 112 | # Patch hub to download models from modelscope to speed up. 113 | patch_hub() 114 | 115 | # server 116 | reward_model = RewardModelProxy(args) 117 | app = FastAPI() 118 | 119 | @app.post("/get_reward") 120 | async def get_reward(request: Request): 121 | data = await request.json() 122 | queries = data.get("query") 123 | prompts = data.get("prompts") 124 | rewards = reward_model.get_reward(queries, prompts) 125 | result = {"rewards": rewards} 126 | logger.info(f"Sent JSON: {result}") 127 | return JSONResponse(result) 128 | 129 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 130 | -------------------------------------------------------------------------------- /openrlhf/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import interleave_datasets, load_dataset, load_from_disk 4 | from transformers import AutoTokenizer 5 | 6 | 7 | def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True): 8 | tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast) 9 | tokenizer.padding_side = padding_side 10 | # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM. 11 | # https://github.com/facebookresearch/llama-recipes/pull/196 12 | if tokenizer.pad_token is None: 13 | tokenizer.pad_token = tokenizer.eos_token 14 | tokenizer.pad_token_id = tokenizer.eos_token_id 15 | if model is not None: 16 | model.config.pad_token_id = tokenizer.pad_token_id 17 | 18 | return tokenizer 19 | 20 | 21 | def get_strategy(args): 22 | from openrlhf.utils.deepspeed import DeepspeedStrategy 23 | 24 | strategy = DeepspeedStrategy( 25 | seed=getattr(args, "seed", 42), 26 | full_determinism=getattr(args, "full_determinism", False), 27 | max_norm=getattr(args, "max_norm", 1.0), 28 | micro_train_batch_size=getattr(args, "micro_train_batch_size", 1), 29 | train_batch_size=getattr(args, "train_batch_size", 128), 30 | zero_stage=args.zero_stage, 31 | bf16=getattr(args, "bf16", True), 32 | args=args, 33 | ) 34 | return strategy 35 | 36 | 37 | def blending_datasets( 38 | datasets, 39 | probabilities=None, 40 | strategy=None, 41 | seed=42, 42 | max_count=1e8, 43 | stopping_strategy="all_exhausted", 44 | dataset_split="train", 45 | ): 46 | """Blend multiple datasets with optional probability sampling. 47 | 48 | Args: 49 | datasets (str): Comma-separated list of dataset paths 50 | probabilities (str, optional): Comma-separated list of probabilities for sampling. 51 | If None, datasets will be concatenated without probability sampling. 52 | strategy: Training strategy object 53 | seed (int): Random seed 54 | max_count (int): Maximum number of samples per dataset 55 | """ 56 | datasets = datasets.split(",") 57 | if probabilities is not None: 58 | probabilities = list(map(float, probabilities.split(","))) 59 | assert len(probabilities) == len(datasets) 60 | 61 | data_list = [] 62 | for i, dataset in enumerate(datasets): 63 | dataset = dataset.strip() 64 | strategy.print(f"dataset: {dataset}") 65 | 66 | data_dir = dataset.split("@")[1].strip() if "@" in dataset else None 67 | dataset = dataset.split("@")[0].strip() 68 | dataset_basename = os.path.basename(dataset) 69 | 70 | ext = os.path.splitext(dataset)[-1] 71 | # local python script 72 | if ext == ".py" or ( 73 | os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py")) 74 | ): 75 | data = load_dataset(dataset, trust_remote_code=True) 76 | strategy.print(f"loaded {dataset} with python script") 77 | # local text file 78 | elif ext in [".json", ".jsonl", ".csv", ".parquet", ".arrow"]: 79 | ext = ext.lower().strip(".") 80 | if ext == "jsonl": 81 | ext = "json" 82 | data = load_dataset(ext, data_files=dataset) 83 | strategy.print(f"loaded {dataset} with data_files={dataset}") 84 | # local dataset saved with `datasets.Dataset.save_to_disk` 85 | elif os.path.isdir(dataset): 86 | try: 87 | data = load_from_disk(dataset) 88 | strategy.print(f"loaded {dataset} from disk") 89 | except Exception as e: 90 | strategy.print(f"failed to load {dataset} from disk: {e}") 91 | data = load_dataset(dataset, data_dir=data_dir) 92 | strategy.print(f"loaded {dataset} from files") 93 | # remote/local folder or common file 94 | elif strategy.args.use_ms: 95 | from modelscope.msdatasets import MsDataset 96 | 97 | namespace, dataset = dataset.split("/") 98 | data = MsDataset.load(dataset, namespace=namespace) 99 | else: 100 | data = load_dataset(dataset, data_dir=data_dir) 101 | strategy.print(f"loaded {dataset} from files") 102 | 103 | # Select dataset 104 | if dataset_split and dataset_split in data: 105 | data = data[dataset_split] 106 | data = data.select(range(min(max_count, len(data)))) 107 | data_list.append(data) 108 | 109 | # merge datasets 110 | if strategy.is_rank_0(): 111 | print(data_list) 112 | 113 | # If probabilities is None, concatenate datasets directly 114 | if probabilities is None: 115 | from datasets import concatenate_datasets 116 | 117 | dataset = concatenate_datasets(data_list) 118 | else: 119 | dataset = interleave_datasets( 120 | data_list, 121 | probabilities=probabilities, 122 | seed=seed, 123 | stopping_strategy=stopping_strategy, 124 | ) 125 | 126 | return dataset 127 | 128 | 129 | def convert_token_to_id(token, tokenizer): 130 | if isinstance(token, str): 131 | token = tokenizer.encode(token, add_special_tokens=False) 132 | assert len(token) == 1 133 | return token[0] 134 | else: 135 | raise ValueError("token should be int or str") 136 | -------------------------------------------------------------------------------- /docs/ppo_examples.md: -------------------------------------------------------------------------------- 1 | ### Training settings 2 | 3 | * Base model: LLaMA2 7B 4 | * SFT dataset: Ocra 500k 5 | * RM datasets: Anthropic HH, LMSys, OASST, WebGPT ~ 240K+ 6 | * PPO prompts: alpaca, Anthropic HH, OASST ~ 100K 7 | 8 | ### Question: Please introduce the game GTA5 9 | 10 | **SFT** 11 | 12 | Human: Please introduce the game GTA5 13 | 14 | Assistant: GTA5, or Grand Theft Auto V, is an action-adventure video game developed by Rockstar North and published by Rockstar Games. It was released in 2013 for the PlayStation 3 and Xbox 360, and later in 2014 for the PlayStation 4 and Xbox One. The game is set in the fictional city of Los Santos, based in Los Angeles, and follows the story of three criminals as they plan and execute heists. 15 | 16 | GTA5 is an open-world game, meaning that players can explore the game's vast and detailed world at their leisure, completing missions, engaging in side activities, and interacting with other characters. The game features an extensive array of vehicles, from cars to boats to planes, and players can use them to travel around the city and engage in various forms of criminal activity. 17 | 18 | GTA5 is known for its immersive gameplay, stunning graphics, and rich storytelling. It has received critical acclaim and has become one of the best-selling video games of all time, with over 150 million copies sold worldwide. 19 | 20 | **PPO** 21 | 22 | Human: Please introduce the game GTA5 23 | 24 | Assistant: 1. **Game Name:** Grand Theft Auto V (GTA V) 25 | 26 | 2. **Platforms:** GTA V is available on PlayStation 3 and Xbox 360, and later released on PlayStation 4, Xbox One, and Microsoft Windows. 27 | 28 | 3. **Genre:** Action-adventure game 29 | 30 | 4. **Gameplay:** GTA V is an open-world game where players can explore the fictional city of Los Santos and its surrounding areas, completing missions, engaging in side activities, and interacting with non-player characters. The game features a third-person perspective and allows players to freely roam the game world, engage in combat, and complete objectives. 31 | 32 | 5. **Storyline:** The game follows the story of three protagonists, Michael De Santa, Franklin Clinton, and Trevor Philips, as they become embroiled in a series of heists and other criminal activities. The game's storyline is considered one of the strongest aspects of GTA 33 | 34 | 6. **Graphics and Sound:** GTA V features stunning graphics and realistic character animations, making the game world feel alive and immersive. The soundtrack and voice acting are also highly praised, with a diverse and engaging musical score that complements the game's action-packed gameplay. 35 | 36 | 7. **Reception:** GTA V received critical acclaim upon its release, with many praising its storyline, gameplay, graphics, and sound. It is considered one of the best video games of all time and has sold over 150 million copies worldwide, making it one of the best-selling video games of all time. 37 | 38 | In summary, GTA V is an action-adventure game set in a massive open world, with a complex and engaging storyline, stunning graphics, and highly praised soundtrack and voice acting. It is considered one of the greatest video games of all time and continues to be a popular choice among gamers worldwide. 39 | 40 | 41 | ### Question: Which GPU is better, AMD or NVIDIA? 42 | 43 | **SFT** 44 | 45 | Human: Which GPU is better, AMD or NVIDIA? 46 | 47 | Assistant: 48 | 49 | The performance and suitability of a GPU (Graphics Processing Unit) depend on various factors, including the specific task or application being used, the model and generation of the GPU, and the system it is being used in. Both AMD and NVIDIA are well-known and reputable manufacturers of GPUs, and they both have a wide range of models and generations. 50 | 51 | In general, the performance of a GPU can vary depending on the specific task or application being used. For example, AMD's GPUs are known for their performance in gaming and video editing, while NVIDIA's GPUs are known for their performance in machine learning and deep learning. 52 | 53 | It is important to note that the performance of a GPU can also vary depending on the generation of the GPU. Newer generations of GPUs tend to have better performance and features than older generations. 54 | 55 | Ultimately, the best GPU for you will depend on your specific needs and the tasks you plan to use it for. It is recommended to research and compare the specific models and generations of AMD and NVIDIA GPUs to determine which one is the best fit for your needs. 56 | 57 | **PPO** 58 | 59 | Human: Which GPU is better, AMD or NVIDIA? 60 | 61 | Assistant: 1. **Performance:** In terms of raw performance, NVIDIA typically outperforms AMD in most cases, especially in gaming and video editing. 62 | 63 | 2. **Price:** NVIDIA GPUs tend to be more expensive than AMD GPUs, although this can vary depending on the specific model and its features. 64 | 65 | 3. **Architecture:** NVIDIA's GPU architecture is based on the CUDA (Compute Unified Device Architecture) platform, which allows for more efficient processing of parallel computing tasks. AMD's GPU architecture, on the other hand, is based on the Stream Compute architecture, which is designed for general-purpose computing. 66 | 67 | 4. **Application Support:** NVIDIA's GPUs have better support for certain applications, particularly in the fields of deep learning and artificial intelligence. AMD's GPUs, however, have better support for certain games and video editing software. 68 | 69 | In summary, NVIDIA's GPUs tend to be more powerful and efficient for certain tasks, while AMD's GPUs are better suited for other tasks and have better support for certain applications. Ultimately, the choice between the two depends on the specific needs and preferences of the user. 70 | -------------------------------------------------------------------------------- /openrlhf/cli/interactive_chat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from openrlhf.models import Actor 4 | from openrlhf.utils import get_tokenizer 5 | 6 | 7 | def generate(args): 8 | # dummy strategy 9 | class Empty: 10 | pass 11 | 12 | dummy_strategy = Empty() 13 | dummy_strategy.print = print 14 | dummy_strategy.is_rank_0 = lambda: True 15 | dummy_strategy.args = args 16 | 17 | # configure model 18 | model = Actor( 19 | args.pretrain, 20 | use_flash_attention_2=args.flash_attn, 21 | bf16=args.bf16, 22 | load_in_4bit=args.load_in_4bit, 23 | device_map="auto", 24 | ) 25 | 26 | # configure tokenizer 27 | tokenizer = get_tokenizer( 28 | args.pretrain, model.model, "left", dummy_strategy, use_fast=not args.disable_fast_tokenizer 29 | ) 30 | 31 | if args.ta_prompt: 32 | with open(args.ta_prompt, "r") as f: 33 | user_prompt = f.read() 34 | else: 35 | user_prompt = "" 36 | 37 | if args.apply_chat_template: 38 | conversations = [] 39 | 40 | while True: 41 | inputs = input("Please enter a prompt (or type 'exit' to quit): ") 42 | if inputs.strip().lower() == "exit": 43 | print("Exiting program...") 44 | break 45 | if inputs.strip().lower() == "clear": 46 | if args.apply_chat_template: 47 | conversations = [] 48 | else: 49 | user_prompt = "" 50 | continue 51 | 52 | # get input prompt 53 | if args.apply_chat_template: 54 | conversations.append({"role": "user", "content": inputs}) 55 | user_prompt = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=True) 56 | else: 57 | user_prompt = user_prompt + "\n" + args.input_template.format(inputs) 58 | 59 | if args.enable_csft: 60 | user_prompt += args.csft_prompt.strip() + " " 61 | 62 | user_prompt_len = len(user_prompt) 63 | input_ids = tokenizer.encode(user_prompt, return_tensors="pt").to(torch.cuda.current_device()) 64 | outputs = model.generate( 65 | input_ids=input_ids, 66 | use_cache=True, 67 | max_length=args.max_len, 68 | do_sample=not args.greedy_sampling, 69 | top_p=args.top_p, 70 | early_stopping=False, 71 | num_beams=1, 72 | temperature=args.temperature, 73 | repetition_penalty=args.repetition_penalty, 74 | pad_token_id=tokenizer.pad_token_id, 75 | eos_token_id=tokenizer.eos_token_id, 76 | ) 77 | 78 | if args.apply_chat_template: 79 | generated_ids = outputs[0][:, input_ids.shape[1] :] 80 | response = tokenizer.batch_decode( 81 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 82 | )[0] 83 | conversations.append({"role": "assistant", "content": response}) 84 | else: 85 | user_prompt = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)[0] 86 | response = user_prompt[user_prompt_len:] 87 | 88 | print(response) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") 94 | parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") 95 | parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) 96 | 97 | # QLora 98 | parser.add_argument("--load_in_4bit", action="store_true", default=False, help="Use QLoRA") 99 | 100 | # Sampling 101 | parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") 102 | parser.add_argument("--max_len", type=int, default=4096) 103 | parser.add_argument("--greedy_sampling", action="store_true", default=False, help="Use Greedy sampling") 104 | parser.add_argument("--top_p", type=float, default=0.9, help="top_p for Sampling") 105 | parser.add_argument("--temperature", type=float, default=0.2, help="temperature for Sampling") 106 | parser.add_argument("--repetition_penalty", type=float, default=1.2) 107 | parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ") 108 | parser.add_argument( 109 | "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" 110 | ) 111 | 112 | parser.add_argument("--ta_prompt", type=str, default=None) 113 | parser.add_argument("--enable_csft", action="store_true", default=False) 114 | parser.add_argument("--csft_prompt", type=str, default=": 5.00", help="conditional SFT prompt") 115 | 116 | # ModelScope parameters 117 | parser.add_argument("--use_ms", action="store_true", default=False) 118 | 119 | args = parser.parse_args() 120 | 121 | if args.input_template and "{}" not in args.input_template: 122 | print("[Warning] {} not in args.input_template, set to None") 123 | args.input_template = None 124 | 125 | if args.input_template and "\\n" in args.input_template: 126 | print( 127 | "[Warning] input_template contains \\n chracters instead of newline. " 128 | "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." 129 | ) 130 | 131 | if args.use_ms: 132 | from modelscope.utils.hf_util import patch_hub 133 | 134 | # Patch hub to download models from modelscope to speed up. 135 | patch_hub() 136 | 137 | print(args) 138 | generate(args) 139 | -------------------------------------------------------------------------------- /openrlhf/datasets/unpaired_preference_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from .utils import zero_pad_sequences 7 | 8 | 9 | def preprocess_data( 10 | data, input_template=None, input_key=None, output_key=None, label_key=None, apply_chat_template=None 11 | ): 12 | """ 13 | Preprocess data from raw dataset to prompt, response, label 14 | 15 | Args: 16 | data: raw data from dataset 17 | """ 18 | label = data[label_key] 19 | 20 | if apply_chat_template: 21 | if output_key: 22 | prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True) 23 | response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt) :] 24 | else: 25 | prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) 26 | response = apply_chat_template(data[input_key], tokenize=False)[len(prompt) :] 27 | else: 28 | prompt = data[input_key] 29 | response = data[output_key] 30 | if input_template: 31 | prompt = input_template.format(prompt) 32 | return prompt, response, label 33 | 34 | 35 | class UnpairedPreferenceDataset(Dataset): 36 | """ 37 | Unpaired preference dataset for algorithm, like KTO 38 | 39 | Args: 40 | dataset: raw dataset 41 | self.tokenizer: self.tokenizer for model 42 | self.max_length: max length of input 43 | """ 44 | 45 | def __init__( 46 | self, dataset, tokenizer: Callable, max_length: int, strategy, input_template=None, num_processors=8 47 | ) -> None: 48 | super().__init__() 49 | self.tokenizer = tokenizer 50 | self.strategy = strategy 51 | self.max_length = max_length 52 | 53 | # chat_template 54 | self.input_template = input_template 55 | self.input_key = getattr(self.strategy.args, "input_key", None) 56 | self.output_key = getattr(self.strategy.args, "output_key", None) 57 | self.label_key = getattr(self.strategy.args, "label_key", None) 58 | self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 59 | 60 | if self.apply_chat_template: 61 | self.apply_chat_template = self.tokenizer.apply_chat_template 62 | tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) 63 | if tokenizer_chat_template: 64 | self.tokenizer.chat_template = tokenizer_chat_template 65 | 66 | # Parallel loading datasets 67 | processed_dataset = dataset.map( 68 | self.process_data, remove_columns=dataset.column_names, num_proc=num_processors 69 | ) 70 | 71 | # Filter out None values if necessary 72 | processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) 73 | 74 | # Store the processed data in class attributes 75 | self.prompts = processed_dataset["prompt"] 76 | self.responses = processed_dataset["response"] 77 | self.labels = processed_dataset["label"] 78 | self.prompt_ids_lens = processed_dataset["prompt_ids_len"] 79 | 80 | def process_data(self, data): 81 | prompt, response, label = preprocess_data( 82 | data, self.input_template, self.input_key, self.output_key, self.label_key, self.apply_chat_template 83 | ) 84 | prompt_token = self.tokenizer( 85 | prompt, 86 | max_length=self.max_length, 87 | padding=False, 88 | truncation=True, 89 | return_tensors="pt", 90 | add_special_tokens=False, 91 | ) 92 | prompt_ids_len = prompt_token["attention_mask"].int().sum().item() 93 | 94 | # filter the sample whose length is greater than max_length (2 for answer length) 95 | if prompt_ids_len >= self.max_length - 2: 96 | prompt = None 97 | 98 | return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len} 99 | 100 | def __len__(self): 101 | return len(self.prompts) 102 | 103 | def __getitem__(self, index): 104 | return self.prompts[index], self.responses[index], self.labels[index], self.prompt_ids_lens[index] 105 | 106 | def collate_fn(self, item_list): 107 | def tokenizer(prompt, response): 108 | text = (prompt + response).rstrip("\n") 109 | if not text.endswith(self.tokenizer.eos_token): 110 | text += " " + self.tokenizer.eos_token 111 | inputs = self.tokenizer( 112 | text, 113 | max_length=self.max_length, 114 | padding=False, 115 | truncation=True, 116 | return_tensors="pt", 117 | add_special_tokens=False, 118 | ) 119 | 120 | inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id 121 | inputs["attention_mask"][0][-1] = True 122 | return inputs["input_ids"], inputs["attention_mask"] 123 | 124 | tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], [] 125 | for prompt, response, label, prompt_ids_len in item_list: 126 | input_ids, attention_mask = tokenizer(prompt, response) 127 | tot_ids.append(input_ids) 128 | tot_masks.append(attention_mask) 129 | tot_labels.append(label) 130 | prompt_ids_lens.append(prompt_ids_len) 131 | 132 | # add unmatched y'| x (used to estimate the KL divergence between policy and reference) 133 | for idx in range(len(item_list)): 134 | next_idx = (idx + 1) % len(item_list) 135 | input_ids, attention_mask = tokenizer(item_list[idx][0], item_list[next_idx][1]) 136 | tot_ids.append(input_ids) 137 | tot_masks.append(attention_mask) 138 | tot_labels.append(-1) 139 | prompt_ids_lens.append(item_list[idx][3]) 140 | 141 | input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id) 142 | attention_mask = zero_pad_sequences(tot_masks, side="right") 143 | return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens 144 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/deepspeed_utils.py: -------------------------------------------------------------------------------- 1 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 2 | 3 | 4 | def get_train_ds_config( 5 | offload, 6 | adam_offload=True, 7 | stage=2, 8 | bf16=True, 9 | max_norm=1.0, 10 | zpg=8, 11 | grad_accum_dtype=None, 12 | overlap_comm=False, 13 | use_ds_universal_ckpt=False, 14 | deepcompile=False, 15 | tensor_parallel_size=1, 16 | ): 17 | device = "cpu" if offload else "none" 18 | zero_opt_dict = { 19 | "stage": stage, 20 | "offload_param": {"device": device}, 21 | "offload_optimizer": { 22 | "device": "cpu" if adam_offload else "none", 23 | "pin_memory": True, 24 | }, 25 | "sub_group_size": "auto", 26 | "stage3_max_live_parameters": "auto", 27 | "stage3_max_reuse_distance": "auto", 28 | "stage3_param_persistence_threshold": "auto", 29 | "stage3_prefetch_bucket_size": "auto", 30 | "reduce_bucket_size": "auto", 31 | # ZeRO++ 32 | "zero_hpz_partition_size": zpg, 33 | "zero_quantized_weights": False, 34 | "zero_quantized_gradients": False, 35 | } 36 | if overlap_comm: 37 | zero_opt_dict["overlap_comm"] = True 38 | zero_opt_dict["contiguous_gradients"] = True 39 | if stage == 3: 40 | zero_opt_dict["reduce_scatter"] = True 41 | 42 | return { 43 | "steps_per_print": 100, 44 | "zero_optimization": zero_opt_dict, 45 | "bf16": { 46 | "enabled": bf16, 47 | }, 48 | "gradient_clipping": max_norm, 49 | "prescale_gradients": False, 50 | "wall_clock_breakdown": False, 51 | "data_types": {"grad_accum_dtype": grad_accum_dtype}, 52 | "checkpoint": { 53 | "load_universal": use_ds_universal_ckpt, 54 | }, 55 | "compile": { 56 | "deepcompile": deepcompile, 57 | }, 58 | "tensor_parallel": { 59 | "autotp_size": tensor_parallel_size, 60 | }, 61 | } 62 | 63 | 64 | def get_eval_ds_config( 65 | offload, 66 | stage=0, 67 | bf16=True, 68 | deepcompile=False, 69 | tensor_parallel_size=1, 70 | ): 71 | # At least for 0.16.6, DeepCompile hasn't support pure inference mode 72 | # https://github.com/deepspeedai/DeepSpeed/pull/7225 73 | deepcompile = False 74 | 75 | zero_opt_dict = { 76 | "stage": stage, 77 | "stage3_max_live_parameters": "auto", 78 | "stage3_max_reuse_distance": "auto", 79 | "stage3_param_persistence_threshold": "auto", 80 | "stage3_prefetch_bucket_size": "auto", 81 | "offload_param": { 82 | "device": "cpu" if offload else "none", 83 | "pin_memory": True, 84 | }, 85 | } 86 | return { 87 | "steps_per_print": 100, 88 | "zero_optimization": zero_opt_dict, 89 | "bf16": { 90 | "enabled": bf16, 91 | }, 92 | "gradient_clipping": 1.0, 93 | "prescale_gradients": False, 94 | "wall_clock_breakdown": False, 95 | "compile": { 96 | "deepcompile": deepcompile, 97 | }, 98 | "tensor_parallel": { 99 | "autotp_size": tensor_parallel_size, 100 | }, 101 | } 102 | 103 | 104 | def get_optimizer_grouped_parameters( 105 | model, 106 | weight_decay, 107 | no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], 108 | ): 109 | optimizer_grouped_parameters = [ 110 | { 111 | "params": [ 112 | p 113 | for n, p in model.named_parameters() 114 | if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) 115 | ], 116 | "weight_decay": weight_decay, 117 | }, 118 | { 119 | "params": [ 120 | p 121 | for n, p in model.named_parameters() 122 | if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) 123 | ], 124 | "weight_decay": 0.0, 125 | }, 126 | ] 127 | return optimizer_grouped_parameters 128 | 129 | 130 | def _z3_params_to_fetch(param_list): 131 | return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] 132 | 133 | 134 | def offload_deepspeed_states(model, pin_memory=True, non_blocking=True): 135 | zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage'] 136 | adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu" 137 | 138 | # state offloading not required when using Adam optimizer offloading 139 | if adam_offload: 140 | return 141 | 142 | if zero_stage != 3: 143 | raise NotImplementedError("Only Zero stage 3 is currently supported") 144 | 145 | # if zero_stage == 3 and not adam_offload: 146 | import deepspeed 147 | import torch 148 | from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum 149 | 150 | offload_state_types = [ 151 | OffloadStateTypeEnum.optim_states, 152 | OffloadStateTypeEnum.contiguous_grad_buffer, 153 | OffloadStateTypeEnum.hp_params, 154 | ] 155 | 156 | if deepspeed.__version__ >= "0.16.5": 157 | # These offload types are fixed in https://github.com/deepspeedai/DeepSpeed/pull/7050 158 | offload_state_types += [ 159 | OffloadStateTypeEnum.lp_grads, 160 | # OffloadStateTypeEnum.lp_params, 161 | ] 162 | 163 | model.optimizer.offload_states( 164 | include=offload_state_types, 165 | device=OffloadDeviceEnum.cpu, 166 | pin_memory=pin_memory, 167 | non_blocking=non_blocking, 168 | ) 169 | model.empty_partition_cache() 170 | torch.cuda.empty_cache() 171 | torch.distributed.barrier() 172 | torch.cuda.synchronize() 173 | 174 | 175 | def reload_deepspeed_states(model, non_blocking=True): 176 | zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage'] 177 | adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu" 178 | 179 | # state offloading not required when using Adam optimizer offloading 180 | if adam_offload: 181 | return 182 | 183 | if zero_stage != 3: 184 | raise NotImplementedError("Only Zero stage 3 is currently supported") 185 | 186 | # if zero_stage == 3 and not adam_offload: 187 | import torch 188 | 189 | model.reload_states(non_blocking=non_blocking) 190 | torch.cuda.empty_cache() 191 | torch.distributed.barrier() 192 | torch.cuda.synchronize() 193 | -------------------------------------------------------------------------------- /openrlhf/utils/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterator, Optional, TypeVar 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | __all__ = ["DistributedSampler"] 11 | 12 | 13 | _T_co = TypeVar("_T_co", covariant=True) 14 | 15 | 16 | # Adapted from https://github.com/pytorch/pytorch/blob/5298acb5c76855bc5a99ae10016efc86b27949bd/torch/utils/data/distributed.py 17 | class DistributedSampler(Sampler[_T_co]): 18 | r"""Sampler that restricts data loading to a subset of the dataset. 19 | 20 | It is especially useful in conjunction with 21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 22 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 24 | original dataset that is exclusive to it. 25 | 26 | .. note:: 27 | Dataset is assumed to be of constant size and that any instance of it always 28 | returns the same elements in the same order. 29 | 30 | Args: 31 | dataset: Dataset used for sampling. 32 | num_replicas (int, optional): Number of processes participating in 33 | distributed training. By default, :attr:`world_size` is retrieved from the 34 | current distributed group. 35 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 36 | By default, :attr:`rank` is retrieved from the current distributed 37 | group. 38 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 39 | indices. 40 | seed (int, optional): random seed used to shuffle the sampler if 41 | :attr:`shuffle=True`. This number should be identical across all 42 | processes in the distributed group. Default: ``0``. 43 | drop_last (bool, optional): if ``True``, then the sampler will drop the 44 | tail of the data to make it evenly divisible across the number of 45 | replicas. If ``False``, the sampler will add extra indices to make 46 | the data evenly divisible across the replicas. Default: ``False``. 47 | 48 | .. warning:: 49 | In distributed mode, calling the :meth:`set_epoch` method at 50 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 51 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 52 | the same ordering will be always used. 53 | 54 | Example:: 55 | 56 | >>> # xdoctest: +SKIP 57 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 58 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 59 | ... sampler=sampler) 60 | >>> for epoch in range(start_epoch, n_epochs): 61 | ... if is_distributed: 62 | ... sampler.set_epoch(epoch) 63 | ... train(loader) 64 | """ 65 | 66 | def __init__( 67 | self, 68 | dataset: Dataset, 69 | num_replicas: Optional[int] = None, 70 | rank: Optional[int] = None, 71 | shuffle: bool = True, 72 | seed: int = 0, 73 | drop_last: bool = False, 74 | consumed_samples=0, 75 | ) -> None: 76 | if num_replicas is None: 77 | if not dist.is_available(): 78 | raise RuntimeError("Requires distributed package to be available") 79 | num_replicas = dist.get_world_size() 80 | if rank is None: 81 | if not dist.is_available(): 82 | raise RuntimeError("Requires distributed package to be available") 83 | rank = dist.get_rank() 84 | if rank >= num_replicas or rank < 0: 85 | raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") 86 | self.dataset = dataset 87 | self.num_replicas = num_replicas 88 | self.rank = rank 89 | self.epoch = 0 90 | self.drop_last = drop_last 91 | # If the dataset length is evenly divisible by # of replicas, then there 92 | # is no need to drop any data, since the dataset will be split equally. 93 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] 94 | # Split to nearest available length that is evenly divisible. 95 | # This is to ensure each rank receives the same amount of data when 96 | # using this Sampler. 97 | self.num_samples = math.ceil( 98 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] 99 | ) 100 | else: 101 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] 102 | self.total_size = self.num_samples * self.num_replicas 103 | self.shuffle = shuffle 104 | self.seed = seed 105 | self.consumed_indicies = consumed_samples // self.num_replicas 106 | 107 | def __iter__(self) -> Iterator[_T_co]: 108 | if self.shuffle: 109 | # deterministically shuffle based on epoch and seed 110 | g = torch.Generator() 111 | g.manual_seed(self.seed + self.epoch) 112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 113 | else: 114 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 115 | 116 | if not self.drop_last: 117 | # add extra samples to make it evenly divisible 118 | padding_size = self.total_size - len(indices) 119 | if padding_size <= len(indices): 120 | indices += indices[:padding_size] 121 | else: 122 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 123 | else: 124 | # remove tail of data to make it evenly divisible. 125 | indices = indices[: self.total_size] 126 | assert len(indices) == self.total_size 127 | 128 | # subsample 129 | indices = indices[self.rank : self.total_size : self.num_replicas] 130 | # skip consumed_samples 131 | indices = indices[self.consumed_indicies :] 132 | assert len(indices) == self.num_samples - self.consumed_indicies 133 | 134 | return iter(indices) 135 | 136 | def __len__(self) -> int: 137 | return self.num_samples - self.consumed_indicies 138 | 139 | def set_epoch(self, epoch: int, consumed_samples=0) -> None: 140 | r""" 141 | Set the epoch for this sampler. 142 | 143 | When :attr:`shuffle=True`, this ensures all replicas 144 | use a different random ordering for each epoch. Otherwise, the next iteration of this 145 | sampler will yield the same ordering. 146 | 147 | Args: 148 | epoch (int): Epoch number. 149 | """ 150 | self.epoch = epoch 151 | self.consumed_indicies = consumed_samples // self.num_replicas 152 | -------------------------------------------------------------------------------- /openrlhf/models/ring_attn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input 4 | from flash_attn.utils.distributed import all_gather 5 | 6 | RING_ATTN_GROUP = None 7 | 8 | 9 | def set_ring_attn_group(group): 10 | global RING_ATTN_GROUP 11 | RING_ATTN_GROUP = group 12 | 13 | 14 | def get_ring_attn_group(): 15 | return RING_ATTN_GROUP 16 | 17 | 18 | def reset_ring_attn_position_ids(start, end, packed_seq_lens): 19 | """ 20 | Calculate position ids for packed_seq_ids[start:end]. 21 | For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8, 22 | the position ids will be [2, 0, 1, 0, 1, 2]. 23 | 24 | Args: 25 | start: the start position 26 | end: the end position 27 | packed_seq_lens: the sequence lengths of packed sequences 28 | """ 29 | position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device()) 30 | offset = 0 31 | for seqlen in packed_seq_lens: 32 | seq_start = max(offset, start) 33 | seq_end = min(offset + seqlen, end) 34 | if seq_start < seq_end: 35 | position_ids[0, seq_start - start : seq_end - start] = torch.arange(seq_start - offset, seq_end - offset) 36 | 37 | offset += seqlen 38 | if offset >= end: 39 | break 40 | return position_ids 41 | 42 | 43 | def update_ring_attn_params(cu_seqlens): 44 | """ 45 | Calculate the cu_seqlens for the current forward pass and pass the value to 46 | the substituted ring_flash_attn. 47 | 48 | Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding. 49 | """ 50 | assert RING_ATTN_GROUP is not None 51 | 52 | from ring_flash_attn import update_ring_flash_attn_params 53 | 54 | update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP) 55 | 56 | 57 | def get_tensor_in_current_ring_attn_rank(tensors: list[torch.Tensor] | torch.Tensor, ring_attn_group, pad_id): 58 | """ 59 | Deal with padding and slice the tensor to current ring_attn_rank. 60 | Args: 61 | tensors: Each tensor shaped (batch, seqlen) or (1, total_seqs) 62 | ring_attn_group: Ring attention group 63 | pad_id: Padding id 64 | Returns: 65 | Processed tensor 66 | """ 67 | if isinstance(tensors, torch.Tensor): 68 | tensors = [tensors] 69 | ring_attn_rank = dist.get_rank(group=ring_attn_group) 70 | ring_attn_size = dist.get_world_size(group=ring_attn_group) 71 | seqlen = tensors[0].shape[-1] 72 | total_seq_len = tensors[0].numel() 73 | ring_attn_pad_len = (ring_attn_size - seqlen % ring_attn_size) % ring_attn_size 74 | output_tensors = [] 75 | for tensor in tensors: 76 | if tensor.numel() != total_seq_len: 77 | raise ValueError(f"tensor.numel() {tensor.numel()} != total_seq_len {total_seq_len}") 78 | tensor = torch.nn.functional.pad(tensor, (0, ring_attn_pad_len), value=pad_id) 79 | local_seq_len = tensor.numel() // ring_attn_size 80 | start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len 81 | tensor = tensor[:, start:end] 82 | output_tensors.append(tensor) 83 | if len(output_tensors) == 1: 84 | output_tensors = output_tensors[0] 85 | return output_tensors, ring_attn_pad_len 86 | 87 | 88 | def unpad_and_slice_tensor(sequences, attention_mask, ring_attn_group): 89 | """ 90 | Unpad and slice tensor for distributed training with ring attention. 91 | 92 | This function performs several operations: 93 | 1. Removes padding, unpads sequences from (batch, seqlen) to (1, total_seqs) 94 | 2. Adapts to ring_attn_group, pads sequences to be divisible by ring_attn_group 95 | 3. Slices the sequences for the current ring_attn_rank 96 | 97 | Example: 98 | >>> # Input sequences shape: (batch=2, seqlen=4) 99 | >>> sequences = [[1, 2, 3, 0], [4, 5, 0, 0]] # 0 is padding 100 | >>> attention_mask = [[1, 1, 1, 0], [1, 1, 0, 0]] 101 | >>> # After unpad: 102 | >>> # sequences: [1, 2, 3, 4, 5] # shape (1, total_seqs=5) 103 | >>> # If ring_attn_group size is 2, it will pad to length 6 104 | >>> # Then slice for current rank (e.g., rank 0 gets [1,2,3], rank 1 gets [4,5,0]) 105 | 106 | Args: 107 | sequences: Input sequences tensor of shape (batch, seqlen) 108 | attention_mask: Attention mask tensor for the sequences 109 | ring_attn_group: Ring attention group for distributed processing 110 | 111 | Returns: 112 | tuple: Processed sequences and related tensors for ring attention 113 | """ 114 | rolled_sequences = torch.roll(sequences, shifts=-1, dims=1) 115 | sequences, indices, cu_seqlens, _, _ = unpad_input(sequences.unsqueeze(-1), attention_mask) 116 | sequences = sequences.transpose(0, 1) # (1, total_seqs) 117 | rolled_sequences = index_first_axis( 118 | rearrange(rolled_sequences.unsqueeze(-1), "b s ... -> (b s) ..."), indices 119 | ).transpose( 120 | 0, 1 121 | ) # (1, total_seqs) 122 | position_ids = torch.clip(torch.cumsum(attention_mask, dim=-1) - 1, min=0, max=None) 123 | position_ids = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose( 124 | 0, 1 125 | ) # (1, total_seqs) 126 | ring_attn_pad_len = 0 127 | if ring_attn_group is not None: 128 | (sequences, position_ids, rolled_sequences), ring_attn_pad_len = get_tensor_in_current_ring_attn_rank( 129 | [sequences, position_ids, rolled_sequences], ring_attn_group, 0 130 | ) 131 | cu_seqlens[-1] += ring_attn_pad_len 132 | update_ring_attn_params(cu_seqlens) 133 | return sequences, position_ids, rolled_sequences, ring_attn_pad_len, indices 134 | 135 | 136 | def gather_and_pad_tensor(tensor, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen): 137 | """ 138 | Gather and pad tensor data (such as logits, log_probs, etc.). 139 | 140 | Example: 141 | >>> # Input tensor from each rank (shape: (1, local_seq_len)) 142 | >>> # Rank 0: [1, 2, 3] 143 | >>> # Rank 1: [4, 5, 0] # 0 is padding 144 | >>> # After all_gather: 145 | >>> # tensor: [1, 2, 3, 4, 5, 0] # shape (1, total_seqs=6) 146 | >>> # After removing padding (ring_attn_pad_len=1): 147 | >>> # tensor: [1, 2, 3, 4, 5] # shape (1, total_seqs=5) 148 | >>> # After pad_input with original indices: 149 | >>> # tensor: [[1, 2, 3, 0], [4, 5, 0, 0]] # shape (batch=2, seqlen=4) 150 | 151 | Args: 152 | tensor: Input tensor, can be logits, log_probs, etc. 153 | ring_attn_group: Ring attention group 154 | ring_attn_pad_len: Padding length 155 | indices: Indices 156 | batch: Batch size 157 | seqlen: Sequence length 158 | 159 | Returns: 160 | Padded tensor 161 | """ 162 | if ring_attn_group is not None: 163 | tensor = all_gather(tensor.transpose(0, 1), ring_attn_group).transpose(0, 1) # (1, total_seqs) 164 | if ring_attn_pad_len > 0: 165 | tensor = tensor[:, :-ring_attn_pad_len] 166 | tensor = pad_input(tensor.transpose(0, 1), indices, batch, seqlen).squeeze(-1) # (batch, seqlen) 167 | return tensor 168 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import ABC 3 | from dataclasses import dataclass 4 | from typing import List, Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | from .experience_maker import Experience 11 | 12 | 13 | @dataclass 14 | class BufferItem: 15 | """BufferItem is an item of experience data. 16 | 17 | Shapes of each tensor: 18 | sequences: (S) 19 | action_log_probs: (A) 20 | base_action_log_probs: (A) 21 | values: (1) 22 | returns: (1) 23 | advantages: (1) 24 | attention_mask: (S) 25 | action_mask: (A) 26 | 27 | "A" is the number of actions. 28 | """ 29 | 30 | sequences: torch.Tensor 31 | action_log_probs: torch.Tensor 32 | base_action_log_probs: torch.Tensor 33 | values: torch.Tensor 34 | returns: torch.Tensor 35 | advantages: torch.Tensor 36 | attention_mask: Optional[torch.LongTensor] 37 | action_mask: Optional[torch.BoolTensor] 38 | info: Optional[dict] 39 | 40 | 41 | def split_experience_batch(experience: Experience) -> List[BufferItem]: 42 | batch_size = len(experience.sequences) 43 | batch_kwargs = [{} for _ in range(batch_size)] 44 | keys = ( 45 | "sequences", 46 | "action_log_probs", 47 | "base_action_log_probs", 48 | "values", 49 | "returns", 50 | "advantages", 51 | "attention_mask", 52 | "action_mask", 53 | ) 54 | for key in keys: 55 | value = getattr(experience, key) 56 | if value is None: 57 | for i in range(batch_size): 58 | batch_kwargs[i][key] = None 59 | continue 60 | vals = value 61 | if isinstance(vals, torch.Tensor): 62 | vals = torch.unbind(vals) 63 | assert batch_size == len(vals) 64 | for i, v in enumerate(vals): 65 | batch_kwargs[i][key] = v 66 | 67 | for i in range(batch_size): 68 | batch_kwargs[i]["info"] = {} 69 | for k, v in experience.info.items(): 70 | vals = torch.unbind(v) 71 | assert batch_size == len(vals) 72 | for i, vv in enumerate(vals): 73 | if isinstance(vv, torch.Tensor): 74 | assert vv.numel() == 1, f"info[{k}] must be a scalar tensor, but got {vv.shape}" 75 | vv = vv.item() 76 | batch_kwargs[i]["info"][k] = vv 77 | 78 | items = [BufferItem(**kwargs) for kwargs in batch_kwargs] 79 | return items 80 | 81 | 82 | def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor: 83 | assert side in ("left", "right") 84 | max_len = max(seq.size(0) for seq in sequences) 85 | padded_sequences = [] 86 | for seq in sequences: 87 | pad_len = max_len - seq.size(0) 88 | padding = (pad_len, 0) if side == "left" else (0, pad_len) 89 | padded_sequences.append(F.pad(seq, padding)) 90 | return torch.stack(padded_sequences, dim=0) 91 | 92 | 93 | def make_experience_batch(items: List[BufferItem], packing_samples=False) -> Experience: 94 | kwargs = {} 95 | keys = ( 96 | "sequences", 97 | "action_log_probs", 98 | "base_action_log_probs", 99 | "values", 100 | "returns", 101 | "advantages", 102 | "attention_mask", 103 | "action_mask", 104 | ) 105 | for key in keys: 106 | vals = [getattr(item, key) for item in items] 107 | vals = zero_pad_sequences(vals, "left") if vals[0] is not None else None 108 | kwargs[key] = vals 109 | 110 | kwargs["info"] = {} 111 | for key in items[0].info.keys(): 112 | vals = torch.tensor([item.info[key] for item in items]) 113 | kwargs["info"][key] = vals 114 | return Experience(**kwargs) 115 | 116 | 117 | def remove_padding_in_sequences(items): 118 | for item in items: 119 | seq, act_log_prob, base_act_log_prob, value, ret, adv, att_mask, act_mask = ( 120 | item.sequences, 121 | item.action_log_probs, 122 | item.base_action_log_probs, 123 | item.values, 124 | item.returns, 125 | item.advantages, 126 | item.attention_mask, 127 | item.action_mask, 128 | ) 129 | right_pad = (1 - act_mask.long()).sum() 130 | right_pad = None if right_pad == 0 else -right_pad 131 | 132 | # left_pad for seq and att_mask 133 | left_pad = att_mask.long().argmax() 134 | ( 135 | item.sequences, 136 | item.action_log_probs, 137 | item.base_action_log_probs, 138 | item.values, 139 | item.returns, 140 | item.advantages, 141 | item.attention_mask, 142 | item.action_mask, 143 | ) = ( 144 | seq[left_pad:right_pad], 145 | act_log_prob[:right_pad], 146 | base_act_log_prob[:right_pad] if item.base_action_log_probs is not None else None, 147 | value[:right_pad] if item.values is not None else None, 148 | ret[:right_pad], 149 | adv[:right_pad], 150 | att_mask[left_pad:right_pad], 151 | act_mask[:right_pad], 152 | ) 153 | return items 154 | 155 | 156 | class NaiveReplayBuffer(ABC): 157 | """Naive replay buffer class. It stores experience. 158 | 159 | Args: 160 | sample_batch_size (int): Batch size when sampling. 161 | limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. 162 | cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. 163 | """ 164 | 165 | def __init__( 166 | self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True, packing_samples: bool = False 167 | ) -> None: 168 | super().__init__() 169 | self.sample_batch_size = sample_batch_size 170 | # limit <= 0 means unlimited 171 | self.limit = limit 172 | self.cpu_offload = cpu_offload 173 | self.packing_samples = packing_samples 174 | self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}") 175 | self.items: List[BufferItem] = [] 176 | 177 | @torch.no_grad() 178 | def append(self, experience: Experience) -> None: 179 | if self.cpu_offload: 180 | experience.to_device(torch.device("cpu")) 181 | items = split_experience_batch(experience) 182 | items = remove_padding_in_sequences(items) 183 | self.items.extend(items) 184 | if self.limit > 0: 185 | samples_to_remove = len(self.items) - self.limit 186 | if samples_to_remove > 0: 187 | self.items = self.items[samples_to_remove:] 188 | 189 | def clear(self) -> None: 190 | self.items.clear() 191 | 192 | @torch.no_grad() 193 | def sample(self) -> Experience: 194 | items = random.sample(self.items, self.sample_batch_size) 195 | experience = make_experience_batch(items, self.packing_samples) 196 | if self.cpu_offload: 197 | experience.to_device(self.target_device) 198 | return experience 199 | 200 | def __len__(self) -> int: 201 | return len(self.items) 202 | 203 | def __getitem__(self, idx: int) -> BufferItem: 204 | return self.items[idx] 205 | 206 | def collate_fn(self, batch) -> Experience: 207 | experience = make_experience_batch(batch, self.packing_samples) 208 | return experience 209 | -------------------------------------------------------------------------------- /openrlhf/models/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def compute_approx_kl( 8 | log_probs: torch.Tensor, 9 | log_probs_base: torch.Tensor, 10 | kl_estimator: str = "k1", 11 | ) -> torch.Tensor: 12 | """ 13 | Compute the approximate KL divergence between two distributions. 14 | Schulman blog: http://joschu.net/blog/kl-approx.html 15 | 16 | Args: 17 | log_probs: Log probabilities of the new distribution. 18 | log_probs_base: Log probabilities of the base distribution. 19 | """ 20 | 21 | if kl_estimator == "k1": 22 | log_ratio = log_probs.float() - log_probs_base.float() 23 | 24 | # The k2 estimator is the non negative kl approximation in 25 | # http://joschu.net/blog/kl-approx.html 26 | # The k2_loss is approximately equivalent to the 27 | # one-step KL divergence penalty with the k1 estimator 28 | # used in https://arxiv.org/pdf/2310.10505. 29 | if kl_estimator == "k2": 30 | log_ratio = log_probs.float() - log_probs_base.float() 31 | log_ratio = log_ratio**2 / 2.0 32 | 33 | # The k3 estimator is the non negative kl approximation in 34 | # http://joschu.net/blog/kl-approx.html 35 | if kl_estimator == "k3": 36 | log_ratio = log_probs.float() - log_probs_base.float() 37 | log_ratio = -log_ratio 38 | log_ratio = log_ratio.exp() - 1 - log_ratio 39 | 40 | return log_ratio 41 | 42 | 43 | def compute_reward( 44 | r: Union[torch.Tensor, float], 45 | kl_coef: float, 46 | kl: Union[torch.Tensor, list[torch.Tensor]], 47 | action_mask: Optional[torch.Tensor] = None, 48 | reward_clip_range: Tuple[float, float] = None, 49 | ) -> Union[torch.Tensor, list[torch.Tensor]]: 50 | if kl_coef <= 0.0: 51 | kl_coef = 0.0 52 | 53 | if reward_clip_range: 54 | r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1]) 55 | 56 | kl_reward = -kl_coef * kl 57 | # The following code is equivalent to: 58 | # 59 | # last_reward = torch.zeros_like(kl) 60 | # for i in range(last_reward.size(0)): 61 | # for t in reversed(range(last_reward.size(1))): 62 | # if action_mask[i][t] > 0.5: 63 | # last_reward[i][t] = r[i] 64 | # break 65 | # 66 | eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True) 67 | last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype)) 68 | 69 | reward = last_reward + kl_reward 70 | 71 | return reward 72 | 73 | 74 | def _logsumexp_by_chunk(logits: torch.Tensor, chunk_size: int = 1024) -> torch.Tensor: 75 | seq_len = logits.shape[0] 76 | logsumexp_values = torch.zeros((seq_len), device=logits.device, dtype=logits.dtype) 77 | for s_idx in range(0, seq_len, chunk_size): 78 | end_idx = min(s_idx + chunk_size, seq_len) 79 | logsumexp_values[s_idx:end_idx] = torch.logsumexp(logits[s_idx:end_idx], dim=-1) 80 | 81 | return logsumexp_values 82 | 83 | 84 | def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: 85 | if temperature != 1.0: 86 | logits.div_(temperature) 87 | # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881 88 | if logits.dtype in [torch.float32, torch.float64]: 89 | batch_dim = logits.shape[:-1] 90 | last_dim = logits.shape[-1] 91 | try: 92 | from flash_attn.ops.triton.cross_entropy import cross_entropy_loss 93 | 94 | output = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1)) 95 | log_probs_labels = -output[0].view(*batch_dim) 96 | except ImportError: 97 | logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) 98 | logsumexp_values = _logsumexp_by_chunk(logits.reshape(-1, last_dim)) 99 | logsumexp_values = logsumexp_values.view(*batch_dim) 100 | log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) 101 | else: 102 | log_probs_labels = [] 103 | for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption 104 | row_log_probs = F.log_softmax(row_logits, dim=-1) 105 | row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) 106 | log_probs_labels.append(row_log_probs_labels) 107 | log_probs_labels = torch.stack(log_probs_labels) 108 | return log_probs_labels 109 | 110 | 111 | def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor: 112 | if mask is None: 113 | return tensor.mean(dim=dim) 114 | return (tensor * mask).sum(dim=dim) / mask.sum(dim=dim) 115 | 116 | 117 | def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: 118 | tensor = tensor * mask 119 | mean = masked_mean(tensor, mask, dim=dim) 120 | mean_centered = tensor - mean 121 | var = masked_mean(mean_centered**2, mask, dim=dim) 122 | return mean_centered * var.clamp(min=eps).rsqrt() 123 | 124 | 125 | def compute_entropy(logits: torch.Tensor): 126 | pd = torch.nn.functional.softmax(logits, dim=-1) 127 | entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) 128 | return entropy 129 | 130 | 131 | def process_sequences(sequences: torch.Tensor, input_len, eos_token_id, pad_token_id): 132 | """ 133 | Process generated sequences to create attention masks and action masks. 134 | 135 | Args: 136 | sequences (torch.Tensor): Generated sequence tensor 137 | input_len (int): Length of the input sequence 138 | eos_token_id (int): Token ID for the end-of-sequence token 139 | pad_token_id (int): Token ID for the padding token 140 | 141 | Returns: 142 | tuple: A tuple containing three elements: 143 | - sequences: Original sequence 144 | - attention_mask: Attention mask indicating valid token positions 145 | - action_mask: Action mask indicating valid action token positions 146 | """ 147 | # Create initial attention mask by marking positions that are neither EOS nor padding tokens 148 | attention_mask = (sequences.ne(eos_token_id) & sequences.ne(pad_token_id)).to(dtype=torch.long) 149 | seq_length = attention_mask.size(1) 150 | 151 | # Find the position of the last valid token in each sequence 152 | eos_indices = seq_length - attention_mask.long().fliplr().argmax(dim=1, keepdim=True).clamp(min=1) 153 | 154 | # Handle cases where EOS tokens might appear in the middle of the prompt (for Llama3 and Qwen2 models) 155 | # Find the position of the first valid token in each sequence 156 | first_token_indices = attention_mask.long().argmax(dim=1, keepdim=True) 157 | # Create position mask 158 | mask = torch.arange(seq_length).unsqueeze(0).expand(sequences.size(0), -1).to(device=sequences.device) 159 | # Generate final attention mask, keeping only positions between first and last valid tokens 160 | attention_mask = (mask >= first_token_indices) & (mask <= eos_indices).to(dtype=torch.long) 161 | 162 | # In reinforcement learning, the state transition is represented as: 163 | # state_i (current token) + action_i (next token) -> state_i+1 (next token) 164 | # Generate state sequence from input_len-1 to second-to-last token 165 | state_seq = sequences[:, input_len - 1 : -1] 166 | # Generate action mask indicating valid action token positions 167 | action_mask = state_seq.ne(eos_token_id) & state_seq.ne(pad_token_id) 168 | action_mask[:, 0] = 1 169 | 170 | return sequences, attention_mask, action_mask 171 | -------------------------------------------------------------------------------- /openrlhf/datasets/reward_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from torch.utils.data import Dataset 4 | 5 | from .utils import exist_and_not_none, zero_pad_sequences 6 | 7 | 8 | def preprocess_data( 9 | data, 10 | input_template=None, 11 | prompt_key=None, 12 | chosen_key="chosen", 13 | rejected_key="rejected", 14 | apply_chat_template=None, 15 | is_dpo=False, 16 | ) -> str: 17 | if apply_chat_template: 18 | if prompt_key: 19 | prompt = apply_chat_template(data[prompt_key], tokenize=False, add_generation_prompt=True) 20 | chosen = apply_chat_template(data[prompt_key] + data[chosen_key], tokenize=False)[len(prompt) :] 21 | rejected = apply_chat_template(data[prompt_key] + data[rejected_key], tokenize=False)[len(prompt) :] 22 | else: 23 | prompt = "" 24 | chosen = apply_chat_template(data[chosen_key], tokenize=False) 25 | rejected = apply_chat_template(data[rejected_key], tokenize=False) 26 | 27 | if is_dpo: 28 | prompt = apply_chat_template(data[chosen_key][:-1], tokenize=False, add_generation_prompt=True) 29 | chosen = chosen[len(prompt) :] 30 | rejected = rejected[len(prompt) :] 31 | else: 32 | if prompt_key: 33 | prompt = data[prompt_key] 34 | if input_template: 35 | prompt = input_template.format(prompt) 36 | else: 37 | prompt = "" 38 | chosen = data[chosen_key] 39 | rejected = data[rejected_key] 40 | 41 | # margin loss 42 | margin = data["margin"] if exist_and_not_none(data, "margin") else 0 43 | 44 | return prompt, chosen, rejected, margin 45 | 46 | 47 | class RewardDataset(Dataset): 48 | """ 49 | Dataset for reward model 50 | 51 | Args: 52 | dataset: dataset for reward model 53 | self.tokenizer: self.tokenizer for reward model 54 | self.max_length: max length of input 55 | """ 56 | 57 | def __init__( 58 | self, 59 | dataset, 60 | tokenizer: Callable, 61 | max_length: int, 62 | strategy, 63 | input_template=None, 64 | is_dpo=False, 65 | num_processors=8, 66 | ) -> None: 67 | super().__init__() 68 | self.is_dpo = is_dpo 69 | self.tokenizer = tokenizer 70 | self.strategy = strategy 71 | self.max_length = max_length 72 | 73 | # chat_template 74 | self.input_template = input_template 75 | self.prompt_key = getattr(self.strategy.args, "prompt_key", None) 76 | self.chosen_key = getattr(self.strategy.args, "chosen_key", None) 77 | self.rejected_key = getattr(self.strategy.args, "rejected_key", None) 78 | self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 79 | 80 | if self.apply_chat_template: 81 | self.apply_chat_template = self.tokenizer.apply_chat_template 82 | tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) 83 | if tokenizer_chat_template: 84 | self.tokenizer.chat_template = tokenizer_chat_template 85 | 86 | # Parallel loading datasets 87 | processed_dataset = dataset.map( 88 | self.process_data, remove_columns=dataset.column_names, num_proc=num_processors 89 | ) 90 | 91 | # Filter out None values if necessary 92 | processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) 93 | 94 | # Store the processed data in class attributes 95 | self.prompts = processed_dataset["prompt"] 96 | self.chosens = processed_dataset["chosen"] 97 | self.rejects = processed_dataset["reject"] 98 | self.extras = processed_dataset["extra"] 99 | 100 | def process_data(self, data): 101 | prompt, chosen, reject, margin = preprocess_data( 102 | data, 103 | self.input_template, 104 | self.prompt_key, 105 | self.chosen_key, 106 | self.rejected_key, 107 | self.apply_chat_template, 108 | self.is_dpo, 109 | ) 110 | 111 | if self.is_dpo: 112 | prompt_token = self.tokenizer( 113 | prompt, 114 | max_length=self.max_length, 115 | padding=False, 116 | truncation=True, 117 | return_tensors="pt", 118 | add_special_tokens=False, 119 | ) 120 | prompt_ids_len = prompt_token["attention_mask"].int().sum().item() 121 | 122 | # Filter the sample whose length is greater than max_length (2 for answer length) 123 | if prompt_ids_len >= self.max_length - 2: 124 | prompt = None 125 | 126 | return { 127 | "prompt": prompt, 128 | "chosen": chosen, 129 | "reject": reject, 130 | "extra": prompt_ids_len if self.is_dpo else margin, 131 | } 132 | 133 | def __len__(self): 134 | length = len(self.chosens) 135 | return length 136 | 137 | def __getitem__(self, idx): 138 | prompt, chosen, reject, extra = self.prompts[idx], self.chosens[idx], self.rejects[idx], self.extras[idx] 139 | 140 | chosen = (prompt + chosen).rstrip("\n") 141 | if not chosen.endswith(self.tokenizer.eos_token): 142 | chosen += " " + self.tokenizer.eos_token 143 | chosen_token = self.tokenizer( 144 | chosen, 145 | max_length=self.max_length, 146 | padding=False, 147 | truncation=True, 148 | return_tensors="pt", 149 | add_special_tokens=False, 150 | ) 151 | 152 | reject = (prompt + reject).rstrip("\n") 153 | if not reject.endswith(self.tokenizer.eos_token): 154 | reject += " " + self.tokenizer.eos_token 155 | reject_token = self.tokenizer( 156 | reject, 157 | max_length=self.max_length, 158 | padding=False, 159 | truncation=True, 160 | return_tensors="pt", 161 | add_special_tokens=False, 162 | ) 163 | 164 | # to avoid EOS_token truncation 165 | chosen_token["input_ids"][0][-1] = self.tokenizer.eos_token_id 166 | reject_token["input_ids"][0][-1] = self.tokenizer.eos_token_id 167 | chosen_token["attention_mask"][0][-1] = True 168 | reject_token["attention_mask"][0][-1] = True 169 | 170 | return ( 171 | chosen_token["input_ids"], 172 | chosen_token["attention_mask"], 173 | reject_token["input_ids"], 174 | reject_token["attention_mask"], 175 | extra, 176 | ) 177 | 178 | def collate_fn(self, item_list): 179 | chosen_ids = [] 180 | chosen_masks = [] 181 | reject_ids = [] 182 | rejects_masks = [] 183 | extras = [] 184 | for chosen_id, chosen_mask, reject_id, rejects_mask, extra in item_list: 185 | chosen_ids.append(chosen_id) 186 | chosen_masks.append(chosen_mask) 187 | reject_ids.append(reject_id) 188 | rejects_masks.append(rejects_mask) 189 | extras.append(extra) 190 | 191 | if self.is_dpo: 192 | padding_side = "right" 193 | else: 194 | padding_side = "left" 195 | chosen_ids = zero_pad_sequences(chosen_ids, side=padding_side, value=self.tokenizer.pad_token_id) 196 | chosen_masks = zero_pad_sequences(chosen_masks, side=padding_side) 197 | reject_ids = zero_pad_sequences(reject_ids, side=padding_side, value=self.tokenizer.pad_token_id) 198 | rejects_masks = zero_pad_sequences(rejects_masks, side=padding_side) 199 | return chosen_ids, chosen_masks, reject_ids, rejects_masks, extras 200 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/vllm_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import queue 3 | from collections import defaultdict 4 | from typing import Any, List 5 | 6 | import ray 7 | from ray.util.placement_group import placement_group 8 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 9 | from vllm.inputs import TokensPrompt 10 | 11 | from openrlhf.utils.logging_utils import init_logger 12 | 13 | from .utils import get_bundle_indices, ray_noset_visible_devices 14 | 15 | logger = init_logger(__name__) 16 | 17 | 18 | @ray.remote 19 | def get_all_env_variables(): 20 | import os 21 | 22 | return os.environ 23 | 24 | 25 | @ray.remote 26 | class LLMRayActor: 27 | 28 | def __init__(self, *args, bundle_indices: list = None, **kwargs): 29 | noset_visible_devices = kwargs.pop("noset_visible_devices") 30 | if kwargs.get("distributed_executor_backend") == "ray": 31 | # a hack to make the script work. 32 | # stop ray from manipulating *_VISIBLE_DEVICES 33 | # at the top-level when the distributed_executor_backend is ray. 34 | os.environ.pop("CUDA_VISIBLE_DEVICES", None) 35 | os.environ.pop("ROCR_VISIBLE_DEVICES", None) 36 | os.environ.pop("HIP_VISIBLE_DEVICES", None) 37 | elif noset_visible_devices: 38 | # We need to set CUDA_VISIBLE_DEVICES to the ray assigned GPU 39 | # when the distributed_executor_backend is not ray and 40 | # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set. 41 | os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0]) 42 | 43 | num_gpus = kwargs.pop("num_gpus") 44 | if bundle_indices is not None: 45 | os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(num_gpus) 46 | os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) 47 | print(f"creating LLM with bundle_indices={bundle_indices}") 48 | 49 | # Number of actors that will send prompt to this engine 50 | self.requests = {} 51 | self.response_queues = defaultdict(queue.Queue) 52 | 53 | import vllm 54 | 55 | full_determinism = kwargs.pop("full_determinism", False) 56 | if full_determinism or vllm.__version__ == "0.8.2": 57 | # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 58 | os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" 59 | 60 | self.llm = vllm.LLM(*args, **kwargs) 61 | 62 | def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray): 63 | return self.llm.collective_rpc( 64 | "init_process_group", 65 | args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray), 66 | ) 67 | 68 | def update_weight(self, name, dtype, shape, empty_cache=False): 69 | return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) 70 | 71 | def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): 72 | return self.llm.collective_rpc("update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)) 73 | 74 | def reset_prefix_cache(self): 75 | self.llm.llm_engine.reset_prefix_cache() 76 | 77 | def sleep(self, level=1): 78 | self.llm.sleep(level=level) 79 | 80 | def wake_up(self): 81 | self.llm.wake_up() 82 | 83 | def add_requests(self, actor_rank, *, sampling_params, prompt_token_ids): 84 | """ 85 | Save the requests from actors and generate responses when all actors have sent their requests 86 | """ 87 | self.requests[actor_rank] = prompt_token_ids 88 | num_requests = [] 89 | requests: list[TokensPrompt] = [] 90 | for actor_rank, request in self.requests.items(): 91 | num_requests.append((actor_rank, len(request))) 92 | for r in request: 93 | requests.append(TokensPrompt(prompt_token_ids=r)) 94 | 95 | responses = self.llm.generate(prompts=requests, sampling_params=sampling_params) 96 | offset = 0 97 | self.responses = {} 98 | for actor_rank, num in num_requests: 99 | self.response_queues[actor_rank].put(responses[offset : offset + num]) 100 | offset += num 101 | 102 | self.requests = {} 103 | 104 | def get_responses(self, actor_rank): 105 | """ 106 | Return the responses for the actor with the given rank 107 | """ 108 | return self.response_queues[actor_rank].get() 109 | 110 | 111 | def create_vllm_engines( 112 | num_engines: int, 113 | tensor_parallel_size: int, 114 | pretrain: str, 115 | seed: int, 116 | full_determinism: bool, 117 | enable_prefix_caching: bool, 118 | enforce_eager: bool, 119 | max_model_len: int, 120 | shared_pg=None, 121 | gpu_memory_utilization=None, 122 | vllm_enable_sleep=False, 123 | ): 124 | import vllm 125 | 126 | assert vllm.__version__ >= "0.8.1", "OpenRLHF only supports vllm >= 0.8.1" 127 | 128 | vllm_engines = [] 129 | noset_visible_devices = ray_noset_visible_devices(ray.get(get_all_env_variables.remote())) 130 | distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray" 131 | use_hybrid_engine = shared_pg is not None 132 | num_gpus = int(tensor_parallel_size == 1) 133 | if use_hybrid_engine and tensor_parallel_size == 1: 134 | # every worker will use 0.2 GPU, so that we can schedule 135 | # 2 instances on the same GPUs. 136 | num_gpus = 0.2 137 | 138 | if not use_hybrid_engine: 139 | # Create a big placement group to ensure that all engines are packed 140 | bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_engines * tensor_parallel_size)] 141 | shared_pg = placement_group(bundles, strategy="PACK") 142 | ray.get(shared_pg.ready()) 143 | 144 | for i in range(num_engines): 145 | bundle_indices = None 146 | if tensor_parallel_size > 1: 147 | bundle_indices = get_bundle_indices(shared_pg, i, tensor_parallel_size) 148 | 149 | scheduling_strategy = PlacementGroupSchedulingStrategy( 150 | placement_group=shared_pg, 151 | placement_group_capture_child_tasks=True, 152 | placement_group_bundle_index=bundle_indices[0] if bundle_indices else i, 153 | ) 154 | 155 | vllm_engines.append( 156 | LLMRayActor.options( 157 | num_cpus=num_gpus, 158 | num_gpus=num_gpus, 159 | scheduling_strategy=scheduling_strategy, 160 | ).remote( 161 | model=pretrain, 162 | enforce_eager=enforce_eager, 163 | worker_extension_cls="openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap", 164 | tensor_parallel_size=tensor_parallel_size, 165 | seed=seed + i, 166 | distributed_executor_backend=distributed_executor_backend, 167 | max_model_len=max_model_len, 168 | enable_prefix_caching=enable_prefix_caching, 169 | dtype="bfloat16", 170 | trust_remote_code=True, 171 | full_determinism=full_determinism, 172 | gpu_memory_utilization=gpu_memory_utilization, 173 | bundle_indices=bundle_indices, 174 | num_gpus=0.2 if use_hybrid_engine else 1, 175 | enable_sleep_mode=vllm_enable_sleep, 176 | noset_visible_devices=noset_visible_devices, 177 | ) 178 | ) 179 | 180 | if vllm_enable_sleep: 181 | batch_vllm_engine_call(vllm_engines, "sleep") 182 | 183 | return vllm_engines 184 | 185 | 186 | def batch_vllm_engine_call(engines: List[Any], method_name: str, *args, rank_0_only: bool = True, **kwargs): 187 | """ 188 | Batch call a method on multiple vLLM engines. 189 | Args: 190 | engines: List of vLLM engine instances 191 | method_name: Name of the method to call 192 | rank_0_only: Only execute on rank 0 if True 193 | *args: Positional arguments to pass to the method 194 | **kwargs: Keyword arguments to pass to the method 195 | Returns: 196 | List of results from ray.get() if on rank 0, None otherwise 197 | """ 198 | import torch 199 | 200 | if torch.distributed.is_initialized(): 201 | if rank_0_only and torch.distributed.get_rank() != 0: 202 | return None 203 | 204 | refs = [] 205 | for engine in engines: 206 | method = getattr(engine, method_name) 207 | refs.append(method.remote(*args, **kwargs)) 208 | 209 | return ray.get(refs) 210 | -------------------------------------------------------------------------------- /openrlhf/datasets/sft_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from .utils import zero_pad_sequences 7 | 8 | 9 | def preprocess_data( 10 | data, input_template=None, input_key="input", output_key=None, apply_chat_template=None, multiturn=False 11 | ): 12 | if apply_chat_template: 13 | if output_key: 14 | prompt_message = data[input_key] 15 | response_message = data[output_key] 16 | 17 | if isinstance(prompt_message, str) and isinstance(response_message, str): 18 | prompt_message = [{"role": "user", "content": prompt_message}] 19 | response_message = [{"role": "assistant", "content": response_message}] 20 | 21 | prompt = apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) 22 | response = apply_chat_template(prompt_message + response_message, tokenize=False)[len(prompt) :] 23 | else: 24 | prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) 25 | response = apply_chat_template(data[input_key], tokenize=False)[len(prompt) :] 26 | else: 27 | prompt = data[input_key] 28 | if input_template: 29 | prompt = input_template.format(prompt) 30 | # output_key is None for continue pretrain 31 | response = data[output_key] if output_key else "" 32 | return prompt, response 33 | 34 | 35 | class SFTDataset(Dataset): 36 | """ 37 | Dataset for SFT model 38 | 39 | Args: 40 | dataset: dataset for SFT model 41 | tokenizer: tokenizer for SFT model 42 | max_length: max length of input 43 | """ 44 | 45 | def __init__( 46 | self, 47 | dataset, 48 | tokenizer: Callable, 49 | max_length: int, 50 | strategy, 51 | input_template=None, 52 | pretrain_mode=False, 53 | num_processors=8, # Specify the number of processors you want to use 54 | multiturn=False, 55 | ) -> None: 56 | super().__init__() 57 | self.tokenizer = tokenizer 58 | self.strategy = strategy 59 | self.pretrain_mode = pretrain_mode 60 | self.max_length = max_length 61 | self.multiturn = multiturn 62 | 63 | # chat template 64 | self.input_template = input_template 65 | self.input_key = getattr(self.strategy.args, "input_key", None) 66 | self.output_key = getattr(self.strategy.args, "output_key", None) 67 | self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 68 | 69 | if self.apply_chat_template: 70 | self.apply_chat_template = self.tokenizer.apply_chat_template 71 | tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) 72 | if tokenizer_chat_template: 73 | self.tokenizer.chat_template = tokenizer_chat_template 74 | 75 | # Parallel loading datasets 76 | processed_dataset = dataset.map( 77 | self.process_data, 78 | remove_columns=dataset.column_names, 79 | num_proc=num_processors, 80 | ) 81 | processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) 82 | 83 | # Store the processed data in class attributes 84 | self.prompts = processed_dataset["prompt"] 85 | self.responses = processed_dataset["response"] 86 | self.prompt_ids_lens = processed_dataset["prompt_ids_len"] 87 | self.response_ranges = processed_dataset["response_ranges"] if self.multiturn else None 88 | 89 | def process_data(self, data): 90 | if self.multiturn and self.output_key: 91 | data[self.input_key].append(data[self.output_key]) 92 | data[self.output_key] = None 93 | 94 | if self.multiturn: 95 | assert ( 96 | not self.output_key or not data[self.output_key] 97 | ), "You should put the whole trajactory into data[input_key] and do not set output_key" 98 | input_key = self.input_key 99 | apply_chat_template = self.apply_chat_template 100 | response_ranges = [] 101 | for idx, message in enumerate(data[input_key]): 102 | if message["role"] == "assistant": 103 | prompt = apply_chat_template(data[input_key][:idx], tokenize=False, add_generation_prompt=True) 104 | response = apply_chat_template(data[input_key][: idx + 1], tokenize=False)[len(prompt) :] 105 | 106 | start_idx = ( 107 | self.tokenizer( 108 | prompt, 109 | max_length=self.max_length, 110 | padding=False, 111 | truncation=True, 112 | return_tensors="pt", 113 | add_special_tokens=False, 114 | )["attention_mask"] 115 | .int() 116 | .sum() 117 | .item() 118 | ) 119 | 120 | end_idx = ( 121 | start_idx 122 | + self.tokenizer( 123 | response, 124 | max_length=self.max_length, 125 | padding=False, 126 | truncation=True, 127 | return_tensors="pt", 128 | add_special_tokens=False, 129 | )["attention_mask"] 130 | .int() 131 | .sum() 132 | .item() 133 | - 1 134 | ) 135 | response_ranges.append((start_idx, end_idx)) # left close right close 136 | 137 | prompt, response = preprocess_data( 138 | data, 139 | None if self.pretrain_mode else self.input_template, 140 | self.input_key, 141 | self.output_key, 142 | apply_chat_template=None if self.pretrain_mode else self.apply_chat_template, 143 | multiturn=self.multiturn, 144 | ) 145 | 146 | if not self.pretrain_mode: 147 | prompt_token = self.tokenizer( 148 | prompt, 149 | max_length=self.max_length, 150 | padding=False, 151 | truncation=True, 152 | return_tensors="pt", 153 | add_special_tokens=False, 154 | ) 155 | prompt_ids_len = prompt_token["attention_mask"].int().sum().item() 156 | # filter the sample whose length is greater than max_length (2 for answer length) 157 | if not prompt or not response or prompt_ids_len >= self.max_length - 2: 158 | prompt = None 159 | else: 160 | prompt_ids_len = 0 161 | 162 | return { 163 | "prompt": prompt, 164 | "response": response, 165 | "prompt_ids_len": prompt_ids_len, 166 | "response_ranges": response_ranges if self.multiturn else None, 167 | } 168 | 169 | def __len__(self): 170 | length = len(self.prompts) 171 | return length 172 | 173 | def __getitem__(self, idx): 174 | prompt = self.prompts[idx] 175 | response = self.responses[idx] 176 | 177 | if not self.pretrain_mode: 178 | text = (prompt + response).rstrip("\n") 179 | if not text.endswith(self.tokenizer.eos_token): 180 | text += " " + self.tokenizer.eos_token 181 | else: 182 | text = prompt 183 | 184 | input_token = self.tokenizer( 185 | text, 186 | max_length=self.max_length, 187 | padding=False, 188 | truncation=True, 189 | return_tensors="pt", 190 | add_special_tokens=False, 191 | ) 192 | input_ids = input_token["input_ids"] 193 | attention_mask = input_token["attention_mask"] 194 | loss_mask = self.get_loss_mask(input_ids, idx) 195 | 196 | if not self.pretrain_mode: 197 | # to avoid EOS_token truncation 198 | input_ids[0][-1] = self.tokenizer.eos_token_id 199 | attention_mask[0][-1] = True 200 | return input_ids, attention_mask, loss_mask 201 | 202 | def get_loss_mask(self, input_ids, idx): 203 | if self.pretrain_mode: 204 | return torch.ones_like(input_ids, dtype=torch.float32) # shape:[1, seq_len] 205 | 206 | loss_mask = torch.zeros_like(input_ids, dtype=torch.float32) 207 | if not self.multiturn: 208 | prompt_ids_len = self.prompt_ids_lens[idx] 209 | loss_mask[0, prompt_ids_len - 1 : -1] = 1 210 | else: 211 | response_ranges = self.response_ranges[idx] 212 | for start_idx, end_idx in response_ranges: 213 | loss_mask[0, start_idx - 1 : end_idx] = 1 214 | return loss_mask 215 | 216 | def collate_fn(self, item_list): 217 | input_ids = [] 218 | attention_masks = [] 219 | loss_masks = [] 220 | 221 | for input_id, attention_mask, loss_mask in item_list: 222 | input_ids.append(input_id) 223 | attention_masks.append(attention_mask) 224 | loss_masks.append(loss_mask) 225 | 226 | input_ids = zero_pad_sequences(input_ids, "right", self.tokenizer.pad_token_id) 227 | attention_masks = zero_pad_sequences(attention_masks, "right") 228 | loss_masks = zero_pad_sequences(loss_masks, "right") 229 | return input_ids, attention_masks, loss_masks 230 | -------------------------------------------------------------------------------- /openrlhf/trainer/prm_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | from torch.optim import Optimizer 5 | from tqdm import tqdm 6 | 7 | from openrlhf.models import PRMLoss 8 | from openrlhf.utils.distributed_sampler import DistributedSampler 9 | from openrlhf.utils.utils import convert_token_to_id 10 | 11 | 12 | class ProcessRewardModelTrainer(ABC): 13 | """ 14 | Trainer for training a process reward model. 15 | 16 | Args: 17 | model (torch.nn.Module): The model to be trained. 18 | strategy (Strategy): The training strategy to apply. 19 | optim (Optimizer): The optimizer to use during training. 20 | train_dataloader (DataLoader): The dataloader for the training dataset. 21 | eval_dataloader (DataLoader): The dataloader for the evaluation dataset. 22 | scheduler (Scheduler): The learning rate scheduler for dynamic adjustments during training. 23 | tokenizer (Tokenizer): The tokenizer for processing input text data. 24 | max_norm (float, defaults to 0.5): Maximum gradient norm for gradient clipping. 25 | max_epochs (int, defaults to 2): Maximum number of training epochs. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model, 31 | strategy, 32 | optim: Optimizer, 33 | train_dataloader, 34 | eval_dataloader, 35 | scheduler, 36 | max_norm: float = 1, 37 | batch_size: int = 1, 38 | max_epochs: int = 2, 39 | tokenizer=None, 40 | ) -> None: 41 | super().__init__() 42 | self.strategy = strategy 43 | self.epochs = max_epochs 44 | self.batch_size = batch_size 45 | self.max_norm = max_norm 46 | self.train_dataloader = train_dataloader 47 | self.eval_dataloader = eval_dataloader 48 | self.scheduler = scheduler 49 | self.model = model 50 | self.tokenizer = tokenizer 51 | self.optimizer = optim 52 | self.args = strategy.args 53 | 54 | # set placeholder token 55 | self.placeholder_token_id = convert_token_to_id(strategy.args.placeholder_token, self.tokenizer) 56 | self.reward_token_ids = self.args.reward_tokens 57 | if self.reward_token_ids is not None: 58 | self.reward_token_ids = [convert_token_to_id(token, self.tokenizer) for token in self.reward_token_ids] 59 | 60 | self.ignore_index = -100 61 | self.loss_fn = PRMLoss(self.placeholder_token_id, self.reward_token_ids) 62 | 63 | # Mixtral 8*7b 64 | self.aux_loss = self.args.aux_loss_coef > 1e-8 65 | 66 | # packing samples 67 | self.packing_samples = strategy.args.packing_samples 68 | 69 | # wandb setting 70 | self._wandb = None 71 | if self.strategy.args.use_wandb and self.strategy.is_rank_0(): 72 | import wandb 73 | 74 | self._wandb = wandb 75 | if not wandb.api.api_key: 76 | wandb.login(key=strategy.args.use_wandb) 77 | wandb.init( 78 | entity=strategy.args.wandb_org, 79 | project=strategy.args.wandb_project, 80 | group=strategy.args.wandb_group, 81 | name=strategy.args.wandb_run_name, 82 | config=strategy.args.__dict__, 83 | reinit=True, 84 | ) 85 | 86 | wandb.define_metric("train/global_step") 87 | wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True) 88 | wandb.define_metric("eval/global_step") 89 | wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True) 90 | 91 | def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): 92 | # get eval and save steps 93 | if args.eval_steps == -1: 94 | args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch 95 | if args.save_steps == -1: 96 | args.save_steps = float("inf") # do not save ckpt 97 | 98 | # Restore step and start_epoch 99 | step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 100 | start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch 101 | consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) 102 | 103 | epoch_bar = tqdm( 104 | range(start_epoch, self.epochs), 105 | desc="Train epoch", 106 | disable=not self.strategy.is_rank_0(), 107 | ) 108 | loss_sum = 0 109 | acc_sum = 0 110 | for epoch in range(start_epoch, self.epochs): 111 | if isinstance(self.train_dataloader.sampler, DistributedSampler): 112 | self.train_dataloader.sampler.set_epoch( 113 | epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples 114 | ) 115 | 116 | step_bar = tqdm( 117 | range(self.train_dataloader.__len__()), 118 | desc="Train step of epoch %d" % epoch, 119 | disable=not self.strategy.is_rank_0(), 120 | ) 121 | 122 | # train 123 | self.model.train() 124 | for data in self.train_dataloader: 125 | inputs, attention_masks, labels = data 126 | inputs = inputs.to(torch.cuda.current_device()).squeeze(1) 127 | attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1) 128 | labels = labels.to(torch.cuda.current_device()).squeeze(1) 129 | 130 | output = self.model( 131 | inputs, 132 | attention_mask=attention_mask, 133 | return_output=True, 134 | allgather_logits=True, 135 | ) 136 | 137 | # mixtral 138 | if self.aux_loss: 139 | aux_loss = output.aux_loss 140 | else: 141 | aux_loss = 0 142 | 143 | prm_loss, acc = self.loss_fn(inputs, output.logits, labels, return_acc=True) 144 | loss = prm_loss + aux_loss * self.args.aux_loss_coef 145 | self.strategy.backward(loss, self.model, self.optimizer) 146 | self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) 147 | 148 | loss_sum += loss.item() 149 | acc_sum += acc.item() 150 | logs_dict = { 151 | "prm_loss": prm_loss.item(), 152 | "acc": acc.item(), 153 | "lr": self.scheduler.get_last_lr()[0], 154 | } 155 | if self.aux_loss: 156 | logs_dict["aux_loss"] = aux_loss.item() 157 | # step bar 158 | logs_dict = self.strategy.all_reduce(logs_dict) 159 | step_bar.set_postfix(logs_dict) 160 | step_bar.update() 161 | 162 | # logs/checkpoints/evaluation 163 | if step % self.strategy.accumulated_gradient == 0: 164 | logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient 165 | logs_dict["acc_mean"] = acc_sum / self.strategy.accumulated_gradient 166 | loss_sum = 0 167 | acc_sum = 0 168 | global_step = step // self.strategy.accumulated_gradient 169 | client_states = {"consumed_samples": global_step * args.train_batch_size} 170 | self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) 171 | 172 | step += 1 173 | 174 | epoch_bar.update() 175 | 176 | # logs/checkpoints/evaluation 177 | def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): 178 | if global_step % args.logging_steps == 0: 179 | # wandb 180 | if self._wandb is not None and self.strategy.is_rank_0(): 181 | logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} 182 | self._wandb.log(logs) 183 | 184 | # eval 185 | if self.eval_dataloader is not None and global_step % args.eval_steps == 0: 186 | self.evaluate(self.eval_dataloader, global_step) 187 | # save ckpt 188 | # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric 189 | if global_step % args.save_steps == 0: 190 | tag = f"global_step{global_step}" 191 | self.strategy.save_ckpt( 192 | self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states 193 | ) 194 | 195 | def evaluate(self, eval_dataloader, steps=0): 196 | times = 0 197 | self.model.eval() 198 | with torch.no_grad(): 199 | loss_sum = 0 200 | acc_sum = 0 201 | step_bar = tqdm( 202 | range(eval_dataloader.__len__()), 203 | desc="Eval stage of steps %d" % steps, 204 | disable=not self.strategy.is_rank_0(), 205 | ) 206 | 207 | for data in eval_dataloader: 208 | inputs, attention_masks, labels = data 209 | inputs = inputs.to(torch.cuda.current_device()).squeeze(1) 210 | attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1) 211 | labels = labels.to(torch.cuda.current_device()).squeeze(1) 212 | 213 | output = self.model( 214 | inputs, 215 | attention_mask=attention_mask, 216 | return_output=True, 217 | allgather_logits=True, 218 | ) 219 | 220 | loss, acc = self.loss_fn(inputs, output.logits, labels, return_acc=True) 221 | 222 | times += 1 223 | loss_sum += loss.item() 224 | acc_sum += acc.item() 225 | bar_dict = {"eval prm_loss": loss_sum / times, "eval acc": acc_sum / times} 226 | step_bar.update() 227 | logs = self.strategy.all_reduce(bar_dict) 228 | step_bar.set_postfix(logs) 229 | 230 | if self._wandb is not None and self.strategy.is_rank_0(): 231 | logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} 232 | self._wandb.log(logs) 233 | self.model.train() # reset model state 234 | -------------------------------------------------------------------------------- /openrlhf/cli/train_prm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from datetime import datetime 5 | 6 | from transformers.trainer import get_scheduler 7 | 8 | from openrlhf.datasets import ProcessRewardDataset 9 | from openrlhf.models import Actor 10 | from openrlhf.trainer import ProcessRewardModelTrainer 11 | from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer 12 | 13 | 14 | def train(args): 15 | # configure strategy 16 | strategy = get_strategy(args) 17 | strategy.setup_distributed() 18 | 19 | # configure model 20 | # load huggingface model 21 | model = Actor( 22 | args.pretrain, 23 | use_flash_attention_2=args.flash_attn, 24 | bf16=args.bf16, 25 | load_in_4bit=args.load_in_4bit, 26 | lora_rank=args.lora_rank, 27 | lora_alpha=args.lora_alpha, 28 | target_modules=args.target_modules, 29 | lora_dropout=args.lora_dropout, 30 | ds_config=strategy.get_ds_train_config(is_actor=True), 31 | packing_samples=args.packing_samples, 32 | use_liger_kernel=args.use_liger_kernel, 33 | ) 34 | # configure tokenizer 35 | tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) 36 | strategy.print(model) 37 | 38 | # gradient_checkpointing 39 | if args.gradient_checkpointing: 40 | model.gradient_checkpointing_enable( 41 | gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} 42 | ) 43 | 44 | # configure optimizer 45 | optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) 46 | 47 | # prepare for data and dataset 48 | train_data = blending_datasets( 49 | args.dataset, 50 | args.dataset_probs, 51 | strategy, 52 | args.seed, 53 | max_count=args.max_samples, 54 | dataset_split=args.dataset_split, 55 | ) 56 | 57 | train_data = train_data.select(range(min(args.max_samples, len(train_data)))) 58 | train_dataset = ProcessRewardDataset(train_data, tokenizer, args.max_len, strategy) 59 | 60 | # prepare dataloader 61 | train_dataloader = strategy.setup_dataloader( 62 | train_dataset, 63 | args.micro_train_batch_size, 64 | True, 65 | True, 66 | train_dataset.collate_fn, 67 | ) 68 | 69 | eval_dataset = None 70 | eval_dataloader = None 71 | if getattr(args, "eval_dataset", None): 72 | eval_data = blending_datasets( 73 | args.eval_dataset, 74 | None, # No probability sampling for eval datasets 75 | strategy, 76 | dataset_split=args.eval_split, 77 | ) 78 | eval_dataset = ProcessRewardDataset(eval_data, tokenizer, args.max_len, strategy) 79 | eval_dataloader = strategy.setup_dataloader( 80 | eval_dataset, 81 | args.micro_train_batch_size, 82 | True, 83 | False, 84 | eval_dataset.collate_fn, 85 | ) 86 | 87 | # scheduler 88 | num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size 89 | max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) 90 | 91 | scheduler = get_scheduler( 92 | args.lr_scheduler, 93 | optim, 94 | num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), 95 | num_training_steps=max_steps, 96 | scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, 97 | ) 98 | 99 | # prepare models 100 | (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) 101 | 102 | # load checkpoint 103 | consumed_samples = 0 104 | if args.load_checkpoint and os.path.exists(args.ckpt_path): 105 | _, states = strategy.load_ckpt(model.model, args.ckpt_path) 106 | consumed_samples = states["consumed_samples"] 107 | strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") 108 | 109 | os.makedirs(args.save_path, exist_ok=True) 110 | 111 | # configure Trainer 112 | trainer = ProcessRewardModelTrainer( 113 | model=model, 114 | strategy=strategy, 115 | optim=optim, 116 | train_dataloader=train_dataloader, 117 | eval_dataloader=eval_dataloader, 118 | scheduler=scheduler, 119 | max_norm=args.max_norm, 120 | batch_size=args.train_batch_size, 121 | max_epochs=args.max_epochs, 122 | tokenizer=tokenizer, 123 | ) 124 | 125 | trainer.fit(args, consumed_samples, num_update_steps_per_epoch) 126 | 127 | # save model checkpoint after fitting on only rank0 128 | strategy.save_model(model, tokenizer, args.save_path) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | # Checkpoint 134 | parser.add_argument("--save_path", type=str, default="./ckpt") 135 | parser.add_argument("--save_steps", type=int, default=-1) 136 | parser.add_argument("--logging_steps", type=int, default=1) 137 | parser.add_argument("--eval_steps", type=int, default=-1) 138 | parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_prm") 139 | parser.add_argument("--max_ckpt_num", type=int, default=3) 140 | parser.add_argument("--max_ckpt_mem", type=int, default=1e8) 141 | parser.add_argument("--load_checkpoint", action="store_true", default=False) 142 | parser.add_argument("--use_ds_universal_ckpt", action="store_true", default=False) 143 | 144 | # DeepSpeed 145 | parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") 146 | parser.add_argument("--gradient_checkpointing", action="store_true", default=False) 147 | parser.add_argument("--deepcompile", action="store_true", default=False) 148 | parser.add_argument("--seed", type=int, default=42) 149 | parser.add_argument( 150 | "--full_determinism", 151 | action="store_true", 152 | default=False, 153 | help="Enable reproducible behavior during distributed training", 154 | ) 155 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") 156 | parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") 157 | parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") 158 | parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") 159 | parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") 160 | parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") 161 | parser.add_argument("--use_liger_kernel", action="store_true", default=False, help="Enable Liger Kernel") 162 | parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") 163 | parser.add_argument("--overlap_comm", action="store_true", default=False) 164 | parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) 165 | parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) 166 | 167 | # LoRA 168 | parser.add_argument("--load_in_4bit", action="store_true", default=False) 169 | parser.add_argument("--lora_rank", type=int, default=0) 170 | parser.add_argument("--lora_alpha", type=int, default=16) 171 | parser.add_argument("--lora_dropout", type=float, default=0) 172 | parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") 173 | 174 | # PRM training 175 | parser.add_argument("--pretrain", type=str, default=None) 176 | parser.add_argument("--max_epochs", type=int, default=1) 177 | parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") 178 | parser.add_argument("--learning_rate", type=float, default=1e-6) 179 | parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) 180 | parser.add_argument("--micro_train_batch_size", type=int, default=1) 181 | parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") 182 | parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") 183 | parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") 184 | parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") 185 | parser.add_argument("--placeholder_token", type=str, default=None) 186 | parser.add_argument("--reward_tokens", type=str, nargs="*", default=None) 187 | 188 | # packing samples using Flash Attention2 189 | parser.add_argument("--packing_samples", action="store_true", default=False) 190 | 191 | # custom dataset 192 | parser.add_argument("--dataset", type=str, default=None, help="Path to the training dataset") 193 | parser.add_argument("--dataset_probs", type=str, default=None, help="Sampling probabilities for training datasets") 194 | parser.add_argument("--eval_dataset", type=str, default=None, help="Path to the evaluation dataset") 195 | parser.add_argument("--dataset_split", type=str, default="train") 196 | parser.add_argument("--eval_split", type=str, default="test") 197 | parser.add_argument("--max_samples", type=int, default=1000000, help="Maximum number of samples to use") 198 | 199 | parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") 200 | parser.add_argument("--label_key", type=str, default="label", help="JSON dataset key") 201 | parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") 202 | 203 | # wandb parameters 204 | parser.add_argument("--use_wandb", type=str, default=None) 205 | parser.add_argument("--wandb_org", type=str, default=None) 206 | parser.add_argument("--wandb_group", type=str, default=None) 207 | parser.add_argument("--wandb_project", type=str, default="openrlhf_train_prm") 208 | parser.add_argument( 209 | "--wandb_run_name", 210 | type=str, 211 | default="prm_%s" % datetime.now().strftime("%m%dT%H:%M"), 212 | ) 213 | 214 | # ModelScope parameters 215 | parser.add_argument("--use_ms", action="store_true", default=False) 216 | 217 | args = parser.parse_args() 218 | 219 | # Add positive token and negative token to reward_tokens and remove duplicates 220 | if args.reward_tokens is not None: 221 | print( 222 | "If you are running with soft labels (float values), " 223 | f"the first token in reward_tokens ({args.reward_tokens[0]}) should be the positive token " 224 | "and the second token should be the negative token." 225 | ) 226 | 227 | if args.use_ms: 228 | from modelscope.utils.hf_util import patch_hub 229 | 230 | # Patch hub to download models from modelscope to speed up. 231 | patch_hub() 232 | 233 | train(args) 234 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/ppo_critic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from abc import ABC 4 | from typing import Dict, Optional, Union 5 | 6 | import ray 7 | import torch 8 | from torch.optim import Optimizer 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | from transformers.trainer import get_scheduler 12 | 13 | from openrlhf.models import ValueLoss, get_llm_for_sequence_regression 14 | from openrlhf.models.utils import masked_mean 15 | from openrlhf.trainer.ppo_utils.experience_maker import Experience 16 | from openrlhf.utils import get_tokenizer 17 | from openrlhf.utils.deepspeed import DeepspeedStrategy 18 | from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states 19 | 20 | from ..ppo_utils import NaiveReplayBuffer 21 | from .launcher import BasePPORole 22 | 23 | 24 | class CriticPPOTrainer(ABC): 25 | def __init__( 26 | self, 27 | strategy, 28 | critic: torch.nn.Module, 29 | critic_optim: Optimizer, 30 | critic_scheduler, 31 | micro_train_batch_size: int = 8, 32 | buffer_limit: int = 0, 33 | buffer_cpu_offload: bool = True, 34 | value_clip: float = 0.2, 35 | dataloader_pin_memory: bool = True, 36 | **kwargs, 37 | ): 38 | self.strategy = strategy 39 | self.args = strategy.args 40 | self.critic = critic 41 | self.critic_optim = critic_optim 42 | self.critic_scheduler = critic_scheduler 43 | self.micro_train_batch_size = micro_train_batch_size 44 | self.buffer_limit = buffer_limit 45 | self.buffer_cpu_offload = buffer_cpu_offload 46 | self.value_clip = value_clip 47 | self.dataloader_pin_memory = dataloader_pin_memory 48 | self.max_epochs = self.args.max_epochs 49 | 50 | self.replay_buffer = NaiveReplayBuffer( 51 | micro_train_batch_size, buffer_limit, buffer_cpu_offload, getattr(self.args, "packing_samples", False) 52 | ) 53 | 54 | self.critic_loss_fn = ValueLoss(value_clip) 55 | 56 | # Mixtral 8x7b 57 | self.aux_loss = self.args.aux_loss_coef > 1e-8 58 | 59 | def ppo_train(self): 60 | # replay buffer may be empty at first, we should rebuild at each training 61 | dataloader = DataLoader( 62 | self.replay_buffer, 63 | batch_size=self.replay_buffer.sample_batch_size, 64 | shuffle=False if self.strategy.ring_attn_group is not None else True, 65 | drop_last=True, 66 | pin_memory=self.dataloader_pin_memory, 67 | collate_fn=self.replay_buffer.collate_fn, 68 | ) 69 | device = torch.cuda.current_device() 70 | 71 | status_list = [] 72 | status_mean = {} 73 | for epoch in range(self.max_epochs): 74 | pbar = tqdm( 75 | dataloader, 76 | desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]", 77 | disable=not self.strategy.is_rank_0(), 78 | ) 79 | for experience in pbar: 80 | experience.to_device(device) 81 | status = self.training_step(experience) 82 | 83 | # for DP 84 | status = self.strategy.all_reduce(status) 85 | 86 | status_list.append(status) 87 | pbar.set_postfix(status) 88 | 89 | if status_list: 90 | status_mean = status_list[0] 91 | for m in status_list[1:]: 92 | for k, v in m.items(): 93 | status_mean[k] += v 94 | for k in status_mean.keys(): 95 | status_mean[k] /= len(status_list) 96 | return status_mean 97 | 98 | def training_step(self, experience: Experience) -> Dict[str, float]: 99 | self.critic.train() 100 | 101 | sequences = experience.sequences 102 | old_values = experience.values 103 | returns = experience.returns 104 | action_mask = experience.action_mask 105 | packed_seq_lens = None 106 | attention_mask = experience.attention_mask 107 | 108 | # critic loss 109 | values, output = self.critic( 110 | sequences, 111 | action_mask=action_mask, 112 | attention_mask=attention_mask, 113 | return_output=True, 114 | ring_attn_group=self.strategy.ring_attn_group, 115 | values_allgather=True, 116 | packed_seq_lens=packed_seq_lens, 117 | ) 118 | 119 | # loss function 120 | critic_loss = self.critic_loss_fn( 121 | values, 122 | old_values, 123 | returns, 124 | action_mask=experience.action_mask, 125 | ) 126 | # mixtral 127 | if self.aux_loss: 128 | aux_loss = output.aux_loss 129 | else: 130 | aux_loss = 0 131 | loss = critic_loss + aux_loss * self.args.aux_loss_coef 132 | self.strategy.backward(loss, self.critic, self.critic_optim) 133 | self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") 134 | 135 | # status 136 | status = { 137 | "critic_loss": critic_loss.detach().item(), 138 | "values": masked_mean(values, experience.action_mask).detach().item(), 139 | "critic_lr": self.critic_scheduler.get_last_lr()[0], 140 | } 141 | return status 142 | 143 | 144 | @ray.remote(num_gpus=1) 145 | class CriticModelRayActor(BasePPORole): 146 | def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_steps): 147 | args = strategy.args 148 | 149 | self._setup_distributed(strategy) 150 | critic = get_llm_for_sequence_regression( 151 | pretrain, 152 | "critic", 153 | normalize_reward=strategy.args.normalize_reward, 154 | use_flash_attention_2=strategy.args.flash_attn, 155 | bf16=strategy.args.bf16, 156 | load_in_4bit=strategy.args.load_in_4bit, 157 | lora_rank=strategy.args.lora_rank, 158 | lora_alpha=strategy.args.lora_alpha, 159 | target_modules=strategy.args.target_modules, 160 | lora_dropout=strategy.args.lora_dropout, 161 | ds_config=strategy.get_ds_train_config(is_actor=False), 162 | value_head_prefix=strategy.args.value_head_prefix, 163 | init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain, 164 | packing_samples=strategy.args.packing_samples, 165 | ) 166 | strategy.print(critic) 167 | strategy.print("reward normalization status: {}".format(strategy.args.normalize_reward)) 168 | strategy.print("mean: {}, std {}".format(critic.mean, critic.std)) 169 | 170 | # configure tokenizer 171 | if strategy.args.save_value_network: 172 | self.tokenizer = get_tokenizer( 173 | pretrain, critic, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer 174 | ) 175 | 176 | # configure optimizer 177 | critic_optim = strategy.create_optimizer( 178 | critic, lr=args.critic_learning_rate, betas=args.adam_betas, weight_decay=args.l2 179 | ) 180 | 181 | # configure scheduler 182 | critic_scheduler = get_scheduler( 183 | "cosine_with_min_lr", 184 | critic_optim, 185 | num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), 186 | num_training_steps=max_steps, 187 | scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1}, 188 | ) 189 | 190 | if args.gradient_checkpointing: 191 | critic.gradient_checkpointing_enable( 192 | gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} 193 | ) 194 | 195 | # prepare models/optimizers... 196 | self.critic, self.critic_optim, self.critic_scheduler = strategy.prepare( 197 | (critic, critic_optim, critic_scheduler), 198 | is_rlhf=True, 199 | ) 200 | 201 | # load checkpoint 202 | if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): 203 | ckpt_path = os.path.join(args.ckpt_path, "_critic") 204 | strategy.print(f"Loading the checkpoint: {ckpt_path}") 205 | strategy.load_ckpt(self.critic, ckpt_path) 206 | 207 | # initial offload 208 | if strategy.args.deepspeed_enable_sleep: 209 | self.offload_states() 210 | 211 | # configure Trainer 212 | self.trainer = CriticPPOTrainer( 213 | strategy, 214 | critic=self.critic, 215 | critic_optim=self.critic_optim, 216 | critic_scheduler=self.critic_scheduler, 217 | micro_train_batch_size=args.micro_train_batch_size, 218 | value_clip=args.value_clip, 219 | ) 220 | 221 | def forward( 222 | self, 223 | sequences: torch.LongTensor, 224 | action_mask: Optional[Union[int, list[int]]] = None, 225 | attention_mask: Optional[torch.Tensor] = None, 226 | packed_seq_lens=None, 227 | ) -> torch.Tensor: 228 | """Generates critic values.""" 229 | device = torch.cuda.current_device() 230 | self.critic.eval() 231 | with torch.no_grad(): 232 | value = self.critic( 233 | sequences.to(device), 234 | action_mask.to(device), 235 | attention_mask.to(device), 236 | ring_attn_group=self.strategy.ring_attn_group, 237 | values_allgather=True, 238 | ) 239 | self.critic.train() # reset model state 240 | return value.to("cpu") 241 | 242 | def append(self, experience): 243 | """Append experience to replay buffer.""" 244 | self.trainer.replay_buffer.append(experience) 245 | 246 | def fit(self): 247 | """Train critic model with the replay buffer.""" 248 | torch.cuda.empty_cache() 249 | self.critic.train() 250 | status = self.trainer.ppo_train() 251 | self.trainer.replay_buffer.clear() 252 | torch.cuda.empty_cache() 253 | torch.cuda.synchronize() 254 | return status 255 | 256 | def save_model(self): 257 | args = self.strategy.args 258 | 259 | # save model checkpoint after fitting on only rank0 260 | self.strategy.save_model( 261 | self.critic, 262 | self.tokenizer, 263 | args.save_path + "_critic", 264 | ) 265 | 266 | def save_checkpoint(self, tag): 267 | args = self.strategy.args 268 | self.strategy.save_ckpt( 269 | self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem 270 | ) 271 | 272 | def reload_states(self): 273 | reload_deepspeed_states(self.critic) 274 | 275 | def offload_states(self): 276 | offload_deepspeed_states(self.critic) 277 | -------------------------------------------------------------------------------- /openrlhf/models/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn as nn 6 | from peft import LoraConfig, TaskType, get_peft_model 7 | from peft.tuners.lora import LoraLayer 8 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig 9 | from transformers.integrations.deepspeed import HfDeepSpeedConfig 10 | 11 | from .ring_attn_utils import gather_and_pad_tensor, unpad_and_slice_tensor 12 | from .utils import compute_entropy, log_probs_from_logits, process_sequences 13 | 14 | compute_entropy = torch.compile(compute_entropy) 15 | 16 | 17 | class Actor(nn.Module): 18 | """ 19 | Base class for Actor models in reinforcement learning. 20 | 21 | This class serves as a foundation for implementing various actor models, which are responsible for selecting actions based on the policy learned from the environment. 22 | 23 | Args: 24 | pretrain_or_model (nn.Module): A pretrained model or a new model instance to be used as the actor. 25 | use_flash_attention_2 (bool, optional): Whether to utilize Flash Attention 2.0 for improved performance. Defaults to False. 26 | bf16 (bool, optional): Enable bfloat16 precision for model computations. Defaults to True. 27 | load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False. 28 | lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0. 29 | lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16. 30 | lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0. 31 | target_modules (list, optional): List of target modules for applying LoRA. Defaults to None. 32 | ds_config (dict, optional): Configuration for DeepSpeed, enabling model partitioning across multiple GPUs. Defaults to None. 33 | device_map (dict, optional): Device mapping for loading the model onto specific devices. Defaults to None. 34 | packing_samples (bool, optional): Whether to pack samples during training. Defaults to False. 35 | temperature (float, optional): Temperature for action selection. Defaults to 1.0. 36 | use_liger_kernel (bool, optional): Whether to use Liger Kernel for the model. Defaults to False. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | pretrain_or_model, 42 | use_flash_attention_2=False, 43 | bf16=True, 44 | load_in_4bit=False, 45 | lora_rank=0, 46 | lora_alpha=16, 47 | lora_dropout=0, 48 | target_modules=None, 49 | ds_config=None, 50 | device_map=None, 51 | packing_samples=False, 52 | temperature=1.0, 53 | use_liger_kernel=False, 54 | **kwargs, 55 | ) -> None: 56 | super().__init__() 57 | self.temperature = temperature 58 | 59 | if isinstance(pretrain_or_model, str): 60 | attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" 61 | 62 | # Note: dschf is defined in function scope to avoid global effects 63 | # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration 64 | if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: 65 | dschf = HfDeepSpeedConfig(ds_config) 66 | else: 67 | dschf = None 68 | 69 | if load_in_4bit: 70 | assert bf16, "we only support bnb_4bit_compute_dtype = bf16" 71 | nf4_config = BitsAndBytesConfig( 72 | load_in_4bit=True, 73 | bnb_4bit_quant_type="nf4", 74 | bnb_4bit_use_double_quant=True, 75 | bnb_4bit_compute_dtype=torch.bfloat16, 76 | ) 77 | else: 78 | nf4_config = None 79 | 80 | if use_liger_kernel: 81 | from liger_kernel.transformers import AutoLigerKernelForCausalLM 82 | 83 | model_class = AutoLigerKernelForCausalLM 84 | else: 85 | model_class = AutoModelForCausalLM 86 | 87 | self.model = model_class.from_pretrained( 88 | pretrain_or_model, 89 | trust_remote_code=True, 90 | attn_implementation=attn_implementation, 91 | quantization_config=nf4_config, 92 | torch_dtype=torch.bfloat16 if bf16 else "auto", 93 | device_map=device_map, 94 | ) 95 | 96 | # LoRA 97 | if lora_rank > 0: 98 | # https://github.com/huggingface/peft/issues/137 99 | self.model.enable_input_require_grads() 100 | lora_config = LoraConfig( 101 | task_type=TaskType.CAUSAL_LM, 102 | r=lora_rank, 103 | lora_alpha=lora_alpha, 104 | target_modules=target_modules, 105 | lora_dropout=lora_dropout, 106 | bias="none", 107 | ) 108 | self.model = get_peft_model(self.model, lora_config) 109 | 110 | if load_in_4bit: 111 | for name, module in self.model.named_modules(): 112 | if isinstance(module, LoraLayer): 113 | module = module.to(torch.bfloat16) 114 | if "norm" in name: 115 | module = module.to(torch.float32) 116 | if "lm_head" in name or "embed_tokens" in name: 117 | if hasattr(module, "weight"): 118 | module = module.to(torch.bfloat16) 119 | 120 | # MoE - balancing loss 121 | model_config = self.model.config.to_dict() 122 | if "output_router_logits" in model_config: 123 | print("[MoE] set output_router_logits as True") 124 | self.model.config.output_router_logits = True 125 | 126 | # https://github.com/huggingface/transformers/issues/26877 127 | # Use `model.generate(use_cache=True)` instead.` 128 | self.model.config.use_cache = False 129 | 130 | # packing samples using Flash Attention 2 131 | self.packing_samples = packing_samples 132 | else: 133 | self.model = pretrain_or_model 134 | 135 | @torch.no_grad() 136 | def generate(self, input_ids: torch.Tensor, **kwargs) -> Union[ 137 | Tuple[torch.LongTensor, torch.LongTensor], 138 | Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor], 139 | ]: 140 | generate_args = { 141 | "input_ids": input_ids, 142 | "top_k": kwargs.get("top_k", None), 143 | "top_p": kwargs.get("top_p", None), 144 | "do_sample": kwargs.get("do_sample", True), 145 | "early_stopping": kwargs.get("num_beams", 1) > 1, 146 | "temperature": kwargs.get("temperature", 1), 147 | "use_cache": True, 148 | "num_beams": kwargs.get("num_beams", 1), 149 | "attention_mask": kwargs.get("attention_mask"), 150 | "eos_token_id": kwargs.get("eos_token_id"), 151 | "pad_token_id": kwargs.get("pad_token_id"), 152 | "min_new_tokens": kwargs.get("min_new_tokens", 1), 153 | } 154 | 155 | if kwargs.get("max_new_tokens", None): 156 | generate_args["max_new_tokens"] = kwargs.get("max_new_tokens") 157 | if kwargs.get("max_length", None): 158 | generate_args["max_length"] = kwargs.get("max_length") 159 | 160 | # Call generate 161 | sequences = self.model.generate(**generate_args) 162 | 163 | # Prepare mask tensor 164 | eos_token_id = generate_args["eos_token_id"] 165 | pad_token_id = generate_args["pad_token_id"] 166 | 167 | return process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id) 168 | 169 | def forward( 170 | self, 171 | sequences: torch.LongTensor, 172 | action_mask: Optional[torch.Tensor] = None, 173 | attention_mask: Optional[torch.Tensor] = None, 174 | return_output=False, 175 | allgather_logits=False, 176 | return_logprobs=False, 177 | ring_attn_group: Optional[dist.ProcessGroup] = None, 178 | packed_seq_lens: Optional[list[int]] = None, 179 | return_entropy=False, 180 | ) -> torch.Tensor: 181 | """Returns action log probs""" 182 | batch, seqlen = sequences.size() 183 | foward_attention_mask = attention_mask 184 | if self.packing_samples: 185 | sequences, position_ids, rolled_sequences, ring_attn_pad_len, indices = unpad_and_slice_tensor( 186 | sequences, attention_mask, ring_attn_group 187 | ) 188 | foward_attention_mask = None 189 | else: 190 | # https://github.com/OpenRLHF/OpenRLHF/issues/217 191 | rolled_sequences = torch.roll(sequences, shifts=-1, dims=1) 192 | position_ids = attention_mask.long().cumsum(-1) - 1 193 | position_ids.masked_fill_(attention_mask == 0, 1) 194 | 195 | output = self.model(sequences, attention_mask=foward_attention_mask, position_ids=position_ids) 196 | # https://github.com/OpenRLHF/OpenRLHF/pull/634 197 | output["logits"] = output["logits"].to(torch.float32) 198 | 199 | if return_entropy: 200 | assert return_output 201 | entropy = compute_entropy(output["logits"]) 202 | if self.packing_samples: 203 | entropy = gather_and_pad_tensor(entropy, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen) 204 | setattr(output, "entropy", entropy[:, :-1]) 205 | 206 | return_action_log_probs = action_mask is not None 207 | if not return_action_log_probs and not return_logprobs: 208 | assert return_output 209 | if allgather_logits and self.packing_samples: 210 | output["logits"] = gather_and_pad_tensor( 211 | output["logits"], ring_attn_group, ring_attn_pad_len, indices, batch, seqlen 212 | ) 213 | return output 214 | 215 | log_probs = log_probs_from_logits(output["logits"], rolled_sequences, temperature=self.temperature) 216 | 217 | if self.packing_samples: 218 | log_probs = gather_and_pad_tensor(log_probs, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen) 219 | 220 | log_probs = log_probs[:, :-1] 221 | if not return_action_log_probs and return_logprobs: 222 | return (log_probs, output) if return_output else log_probs 223 | 224 | action_log_probs = log_probs[:, -action_mask.shape[1] :] * action_mask.float() 225 | 226 | return (action_log_probs, output) if return_output else action_log_probs 227 | 228 | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): 229 | self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) 230 | 231 | def gradient_checkpointing_disable(self): 232 | self.model.gradient_checkpointing_disable() 233 | 234 | def print_trainable_parameters(self): 235 | self.model.print_trainable_parameters() 236 | --------------------------------------------------------------------------------