├── .github └── workflows │ └── static.yml ├── .gitignore ├── .nojekyll ├── LICENSE ├── README.md ├── db_info ├── bird_db_id2db_info.json ├── bird_db_id2sampled_db_values.json ├── spiderdev_db_id2db_info.json ├── spiderdev_db_id2sampled_db_values.json ├── spidertest_db_id2db_info.json └── spidertest_db_id2sampled_db_values.json ├── docker ├── Dockerfile.ngc.vllm └── Dockerfile.vemlp.vllm.te ├── example_data ├── sampled_Complex.jsonl ├── test.parquet └── train.parquet ├── images ├── overview.png └── table1.png ├── index.html ├── patches └── megatron_v4.patch ├── pyproject.toml ├── requirements.txt ├── sh ├── eval_bird.sh ├── eval_spider.sh ├── inference.sh └── train.sh ├── src ├── data_preprocess │ ├── nl2sql_synsql_nonvalue.py │ └── nl2sql_synsql_value.py ├── evaluation_bird_post.py ├── evaluation_spider.py ├── evaluation_spider_post.py ├── evaluations │ ├── bird_evaluations │ │ ├── bird_exec_evaluation.py │ │ └── bird_ves_evaluation.py │ ├── preprocess_bird_result.py │ └── spider1_evaluations │ │ ├── __init__.py │ │ ├── evaluation_utils.py │ │ └── src │ │ ├── __init__.py │ │ ├── bridge_content_encoder.py │ │ ├── dataset.py │ │ ├── exec_eval.py │ │ ├── files_to_convert_natsql2sql │ │ └── natsql2sql │ │ │ ├── natsql2sql.py │ │ │ ├── natsql_parser.py │ │ │ ├── preprocess │ │ │ ├── Schema_Token.py │ │ │ ├── TokenString.py │ │ │ ├── col_match.py │ │ │ ├── db_match.py │ │ │ ├── match.py │ │ │ ├── others_pattern.py │ │ │ ├── pattern_analyze.py │ │ │ ├── pattern_question_type.py │ │ │ ├── question_repair.py │ │ │ ├── sentence_analyse.py │ │ │ ├── sq.py │ │ │ ├── sql_back.py │ │ │ ├── stemmer.py │ │ │ ├── table_match.py │ │ │ └── utils.py │ │ │ ├── process_sql.py │ │ │ └── utils.py │ │ ├── get_tables.py │ │ ├── nltk_downloader.py │ │ ├── parse.py │ │ ├── process_sql.py │ │ └── token_preprocessing.py ├── inference.py └── utils │ └── prepare_input_seq.py ├── static ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css ├── images │ ├── carousel1.jpg │ ├── carousel2.jpg │ ├── carousel3.jpg │ ├── carousel4.jpg │ └── favicon.ico ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js ├── pdfs │ └── sample.pdf └── videos │ ├── banner_video.mp4 │ ├── carousel1.mp4 │ ├── carousel2.mp4 │ └── carousel3.mp4 └── verl ├── __init__.py ├── models ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── registry.cpython-39.pyc ├── llama │ ├── __init__.py │ └── megatron │ │ ├── __init__.py │ │ ├── checkpoint_utils │ │ ├── __init__.py │ │ ├── llama_loader.py │ │ └── llama_saver.py │ │ ├── layers │ │ ├── __init__.py │ │ ├── parallel_attention.py │ │ ├── parallel_decoder.py │ │ ├── parallel_linear.py │ │ ├── parallel_mlp.py │ │ └── parallel_rmsnorm.py │ │ └── modeling_llama_megatron.py ├── registry.py ├── transformers │ ├── __init__.py │ ├── llama.py │ ├── monkey_patch.py │ └── qwen2.py └── weight_loader_registry.py ├── protocol.py ├── single_controller ├── __init__.py ├── base │ ├── __init__.py │ ├── decorator.py │ ├── megatron │ │ ├── __init__.py │ │ ├── worker.py │ │ └── worker_group.py │ ├── register_center │ │ ├── __init__.py │ │ └── ray.py │ ├── worker.py │ └── worker_group.py ├── ray │ ├── __init__.py │ ├── base.py │ └── megatron.py └── version │ └── version ├── third_party ├── __init__.py └── vllm │ ├── __init__.py │ ├── vllm_v_0_3_1 │ ├── __init__.py │ ├── arg_utils.py │ ├── config.py │ ├── llm.py │ ├── llm_engine_sp.py │ ├── model_loader.py │ ├── model_runner.py │ ├── parallel_state.py │ ├── tokenizer.py │ ├── weight_loaders.py │ └── worker.py │ ├── vllm_v_0_4_2 │ ├── __init__.py │ ├── arg_utils.py │ ├── config.py │ ├── dtensor_weight_loaders.py │ ├── hf_weight_loader.py │ ├── llm.py │ ├── llm_engine_sp.py │ ├── megatron_weight_loaders.py │ ├── model_loader.py │ ├── model_runner.py │ ├── parallel_state.py │ ├── spmd_gpu_executor.py │ ├── tokenizer.py │ └── worker.py │ ├── vllm_v_0_5_4 │ ├── __init__.py │ ├── arg_utils.py │ ├── config.py │ ├── dtensor_weight_loaders.py │ ├── hf_weight_loader.py │ ├── llm.py │ ├── llm_engine_sp.py │ ├── megatron_weight_loaders.py │ ├── model_loader.py │ ├── model_runner.py │ ├── parallel_state.py │ ├── spmd_gpu_executor.py │ ├── tokenizer.py │ └── worker.py │ └── vllm_v_0_6_3 │ ├── __init__.py │ ├── arg_utils.py │ ├── config.py │ ├── dtensor_weight_loaders.py │ ├── hf_weight_loader.py │ ├── llm.py │ ├── llm_engine_sp.py │ ├── megatron_weight_loaders.py │ ├── model_loader.py │ ├── model_runner.py │ ├── parallel_state.py │ ├── spmd_gpu_executor.py │ ├── tokenizer.py │ └── worker.py ├── trainer ├── __init__.py ├── config │ ├── evaluation.yaml │ ├── generation.yaml │ ├── model │ │ └── lora_enabled.yaml │ ├── ppo_megatron_trainer.yaml │ ├── ppo_trainer.yaml │ └── sft_trainer.yaml ├── fsdp_sft_trainer.py ├── main_eval.py ├── main_generation.py ├── main_ppo.py ├── ppo │ ├── __init__.py │ ├── core_algos.py │ └── ray_trainer.py └── runtime_env.yaml ├── utils ├── __init__.py ├── config.py ├── dataset │ ├── README.md │ ├── __init__.py │ ├── rl_dataset.py │ ├── rm_dataset.py │ └── sft_dataset.py ├── debug │ ├── __init__.py │ ├── performance.py │ └── trajectory_tracker.py ├── distributed.py ├── flops_counter.py ├── fs.py ├── fsdp_utils.py ├── hdfs_io.py ├── import_utils.py ├── logger │ ├── __init__.py │ └── aggregate_logger.py ├── logging_utils.py ├── megatron │ ├── __init__.py │ ├── memory.py │ ├── optimizer.py │ ├── optimizer_config.py │ ├── pipeline_parallel.py │ ├── sequence_parallel.py │ └── tensor_parallel.py ├── megatron_utils.py ├── memory_buffer.py ├── model.py ├── py_functional.py ├── ray_utils.py ├── rendezvous │ ├── __init__.py │ └── ray_backend.py ├── reward_score │ ├── __init__.py │ ├── countdown.py │ ├── exec_eval.py │ ├── gsm8k.py │ ├── parse.py │ ├── patterns │ │ ├── nl_patterns.yml │ │ └── output_templates.yml │ └── synsql.py ├── seqlen_balancing.py ├── tokenizer.py ├── torch_dtypes.py ├── torch_functional.py ├── tracking.py └── ulysses.py ├── version └── version └── workers ├── __init__.py ├── actor ├── __init__.py ├── base.py ├── dp_actor.py └── megatron_actor.py ├── critic ├── __init__.py ├── base.py ├── dp_critic.py └── megatron_critic.py ├── fsdp_workers.py ├── megatron_workers.py ├── reward_model ├── __init__.py ├── base.py └── megatron │ ├── __init__.py │ └── reward_model.py ├── rollout ├── __init__.py ├── base.py ├── hf_rollout.py ├── naive │ ├── __init__.py │ └── naive_rollout.py ├── tokenizer.py └── vllm_rollout │ ├── __init__.py │ └── vllm_rollout.py └── sharding_manager ├── __init__.py ├── base.py ├── fsdp_ulysses.py ├── fsdp_vllm.py └── megatron_vllm.py /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | # Simple workflow for deploying static content to GitHub Pages 2 | name: Deploy static content to Pages 3 | 4 | on: 5 | # Runs on pushes targeting the default branch 6 | push: 7 | branches: ["main"] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 13 | permissions: 14 | contents: read 15 | pages: write 16 | id-token: write 17 | 18 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 19 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 20 | concurrency: 21 | group: "pages" 22 | cancel-in-progress: false 23 | 24 | jobs: 25 | # Single deploy job since we're just deploying 26 | deploy: 27 | environment: 28 | name: github-pages 29 | url: ${{ steps.deployment.outputs.page_url }} 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v4 34 | - name: Setup Pages 35 | uses: actions/configure-pages@v5 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v3 38 | with: 39 | # Upload entire repository 40 | path: '.' 41 | - name: Deploy to GitHub Pages 42 | id: deployment 43 | uses: actions/deploy-pages@v4 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Folders 2 | data/ 3 | models/ 4 | logs/ 5 | outputs/ 6 | results/ 7 | wandb/ 8 | # sh/ 9 | *verl/models/ 10 | openr1_ckpts/ 11 | *.wandb 12 | *.out 13 | 14 | core 15 | test_all.ipynb 16 | 17 | 18 | # Python 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | *.so 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # IDE 41 | .idea/ 42 | .vscode/ 43 | *.swp 44 | *.swo 45 | 46 | # 环境和依赖 47 | venv/ 48 | env/ 49 | .env 50 | .venv 51 | ENV/ 52 | env.bak/ 53 | venv.bak/ 54 | .python-version 55 | 56 | # 日志和缓存 57 | *.log 58 | logs/ 59 | .cache 60 | .pytest_cache/ 61 | .coverage 62 | htmlcov/ 63 | 64 | # 数据和模型文件 65 | data/ 66 | *.pkl 67 | *.h5 68 | *.pt 69 | *.pth 70 | *.bin 71 | *.ckpt 72 | *.model 73 | results/ 74 | 75 | # 系统文件 76 | .DS_Store 77 | Thumbs.db 78 | 79 | # 配置文件 80 | config.ini 81 | secrets.json 82 | credentials.json 83 | *.config 84 | 85 | # 临时文件 86 | tmp/ 87 | temp/ 88 | .temp/ 89 | *.tmp 90 | -------------------------------------------------------------------------------- /.nojekyll: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docker/Dockerfile.ngc.vllm: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.05-py3 2 | 3 | # uninstall nv-pytorch fork 4 | RUN pip3 uninstall pytorch-quantization \ 5 | pytorch-triton \ 6 | torch \ 7 | torch-tensorrt \ 8 | torchvision \ 9 | xgboost transformer_engine flash_attn \ 10 | apex megatron-core -y 11 | 12 | RUN pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124 13 | 14 | # make sure torch version is kept 15 | RUN pip3 install --no-cache-dir \ 16 | "torch==2.4.0" \ 17 | accelerate \ 18 | codetiming \ 19 | datasets \ 20 | dill \ 21 | hydra-core \ 22 | numpy \ 23 | pybind11 \ 24 | tensordict \ 25 | "transformers<=4.46.0" 26 | 27 | # ray is installed via vllm 28 | RUN pip3 install --no-cache-dir vllm==0.6.3 29 | 30 | # we choose flash-attn v2.7.0 or v2.7.2 which contain pre-built wheels 31 | RUN pip3 install --no-cache-dir --no-build-isolation flash-attn==2.7.0.post2 32 | 33 | # install apex, set MAX_JOBS to avoid OOMs 34 | RUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ 35 | --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" \ 36 | git+https://github.com/NVIDIA/apex 37 | 38 | # install Transformer Engine, which requires FA 2.5.8 39 | RUN MAX_JOBS=4 NINJA_FLAGS="-j4" pip3 install flash-attn==2.5.8 --no-cache-dir --no-build-isolation 40 | RUN MAX_JOBS=4 NINJA_FLAGS="-j4" pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@v1.7 41 | 42 | # Pin wandb to v0.18 since v0.19.1 is released with ImportError 43 | RUN pip3 install wandb==0.18.7 py-spy 44 | -------------------------------------------------------------------------------- /docker/Dockerfile.vemlp.vllm.te: -------------------------------------------------------------------------------- 1 | # docker buildx build --platform linux/x86_64 -t "verlai/verl:$TAG" -f docker/$FILE . 2 | 3 | # the one in docker.io is an alias for the one veturbo 4 | # FROM vemlp-cn-beijing.cr.volces.com/veturbo/pytorch:2.4-cu124 5 | FROM docker.io/haibinlin/verl:v0.0.5-th2.4.0-cu124-base 6 | 7 | # only config pip index with https://pypi.tuna.tsinghua.edu.cn/simple if needed 8 | # unset for now 9 | RUN pip3 config unset global.index-url 10 | 11 | # transformers 4.47.0 contains the following bug: 12 | # AttributeError: 'Gemma2Attention' object has no attribute '_flash_attn_uses_top_left_mask' 13 | RUN pip3 install --no-cache-dir \ 14 | torch==2.4.0 \ 15 | accelerate \ 16 | codetiming \ 17 | dill \ 18 | hydra-core \ 19 | numpy \ 20 | pybind11 \ 21 | tensordict \ 22 | "transformers <= 4.46.0" 23 | 24 | RUN pip3 install --no-cache-dir flash-attn==2.7.0.post2 --no-build-isolation 25 | 26 | # vllm depends on ray, and veRL does not support ray > 2.37 27 | RUN pip3 install --no-cache-dir vllm==0.6.3 ray==2.10 28 | 29 | # install apex 30 | RUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ 31 | --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" \ 32 | git+https://github.com/NVIDIA/apex 33 | 34 | # install Transformer Engine 35 | # - flash-attn pinned to 2.5.3 by TransformerEngine, switch to eric-haibin-lin/TransformerEngine.git@v1.7.0 to relax version req 36 | # - install with: MAX_JOBS=1 NINJA_FLAGS="-j1" TE_BUILD_WITH_NINJA=0 to avoid OOM 37 | # - cudnn is required by TransformerEngine 38 | # RUN CUDNN_PATH=/opt/conda/lib/python3.11/site-packages/nvidia/cudnn \ 39 | # pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0 40 | RUN MAX_JOBS=1 NINJA_FLAGS="-j1" pip3 install flash-attn==2.5.3 --no-cache-dir --no-build-isolation 41 | RUN MAX_JOBS=1 NINJA_FLAGS="-j1" pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@v1.7 42 | -------------------------------------------------------------------------------- /example_data/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/example_data/test.parquet -------------------------------------------------------------------------------- /example_data/train.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/example_data/train.parquet -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/images/overview.png -------------------------------------------------------------------------------- /images/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/images/table1.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # ------------------------------- 2 | # build-system 3 | # ------------------------------- 4 | [build-system] 5 | requires = [ 6 | "setuptools>=61.0", 7 | "wheel" 8 | ] 9 | build-backend = "setuptools.build_meta" 10 | 11 | # ------------------------------- 12 | # project (PEP 621 metadata) 13 | # ------------------------------- 14 | [project] 15 | name = "verl" 16 | # We'll mark the version as "dynamic" because it's read from the file "verl/version/version" 17 | # (PEP 621 calls this "dynamic version"). 18 | # The actual version is specified in the [tool.setuptools.dynamic] section below. 19 | dynamic = ["version"] 20 | 21 | description = "veRL: Volcano Engine Reinforcement Learning for LLM" 22 | license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifier 23 | readme = {file = "README.md", content-type = "text/markdown"} 24 | requires-python = ">=3.8" 25 | 26 | authors = [ 27 | { name = "Bytedance - Seed - MLSys", email = "zhangchi.usc1992@bytedance.com" }, 28 | { name = "Bytedance - Seed - MLSys", email = "gmsheng@connect.hku.hk" }, 29 | ] 30 | 31 | # Dependencies corresponding to install_requires in setup.py 32 | dependencies = [ 33 | "accelerate", 34 | "codetiming", 35 | "datasets", 36 | "dill", 37 | "hydra-core", 38 | "numpy", 39 | "pybind11", 40 | "ray", 41 | "tensordict", 42 | "transformers<4.48", 43 | "vllm<=0.6.3", 44 | ] 45 | 46 | # Optional dependencies (extras_require in setup.py) 47 | [project.optional-dependencies] 48 | test = [ 49 | "pytest", "yapf" 50 | ] 51 | 52 | # URLs 53 | [project.urls] 54 | Homepage = "https://github.com/volcengine/verl" 55 | 56 | # ------------------------------- 57 | # tool.setuptools - Additional config 58 | # ------------------------------- 59 | [tool.setuptools] 60 | # True means `setuptools` will attempt to include all relevant files in package_data automatically. 61 | # This corresponds to `include_package_data=True` in setup.py. 62 | include-package-data = true 63 | 64 | # We read the version from a file in 'verl/version/version' 65 | [tool.setuptools.dynamic] 66 | version = {file = "verl/version/version"} 67 | 68 | # If you need to mimic `package_dir={'': '.'}`: 69 | [tool.setuptools.package-dir] 70 | "" = "." 71 | 72 | # If you need to include specific non-Python data (like YAML files or version file): 73 | # This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']} 74 | [tool.setuptools.package-data] 75 | verl = [ 76 | "version/*", 77 | "trainer/config/*.yaml" 78 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | codetiming 3 | datasets 4 | dill 5 | flash-attn 6 | hydra-core 7 | numpy 8 | pandas 9 | pybind11 10 | ray 11 | tensordict<0.6 12 | transformers<4.48 13 | vllm==0.6.3 14 | wandb 15 | func_timeout 16 | sqlparse 17 | -------------------------------------------------------------------------------- /sh/eval_bird.sh: -------------------------------------------------------------------------------- 1 | echo "Evaluating Bird dataset..." 2 | 3 | POST_PROCESS_MODE=Maj # [Maj] 4 | 5 | PRED_SQL_PATH= 6 | 7 | if [ "$POST_PROCESS_MODE" = "None" ]; then 8 | PRED_SQL_JSON_PATH=${PRED_SQL_PATH%.*}.json 9 | else 10 | echo "Current post-processing mode is $POST_PROCESS_MODE" 11 | fi 12 | 13 | GROUND_TRUTH_SQL_PATH=data/BIRD/dev.sql 14 | GROUND_TRUTH_JSON_PATH=data/BIRD/dev.json 15 | DB_ROOT_PATH=data/BIRD/dev_databases/ 16 | NUM_CPUS=64 17 | META_TIME_OUT=30.0 18 | MODE_GT=gt 19 | MODE_PREDICT=gpt 20 | ITERATE_NUM=100 21 | DIFF_JSON_PATH=data/BIRD/dev.json 22 | SAVE_DIR=results/eval/bird 23 | EVAL_MODE=acc 24 | 25 | 26 | if [ "$POST_PROCESS_MODE" = "Gre" ]; then 27 | python src/evaluation_bird_post.py \ 28 | --pred $PRED_SQL_PATH \ 29 | --gold $GROUND_TRUTH_JSON_PATH \ 30 | --db_path $DB_ROOT_PATH \ 31 | --mode greedy_search 32 | 33 | elif [ "$POST_PROCESS_MODE" = "Maj" ]; then 34 | python src/evaluation_bird_post.py \ 35 | --pred $PRED_SQL_PATH \ 36 | --gold $GROUND_TRUTH_JSON_PATH \ 37 | --db_path $DB_ROOT_PATH \ 38 | --mode major_voting 39 | else 40 | echo 'Please set the post-processing mode' 41 | fi -------------------------------------------------------------------------------- /sh/eval_spider.sh: -------------------------------------------------------------------------------- 1 | echo "Evaluating Spider dataset..." 2 | 3 | PRED_SQL= 4 | MODE=test # [dev, test] 5 | POST_PROCESS_MODE=Maj 6 | 7 | 8 | if [ "$MODE" = "dev" ]; then 9 | GOLD_SQL=data/NL2SQL/Spider/dev_gold.sql 10 | DB=data/NL2SQL/Spider/database 11 | TABLE=data/NL2SQL/Spider/tables.json 12 | elif [ "$MODE" = "test" ]; then 13 | GOLD_SQL=data/NL2SQL/Spider/test_gold.sql 14 | DB=data/NL2SQL/Spider/test_database 15 | TABLE=data/NL2SQL/Spider/test_tables.json 16 | else 17 | echo "Only support dev or test mode for Spider" 18 | exit 1 19 | fi 20 | 21 | ETYPE=all 22 | PLUG_VALUE=false 23 | KEEP_DISTINCT=false 24 | PROGRESS_BAR_FOR_EACH_DATAPOINT=false 25 | SAVE_DIR=results/eval/spider 26 | 27 | 28 | if [ "$POST_PROCESS_MODE" = "Maj" ]; then 29 | python src/evaluation_spider_post.py \ 30 | --pred $PRED_SQL \ 31 | --gold $GOLD_SQL \ 32 | --db_path $DB/ \ 33 | --table $TABLE \ 34 | --mode major_voting \ 35 | --save_pred_sqls False \ 36 | --save_dir $SAVE_DIR 37 | 38 | PRED_SQL=${PRED_SQL%.*}_pred_major_voting_sqls.txt 39 | python src/evaluation_spider.py \ 40 | --gold_sql $GOLD_SQL \ 41 | --pred_sql $PRED_SQL \ 42 | --db $DB \ 43 | --table $TABLE \ 44 | --etype $ETYPE \ 45 | --plug_value $PLUG_VALUE \ 46 | --keep_distinct $KEEP_DISTINCT \ 47 | --progress_bar_for_each_datapoint $PROGRESS_BAR_FOR_EACH_DATAPOINT \ 48 | --save_dir $SAVE_DIR 49 | 50 | elif [ "$POST_PROCESS_MODE" = "Gre" ]; then 51 | python src/evaluation_spider_post.py \ 52 | --pred $PRED_SQL \ 53 | --gold $GOLD_SQL \ 54 | --db_path $DB/ \ 55 | --table $TABLE \ 56 | --mode greedy_search \ 57 | --save_pred_sqls False \ 58 | --save_dir $SAVE_DIR 59 | 60 | PRED_SQL=${PRED_SQL%.*}_pred_greedy_search_sqls.txt 61 | python src/evaluation_spider.py \ 62 | --gold_sql $GOLD_SQL \ 63 | --pred_sql $PRED_SQL \ 64 | --db $DB \ 65 | --table $TABLE \ 66 | 67 | else 68 | echo 'Please set the post-processing mode' 69 | fi -------------------------------------------------------------------------------- /sh/inference.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | 3 | OUTPUT_FORMAT=json 4 | 5 | MODEL_ENV= # TODO: add model path 6 | OUTPUT_FILE_NAME=generated_sql.$OUTPUT_FORMAT 7 | DATASET=bird # [bird, spider, spider-dk, spider-syn, spider-realistic, spider2-lite] 8 | MODE=dev # [dev, test], only spider has test mode 9 | NUM_GPUS=4 10 | 11 | TEMPERATURE=0.8 12 | 13 | N=8 14 | 15 | if [ "$DATASET" = "spider" ]; then 16 | if [ "$MODE" = "test" ]; then 17 | INPUT_FILE=data/NL2SQL/Spider/test.json 18 | DATABASE_PATH=data/NL2SQL/Spider/test_database 19 | OUTPUT_FILE=results/spidertest-$OUTPUT_FILE_NAME 20 | TABLE_VALUE_CACHE_PATH=data/NL2SQL/Spider/spidertest_db_id2sampled_db_values.json 21 | TABLE_INFO_CACHE_PATH=data/NL2SQL/Spider/spidertest_db_id2db_info.json 22 | elif [ "$MODE" = "dev" ]; then 23 | INPUT_FILE=data/NL2SQL/Spider/dev.json 24 | DATABASE_PATH=data/NL2SQL/Spider/database 25 | OUTPUT_FILE=results/spiderdev-$OUTPUT_FILE_NAME 26 | TABLE_VALUE_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2sampled_db_values.json 27 | TABLE_INFO_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2db_info.json 28 | fi 29 | elif [ "$DATASET" = "bird" ]; then 30 | if [ "$MODE" = "dev" ]; then 31 | INPUT_FILE=data/NL2SQL/BIRD/dev/dev.json 32 | DATABASE_PATH=data/NL2SQL/BIRD/dev/dev_databases 33 | OUTPUT_FILE=results/birddev-$OUTPUT_FILE_NAME 34 | TABLE_VALUE_CACHE_PATH=data/NL2SQL/BIRD/dev/bird_db_id2sampled_db_values.json 35 | TABLE_INFO_CACHE_PATH=data/NL2SQL/BIRD/dev/bird_db_id2db_info.json 36 | else 37 | exit 1 38 | fi 39 | elif [ "$DATASET" = "spider-dk" ]; then 40 | if [ "$MODE" = "dev" ]; then 41 | INPUT_FILE=data/NL2SQL/Spider-DK/spiderdk_dev.json 42 | DATABASE_PATH=data/NL2SQL/Spider-DK/database 43 | OUTPUT_FILE=results/spiderdkdev-$OUTPUT_FILE_NAME 44 | TABLE_VALUE_CACHE_PATH=data/NL2SQL/Spider-DK/spiderdkdev_db_id2sampled_db_values.json 45 | TABLE_INFO_CACHE_PATH=data/NL2SQL/Spider-DK/spiderdkdev_db_id2db_info.json 46 | else 47 | exit 1 48 | fi 49 | elif [ "$DATASET" = "spider-syn" ]; then 50 | if [ "$MODE" = "dev" ]; then 51 | INPUT_FILE=data/NL2SQL/Spider-Syn/spider_syn.json 52 | DATABASE_PATH=data/NL2SQL/Spider/database 53 | OUTPUT_FILE=results/spidersyn-$OUTPUT_FILE_NAME 54 | TABLE_VALUE_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2sampled_db_values.json 55 | TABLE_INFO_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2db_info.json 56 | else 57 | exit 1 58 | fi 59 | elif [ "$DATASET" = "spider-realistic" ]; then 60 | if [ "$MODE" = "dev" ]; then 61 | INPUT_FILE=data/NL2SQL/Spider-Realistic/spider-realistic.json 62 | DATABASE_PATH=data/NL2SQL/Spider/database 63 | OUTPUT_FILE=results/spiderrealdev-$OUTPUT_FILE_NAME 64 | TABLE_VALUE_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2sampled_db_values.json 65 | TABLE_INFO_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2db_info.json 66 | else 67 | exit 1 68 | fi 69 | elif [ "$DATASET" = "spider2-lite" ]; then 70 | if [ "$MODE" = "dev" ]; then 71 | INPUT_FILE=data/NL2SQL/Spider-Realistic/spider-realistic.json 72 | DATABASE_PATH=data/NL2SQL/Spider/database 73 | OUTPUT_FILE=results/spiderrealdev-$OUTPUT_FILE_NAME 74 | TABLE_VALUE_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2sampled_db_values.json 75 | TABLE_INFO_CACHE_PATH=data/NL2SQL/Spider/spiderdev_db_id2db_info.json 76 | else 77 | exit 1 78 | fi 79 | else 80 | echo "Only support spider, bird, spdier-dk" 81 | exit 1 82 | fi 83 | 84 | 85 | python src/inference.py \ 86 | --nl2sql_ckpt_path $MODEL_ENV \ 87 | --dataset_name $DATASET \ 88 | --input_file $INPUT_FILE \ 89 | --output_file $OUTPUT_FILE \ 90 | --database_path $DATABASE_PATH \ 91 | --tensor_parallel_size $NUM_GPUS \ 92 | --n $N \ 93 | --temperature $TEMPERATURE \ 94 | --output_format $OUTPUT_FORMAT \ 95 | --table_value_cache_path $TABLE_VALUE_CACHE_PATH \ 96 | --table_info_cache_path $TABLE_INFO_CACHE_PATH 97 | -------------------------------------------------------------------------------- /sh/train.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY=your_wandb_api_key 2 | export VLLM_ATTENTION_BACKEND=XFORMERS 3 | 4 | DATA_DIR_PATH=data 5 | 6 | RUN_ID=7B 7 | GPU_ENV=8GPU 8 | MODEL_ENV=Qwen2.5-Coder-7B-Instruct 9 | PROJECT_NAME=SQL-R1 10 | 11 | LOG_PATH=logs/$PROJECT_NAME 12 | MODEL_PATH=models/$MODEL_ENV 13 | EXPERIMENT_NAME=$GPU_ENV-$MODEL_ENV-$RUN_ID 14 | 15 | mkdir -p $LOG_PATH 16 | 17 | set -x 18 | 19 | nvidia-smi 20 | 21 | python -m verl.trainer.main_ppo \ 22 | algorithm.adv_estimator=grpo \ 23 | data.train_files=$DATA_DIR_PATH/train.parquet \ 24 | data.val_files=$DATA_DIR_PATH/test.parquet \ 25 | data.train_batch_size=8 \ 26 | data.val_batch_size=8 \ 27 | data.max_prompt_length=4096 \ 28 | data.max_response_length=2048 \ 29 | actor_rollout_ref.model.path=$MODEL_PATH \ 30 | actor_rollout_ref.actor.optim.lr=3e-7 \ 31 | actor_rollout_ref.model.use_remove_padding=True \ 32 | actor_rollout_ref.actor.ppo_mini_batch_size=8 \ 33 | actor_rollout_ref.actor.ppo_micro_batch_size=8 \ 34 | actor_rollout_ref.actor.use_kl_loss=True \ 35 | actor_rollout_ref.actor.kl_loss_coef=0.001 \ 36 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 37 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 38 | actor_rollout_ref.actor.fsdp_config.param_offload=True \ 39 | actor_rollout_ref.actor.fsdp_config.grad_offload=True \ 40 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ 41 | actor_rollout_ref.rollout.log_prob_micro_batch_size=80 \ 42 | actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ 43 | actor_rollout_ref.rollout.name=vllm \ 44 | actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \ 45 | actor_rollout_ref.rollout.n=8 \ 46 | actor_rollout_ref.rollout.temperature=1.1 \ 47 | actor_rollout_ref.ref.log_prob_micro_batch_size=80 \ 48 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 49 | algorithm.kl_ctrl.kl_coef=0.001 \ 50 | trainer.critic_warmup=0 \ 51 | trainer.logger=['wandb'] \ 52 | trainer.project_name=$PROJECT_NAME \ 53 | trainer.experiment_name=$EXPERIMENT_NAME \ 54 | trainer.n_gpus_per_node=8 \ 55 | trainer.nnodes=1 \ 56 | trainer.default_local_dir=$LOG_PATH/$EXPERIMENT_NAME \ 57 | trainer.default_hdfs_dir=null \ 58 | trainer.save_freq=100 \ 59 | trainer.test_freq=100 \ 60 | trainer.total_epochs=10 $@ 2>&1 | tee $LOG_PATH/$MODEL_ENV/grpo.log 61 | -------------------------------------------------------------------------------- /src/evaluation_spider_post.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import random 5 | import re 6 | 7 | from evaluation_bird_post import major_voting, mark_invalid_sqls 8 | 9 | random.seed(42) 10 | 11 | def format_sql(sql): 12 | sql = sql.strip() 13 | # remove multi-line comments /* ... */ 14 | sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL) 15 | 16 | # remove single-line comments -- 17 | sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE) 18 | 19 | sql = sql.replace("\n", " ").replace("\t", " ") 20 | sql = sql.strip() 21 | 22 | if sql == "": 23 | sql = "Error SQL" 24 | 25 | return sql 26 | 27 | def run_spider_eval(gold_file, pred_file, db_path, table, mode, save_pred_sqls, save_dir): 28 | # assert mode in ["greedy_search", "major_voting"] 29 | gold_sqls = [line.split("\t")[0].strip() for line in open(gold_file).readlines()] 30 | db_ids = [line.split("\t")[1].strip() for line in open(gold_file).readlines()] 31 | pred = json.load(open(pred_file)) 32 | pred_sql_key = "pred_sqls" 33 | # pred_sql_key = "responses" 34 | 35 | pred_sqls = [] 36 | if mode == "greedy_search": 37 | pred_sqls = [pred_data[pred_sql_key][0] for pred_data in pred] 38 | assert len(pred_sqls) == len(db_ids) 39 | db_files = [os.path.join(db_path, db_id, db_id + ".sqlite") for db_id in db_ids] 40 | pred_sqls = mark_invalid_sqls(db_files, pred_sqls) 41 | elif mode == "major_voting": 42 | # perform major voting using the BIRD's evaluation script 43 | sampling_num = len(pred[0][pred_sql_key]) 44 | print("sampling_num:", sampling_num) 45 | 46 | all_db_files = [] 47 | for db_id in db_ids: 48 | all_db_files.extend([os.path.join(db_path, db_id, db_id + ".sqlite")] * sampling_num) 49 | 50 | all_pred_sqls = [] 51 | for pred_data in pred: 52 | all_pred_sqls.extend(pred_data[pred_sql_key]) 53 | assert len(all_db_files) == len(all_pred_sqls) 54 | 55 | pred_sqls = major_voting(all_db_files, all_pred_sqls, sampling_num, False) 56 | 57 | pred_sqls = [format_sql(pred_sql) for pred_sql in pred_sqls] 58 | assert len(pred_sqls) == len(gold_sqls) 59 | 60 | if save_pred_sqls: 61 | with open(pred_file[:-5] + f"_pred_{mode}_sqls.json", "w", encoding="utf-8") as f: 62 | f.write(json.dumps(pred_sqls, indent=2 ,ensure_ascii=False)) 63 | 64 | new_txt_file_name = pred_file.rsplit('.', 1)[0] 65 | with open(new_txt_file_name + f"_pred_{mode}_sqls.txt", 'w') as temp_file: 66 | for pred_sql in pred_sqls: 67 | temp_file.write(pred_sql + "\n") 68 | temp_file_name = temp_file.name 69 | print(temp_file_name) 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--pred', type = str, default = "predict_dev.json") 74 | parser.add_argument('--gold', type = str, default = "./data/spider/dev_gold.sql") 75 | parser.add_argument('--db_path', type = str, default = "./data/spider/database/") 76 | parser.add_argument('--table', type = str, default = "./data/spider/tables.json") 77 | parser.add_argument('--mode', type = str, default = "greedy_search") 78 | parser.add_argument('--save_pred_sqls', type = bool, default = False) 79 | parser.add_argument('--save_dir', type = str, default = "results/eval/spider") 80 | opt = parser.parse_args() 81 | 82 | run_spider_eval(opt.gold, opt.pred, opt.db_path, opt.table, opt.mode, opt.save_pred_sqls, opt.save_dir) -------------------------------------------------------------------------------- /src/evaluations/preprocess_bird_result.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | import re 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--txt_result_path', type=str) 9 | parser.add_argument('--json_result_path', type=str) 10 | parser.add_argument('--json_save_path', type=str) 11 | args = parser.parse_args() 12 | 13 | txt_result_path = args.txt_result_path 14 | json_result_path = args.json_result_path 15 | json_save_path = args.json_save_path 16 | 17 | with open(txt_result_path, 'r') as f: 18 | result_sqls = f.readlines() 19 | 20 | with open(json_result_path, 'r') as f: 21 | json_result = json.load(f) 22 | 23 | final_output_dict = {} 24 | 25 | for sql, json_data in zip(result_sqls, json_result): 26 | sql = sql.split('/*')[0].strip() 27 | final_output_dict[str(json_data['question_id'])] = sql + "\t----- bird -----\t" + json_data['db_id'] 28 | 29 | with open(json_save_path, 'w') as f: 30 | json.dump(final_output_dict, f, indent=4) -------------------------------------------------------------------------------- /src/evaluations/spider1_evaluations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/src/evaluations/spider1_evaluations/__init__.py -------------------------------------------------------------------------------- /src/evaluations/spider1_evaluations/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/src/evaluations/spider1_evaluations/src/__init__.py -------------------------------------------------------------------------------- /src/evaluations/spider1_evaluations/src/files_to_convert_natsql2sql/natsql2sql/preprocess/stemmer.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | from .match import ALL_JJS 3 | 4 | DICT = {"weight":"weigh", 5 | "won":"win", 6 | "nation":"country", 7 | 8 | } 9 | 10 | class MyStemmer(): 11 | def __init__(self): 12 | self.stemmer = nltk.stem.LancasterStemmer() 13 | 14 | def stem(self,w): 15 | result = w.lower() 16 | if result == "january": 17 | return "jan" 18 | elif result == "february": 19 | result = "feb" 20 | elif result == "march": 21 | return "mar" 22 | elif result == "april": 23 | return "apr" 24 | elif result == "may": 25 | return "may" 26 | elif result == "june": 27 | return "jun" 28 | elif result == "july": 29 | return "jul" 30 | elif result == "august": 31 | return "aug" 32 | elif result == "september": 33 | return "sep" 34 | elif result == "sept": 35 | return "sep" 36 | elif result == "october": 37 | return "oct" 38 | elif result == "november": 39 | return "nov" 40 | elif result == "december": 41 | return "dec" 42 | result = self.stemmer.stem(result) 43 | if result == "weight": 44 | result = "weigh" 45 | if result == "hight": 46 | result = "high" 47 | elif result == "won": 48 | result = "win" 49 | elif result in ALL_JJS: 50 | return ALL_JJS[result] 51 | elif result == "maxim": 52 | result = "max" 53 | elif result == "minim": 54 | result = "min" 55 | return result -------------------------------------------------------------------------------- /src/evaluations/spider1_evaluations/src/nltk_downloader.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | nltk.download('punkt') -------------------------------------------------------------------------------- /static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | padding: 20px; 121 | font-size: 0; 122 | } 123 | 124 | .results-carousel video { 125 | margin: 0; 126 | } 127 | 128 | .slider-pagination .slider-page { 129 | background: #000000; 130 | } 131 | 132 | .eql-cntrb { 133 | font-size: smaller; 134 | } 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /static/images/carousel1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/images/carousel1.jpg -------------------------------------------------------------------------------- /static/images/carousel2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/images/carousel2.jpg -------------------------------------------------------------------------------- /static/images/carousel3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/images/carousel3.jpg -------------------------------------------------------------------------------- /static/images/carousel4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/images/carousel4.jpg -------------------------------------------------------------------------------- /static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/images/favicon.ico -------------------------------------------------------------------------------- /static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | 7 | var options = { 8 | slidesToScroll: 1, 9 | slidesToShow: 1, 10 | loop: true, 11 | infinite: true, 12 | autoplay: true, 13 | autoplaySpeed: 5000, 14 | } 15 | 16 | // Initialize all div with carousel class 17 | var carousels = bulmaCarousel.attach('.carousel', options); 18 | 19 | bulmaSlider.attach(); 20 | 21 | }) 22 | -------------------------------------------------------------------------------- /static/pdfs/sample.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/pdfs/sample.pdf -------------------------------------------------------------------------------- /static/videos/banner_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/videos/banner_video.mp4 -------------------------------------------------------------------------------- /static/videos/carousel1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/videos/carousel1.mp4 -------------------------------------------------------------------------------- /static/videos/carousel2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/videos/carousel2.mp4 -------------------------------------------------------------------------------- /static/videos/carousel3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/static/videos/carousel3.mp4 -------------------------------------------------------------------------------- /verl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 18 | 19 | with open(os.path.join(version_folder, 'version/version')) as f: 20 | __version__ = f.read().strip() 21 | 22 | from .protocol import DataProto 23 | 24 | from .utils.logging_utils import set_basic_config 25 | import logging 26 | 27 | set_basic_config(level=logging.WARNING) 28 | -------------------------------------------------------------------------------- /verl/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. 3 | ## Adding a New Huggingface Model 4 | ### Step 1: Copy the model file from HF to verl 5 | - Add a new file under verl/models/hf 6 | - Copy ONLY the model file from huggingface/transformers/models to verl/models/hf 7 | 8 | ### Step 2: Modify the model file to use packed inputs 9 | - Remove all the code related to inference (kv cache) 10 | - Modify the inputs to include only 11 | - input_ids (total_nnz,) 12 | - cu_seqlens (total_nnz + 1,) 13 | - max_seqlen_in_batch: int 14 | - Note that this requires using flash attention with causal mask. 15 | 16 | ### Step 2.5: Add tests 17 | - Add a test to compare this version and the huggingface version 18 | - Following the infrastructure and add tests to tests/models/hf 19 | 20 | ### Step 3: Add a function to apply tensor parallelism 21 | - Please follow 22 | - https://pytorch.org/docs/stable/distributed.tensor.parallel.html 23 | - https://pytorch.org/tutorials/intermediate/TP_tutorial.html 24 | - General comments 25 | - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward. 26 | 27 | ### Step 4: Add a function to apply data parallelism 28 | - Please use FSDP2 APIs 29 | - See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413 30 | 31 | ### Step 5: Add a function to apply pipeline parallelism 32 | - Comes in Pytorch 2.4 33 | - Currently only in alpha in nightly version 34 | - Check torchtitan for more details 35 | 36 | -------------------------------------------------------------------------------- /verl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/verl/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /verl/models/__pycache__/registry.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/SQL-R1/9d7f923fcd1c2ac41e9f8cc594d435824b0753f5/verl/models/__pycache__/registry.cpython-39.pyc -------------------------------------------------------------------------------- /verl/models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/llama/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .modeling_llama_megatron import ( 16 | # original model with megatron 17 | ParallelLlamaModel, 18 | ParallelLlamaForCausalLM, 19 | # rmpad with megatron 20 | ParallelLlamaForCausalLMRmPad, 21 | ParallelLlamaForValueRmPad, 22 | # rmpad with megatron and pipeline parallelism 23 | ParallelLlamaForCausalLMRmPadPP, 24 | ParallelLlamaForValueRmPadPP) 25 | -------------------------------------------------------------------------------- /verl/models/llama/megatron/checkpoint_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/llama/megatron/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .parallel_attention import ParallelLlamaAttention 16 | from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad 17 | from .parallel_mlp import ParallelLlamaMLP 18 | from .parallel_rmsnorm import ParallelLlamaRMSNorm 19 | -------------------------------------------------------------------------------- /verl/models/llama/megatron/layers/parallel_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py 15 | 16 | from typing import Optional, Tuple 17 | 18 | from megatron.core import tensor_parallel 19 | 20 | 21 | class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): 22 | 23 | def __init__(self, 24 | input_size, 25 | num_heads, 26 | num_key_value_heads, 27 | head_dim, 28 | *, 29 | bias=True, 30 | gather_output=True, 31 | skip_bias_add=False, 32 | **kwargs): 33 | # Keep input parameters, and already restrict the head numbers 34 | self.input_size = input_size 35 | self.q_output_size = num_heads * head_dim 36 | self.kv_output_size = num_key_value_heads * head_dim 37 | self.head_dim = head_dim 38 | self.gather_output = gather_output 39 | self.skip_bias_add = skip_bias_add 40 | 41 | input_size = self.input_size 42 | output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim 43 | 44 | super().__init__(input_size=input_size, 45 | output_size=output_size, 46 | bias=bias, 47 | gather_output=gather_output, 48 | skip_bias_add=skip_bias_add, 49 | **kwargs) 50 | 51 | 52 | class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): 53 | 54 | def __init__(self, 55 | input_size, 56 | gate_ouput_size, 57 | up_output_size, 58 | *, 59 | bias=True, 60 | gather_output=True, 61 | skip_bias_add=False, 62 | **kwargs): 63 | # Keep input parameters, and already restrict the head numbers 64 | self.input_size = input_size 65 | self.output_size = gate_ouput_size + up_output_size 66 | self.gather_output = gather_output 67 | self.skip_bias_add = skip_bias_add 68 | 69 | super().__init__(input_size=self.input_size, 70 | output_size=self.output_size, 71 | bias=bias, 72 | gather_output=gather_output, 73 | skip_bias_add=skip_bias_add, 74 | **kwargs) 75 | -------------------------------------------------------------------------------- /verl/models/llama/megatron/layers/parallel_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | from megatron.core import parallel_state as mpu 22 | from megatron.core import tensor_parallel 23 | from megatron.core import ModelParallelConfig 24 | from torch import nn 25 | from transformers.activations import ACT2FN 26 | from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear 27 | 28 | from verl.utils.megatron import tensor_parallel as tp_utils 29 | 30 | 31 | class ParallelLlamaMLP(nn.Module): 32 | 33 | def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: 34 | super().__init__() 35 | self.config = config 36 | self.hidden_size = config.hidden_size 37 | self.intermediate_size = config.intermediate_size 38 | # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] 39 | 40 | column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() 41 | row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() 42 | 43 | if megatron_config is not None: 44 | assert column_kwargs.get('config', False), 'must have ModelParallelConfig' 45 | assert row_kwargs.get('config', False), 'must have ModelParallelConfig' 46 | tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) 47 | tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) 48 | 49 | tp_size = mpu.get_tensor_model_parallel_world_size() 50 | 51 | self.gate_up_proj = MergedColumnParallelLinear( 52 | input_size=self.hidden_size, 53 | gate_ouput_size=self.intermediate_size, 54 | up_output_size=self.intermediate_size, 55 | bias=False, 56 | gather_output=False, 57 | skip_bias_add=False, 58 | **column_kwargs, 59 | ) 60 | self.gate_size = self.intermediate_size // tp_size 61 | 62 | self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size, 63 | output_size=self.hidden_size, 64 | bias=False, 65 | input_is_parallel=True, 66 | skip_bias_add=False, 67 | **row_kwargs) 68 | 69 | self.act_fn = ACT2FN[config.hidden_act] 70 | 71 | def forward(self, x): 72 | gate_up = self.gate_up_proj(x)[0] 73 | gate, up = gate_up.split(self.gate_size, dim=-1) 74 | return self.down_proj(self.act_fn(gate) * up)[0] 75 | -------------------------------------------------------------------------------- /verl/models/llama/megatron/layers/parallel_rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numbers 16 | import torch 17 | from megatron.core import ModelParallelConfig 18 | from torch import nn 19 | from transformers import LlamaConfig 20 | 21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine 22 | from verl.utils.megatron import sequence_parallel as sp_utils 23 | 24 | 25 | class ParallelLlamaRMSNorm(nn.Module): 26 | 27 | def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): 28 | """ 29 | LlamaRMSNorm is equivalent to T5LayerNorm 30 | """ 31 | super().__init__() 32 | if isinstance(config.hidden_size, numbers.Integral): 33 | normalized_shape = (config.hidden_size,) 34 | self.normalized_shape = torch.Size(normalized_shape) 35 | self.weight = nn.Parameter(torch.ones(self.normalized_shape)) 36 | self.variance_epsilon = config.rms_norm_eps 37 | 38 | if megatron_config.sequence_parallel: 39 | sp_utils.mark_parameter_as_sequence_parallel(self.weight) 40 | 41 | def forward(self, hidden_states): 42 | return fused_rms_norm_affine(input=hidden_states, 43 | weight=self.weight, 44 | normalized_shape=self.normalized_shape, 45 | eps=self.variance_epsilon, 46 | memory_efficient=True) -------------------------------------------------------------------------------- /verl/models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | from typing import List, Optional, Type 17 | 18 | import torch.nn as nn 19 | 20 | # Supported models using HF Rmpad 21 | # TODO(sgm): HF may supported more than listed here, we should add more after testing 22 | from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config 23 | 24 | _REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config} 25 | 26 | 27 | def check_model_support_rmpad(model_type: str): 28 | assert isinstance(model_type, str) 29 | if not model_type in _REOVEPAD_MODELS.keys(): 30 | raise ValueError(f"Model architecture {model_type} is not supported for now. " 31 | f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}." 32 | f"Please set `use_remove_padding=False` in the model config.") 33 | 34 | 35 | # Supported models in Megatron-LM 36 | # Architecture -> (module, class). 37 | _MODELS = { 38 | "LlamaForCausalLM": 39 | ("llama", ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad")), 40 | "MistralForCausalLM": ("mistral", ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", 41 | "ParallelMistralForCausalLMRmPad")) 42 | } 43 | 44 | 45 | # return model class 46 | class ModelRegistry: 47 | 48 | @staticmethod 49 | def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]: 50 | if model_arch not in _MODELS: 51 | return None 52 | 53 | megatron = "megatron" 54 | 55 | module_name, model_cls_name = _MODELS[model_arch] 56 | if not value: # actor/ref 57 | model_cls_name = model_cls_name[0] 58 | elif value: # critic/rm 59 | model_cls_name = model_cls_name[1] 60 | 61 | module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") 62 | return getattr(module, model_cls_name, None) 63 | 64 | @staticmethod 65 | def get_supported_archs() -> List[str]: 66 | return list(_MODELS.keys()) 67 | -------------------------------------------------------------------------------- /verl/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/models/transformers/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Apply monkey-patch function to models 16 | """ 17 | 18 | #### Open Source Models 19 | #### transformers version < 4.48 20 | 21 | 22 | def apply_monkey_patch_to_llama(): 23 | from transformers.models.llama.modeling_llama import LlamaFlashAttention2 24 | from verl.models.transformers.llama import llama_flash_attn_forward 25 | LlamaFlashAttention2.forward = llama_flash_attn_forward 26 | 27 | 28 | def apply_monkey_patch_to_qwen2(): 29 | from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2 30 | from verl.models.transformers.qwen2 import qwen2_flash_attn_forward 31 | Qwen2FlashAttention2.forward = qwen2_flash_attn_forward 32 | 33 | 34 | _PATCH_NAME_TO_FUNC = { 35 | 'llama': apply_monkey_patch_to_llama, 36 | 'qwen2': apply_monkey_patch_to_qwen2, 37 | } 38 | 39 | from transformers import PretrainedConfig 40 | 41 | 42 | def apply_monkey_patch(config: PretrainedConfig, verbose=True): 43 | if not is_transformers_version_in_range("4.45.0", "4.47.1"): 44 | raise AssertionError("The installed `transformers` version doesn't support ulysses patch. " 45 | "Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature.") 46 | success_apply_monkey_patch = False 47 | if config.model_type in _PATCH_NAME_TO_FUNC: 48 | _PATCH_NAME_TO_FUNC[config.model_type]() 49 | success_apply_monkey_patch = True 50 | 51 | if success_apply_monkey_patch and verbose: 52 | print(f'Applying monkey patch to model {config.model_type}') 53 | elif not success_apply_monkey_patch: 54 | raise NotImplementedError(f'Ulysses for model {config.model_type} is not implemented, \ 55 | please set `ulysses_sequence_parallel_size=1`') 56 | 57 | return success_apply_monkey_patch 58 | 59 | 60 | from functools import lru_cache 61 | from packaging import version 62 | import importlib.metadata 63 | 64 | 65 | @lru_cache() 66 | def is_transformers_version_in_range(min_version: str, max_version: str) -> bool: 67 | try: 68 | # Get the installed version of the transformers library 69 | transformers_version = importlib.metadata.version("transformers") 70 | except importlib.metadata.PackageNotFoundError: 71 | raise ModuleNotFoundError("The `transformers` package is not installed.") 72 | 73 | # Check if the version is within the specified range 74 | return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version) 75 | -------------------------------------------------------------------------------- /verl/models/weight_loader_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def get_weight_loader(arch: str): 17 | from verl.models.llama.megatron.checkpoint_utils.llama_loader import load_state_dict_to_megatron_llama 18 | _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {'LlamaForCausalLM': load_state_dict_to_megatron_llama} 19 | 20 | if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: 21 | return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] 22 | raise ValueError(f"Model architectures {arch} are not supported for now. " 23 | f"Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}") 24 | -------------------------------------------------------------------------------- /verl/single_controller/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) 18 | 19 | with open(os.path.join(version_folder, 'version/version')) as f: 20 | __version__ = f.read().strip() 21 | -------------------------------------------------------------------------------- /verl/single_controller/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .worker import Worker 16 | from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool 17 | -------------------------------------------------------------------------------- /verl/single_controller/base/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/single_controller/base/megatron/worker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from dataclasses import dataclass 17 | from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo 18 | 19 | 20 | class MegatronWorker(Worker): 21 | 22 | def __init__(self, cuda_visible_devices=None) -> None: 23 | super().__init__(cuda_visible_devices) 24 | 25 | def get_megatron_global_info(self): 26 | from megatron.core import parallel_state as mpu 27 | tp_size = mpu.get_tensor_model_parallel_world_size() 28 | dp_size = mpu.get_data_parallel_world_size() 29 | pp_size = mpu.get_pipeline_model_parallel_world_size() 30 | info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size) 31 | return info 32 | 33 | def get_megatron_rank_info(self): 34 | from megatron.core import parallel_state as mpu 35 | tp_rank = mpu.get_tensor_model_parallel_rank() 36 | dp_rank = mpu.get_data_parallel_rank() 37 | pp_rank = mpu.get_pipeline_model_parallel_rank() 38 | info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank) 39 | return info -------------------------------------------------------------------------------- /verl/single_controller/base/megatron/worker_group.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from .worker import DistRankInfo, DistGlobalInfo 18 | from verl.single_controller.base import ResourcePool, WorkerGroup 19 | 20 | 21 | class MegatronWorkerGroup(WorkerGroup): 22 | 23 | def __init__(self, resource_pool: ResourcePool, **kwargs): 24 | super().__init__(resource_pool=resource_pool, **kwargs) 25 | self._megatron_rank_info = None 26 | self._megatron_global_info: DistGlobalInfo = None 27 | 28 | def init_megatron(self, default_megatron_kwargs: Dict = None): 29 | raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten") 30 | 31 | def get_megatron_rank_info(self, rank: int) -> DistRankInfo: 32 | assert 0 <= rank < self.world_size, f'rank must be from [0, world_size), Got {rank}' 33 | return self._megatron_rank_info[rank] 34 | 35 | @property 36 | def tp_size(self): 37 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 38 | return self._megatron_global_info.tp_size 39 | 40 | @property 41 | def dp_size(self): 42 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 43 | return self._megatron_global_info.dp_size 44 | 45 | @property 46 | def pp_size(self): 47 | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" 48 | return self._megatron_global_info.pp_size 49 | 50 | def get_megatron_global_info(self): 51 | return self._megatron_global_info 52 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/single_controller/base/register_center/ray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ray 16 | 17 | 18 | @ray.remote 19 | class WorkerGroupRegisterCenter: 20 | 21 | def __init__(self, rank_zero_info): 22 | self.rank_zero_info = rank_zero_info 23 | 24 | def get_rank_zero_info(self): 25 | return self.rank_zero_info 26 | 27 | 28 | def create_worker_group_register_center(name, info): 29 | return WorkerGroupRegisterCenter.options(name=name).remote(info) 30 | -------------------------------------------------------------------------------- /verl/single_controller/ray/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls 16 | from .megatron import (MegatronRayWorkerGroup, DistRankInfo, DistGlobalInfo) -------------------------------------------------------------------------------- /verl/single_controller/ray/megatron.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, Optional 16 | 17 | import ray 18 | 19 | from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs 20 | from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo 21 | from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup 22 | 23 | 24 | # NOTE(sgm): for opensource megatron-core 25 | class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): 26 | """ 27 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup 28 | so that the dispatcher can use it to dispatch data. 29 | """ 30 | 31 | def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): 32 | super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) 33 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') 34 | self._megatron_global_info: DistGlobalInfo = ray.get( 35 | self.execute_rank_zero_async(method_name='get_megatron_global_info')) 36 | 37 | 38 | class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): 39 | """ 40 | MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup 41 | so that the dispatcher can use it to dispatch data. 42 | """ 43 | 44 | def __init__(self, 45 | resource_pool: RayResourcePool, 46 | ray_cls_with_init: RayClassWithInitArgs, 47 | default_megatron_kwargs: Dict = None, 48 | **kwargs): 49 | super().__init__(resource_pool=resource_pool, 50 | ray_cls_with_init=ray_cls_with_init, 51 | default_megatron_kwargs=default_megatron_kwargs, 52 | **kwargs) 53 | self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) 54 | self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') 55 | self._megatron_global_info: DistGlobalInfo = ray.get( 56 | self.execute_rank_zero_async(method_name='get_megatron_global_info')) 57 | 58 | def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None): 59 | # after super, we will call init of each worker 60 | if not self._is_init_with_detached_workers: 61 | # only init_megatron if the WorkerGroup is created from scratch 62 | self.execute_all_sync(method_name='init_megatron', default_megatron_kwargs=default_megatron_kwargs) 63 | -------------------------------------------------------------------------------- /verl/single_controller/version/version: -------------------------------------------------------------------------------- 1 | 0.0.2 -------------------------------------------------------------------------------- /verl/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/third_party/vllm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from importlib.metadata import version, PackageNotFoundError 16 | 17 | 18 | def get_version(pkg): 19 | try: 20 | return version(pkg) 21 | except PackageNotFoundError: 22 | return None 23 | 24 | 25 | package_name = 'vllm' 26 | package_version = get_version(package_name) 27 | 28 | if package_version == '0.3.1': 29 | vllm_version = '0.3.1' 30 | from .vllm_v_0_3_1.llm import LLM 31 | from .vllm_v_0_3_1.llm import LLMEngine 32 | from .vllm_v_0_3_1 import parallel_state 33 | elif package_version == '0.4.2': 34 | vllm_version = '0.4.2' 35 | from .vllm_v_0_4_2.llm import LLM 36 | from .vllm_v_0_4_2.llm import LLMEngine 37 | from .vllm_v_0_4_2 import parallel_state 38 | elif package_version == '0.5.4': 39 | vllm_version = '0.5.4' 40 | from .vllm_v_0_5_4.llm import LLM 41 | from .vllm_v_0_5_4.llm import LLMEngine 42 | from .vllm_v_0_5_4 import parallel_state 43 | elif package_version == '0.6.3': 44 | vllm_version = '0.6.3' 45 | from .vllm_v_0_6_3.llm import LLM 46 | from .vllm_v_0_6_3.llm import LLMEngine 47 | from .vllm_v_0_6_3 import parallel_state 48 | else: 49 | raise ValueError( 50 | f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.' 51 | ) 52 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_3_1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_3_1/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | from vllm.lora.request import LoRARequest 21 | from vllm.utils import make_async, LRUCache 22 | from vllm.transformers_utils.tokenizers import * 23 | 24 | 25 | class TokenizerGroup: 26 | """A group of tokenizers that can be used for LoRA adapters.""" 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 29 | max_input_length: Optional[int]): 30 | self.enable_lora = enable_lora 31 | self.max_input_length = max_input_length 32 | self.tokenizer = tokenizer 33 | if enable_lora: 34 | self.lora_tokenizers = LRUCache(capacity=max_num_seqs) 35 | else: 36 | self.lora_tokenizers = None 37 | 38 | def encode(self, 39 | prompt: str, 40 | request_id: Optional[str] = None, 41 | lora_request: Optional[LoRARequest] = None) -> List[int]: 42 | tokenizer = self.get_lora_tokenizer(lora_request) 43 | return tokenizer.encode(prompt) 44 | 45 | async def encode_async(self, 46 | prompt: str, 47 | request_id: Optional[str] = None, 48 | lora_request: Optional[LoRARequest] = None) -> List[int]: 49 | tokenizer = await self.get_lora_tokenizer_async(lora_request) 50 | return tokenizer.encode(prompt) 51 | 52 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": 53 | if not lora_request or not self.enable_lora: 54 | return self.tokenizer 55 | if lora_request.lora_int_id not in self.lora_tokenizers: 56 | # TODO(sgm): the lora tokenizer is also passed, but may be different 57 | tokenizer = self.tokenizer 58 | # tokenizer = (get_lora_tokenizer( 59 | # lora_request, **self.tokenizer_config) or self.tokenizer) 60 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) 61 | return tokenizer 62 | else: 63 | return self.lora_tokenizers.get(lora_request.lora_int_id) 64 | 65 | # FIXME(sgm): for simplicity, we assign the special token here 66 | @property 67 | def pad_token_id(self): 68 | return self.tokenizer.pad_token_id 69 | 70 | @property 71 | def eos_token_id(self): 72 | return self.tokenizer.eos_token_id 73 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_3_1/weight_loaders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models 15 | 16 | from typing import Dict 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | # NOTE(shengguangming): replace the origin weight loader function in the class 22 | def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: 23 | """Parallel Linear weight loader.""" 24 | assert param.size() == loaded_weight.size( 25 | ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format( 26 | param.size(), loaded_weight.size()) 27 | assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" 28 | 29 | param.data = loaded_weight.data 30 | 31 | 32 | def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: 33 | """Default weight loader.""" 34 | assert param.size() == loaded_weight.size() 35 | assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" 36 | 37 | param.data = loaded_weight.data 38 | 39 | 40 | def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: 41 | params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) 42 | for name, loaded_weight in actor_weights.items(): 43 | if "lm_head.weight" in name: 44 | # GPT-2 ties the weights of the embedding layer and the final 45 | # linear layer. 46 | continue 47 | if ".attn.bias" in name or ".attn.masked_bias" in name: 48 | # Skip attention mask. 49 | # NOTE: "c_attn.bias" should not be skipped. 50 | continue 51 | if not name.startswith("transformer."): 52 | name = "transformer." + name 53 | param = params_dict[name] 54 | # The HF's GPT-2 implementation uses Conv1D instead of Linear. 55 | # Because of this, we need to transpose the weights. 56 | # Note(zhuohan): the logic below might break quantized models. 57 | for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: 58 | if conv1d_weight_name not in name: 59 | continue 60 | if not name.endswith(".weight"): 61 | continue 62 | # TODO: check megatron 63 | loaded_weight = loaded_weight.t() 64 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 65 | weight_loader(param, loaded_weight) 66 | 67 | 68 | def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: 69 | # NOTE(shengguangming): the megatron llama may have this prefix 70 | prefix = '0.module.module.' 71 | params_dict = dict(vllm_model.named_parameters()) 72 | for name, loaded_weight in actor_weights.items(): 73 | if name[:len(prefix)] == prefix: 74 | name = name[len(prefix):] 75 | if "rotary_emb.inv_freq" in name: 76 | continue 77 | else: 78 | param = params_dict[name] 79 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 80 | weight_loader(param, loaded_weight) 81 | 82 | 83 | def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: 84 | # TODO: need to implement a general way to deal with prefix 85 | prefix = '0.module.module.' 86 | params_dict = dict(vllm_model.named_parameters()) 87 | for name, loaded_weight in actor_weights.items(): 88 | if name[:len(prefix)] == prefix: 89 | name = name[len(prefix):] 90 | if "rotary_emb.inv_freq" in name: 91 | continue 92 | else: 93 | param = params_dict[name] 94 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 95 | weight_loader(param, loaded_weight) 96 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_4_2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_4_2/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models 15 | 16 | from typing import Dict, Union, Optional, Iterable, Tuple 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 23 | 24 | 25 | def update_hf_weight_loader(): 26 | from vllm.model_executor.models.gemma import GemmaForCausalLM 27 | GemmaForCausalLM.load_weights = gemma_load_weights 28 | 29 | 30 | def gemma_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): 31 | stacked_params_mapping = [ 32 | # (param_name, shard_name, shard_id) 33 | ("qkv_proj", "q_proj", "q"), 34 | ("qkv_proj", "k_proj", "k"), 35 | ("qkv_proj", "v_proj", "v"), 36 | ("gate_up_proj", "gate_proj", 0), 37 | ("gate_up_proj", "up_proj", 1), 38 | ] 39 | params_dict = dict(self.named_parameters()) 40 | loaded_params = set() 41 | for name, loaded_weight in weights: 42 | for (param_name, shard_name, shard_id) in stacked_params_mapping: 43 | if shard_name not in name: 44 | continue 45 | name = name.replace(shard_name, param_name) 46 | # Skip loading extra bias for GPTQ models. 47 | if name.endswith(".bias") and name not in params_dict: 48 | continue 49 | param = params_dict[name] 50 | weight_loader = param.weight_loader 51 | weight_loader(param, loaded_weight, shard_id) 52 | break 53 | else: 54 | # lm_head is not used in vllm as it is tied with embed_token. 55 | # To prevent errors, skip loading lm_head.weight. 56 | if "lm_head.weight" in name: 57 | continue 58 | # Skip loading extra bias for GPTQ models. 59 | if name.endswith(".bias") and name not in params_dict: 60 | continue 61 | # GemmaRMSNorm is different from Llama's in that it multiplies 62 | # (1 + weight) to the output, instead of just weight. 63 | if "norm.weight" in name: 64 | norm_weight = loaded_weight + 1.0 # prevent inplace modify actor weights 65 | param = params_dict[name] 66 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 67 | weight_loader(param, norm_weight) 68 | else: 69 | param = params_dict[name] 70 | weight_loader = getattr(param, "weight_loader", default_weight_loader) 71 | weight_loader(param, loaded_weight) 72 | loaded_params.add(name) 73 | unloaded_params = params_dict.keys() - loaded_params 74 | if unloaded_params: 75 | raise RuntimeError("Some weights are not initialized from checkpoints: " 76 | f"{unloaded_params}") 77 | 78 | 79 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 80 | assert isinstance(actor_weights, Dict) 81 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 82 | vllm_model.load_weights(actor_weights.items()) 83 | for _, module in vllm_model.named_modules(): 84 | quant_method = getattr(module, "quant_method", None) 85 | if quant_method is not None: 86 | quant_method.process_weights_after_loading(module) 87 | # FIXME: Remove this after Mixtral is updated 88 | # to use quant_method. 89 | if hasattr(module, "process_weights_after_loading"): 90 | module.process_weights_after_loading() 91 | vllm_model = vllm_model.cuda() 92 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_4_2/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | from vllm.lora.request import LoRARequest 21 | from vllm.utils import make_async, LRUCache 22 | from vllm.transformers_utils.tokenizers import * 23 | 24 | 25 | class TokenizerGroup: 26 | """A group of tokenizers that can be used for LoRA adapters.""" 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 29 | max_input_length: Optional[int]): 30 | self.enable_lora = enable_lora 31 | self.max_input_length = max_input_length 32 | self.tokenizer = tokenizer 33 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None 34 | 35 | def ping(self) -> bool: 36 | """Check if the tokenizer group is alive.""" 37 | return True 38 | 39 | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: 40 | """Get the maximum input length for the LoRA request.""" 41 | return self.max_input_length 42 | 43 | def encode(self, 44 | prompt: str, 45 | request_id: Optional[str] = None, 46 | lora_request: Optional[LoRARequest] = None) -> List[int]: 47 | tokenizer = self.get_lora_tokenizer(lora_request) 48 | return tokenizer.encode(prompt) 49 | 50 | async def encode_async(self, 51 | prompt: str, 52 | request_id: Optional[str] = None, 53 | lora_request: Optional[LoRARequest] = None) -> List[int]: 54 | tokenizer = await self.get_lora_tokenizer_async(lora_request) 55 | return tokenizer.encode(prompt) 56 | 57 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": 58 | if not lora_request or not self.enable_lora: 59 | return self.tokenizer 60 | if lora_request.lora_int_id not in self.lora_tokenizers: 61 | # TODO(sgm): the lora tokenizer is also passed, but may be different 62 | tokenizer = self.tokenizer 63 | # tokenizer = (get_lora_tokenizer( 64 | # lora_request, **self.tokenizer_config) or self.tokenizer) 65 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) 66 | return tokenizer 67 | else: 68 | return self.lora_tokenizers.get(lora_request.lora_int_id) 69 | 70 | # FIXME(sgm): for simplicity, we assign the special token here 71 | @property 72 | def pad_token_id(self): 73 | return self.tokenizer.pad_token_id 74 | 75 | @property 76 | def eos_token_id(self): 77 | return self.tokenizer.eos_token_id 78 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_5_4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models 15 | 16 | from typing import Dict, Union, Optional, Iterable, Tuple 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 22 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 23 | 24 | 25 | def update_hf_weight_loader(): 26 | print('no hf weight loader need to be updated') 27 | return 28 | 29 | 30 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 31 | assert isinstance(actor_weights, Dict) 32 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 33 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): 34 | del actor_weights["lm_head.weight"] 35 | vllm_model.load_weights(actor_weights.items()) 36 | for _, module in vllm_model.named_modules(): 37 | quant_method = getattr(module, "quant_method", None) 38 | if quant_method is not None: 39 | quant_method.process_weights_after_loading(module) 40 | # FIXME: Remove this after Mixtral is updated 41 | # to use quant_method. 42 | if hasattr(module, "process_weights_after_loading"): 43 | module.process_weights_after_loading() 44 | vllm_model = vllm_model.cuda() 45 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) 19 | 20 | from vllm.lora.request import LoRARequest 21 | from vllm.utils import make_async, LRUCache 22 | from vllm.transformers_utils.tokenizers import * 23 | 24 | 25 | class TokenizerGroup: 26 | """A group of tokenizers that can be used for LoRA adapters.""" 27 | 28 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 29 | max_input_length: Optional[int]): 30 | self.enable_lora = enable_lora 31 | self.max_input_length = max_input_length 32 | self.tokenizer = tokenizer 33 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None 34 | 35 | def ping(self) -> bool: 36 | """Check if the tokenizer group is alive.""" 37 | return True 38 | 39 | def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: 40 | """Get the maximum input length for the LoRA request.""" 41 | return self.max_input_length 42 | 43 | def encode(self, 44 | prompt: str, 45 | request_id: Optional[str] = None, 46 | lora_request: Optional[LoRARequest] = None) -> List[int]: 47 | tokenizer = self.get_lora_tokenizer(lora_request) 48 | return tokenizer.encode(prompt) 49 | 50 | async def encode_async(self, 51 | prompt: str, 52 | request_id: Optional[str] = None, 53 | lora_request: Optional[LoRARequest] = None) -> List[int]: 54 | tokenizer = await self.get_lora_tokenizer_async(lora_request) 55 | return tokenizer.encode(prompt) 56 | 57 | def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": 58 | if not lora_request or not self.enable_lora: 59 | return self.tokenizer 60 | if lora_request.lora_int_id not in self.lora_tokenizers: 61 | # TODO(sgm): the lora tokenizer is also passed, but may be different 62 | tokenizer = self.tokenizer 63 | # tokenizer = (get_lora_tokenizer( 64 | # lora_request, **self.tokenizer_config) or self.tokenizer) 65 | self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) 66 | return tokenizer 67 | else: 68 | return self.lora_tokenizers.get(lora_request.lora_int_id) 69 | 70 | # FIXME(sgm): for simplicity, we assign the special token here 71 | @property 72 | def pad_token_id(self): 73 | return self.tokenizer.pad_token_id 74 | 75 | @property 76 | def eos_token_id(self): 77 | return self.tokenizer.eos_token_id 78 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_6_3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py 15 | 16 | import os 17 | from dataclasses import dataclass 18 | 19 | from transformers import PretrainedConfig 20 | from vllm.config import EngineConfig 21 | from vllm.engine.arg_utils import EngineArgs 22 | 23 | from .config import LoadConfig, ModelConfig 24 | 25 | 26 | @dataclass 27 | class EngineArgs(EngineArgs): 28 | model_hf_config: PretrainedConfig = None # for verl 29 | 30 | def __post_init__(self): 31 | pass 32 | 33 | def create_model_config(self) -> ModelConfig: 34 | return ModelConfig( 35 | hf_config=self.model_hf_config, 36 | tokenizer_mode=self.tokenizer_mode, 37 | trust_remote_code=self.trust_remote_code, 38 | dtype=self.dtype, 39 | seed=self.seed, 40 | revision=self.revision, 41 | code_revision=self.code_revision, 42 | rope_scaling=self.rope_scaling, 43 | rope_theta=self.rope_theta, 44 | tokenizer_revision=self.tokenizer_revision, 45 | max_model_len=self.max_model_len, 46 | quantization=self.quantization, 47 | quantization_param_path=self.quantization_param_path, 48 | enforce_eager=self.enforce_eager, 49 | max_context_len_to_capture=self.max_context_len_to_capture, 50 | max_seq_len_to_capture=self.max_seq_len_to_capture, 51 | max_logprobs=self.max_logprobs, 52 | disable_sliding_window=self.disable_sliding_window, 53 | skip_tokenizer_init=self.skip_tokenizer_init, 54 | served_model_name=self.served_model_name, 55 | limit_mm_per_prompt=self.limit_mm_per_prompt, 56 | use_async_output_proc=not self.disable_async_output_proc, 57 | override_neuron_config=self.override_neuron_config, 58 | config_format=self.config_format, 59 | mm_processor_kwargs=self.mm_processor_kwargs, 60 | ) 61 | 62 | def create_load_config(self) -> LoadConfig: 63 | return LoadConfig( 64 | load_format=self.load_format, 65 | download_dir=self.download_dir, 66 | model_loader_extra_config=self.model_loader_extra_config, 67 | ignore_patterns=self.ignore_patterns, 68 | ) 69 | 70 | def create_engine_config(self) -> EngineConfig: 71 | engine_config = super().create_engine_config() 72 | 73 | # NOTE[VERL]: Use the world_size set by torchrun 74 | world_size = int(os.getenv("WORLD_SIZE", "-1")) 75 | assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" 76 | engine_config.parallel_config.world_size = world_size 77 | 78 | return engine_config 79 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_6_3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py 15 | 16 | import enum 17 | import json 18 | from dataclasses import dataclass, field 19 | from typing import TYPE_CHECKING, List, Optional, Union 20 | 21 | from transformers import PretrainedConfig 22 | 23 | # Add for verl 24 | from vllm.config import ModelConfig 25 | from vllm.logger import init_logger 26 | from vllm.utils import is_hip 27 | 28 | if TYPE_CHECKING: 29 | from vllm.model_executor.model_loader.loader import BaseModelLoader 30 | 31 | logger = init_logger(__name__) 32 | 33 | 34 | class LoadFormat(str, enum.Enum): 35 | AUTO = "auto" 36 | MEGATRON = "megatron" 37 | HF = "hf" 38 | DTENSOR = "dtensor" 39 | DUMMY_HF = "dummy_hf" 40 | DUMMY_MEGATRON = "dummy_megatron" 41 | DUMMY_DTENSOR = "dummy_dtensor" 42 | 43 | 44 | class ModelConfig(ModelConfig): 45 | 46 | def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: 47 | super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) 48 | self.hf_config = hf_config 49 | 50 | 51 | @dataclass 52 | class LoadConfig: 53 | """ 54 | download_dir: Directory to download and load the weights, default to the 55 | default cache directory of huggingface. 56 | load_format: The format of the model weights to load: 57 | "auto" will try to load the weights in the safetensors format and 58 | fall back to the pytorch bin format if safetensors format is 59 | not available. 60 | "pt" will load the weights in the pytorch bin format. 61 | "safetensors" will load the weights in the safetensors format. 62 | "npcache" will load the weights in pytorch format and store 63 | a numpy cache to speed up the loading. 64 | "dummy" will initialize the weights with random values, which is 65 | mainly for profiling. 66 | "tensorizer" will use CoreWeave's tensorizer library for 67 | fast weight loading. 68 | "bitsandbytes" will load nf4 type weights. 69 | ignore_patterns: The list of patterns to ignore when loading the model. 70 | Default to "original/**/*" to avoid repeated loading of llama's 71 | checkpoints. 72 | 73 | """ 74 | 75 | load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO 76 | download_dir: Optional[str] = None 77 | model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) 78 | ignore_patterns: Optional[Union[List[str], str]] = None 79 | 80 | def __post_init__(self): 81 | model_loader_extra_config = self.model_loader_extra_config or {} 82 | if isinstance(model_loader_extra_config, str): 83 | self.model_loader_extra_config = json.loads(model_loader_extra_config) 84 | self._verify_load_format() 85 | 86 | if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: 87 | logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) 88 | else: 89 | self.ignore_patterns = ["original/**/*"] 90 | 91 | def _verify_load_format(self) -> None: 92 | if not isinstance(self.load_format, str): 93 | return 94 | 95 | load_format = self.load_format.lower() 96 | self.load_format = LoadFormat(load_format) 97 | 98 | rocm_not_supported_load_format: List[str] = [] 99 | if is_hip() and load_format in rocm_not_supported_load_format: 100 | rocm_supported_load_format = [ 101 | f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) 102 | ] 103 | raise ValueError(f"load format '{load_format}' is not supported in ROCm. " 104 | f"Supported load formats are " 105 | f"{rocm_supported_load_format}") 106 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader 15 | 16 | from typing import Dict 17 | 18 | import torch.nn as nn 19 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype 20 | 21 | 22 | def update_hf_weight_loader(): 23 | print("no hf weight loader need to be updated") 24 | return 25 | 26 | 27 | def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): 28 | assert isinstance(actor_weights, Dict) 29 | with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO 30 | if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): 31 | del actor_weights["lm_head.weight"] 32 | vllm_model.load_weights(actor_weights.items()) 33 | for _, module in vllm_model.named_modules(): 34 | quant_method = getattr(module, "quant_method", None) 35 | if quant_method is not None: 36 | quant_method.process_weights_after_loading(module) 37 | # FIXME: Remove this after Mixtral is updated 38 | # to use quant_method. 39 | if hasattr(module, "process_weights_after_loading"): 40 | module.process_weights_after_loading() 41 | vllm_model = vllm_model.cuda() 42 | -------------------------------------------------------------------------------- /verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2023 The vLLM team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py 15 | 16 | from typing import Optional 17 | 18 | from transformers import PreTrainedTokenizer 19 | from vllm.transformers_utils.tokenizer_group import TokenizerGroup 20 | from vllm.utils import LRUCache 21 | 22 | 23 | class TokenizerGroup(TokenizerGroup): 24 | """A group of tokenizers that can be used for LoRA adapters.""" 25 | 26 | def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, 27 | max_input_length: Optional[int]): 28 | self.enable_lora = enable_lora 29 | self.max_input_length = max_input_length 30 | self.tokenizer = tokenizer 31 | self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None 32 | 33 | # FIXME(sgm): for simplicity, we assign the special token here 34 | @property 35 | def pad_token_id(self): 36 | return self.tokenizer.pad_token_id 37 | 38 | @property 39 | def eos_token_id(self): 40 | return self.tokenizer.eos_token_id 41 | -------------------------------------------------------------------------------- /verl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/trainer/config/evaluation.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | path: /tmp/math_Qwen2-7B-Instruct.parquet 3 | prompt_key: prompt 4 | response_key: responses 5 | data_source_key: data_source 6 | reward_model_key: reward_model -------------------------------------------------------------------------------- /verl/trainer/config/generation.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | nnodes: 1 3 | n_gpus_per_node: 8 4 | 5 | data: 6 | path: ~/data/rlhf/math/test.parquet 7 | prompt_key: prompt 8 | n_samples: 5 9 | output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet 10 | batch_size: 128 11 | 12 | model: 13 | path: ~/models/Qwen2-7B-Instruct 14 | external_lib: null 15 | rollout: 16 | name: vllm 17 | temperature: 1.0 18 | top_k: 50 # 0 for hf rollout, -1 for vllm rollout 19 | top_p: 0.7 20 | prompt_length: 1536 21 | response_length: 512 22 | # for vllm rollout 23 | dtype: bfloat16 # should align with FSDP 24 | gpu_memory_utilization: 0.5 25 | ignore_eos: False 26 | micro_batch_size: 256 27 | enforce_eager: True 28 | free_cache_engine: True 29 | load_format: dummy_dtensor 30 | tensor_model_parallel_size: 1 31 | max_num_batched_tokens: 8192 32 | max_num_seqs: 1024 33 | log_prob_micro_batch_size: 8 34 | # for hf rollout 35 | do_sample: True -------------------------------------------------------------------------------- /verl/trainer/config/model/lora_enabled.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /verl/trainer/config/ppo_megatron_trainer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | tokenizer: null 3 | train_files: ~/data/rlhf/gsm8k/train.parquet 4 | val_files: ~/data/rlhf/gsm8k/test.parquet 5 | prompt_key: prompt 6 | max_prompt_length: 512 7 | max_response_length: 512 8 | train_batch_size: 1024 9 | val_batch_size: 1312 10 | return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs 11 | return_raw_chat: False 12 | 13 | actor_rollout_ref: 14 | hybrid_engine: True 15 | model: 16 | path: ~/models/deepseek-llm-7b-chat 17 | external_lib: null 18 | override_config: {} 19 | enable_gradient_checkpointing: False 20 | actor: 21 | strategy: megatron # This is for backward-compatibility 22 | ppo_mini_batch_size: 256 23 | ppo_micro_batch_size: 64 24 | clip_ratio: 0.2 25 | entropy_coeff: 0.001 26 | ppo_epochs: 1 27 | shuffle: True 28 | optim: 29 | lr: 1e-6 30 | clip_grad: 1.0 31 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime 32 | min_lr_ratio: null # only useful for warmup with cosine 33 | warmup_style: constant # select from constant/cosine 34 | total_training_steps: -1 # must be override by program 35 | megatron: 36 | tensor_model_parallel_size: 4 37 | pipeline_model_parallel_size: 1 38 | num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. 39 | sequence_parallel: True 40 | seed: 1 41 | load_weight: True 42 | ref: 43 | megatron: 44 | tensor_model_parallel_size: 4 45 | pipeline_model_parallel_size: 1 46 | num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. 47 | sequence_parallel: True 48 | seed: 1 49 | load_weight: True 50 | param_offload: False 51 | log_prob_micro_batch_size: 32 52 | rollout: 53 | name: vllm 54 | temperature: 1.0 55 | top_k: -1 # 0 for hf rollout, -1 for vllm rollout 56 | top_p: 1 57 | prompt_length: ${data.max_prompt_length} # for xperf_gpt 58 | response_length: ${data.max_response_length} 59 | # for vllm rollout 60 | dtype: bfloat16 # should align with FSDP 61 | gpu_memory_utilization: 0.5 62 | ignore_eos: False 63 | enforce_eager: True 64 | free_cache_engine: True 65 | load_format: dummy_megatron 66 | tensor_model_parallel_size: 2 67 | max_num_batched_tokens: 8192 68 | max_num_seqs: 1024 69 | log_prob_micro_batch_size: 2 70 | # for hf rollout 71 | do_sample: True 72 | layer_name_map: 73 | qkv_layer_name: qkv 74 | gate_proj_layer_name: gate_up 75 | # number of responses (i.e. num sample times) 76 | n: 1 77 | 78 | critic: 79 | strategy: megatron 80 | optim: 81 | lr: 1e-5 82 | clip_grad: 1.0 83 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime 84 | min_lr_ratio: null # only useful for warmup with cosine 85 | warmup_style: constant # select from constant/cosine 86 | total_training_steps: -1 # must be override by program 87 | model: 88 | path: ~/models/deepseek-llm-7b-chat 89 | tokenizer_path: ${actor_rollout_ref.model.path} 90 | override_config: {} 91 | external_lib: ${actor_rollout_ref.model.external_lib} 92 | enable_gradient_checkpointing: False 93 | megatron: 94 | tensor_model_parallel_size: 4 95 | pipeline_model_parallel_size: 1 96 | num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. 97 | sequence_parallel: True 98 | seed: 1 99 | load_weight: True 100 | ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} 101 | ppo_micro_batch_size: 2 102 | ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} 103 | shuffle: ${actor_rollout_ref.actor.shuffle} 104 | cliprange_value: 0.5 105 | kl_ctrl: 106 | type: fixed 107 | kl_coef: 0.001 108 | 109 | reward_model: 110 | enable: False 111 | strategy: megatron 112 | megatron: 113 | tensor_model_parallel_size: 4 114 | pipeline_model_parallel_size: 1 115 | num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. 116 | sequence_parallel: True 117 | seed: 1 118 | model: 119 | input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical 120 | path: ~/models/FsfairX-LLaMA3-RM-v0.1 121 | external_lib: ${actor_rollout_ref.model.external_lib} 122 | load_weight: True 123 | param_offload: False 124 | micro_batch_size: 64 125 | max_length: null 126 | 127 | algorithm: 128 | gamma: 1.0 129 | lam: 1.0 130 | adv_estimator: gae 131 | kl_penalty: kl # how to estimate kl divergence 132 | kl_ctrl: 133 | type: fixed 134 | kl_coef: 0.001 135 | 136 | trainer: 137 | total_epochs: 30 138 | total_training_steps: null 139 | project_name: verl_examples 140 | experiment_name: gsm8k 141 | logger: ['console', 'wandb'] 142 | nnodes: 1 143 | n_gpus_per_node: 8 144 | save_freq: -1 145 | test_freq: 2 146 | critic_warmup: 0 147 | default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} 148 | default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} 149 | -------------------------------------------------------------------------------- /verl/trainer/config/sft_trainer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_batch_size: 256 3 | micro_batch_size: 16 # this is also val batch size 4 | train_files: ~/data/gsm8k/train.parquet 5 | val_files: ~/data/gsm8k/test.parquet 6 | prompt_key: question 7 | response_key: answer 8 | max_length: 1024 9 | truncation: error 10 | balance_dp_token: False 11 | chat_template: null 12 | model: 13 | partial_pretrain: ~/models/gemma-1.1-7b-it 14 | fsdp_config: 15 | wrap_policy: 16 | min_num_params: 0 17 | cpu_offload: False 18 | offload_params: False 19 | external_lib: null 20 | enable_gradient_checkpointing: False 21 | trust_remote_code: False 22 | optim: 23 | lr: 1e-5 24 | betas: [0.9, 0.95] 25 | weight_decay: 0.01 26 | warmup_steps_ratio: 0.1 27 | clip_grad: 1.0 28 | 29 | trainer: 30 | default_local_dir: /tmp/sft_model 31 | default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here 32 | resume_path: null 33 | project_name: gsm8k-sft 34 | experiment_name: test 35 | total_epochs: 4 36 | logger: ['console'] 37 | seed: 1 38 | 39 | -------------------------------------------------------------------------------- /verl/trainer/main_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Offline evaluate the performance of a generated file using reward model and ground truth verifier. 16 | The input is a parquet file that contains N generated sequences and (optional) the ground truth. 17 | 18 | """ 19 | 20 | import hydra 21 | from verl.utils.fs import copy_local_path_from_hdfs 22 | from verl.utils.reward_score import math, gsm8k, kk 23 | import pandas as pd 24 | import numpy as np 25 | 26 | 27 | def select_reward_fn(data_source): 28 | if data_source == 'lighteval/MATH': 29 | return math.compute_score 30 | if 'kk' in data_source: 31 | return kk.compute_score 32 | else: 33 | raise NotImplementedError 34 | 35 | 36 | @hydra.main(config_path='config', config_name='evaluation', version_base=None) 37 | def main(config): 38 | local_path = copy_local_path_from_hdfs(config.data.path) 39 | dataset = pd.read_parquet(local_path) 40 | prompts = dataset[config.data.prompt_key] 41 | responses = dataset[config.data.response_key] 42 | data_sources = dataset[config.data.data_source_key] 43 | reward_model_data = dataset[config.data.reward_model_key] 44 | 45 | passes = 0 46 | 47 | total = len(dataset) 48 | 49 | for i in range(total): 50 | response_lst = responses[i] 51 | data_source = data_sources[i] 52 | # select reward score based on data_source 53 | prompt = prompts[i] 54 | reward_data = reward_model_data[i] 55 | reward_fn = select_reward_fn(data_source) 56 | ground_truth = reward_data['ground_truth'] 57 | score_lst = [] 58 | for r in response_lst: 59 | score = reward_fn(r, ground_truth) 60 | score_lst.append(score) 61 | 62 | max_score = np.max(score_lst) 63 | 64 | if max_score == 3: 65 | passes += 1 66 | 67 | print(f'pass@5: {passes / total}') 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /verl/trainer/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/trainer/runtime_env.yaml: -------------------------------------------------------------------------------- 1 | working_dir: ./ 2 | excludes: ["/.git/"] 3 | env_vars: 4 | TORCH_NCCL_AVOID_RECORD_STREAMS: "1" 5 | VLLM_ATTENTION_BACKEND: "XFORMERS" -------------------------------------------------------------------------------- /verl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import tokenizer 16 | from .tokenizer import * 17 | 18 | __all__ = tokenizer.__all__ -------------------------------------------------------------------------------- /verl/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from omegaconf import DictConfig 18 | 19 | 20 | def update_dict_with_config(dictionary: Dict, config: DictConfig): 21 | for key in dictionary: 22 | if hasattr(config, key): 23 | dictionary[key] = getattr(config, key) 24 | -------------------------------------------------------------------------------- /verl/utils/dataset/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Format 2 | ## RLHF dataset 3 | We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers. 4 | 5 | Math problems 6 | ```json 7 | { 8 | "data_source": "openai/gsm8k", 9 | "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}], 10 | "ability": "math", 11 | "reward_model": { 12 | "style": "rule", 13 | "ground_truth": ["72"] 14 | }, 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /verl/utils/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .rl_dataset import RLHFDataset 16 | from .rm_dataset import RMDataset 17 | from .sft_dataset import SFTDataset 18 | -------------------------------------------------------------------------------- /verl/utils/debug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .performance import log_gpu_memory_usage -------------------------------------------------------------------------------- /verl/utils/debug/performance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.distributed as dist 17 | import logging 18 | 19 | 20 | def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): 21 | if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): 22 | memory_allocated = torch.cuda.memory_allocated() / 1024**3 23 | memory_reserved = torch.cuda.memory_reserved() / 1024**3 24 | 25 | message = f'{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}' 26 | 27 | if logger is None: 28 | print(message) 29 | else: 30 | logger.log(msg=message, level=level) 31 | -------------------------------------------------------------------------------- /verl/utils/debug/trajectory_tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Trajectory tracker can be inserted into code to save the intermediate results. 16 | The results will be dump to hdfs for offline comparison. 17 | Each process will have a client that first move all the tensors to CPU 18 | """ 19 | 20 | from verl.utils.hdfs_io import makedirs, copy 21 | import torch 22 | import os 23 | import ray 24 | import io 25 | import tempfile 26 | 27 | from collections import deque 28 | 29 | remote_copy = ray.remote(copy) 30 | 31 | 32 | @ray.remote 33 | def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose): 34 | filename = name + '.pth' 35 | with tempfile.TemporaryDirectory() as tmpdirname: 36 | local_filepath = os.path.join(tmpdirname, filename) 37 | with open(local_filepath, 'wb') as f: 38 | f.write(data.getbuffer()) 39 | # upload to hdfs 40 | 41 | if verbose: 42 | print(f'Saving {local_filepath} to {hdfs_dir}') 43 | try: 44 | copy(local_filepath, hdfs_dir) 45 | except Exception as e: 46 | print(e) 47 | 48 | 49 | @ray.remote 50 | class TrajectoryTracker(): 51 | 52 | def __init__(self, hdfs_dir, verbose) -> None: 53 | self.hdfs_dir = hdfs_dir 54 | makedirs(hdfs_dir) 55 | self.verbose = verbose 56 | 57 | self.handle = deque() 58 | 59 | def dump(self, data: io.BytesIO, name): 60 | # get a temp file and write to it 61 | self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose)) 62 | 63 | def wait_for_hdfs(self): 64 | while len(self.handle) != 0: 65 | future = self.handle.popleft() 66 | ray.get(future) 67 | 68 | 69 | def dump_data(data, name): 70 | enable = os.getenv('VERL_ENABLE_TRACKER', '0') == '1' 71 | if not enable: 72 | return 73 | buffer = io.BytesIO() 74 | torch.save(data, buffer) 75 | tracker = get_trajectory_tracker() 76 | ray.get(tracker.dump.remote(buffer, name)) 77 | 78 | 79 | def get_trajectory_tracker(): 80 | hdfs_dir = os.getenv('VERL_TRACKER_HDFS_DIR', default=None) 81 | verbose = os.getenv('VERL_TRACKER_VERBOSE', default='0') == '1' 82 | assert hdfs_dir is not None 83 | tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, 84 | lifetime="detached").remote(hdfs_dir, verbose) 85 | return tracker 86 | 87 | 88 | if __name__ == '__main__': 89 | # testing 90 | os.environ['VERL_ENABLE_TRACKER'] = '1' 91 | os.environ['VERL_TRACKER_HDFS_DIR'] = '~/debug/test' 92 | 93 | @ray.remote 94 | def process(iter): 95 | data = {'obs': torch.randn(10, 20)} 96 | dump_data(data, f'process_{iter}_obs') 97 | 98 | ray.init() 99 | 100 | output_lst = [] 101 | 102 | for i in range(10): 103 | output_lst.append(process.remote(i)) 104 | 105 | out = ray.get(output_lst) 106 | 107 | tracker = get_trajectory_tracker() 108 | ray.get(tracker.wait_for_hdfs.remote()) 109 | -------------------------------------------------------------------------------- /verl/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for distributed training.""" 15 | import os 16 | 17 | 18 | def initialize_global_process_group(timeout_second=36000): 19 | import torch.distributed 20 | from datetime import timedelta 21 | torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) 22 | local_rank = int(os.environ["LOCAL_RANK"]) 23 | rank = int(os.environ["RANK"]) 24 | world_size = int(os.environ["WORLD_SIZE"]) 25 | 26 | if torch.distributed.is_initialized(): 27 | torch.cuda.set_device(local_rank) 28 | return local_rank, rank, world_size 29 | -------------------------------------------------------------------------------- /verl/utils/flops_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import PretrainedConfig, Qwen2Config, LlamaConfig 17 | 18 | VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig) 19 | 20 | 21 | def get_device_flops(unit="T"): 22 | 23 | def unit_convert(number, level): 24 | units = ["B", "K", "M", "G", "T", "P"] 25 | if number <= 0: 26 | return number 27 | ptr = 0 28 | while ptr < len(units) and units[ptr] != level: 29 | number /= 1000 30 | ptr += 1 31 | return number 32 | 33 | device_name = torch.cuda.get_device_name() 34 | flops = float("inf") # INF flops for unkown gpu type 35 | if "H100" in device_name or "H800" in device_name: 36 | flops = 989e12 37 | elif "A100" in device_name or "A800" in device_name: 38 | flops = 312e12 39 | elif "L40" in device_name: 40 | flops = 181.05e12 41 | elif "L20" in device_name: 42 | flops = 119.5e12 43 | elif "H20" in device_name: 44 | flops = 148e12 45 | elif "910B" in device_name: 46 | flops = 354e12 47 | flops_unit = unit_convert(flops, unit) 48 | return flops_unit 49 | 50 | 51 | class FlopsCounter: 52 | """ 53 | Used to count mfu during training loop 54 | 55 | Example: 56 | flops_counter = FlopsCounter(config) 57 | flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time) 58 | 59 | """ 60 | 61 | def __init__(self, config: PretrainedConfig): 62 | if not isinstance(config, VALID_CONFIG_TYPE): 63 | print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. " 64 | f"MFU will always be zero.") 65 | 66 | self.estimate_func = {"qwen2": self._estimate_qwen2_flops, 'llama': self._estimate_qwen2_flops} 67 | self.config = config 68 | 69 | def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time): 70 | return 0 71 | 72 | def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): 73 | assert isinstance(self.config, (Qwen2Config, LlamaConfig)) 74 | hidden_size = self.config.hidden_size 75 | vocab_size = self.config.vocab_size 76 | num_hidden_layers = self.config.num_hidden_layers 77 | num_key_value_heads = self.config.num_key_value_heads 78 | num_attention_heads = self.config.num_attention_heads 79 | intermediate_size = self.config.intermediate_size 80 | 81 | head_dim = hidden_size // num_attention_heads 82 | q_size = num_attention_heads * head_dim 83 | k_size = num_key_value_heads * head_dim 84 | v_size = num_key_value_heads * head_dim 85 | 86 | # non-attn per layer parm 87 | # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp 88 | mlp_N = hidden_size * intermediate_size * 3 89 | attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) 90 | emd_and_lm_head_N = vocab_size * hidden_size * 2 91 | # non-attn all_layer parm 92 | dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N 93 | # non-attn all_layer & all_token fwd & bwd flops 94 | dense_N_flops = 6 * dense_N * tokens_sum 95 | 96 | # attn all_layer & all_token fwd & bwd flops 97 | seqlen_square_sum = 0 98 | for seqlen in batch_seqlens: 99 | seqlen_square_sum += seqlen * seqlen 100 | attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers 101 | 102 | # all_layer & all_token fwd & bwd flops 103 | flops_all_token = dense_N_flops + attn_qkv_flops 104 | flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 105 | return flops_achieved 106 | 107 | def estimate_flops(self, batch_seqlens, delta_time): 108 | """ 109 | Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. 110 | 111 | Args: 112 | batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch. 113 | delta_time (float): The time taken to process the batch, in seconds. 114 | 115 | Returns: 116 | estimated_flops (float): The estimated FLOPS based on the input tokens and time. 117 | promised_flops (float): The expected FLOPS of the current device. 118 | """ 119 | tokens_sum = sum(batch_seqlens) 120 | func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops) 121 | estimated_flops = func(tokens_sum, batch_seqlens, delta_time) 122 | promised_flops = get_device_flops() 123 | return estimated_flops, promised_flops 124 | -------------------------------------------------------------------------------- /verl/utils/fs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # -*- coding: utf-8 -*- 17 | """File-system agnostic IO APIs""" 18 | import os 19 | import tempfile 20 | import hashlib 21 | 22 | from .hdfs_io import copy, makedirs, exists 23 | 24 | __all__ = ["copy", "exists", "makedirs"] 25 | 26 | _HDFS_PREFIX = "hdfs://" 27 | 28 | 29 | def _is_non_local(path): 30 | return path.startswith(_HDFS_PREFIX) 31 | 32 | 33 | def md5_encode(path: str) -> str: 34 | return hashlib.md5(path.encode()).hexdigest() 35 | 36 | 37 | def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str: 38 | """Return a local temp path that joins cache_dir and basename of hdfs_path 39 | 40 | Args: 41 | hdfs_path: 42 | cache_dir: 43 | 44 | Returns: 45 | 46 | """ 47 | # make a base64 encoding of hdfs_path to avoid directory conflict 48 | encoded_hdfs_path = md5_encode(hdfs_path) 49 | temp_dir = os.path.join(cache_dir, encoded_hdfs_path) 50 | os.makedirs(temp_dir, exist_ok=True) 51 | dst = os.path.join(temp_dir, os.path.basename(hdfs_path)) 52 | return dst 53 | 54 | 55 | def copy_local_path_from_hdfs(src: str, cache_dir=None, filelock='.file.lock', verbose=False) -> str: 56 | """Copy src from hdfs to local if src is on hdfs or directly return src. 57 | If cache_dir is None, we will use the default cache dir of the system. Note that this may cause conflicts if 58 | the src name is the same between calls 59 | 60 | Args: 61 | src (str): a HDFS path of a local path 62 | 63 | Returns: 64 | a local path of the copied file 65 | """ 66 | from filelock import FileLock 67 | 68 | assert src[-1] != '/', f'Make sure the last char in src is not / because it will cause error. Got {src}' 69 | 70 | if _is_non_local(src): 71 | # download from hdfs to local 72 | if cache_dir is None: 73 | # get a temp folder 74 | cache_dir = tempfile.gettempdir() 75 | os.makedirs(cache_dir, exist_ok=True) 76 | assert os.path.exists(cache_dir) 77 | local_path = get_local_temp_path(src, cache_dir) 78 | # get a specific lock 79 | filelock = md5_encode(src) + '.lock' 80 | lock_file = os.path.join(cache_dir, filelock) 81 | with FileLock(lock_file=lock_file): 82 | if not os.path.exists(local_path): 83 | if verbose: 84 | print(f'Copy from {src} to {local_path}') 85 | copy(src, local_path) 86 | return local_path 87 | else: 88 | return src 89 | -------------------------------------------------------------------------------- /verl/utils/hdfs_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import shutil 17 | import logging 18 | 19 | logger = logging.getLogger(__file__) 20 | logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) 21 | 22 | _HDFS_PREFIX = "hdfs://" 23 | 24 | _HDFS_BIN_PATH = shutil.which('hdfs') 25 | 26 | 27 | def exists(path: str, **kwargs) -> bool: 28 | r"""Works like os.path.exists() but supports hdfs. 29 | 30 | Test whether a path exists. Returns False for broken symbolic links. 31 | 32 | Args: 33 | path (str): path to test 34 | 35 | Returns: 36 | bool: True if the path exists, False otherwise 37 | """ 38 | if _is_non_local(path): 39 | return _exists(path, **kwargs) 40 | return os.path.exists(path) 41 | 42 | 43 | def _exists(file_path: str): 44 | """ hdfs capable to check whether a file_path is exists """ 45 | if file_path.startswith("hdfs"): 46 | return _run_cmd(_hdfs_cmd(f"-test -e {file_path}")) == 0 47 | return os.path.exists(file_path) 48 | 49 | 50 | def makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None: 51 | r"""Works like os.makedirs() but supports hdfs. 52 | 53 | Super-mkdir; create a leaf directory and all intermediate ones. Works like 54 | mkdir, except that any intermediate path segment (not just the rightmost) 55 | will be created if it does not exist. If the target directory already 56 | exists, raise an OSError if exist_ok is False. Otherwise no exception is 57 | raised. This is recursive. 58 | 59 | Args: 60 | name (str): directory to create 61 | mode (int): file mode bits 62 | exist_ok (bool): if True, do not raise an exception if the directory already exists 63 | kwargs: keyword arguments for hdfs 64 | 65 | """ 66 | if _is_non_local(name): 67 | # TODO(haibin.lin): 68 | # - handle OSError for hdfs(?) 69 | # - support exist_ok for hdfs(?) 70 | _mkdir(name, **kwargs) 71 | else: 72 | os.makedirs(name, mode=mode, exist_ok=exist_ok) 73 | 74 | 75 | def _mkdir(file_path: str) -> bool: 76 | """hdfs mkdir""" 77 | if file_path.startswith("hdfs"): 78 | _run_cmd(_hdfs_cmd(f"-mkdir -p {file_path}")) 79 | else: 80 | os.makedirs(file_path, exist_ok=True) 81 | return True 82 | 83 | 84 | def copy(src: str, dst: str, **kwargs) -> bool: 85 | r"""Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs. 86 | 87 | Copy data and mode bits ("cp src dst"). Return the file's destination. 88 | The destination may be a directory. 89 | If source and destination are the same file, a SameFileError will be 90 | raised. 91 | 92 | Arg: 93 | src (str): source file path 94 | dst (str): destination file path 95 | kwargs: keyword arguments for hdfs copy 96 | 97 | Returns: 98 | str: destination file path 99 | 100 | """ 101 | if _is_non_local(src) or _is_non_local(dst): 102 | # TODO(haibin.lin): 103 | # - handle SameFileError for hdfs files(?) 104 | # - return file destination for hdfs files 105 | return _copy(src, dst) 106 | else: 107 | if os.path.isdir(src): 108 | return shutil.copytree(src, dst, **kwargs) 109 | else: 110 | return shutil.copy(src, dst, **kwargs) 111 | 112 | 113 | def _copy(from_path: str, to_path: str, timeout: int = None) -> bool: 114 | if to_path.startswith("hdfs"): 115 | if from_path.startswith("hdfs"): 116 | returncode = _run_cmd(_hdfs_cmd(f"-cp -f {from_path} {to_path}"), timeout=timeout) 117 | else: 118 | returncode = _run_cmd(_hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout) 119 | else: 120 | if from_path.startswith("hdfs"): 121 | returncode = _run_cmd(_hdfs_cmd(f"-get \ 122 | {from_path} {to_path}"), timeout=timeout) 123 | else: 124 | try: 125 | shutil.copy(from_path, to_path) 126 | returncode = 0 127 | except shutil.SameFileError: 128 | returncode = 0 129 | except Exception as e: 130 | logger.warning(f"copy {from_path} {to_path} failed: {e}") 131 | returncode = -1 132 | return returncode == 0 133 | 134 | 135 | def _run_cmd(cmd: str, timeout=None): 136 | return os.system(cmd) 137 | 138 | 139 | def _hdfs_cmd(cmd: str) -> str: 140 | return f"{_HDFS_BIN_PATH} dfs {cmd}" 141 | 142 | 143 | def _is_non_local(path: str): 144 | return path.startswith(_HDFS_PREFIX) 145 | -------------------------------------------------------------------------------- /verl/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities to check if packages are available. 16 | We assume package availability won't change during runtime. 17 | """ 18 | 19 | from functools import cache 20 | from typing import List 21 | 22 | 23 | @cache 24 | def is_megatron_core_available(): 25 | try: 26 | from megatron.core import parallel_state as mpu 27 | return True 28 | except ImportError: 29 | return False 30 | 31 | 32 | @cache 33 | def is_vllm_available(): 34 | try: 35 | import vllm 36 | return True 37 | except ImportError: 38 | return False 39 | 40 | 41 | def import_external_libs(external_libs=None): 42 | if external_libs is None: 43 | return 44 | if not isinstance(external_libs, List): 45 | external_libs = [external_libs] 46 | import importlib 47 | for external_lib in external_libs: 48 | importlib.import_module(external_lib) 49 | -------------------------------------------------------------------------------- /verl/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/utils/logger/aggregate_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A Ray logger will receive logging info from different processes. 16 | """ 17 | import numbers 18 | from typing import Dict 19 | 20 | 21 | def concat_dict_to_str(dict: Dict, step): 22 | output = [f'step:{step}'] 23 | for k, v in dict.items(): 24 | if isinstance(v, numbers.Number): 25 | output.append(f'{k}:{v:.3f}') 26 | output_str = ' - '.join(output) 27 | return output_str 28 | 29 | 30 | class LocalLogger: 31 | 32 | def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False): 33 | self.print_to_console = print_to_console 34 | if print_to_console: 35 | print('Using LocalLogger is deprecated. The constructor API will change ') 36 | 37 | def flush(self): 38 | pass 39 | 40 | def log(self, data, step): 41 | if self.print_to_console: 42 | print(concat_dict_to_str(data, step=step), flush=True) -------------------------------------------------------------------------------- /verl/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | 18 | def set_basic_config(level): 19 | """ 20 | This function sets the global logging format and level. It will be called when import verl 21 | """ 22 | logging.basicConfig(format='%(levelname)s:%(asctime)s:%(message)s', level=level) 23 | -------------------------------------------------------------------------------- /verl/utils/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/utils/megatron/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | class MemoryBuffer: 19 | 20 | def __init__(self, numel, numel_padded, dtype): 21 | self.numel = numel 22 | self.numel_padded = numel_padded 23 | self.dtype = dtype 24 | self.data = torch.zeros(self.numel_padded, 25 | dtype=self.dtype, 26 | device=torch.cuda.current_device(), 27 | requires_grad=False) 28 | 29 | def zero(self): 30 | """Reset the buffer to zero.""" 31 | self.data.zero_() 32 | 33 | def get(self, shape, start_index): 34 | """Return a tensor with the input `shape` as a view into the 35 | 1-D data starting at `start_index`.""" 36 | end_index = start_index + shape.numel() 37 | assert end_index <= self.numel, \ 38 | 'requested tensor is out of the buffer range.' 39 | buffer_tensor = self.data[start_index:end_index] 40 | buffer_tensor = buffer_tensor.view(shape) 41 | return buffer_tensor 42 | -------------------------------------------------------------------------------- /verl/utils/megatron/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from apex.optimizers import FusedAdam as Adam 17 | from apex.optimizers import FusedSGD as SGD 18 | from megatron.optimizer.distrib_optimizer import DistributedOptimizer 19 | from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler 20 | from megatron.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer 21 | from megatron.optimizer import get_param_groups 22 | 23 | from verl.utils.megatron.optimizer_config import OptimizerConfig 24 | 25 | 26 | def get_megatron_optimizer( 27 | model, 28 | config: OptimizerConfig, 29 | no_weight_decay_cond=None, 30 | scale_lr_cond=None, 31 | lr_mult=1.0, 32 | check_for_nan_in_loss_and_grad=False, 33 | overlap_param_gather=False # add for verl 34 | ): 35 | # Base optimizer. 36 | param_groups = get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr_mult) 37 | 38 | if config.optimizer == 'adam': 39 | optimizer = Adam(param_groups, 40 | lr=config.lr, 41 | weight_decay=config.weight_decay, 42 | betas=(config.adam_beta1, config.adam_beta2), 43 | eps=config.adam_eps) 44 | elif config.optimizer == 'sgd': 45 | optimizer = SGD(param_groups, lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum) 46 | else: 47 | raise Exception('{} optimizer is not supported.'.format(config.optimizer)) 48 | 49 | # Determine whether the params have main-grad field. 50 | params_have_main_grad = True 51 | 52 | # Mixed precision optimizer. 53 | # - Note: both the Float16Optimizer and the DistributedOptimizer inherit 54 | # from the MixedPrecisionOptimizer, which manages any optimizer where 55 | # the model params and main params are distinct. 56 | if config.fp16 or config.bf16 or config.use_distributed_optimizer: 57 | 58 | # Grad scaler: 59 | # if loss-scale is provided, instantiate the constant scaler. 60 | # if we are using fp16 and loss-scale is not present, use a 61 | # dynamic scaler. 62 | # otherwise we are running in bf16 with no loss-scale so 63 | # leave it as None. 64 | grad_scaler = None 65 | 66 | # Constant loss scale. 67 | if config.loss_scale: 68 | grad_scaler = ConstantGradScaler(config.loss_scale) 69 | 70 | # Dynamic loss scale. 71 | else: 72 | if config.fp16: 73 | grad_scaler = DynamicGradScaler(initial_scale=config.initial_loss_scale, 74 | min_scale=config.min_loss_scale, 75 | growth_factor=2.0, 76 | backoff_factor=0.5, 77 | growth_interval=config.loss_scale_window, 78 | hysteresis=config.hysteresis) 79 | 80 | # Megatron optimizer. 81 | if config.use_distributed_optimizer: 82 | return DistributedOptimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad, 83 | check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, config.bf16, 84 | config.params_dtype, grad_scaler, model, overlap_param_gather) 85 | else: 86 | return Float16OptimizerWithFloat16Params(optimizer, config.clip_grad, config.log_num_zeros_in_grad, 87 | check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, 88 | config.bf16, config.params_dtype, grad_scaler, model) 89 | 90 | # FP32. 91 | return FP32Optimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad, check_for_nan_in_loss_and_grad, 92 | params_have_main_grad, model) 93 | -------------------------------------------------------------------------------- /verl/utils/megatron/optimizer_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | from typing import Callable, Optional 18 | 19 | import torch 20 | 21 | 22 | @dataclass 23 | class OptimizerConfig: 24 | """Configuration for optimizer.""" 25 | 26 | ############## 27 | # General 28 | ############## 29 | optimizer: str = 'adam' 30 | """Optimizer to use (one of Adam or SGD).""" 31 | 32 | lr: Optional[float] = None 33 | """Initial learning rate. Depending on decay style and initial warmup, the learning rate at each 34 | iteration would be different. 35 | """ 36 | 37 | min_lr: Optional[float] = None 38 | """Minumum value for learning rate. The scheduler clip values below this threshold.""" 39 | 40 | decoupled_lr: Optional[float] = None 41 | """Separate learning rate for the input and output layer.""" 42 | 43 | decoupled_min_lr: Optional[float] = None 44 | """Minimum value for learning rate for the input and output layer. The scheduler clip values 45 | below this threshold. 46 | """ 47 | 48 | weight_decay: float = 0.01 49 | """Weight decay coefficient for L2 regularization.""" 50 | 51 | ############## 52 | # Precision 53 | ############## 54 | fp16: bool = False 55 | """If true, train with fp16 mixed precision training. Defaults to False.""" 56 | 57 | bf16: bool = False 58 | """If true, train with bf16 mixed precision training. Defaults to False.""" 59 | 60 | params_dtype: torch.dtype = torch.float32 61 | """dtype used when intializing the weights. Defaults to torch.float32.""" 62 | 63 | ############### 64 | # Loss scaling 65 | ############### 66 | loss_scale: Optional[float] = None 67 | """Static loss scaling, positive power of 2 values can improve fp16 convergence. If None, 68 | dynamic loss scaling is used. 69 | """ 70 | 71 | initial_loss_scale: float = 2**32 72 | """Initial loss-scale for dynamic loss scaling.""" 73 | 74 | min_loss_scale: float = 1.0 75 | """Minimum loss scale for dynamic loss scaling.""" 76 | 77 | loss_scale_window: float = 1000 78 | """Window over which to raise/lower dynamic scale.""" 79 | 80 | hysteresis: int = 2 81 | """Hysteresis for dynamic loss scaling.""" 82 | 83 | ############## 84 | # Optimizer 85 | ############## 86 | # Adam 87 | adam_beta1: float = 0.9 88 | """First coefficient for computing running averages of gradient and its square in Adam 89 | optimizer. 90 | """ 91 | 92 | adam_beta2: float = 0.999 93 | """Second coefficient for computing running averages of gradient and its square in Adam 94 | optimizer. 95 | """ 96 | 97 | adam_eps: float = 1e-08 98 | """Term added to the denominator to improve numerical stability in Adam optimizer.""" 99 | 100 | # SGD. 101 | sgd_momentum: float = 0.9 102 | """Momentum factor for SGD optimizer.""" 103 | 104 | ####################### 105 | # Distributed optimizer 106 | ####################### 107 | use_distributed_optimizer: bool = False 108 | """Distribute optimizer state over data-parallel replicas.""" 109 | 110 | overlap_grad_reduce: bool = False 111 | """If true, overlap grad reduce-scatter with backward compute in distributed optimizer.""" 112 | 113 | overlap_param_gather: bool = False 114 | """If true, overlap param all-gather with forward compute in distributed optimizer.""" 115 | 116 | ################ 117 | # Miscellaneous 118 | ################ 119 | clip_grad: float = 1.0 120 | """Gradient clipping based on global L2 norm.""" 121 | 122 | log_num_zeros_in_grad: bool = False 123 | """If true, calculate and log the number of zeros in gradient.""" 124 | 125 | barrier_with_L1_time: bool = False 126 | """If true, use barrier with level 1 time measurements.""" 127 | 128 | timers: Callable = None 129 | """Function to get timers.""" 130 | -------------------------------------------------------------------------------- /verl/utils/megatron/pipeline_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from megatron.core import parallel_state as mpu 18 | 19 | from .sequence_parallel import pad_to_sequence_parallel 20 | 21 | 22 | def compute_transformers_input_shapes(batches, meta_info): 23 | from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron 24 | # pre-compute input shapes for each micro-batch at each pp stage 25 | input_shapes = [] 26 | for model_inputs in batches: 27 | input_ids = model_inputs['input_ids'] 28 | attention_mask = model_inputs['attention_mask'] 29 | input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1) 30 | if meta_info['sequence_parallel']: 31 | input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad) 32 | # compute shapes for model_inputs 33 | input_shapes.append( 34 | torch.Size([ 35 | input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), 1, meta_info['hidden_size'] 36 | ])) 37 | else: 38 | # compute shapes for model_inputs 39 | input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info['hidden_size']])) 40 | return input_shapes 41 | 42 | 43 | def make_batch_generator(batches, vpp_size): 44 | if vpp_size > 1: 45 | # has vpp 46 | batch_generator = [batches] * vpp_size # number of vpp chunks 47 | batch_generator = [iter(b) for b in batch_generator] 48 | else: 49 | # no vpp 50 | batch_generator = iter(batches) 51 | return batch_generator 52 | -------------------------------------------------------------------------------- /verl/utils/megatron/sequence_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from megatron.core import parallel_state as mpu 19 | 20 | 21 | def mark_parameter_as_sequence_parallel(parameter): 22 | setattr(parameter, 'sequence_parallel', True) 23 | 24 | 25 | def is_sequence_parallel_param(param): 26 | return hasattr(param, 'sequence_parallel') and param.sequence_parallel 27 | 28 | 29 | def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): 30 | """pad the tokens such that the total length is a multiple of sp world size 31 | 32 | Args: 33 | unpad_tokens: (total_nnz, ...). Tokens after removing padding 34 | 35 | Returns: 36 | 37 | """ 38 | total_nnz = unpad_tokens.shape[0] 39 | sp_world_size = mpu.get_tensor_model_parallel_world_size() 40 | 41 | if total_nnz % sp_world_size == 0: 42 | pad_size = 0 43 | else: 44 | pad_size = sp_world_size - total_nnz % sp_world_size 45 | 46 | if pad_size > 0: 47 | if unpad_tokens.ndim == 1: 48 | unpad_tokens = F.pad(unpad_tokens, (0, pad_size)) 49 | elif unpad_tokens.ndim == 2: 50 | unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) 51 | else: 52 | raise NotImplementedError(f'Padding dim {unpad_tokens.ndim()} is not supported') 53 | 54 | return unpad_tokens 55 | -------------------------------------------------------------------------------- /verl/utils/py_functional.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contain small python utility functions 16 | """ 17 | 18 | from typing import Dict 19 | from types import SimpleNamespace 20 | 21 | 22 | def union_two_dict(dict1: Dict, dict2: Dict): 23 | """Union two dict. Will throw an error if there is an item not the same object with the same key. 24 | 25 | Args: 26 | dict1: 27 | dict2: 28 | 29 | Returns: 30 | 31 | """ 32 | for key, val in dict2.items(): 33 | if key in dict1: 34 | assert dict2[key] == dict1[key], \ 35 | f'{key} in meta_dict1 and meta_dict2 are not the same object' 36 | dict1[key] = val 37 | 38 | return dict1 39 | 40 | 41 | def append_to_dict(data: Dict, new_data: Dict): 42 | for key, val in new_data.items(): 43 | if key not in data: 44 | data[key] = [] 45 | data[key].append(val) 46 | 47 | 48 | class NestedNamespace(SimpleNamespace): 49 | 50 | def __init__(self, dictionary, **kwargs): 51 | super().__init__(**kwargs) 52 | for key, value in dictionary.items(): 53 | if isinstance(value, dict): 54 | self.__setattr__(key, NestedNamespace(value)) 55 | else: 56 | self.__setattr__(key, value) 57 | -------------------------------------------------------------------------------- /verl/utils/ray_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains commonly used utilities for ray 16 | """ 17 | 18 | import ray 19 | 20 | import concurrent.futures 21 | 22 | 23 | def parallel_put(data_list, max_workers=None): 24 | 25 | def put_data(index, data): 26 | return index, ray.put(data) 27 | 28 | if max_workers is None: 29 | max_workers = min(len(data_list), 16) 30 | 31 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 32 | data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)] 33 | res_lst = [] 34 | for future in concurrent.futures.as_completed(data_list_f): 35 | res_lst.append(future.result()) 36 | 37 | # reorder based on index 38 | output = [None for _ in range(len(data_list))] 39 | for res in res_lst: 40 | index, data_ref = res 41 | output[index] = data_ref 42 | 43 | return output 44 | -------------------------------------------------------------------------------- /verl/utils/rendezvous/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/utils/rendezvous/ray_backend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import time 17 | 18 | from cupy.cuda.nccl import NcclCommunicator, get_unique_id 19 | 20 | import ray 21 | from ray.util import list_named_actors 22 | 23 | 24 | @ray.remote 25 | class NCCLIDStore: 26 | 27 | def __init__(self, nccl_id): 28 | self._nccl_id = nccl_id 29 | 30 | def get(self): 31 | return self._nccl_id 32 | 33 | 34 | def get_nccl_id_store_by_name(name): 35 | all_actors = list_named_actors(all_namespaces=True) 36 | matched_actors = [actor for actor in all_actors if actor.get("name", None) == name] 37 | if len(matched_actors) == 1: 38 | actor = matched_actors[0] 39 | return ray.get_actor(**actor) 40 | elif len(matched_actors) > 1: 41 | logging.warning(f"multiple actors with same name found: {matched_actors}") 42 | elif len(matched_actors) == 0: 43 | logging.info(f"failed to get any actor named {name}") 44 | return None 45 | 46 | 47 | def create_nccl_communicator_in_ray(rank: int, 48 | world_size: int, 49 | group_name: str, 50 | max_retries: int = 100, 51 | interval_s: int = 5): 52 | if rank == 0: 53 | nccl_id = get_unique_id() 54 | nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id) 55 | 56 | assert ray.get(nccl_id_store.get.remote()) == nccl_id 57 | communicator = NcclCommunicator( 58 | ndev=world_size, 59 | commId=nccl_id, 60 | rank=0, 61 | ) 62 | return communicator 63 | else: 64 | for i in range(max_retries): 65 | nccl_id_store = get_nccl_id_store_by_name(group_name) 66 | if nccl_id_store is not None: 67 | logging.info(f"nccl_id_store {group_name} got") 68 | nccl_id = ray.get(nccl_id_store.get.remote()) 69 | logging.info(f"nccl id for {group_name} got: {nccl_id}") 70 | communicator = NcclCommunicator( 71 | ndev=world_size, 72 | commId=nccl_id, 73 | rank=rank, 74 | ) 75 | return communicator 76 | logging.info(f"failed to get nccl_id for {i+1} time, sleep for {interval_s} seconds") 77 | time.sleep(interval_s) 78 | -------------------------------------------------------------------------------- /verl/utils/reward_score/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/utils/reward_score/countdown.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | import ast 4 | import operator 5 | 6 | 7 | def extract_solution(solution_str): 8 | """Extract the equation from the solution string.""" 9 | # Remove everything before the first "Assistant:" 10 | if "Assistant:" in solution_str: 11 | solution_str = solution_str.split("Assistant:", 1)[1] 12 | elif "<|im_start|>assistant" in solution_str: 13 | solution_str = solution_str.split("<|im_start|>assistant", 1)[1] 14 | else: 15 | return None 16 | solution_str = solution_str.split('\n')[-1] 17 | 18 | answer_pattern = r'(.*?)' 19 | match = re.finditer(answer_pattern, solution_str) 20 | matches = list(match) 21 | if matches: 22 | final_answer = matches[-1].group(1).strip() 23 | else: 24 | final_answer = None 25 | return final_answer 26 | 27 | 28 | def validate_equation(equation_str, available_numbers): 29 | """Validate that equation only uses available numbers and each number once.""" 30 | try: 31 | # Extract all numbers from the equation 32 | numbers_in_eq = [int(n) for n in re.findall(r'\d+', equation_str)] 33 | 34 | # Check if all numbers in equation are available 35 | available_numbers = sorted(available_numbers) 36 | numbers_in_eq = sorted(numbers_in_eq) 37 | 38 | # Each number should be used exactly once 39 | return numbers_in_eq == available_numbers 40 | except: 41 | return False 42 | 43 | 44 | def evaluate_equation(equation_str): 45 | """Safely evaluate the arithmetic equation using eval() with precautions.""" 46 | try: 47 | # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace 48 | allowed_pattern = r'^[\d+\-*/().\s]+$' 49 | if not re.match(allowed_pattern, equation_str): 50 | raise ValueError("Invalid characters in equation.") 51 | 52 | # Evaluate the equation with restricted globals and locals 53 | result = eval(equation_str, {"__builtins__": None}, {}) 54 | return result 55 | except Exception as e: 56 | return None 57 | 58 | 59 | def compute_score(solution_str, ground_truth, method='strict', format_score=0.1, score=1.): 60 | """The scoring function for countdown task. 61 | 62 | Args: 63 | solution_str: the solution text 64 | ground_truth: dictionary containing target number and available numbers 65 | method: the method to extract the solution 66 | format_score: the score for correct format but wrong answer 67 | score: the score for the correct answer 68 | """ 69 | target = ground_truth['target'] 70 | numbers = ground_truth['numbers'] 71 | 72 | equation = extract_solution(solution_str=solution_str) 73 | do_print = random.randint(1, 64) == 1 74 | 75 | if do_print: 76 | print(f"--------------------------------") 77 | print(f"Target: {target} | Numbers: {numbers}") 78 | print(f"Extracted equation: {equation}") 79 | print(f"Solution string: {solution_str}") 80 | 81 | if equation is None: 82 | if do_print: 83 | print(f"No equation found") 84 | return 0 85 | 86 | # Validate equation uses correct numbers 87 | if not validate_equation(equation, numbers): 88 | if do_print: 89 | print(f"Invalid equation") 90 | return format_score 91 | 92 | # Evaluate equation 93 | try: 94 | result = evaluate_equation(equation) 95 | if result is None: 96 | if do_print: 97 | print(f"Could not evaluate equation") 98 | return format_score 99 | 100 | if abs(result - target) < 1e-5: # Account for floating point precision 101 | if do_print: 102 | print(f"Correct equation: {equation} = {result}") 103 | return score 104 | else: 105 | if do_print: 106 | print(f"Wrong result: equation = {result}, target = {target}") 107 | return format_score 108 | except: 109 | if do_print: 110 | print(f"Error evaluating equation") 111 | return format_score -------------------------------------------------------------------------------- /verl/utils/reward_score/gsm8k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def extract_solution(solution_str, method='strict'): 19 | assert method in ['strict', 'flexible'] 20 | 21 | if method == 'strict': 22 | # this also tests the formatting of the model 23 | solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) 24 | if solution is None: 25 | final_answer = None 26 | else: 27 | final_answer = solution.group(0) 28 | final_answer = final_answer.split('#### ')[1].replace(',', '').replace('$', '') 29 | elif method == 'flexible': 30 | answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) 31 | final_answer = None 32 | if len(answer) == 0: 33 | # no reward is there is no answer 34 | pass 35 | else: 36 | invalid_str = ['', '.'] 37 | # find the last number that is not '.' 38 | for final_answer in reversed(answer): 39 | if final_answer not in invalid_str: 40 | break 41 | return final_answer 42 | 43 | 44 | def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.): 45 | """The scoring function for GSM8k. 46 | 47 | Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. 48 | 49 | Args: 50 | solution_str: the solution text 51 | ground_truth: the ground truth 52 | method: the method to extract the solution, choices are 'strict' and 'flexible' 53 | format_score: the score for the format 54 | score: the score for the correct answer 55 | """ 56 | answer = extract_solution(solution_str=solution_str, method=method) 57 | if answer is None: 58 | return 0 59 | else: 60 | if answer == ground_truth: 61 | return score 62 | else: 63 | return format_score -------------------------------------------------------------------------------- /verl/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for tokenization.""" 15 | import warnings 16 | 17 | __all__ = ['hf_tokenizer'] 18 | 19 | 20 | def set_pad_token_id(tokenizer): 21 | """Set pad_token_id to eos_token_id if it is None. 22 | 23 | Args: 24 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set. 25 | 26 | """ 27 | if tokenizer.pad_token_id is None: 28 | tokenizer.pad_token_id = tokenizer.eos_token_id 29 | warnings.warn(f'tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}') 30 | if tokenizer.pad_token is None: 31 | tokenizer.pad_token = tokenizer.eos_token 32 | warnings.warn(f'tokenizer.pad_token is None. Now set to {tokenizer.eos_token}') 33 | 34 | 35 | def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): 36 | """Create a huggingface pretrained tokenizer. 37 | 38 | Args: 39 | name (str): The name of the tokenizer. 40 | correct_pad_token (bool): Whether to correct the pad token id. 41 | correct_gemma2 (bool): Whether to correct the gemma2 tokenizer. 42 | **kwargs: The keyword arguments for the tokenizer. 43 | 44 | Returns: 45 | transformers.PreTrainedTokenizer: The pretrained tokenizer. 46 | 47 | """ 48 | from transformers import AutoTokenizer 49 | if correct_gemma2 and isinstance(name_or_path, str) and 'gemma-2-2b-it' in name_or_path: 50 | # the EOS token in gemma2 is ambiguious, which may worsen RL performance. 51 | # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a 52 | warnings.warn('Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.') 53 | kwargs['eos_token'] = '' 54 | kwargs['eos_token_id'] = 107 55 | tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) 56 | if correct_pad_token: 57 | set_pad_token_id(tokenizer) 58 | return tokenizer -------------------------------------------------------------------------------- /verl/utils/torch_dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Adapted from Cruise. 16 | """ 17 | 18 | import torch 19 | 20 | from typing import Union 21 | 22 | HALF_LIST = [16, "16", "fp16", "float16"] 23 | FLOAT_LIST = [32, "32", "fp32", "float32"] 24 | BFLOAT_LIST = ["bf16", "bfloat16"] 25 | 26 | 27 | class PrecisionType(object): 28 | """Type of precision used. 29 | 30 | >>> PrecisionType.HALF == 16 31 | True 32 | >>> PrecisionType.HALF in (16, "16") 33 | True 34 | """ 35 | 36 | HALF = "16" 37 | FLOAT = "32" 38 | FULL = "64" 39 | BFLOAT = "bf16" 40 | MIXED = "mixed" 41 | 42 | @staticmethod 43 | def supported_type(precision: Union[str, int]) -> bool: 44 | return any(x == precision for x in PrecisionType) 45 | 46 | @staticmethod 47 | def supported_types() -> list[str]: 48 | return [x.value for x in PrecisionType] 49 | 50 | @staticmethod 51 | def is_fp16(precision): 52 | return precision in HALF_LIST 53 | 54 | @staticmethod 55 | def is_fp32(precision): 56 | return precision in FLOAT_LIST 57 | 58 | @staticmethod 59 | def is_bf16(precision): 60 | return precision in BFLOAT_LIST 61 | 62 | @staticmethod 63 | def to_dtype(precision): 64 | if precision in HALF_LIST: 65 | return torch.float16 66 | elif precision in FLOAT_LIST: 67 | return torch.float32 68 | elif precision in BFLOAT_LIST: 69 | return torch.bfloat16 70 | else: 71 | raise RuntimeError(f"unexpected precision: {precision}") 72 | 73 | @staticmethod 74 | def to_str(precision): 75 | if precision == torch.float16: 76 | return 'fp16' 77 | elif precision == torch.float32: 78 | return 'fp32' 79 | elif precision == torch.bfloat16: 80 | return 'bf16' 81 | else: 82 | raise RuntimeError(f"unexpected precision: {precision}") 83 | -------------------------------------------------------------------------------- /verl/utils/tracking.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A unified tracking interface that supports logging data to different backend 16 | """ 17 | import dataclasses 18 | from enum import Enum 19 | from functools import partial 20 | from pathlib import Path 21 | from typing import List, Union, Dict, Any 22 | 23 | 24 | class Tracking(object): 25 | supported_backend = ['wandb', 'mlflow', 'console'] 26 | 27 | def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None): 28 | if isinstance(default_backend, str): 29 | default_backend = [default_backend] 30 | for backend in default_backend: 31 | if backend == 'tracking': 32 | import warnings 33 | warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning) 34 | else: 35 | assert backend in self.supported_backend, f'{backend} is not supported' 36 | 37 | self.logger = {} 38 | 39 | if 'tracking' in default_backend or 'wandb' in default_backend: 40 | import wandb 41 | import os 42 | WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None) 43 | if WANDB_API_KEY: 44 | wandb.login(key=WANDB_API_KEY) 45 | wandb.init(project=project_name, name=experiment_name, config=config) 46 | self.logger['wandb'] = wandb 47 | 48 | if 'mlflow' in default_backend: 49 | import mlflow 50 | mlflow.start_run(run_name=experiment_name) 51 | mlflow.log_params(_compute_mlflow_params_from_objects(config)) 52 | self.logger['mlflow'] = _MlflowLoggingAdapter() 53 | 54 | if 'console' in default_backend: 55 | from verl.utils.logger.aggregate_logger import LocalLogger 56 | self.console_logger = LocalLogger(print_to_console=True) 57 | self.logger['console'] = self.console_logger 58 | 59 | def log(self, data, step, backend=None): 60 | for default_backend, logger_instance in self.logger.items(): 61 | if backend is None or default_backend in backend: 62 | logger_instance.log(data=data, step=step) 63 | 64 | 65 | class _MlflowLoggingAdapter: 66 | 67 | def log(self, data, step): 68 | import mlflow 69 | mlflow.log_metrics(metrics=data, step=step) 70 | 71 | 72 | def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]: 73 | if params is None: 74 | return {} 75 | 76 | return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep='/') 77 | 78 | 79 | def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): 80 | _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict) 81 | 82 | if dataclasses.is_dataclass(x): 83 | return _transform(dataclasses.asdict(x)) 84 | if isinstance(x, dict): 85 | return {k: _transform(v) for k, v in x.items()} 86 | if isinstance(x, list): 87 | if convert_list_to_dict: 88 | return {'list_len': len(x)} | {f'{i}': _transform(v) for i, v in enumerate(x)} 89 | else: 90 | return [_transform(v) for v in x] 91 | if isinstance(x, Path): 92 | return str(x) 93 | if isinstance(x, Enum): 94 | return x.value 95 | 96 | return x 97 | 98 | 99 | def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]: 100 | import pandas as pd 101 | ans = pd.json_normalize(raw, sep=sep).to_dict(orient='records')[0] 102 | assert isinstance(ans, dict) 103 | return ans 104 | -------------------------------------------------------------------------------- /verl/version/version: -------------------------------------------------------------------------------- 1 | 0.1 -------------------------------------------------------------------------------- /verl/workers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /verl/workers/actor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOActor 16 | from .dp_actor import DataParallelPPOActor 17 | 18 | __all__ = ["BasePPOActor", "DataParallelPPOActor"] 19 | -------------------------------------------------------------------------------- /verl/workers/actor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for Actor 16 | """ 17 | from abc import ABC, abstractmethod 18 | from typing import Iterable, Dict 19 | 20 | from verl import DataProto 21 | import torch 22 | 23 | __all__ = ['BasePPOActor'] 24 | 25 | 26 | class BasePPOActor(ABC): 27 | 28 | def __init__(self, config): 29 | """The base class for PPO actor 30 | 31 | Args: 32 | config (DictConfig): a config passed to the PPOActor. We expect the type to be 33 | DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general. 34 | """ 35 | super().__init__() 36 | self.config = config 37 | 38 | @abstractmethod 39 | def compute_log_prob(self, data: DataProto) -> torch.Tensor: 40 | """Compute logits given a batch of data. 41 | 42 | Args: 43 | data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, 44 | ```attention_mask``` and ```position_ids```. 45 | 46 | Returns: 47 | DataProto: a DataProto containing the key ```log_probs``` 48 | 49 | 50 | """ 51 | pass 52 | 53 | @abstractmethod 54 | def update_policy(self, data: DataProto) -> Dict: 55 | """Update the policy with an iterator of DataProto 56 | 57 | Args: 58 | data (DataProto): an iterator over the DataProto that returns by 59 | ```make_minibatch_iterator``` 60 | 61 | Returns: 62 | Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model 63 | such as ```loss```, ```grad_norm```, etc,. 64 | 65 | """ 66 | pass 67 | -------------------------------------------------------------------------------- /verl/workers/critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPOCritic 16 | from .dp_critic import DataParallelPPOCritic 17 | 18 | __all__ = ["BasePPOCritic", "DataParallelPPOCritic"] 19 | -------------------------------------------------------------------------------- /verl/workers/critic/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Base class for a critic 16 | """ 17 | from abc import ABC, abstractmethod 18 | 19 | import torch 20 | 21 | from verl import DataProto 22 | 23 | __all__ = ['BasePPOCritic'] 24 | 25 | 26 | class BasePPOCritic(ABC): 27 | 28 | def __init__(self, config): 29 | super().__init__() 30 | self.config = config 31 | 32 | @abstractmethod 33 | def compute_values(self, data: DataProto) -> torch.Tensor: 34 | """Compute values""" 35 | pass 36 | 37 | @abstractmethod 38 | def update_critic(self, data: DataProto): 39 | """Update the critic""" 40 | pass 41 | -------------------------------------------------------------------------------- /verl/workers/reward_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BasePPORewardModel 16 | -------------------------------------------------------------------------------- /verl/workers/reward_model/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The base class for reward model 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | 20 | from verl import DataProto 21 | 22 | 23 | class BasePPORewardModel(ABC): 24 | 25 | def __init__(self, config): 26 | self.config = config 27 | 28 | @abstractmethod 29 | def compute_reward(self, data: DataProto) -> DataProto: 30 | """Computing reward given input_ids. The transformers should output a tensor with shape 31 | [batch_size, sequence_length], and the value at [EOS] mask should be gathered. 32 | 33 | Args: 34 | data: must contain keys "input_ids", "attention_mask" and "position_ids". 35 | - input_ids: [batch_size, sequence_length] 36 | - attention_mask: [batch_size, sequence_length] 37 | - position_ids: [batch_size, sequence_length] 38 | 39 | Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward. 40 | Other position should have zero reward. Note that this may change in the future if we use 41 | dense reward. So, we leave the interface for general case. 42 | - reward: [batch_size, sequence_length]. 43 | 44 | """ 45 | pass 46 | -------------------------------------------------------------------------------- /verl/workers/reward_model/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .reward_model import MegatronRewardModel 16 | -------------------------------------------------------------------------------- /verl/workers/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import BaseRollout 16 | from .naive import NaiveRollout 17 | from .hf_rollout import HFRollout 18 | 19 | __all__ = ["BaseRollout", "NaiveRollout", "HFRollout"] 20 | -------------------------------------------------------------------------------- /verl/workers/rollout/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | from typing import Iterable, Union 17 | 18 | from verl import DataProto 19 | 20 | __all__ = ['BaseRollout'] 21 | 22 | 23 | class BaseRollout(ABC): 24 | 25 | def __init__(self): 26 | """ 27 | 28 | Args: 29 | dataloader: an Iterable of TensorDict that consistently generates prompts. Note that the dataloader 30 | should handle when the training stops. 31 | """ 32 | super().__init__() 33 | 34 | @abstractmethod 35 | def generate_sequences(self, prompts: DataProto) -> DataProto: 36 | """Generate sequences""" 37 | pass 38 | -------------------------------------------------------------------------------- /verl/workers/rollout/naive/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .naive_rollout import NaiveRollout 16 | -------------------------------------------------------------------------------- /verl/workers/rollout/naive/naive_rollout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | In single GPU rollout, the sequences are generated directly by sampling from the model. 16 | The output will contain 17 | 1. output_ids 18 | 2. attention_masks (left padding) 19 | 3. eos_masks 20 | 4. log_probs 21 | """ 22 | from typing import Iterable, Union 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | from tensordict import TensorDict 27 | from torch import nn 28 | 29 | from verl import DataProto 30 | from verl.utils.torch_functional import logprobs_from_logits 31 | from ..base import BaseRollout 32 | 33 | __all__ = ['NativeRollout'] 34 | 35 | 36 | class NaiveRollout(BaseRollout): 37 | 38 | def __init__(self, module: nn.Module, config): 39 | """A naive rollout. It requires the module to be compatible with huggingface APIs. That is: 40 | The module should define __call__ to receive input_ids, attention_mask and position_ids. 41 | It outputs a structure that contains logits field. 42 | 43 | Args: 44 | module: module here follows huggingface APIs 45 | config: DictConfig 46 | """ 47 | super().__init__() 48 | self.config = config 49 | self.module = module 50 | 51 | @torch.no_grad() 52 | def generate_sequences(self, prompts: DataProto) -> DataProto: 53 | """Generate sequences""" 54 | idx = prompts.batch['input_ids'] # (bs, prompt_length) 55 | attention_mask = prompts.batch['attention_mask'] # left-padded attention_mask 56 | position_ids = prompts.batch['position_ids'] 57 | 58 | # used to construct attention_mask 59 | eos_token_id = prompts.meta_info['eos_token_id'] 60 | 61 | batch_size = idx.size(0) 62 | prompt_length = idx.size(1) 63 | 64 | self.module.eval() 65 | 66 | prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device) 67 | 68 | logits_lst = [] 69 | for _ in range(self.config.response_length): 70 | # if the sequence context is growing too long we must crop it at block_size 71 | # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 72 | idx_cond = idx 73 | # forward the model to get the logits for the index in the sequence 74 | # we use huggingface APIs here 75 | output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids) 76 | logits = output.logits 77 | # pluck the logits at the final step and scale by desired temperature 78 | logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size) 79 | # optionally crop the logits to only the top k options 80 | if self.config.top_k is not None: 81 | v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1))) 82 | logits[logits < v[:, [-1]]] = -float('Inf') 83 | # apply softmax to convert logits to (normalized) probabilities 84 | probs = F.softmax(logits, dim=-1) 85 | # sample from the distribution 86 | if self.config.do_sample: 87 | idx_next = torch.multinomial(probs, num_samples=1) 88 | else: 89 | idx_next = torch.argmax(probs, dim=-1, keepdim=True) 90 | 91 | attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1) 92 | 93 | prev_attention_mask = torch.logical_and(idx_next != eos_token_id, prev_attention_mask.bool()) 94 | prev_attention_mask.to(attention_mask.dtype) 95 | 96 | position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1) 97 | 98 | # append sampled index to the running sequence and continue 99 | idx = torch.cat((idx, idx_next), dim=1) 100 | logits_lst.append(logits) 101 | 102 | logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size) 103 | prompts = idx[:, :prompt_length] # (bs, prompt_length) 104 | response = idx[:, prompt_length:] # (bs, response_length) 105 | log_probs = logprobs_from_logits(logits=logits, labels=response) 106 | batch = TensorDict( 107 | { 108 | 'input_ids': prompts, 109 | 'responses': response, 110 | 'sequences': idx, 111 | 'old_log_probs': log_probs, 112 | 'attention_mask': attention_mask, 113 | 'position_ids': position_ids, 114 | }, 115 | batch_size=batch_size) 116 | 117 | self.module.train() 118 | 119 | return DataProto(batch=batch) 120 | -------------------------------------------------------------------------------- /verl/workers/rollout/vllm_rollout/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .vllm_rollout import vLLMRollout -------------------------------------------------------------------------------- /verl/workers/sharding_manager/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from verl.utils.import_utils import is_vllm_available, is_megatron_core_available 16 | 17 | from .base import BaseShardingManager 18 | from .fsdp_ulysses import FSDPUlyssesShardingManager 19 | 20 | AllGatherPPModel = None 21 | 22 | if is_megatron_core_available() and is_vllm_available(): 23 | from .megatron_vllm import AllGatherPPModel, MegatronVLLMShardingManager 24 | elif AllGatherPPModel is not None: 25 | pass 26 | else: 27 | AllGatherPPModel = None 28 | MegatronVLLMShardingManager = None 29 | 30 | if is_vllm_available(): 31 | from .fsdp_vllm import FSDPVLLMShardingManager 32 | else: 33 | FSDPVLLMShardingManager = None 34 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Sharding manager to implement HybridEngine 16 | """ 17 | 18 | from verl import DataProto 19 | 20 | 21 | class BaseShardingManager: 22 | 23 | def __enter__(self): 24 | pass 25 | 26 | def __exit__(self, exc_type, exc_value, traceback): 27 | pass 28 | 29 | def preprocess_data(self, data: DataProto) -> DataProto: 30 | return data 31 | 32 | def postprocess_data(self, data: DataProto) -> DataProto: 33 | return data 34 | -------------------------------------------------------------------------------- /verl/workers/sharding_manager/fsdp_ulysses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT 16 | """ 17 | from typing import Optional 18 | from .base import BaseShardingManager 19 | 20 | import random 21 | from torch.distributed.device_mesh import DeviceMesh 22 | 23 | from verl.utils.torch_functional import allgather_dict_tensors 24 | from verl.utils.ulysses import set_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_group 25 | import numpy as np 26 | 27 | import torch 28 | import torch.distributed 29 | 30 | from verl import DataProto 31 | 32 | 33 | class FSDPUlyssesShardingManager(BaseShardingManager): 34 | """ 35 | Sharding manager to support data resharding when using FSDP + Ulysses 36 | """ 37 | 38 | def __init__(self, device_mesh: DeviceMesh): 39 | super().__init__() 40 | self.device_mesh = device_mesh 41 | self.seed_offset = 12345 42 | 43 | def __enter__(self): 44 | if self.device_mesh is not None: 45 | # We have a global SP group 46 | # so we have to change to use model-specific sp group 47 | self.prev_sp_group = get_ulysses_sequence_parallel_group() 48 | set_ulysses_sequence_parallel_group(self.device_mesh['sp'].get_group()) 49 | # TODO: check how to set seed for each model 50 | 51 | def __exit__(self, exc_type, exc_value, traceback): 52 | # restore random states 53 | if self.device_mesh is not None: 54 | # revert to previous sp group 55 | set_ulysses_sequence_parallel_group(self.prev_sp_group) 56 | # TODO: check how to set seed for each model 57 | 58 | def preprocess_data(self, data: DataProto) -> DataProto: 59 | """ 60 | AllGather data from sp region 61 | This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE 62 | In Ulysses, we need to make sure the same data is used across a SP group 63 | """ 64 | if self.device_mesh is not None: 65 | sp_size = self.device_mesh['sp'].size() 66 | group = self.device_mesh['sp'].get_group() 67 | 68 | prev_device = data.batch.device 69 | data.batch = data.batch.cuda(device=torch.cuda.current_device()) 70 | data.batch = allgather_dict_tensors(data.batch.contiguous(), size=sp_size, group=group, dim=0) 71 | data.batch = data.batch.to(prev_device) 72 | # all gather non_tensor_batch 73 | all_non_tensor_batch = [None for _ in range(sp_size)] 74 | torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group) 75 | data.non_tensor_batch = { 76 | k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch 77 | } 78 | return data 79 | 80 | def postprocess_data(self, data: DataProto) -> DataProto: 81 | """ 82 | Split the data to follow FSDP partition 83 | """ 84 | if self.device_mesh is not None: 85 | sp_size = self.device_mesh['sp'].size() 86 | sp_rank = self.device_mesh['sp'].get_local_rank() 87 | data = data.chunk(chunks=sp_size)[sp_rank] 88 | return data --------------------------------------------------------------------------------