├── README.md ├── images └── image.png ├── math_tir ├── .env ├── routed_sandbox.py ├── test_e2b_sandbox.ipynb └── routed_morph.py ├── sequence_parallel ├── images │ ├── splitQ.png │ ├── KV-rotate.gif │ ├── ulysses.png │ ├── sl4096_rank7.png │ ├── sl65536_rank7.png │ ├── ulysses_simple.png │ ├── zigzag_iteration.png │ ├── zigzag_workflow.png │ ├── ring_attention_fwd.png │ ├── ring_attn_workflow.png │ ├── strided_attention.png │ ├── stripe_attn_profile.png │ ├── sl8192_rank7_overlap.png │ ├── stripe_attn_iteration.png │ ├── striped_attn_sequence.png │ └── workload_distribution.png ├── utils.py ├── basic │ ├── test_flash_attn_func.py │ ├── test_flash_attn_qkvpacked_func.py │ ├── test_flash_attn_varlen_qkvpacked_func.py │ ├── test_flash_attn_qkvpacked_func_dist.py │ └── benchmark_qkvpacked_func.py ├── ring_flash_attention │ ├── test_ring_flash_attn.py │ ├── test_stripe_flash_attn_func.py │ ├── test_zigzag_ring_flash_attn_func.py │ ├── benchmark_ring_flash_attn.py │ ├── rfa_utils.py │ ├── benchmark_stripe_ring_flash_attn.py │ └── benchmark_zigzag_ring_flash_attn.py ├── ulysses │ ├── test_ulysses_attn.py │ └── ulysses_attn.py ├── usp │ ├── test_usp_qkvpacked_attn.py │ ├── test_usp_attn.py │ └── usp_utils.py ├── loongtrain │ └── test_double_ring_attn.py ├── reference.py └── readme_usp.md ├── verl_test ├── scripts │ ├── download_corpus.sh │ ├── readme.md │ ├── start_retrieval_server.sh │ ├── train_sppo_32b.sh │ ├── train_rfpp.sh │ ├── train_grpo_32b.sh │ ├── train_ppo_14b.sh │ ├── run_qwen_search.sh │ └── train_dapo_32b.sh ├── megatron_test │ ├── math_score.py │ ├── test_megatron_sft.sh │ ├── megatron_sft.yaml │ ├── test_meagtron_utils.ipynb │ └── test_actor_rollout_megatron.py ├── rollout_test │ ├── sandbox_fusion_tool_config │ ├── search_tool_config │ ├── test_vllm_spmd_profile.py │ ├── test_spmd.py │ ├── test_sglang_async_rollout_without_tools.py │ ├── test_actor_rollout_fsdp.py │ └── test_vllm_tp_dp_spmd.py ├── test_fa_ops.py ├── utils.py ├── test_memory_buffers.py ├── test_ray_tp.py ├── agent_test │ ├── agent_utils.py │ ├── test_agent_single_turn.py │ └── vllm_async_rollout_perf.py ├── test_rm_worker_fsdp.py └── test_torch_func.py ├── collective_ops ├── test_broadcast.py ├── test_all_reduce.py ├── test_all_to_all.py ├── test_p2pop.py ├── test_all_gather_into_tensor.py ├── test_reduce_scatter.py ├── test_all_gather.py ├── test_gather.py ├── test_reduce.py ├── test_scatter.py ├── test_all_to_all_single.py └── test_ring_comm_customized.py ├── parallel_framework ├── megatron_utils.py └── test_megatron_dp.py ├── moe_ep ├── deepseek │ └── config.json ├── test_ep_torch.py ├── test_layer_ep_torch.py └── test_moe_kernel_vllm.py ├── torch_memory_walkthrough.ipynb └── rlhf └── reward_verl.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # AI_analysis 2 | analyse problems of AI with Math and Code -------------------------------------------------------------------------------- /images/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/images/image.png -------------------------------------------------------------------------------- /math_tir/.env: -------------------------------------------------------------------------------- 1 | E2B_API_KEY=e2b_6872be0e5e96a6b8c299c15832a3edbc8650a788 2 | MORPH_API_KEY=morph_5JB7l1wh2q5i5hNY9rrQQT -------------------------------------------------------------------------------- /sequence_parallel/images/splitQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/splitQ.png -------------------------------------------------------------------------------- /sequence_parallel/images/KV-rotate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/KV-rotate.gif -------------------------------------------------------------------------------- /sequence_parallel/images/ulysses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/ulysses.png -------------------------------------------------------------------------------- /sequence_parallel/images/sl4096_rank7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/sl4096_rank7.png -------------------------------------------------------------------------------- /sequence_parallel/images/sl65536_rank7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/sl65536_rank7.png -------------------------------------------------------------------------------- /sequence_parallel/images/ulysses_simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/ulysses_simple.png -------------------------------------------------------------------------------- /sequence_parallel/images/zigzag_iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/zigzag_iteration.png -------------------------------------------------------------------------------- /sequence_parallel/images/zigzag_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/zigzag_workflow.png -------------------------------------------------------------------------------- /sequence_parallel/images/ring_attention_fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/ring_attention_fwd.png -------------------------------------------------------------------------------- /sequence_parallel/images/ring_attn_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/ring_attn_workflow.png -------------------------------------------------------------------------------- /sequence_parallel/images/strided_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/strided_attention.png -------------------------------------------------------------------------------- /sequence_parallel/images/stripe_attn_profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/stripe_attn_profile.png -------------------------------------------------------------------------------- /sequence_parallel/images/sl8192_rank7_overlap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/sl8192_rank7_overlap.png -------------------------------------------------------------------------------- /sequence_parallel/images/stripe_attn_iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/stripe_attn_iteration.png -------------------------------------------------------------------------------- /sequence_parallel/images/striped_attn_sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/striped_attn_sequence.png -------------------------------------------------------------------------------- /sequence_parallel/images/workload_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ifromeast/AI_analysis/HEAD/sequence_parallel/images/workload_distribution.png -------------------------------------------------------------------------------- /verl_test/scripts/download_corpus.sh: -------------------------------------------------------------------------------- 1 | save_path=/data2/zzd/data/search_r1 2 | python examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py --save_path $save_path 3 | cat $save_path/part_* > $save_path/e5_Flat.index 4 | gzip -d $save_path/wiki-18.jsonl.gz -------------------------------------------------------------------------------- /verl_test/megatron_test/math_score.py: -------------------------------------------------------------------------------- 1 | 2 | from verl.utils.reward_score import math 3 | 4 | def compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None): 5 | return math.compute_score(solution_str, ground_truth) 6 | -------------------------------------------------------------------------------- /verl_test/scripts/readme.md: -------------------------------------------------------------------------------- 1 | 2 | creat environment at all node by 3 | ``` 4 | conda env create -f environment.yml 5 | ``` 6 | 7 | start ray at head node by 8 | ``` 9 | ray start --head --dashboard-host=0.0.0.0 10 | ``` 11 | 12 | start ray at worker node by 13 | ``` 14 | ray start --address='10.157.150.10:6379' 15 | ``` 16 | 17 | -------------------------------------------------------------------------------- /collective_ops/test_broadcast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | dist.init_process_group("nccl") 5 | rank = dist.get_rank() 6 | world_size = dist.get_world_size() 7 | 8 | tensor = torch.tensor([rank + 1.0, rank - 1.0], device=torch.device(f'cuda:{rank}')) 9 | 10 | # 打印广播前的张量 11 | print(f"Rank {rank} before broadcast: {tensor}") 12 | 13 | # 使用 broadcast 操作将 rank 0 的张量广播到所有进程 14 | dist.broadcast(tensor, src=0) 15 | 16 | # 打印广播后的张量 17 | print(f"Rank {rank} after broadcast: {tensor}") 18 | 19 | # 销毁进程组 20 | dist.destroy_process_group() -------------------------------------------------------------------------------- /parallel_framework/megatron_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | 5 | def initialize_global_process_group(timeout_second=36000): 6 | from datetime import timedelta 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | dist.init_process_group("nccl", timeout=timedelta(seconds=timeout_second)) 12 | local_rank = int(os.environ["LOCAL_RANK"]) 13 | rank = int(os.environ["RANK"]) 14 | world_size = int(os.environ["WORLD_SIZE"]) 15 | 16 | if dist.is_initialized(): 17 | torch.cuda.set_device(local_rank) 18 | return local_rank, rank, world_size -------------------------------------------------------------------------------- /verl_test/rollout_test/sandbox_fusion_tool_config: -------------------------------------------------------------------------------- 1 | tools: 2 | - class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool" 3 | config: 4 | sandbox_fusion_url: "https://xxx.apigateway-cn-beijing.volceapi.com/run_code" 5 | tool_schema: 6 | type: "function" 7 | function: 8 | name: "code_interpreter" 9 | description: "A tool for executing code." 10 | parameters: 11 | type: "object" 12 | properties: 13 | code: 14 | type: "string" 15 | description: "The code to execute." 16 | required: ["code"] -------------------------------------------------------------------------------- /collective_ops/test_all_reduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | dist.init_process_group("nccl") 6 | rank = dist.get_rank() 7 | world_size = dist.get_world_size() 8 | 9 | 10 | # 创建一个张量并放到当前 GPU 上 11 | tensor = torch.tensor([rank + 1.0], device=torch.device(f'cuda:{rank}')) 12 | 13 | # 打印初始张量 14 | print(f"Rank {rank} (GPU {rank}) has tensor {tensor}") 15 | 16 | # 使用 all_reduce 操作 17 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 18 | 19 | # 打印 all_reduce 后的张量 20 | print(f"Rank {rank} (GPU {rank}) has tensor {tensor} after all_reduce") 21 | 22 | # 销毁进程组 23 | dist.destroy_process_group() -------------------------------------------------------------------------------- /collective_ops/test_all_to_all.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | # 初始化进程组 5 | dist.init_process_group("nccl") 6 | rank = dist.get_rank() 7 | world_size = dist.get_world_size() 8 | 9 | device = torch.device("cuda:{}".format(rank)) 10 | 11 | input = torch.arange(8) + rank * 8 12 | input = list(input.to(device).chunk(8)) 13 | print(f"Rank {rank} before all_to_all input={input}") 14 | 15 | output = list(torch.empty([8], dtype=torch.int64).to(device).chunk(8)) 16 | dist.all_to_all(output, input) 17 | print(f"Rank {rank} after all_to_all output={output}") 18 | 19 | # 销毁进程组 20 | dist.destroy_process_group() 21 | -------------------------------------------------------------------------------- /verl_test/scripts/start_retrieval_server.sh: -------------------------------------------------------------------------------- 1 | 2 | export HF_ENDPOINT='https://hf-mirror.com' 3 | export HF_DATASETS_CACHE=/data2/zzd/cache 4 | export HF_HOME=/data2/zzd/cache 5 | 6 | save_path=/data2/zzd/data/search_r1 7 | index_file=$save_path/e5_Flat.index 8 | corpus_file=$save_path/wiki-18.jsonl 9 | retriever_name=e5 10 | retriever_path=intfloat/e5-base-v2 11 | 12 | python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \ 13 | --index_path $index_file \ 14 | --corpus_path $corpus_file \ 15 | --topk 3 \ 16 | --retriever_name $retriever_name \ 17 | --retriever_model $retriever_path \ 18 | --faiss_gpu -------------------------------------------------------------------------------- /verl_test/rollout_test/search_tool_config: -------------------------------------------------------------------------------- 1 | tools: 2 | - class_name: verl.tools.search_tool.SearchTool 3 | config: 4 | retrieval_service_url: http://127.0.0.1:8000/retrieve 5 | num_workers: 120 6 | rate_limit: 120 7 | timeout: 30 8 | tool_schema: 9 | type: function 10 | function: 11 | name: search 12 | description: Searches the web for relevant information based on the given query. 13 | parameters: 14 | type: object 15 | properties: 16 | query_list: 17 | type: array 18 | item: 19 | type: string 20 | description: A list of fully-formed semantic queries. The tool will return search results for each query. 21 | required: 22 | - query_list -------------------------------------------------------------------------------- /collective_ops/test_p2pop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | # 初始化进程组 6 | dist.init_process_group("nccl") 7 | rank = dist.get_rank() 8 | world_size = dist.get_world_size() 9 | device = torch.device("cuda:{}".format(rank)) 10 | 11 | # 创建一个简单的张量 12 | tensor = torch.tensor([rank, rank+1], dtype=torch.float32, device=device) 13 | print("rank {} tensor {}".format(rank, tensor)) 14 | 15 | send_op = dist.P2POp(dist.isend, tensor, (rank + 1)%world_size) 16 | 17 | recv_tensor = torch.zeros((1,2), dtype=torch.float32, device=device) 18 | recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size) 19 | reqs = dist.batch_isend_irecv([send_op, recv_op]) 20 | 21 | for req in reqs: 22 | req.wait() 23 | print(f"Rank {rank} received: {recv_tensor}") 24 | 25 | 26 | # 销毁进程组 27 | dist.destroy_process_group() 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /collective_ops/test_all_gather_into_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | # 初始化进程组 6 | dist.init_process_group("nccl") 7 | rank = dist.get_rank() 8 | world_size = dist.get_world_size() 9 | 10 | shape = (1, 2) 11 | device = torch.device("cuda:{}".format(rank)) 12 | 13 | # 每个进程的本地张量 14 | local_tensor = torch.ones((shape[0], shape[1]), dtype=torch.int64, device=device) * rank 15 | print(f"Rank {rank}: Local tensor = {local_tensor}") 16 | 17 | # 创建一个用于存储所有张量的全局张量 18 | # 注意:全局张量的大小应为 (world_size, *local_tensor.shape) 19 | global_tensor = torch.zeros(world_size, *local_tensor.shape, dtype=local_tensor.dtype, device=device) 20 | 21 | # 使用 all_gather_into_tensor 收集所有张量 22 | dist.all_gather_into_tensor(global_tensor, local_tensor, group=None) 23 | 24 | print(f"Rank {rank}: Global tensor after all_gather = {global_tensor}") 25 | 26 | 27 | # 销毁进程组 28 | dist.destroy_process_group() -------------------------------------------------------------------------------- /collective_ops/test_reduce_scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | # 初始化进程组 6 | dist.init_process_group("nccl") 7 | rank = dist.get_rank() 8 | world_size = dist.get_world_size() 9 | 10 | shape = (8, 2) 11 | device = torch.device("cuda:{}".format(rank)) 12 | 13 | # 每个进程的本地张量 14 | local_tensor = (torch.ones((shape[0], shape[1]), dtype=torch.int64) * rank).to(device) 15 | 16 | # 用于存储 reduce_scatter 结果的张量 17 | output_tensor = torch.zeros((shape[0] // world_size, shape[1]), dtype=torch.int64).to(device) 18 | 19 | # 打印 reduce_scatter 前的张量 20 | print(f"Rank {rank} before reduce_scatter: local_tensor={local_tensor}, output_tensor={output_tensor}") 21 | 22 | # 执行 reduce_scatter 操作 23 | dist.reduce_scatter(output_tensor, [local_tensor for _ in range(world_size)]) 24 | 25 | # 打印 reduce_scatter 后的张量 26 | print(f"Rank {rank} after reduce_scatter: output_tensor={output_tensor}") 27 | 28 | # 销毁进程组 29 | dist.destroy_process_group() -------------------------------------------------------------------------------- /collective_ops/test_all_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | # 初始化进程组 6 | dist.init_process_group("nccl") 7 | rank = dist.get_rank() 8 | world_size = dist.get_world_size() 9 | 10 | shape = (8, 2) 11 | device = torch.device("cuda:{}".format(rank)) 12 | 13 | # 每个进程的本地张量 14 | local_tensor = (torch.ones((shape[0] // world_size, shape[1]), dtype=torch.int64) * rank).to(device) 15 | 16 | # 用于存储所有进程张量的列表 17 | output_tensor_list = [torch.zeros((shape[0] // world_size, shape[1]), dtype=torch.int64).to(device) for _ in range(world_size)] 18 | 19 | # 打印 all_gather 前的张量 20 | print(f"Rank {rank} before all_gather: local_tensor={local_tensor}, output_tensor_list={output_tensor_list}") 21 | 22 | # 执行 all_gather 操作 23 | dist.all_gather(output_tensor_list, local_tensor) 24 | 25 | # 打印 all_gather 后的张量 26 | print(f"Rank {rank} after all_gather: output_tensor_list={output_tensor_list}") 27 | 28 | # 销毁进程组 29 | dist.destroy_process_group() -------------------------------------------------------------------------------- /verl_test/megatron_test/test_megatron_sft.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | ray stop 3 | 4 | export PYTHONPATH=$PYTHONPATH:/data2/zzd/rl_llm/verl 5 | 6 | python3 -m megatron_sft_trainer \ 7 | data.train_files=/data2/zzd/data/data/bella_train.parquet \ 8 | data.val_files=/data2/zzd/data/data/bella_train.parquet \ 9 | data.multiturn.enable=true \ 10 | data.multiturn.messages_key=messages \ 11 | data.train_batch_size=64 \ 12 | data.micro_batch_size_per_gpu=4 \ 13 | data.max_length=256 \ 14 | data.truncation='right' \ 15 | model.path=/data3/ckpt/Qwen/Qwen2.5-1.5B-Instruct \ 16 | megatron.tensor_model_parallel_size=2 \ 17 | megatron.pipeline_model_parallel_size=1 \ 18 | optim.lr=1e-3 \ 19 | trainer.default_local_dir=/data2/zzd/out_test \ 20 | trainer.project_name=megatron-sft \ 21 | trainer.experiment_name=verl-qwen-megatron-sft \ 22 | trainer.logger=['console'] \ 23 | trainer.total_epochs=2 \ 24 | trainer.n_gpus_per_node=8 \ 25 | trainer.nnodes=1 \ 26 | use_remove_padding=true -------------------------------------------------------------------------------- /collective_ops/test_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | # 初始化进程组 6 | dist.init_process_group("nccl") 7 | rank = dist.get_rank() 8 | world_size = dist.get_world_size() 9 | 10 | shape = (16, 2) 11 | device = torch.device("cuda:{}".format(rank)) 12 | 13 | # 每个进程生成一个张量 14 | input_tensor = (torch.ones((shape[0] // world_size, shape[1]), dtype=torch.int64) * rank).to(device) 15 | 16 | # 在 root 进程中创建一个列表来接收所有张量 17 | if rank == 1: 18 | tensor_list = [torch.zeros((shape[0] // world_size, shape[1]), dtype=torch.int64).to(device) for _ in range(world_size)] 19 | else: 20 | tensor_list = None 21 | 22 | # 打印 gather 前的张量 23 | print(f"Rank {rank} before gather: input_tensor={input_tensor}") 24 | 25 | # 执行 gather 操作 26 | dist.gather(input_tensor, gather_list=tensor_list, dst=1) 27 | 28 | # 打印 gather 后的张量 29 | if rank == 1: 30 | print(f"Rank {rank} after gather: gathered_tensors={tensor_list}") 31 | else: 32 | print(f"Rank {rank} after gather: input_tensor={input_tensor}") 33 | 34 | # 销毁进程组 35 | dist.destroy_process_group() -------------------------------------------------------------------------------- /collective_ops/test_reduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | # 初始化进程组 6 | dist.init_process_group("nccl") 7 | rank = dist.get_rank() 8 | world_size = dist.get_world_size() 9 | 10 | # 定义张量的形状和设备 11 | shape = (2, 2) 12 | device = torch.device("cuda:{}".format(rank)) 13 | 14 | # 每个进程初始化一个输入张量 15 | input_tensor = (torch.ones(shape, dtype=torch.int64) * rank).to(device) 16 | 17 | # 目标进程(rank 0)初始化一个输出张量 18 | if rank == 0: 19 | output_tensor = torch.zeros(shape, dtype=torch.int64).to(device) 20 | else: 21 | output_tensor = None 22 | 23 | # 打印 reduce 前的张量 24 | print(f"Rank {rank} before reduce: input_tensor={input_tensor}, output_tensor={output_tensor}") 25 | 26 | # 执行 reduce 操作,将所有进程的 input_tensor 求和到 rank 0 的 output_tensor 27 | dist.reduce(tensor=input_tensor, dst=0, op=dist.ReduceOp.SUM) 28 | 29 | # 打印 reduce 后的张量 30 | if rank == 0: 31 | print(f"Rank {rank} after reduce: output_tensor={input_tensor}") 32 | else: 33 | print(f"Rank {rank} after reduce: input_tensor={input_tensor}") 34 | 35 | # 销毁进程组 36 | dist.destroy_process_group() -------------------------------------------------------------------------------- /collective_ops/test_scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | # 初始化进程组 6 | dist.init_process_group("nccl") 7 | rank = dist.get_rank() 8 | world_size = dist.get_world_size() 9 | 10 | shape = (16,2) 11 | device = torch.device("cuda:{}".format(rank)) 12 | 13 | output_tensor=torch.zeros((shape[0]//world_size,shape[1]),dtype=torch.int64).to(device) 14 | 15 | # 在 root 进程中创建一个列表来发送所有张量 16 | if rank == 0: 17 | tensor_list=[(torch.ones((shape[0]//world_size,shape[1]),dtype=torch.int64)*i).to(device) for i in range(world_size)] 18 | 19 | # # 打印 scatter 前的张量 20 | if rank == 0: 21 | print(f"Rank {rank} before scatter: scattered_tensors={tensor_list}, recv_tensor={output_tensor}") 22 | else: 23 | print(f"Rank {rank} before scatter: recv_tensor={output_tensor}") 24 | 25 | if rank == 0: 26 | dist.scatter(output_tensor, scatter_list=tensor_list, src=0) 27 | else: 28 | dist.scatter(output_tensor, src=0) 29 | 30 | # # 打印 scatter 后的张量 31 | print(f"Rank {rank} after scatter: recv_tensor={output_tensor}") 32 | 33 | # 销毁进程组 34 | dist.destroy_process_group() -------------------------------------------------------------------------------- /collective_ops/test_all_to_all_single.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | # 初始化分布式环境 5 | dist.init_process_group(backend="nccl") 6 | rank = dist.get_rank() 7 | world_size = dist.get_world_size() 8 | 9 | device = torch.device("cuda:{}".format(rank)) 10 | 11 | # 输入张量 12 | if rank == 0: 13 | input = torch.tensor([0, 1, 2, 3, 4, 5]).to(device) 14 | input_splits = [2, 2, 1, 1] 15 | output_splits = [2, 3, 2, 2] 16 | elif rank == 1: 17 | input = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]).to(device) 18 | input_splits = [3, 2, 2, 2] 19 | output_splits = [2, 2, 1, 2] 20 | elif rank == 2: 21 | input = torch.tensor([20, 21, 22, 23, 24]).to(device) 22 | input_splits = [2, 1, 1, 1] 23 | output_splits = [1, 2, 1, 2] 24 | elif rank == 3: 25 | input = torch.tensor([30, 31, 32, 33, 34, 35, 36]).to(device) 26 | input_splits = [2, 2, 2, 1] 27 | output_splits = [1, 2, 1, 1] 28 | 29 | 30 | output = torch.empty(sum(output_splits), dtype=torch.int64).to(device) 31 | 32 | # 调用 all_to_all_single 33 | dist.all_to_all_single(output, input, output_split_sizes=output_splits, input_split_sizes=input_splits) 34 | 35 | print(f"Rank {rank}: Input = {input}, Output = {output}") 36 | 37 | # 销毁进程组 38 | dist.destroy_process_group() 39 | -------------------------------------------------------------------------------- /verl_test/test_fa_ops.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/data/zzd/verl") 3 | 4 | import torch 5 | from flash_attn.ops.triton.cross_entropy import cross_entropy_loss 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from verl.utils.debug import log_gpu_memory_usage 10 | from verl.utils.torch_functional import logprobs_from_logits_naive 11 | 12 | 13 | def test_flash_attn_cross_entropy(): 14 | log_gpu_memory_usage("At start") 15 | 16 | hidden_states = torch.randn(size=(2048, 5120), device="cuda", requires_grad=True, dtype=torch.float32) 17 | linear = nn.Linear(in_features=5120, out_features=155136, bias=False, device="cuda", dtype=torch.float32) 18 | logits = linear(hidden_states) # (2048, 155136) 19 | 20 | labels = torch.randint(low=0, high=155136, size=(2048,), device="cuda") 21 | 22 | log_gpu_memory_usage("before computation") 23 | output = cross_entropy_loss(logits, labels)[0] 24 | log_gpu_memory_usage("After forward") 25 | 26 | output.sum().backward() 27 | log_gpu_memory_usage("After backward") 28 | 29 | groundtruth = -logprobs_from_logits_naive(logits.float(), labels) 30 | torch.testing.assert_close(output, groundtruth) 31 | 32 | loss = F.cross_entropy(logits, labels) 33 | torch.testing.assert_close(loss, output.mean()) 34 | log_gpu_memory_usage("After loss") 35 | 36 | 37 | if __name__ == "__main__": 38 | test_flash_attn_cross_entropy() -------------------------------------------------------------------------------- /sequence_parallel/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | import torch.distributed as dist 5 | 6 | def set_seed(rank, seed=42): 7 | seed = rank + seed 8 | random.seed(seed) 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | 13 | 14 | def log(msg, a, rank0_only=False): 15 | world_size = dist.get_world_size() 16 | rank = dist.get_rank() 17 | if rank0_only: 18 | if rank == 0: 19 | print( 20 | f"{msg}: " 21 | f"max {a.abs().max().item():.3g}, " 22 | f"mean {a.abs().mean().item():.3g}", 23 | flush=True, 24 | ) 25 | return 26 | 27 | for i in range(world_size): 28 | if i == rank: 29 | if rank == 0: 30 | print(f"{msg}:") 31 | print( 32 | f"Rank[{rank}] " 33 | f"max {a.abs().max().item():.3g}, " 34 | f"mean {a.abs().mean().item():.3g}", 35 | flush=True, 36 | ) 37 | dist.barrier() 38 | 39 | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): 40 | assert mode in ["fwd", "bwd", "fwd_bwd"] 41 | f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) 42 | return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) 43 | 44 | def efficiency(flop, time): 45 | return (flop / time / 10**12) if not math.isnan(time) else 0.0 46 | -------------------------------------------------------------------------------- /verl_test/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def levenshtein(s1, s2): 3 | m, n = len(s1), len(s2) 4 | # Initialize matrix of zeros 5 | dp = [[0] * (n + 1) for _ in range(m + 1)] 6 | # Initialize first column and first row of the matrix 7 | for i in range(m + 1): 8 | dp[i][0] = i # Deletion from s1 to empty string 9 | for j in range(n + 1): 10 | dp[0][j] = j # Insertion to s1 from empty string 11 | # Compute the Levenshtein distance matrix 12 | for i in range(1, m + 1): 13 | for j in range(1, n + 1): 14 | cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match 15 | dp[i][j] = min( 16 | dp[i - 1][j] + 1, # Deletion 17 | dp[i][j - 1] + 1, # Insertion 18 | dp[i - 1][j - 1] + cost, # Substitution 19 | ) 20 | return dp[m][n] 21 | 22 | 23 | def are_lists_similar(a, b): 24 | if len(a) != len(b): 25 | print("The lists are of different lengths.") 26 | return False 27 | 28 | total_length = 0 29 | total_diff = 0 30 | 31 | for s1, s2 in zip(a, b): 32 | max_len = max(len(s1), len(s2)) 33 | total_length += max_len 34 | diff = levenshtein(s1, s2) 35 | total_diff += diff 36 | print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n") 37 | 38 | percentage_difference = (total_diff / total_length) * 100 39 | print(f"Total difference: {percentage_difference:.2f}%") 40 | 41 | return percentage_difference <= 15 42 | -------------------------------------------------------------------------------- /moe_ep/deepseek/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "DeepseekV2ForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "auto_map": { 8 | "AutoConfig": "configuration_deepseek.DeepseekV2Config", 9 | "AutoModel": "modeling_deepseek.DeepseekV2Model", 10 | "AutoModelForCausalLM": "modeling_deepseek.DeepseekV2ForCausalLM" 11 | }, 12 | "aux_loss_alpha": 0.001, 13 | "bos_token_id": 100000, 14 | "eos_token_id": 100001, 15 | "ep_size": 1, 16 | "first_k_dense_replace": 1, 17 | "hidden_act": "silu", 18 | "hidden_size": 2048, 19 | "initializer_range": 0.02, 20 | "intermediate_size": 10944, 21 | "kv_lora_rank": 512, 22 | "max_position_embeddings": 4096, 23 | "model_type": "deepseek_v2", 24 | "moe_intermediate_size": 1408, 25 | "moe_layer_freq": 1, 26 | "n_group": 1, 27 | "n_routed_experts": 64, 28 | "n_shared_experts": 2, 29 | "norm_topk_prob": false, 30 | "num_attention_heads": 16, 31 | "num_experts_per_tok": 6, 32 | "num_hidden_layers": 27, 33 | "num_key_value_heads": 16, 34 | "pretraining_tp": 1, 35 | "q_lora_rank": null, 36 | "qk_nope_head_dim": 128, 37 | "qk_rope_head_dim": 64, 38 | "rms_norm_eps": 1e-06, 39 | "rope_scaling": null, 40 | "rope_theta": 10000, 41 | "routed_scaling_factor": 1.0, 42 | "scoring_func": "softmax", 43 | "seq_aux": true, 44 | "tie_word_embeddings": false, 45 | "topk_group": 1, 46 | "topk_method": "greedy", 47 | "torch_dtype": "bfloat16", 48 | "transformers_version": "4.33.1", 49 | "use_cache": true, 50 | "v_head_dim": 128, 51 | "vocab_size": 102400 52 | } 53 | -------------------------------------------------------------------------------- /verl_test/rollout_test/test_vllm_spmd_profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vllm import LLM, SamplingParams 3 | 4 | # Create prompts, the same across all ranks 5 | prompts = [ 6 | "奇变偶不变", 7 | "The president of the United States is", 8 | "大鹏一日同风起,", 9 | "The future of AI is", 10 | ] 11 | 12 | # Create sampling parameters, the same across all ranks 13 | sampling_params = SamplingParams(temperature=0.8, top_p=0.95) 14 | 15 | # Use `distributed_executor_backend="external_launcher"` so that 16 | # this llm engine/instance only creates one worker. 17 | 18 | model_name = "/data3/ckpt/Qwen/Qwen2.5-3B-Instruct" 19 | llm = LLM( 20 | model=model_name, 21 | tensor_parallel_size=4, 22 | distributed_executor_backend="external_launcher", 23 | dtype="bfloat16", 24 | seed=1, 25 | ) 26 | 27 | outputs = llm.generate(prompts, sampling_params) 28 | 29 | # Use torch profiler to profile the model 30 | if True: 31 | torch.backends.cudnn.benchmark = True 32 | profiler = torch.profiler.profile( 33 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 34 | schedule=torch.profiler.schedule(wait=5, warmup=5, active=5,), 35 | record_shapes=True, 36 | profile_memory=True, 37 | with_flops=True, 38 | with_modules=True, 39 | with_stack=True, 40 | on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiles/vllm_spmd_profile"), 41 | ) 42 | profiler.start() 43 | 44 | for _ in range(20): 45 | outputs = llm.generate(prompts, sampling_params) 46 | profiler.step() 47 | profiler.stop() 48 | 49 | 50 | # all ranks will have the same outputs 51 | for output in outputs: 52 | prompt = output.prompt 53 | generated_text = output.outputs[0].text 54 | print(f"Prompt: {prompt!r}, " 55 | f"Generated text: {generated_text!r}") -------------------------------------------------------------------------------- /verl_test/test_memory_buffers.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Test memory buffers 4 | - We start with two models with the same weights 5 | - We use Memory buffer to make one of the models and then compare the parameters 6 | """ 7 | import sys 8 | sys.path.append("/data/zzd/verl") 9 | import gc 10 | 11 | import torch 12 | from transformers import LlamaConfig, LlamaModel 13 | 14 | from verl.utils.memory_buffer import MemoryBufferModuleWrapper 15 | 16 | 17 | def test_memory_buffers(): 18 | llama_config = LlamaConfig( 19 | vocab_size=155136, 20 | hidden_size=4096, 21 | intermediate_size=11008, 22 | num_hidden_layers=2, 23 | num_attention_heads=16, 24 | num_key_value_heads=16, 25 | ) 26 | 27 | model = LlamaModel(config=llama_config).cuda() 28 | model_copy = LlamaModel(config=llama_config).cuda() 29 | model_copy.load_state_dict(model.state_dict()) 30 | 31 | model_named_params = dict(model.named_parameters()) 32 | model_copy_named_params = dict(model_copy.named_parameters()) 33 | 34 | norm_factor = 1024**3 35 | 36 | t_before = torch.cuda.get_device_properties(0).total_memory / norm_factor 37 | r_before = torch.cuda.memory_reserved(0) / norm_factor 38 | a_before = torch.cuda.memory_allocated(0) / norm_factor 39 | 40 | print(f"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB") 41 | 42 | model_wrapper = MemoryBufferModuleWrapper(model) 43 | 44 | t = torch.cuda.get_device_properties(0).total_memory / norm_factor 45 | r = torch.cuda.memory_reserved(0) / norm_factor 46 | a = torch.cuda.memory_allocated(0) / norm_factor 47 | 48 | gc.collect() 49 | torch.cuda.empty_cache() 50 | 51 | print(f"After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB") 52 | 53 | change_ratio = (a - a_before) / a_before 54 | assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}" 55 | 56 | for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters()): 57 | assert name1 == name2 58 | assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}" 59 | 60 | 61 | if __name__ == "__main__": 62 | test_memory_buffers() 63 | -------------------------------------------------------------------------------- /sequence_parallel/basic/test_flash_attn_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flash_attn import flash_attn_func 3 | from reference import attention_ref 4 | from utils import set_seed 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | device = "cuda" 10 | set_seed(rank=0, seed=32) # set seed 11 | batch_size = 4 12 | seqlen_q = seqlen_k = 1024 13 | d = 128 14 | nheads_k = nheads = 8 15 | dtype = torch.float16 16 | 17 | assert nheads % nheads_k == 0 18 | dropout_p = 0 19 | causal = True 20 | deterministic = False 21 | window_size = (-1, -1) 22 | 23 | 24 | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) 25 | k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True) 26 | v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True) 27 | 28 | out, lse, S_dmask = flash_attn_func( 29 | q, k, v, 30 | dropout_p, 31 | causal=causal, 32 | window_size=window_size, 33 | softcap=0.0, 34 | alibi_slopes=None, 35 | deterministic=deterministic, 36 | return_attn_probs=True, 37 | ) 38 | 39 | out_ref, attn_ref = attention_ref( 40 | q, k, v, 41 | None, 42 | None, 43 | None, 44 | dropout_p, 45 | None, 46 | causal=causal, 47 | window_size=window_size, 48 | softcap=0.0, 49 | ) 50 | 51 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 52 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 53 | 54 | g = torch.randn_like(out) 55 | (dq, dk, dv,) = torch.autograd.grad(out, (q, k, v), g) 56 | (dq_ref, dk_ref, dv_ref,) = torch.autograd.grad(out_ref, (q, k, v), g) 57 | 58 | print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") 59 | print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") 60 | print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") 61 | print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") 62 | print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") 63 | print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") 64 | 65 | 66 | -------------------------------------------------------------------------------- /verl_test/rollout_test/test_spmd.py: -------------------------------------------------------------------------------- 1 | # 该代码需要使用 torchrun --nproc_per_node=4 test_spmd.py 来运行 2 | 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributed as dist 7 | 8 | dist.init_process_group(backend="nccl") 9 | torch.cuda.set_device(int(os.environ['RANK'])) 10 | world_size = dist.get_world_size() 11 | rank = dist.get_rank() 12 | print(f"Rank {rank} of {world_size} initializing...") 13 | 14 | torch.manual_seed(42) 15 | torch.cuda.manual_seed_all(42) 16 | 17 | # 构建一个从column维度切分的linear layer 18 | class HybridParallelLayer(torch.nn.Module): 19 | def __init__(self, input_size, output_size, world_size): 20 | super().__init__() 21 | self.world_size = world_size 22 | self.layer = nn.Linear(input_size, output_size // world_size, bias=False).to(device='cuda') 23 | 24 | def forward(self, x): 25 | local_output = self.layer(x) 26 | 27 | # 跨设备收集所有分片结果 28 | output_list = [torch.empty_like(local_output) for _ in range(self.world_size)] 29 | dist.all_gather(output_list, local_output) 30 | 31 | # 沿特征维度拼接 32 | return torch.cat(output_list, dim=-1) 33 | 34 | def load_weights(self, weight, rank): 35 | dim_per_rank = weight.shape[0] // self.world_size 36 | self.layer.weight.data.copy_(weight[rank*dim_per_rank: (rank+1)*dim_per_rank, :]) 37 | 38 | 39 | # 数据准备 ----------------------------------------------------------------- 40 | batch_size_per_gpu = 16 41 | global_batch_size = batch_size_per_gpu * world_size 42 | 43 | # 生成全局数据(模拟数据加载器行为) 44 | global_input = torch.randn(global_batch_size, 128).cuda() # 由于 seed 固定,每个 rank 数据是一样的 45 | local_input = global_input.chunk(world_size, dim=0)[rank].detach().clone() 46 | 47 | full_layer = torch.nn.Linear(128, 512, bias=False).cuda() 48 | weight = full_layer.weight.data 49 | print(f"full_layer weight shape: {weight.shape}") 50 | 51 | tp_layer = HybridParallelLayer(128, 512, world_size) 52 | tp_layer.load_weights(weight, rank) 53 | 54 | tp_ret = tp_layer(global_input) # TP 输入的数据必须一样 55 | with torch.no_grad(): 56 | fl_ret = full_layer(local_input) 57 | 58 | torch.testing.assert_close(tp_ret.chunk(world_size, dim=0)[rank].detach().cpu(), 59 | fl_ret.detach().cpu(), atol=1e-3, rtol=1e-3) 60 | if rank == 0: 61 | print("✅ 前向传播验证通过") 62 | 63 | -------------------------------------------------------------------------------- /verl_test/test_ray_tp.py: -------------------------------------------------------------------------------- 1 | # 该代码只需要使用 python test_ray_tp.py 即可运行 2 | 3 | import os 4 | import socket 5 | import torch 6 | import torch.nn as nn 7 | import torch.distributed as dist 8 | import ray 9 | 10 | # 构建一个从column维度切分的linear layer 11 | class HybridParallelLayer(torch.nn.Module): 12 | def __init__(self, input_size, output_size, world_size): 13 | super().__init__() 14 | self.world_size = world_size 15 | if not dist.is_initialized(): 16 | dist.init_process_group(backend="nccl") 17 | self.layer = nn.Linear(input_size, output_size // world_size, bias=False).to(device='cuda') 18 | 19 | def forward(self, x): 20 | local_output = self.layer(x) 21 | 22 | # 跨设备收集所有分片结果 23 | output_list = [torch.empty_like(local_output) for _ in range(self.world_size)] 24 | dist.all_gather(output_list, local_output) 25 | 26 | # 沿特征维度拼接 27 | return torch.cat(output_list, dim=-1) 28 | 29 | def load_weights(self, weight, rank): 30 | dim_per_rank = weight.shape[0] // self.world_size 31 | self.layer.weight.data.copy_(weight[rank*dim_per_rank: (rank+1)*dim_per_rank, :]) 32 | 33 | 34 | ray.init() 35 | 36 | master_addr = ray._private.services.get_node_ip_address() 37 | with socket.socket() as sock: 38 | sock.bind(('', 0)) 39 | master_port = sock.getsockname()[1] 40 | 41 | num_gpus = 4 42 | workers = [] 43 | for i in range(num_gpus): 44 | options = {'runtime_env': {'env_vars': {'WORLD_SIZE': str(num_gpus), 'RANK': str(i), 'MASTER_ADDR': master_addr, 'MASTER_PORT': str(master_port)}}} 45 | workers.append(ray.remote(num_gpus=1)(HybridParallelLayer).options(**options).remote(128, 512, num_gpus)) 46 | 47 | batch_size = 10 48 | input_data = torch.randn(batch_size, 128).cuda() 49 | 50 | full_layer = torch.nn.Linear(128, 512, bias=False).cuda() 51 | weight = full_layer.state_dict()['weight'] 52 | 53 | ret_list = [] 54 | for i in range(num_gpus): 55 | _ = ray.get(workers[i].load_weights.remote(weight, i)) 56 | 57 | for i in range(num_gpus): 58 | ret_list.append(workers[i].forward.remote(input_data)) 59 | 60 | ret = ray.get(ret_list) 61 | ray.shutdown() 62 | 63 | fl_ret = full_layer(input_data).cpu() 64 | torch.testing.assert_close(ret[0], ret[1]) 65 | torch.testing.assert_close(ret[0].cpu(), fl_ret) 66 | 67 | print("✅ 前向传播验证通过") 68 | 69 | -------------------------------------------------------------------------------- /torch_memory_walkthrough.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/zzd/miniconda3/envs/test/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n", 14 | "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import torch\n", 20 | "from transformers import AutoModelForCausalLM\n", 21 | "\n", 22 | "# Start recording memory snapshot history\n", 23 | "torch.cuda.memory._record_memory_history(max_entries=100000)\n", 24 | "\n", 25 | "model = AutoModelForCausalLM.from_pretrained(\"/data/ckpt/Qwen/Qwen2.5-0.5B\").to(\"cuda\")\n", 26 | "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n", 27 | "inputs = torch.randint(0, 100, (10, 256), device=\"cuda\") # Dummy input\n", 28 | "\n", 29 | "for i in range(3):\n", 30 | " input = inputs[i:i+2]\n", 31 | " loss = torch.mean(model(inputs).logits) # Dummy loss\n", 32 | " loss.backward()\n", 33 | " optimizer.step()\n", 34 | " optimizer.zero_grad()\n", 35 | "\n", 36 | "# Dump memory snapshot history to a file and stop recording\n", 37 | "torch.cuda.memory._dump_snapshot(\"profile3.pkl\")\n", 38 | "torch.cuda.memory._record_memory_history(enabled=None)\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "![image.png](images/image.png)" 46 | ] 47 | } 48 | ], 49 | "metadata": { 50 | "kernelspec": { 51 | "display_name": "test", 52 | "language": "python", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "codemirror_mode": { 57 | "name": "ipython", 58 | "version": 3 59 | }, 60 | "file_extension": ".py", 61 | "mimetype": "text/x-python", 62 | "name": "python", 63 | "nbconvert_exporter": "python", 64 | "pygments_lexer": "ipython3", 65 | "version": "3.10.16" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 2 70 | } 71 | -------------------------------------------------------------------------------- /sequence_parallel/basic/test_flash_attn_qkvpacked_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | from flash_attn import flash_attn_qkvpacked_func 5 | from reference import attention_ref 6 | 7 | 8 | if __name__ == "__main__": 9 | device = "cuda" 10 | # set seed 11 | torch.random.manual_seed(0) 12 | batch_size = 4 13 | seqlen = 1024 14 | d = 128 15 | nheads_k = nheads = 8 16 | dtype = torch.float16 17 | dropout_p = 0 18 | causal = True 19 | deterministic = False 20 | window_size=(-1, -1) 21 | alibi_slopes, attn_bias = None, None 22 | dropout_mask = None 23 | 24 | assert nheads % nheads_k == 0 25 | # window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) 26 | qkv = torch.randn( 27 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 28 | ) 29 | q, k, v = qkv.clone().unbind(2) 30 | 31 | out, lse, S_dmask = flash_attn_qkvpacked_func( 32 | qkv, 33 | dropout_p, 34 | causal=causal, 35 | window_size=window_size, 36 | softcap=0.0, 37 | alibi_slopes=alibi_slopes, 38 | deterministic=deterministic, 39 | return_attn_probs=True, 40 | ) 41 | 42 | out_ref, attn_ref = attention_ref( 43 | q, k, v, 44 | None, 45 | None, 46 | attn_bias, 47 | dropout_p, 48 | dropout_mask, 49 | causal=causal, 50 | window_size=window_size, 51 | softcap=0.0, 52 | ) 53 | 54 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 55 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 56 | 57 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 58 | out.backward(dout) 59 | dqkv = qkv.grad 60 | 61 | (dq_ref, dk_ref, dv_ref,) = torch.autograd.grad(out_ref, (q, k, v), dout) 62 | 63 | print(f"dQ max diff: {(dqkv[:,:,0] - dq_ref).abs().max().item()}") 64 | print(f"dK max diff: {(dqkv[:,:,1] - dk_ref).abs().max().item()}") 65 | print(f"dV max diff: {(dqkv[:,:,2] - dv_ref).abs().max().item()}") 66 | print(f"dQ mean diff: {(dqkv[:,:,0] - dq_ref).abs().mean().item()}") 67 | print(f"dK mean diff: {(dqkv[:,:,1] - dk_ref).abs().mean().item()}") 68 | print(f"dV mean diff: {(dqkv[:,:,2] - dv_ref).abs().mean().item()}") 69 | 70 | -------------------------------------------------------------------------------- /verl_test/scripts/train_sppo_32b.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | limo_train_path=/mnt/nvme0/zzd/data/LIMO/train.parquet 4 | limo_test_path=/mnt/nvme0/zzd/data/LIMO/test.parquet 5 | math_train_path=/mnt/nvme0/zzd/data/MATH/train.parquet 6 | math_test_path=/mnt/nvme0/zzd/data/MATH/test.parquet 7 | 8 | train_files="['$limo_train_path', '$math_train_path']" 9 | test_files="['$limo_test_path', '$math_test_path']" 10 | 11 | WORKING_DIR=${WORKING_DIR:-"${PWD}"} 12 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} 13 | 14 | ray job submit --address="http://10.157.150.10:8265" \ 15 | --runtime-env="${RUNTIME_ENV}" \ 16 | -- python3 -m recipe.sppo.main_sppo \ 17 | data.train_files="$train_files" \ 18 | data.val_files="$test_files" \ 19 | data.train_batch_size=1024 \ 20 | data.max_prompt_length=1024 \ 21 | data.max_response_length=1024 \ 22 | data.filter_overlong_prompts=True \ 23 | data.truncation='error' \ 24 | data.return_raw_chat=True \ 25 | actor_rollout_ref.model.path=/mnt/nvme0/zzd/ckpt/Qwen/Qwen2.5-32B \ 26 | actor_rollout_ref.actor.optim.lr=1e-6 \ 27 | actor_rollout_ref.model.use_remove_padding=True \ 28 | actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ 29 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \ 30 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ 31 | actor_rollout_ref.actor.use_kl_loss=False \ 32 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 33 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 34 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 35 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ 36 | actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ 37 | actor_rollout_ref.rollout.name=vllm \ 38 | actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ 39 | custom_reward_function.path=/mnt/nvme0/zzd/verl/math_score.py \ 40 | algorithm.use_kl_in_reward=False \ 41 | trainer.critic_warmup=0 \ 42 | trainer.logger=['console','wandb'] \ 43 | trainer.project_name='verl-h20' \ 44 | trainer.val_before_train=True \ 45 | trainer.experiment_name='sppo_32b_0511' \ 46 | trainer.n_gpus_per_node=8 \ 47 | trainer.nnodes=2 \ 48 | trainer.save_freq=-1 \ 49 | trainer.test_freq=5 \ 50 | trainer.total_epochs=200 $@ 51 | # Note that we set lr_warmup_steps = 15 in config/sppo_trainer.yaml 52 | # The experiment will converge to 0.656 on MATH dataset after 20 epochs -------------------------------------------------------------------------------- /collective_ops/test_ring_comm_customized.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | class RingComm: 7 | def __init__(self, process_group: dist.ProcessGroup): 8 | self._process_group = process_group 9 | self._ops = [] 10 | self.rank = dist.get_rank(self._process_group) 11 | self.world_size = dist.get_world_size(self._process_group) 12 | self._reqs = None 13 | 14 | self.send_rank = (self.rank + 1) % self.world_size 15 | self.recv_rank = (self.rank - 1) % self.world_size 16 | 17 | if process_group is not None: 18 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) 19 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) 20 | 21 | def send_recv( 22 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None 23 | ) -> torch.Tensor: 24 | if recv_tensor is None: 25 | res = torch.empty_like(to_send) 26 | else: 27 | res = recv_tensor 28 | 29 | send_op = dist.P2POp( 30 | dist.isend, to_send, self.send_rank, group=self._process_group 31 | ) 32 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) 33 | self._ops.append(send_op) 34 | self._ops.append(recv_op) 35 | return res 36 | 37 | def commit(self): 38 | if self._reqs is not None: 39 | raise RuntimeError("commit called twice") 40 | self._reqs = dist.batch_isend_irecv(self._ops) 41 | 42 | def wait(self): 43 | if self._reqs is None: 44 | raise RuntimeError("wait called before commit") 45 | for req in self._reqs: 46 | req.wait() 47 | self._reqs = None 48 | self._ops = [] 49 | 50 | 51 | # 初始化进程组 52 | dist.init_process_group("nccl") 53 | rank = dist.get_rank() 54 | world_size = dist.get_world_size() 55 | device = torch.device("cuda:{}".format(rank)) 56 | 57 | ring_comm = RingComm(process_group=None) 58 | 59 | tensor = torch.tensor([rank, rank+1], dtype=torch.float32, device=device) 60 | print("rank {} tensor {}".format(rank, tensor)) 61 | 62 | recv_tensor = ring_comm.send_recv(tensor) 63 | # Commit the operations 64 | ring_comm.commit() 65 | # Wait for the operations to complete 66 | ring_comm.wait() 67 | 68 | print(f"Rank {rank} received: {recv_tensor}") 69 | 70 | 71 | # 销毁进程组 72 | dist.destroy_process_group() 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /verl_test/agent_test/agent_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Union 3 | 4 | import ray 5 | from omegaconf import DictConfig 6 | 7 | from verl.experimental.agent_loop import AgentLoopManager 8 | from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup 9 | from verl.single_controller.ray.base import create_colocated_worker_cls 10 | from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role 11 | from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker 12 | 13 | 14 | def init_agent_loop_manager(config: DictConfig) -> Union[AgentLoopManager, RayWorkerGroup]: 15 | # =========================== 1. Create hybrid ActorRollout workers =========================== 16 | actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker 17 | role_worker_mapping = {Role.ActorRollout: ray.remote(actor_rollout_cls),} 18 | 19 | global_pool_id = "global_pool" 20 | resource_pool_spec = {global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,} 21 | mapping = {Role.ActorRollout: global_pool_id,} 22 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) 23 | resource_pool_manager.create_resource_pool() 24 | resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} 25 | 26 | # create actor and rollout 27 | resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) 28 | actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout") 29 | resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls 30 | 31 | all_wg = {} 32 | for resource_pool, class_dict in resource_pool_to_cls.items(): 33 | worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) 34 | wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) 35 | spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) 36 | all_wg.update(spawn_wg) 37 | actor_rollout_wg = all_wg["actor_rollout"] 38 | actor_rollout_wg.init_model() 39 | 40 | if config.actor_rollout_ref.rollout.mode == "sync": 41 | return actor_rollout_wg 42 | 43 | # =========================== 2. Create AgentLoopManager =========================== 44 | print("Creating AgentLoopManager...") 45 | agent_loop_manager = AgentLoopManager( 46 | config=config, 47 | worker_group=actor_rollout_wg, 48 | ) 49 | 50 | return agent_loop_manager 51 | -------------------------------------------------------------------------------- /verl_test/scripts/train_rfpp.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | limo_train_path=/mnt/nvme0/zzd/data/LIMO/train.parquet 4 | limo_test_path=/mnt/nvme0/zzd/data/LIMO/test.parquet 5 | math_train_path=/mnt/nvme0/zzd/data/MATH/train.parquet 6 | math_test_path=/mnt/nvme0/zzd/data/MATH/test.parquet 7 | 8 | train_files="['$limo_train_path', '$math_train_path']" 9 | test_files="['$limo_test_path', '$math_test_path']" 10 | 11 | WORKING_DIR=${WORKING_DIR:-"${PWD}"} 12 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} 13 | 14 | ray job submit --address="http://10.157.150.10:8265" \ 15 | --runtime-env="${RUNTIME_ENV}" \ 16 | --working-dir "${WORKING_DIR}" \ 17 | -- python3 -m verl.trainer.main_ppo \ 18 | algorithm.adv_estimator=reinforce_plus_plus \ 19 | data.train_files="$train_files" \ 20 | data.val_files="$test_files" \ 21 | data.train_batch_size=1024 \ 22 | data.max_prompt_length=1024 \ 23 | data.max_response_length=1024 \ 24 | data.filter_overlong_prompts=True \ 25 | data.truncation='error' \ 26 | actor_rollout_ref.model.path=/mnt/nvme0/zzd/ckpt/Qwen/Qwen2.5-32B \ 27 | actor_rollout_ref.actor.optim.lr=3e-6 \ 28 | actor_rollout_ref.model.use_remove_padding=True \ 29 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \ 30 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ 31 | actor_rollout_ref.actor.use_kl_loss=True \ 32 | actor_rollout_ref.actor.kl_loss_coef=0.001 \ 33 | actor_rollout_ref.actor.kl_loss_type=mse \ 34 | actor_rollout_ref.actor.entropy_coeff=0 \ 35 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 36 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 37 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 38 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ 39 | actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ 40 | actor_rollout_ref.rollout.name=vllm \ 41 | actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ 42 | actor_rollout_ref.rollout.n=8 \ 43 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ 44 | actor_rollout_ref.ref.fsdp_config.param_offload=False \ 45 | custom_reward_function.path=/mnt/nvme0/zzd/verl/math_score.py \ 46 | algorithm.use_kl_in_reward=False \ 47 | trainer.critic_warmup=0 \ 48 | trainer.logger=['console','wandb'] \ 49 | trainer.project_name='verl-h20' \ 50 | trainer.experiment_name='rfpp_32b_math_0429' \ 51 | trainer.val_before_train=True \ 52 | trainer.n_gpus_per_node=8 \ 53 | trainer.nnodes=2 \ 54 | trainer.save_freq=-1 \ 55 | trainer.test_freq=5 \ 56 | trainer.total_epochs=20 $@ -------------------------------------------------------------------------------- /verl_test/scripts/train_grpo_32b.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | gsm8k_train_path=/mnt/nvme0/zzd/data/GSM8K/train.parquet 4 | gsm8k_test_path=/mnt/nvme0/zzd/data/GSM8K/test.parquet 5 | math_train_path=/mnt/nvme0/zzd/data/MATH/train.parquet 6 | math_test_path=/mnt/nvme0/zzd/data/MATH/test.parquet 7 | 8 | train_files="['$gsm8k_train_path', '$math_train_path']" 9 | test_files="['$gsm8k_test_path', '$math_test_path']" 10 | 11 | WORKING_DIR=${WORKING_DIR:-"${PWD}"} 12 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} 13 | 14 | ray job submit --address="http://10.157.150.10:8265" \ 15 | --runtime-env="${RUNTIME_ENV}" \ 16 | --working-dir "${WORKING_DIR}" \ 17 | -- python3 -m verl.trainer.main_ppo \ 18 | algorithm.adv_estimator=grpo \ 19 | data.train_files="$train_files" \ 20 | data.val_files="$test_files" \ 21 | data.train_batch_size=1024 \ 22 | data.max_prompt_length=1024 \ 23 | data.max_response_length=1024 \ 24 | data.filter_overlong_prompts=True \ 25 | data.truncation='error' \ 26 | actor_rollout_ref.model.path=/mnt/nvme0/zzd/ckpt/Qwen/Qwen2.5-32B-Instruct \ 27 | actor_rollout_ref.actor.optim.lr=1e-6 \ 28 | actor_rollout_ref.model.use_remove_padding=True \ 29 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \ 30 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ 31 | actor_rollout_ref.actor.use_kl_loss=True \ 32 | actor_rollout_ref.actor.kl_loss_coef=0.001 \ 33 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 34 | actor_rollout_ref.actor.entropy_coeff=0 \ 35 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 36 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 37 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ 38 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ 39 | actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ 40 | actor_rollout_ref.rollout.name=vllm \ 41 | actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ 42 | actor_rollout_ref.rollout.n=5 \ 43 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ 44 | actor_rollout_ref.ref.fsdp_config.param_offload=False \ 45 | custom_reward_function.path=/mnt/nvme0/zzd/verl/math_score.py \ 46 | algorithm.use_kl_in_reward=False \ 47 | trainer.critic_warmup=0 \ 48 | trainer.logger=['console','wandb'] \ 49 | trainer.project_name='verl-h20' \ 50 | trainer.experiment_name='grpo_math_32b_0427' \ 51 | trainer.val_before_train=True \ 52 | trainer.n_gpus_per_node=8 \ 53 | trainer.nnodes=2 \ 54 | trainer.save_freq=-1 \ 55 | trainer.test_freq=5 \ 56 | trainer.total_epochs=10 $@ -------------------------------------------------------------------------------- /sequence_parallel/ring_flash_attention/test_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_qkvpacked_func 4 | from ring_flash_attention.ring_flash_attn import ring_flash_attn_qkvpacked_func 5 | from utils import log, set_seed 6 | 7 | 8 | if __name__ == "__main__": 9 | dist.init_process_group("nccl") 10 | rank = dist.get_rank() 11 | set_seed(rank) 12 | world_size = dist.get_world_size() 13 | dtype = torch.bfloat16 14 | device = torch.device(f"cuda:{rank}") 15 | 16 | batch_size = 1 17 | seqlen = 3816 18 | nheads = 5 19 | d = 128 20 | dropout_p = 0 21 | causal = True 22 | deterministic = False 23 | 24 | assert seqlen % world_size == 0 25 | assert d % 8 == 0 26 | 27 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 28 | dist.broadcast(qkv, src=0) 29 | 30 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 31 | dist.broadcast(dout, src=0) 32 | 33 | local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() 34 | local_qkv.requires_grad = True 35 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 36 | 37 | dist.barrier() 38 | if rank == 0: 39 | print("#" * 30) 40 | print("# forward:") 41 | print("#" * 30) 42 | 43 | out, lse, _ = flash_attn_qkvpacked_func( 44 | qkv, 45 | dropout_p=dropout_p, 46 | causal=causal, 47 | window_size=(-1, -1), 48 | alibi_slopes=None, 49 | deterministic=deterministic, 50 | return_attn_probs=True, 51 | ) 52 | 53 | local_out = out.chunk(world_size, dim=1)[rank] 54 | local_lse = lse.chunk(world_size, dim=-1)[rank] 55 | 56 | ring_out, ring_lse, _ = ring_flash_attn_qkvpacked_func( 57 | local_qkv, 58 | dropout_p=dropout_p, 59 | causal=causal, 60 | window_size=(-1, -1), 61 | alibi_slopes=None, 62 | deterministic=deterministic, 63 | return_attn_probs=True, 64 | ) 65 | 66 | log("out diff", local_out - ring_out) 67 | log("lse diff", local_lse - ring_lse) 68 | 69 | dist.barrier() 70 | if rank == 0: 71 | print("#" * 30) 72 | print("# backward:") 73 | print("#" * 30) 74 | 75 | out.backward(dout) 76 | dqkv = qkv.grad 77 | local_dqkv = dqkv.chunk(world_size, dim=1)[rank] 78 | 79 | ring_out.backward(local_dout) 80 | ring_dqkv = local_qkv.grad 81 | 82 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 83 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 84 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 85 | 86 | -------------------------------------------------------------------------------- /verl_test/scripts/train_ppo_14b.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | limo_train_path=/mnt/nvme0/zzd/data/LIMO/train.parquet 4 | limo_test_path=/mnt/nvme0/zzd/data/LIMO/test.parquet 5 | math_train_path=/mnt/nvme0/zzd/data/MATH/train.parquet 6 | math_test_path=/mnt/nvme0/zzd/data/MATH/test.parquet 7 | 8 | train_files="['$limo_train_path', '$math_train_path']" 9 | test_files="['$limo_test_path', '$math_test_path']" 10 | 11 | WORKING_DIR=${WORKING_DIR:-"${PWD}"} 12 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} 13 | 14 | ray job submit --address="http://10.157.150.10:8265" \ 15 | --runtime-env="${RUNTIME_ENV}" \ 16 | --working-dir "${WORKING_DIR}" \ 17 | -- python3 -m verl.trainer.main_ppo \ 18 | algorithm.adv_estimator=gae \ 19 | data.train_files="$train_files" \ 20 | data.val_files="$test_files" \ 21 | data.train_batch_size=1024 \ 22 | data.max_prompt_length=1024 \ 23 | data.max_response_length=1024 \ 24 | data.filter_overlong_prompts=True \ 25 | data.truncation='error' \ 26 | actor_rollout_ref.model.path=/mnt/nvme0/zzd/ckpt/Qwen/Qwen2.5-14B-Instruct \ 27 | actor_rollout_ref.actor.optim.lr=1e-6 \ 28 | actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ 29 | actor_rollout_ref.model.use_remove_padding=True \ 30 | actor_rollout_ref.actor.ppo_mini_batch_size=128 \ 31 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ 32 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 33 | actor_rollout_ref.actor.fsdp_config.param_offload=False \ 34 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ 35 | actor_rollout_ref.actor.use_kl_loss=False \ 36 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ 37 | actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ 38 | actor_rollout_ref.rollout.name=vllm \ 39 | actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ 40 | critic.optim.lr=1e-5 \ 41 | critic.model.use_remove_padding=True \ 42 | critic.model.path=/mnt/nvme0/zzd/ckpt/Qwen/Qwen2.5-14B-Instruct \ 43 | critic.model.enable_gradient_checkpointing=False \ 44 | critic.ppo_micro_batch_size_per_gpu=4 \ 45 | critic.model.fsdp_config.param_offload=False \ 46 | critic.model.fsdp_config.optimizer_offload=True \ 47 | custom_reward_function.path=/mnt/nvme0/zzd/verl/math_score.py \ 48 | algorithm.use_kl_in_reward=False \ 49 | trainer.critic_warmup=0 \ 50 | trainer.logger=['console','wandb'] \ 51 | trainer.project_name='verl-h20' \ 52 | trainer.experiment_name='ppo_math_14b_0429' \ 53 | trainer.val_before_train=False \ 54 | trainer.n_gpus_per_node=8 \ 55 | trainer.nnodes=2 \ 56 | trainer.save_freq=-1 \ 57 | trainer.test_freq=5 \ 58 | trainer.total_epochs=15 $@ -------------------------------------------------------------------------------- /verl_test/test_rm_worker_fsdp.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append("/data2/zzd/rl_llm/verl") 4 | import pdb 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | import ray 9 | import torch 10 | ray.init( 11 | runtime_env={ 12 | "working_dir": "/data2/zzd/rl_llm/verl", # 工作目录(会上传到集群) 13 | } 14 | ) 15 | 16 | def get_config(): 17 | from omegaconf import OmegaConf 18 | config = OmegaConf.load("/data2/zzd/rl_llm/verl/verl/trainer/config/ppo_trainer.yaml") 19 | config.reward_model.model.input_tokenizer=None 20 | config.reward_model.model.path="/data3/ckpt/sfairXC/FsfairX-LLaMA3-RM-v0.1" 21 | config.reward_model.micro_batch_size_per_gpu=4 22 | return config 23 | 24 | config = get_config() 25 | 26 | 27 | def gen_test_batch_data(): 28 | from verl import DataProto 29 | batch_dict = { 30 | "input_ids": torch.randint(0, 100, (16, 128)), 31 | "attention_mask": torch.ones((16, 128)), 32 | "position_ids": torch.arange(128).expand(16, -1), 33 | "responses": torch.randint(0, 100, (16, 32)), 34 | } 35 | batch = DataProto.from_single_dict(batch_dict) 36 | return batch 37 | 38 | 39 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool 40 | from verl.workers.fsdp_workers import RewardModelWorker 41 | from verl.single_controller.ray import RayWorkerGroup 42 | 43 | rm_worker = ray.remote(RewardModelWorker) 44 | 45 | resource_pool = RayResourcePool([8], use_gpu=True, max_colocate_count=1) 46 | rm_cls = RayClassWithInitArgs(rm_worker, config=config.reward_model) 47 | rm_worker_group = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=rm_cls, name_prefix='rm') 48 | 49 | print("world size:", rm_worker_group.world_size) 50 | print("worker_names:", rm_worker_group.worker_names) 51 | rm_worker_group.init_model() 52 | 53 | batch = gen_test_batch_data() 54 | reward_tensor = rm_worker_group.compute_rm_score(batch) 55 | reward_tensor.to("cpu") 56 | print(reward_tensor.batch['rm_scores']) 57 | 58 | # Use torch profiler to profile the model 59 | if False: 60 | torch.backends.cudnn.benchmark = True 61 | profiler = torch.profiler.profile( 62 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 63 | schedule=torch.profiler.schedule(wait=5, warmup=5, active=5,), 64 | record_shapes=True, 65 | profile_memory=True, 66 | with_flops=True, 67 | with_modules=True, 68 | with_stack=True, 69 | on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiles/rm_worker_profile"), 70 | ) 71 | profiler.start() 72 | 73 | for _ in range(20): 74 | reward_tensor = rm_worker_group.compute_rm_score(batch) 75 | profiler.step() 76 | profiler.stop() 77 | 78 | ray.shutdown() -------------------------------------------------------------------------------- /sequence_parallel/ring_flash_attention/test_stripe_flash_attn_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_qkvpacked_func 4 | from ring_flash_attention.stripe_flash_attn import stripe_flash_attn_qkvpacked_func, extract_local 5 | from utils import log, set_seed 6 | 7 | 8 | if __name__ == "__main__": 9 | dist.init_process_group("nccl") 10 | rank = dist.get_rank() 11 | set_seed(rank) 12 | world_size = dist.get_world_size() 13 | dtype = torch.bfloat16 14 | device = torch.device(f"cuda:{rank}") 15 | 16 | batch_size = 1 17 | seqlen = 3824 18 | nheads = 5 19 | d = 128 20 | dropout_p = 0 21 | causal = True 22 | deterministic = False 23 | 24 | assert causal 25 | assert seqlen % (2 * world_size) == 0 26 | assert d % 8 == 0 27 | 28 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 29 | dist.broadcast(qkv, src=0) 30 | 31 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 32 | dist.broadcast(dout, src=0) 33 | 34 | local_qkv = extract_local(qkv, rank, world_size).detach().clone() 35 | local_qkv.requires_grad = True 36 | local_dout = extract_local(dout, rank, world_size).detach().clone() 37 | 38 | dist.barrier() 39 | if rank == 0: 40 | print("#" * 30) 41 | print("# forward:") 42 | print("#" * 30) 43 | 44 | out, lse, _ = flash_attn_qkvpacked_func( 45 | qkv, 46 | dropout_p=dropout_p, 47 | causal=causal, 48 | window_size=(-1, -1), 49 | alibi_slopes=None, 50 | deterministic=deterministic, 51 | return_attn_probs=True, 52 | ) 53 | 54 | local_out = extract_local(out, rank, world_size) 55 | local_lse = extract_local(lse, rank, world_size, dim=2) 56 | 57 | ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func( 58 | local_qkv, 59 | dropout_p=dropout_p, 60 | causal=causal, 61 | window_size=(-1, -1), 62 | alibi_slopes=None, 63 | deterministic=deterministic, 64 | return_attn_probs=True, 65 | ) 66 | 67 | log("out", out, rank0_only=True) 68 | log("lse", lse, rank0_only=True) 69 | log("out diff", local_out - ring_out) 70 | log("lse diff", local_lse - ring_lse) 71 | 72 | dist.barrier() 73 | if rank == 0: 74 | print("#" * 30) 75 | print("# backward:") 76 | print("#" * 30) 77 | 78 | out.backward(dout) 79 | dqkv = qkv.grad 80 | 81 | local_dqkv = extract_local(dqkv, rank, world_size) 82 | 83 | ring_out.backward(local_dout) 84 | ring_dqkv = local_qkv.grad 85 | 86 | log("local_dqkv", local_dqkv) 87 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 88 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 89 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 90 | -------------------------------------------------------------------------------- /verl_test/scripts/run_qwen_search.sh: -------------------------------------------------------------------------------- 1 | # run on 8xH20 2 | # make sure your current working directory is the root of the project 3 | 4 | set -x 5 | 6 | ulimit -n 65535 7 | 8 | export TENSORBOARD_DIR="tensorboard_log_0711" 9 | 10 | PROJECT_DIR="$(pwd)" 11 | CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" 12 | 13 | TRAIN_DATA="/data2/zzd/data/searchR1_processed_direct/train.parquet" 14 | VAL_DATA="/data2/zzd/data/searchR1_processed_direct/test.parquet" 15 | 16 | TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml" 17 | 18 | python3 -m verl.trainer.main_ppo \ 19 | --config-path="$CONFIG_PATH" \ 20 | --config-name='search_multiturn_grpo' \ 21 | algorithm.adv_estimator=grpo \ 22 | data.train_batch_size=512 \ 23 | data.val_batch_size=256 \ 24 | data.max_prompt_length=2048 \ 25 | data.max_response_length=1024 \ 26 | data.filter_overlong_prompts=True \ 27 | data.truncation='error' \ 28 | data.return_raw_chat=True \ 29 | actor_rollout_ref.model.path=/data3/ckpt/Qwen/Qwen2.5-1.5B-Instruct \ 30 | actor_rollout_ref.actor.optim.lr=1e-6 \ 31 | actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \ 32 | actor_rollout_ref.model.use_remove_padding=True \ 33 | actor_rollout_ref.actor.ppo_mini_batch_size=128 \ 34 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ 35 | actor_rollout_ref.actor.use_kl_loss=True \ 36 | actor_rollout_ref.actor.kl_loss_coef=0.001 \ 37 | actor_rollout_ref.actor.kl_loss_type=low_var_kl \ 38 | actor_rollout_ref.actor.entropy_coeff=0 \ 39 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 40 | actor_rollout_ref.actor.fsdp_config.param_offload=True \ 41 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ 42 | actor_rollout_ref.rollout.max_model_len=15000 \ 43 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ 44 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ 45 | actor_rollout_ref.rollout.name=sglang \ 46 | actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ 47 | actor_rollout_ref.rollout.n=5 \ 48 | actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \ 49 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ 50 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 51 | algorithm.use_kl_in_reward=False \ 52 | trainer.critic_warmup=0 \ 53 | trainer.val_before_train=False \ 54 | trainer.logger=['console','tensorboard'] \ 55 | trainer.project_name='search_r1_like_async_rl' \ 56 | trainer.experiment_name='qwen2.5-1.5b-search-0711' \ 57 | trainer.n_gpus_per_node=8 \ 58 | trainer.nnodes=1 \ 59 | trainer.save_freq=100 \ 60 | trainer.test_freq=50 \ 61 | data.train_files="$TRAIN_DATA" \ 62 | data.val_files="$VAL_DATA" \ 63 | actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \ 64 | trainer.total_epochs=1 65 | 66 | -------------------------------------------------------------------------------- /rlhf/reward_verl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "433a71d8", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "Advantages: tensor([[ 0., -1., -3.],\n", 14 | " [ -6., -10., 0.]])\n", 15 | "Returns: tensor([[6., 5., 3.],\n", 16 | " [9., 5., 0.]])\n" 17 | ] 18 | } 19 | ], 20 | "source": [ 21 | "import torch\n", 22 | "\n", 23 | "\n", 24 | "def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor,\n", 25 | " response_mask: torch.Tensor):\n", 26 | " \"\"\"\n", 27 | " Compute advantage for ReMax, operating only on Outcome reward \n", 28 | " This implementation is based on the paper: https://arxiv.org/abs/2310.10505\n", 29 | "\n", 30 | " (with only one scalar reward for each response).\n", 31 | " Args:\n", 32 | " token_level_rewards: `(torch.Tensor)`\n", 33 | " shape: (bs, response_length)\n", 34 | " reward_baselines: `(torch.Tensor)`\n", 35 | " shape: (bs,)\n", 36 | " response_mask: `(torch.Tensor)`\n", 37 | " shape: (bs, response_length)\n", 38 | " \n", 39 | " Returns:\n", 40 | " advantages: `(torch.Tensor)`\n", 41 | " shape: (bs, response_length)\n", 42 | " Returns: `(torch.Tensor)`\n", 43 | " shape: (bs, response_length)\n", 44 | " \"\"\"\n", 45 | "\n", 46 | " with torch.no_grad():\n", 47 | " returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])\n", 48 | " advantages = returns - reward_baselines.unsqueeze(-1) * response_mask\n", 49 | "\n", 50 | " return advantages, returns\n", 51 | "\n", 52 | "token_level_rewards = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n", 53 | "reward_baselines = torch.tensor([6.0, 15.0])\n", 54 | "response_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])\n", 55 | "\n", 56 | "advantages, returns = compute_remax_outcome_advantage(token_level_rewards, reward_baselines, response_mask)\n", 57 | "print(\"Advantages:\", advantages)\n", 58 | "print(\"Returns:\", returns)\n" 59 | ] 60 | } 61 | ], 62 | "metadata": { 63 | "kernelspec": { 64 | "display_name": "rl-env", 65 | "language": "python", 66 | "name": "python3" 67 | }, 68 | "language_info": { 69 | "codemirror_mode": { 70 | "name": "ipython", 71 | "version": 3 72 | }, 73 | "file_extension": ".py", 74 | "mimetype": "text/x-python", 75 | "name": "python", 76 | "nbconvert_exporter": "python", 77 | "pygments_lexer": "ipython3", 78 | "version": "3.10.16" 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 5 83 | } 84 | -------------------------------------------------------------------------------- /sequence_parallel/ring_flash_attention/test_zigzag_ring_flash_attn_func.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_qkvpacked_func 2 | import torch 3 | import torch.distributed as dist 4 | from ring_flash_attention.zigzag_ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func, extract_local 5 | from utils import log, set_seed 6 | 7 | 8 | if __name__ == "__main__": 9 | dist.init_process_group("nccl") 10 | rank = dist.get_rank() 11 | set_seed(rank) 12 | world_size = dist.get_world_size() 13 | dtype = torch.bfloat16 14 | device = torch.device(f"cuda:{rank}") 15 | 16 | batch_size = 1 17 | seqlen = 3824 18 | nheads = 5 19 | d = 128 20 | dropout_p = 0 21 | causal = True 22 | deterministic = False 23 | 24 | assert causal 25 | assert seqlen % (2 * world_size) == 0 26 | assert d % 8 == 0 27 | 28 | qkv = torch.randn( 29 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 30 | ) 31 | dist.broadcast(qkv, src=0) 32 | 33 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 34 | dist.broadcast(dout, src=0) 35 | 36 | local_qkv = extract_local(qkv, rank, world_size).detach().clone() 37 | local_qkv.requires_grad = True 38 | local_dout = extract_local(dout, rank, world_size).detach().clone() 39 | 40 | dist.barrier() 41 | if rank == 0: 42 | print("#" * 30) 43 | print("# forward:") 44 | print("#" * 30) 45 | 46 | out, lse, _ = flash_attn_qkvpacked_func( 47 | qkv, 48 | dropout_p=dropout_p, 49 | causal=causal, 50 | window_size=(-1, -1), 51 | alibi_slopes=None, 52 | deterministic=deterministic, 53 | return_attn_probs=True, 54 | ) 55 | 56 | local_out = extract_local(out, rank, world_size) 57 | local_lse = extract_local(lse, rank, world_size, dim=2) 58 | 59 | ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func( 60 | local_qkv, 61 | dropout_p=dropout_p, 62 | causal=causal, 63 | window_size=(-1, -1), 64 | alibi_slopes=None, 65 | deterministic=deterministic, 66 | return_attn_probs=True, 67 | ) 68 | 69 | log("out", out, rank0_only=True) 70 | log("lse", lse, rank0_only=True) 71 | log("out diff", local_out - ring_out) 72 | log("lse diff", local_lse - ring_lse) 73 | 74 | dist.barrier() 75 | if rank == 0: 76 | print("#" * 30) 77 | print("# backward:") 78 | print("#" * 30) 79 | 80 | out.backward(dout) 81 | dqkv = qkv.grad 82 | 83 | local_dqkv = extract_local(dqkv, rank, world_size) 84 | 85 | ring_out.backward(local_dout) 86 | ring_dqkv = local_qkv.grad 87 | 88 | log("local_dqkv", local_dqkv) 89 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 90 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 91 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 92 | -------------------------------------------------------------------------------- /sequence_parallel/basic/test_flash_attn_varlen_qkvpacked_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from flash_attn import flash_attn_varlen_qkvpacked_func 4 | from flash_attn.bert_padding import pad_input, unpad_input 5 | from reference import attention_ref 6 | 7 | 8 | if __name__ == "__main__": 9 | device = "cuda" 10 | # set seed 11 | torch.random.manual_seed(0) 12 | batch_size = 4 13 | seqlen = 1024 14 | d = 128 15 | nheads_k = nheads = 8 16 | dtype = torch.float16 17 | 18 | assert nheads % nheads_k == 0 19 | # window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) 20 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 21 | q, k, v = qkv.clone().unbind(2) 22 | 23 | random_lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device) 24 | key_padding_mask = (repeat(torch.arange(seqlen, device=device), "s -> b s", b=batch_size) < random_lengths) 25 | query_padding_mask = key_padding_mask.clone() 26 | 27 | q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) 28 | k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) 29 | v_unpad, _, _, _ = unpad_input(v, key_padding_mask) 30 | qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1).detach().requires_grad_() 31 | qkv = torch.stack([q, k, v], dim=2).detach().requires_grad_() 32 | 33 | output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen) 34 | dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen) 35 | 36 | cu_seqlens = cu_seqlens_q 37 | max_seqlen = max_seqlen_q 38 | 39 | out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( 40 | qkv_unpad, 41 | cu_seqlens, 42 | max_seqlen, 43 | return_attn_probs=True, 44 | ) 45 | out = output_pad_fn(out_unpad) 46 | 47 | out_ref, attn_ref = attention_ref( 48 | q, k, v, 49 | query_padding_mask, 50 | key_padding_mask 51 | ) 52 | 53 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 54 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 55 | 56 | g = torch.randn_like(out) 57 | 58 | (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) 59 | dqkv = dqkv_pad_fn(dqkv_unpad) 60 | 61 | (dq_ref, dk_ref, dv_ref,) = torch.autograd.grad(out_ref, (q, k, v), g) 62 | 63 | print(f"dQ max diff: {(dqkv[:,:,0] - dq_ref).abs().max().item()}") 64 | print(f"dK max diff: {(dqkv[:,:,1] - dk_ref).abs().max().item()}") 65 | print(f"dV max diff: {(dqkv[:,:,2] - dv_ref).abs().max().item()}") 66 | print(f"dQ mean diff: {(dqkv[:,:,0] - dq_ref).abs().mean().item()}") 67 | print(f"dK mean diff: {(dqkv[:,:,1] - dk_ref).abs().mean().item()}") 68 | print(f"dV mean diff: {(dqkv[:,:,2] - dv_ref).abs().mean().item()}") 69 | 70 | -------------------------------------------------------------------------------- /verl_test/megatron_test/megatron_sft.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_batch_size: 256 3 | val_batch_size: 8 4 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu 5 | micro_batch_size_per_gpu: 4 # this is also val batch size 6 | train_files: ~/data/gsm8k/train.parquet 7 | val_files: ~/data/gsm8k/test.parquet 8 | # Single-turn settings 9 | prompt_key: question 10 | response_key: answer 11 | prompt_dict_keys: ['question'] 12 | response_dict_keys: ['answer'] 13 | # Multi-turn settings 14 | multiturn: 15 | enable: false # Set to true to use multi-turn dataset 16 | messages_key: messages # Key for messages list in multi-turn mode 17 | max_length: 1024 18 | truncation: error 19 | balance_dp_token: False 20 | chat_template: null 21 | custom_cls: 22 | path: null 23 | name: null 24 | shuffle: False 25 | model: 26 | path: ~/models/gemma-1.1-7b-it 27 | override_config: 28 | model_config: {} 29 | moe_config: 30 | freeze_moe_router: False 31 | external_lib: null 32 | enable_gradient_checkpointing: False 33 | trust_remote_code: False 34 | lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) 35 | lora_alpha: 16 # LoRA scaling factor 36 | target_modules: all-linear # Target modules for LoRA adaptation 37 | use_liger: False 38 | optim: 39 | lr: 1e-5 40 | betas: [0.9, 0.95] 41 | weight_decay: 0.01 42 | warmup_steps_ratio: 0.1 43 | clip_grad: 1.0 44 | lr_scheduler: cosine 45 | megatron: 46 | param_offload: False 47 | grad_offload: False 48 | optimizer_offload: False 49 | tensor_model_parallel_size: 1 50 | expert_model_parallel_size: 1 51 | expert_tensor_parallel_size: null 52 | pipeline_model_parallel_size: 1 53 | virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests 54 | context_parallel_size: 1 55 | sequence_parallel: True 56 | use_distributed_optimizer: True 57 | use_dist_checkpointing: False 58 | dist_checkpointing_path: null 59 | seed: 42 60 | override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage 61 | ulysses_sequence_parallel_size: 1 62 | data_loader_seed: null 63 | use_remove_padding: False 64 | checkpoint: 65 | contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space 66 | trainer: 67 | default_local_dir: /tmp/sft_model 68 | default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here 69 | resume_path: null 70 | project_name: megatron-sft 71 | experiment_name: test 72 | total_epochs: 4 73 | total_training_steps: null 74 | logger: ['console'] 75 | seed: 1 76 | nnodes: 1 77 | n_gpus_per_node: 8 78 | ray_wait_register_center_timeout: 300 79 | device: cuda 80 | ray_init: 81 | num_cpus: null # `None` means using all CPUs 82 | profile: # profile the actor model in `update_policy` 83 | use_profile: False # open it when you want to profile the actor model 84 | profile_ranks: null # list, you can specify the ranks to profile 85 | step_start: -1 # start step in update_policy 86 | step_end: -1 # end step 87 | save_path: null # the path to save the profile result -------------------------------------------------------------------------------- /moe_ep/test_ep_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.distributed as dist 10 | from deepseek.configuration_deepseek import DeepseekV2Config 11 | from deepseek.modeling_deepseek import DeepseekV2MoE 12 | 13 | 14 | 15 | def init_parallel_groups(ep_size=1): 16 | dist.init_process_group("nccl") 17 | # world_size = int(os.getenv("WORLD_SIZE", "0")) 18 | # local_rank = int(os.getenv("LOCAL_RANK", "0")) 19 | local_rank = dist.get_rank() 20 | world_size = dist.get_world_size() 21 | torch.cuda.set_device(local_rank) 22 | ep_group = edp_group = None 23 | for i in range(0, world_size, ep_size): 24 | ranks = list(range(i, i + ep_size)) 25 | group = dist.new_group(ranks) 26 | if local_rank in ranks: 27 | ep_group = group 28 | edp_group = None 29 | for i in range(ep_size): 30 | ranks = list(range(i, world_size, ep_size)) 31 | group = dist.new_group(ranks) 32 | if local_rank in ranks: 33 | edp_group = group 34 | dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group) 35 | dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group) 36 | return world_size, local_rank, ep_group, edp_group 37 | 38 | 39 | if __name__ == "__main__": 40 | ep_size = 8 41 | batch_size = 8 42 | seq_length = 4096 43 | only_forward = False 44 | 45 | world_size, local_rank, ep_group, edp_group = init_parallel_groups(ep_size) 46 | device = torch.device(f"cuda:{local_rank}") 47 | torch.cuda.set_device(device) 48 | with open("deepseek/config.json", "r") as f: 49 | config_dict = json.load(f) 50 | 51 | config = DeepseekV2Config(**config_dict) 52 | config.ep_size = ep_size 53 | dsv2_moe = DeepseekV2MoE(config).to(device) 54 | 55 | input_x = torch.randn(batch_size, seq_length, config.hidden_size).to(device) 56 | dout = torch.randn(batch_size, seq_length, config.hidden_size).to(device) 57 | 58 | moe_outputs = dsv2_moe(input_x) # warmup 59 | moe_outputs.backward(dout) 60 | 61 | torch.backends.cudnn.benchmark = True 62 | profiler = torch.profiler.profile( 63 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 64 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=1,), 65 | record_shapes=True, 66 | profile_memory=True, 67 | with_flops=True, 68 | with_modules=True, 69 | with_stack=True, 70 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 71 | f"./profiles/dsv2moe_ep_bs_{batch_size}_seq_{seq_length}_rank_{local_rank}_{'fwd' if only_forward else 'fwd_bwd'}" 72 | ), 73 | ) 74 | profiler.start() 75 | 76 | begin = torch.cuda.Event(enable_timing=True) 77 | begin.record() 78 | 79 | for i in range(8): 80 | if only_forward: 81 | with torch.no_grad(): 82 | moe_outputs = dsv2_moe(input_x) 83 | else: 84 | moe_outputs = dsv2_moe(input_x) 85 | moe_outputs.backward(dout) 86 | profiler.step() 87 | 88 | end = torch.cuda.Event(enable_timing=True) 89 | end.record() 90 | torch.cuda.synchronize(device=device) 91 | profiler.stop() 92 | dist.destroy_process_group() 93 | -------------------------------------------------------------------------------- /sequence_parallel/basic/test_flash_attn_qkvpacked_func_dist.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.distributed as dist 4 | from flash_attn import flash_attn_qkvpacked_func 5 | from utils import set_seed, log 6 | from reference import attention_ref 7 | 8 | 9 | if __name__ == "__main__": 10 | dist.init_process_group("nccl") 11 | rank = dist.get_rank() 12 | set_seed(rank) 13 | world_size = dist.get_world_size() 14 | dtype = torch.bfloat16 15 | device = torch.device(f"cuda:{rank}") 16 | 17 | batch_size = 1 18 | seqlen = 3816 19 | nheads = 5 20 | d = 128 21 | dropout_p = 0 22 | causal = True 23 | deterministic = False 24 | 25 | assert seqlen % world_size == 0 26 | assert d % 8 == 0 27 | 28 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 29 | dist.broadcast(qkv, src=0) 30 | 31 | q, k, v = qkv.clone().unbind(2) 32 | q, k, v = q.contiguous(), k.contiguous(), v.contiguous() 33 | dist.broadcast(q, src=0) 34 | dist.broadcast(k, src=0) 35 | dist.broadcast(v, src=0) 36 | 37 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 38 | dist.broadcast(dout, src=0) 39 | 40 | local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() 41 | local_qkv.requires_grad = True 42 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 43 | 44 | dist.barrier() 45 | if rank == 0: 46 | print("#" * 30) 47 | print("# forward:") 48 | print("#" * 30) 49 | 50 | out, lse, _ = flash_attn_qkvpacked_func( 51 | qkv, 52 | dropout_p=dropout_p, 53 | causal=causal, 54 | window_size=(-1, -1), 55 | alibi_slopes=None, 56 | deterministic=deterministic, 57 | return_attn_probs=True, 58 | ) 59 | 60 | local_out = out.chunk(world_size, dim=1)[rank].detach().clone() 61 | local_lse = lse.chunk(world_size, dim=-1)[rank] 62 | 63 | 64 | out_pt_ref, attn_pt_ref = attention_ref( 65 | q, 66 | k, 67 | v, 68 | None, 69 | None, 70 | None, 71 | dropout_p, 72 | dropout_mask=None, 73 | causal=causal, 74 | window_size=(-1, -1), 75 | ) 76 | local_out_pt_ref = out_pt_ref.chunk(world_size, dim=1)[rank].detach().clone() 77 | dist.barrier() 78 | 79 | print(f'rank {rank} out (distributed) - out_ref (non-distributed) diff: {(local_out - local_out_pt_ref).abs().max().item()}') 80 | 81 | if rank == 0: 82 | print("#" * 30) 83 | print("# backward:") 84 | print("#" * 30) 85 | 86 | out.backward(dout) 87 | dqkv = qkv.grad 88 | local_dqkv = dqkv.chunk(world_size, dim=1)[rank].detach().clone() 89 | 90 | (dq_ref, dk_ref, dv_ref,) = torch.autograd.grad(out_pt_ref, (q, k, v), dout) 91 | local_dq_ref = dq_ref.chunk(world_size, dim=1)[rank].detach().clone() 92 | local_dk_ref = dk_ref.chunk(world_size, dim=1)[rank].detach().clone() 93 | local_dv_ref = dv_ref.chunk(world_size, dim=1)[rank].detach().clone() 94 | 95 | dist.barrier() 96 | 97 | log("dq diff", local_dqkv[:, :, 0] - local_dq_ref) 98 | log("dk diff", local_dqkv[:, :, 1] - local_dk_ref) 99 | log("dv diff", local_dqkv[:, :, 2] - local_dv_ref) 100 | 101 | if dist.is_initialized(): 102 | dist.destroy_process_group() 103 | -------------------------------------------------------------------------------- /verl_test/rollout_test/test_sglang_async_rollout_without_tools.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | torchrun --standalone --nnodes=1 --nproc_per_node=2 test_sglang_async_rollout_without_tools.py 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | from tensordict import TensorDict 9 | from torch.distributed.device_mesh import init_device_mesh 10 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 11 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 12 | from utils_sglang import ( 13 | are_lists_similar, 14 | clean_torchelastic_env, 15 | generate_hf_output, 16 | get_rollout_config, 17 | initialize_global_process_group, 18 | load_tokenizer_and_model, 19 | prepare_inputs, 20 | ) 21 | 22 | from verl import DataProto 23 | from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout 24 | from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager 25 | 26 | 27 | def test_async_sglang_rollout_without_tool(): 28 | assert torch.cuda.device_count() >= 2 29 | initialize_global_process_group() 30 | clean_torchelastic_env() 31 | 32 | max_prompt_length = 32 33 | max_response_length = 16 34 | dtype = "bfloat16" 35 | tensor_parallel_size = 1 36 | local_model_path = "/data3/ckpt/Qwen/Qwen2.5-3B-Instruct" 37 | 38 | tokenizer, actor_model = load_tokenizer_and_model(local_model_path) 39 | 40 | preencode_prompts = [ 41 | [{"role": "user", "content": prompt, "tool_calls": None}] 42 | for prompt in [ 43 | "Who won the Champions League in 2019?", 44 | "The founder of Apple is", 45 | "What's the best way to learn python?", 46 | ] 47 | ] 48 | prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts] 49 | input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) 50 | 51 | rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, "./sandbox_fusion_tool_config") 52 | rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config) 53 | 54 | prompt_dict = TensorDict( 55 | { 56 | "input_ids": input_ids, 57 | "attention_mask": attention_mask, 58 | "position_ids": position_ids, 59 | }, 60 | batch_size=input_ids.shape[0], 61 | ) 62 | print(f"preprocessed {input_ids.shape=}") 63 | 64 | messages = np.asarray(preencode_prompts) 65 | prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": np.array([{}] * input_ids.shape[0], dtype=object)}) 66 | 67 | prompts.meta_info.update( 68 | { 69 | "eos_token_id": tokenizer.eos_token_id, 70 | "pad_token_id": tokenizer.pad_token_id, 71 | } 72 | ) 73 | 74 | output = rollout.generate_sequences(prompts=prompts) 75 | print(f"generated {output.batch['responses'].shape}") 76 | sglang_output = output.to("cpu") 77 | 78 | sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch["responses"]) 79 | 80 | print(f"sglang response: {sglang_response_tokens}") 81 | print("✅ SGLang w/o tool Test Passed!") 82 | 83 | torch.distributed.barrier() 84 | torch.distributed.destroy_process_group() 85 | 86 | 87 | if __name__ == "__main__": 88 | test_async_sglang_rollout_without_tool() 89 | -------------------------------------------------------------------------------- /moe_ep/test_layer_ep_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.distributed as dist 10 | from deepseek.configuration_deepseek import DeepseekV2Config 11 | from deepseek.modeling_deepseek import DeepseekV2DecoderLayer 12 | 13 | 14 | 15 | def init_parallel_groups(ep_size=1): 16 | dist.init_process_group("nccl") 17 | # world_size = int(os.getenv("WORLD_SIZE", "0")) 18 | # local_rank = int(os.getenv("LOCAL_RANK", "0")) 19 | local_rank = dist.get_rank() 20 | world_size = dist.get_world_size() 21 | torch.cuda.set_device(local_rank) 22 | ep_group = edp_group = None 23 | for i in range(0, world_size, ep_size): 24 | ranks = list(range(i, i + ep_size)) 25 | group = dist.new_group(ranks) 26 | if local_rank in ranks: 27 | ep_group = group 28 | edp_group = None 29 | for i in range(ep_size): 30 | ranks = list(range(i, world_size, ep_size)) 31 | group = dist.new_group(ranks) 32 | if local_rank in ranks: 33 | edp_group = group 34 | dist.all_reduce(torch.zeros(1, device="cuda"), group=ep_group) 35 | dist.all_reduce(torch.zeros(1, device="cuda"), group=edp_group) 36 | return world_size, local_rank, ep_group, edp_group 37 | 38 | 39 | if __name__ == "__main__": 40 | ep_size = 8 41 | batch_size = 2 42 | seq_length = 4096 43 | only_forward = False 44 | 45 | world_size, local_rank, ep_group, edp_group = init_parallel_groups(ep_size) 46 | device = torch.device(f"cuda:{local_rank}") 47 | torch.cuda.set_device(device) 48 | with open("deepseek/config.json", "r") as f: 49 | config_dict = json.load(f) 50 | 51 | config = DeepseekV2Config(**config_dict) 52 | config.ep_size = ep_size 53 | config._attn_implementation = "flash_attention_2" 54 | dsv2_layer = DeepseekV2DecoderLayer(config, layer_idx=10).to(torch.float16).to(device) 55 | 56 | input_x = torch.randn(batch_size, seq_length, config.hidden_size).to(torch.float16).to(device) 57 | dout = torch.randn(batch_size, seq_length, config.hidden_size).to(torch.float16).to(device) 58 | 59 | layer_outputs = dsv2_layer(input_x)[0] # warmup 60 | layer_outputs.backward(dout) 61 | 62 | torch.backends.cudnn.benchmark = True 63 | profiler = torch.profiler.profile( 64 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 65 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=1,), 66 | record_shapes=True, 67 | profile_memory=True, 68 | with_flops=True, 69 | with_modules=True, 70 | with_stack=True, 71 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 72 | f"./profiles/dsv2layer_ep_bs_{batch_size}_seq_{seq_length}_rank_{local_rank}_{'fwd' if only_forward else 'fwd_bwd'}" 73 | ), 74 | ) 75 | profiler.start() 76 | 77 | begin = torch.cuda.Event(enable_timing=True) 78 | begin.record() 79 | 80 | for i in range(8): 81 | if only_forward: 82 | with torch.no_grad(): 83 | layer_outputs = dsv2_layer(input_x) 84 | else: 85 | layer_outputs = dsv2_layer(input_x)[0] 86 | layer_outputs.backward(dout) 87 | profiler.step() 88 | 89 | end = torch.cuda.Event(enable_timing=True) 90 | end.record() 91 | torch.cuda.synchronize(device=device) 92 | profiler.stop() 93 | dist.destroy_process_group() 94 | -------------------------------------------------------------------------------- /sequence_parallel/ulysses/test_ulysses_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_func 4 | from ulysses.ulysses_attn import UlyssesAttention 5 | from utils import log, set_seed 6 | 7 | 8 | if __name__ == "__main__": 9 | dist.init_process_group("nccl") 10 | rank = dist.get_rank() 11 | set_seed(rank) 12 | world_size = dist.get_world_size() 13 | dtype = torch.bfloat16 14 | device = torch.device(f"cuda:{rank}") 15 | 16 | batch_size = 1 17 | seqlen = 3816 18 | nheads = 8 19 | d = 128 20 | dropout_p = 0 21 | causal = True 22 | deterministic = False 23 | 24 | assert seqlen % world_size == 0 25 | assert d % 8 == 0 26 | 27 | q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) 28 | k = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) 29 | v = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) 30 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 31 | 32 | dist.broadcast(q, src=0) 33 | dist.broadcast(k, src=0) 34 | dist.broadcast(v, src=0) 35 | dist.broadcast(dout, src=0) 36 | 37 | local_q = q.chunk(world_size, dim=1)[rank].detach().clone() 38 | local_q.requires_grad = True 39 | local_k = k.chunk(world_size, dim=1)[rank].detach().clone() 40 | local_k.requires_grad = True 41 | local_v = v.chunk(world_size, dim=1)[rank].detach().clone() 42 | local_v.requires_grad = True 43 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 44 | 45 | sp_pg = None #dist.new_group(ranks=[i for i in range(world_size)]) 46 | ulysses_attn = UlyssesAttention(sp_pg) 47 | 48 | dist.barrier() 49 | if rank == 0: 50 | print("#" * 30) 51 | print("# ds-ulysses forward:") 52 | print("#" * 30) 53 | 54 | out, lse, _ = flash_attn_func( 55 | q, 56 | k, 57 | v, 58 | dropout_p=dropout_p, 59 | causal=causal, 60 | window_size=(-1, -1), 61 | alibi_slopes=None, 62 | deterministic=deterministic, 63 | return_attn_probs=True, 64 | ) 65 | 66 | local_out = out.chunk(world_size, dim=1)[rank] 67 | # local_lse = lse.chunk(world_size, dim=1)[rank] 68 | 69 | ulysses_out = ulysses_attn( 70 | local_q, 71 | local_k, 72 | local_v, 73 | dropout_p=dropout_p, 74 | causal=causal, 75 | window_size=(-1, -1), 76 | alibi_slopes=None, 77 | deterministic=deterministic, 78 | return_attn_probs=True, 79 | ) 80 | 81 | log("out diff", local_out - ulysses_out) 82 | # log("lse diff", local_lse - ulysses_lse) 83 | 84 | dist.barrier() 85 | if rank == 0: 86 | print("#" * 30) 87 | print("# backward:") 88 | print("#" * 30) 89 | 90 | ulysses_out.backward(local_dout) 91 | dist.barrier() 92 | 93 | out.backward(dout) 94 | dist.barrier() 95 | dq, dk, dv = q.grad, k.grad, v.grad 96 | local_dq_ref = dq.chunk(world_size, dim=1)[rank] 97 | local_dk_ref = dk.chunk(world_size, dim=1)[rank] 98 | local_dv_ref = dv.chunk(world_size, dim=1)[rank] 99 | 100 | log("dq diff", local_dq_ref - local_q.grad) 101 | log("dk diff", local_dk_ref - local_k.grad) 102 | log("dv diff", local_dv_ref - local_v.grad) 103 | 104 | -------------------------------------------------------------------------------- /verl_test/test_torch_func.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/data/zzd/verl") 3 | 4 | import pdb 5 | import torch 6 | from flash_attn.bert_padding import unpad_input 7 | 8 | from verl.utils.model import create_random_mask 9 | 10 | 11 | def test_log_probs_from_logits_response_rmpad(): 12 | from verl.utils.torch_functional import log_probs_from_logits_response, log_probs_from_logits_response_rmpad 13 | 14 | vocab_size = 32000 15 | batch_size = 2 16 | prompt_length = 256 17 | response_length = 256 18 | 19 | input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, prompt_length + response_length), device="cuda") 20 | attention_mask = create_random_mask( 21 | input_ids=input_ids, max_ratio_of_left_padding=0.2, max_ratio_of_valid_token=0.8, min_ratio_of_valid_token=0.6 22 | ) 23 | response_mask = attention_mask[:, -response_length:] 24 | assert torch.all(response_mask[:, 0] == 1) 25 | 26 | logits = torch.randn(batch_size, prompt_length + response_length, vocab_size, device="cuda") 27 | logits_rmpad = unpad_input(logits, attention_mask)[0] 28 | 29 | expected_output = log_probs_from_logits_response( 30 | input_ids=input_ids, logits=logits, response_length=response_length 31 | ) 32 | actual_output = log_probs_from_logits_response_rmpad( 33 | input_ids=input_ids, attention_mask=attention_mask, logits_rmpad=logits_rmpad, response_length=response_length 34 | ) 35 | 36 | # This should bitwise align as only this operation only contains gather operators 37 | assert torch.all(torch.eq(actual_output * response_mask, expected_output * response_mask)) 38 | 39 | 40 | def test_logprobs_from_logits_v2(dtype=torch.bfloat16): 41 | from verl.utils.torch_functional import logprobs_from_logits_naive, logprobs_from_logits_v2 42 | 43 | vocab_size = 32000 44 | batch_size = 2 45 | seq_len = 512 46 | 47 | labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") 48 | logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) 49 | 50 | expected_output = logprobs_from_logits_naive(labels=labels, logits=logits) 51 | actual_output = logprobs_from_logits_v2(labels=labels, logits=logits) 52 | 53 | if dtype in [torch.float16, torch.bfloat16]: # float16 falls back to an exactly equivalent method 54 | assert torch.equal(actual_output, expected_output) 55 | else: # small numerical difference when using gather / logsumexp approach 56 | torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) 57 | 58 | 59 | def test_lr_scheduler(): 60 | from torch import nn 61 | 62 | model = nn.Linear(10, 10) 63 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 64 | 65 | from verl.utils.torch_functional import get_constant_schedule_with_warmup 66 | 67 | constant_lr = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=2) 68 | 69 | lr_lst = [] 70 | 71 | for _ in range(5): 72 | lr_lst.append(constant_lr.get_last_lr()[0]) 73 | constant_lr.step() 74 | 75 | torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.001, 0.001]) 76 | 77 | from verl.utils.torch_functional import get_cosine_schedule_with_warmup 78 | 79 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 80 | cosine_lr = get_cosine_schedule_with_warmup( 81 | optimizer=optimizer, num_warmup_steps=2, num_training_steps=5, min_lr_ratio=0.1 82 | ) 83 | 84 | lr_lst = [] 85 | 86 | for _ in range(5): 87 | lr_lst.append(cosine_lr.get_last_lr()[0]) 88 | cosine_lr.step() 89 | 90 | torch.testing.assert_close(lr_lst, [0.0, 0.0005, 0.001, 0.0007750000000000002, 0.0003250000000000002]) 91 | 92 | 93 | if __name__ == "__main__": 94 | test_log_probs_from_logits_response_rmpad() 95 | test_logprobs_from_logits_v2() 96 | test_lr_scheduler() -------------------------------------------------------------------------------- /sequence_parallel/ulysses/ulysses_attn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from typing import Any 5 | from torch import Tensor 6 | import torch.distributed as dist 7 | from flash_attn import flash_attn_func 8 | from ulysses.ulyssess_utils import SeqAllToAll4D 9 | 10 | 11 | class UlyssesAttention(torch.nn.Module): 12 | """Initialization. 13 | 14 | Arguments: 15 | local_attention (Module): local attention with q,k,v 16 | sequence_process_group (ProcessGroup): sequence parallel process group 17 | scatter_idx (int): scatter_idx for all2all comm 18 | gather_idx (int): gather_idx for all2all comm 19 | use_sync (bool): whether to synchronize after all-to-all. This flag can save cuda memory but will slow down the speed. 20 | attn_type (FlashAttentionImpl): attention type enum 21 | """ 22 | 23 | def __init__( 24 | self, 25 | sequence_process_group: dist.ProcessGroup = None, 26 | scatter_idx: int = 2, 27 | gather_idx: int = 1, 28 | use_sync: bool = False, 29 | ) -> None: 30 | super(UlyssesAttention, self).__init__() 31 | self.spg = sequence_process_group 32 | self.scatter_idx = scatter_idx 33 | self.gather_idx = gather_idx 34 | self.use_sync = use_sync 35 | 36 | def forward( 37 | self, 38 | query: Tensor, 39 | key: Tensor, 40 | value: Tensor, 41 | dropout_p=0.0, 42 | softmax_scale=None, 43 | causal=False, 44 | window_size=(-1, -1), 45 | softcap=0.0, 46 | alibi_slopes=None, 47 | deterministic=False, 48 | return_attn_probs=False, 49 | *args: Any 50 | ) -> Tensor: 51 | """forward 52 | 53 | Arguments: 54 | query (Tensor): query input to the layer 55 | key (Tensor): key input to the layer 56 | value (Tensor): value input to the layer 57 | args: other args 58 | 59 | Returns: 60 | * output (Tensor): context output 61 | """ 62 | # TODO Merge three alltoall calls into one 63 | # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! 64 | # in shape : e.g., [s/p:h:] 65 | # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) 66 | 67 | # scatter 2, gather 1 68 | q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx, self.use_sync) 69 | k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx, self.use_sync) 70 | v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx, self.use_sync) 71 | 72 | if softmax_scale is None: 73 | softmax_scale = q.shape[-1] ** -0.5 74 | 75 | context_layer = flash_attn_func( 76 | q, 77 | k, 78 | v, 79 | dropout_p=dropout_p, 80 | softmax_scale = softmax_scale, 81 | causal=causal, 82 | window_size=window_size, 83 | softcap=softcap, 84 | alibi_slopes=alibi_slopes, 85 | deterministic=deterministic, 86 | return_attn_probs=return_attn_probs, 87 | ) 88 | 89 | if isinstance(context_layer, tuple): 90 | context_layer = context_layer[0] 91 | 92 | # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) 93 | # scatter 1, gather 2 94 | output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync) 95 | 96 | # out e.g., [s/p::h] 97 | return output 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /verl_test/rollout_test/test_actor_rollout_fsdp.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append("/data2/zzd/rl_llm/verl") 4 | import pdb 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | import ray 9 | import torch 10 | ray.init( 11 | runtime_env={ 12 | "working_dir": "/data2/zzd/rl_llm/verl", # 工作目录(会上传到集群) 13 | } 14 | ) 15 | 16 | def get_config(): 17 | from omegaconf import OmegaConf 18 | config = OmegaConf.load("/data2/zzd/rl_llm/verl/verl/trainer/config/ppo_trainer.yaml") 19 | config.data.train_files="/data2/zzd/data/full_hh_rlhf/rl/train.parquet" 20 | config.data.max_prompt_length=128 21 | config.data.filter_overlong_prompts=True 22 | config.actor_rollout_ref.model.path="/data3/ckpt/Qwen/Qwen2.5-3B-Instruct" 23 | return config 24 | 25 | config = get_config() 26 | 27 | def get_test_data(): 28 | from verl.utils import hf_processor, hf_tokenizer 29 | from verl.utils.fs import copy_to_local 30 | from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler 31 | from verl.utils.dataset.rl_dataset import collate_fn 32 | from verl import DataProto 33 | from torchdata.stateful_dataloader import StatefulDataLoader 34 | 35 | local_path = copy_to_local(config.actor_rollout_ref.model.path) 36 | trust_remote_code = config.data.get("trust_remote_code", False) 37 | tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) 38 | processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none 39 | train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) 40 | train_sampler = create_rl_sampler(config.data, train_dataset) 41 | 42 | train_dataloader = StatefulDataLoader( 43 | dataset=train_dataset, 44 | batch_size=config.data.get("gen_batch_size", 8), 45 | num_workers=config.data.get("dataloader_num_workers", 1), 46 | drop_last=True, 47 | collate_fn=collate_fn, 48 | sampler=train_sampler, 49 | ) 50 | for batch_dict in train_dataloader: 51 | batch: DataProto = DataProto.from_single_dict(batch_dict) 52 | 53 | batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] 54 | non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] 55 | if "multi_modal_inputs" in batch.non_tensor_batch: 56 | non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) 57 | if "raw_prompt" in batch.non_tensor_batch: 58 | non_tensor_batch_keys_to_pop.append("raw_prompt") 59 | if "tools_kwargs" in batch.non_tensor_batch: 60 | non_tensor_batch_keys_to_pop.append("tools_kwargs") 61 | gen_batch = batch.pop( 62 | batch_keys=batch_keys_to_pop, 63 | non_tensor_batch_keys=non_tensor_batch_keys_to_pop, 64 | ) 65 | 66 | break 67 | return batch, gen_batch 68 | 69 | 70 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool 71 | from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker 72 | from verl.single_controller.ray import RayWorkerGroup 73 | 74 | actor_rollout_mode = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker 75 | actor_worker = ray.remote(actor_rollout_mode) 76 | 77 | resource_pool = RayResourcePool([8], use_gpu=True, max_colocate_count=1) 78 | actor_cls = RayClassWithInitArgs(actor_worker, config=config.actor_rollout_ref, role="actor_rollout",) 79 | actor_worker_group = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) 80 | 81 | print("world size:", actor_worker_group.world_size) 82 | print("worker_names:", actor_worker_group.worker_names) 83 | actor_worker_group.init_model() 84 | 85 | batch, gen_batch = get_test_data() 86 | print(gen_batch) 87 | gen_batch_output = actor_worker_group.generate_sequences(gen_batch) 88 | print("gen_batch_output:", gen_batch_output) 89 | 90 | # batch = batch.union(gen_batch_output) 91 | # print("batch:", batch) 92 | 93 | ray.shutdown() -------------------------------------------------------------------------------- /verl_test/megatron_test/test_meagtron_utils.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "395936a5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "os.environ[\"RANK\"] = \"0\"\n", 12 | "os.environ[\"WORLD_SIZE\"] = \"1\"\n", 13 | "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", 14 | "os.environ[\"MASTER_PORT\"] = \"29500\"\n", 15 | "\n", 16 | "import sys\n", 17 | "sys.path.append(\"/data2/zzd/rl_llm/verl\")\n", 18 | "\n", 19 | "import torch\n", 20 | "from megatron.core import parallel_state as mpu\n", 21 | "\n", 22 | "torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n", 23 | "mpu.initialize_model_parallel(\n", 24 | " tensor_model_parallel_size=1,\n", 25 | " pipeline_model_parallel_size=1,\n", 26 | " virtual_pipeline_model_parallel_size=None,\n", 27 | " pipeline_model_parallel_split_rank=None,\n", 28 | " use_sharp=False,\n", 29 | " context_parallel_size=1,\n", 30 | " expert_model_parallel_size=1,\n", 31 | " nccl_communicator_config_path=None,\n", 32 | ")" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 7, 38 | "id": "6ac0ac0c", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from verl.models.mcore.util import preprocess_packed_seqs, postprocess_packed_seqs\n", 43 | "\n", 44 | "batch_size = 2\n", 45 | "seq_len = 32\n", 46 | "\n", 47 | "input_ids = torch.randint(0, 100, (batch_size, seq_len))\n", 48 | "attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool)\n", 49 | "attention_mask[:, -20:] = False \n", 50 | "\n", 51 | "input_ids_pad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 9, 57 | "id": "d9dd49d3", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "tensor([[60, 88, 81, 31, 50, 4, 34, 63, 2, 78, 59, 49, 61, 52, 17, 0, 54, 73,\n", 65 | " 6, 23, 80, 85, 12, 41, 37, 50, 76, 29, 15, 55, 99, 59],\n", 66 | " [52, 84, 62, 95, 14, 96, 55, 8, 28, 97, 27, 49, 34, 43, 50, 71, 8, 48,\n", 67 | " 98, 36, 10, 71, 11, 94, 45, 77, 5, 3, 93, 54, 20, 32]])\n", 68 | "tensor([[ True, True, True, True, True, True, True, True, True, True,\n", 69 | " True, True, False, False, False, False, False, False, False, False,\n", 70 | " False, False, False, False, False, False, False, False, False, False,\n", 71 | " False, False],\n", 72 | " [ True, True, True, True, True, True, True, True, True, True,\n", 73 | " True, True, False, False, False, False, False, False, False, False,\n", 74 | " False, False, False, False, False, False, False, False, False, False,\n", 75 | " False, False]])\n", 76 | "tensor([[60, 88, 81, 31, 50, 4, 34, 63, 2, 78, 59, 49, 52, 84, 62, 95, 14, 96,\n", 77 | " 55, 8, 28, 97, 27, 49]])\n", 78 | "PackedSeqParams(qkv_format='thd', cu_seqlens_q=tensor([ 0, 12, 24], dtype=torch.int32), cu_seqlens_kv=tensor([ 0, 12, 24], dtype=torch.int32), cu_seqlens_q_padded=tensor([ 0, 12, 24], dtype=torch.int32), cu_seqlens_kv_padded=tensor([ 0, 12, 24], dtype=torch.int32), max_seqlen_q=12, max_seqlen_kv=12)\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "print(input_ids)\n", 84 | "print(attention_mask)\n", 85 | "print(input_ids_pad)\n", 86 | "print(packed_seq_params)" 87 | ] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "rl-env", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.10.16" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 5 111 | } 112 | -------------------------------------------------------------------------------- /verl_test/agent_test/test_agent_single_turn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | VERL_PATH = "/data2/zzd/rl_llm/verl" 3 | sys.path.append(VERL_PATH) 4 | 5 | import json 6 | import os 7 | from typing import Any, Tuple 8 | 9 | import numpy as np 10 | import ray 11 | from omegaconf import DictConfig, OmegaConf 12 | from transformers.utils import get_json_schema 13 | 14 | from agent_utils import init_agent_loop_manager 15 | from verl.protocol import DataProto 16 | from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema 17 | from verl.utils import hf_tokenizer 18 | 19 | def init_config() -> DictConfig: 20 | from hydra import compose, initialize_config_dir 21 | 22 | with initialize_config_dir(config_dir=os.path.abspath(f"{VERL_PATH}/verl/trainer/config")): 23 | config = compose( 24 | config_name="ppo_trainer", 25 | overrides=[ 26 | "actor_rollout_ref.actor.use_dynamic_bsz=true", 27 | # test sleep/wake_up with fsdp offload 28 | "actor_rollout_ref.actor.fsdp_config.param_offload=True", 29 | "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True", 30 | ], 31 | ) 32 | model_path = "/data3/ckpt/Qwen/Qwen2.5-1.5B-Instruct" 33 | config.actor_rollout_ref.model.path = model_path 34 | config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") 35 | config.actor_rollout_ref.rollout.mode = "async" 36 | config.actor_rollout_ref.rollout.prompt_length = 4096 37 | config.actor_rollout_ref.rollout.response_length = 4096 38 | config.actor_rollout_ref.rollout.n = 4 39 | config.actor_rollout_ref.rollout.agent.num_workers = 2 40 | 41 | # test sleep/wake_up with fsdp offload 42 | config.actor_rollout_ref.actor.fsdp_config.param_offload = True 43 | config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True 44 | 45 | return config 46 | 47 | 48 | def test_single_turn(init_config): 49 | ray.init( 50 | runtime_env={ 51 | "env_vars": { 52 | "TOKENIZERS_PARALLELISM": "true", 53 | "NCCL_DEBUG": "WARN", 54 | "VLLM_LOGGING_LEVEL": "INFO", 55 | "VLLM_USE_V1": "1", 56 | }, 57 | "working_dir": VERL_PATH, # 工作目录(会上传到集群) 58 | } 59 | ) 60 | 61 | print("Initializing agent loop manager...") 62 | agent_loop_manager = init_agent_loop_manager(init_config) 63 | 64 | raw_prompts = [ 65 | [{"role": "user","content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.",}], 66 | [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], 67 | ] 68 | batch = DataProto( 69 | non_tensor_batch={ 70 | "raw_prompt": np.array(raw_prompts), 71 | "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)), 72 | }, 73 | ) 74 | 75 | print("Generating sequences by agent...") 76 | n = init_config.actor_rollout_ref.rollout.n 77 | batch = batch.repeat(n) 78 | result = agent_loop_manager.generate_sequences(prompts=batch) 79 | print(result.batch["responses"]) 80 | assert len(result) == len(raw_prompts) * n 81 | 82 | # decode responses 83 | tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) 84 | responses = result.batch["responses"] 85 | response_mask = result.batch["response_mask"] 86 | for i in range(len(responses)): 87 | valid_tokens = responses[i][response_mask[i].bool()] 88 | response_str = tokenizer.decode(valid_tokens) 89 | print(f"response {i}: {response_str}") 90 | 91 | # check result 92 | seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) 93 | assert result.batch["input_ids"].size(1) == seq_len 94 | assert result.batch["attention_mask"].size(1) == seq_len 95 | assert result.batch["position_ids"].size(1) == seq_len 96 | 97 | # check turns 98 | num_turns = result.non_tensor_batch["__num_turns__"] 99 | assert np.all(num_turns == 2) 100 | 101 | print("Test passed!") 102 | ray.shutdown() 103 | 104 | 105 | if __name__ == "__main__": 106 | init_config = init_config() 107 | test_single_turn(init_config) 108 | print("✅ All tests passed!") -------------------------------------------------------------------------------- /sequence_parallel/usp/test_usp_qkvpacked_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_qkvpacked_func 4 | from usp.usp_attn import LongContextAttentionQKVPacked 5 | from usp.usp_utils import set_seq_parallel_pg, EXTRACT_FUNC_DICT 6 | from utils import log, set_seed 7 | 8 | 9 | if __name__ == "__main__": 10 | dist.init_process_group("nccl") 11 | rank = dist.get_rank() 12 | set_seed(rank) 13 | world_size = dist.get_world_size() 14 | dtype = torch.bfloat16 15 | device = torch.device(f"cuda:{rank}") 16 | 17 | batch_size = 1 18 | seqlen = 4096 19 | nheads = 8 20 | d = 128 21 | dropout_p = 0 22 | causal = True 23 | deterministic = False 24 | ring_attn_type = "basic" # ["basic", "stripe", "zigzag"] 25 | 26 | assert seqlen % world_size == 0 27 | assert d % 8 == 0 28 | 29 | # global tensors 30 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 31 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 32 | 33 | with torch.no_grad(): 34 | dist.broadcast(qkv, src=0) 35 | dist.broadcast(dout, src=0) 36 | 37 | # prepare process group for hybrid sequence parallelism 38 | use_ring_low_dim = True 39 | sp_ulysses_degree = 2 40 | sp_ring_degree = world_size // sp_ulysses_degree 41 | print(f"rank {rank}, sp_ulysses_degree: {sp_ulysses_degree}, sp_ring_degree: {sp_ring_degree}") 42 | set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) 43 | 44 | # sharded tensors for long context attn 45 | local_qkv = (EXTRACT_FUNC_DICT[ring_attn_type](qkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()) 46 | local_qkv.requires_grad = True 47 | 48 | local_dout = (EXTRACT_FUNC_DICT[ring_attn_type](dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()) 49 | 50 | usp_attn = LongContextAttentionQKVPacked(ring_impl_type=ring_attn_type) 51 | 52 | dist.barrier() 53 | if rank == 0: 54 | print("#" * 30) 55 | print("# USP forward:") 56 | print("#" * 30) 57 | 58 | out_ref, lse, _ = flash_attn_qkvpacked_func( 59 | qkv, 60 | dropout_p=dropout_p, 61 | causal=causal, 62 | window_size=(-1, -1), 63 | alibi_slopes=None, 64 | deterministic=deterministic, 65 | return_attn_probs=True, 66 | ) 67 | local_out_ref = EXTRACT_FUNC_DICT[ring_attn_type](out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree) 68 | 69 | # usp attn forward 70 | usp_out = usp_attn( 71 | local_qkv, 72 | dropout_p=dropout_p, 73 | causal=causal, 74 | window_size=(-1, -1), 75 | alibi_slopes=None, 76 | deterministic=deterministic, 77 | return_attn_probs=True, 78 | ) 79 | 80 | log("out diff", usp_out - local_out_ref) 81 | 82 | max_memory = torch.cuda.max_memory_allocated(device) / (1024 * 1024) # Convert to MB 83 | print(f"[Rank#{rank}] Maximum GPU memory used: {max_memory:.2f} MB") 84 | torch.cuda.reset_peak_memory_stats(device) # Reset stats 85 | 86 | dist.barrier() 87 | 88 | if rank == 0: 89 | print("#" * 30) 90 | print("# backward:") 91 | print("#" * 30) 92 | 93 | out_ref.backward(dout) 94 | dqkv = qkv.grad 95 | local_dqkv_ref = EXTRACT_FUNC_DICT[ring_attn_type](dqkv, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree) 96 | 97 | usp_out.backward(local_dout) 98 | local_dqkv = local_qkv.grad 99 | 100 | log("dq diff", local_dqkv_ref[:,:,0] - local_dqkv[:,:,0]) 101 | log("dk diff", local_dqkv_ref[:,:,1] - local_dqkv[:,:,1]) 102 | log("dv diff", local_dqkv_ref[:,:,2] - local_dqkv[:,:,2]) 103 | 104 | if dist.is_initialized(): 105 | dist.destroy_process_group() -------------------------------------------------------------------------------- /verl_test/rollout_test/test_vllm_tp_dp_spmd.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/data2/zzd/rl_llm/verl") 3 | 4 | import os 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy 9 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 10 | from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 12 | from vllm import LLM, SamplingParams 13 | 14 | from verl.utils.distributed import initialize_global_process_group 15 | from torch.distributed.device_mesh import init_device_mesh 16 | from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout 17 | 18 | 19 | def get_config(): 20 | from omegaconf import OmegaConf 21 | config = OmegaConf.load("/data2/zzd/rl_llm/verl/verl/trainer/config/generation.yaml") 22 | config.data.path="/data2/zzd/data/full_hh_rlhf/rl/train.parquet" 23 | config.model.path="/data3/ckpt/Qwen/Qwen2.5-3B-Instruct" 24 | config.rollout.tensor_model_parallel_size=4 25 | config.rollout.gpu_memory_utilization=0.8 26 | return config 27 | 28 | config = get_config() 29 | 30 | def get_test_data(tokenizer, rank): 31 | from verl import DataProto 32 | from verl.utils.model import compute_position_id_with_mask 33 | from verl.utils.torch_functional import pad_sequence_to_length 34 | 35 | max_prompt_length = 32 36 | preencode_prompts = [ 37 | "Who won the Champions League in 2019?", 38 | # "The founder of Apple is", 39 | # "痛饮狂歌空度日", 40 | # "13*24=" 41 | ] 42 | tokenizer.pad_token = tokenizer.eos_token 43 | prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) 44 | input_ids = prompts["input_ids"] 45 | attention_mask = prompts["attention_mask"] 46 | 47 | # position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).unsqueeze(0) 48 | 49 | input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) 50 | attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) 51 | # position_ids = pad_sequence_to_length(position_ids, max_prompt_length, 0, left_pad=True) 52 | position_ids = compute_position_id_with_mask(attention_mask) 53 | 54 | print("start generation") 55 | input_ids = input_ids.cuda() 56 | attention_mask = attention_mask.cuda() 57 | position_ids = position_ids.cuda() 58 | 59 | data = DataProto.from_single_dict({ 60 | "input_ids": input_ids, 61 | "attention_mask": attention_mask, 62 | "position_ids": position_ids, 63 | }) 64 | data.meta_info["eos_token_id"] = tokenizer.eos_token_id 65 | 66 | return data 67 | 68 | 69 | def test_vllm_spmd(): 70 | assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." 71 | local_rank, rank, world_size = initialize_global_process_group() 72 | 73 | # Initialize model and token 74 | local_cache_path = "/data2/zzd/.cache/verl" 75 | local_cache_path = os.path.expanduser(local_cache_path) 76 | from verl.utils.fs import copy_to_local 77 | 78 | local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path) 79 | tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left", trust_remote_code=True) 80 | hf_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True) 81 | 82 | infer_tp = config.rollout.tensor_model_parallel_size 83 | dp = world_size // infer_tp 84 | assert world_size % infer_tp == 0, f"rollout world_size: {world_size} is not divisible by infer_tp: {infer_tp}" 85 | rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) 86 | 87 | 88 | rollout = vLLMRollout( 89 | model_path=local_model_path, 90 | config=config.rollout, 91 | tokenizer=tokenizer, 92 | model_hf_config=hf_config, 93 | device_mesh=rollout_device_mesh, 94 | trust_remote_code=True, 95 | ) 96 | 97 | prompts_data = get_test_data(tokenizer, rank) 98 | print("start generation") 99 | 100 | print(f"rank:{rank}, prompts_data: {prompts_data}") 101 | response = rollout.generate_sequences(prompts_data) 102 | print(f"rank:{rank}, response: {response}") 103 | 104 | 105 | if __name__ == "__main__": 106 | test_vllm_spmd() 107 | -------------------------------------------------------------------------------- /parallel_framework/test_megatron_dp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import pytest 4 | import torch 5 | from packaging import version 6 | from torch import testing 7 | 8 | from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig 9 | from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups 10 | from megatron.core.transformer import TransformerConfig 11 | from megatron.core import parallel_state as mpu 12 | from megatron_utils import initialize_global_process_group 13 | 14 | 15 | # Test model for testing DDP 16 | class TestModel(torch.nn.Module): 17 | def __init__(self, input_dim, output_dim): 18 | super().__init__() 19 | self.linear1 = torch.nn.Linear(input_dim, input_dim * 4) 20 | self.activation = torch.nn.ReLU() 21 | self.linear2 = torch.nn.Linear(input_dim * 4, output_dim) 22 | 23 | def forward(self, x): 24 | x = self.linear1(x) 25 | x = self.activation(x) 26 | x = self.linear2(x) 27 | return x 28 | 29 | 30 | 31 | def test_ddp_with_dp_process_groups(): 32 | """Test that DDP works correctly with dp pgs from parallel state and user defined pgs.""" 33 | from torch.distributed.device_mesh import init_device_mesh 34 | 35 | local_rank, rank, world_size = initialize_global_process_group() 36 | 37 | dp_size = world_size 38 | mpu.initialize_model_parallel( 39 | tensor_model_parallel_size=1, 40 | pipeline_model_parallel_size=1, 41 | virtual_pipeline_model_parallel_size=None, 42 | pipeline_model_parallel_split_rank=None, 43 | use_sharp=False, 44 | context_parallel_size=1, 45 | expert_model_parallel_size=1, 46 | nccl_communicator_config_path=None, 47 | ) 48 | 49 | # Simple model config 50 | input_dim = 13 51 | output_dim = 17 52 | 53 | # Setup DDP config 54 | ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) 55 | 56 | # Create two identical models 57 | model1 = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() 58 | model2 = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() 59 | 60 | # Ensure identical weights 61 | for p1, p2 in zip(model1.parameters(), model2.parameters()): 62 | p2.data.copy_(p1.data) 63 | 64 | # Wrap first model with default process groups 65 | transformer_config = TransformerConfig( 66 | num_attention_heads=1, num_layers=1, context_parallel_size=1 67 | ) 68 | 69 | ddp_model1 = DistributedDataParallel( 70 | transformer_config, ddp_config=ddp_config, module=model1 71 | ) 72 | 73 | # Create device mesh for explicit process groups 74 | # Create a mesh with dimension dp [dp_size], 1 pp size and 1 ep size 75 | device_mesh = init_device_mesh("cuda", (dp_size, 1, 1), mesh_dim_names=("dp", "ep", "pp")) 76 | 77 | # Create process groups config with ONLY dp group 78 | grad_comm_pgs = GradCommProcessGroups() 79 | model_comm_pgs = ModelCommProcessGroups() 80 | 81 | grad_comm_pgs.dp = device_mesh.get_group(mesh_dim="dp") 82 | model_comm_pgs.pp = device_mesh.get_group(mesh_dim="pp") 83 | model_comm_pgs.ep = device_mesh.get_group(mesh_dim="ep") 84 | 85 | # Wrap second model with minimal process groups (only dp) 86 | ddp_model2 = DistributedDataParallel( 87 | transformer_config, 88 | ddp_config=ddp_config, 89 | module=model2, 90 | grad_comm_pgs=grad_comm_pgs, 91 | model_comm_pgs=model_comm_pgs, 92 | ) 93 | 94 | # Create identical inputs with integer values 95 | batch_size = 2 96 | input_data = torch.randint(0, 10, (batch_size, input_dim), device='cuda', dtype=torch.long) 97 | input_data = input_data.float() # Convert to float for model compatibility 98 | 99 | # Forward pass 100 | out1 = ddp_model1(input_data) 101 | out2 = ddp_model2(input_data) 102 | 103 | testing.assert_close(out1, out2, rtol=0, atol=0) 104 | 105 | # Loss and backward 106 | loss1 = out1.sum() 107 | loss2 = out2.sum() 108 | 109 | loss1.backward() 110 | loss2.backward() 111 | 112 | # Check gradients are identical using torch.testing 113 | for p1, p2 in zip(ddp_model1.parameters(), ddp_model2.parameters()): 114 | if hasattr(p1, 'main_grad') and hasattr(p2, 'main_grad'): 115 | testing.assert_close(p1.main_grad, p2.main_grad, rtol=0, atol=0) 116 | print("✅ Test Passed!") 117 | 118 | if __name__ == "__main__": 119 | test_ddp_with_dp_process_groups() 120 | -------------------------------------------------------------------------------- /verl_test/megatron_test/test_actor_rollout_megatron.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append("/data2/zzd/rl_llm/verl") 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | import ray 8 | ray.shutdown() 9 | ray.init( 10 | runtime_env={ 11 | "working_dir": "/data2/zzd/rl_llm/verl", # 工作目录(会上传到集群) 12 | } 13 | ) 14 | 15 | def get_config(): 16 | from omegaconf import OmegaConf 17 | config = OmegaConf.load("/data2/zzd/rl_llm/verl/verl/trainer/config/ppo_megatron_trainer.yaml") 18 | config.data.train_files="/data2/zzd/data/full_hh_rlhf/rl/train.parquet" 19 | config.data.max_prompt_length=128 20 | config.data.filter_overlong_prompts=True 21 | config.actor_rollout_ref.model.path="/data3/ckpt/Qwen/Qwen2.5-3B-Instruct" 22 | config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=4 23 | config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 24 | config.actor_rollout_ref.actor.megatron.sequence_parallel=False 25 | config.actor_rollout_ref.rollout.gpu_memory_utilization=0.8 26 | return config 27 | 28 | config = get_config() 29 | 30 | def get_test_data(): 31 | from verl.utils import hf_processor, hf_tokenizer 32 | from verl.utils.fs import copy_to_local 33 | from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler 34 | from verl.utils.dataset.rl_dataset import collate_fn 35 | from verl import DataProto 36 | from torchdata.stateful_dataloader import StatefulDataLoader 37 | 38 | local_path = copy_to_local(config.actor_rollout_ref.model.path) 39 | trust_remote_code = config.data.get("trust_remote_code", False) 40 | tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) 41 | processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none 42 | train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) 43 | train_sampler = create_rl_sampler(config.data, train_dataset) 44 | 45 | train_dataloader = StatefulDataLoader( 46 | dataset=train_dataset, 47 | batch_size=config.data.get("gen_batch_size", 8), 48 | num_workers=config.data.get("dataloader_num_workers", 8), 49 | drop_last=True, 50 | collate_fn=collate_fn, 51 | sampler=train_sampler, 52 | ) 53 | for batch_dict in train_dataloader: 54 | batch: DataProto = DataProto.from_single_dict(batch_dict) 55 | 56 | batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] 57 | non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] 58 | if "multi_modal_inputs" in batch.non_tensor_batch: 59 | non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) 60 | if "raw_prompt" in batch.non_tensor_batch: 61 | non_tensor_batch_keys_to_pop.append("raw_prompt") 62 | if "tools_kwargs" in batch.non_tensor_batch: 63 | non_tensor_batch_keys_to_pop.append("tools_kwargs") 64 | gen_batch = batch.pop( 65 | batch_keys=batch_keys_to_pop, 66 | non_tensor_batch_keys=non_tensor_batch_keys_to_pop, 67 | ) 68 | 69 | break 70 | return batch, gen_batch 71 | 72 | 73 | from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool 74 | from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup 75 | from verl.workers.megatron_workers import ActorRolloutRefWorker 76 | from verl.single_controller.ray.base import create_colocated_worker_cls 77 | 78 | actor_worker = ray.remote(ActorRolloutRefWorker) 79 | 80 | resource_pool = RayResourcePool([8], use_gpu=True, max_colocate_count=1, name_prefix="GPU") 81 | actor_worker = RayClassWithInitArgs(actor_worker, config=config.actor_rollout_ref, role="actor_rollout",) 82 | 83 | all_wg = {} 84 | 85 | class_dict = {"actor_worker": actor_worker} 86 | worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) 87 | wg_dict = NVMegatronRayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name="cuda") 88 | spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) 89 | all_wg.update(spawn_wg) 90 | actor_worker_group = all_wg["actor_worker"] 91 | 92 | print("world size:", actor_worker_group.world_size) 93 | print("worker_names:", actor_worker_group.worker_names) 94 | print("TP size:", actor_worker_group.tp_size) 95 | print("PP size:", actor_worker_group.pp_size) 96 | print("DP size:", actor_worker_group.dp_size) 97 | 98 | actor_worker_group.init_model() 99 | 100 | batch, gen_batch = get_test_data() 101 | print(gen_batch) 102 | gen_batch_output = actor_worker_group.generate_sequences(gen_batch) 103 | print("gen_batch_output:", gen_batch_output) 104 | 105 | ray.shutdown() -------------------------------------------------------------------------------- /math_tir/routed_sandbox.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import List, Optional 17 | 18 | import requests 19 | from e2b_code_interpreter.models import Execution, ExecutionError, Result 20 | 21 | 22 | class RoutedSandbox: 23 | """ 24 | A sandbox environment that routes code execution requests to the E2B Router. 25 | This class is designed for batched execution of scripts, primarily for Python code. 26 | It mimics the usage of 'Sandbox' from 'e2b_code_interpreter', but adds support for batch processing. 27 | 28 | Attributes: 29 | router_url (str): The URL of the E2B Router to which code execution requests are sent. 30 | """ 31 | 32 | def __init__(self, router_url: str): 33 | """ 34 | Initializes the RoutedSandbox with the specified router URL. 35 | 36 | Args: 37 | router_url (str): The URL of the E2B Router. 38 | """ 39 | self.router_url = router_url 40 | 41 | def run_code( 42 | self, 43 | scripts: list[str], 44 | languages: Optional[List[str]] = None, 45 | timeout: Optional[int] = None, 46 | request_timeout: Optional[int] = None, 47 | ) -> list[Execution]: 48 | """ 49 | Executes a batch of scripts in the sandbox environment. 50 | 51 | Args: 52 | scripts (list[str]): A list of code scripts to execute. 53 | languages (list[str], optional): List of programming languages for each script. If None, defaults to Python for all scripts. 54 | timeout (Optional[int], optional): The maximum execution time for each script in seconds. Defaults to 300 seconds. 55 | request_timeout (Optional[int], optional): The timeout for the HTTP request in seconds. Defaults to 30 seconds. 56 | 57 | Returns: 58 | list[Execution]: A list of Execution objects containing the results, logs, and errors (if any) for each script. 59 | """ 60 | # Set default values for timeouts if not provided 61 | if timeout is None: 62 | timeout = 300 # Default to 5 minutes 63 | if request_timeout is None: 64 | request_timeout = 30 # Default to 30 seconds 65 | 66 | # Default to Python for all scripts if languages is not provided 67 | if languages is None: 68 | languages = ["python"] * len(scripts) 69 | 70 | # Prepare the payload for the HTTP POST request 71 | payload = { 72 | "scripts": scripts, 73 | "languages": languages, 74 | "timeout": timeout, 75 | "request_timeout": request_timeout, 76 | } 77 | 78 | # Send the request to the E2B Router 79 | response = requests.post(f"http://{self.router_url}/execute_batch", json=payload) 80 | if not response.ok: 81 | print(f"Request failed with status code: {response.status_code}") 82 | 83 | # Parse the response and construct Execution objects 84 | results = response.json() 85 | output = [] 86 | for result in results: 87 | if result["execution"] is None: 88 | # If execution is None, create an empty Execution object 89 | # This can happen when a script times out or fails to execute 90 | execution = Execution() 91 | else: 92 | execution = Execution( 93 | results=[Result(**r) for r in result["execution"]["results"]], 94 | logs=result["execution"]["logs"], 95 | error=(ExecutionError(**result["execution"]["error"]) if result["execution"]["error"] else None), 96 | execution_count=result["execution"]["execution_count"], 97 | ) 98 | output.append(execution) 99 | 100 | return output 101 | 102 | 103 | if __name__ == "__main__": 104 | # for local testing launch an E2B router with: python e2b_router.py 105 | sbx = RoutedSandbox(router_url="0.0.0.0:8001") 106 | codes = ["print('hello world')", "import math\nprint(math.sqrt(3*2+10))"] 107 | executions = sbx.run_code(codes) # Execute Python inside the sandbox 108 | 109 | print(executions) -------------------------------------------------------------------------------- /math_tir/test_e2b_sandbox.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 23, 6 | "id": "37bb70a9", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "['1/3\\n']\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "from dotenv import load_dotenv\n", 19 | "load_dotenv()\n", 20 | "from e2b_code_interpreter import Sandbox, AsyncSandbox\n", 21 | "\n", 22 | "sandbox = Sandbox(timeout=60)\n", 23 | "\n", 24 | "script = \"\"\"\n", 25 | "import sympy as sp\n", 26 | "\n", 27 | "# Define the variable\n", 28 | "x = sp.symbols('x')\n", 29 | "\n", 30 | "# Define the function to integrate\n", 31 | "f = x**2\n", 32 | "\n", 33 | "# Compute the definite integral from 0 to 1\n", 34 | "integral_value = sp.integrate(f, (x, 0, 1))\n", 35 | "\n", 36 | "# Print the result\n", 37 | "print(integral_value)\n", 38 | "\"\"\"\n", 39 | "language = 'python'\n", 40 | "\n", 41 | "execution = sandbox.run_code(script, language=language)\n", 42 | "print(execution.logs.stdout)\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 24, 48 | "id": "6379f383", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "SANDBOX_TIMEOUT = 30\n", 53 | "MARGIN = 2\n", 54 | "REQUEST_TIMEOUT = SANDBOX_TIMEOUT - MARGIN\n", 55 | "ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN\n", 56 | "\n", 57 | "sandbox = await AsyncSandbox.create(timeout=SANDBOX_TIMEOUT, request_timeout=REQUEST_TIMEOUT)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 26, 63 | "id": "f42dfaef", 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "['1/3\\n']\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "import asyncio\n", 76 | "\n", 77 | "script = \"\"\"\n", 78 | "import sympy as sp\n", 79 | "\n", 80 | "# Define the variable\n", 81 | "x = sp.symbols('x')\n", 82 | "\n", 83 | "# Define the function to integrate\n", 84 | "f = x**2\n", 85 | "\n", 86 | "# Compute the definite integral from 0 to 1\n", 87 | "integral_value = sp.integrate(f, (x, 0, 1))\n", 88 | "\n", 89 | "# Print the result\n", 90 | "print(integral_value)\n", 91 | "\"\"\"\n", 92 | "language = 'python'\n", 93 | "\n", 94 | "execution = await asyncio.wait_for(sandbox.run_code(script, language=language), timeout=ASYNCIO_TIMEOUT)\n", 95 | "print(execution.logs.stdout)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 27, 101 | "id": "1c7add76", 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "{\n", 109 | " \"status\": \"Success\",\n", 110 | " \"message\": \"\",\n", 111 | " \"compile_result\": {\n", 112 | " \"status\": \"Finished\",\n", 113 | " \"execution_time\": 0.37775158882141113,\n", 114 | " \"return_code\": 0,\n", 115 | " \"stdout\": \"\",\n", 116 | " \"stderr\": \"\"\n", 117 | " },\n", 118 | " \"run_result\": {\n", 119 | " \"status\": \"Finished\",\n", 120 | " \"execution_time\": 0.008985280990600586,\n", 121 | " \"return_code\": 0,\n", 122 | " \"stdout\": \"Hello, world!\\n\",\n", 123 | " \"stderr\": \"\"\n", 124 | " },\n", 125 | " \"executor_pod_name\": null,\n", 126 | " \"files\": {}\n", 127 | "}\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "# docker run -it -p 8080:8080 vemlp-cn-beijing.cr.volces.com/preset-images/code-sandbox:server-20241204\n", 133 | "\n", 134 | "import requests\n", 135 | "import json\n", 136 | "\n", 137 | "response = requests.post('http://localhost:8080/run_code', json={\n", 138 | " 'code': '''\n", 139 | "#include \n", 140 | "\n", 141 | "int main() {\n", 142 | " std::cout << \"Hello, world!\" << std::endl;\n", 143 | " return 0;\n", 144 | "}\n", 145 | "''',\n", 146 | " 'language': 'cpp',\n", 147 | "})\n", 148 | "\n", 149 | "print(json.dumps(response.json(), indent=2))" 150 | ] 151 | } 152 | ], 153 | "metadata": { 154 | "kernelspec": { 155 | "display_name": "base", 156 | "language": "python", 157 | "name": "python3" 158 | }, 159 | "language_info": { 160 | "codemirror_mode": { 161 | "name": "ipython", 162 | "version": 3 163 | }, 164 | "file_extension": ".py", 165 | "mimetype": "text/x-python", 166 | "name": "python", 167 | "nbconvert_exporter": "python", 168 | "pygments_lexer": "ipython3", 169 | "version": "3.12.4" 170 | } 171 | }, 172 | "nbformat": 4, 173 | "nbformat_minor": 5 174 | } 175 | -------------------------------------------------------------------------------- /sequence_parallel/usp/test_usp_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_func 4 | from usp.usp_attn import LongContextAttention 5 | from usp.usp_utils import set_seq_parallel_pg, EXTRACT_FUNC_DICT 6 | from utils import log, set_seed 7 | 8 | 9 | if __name__ == "__main__": 10 | dist.init_process_group("nccl") 11 | rank = dist.get_rank() 12 | set_seed(rank) 13 | world_size = dist.get_world_size() 14 | dtype = torch.bfloat16 15 | device = torch.device(f"cuda:{rank}") 16 | 17 | batch_size = 1 18 | seqlen = 4096 19 | nheads = 8 20 | d = 128 21 | dropout_p = 0 22 | causal = True 23 | deterministic = False 24 | ring_attn_type = "zigzag" # ["basic", "stripe", "zigzag"] 25 | 26 | assert seqlen % world_size == 0 27 | assert d % 8 == 0 28 | 29 | q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) 30 | k = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) 31 | v = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) 32 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 33 | 34 | dist.broadcast(q, src=0) 35 | dist.broadcast(k, src=0) 36 | dist.broadcast(v, src=0) 37 | dist.broadcast(dout, src=0) 38 | 39 | # prepare process group for hybrid sequence parallelism 40 | use_ring_low_dim = True 41 | sp_ulysses_degree = 2 42 | sp_ring_degree = world_size // sp_ulysses_degree 43 | print(f"rank {rank}, sp_ulysses_degree: {sp_ulysses_degree}, sp_ring_degree: {sp_ring_degree}") 44 | set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) 45 | 46 | # Use EXTRACT_FUNC_DICT to shard the tensors 47 | local_q = EXTRACT_FUNC_DICT[ring_attn_type](q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone() 48 | local_k = EXTRACT_FUNC_DICT[ring_attn_type](k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone() 49 | local_v = EXTRACT_FUNC_DICT[ring_attn_type](v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone() 50 | 51 | local_q.requires_grad = True 52 | local_k.requires_grad = True 53 | local_v.requires_grad = True 54 | 55 | # extract local dout 56 | local_dout = EXTRACT_FUNC_DICT[ring_attn_type]( dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone() 57 | 58 | usp_attn = LongContextAttention(ring_impl_type=ring_attn_type) 59 | 60 | dist.barrier() 61 | if rank == 0: 62 | print("#" * 30) 63 | print("# USP forward:") 64 | print("#" * 30) 65 | 66 | out_ref, lse, _ = flash_attn_func( 67 | q, 68 | k, 69 | v, 70 | dropout_p=dropout_p, 71 | causal=causal, 72 | window_size=(-1, -1), 73 | alibi_slopes=None, 74 | deterministic=deterministic, 75 | return_attn_probs=True, 76 | ) 77 | 78 | local_out_ref = EXTRACT_FUNC_DICT[ring_attn_type](out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree) 79 | 80 | # usp attn forward 81 | usp_out = usp_attn( 82 | local_q, 83 | local_k, 84 | local_v, 85 | dropout_p=dropout_p, 86 | causal=causal, 87 | window_size=(-1, -1), 88 | alibi_slopes=None, 89 | deterministic=deterministic, 90 | return_attn_probs=True, 91 | ) 92 | 93 | log("out diff", usp_out - local_out_ref) 94 | 95 | max_memory = torch.cuda.max_memory_allocated(device) / (1024 * 1024) # Convert to MB 96 | print(f"[Rank#{rank}] Maximum GPU memory used: {max_memory:.2f} MB") 97 | torch.cuda.reset_peak_memory_stats(device) # Reset stats 98 | 99 | dist.barrier() 100 | if rank == 0: 101 | print("#" * 30) 102 | print("# backward:") 103 | print("#" * 30) 104 | 105 | out_ref.backward(dout) 106 | local_dq_ref = EXTRACT_FUNC_DICT[ring_attn_type](q.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree) 107 | local_dk_ref = EXTRACT_FUNC_DICT[ring_attn_type](k.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree) 108 | local_dv_ref = EXTRACT_FUNC_DICT[ring_attn_type](v.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree) 109 | 110 | 111 | usp_out.backward(local_dout) 112 | 113 | log("dq diff", local_dq_ref - local_q.grad) 114 | log("dk diff", local_dk_ref - local_k.grad) 115 | log("dv diff", local_dv_ref - local_v.grad) 116 | 117 | if dist.is_initialized(): 118 | dist.destroy_process_group() -------------------------------------------------------------------------------- /sequence_parallel/ring_flash_attention/benchmark_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from ring_flash_attention.ring_flash_attn import ring_flash_attn_qkvpacked_func 5 | import argparse 6 | from utils import flops, efficiency 7 | 8 | 9 | def benchmark(f, num_iter=100, forward_only=True, log=True, profile=False): 10 | dtype = torch.float16 11 | rank = dist.get_rank() 12 | world_size = dist.get_world_size() 13 | device = torch.device(f"cuda:{rank}") 14 | torch.cuda.set_device(device) 15 | 16 | batch_size = args.batch_size 17 | seqlen = args.seq_len 18 | nheads = args.nheads 19 | d = args.head_size 20 | 21 | dropout_p = 0 22 | causal = True 23 | deterministic = False 24 | 25 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 26 | assert d % 8 == 0 27 | 28 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 29 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 30 | 31 | if profile: 32 | torch.backends.cudnn.benchmark = True 33 | profiler = torch.profiler.profile( 34 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 35 | schedule=torch.profiler.schedule(wait=5, warmup=5, active=5,), 36 | record_shapes=True, 37 | profile_memory=True, 38 | with_flops=True, 39 | with_modules=True, 40 | with_stack=True, 41 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 42 | f"./benchmark/profiles/{f.__name__}_bs_{batch_size}_seq_{seqlen}_heads_{nheads}_d_{d}_rank_{dist.get_rank()}_fwd_only_{forward_only}" 43 | ), 44 | ) 45 | 46 | if profile: 47 | profiler.start() 48 | 49 | begin = torch.cuda.Event(enable_timing=True) 50 | begin.record() 51 | 52 | # warmup 53 | out = f( 54 | qkv, 55 | dropout_p=dropout_p, 56 | causal=causal, 57 | window_size=(-1, -1), 58 | alibi_slopes=None, 59 | deterministic=deterministic, 60 | return_attn_probs=False, 61 | ) 62 | out.backward(dout) 63 | 64 | begin = torch.cuda.Event(enable_timing=True) 65 | begin.record() 66 | 67 | if forward_only: 68 | with torch.no_grad(): 69 | for _ in range(num_iter): 70 | _ = f( 71 | qkv, 72 | dropout_p=dropout_p, 73 | causal=causal, 74 | window_size=(-1, -1), 75 | alibi_slopes=None, 76 | deterministic=deterministic, 77 | return_attn_probs=False, 78 | ) 79 | if profile: 80 | profiler.step() 81 | 82 | else: 83 | for _ in range(num_iter): 84 | qkv.grad = None 85 | out = f( 86 | qkv, 87 | dropout_p=dropout_p, 88 | causal=causal, 89 | window_size=(-1, -1), 90 | alibi_slopes=None, 91 | deterministic=deterministic, 92 | return_attn_probs=False, 93 | ) 94 | out.backward(dout) 95 | if profile: 96 | profiler.step() 97 | end = torch.cuda.Event(enable_timing=True) 98 | end.record() 99 | torch.cuda.synchronize(device=device) 100 | time = begin.elapsed_time(end)/1000 101 | 102 | if profile: 103 | profiler.stop() 104 | 105 | if rank == 0 and log: 106 | print(f"{num_iter / time:.3f} iters/s, {time*1000/num_iter:.3f} ms/iter") 107 | print(f"peak memory: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.3f} MB") 108 | mode = "fwd" if forward_only else "fwd_bwd" 109 | speed_f = efficiency( 110 | flops(batch_size, seqlen, d, nheads, causal, mode=mode), 111 | time/num_iter 112 | ) 113 | print(f"speed: {speed_f:.3f} TFLOPs/s") 114 | 115 | 116 | if __name__ == "__main__": 117 | dist.init_process_group("nccl") 118 | rank = dist.get_rank() 119 | 120 | parser = argparse.ArgumentParser(description="Process some integers.") 121 | parser.add_argument("--nheads", type=int, default=16, help="head number") 122 | parser.add_argument("--head_size", type=int, default=128, help="head dimension") 123 | parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length") 124 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 125 | parser.add_argument("--fwd_only", action="store_true", help="benchmark forward pass only") 126 | parser.add_argument("--profile", action="store_true", help="generate torch profile or not") 127 | args = parser.parse_args() 128 | 129 | torch.cuda.empty_cache() 130 | if rank == 0: 131 | print(f"{ring_flash_attn_qkvpacked_func.__name__} BS:{args.batch_size} seq_len:{args.seq_len} nheads:{args.nheads} head_size:{args.head_size}, fwd_only: {args.fwd_only}") 132 | benchmark(ring_flash_attn_qkvpacked_func, forward_only=args.fwd_only, log=True, profile=args.profile) 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /sequence_parallel/basic/benchmark_qkvpacked_func.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from flash_attn import flash_attn_qkvpacked_func 5 | import argparse 6 | from utils import flops, efficiency 7 | 8 | 9 | def benchmark(f, num_iter=100, forward_only=True, log=True, profile=False): 10 | dtype = torch.float16 11 | rank = dist.get_rank() 12 | world_size = dist.get_world_size() 13 | device = torch.device(f"cuda:{rank}") 14 | torch.cuda.set_device(device) 15 | 16 | batch_size = args.batch_size 17 | seqlen = args.seq_len 18 | nheads = args.nheads 19 | d = args.head_size 20 | 21 | dropout_p = 0 22 | causal = True 23 | deterministic = False 24 | 25 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 26 | assert d % 8 == 0 27 | 28 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 29 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 30 | 31 | if profile: 32 | torch.backends.cudnn.benchmark = True 33 | profiler = torch.profiler.profile( 34 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 35 | schedule=torch.profiler.schedule(wait=5, warmup=5, active=5,), 36 | record_shapes=True, 37 | profile_memory=True, 38 | with_flops=True, 39 | with_modules=True, 40 | with_stack=True, 41 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 42 | f"./benchmark/profiles/{f.__name__}_bs_{batch_size}_seq_{seqlen}_heads_{nheads}_d_{d}_rank_{dist.get_rank()}_fwd_only_{forward_only}" 43 | ), 44 | ) 45 | 46 | if profile: 47 | profiler.start() 48 | 49 | begin = torch.cuda.Event(enable_timing=True) 50 | begin.record() 51 | 52 | # warmup 53 | out = f( 54 | qkv, 55 | dropout_p=dropout_p, 56 | causal=causal, 57 | window_size=(-1, -1), 58 | softcap=0.0, 59 | alibi_slopes=None, 60 | deterministic=deterministic, 61 | return_attn_probs=False, 62 | ) 63 | out.backward(dout) 64 | 65 | begin = torch.cuda.Event(enable_timing=True) 66 | begin.record() 67 | 68 | if forward_only: 69 | with torch.no_grad(): 70 | for _ in range(num_iter): 71 | _ = f( 72 | qkv, 73 | dropout_p=dropout_p, 74 | causal=causal, 75 | window_size=(-1, -1), 76 | softcap=0.0, 77 | alibi_slopes=None, 78 | deterministic=deterministic, 79 | return_attn_probs=False, 80 | ) 81 | if profile: 82 | profiler.step() 83 | 84 | else: 85 | for _ in range(num_iter): 86 | qkv.grad = None 87 | out = f( 88 | qkv, 89 | dropout_p=dropout_p, 90 | causal=causal, 91 | window_size=(-1, -1), 92 | softcap=0.0, 93 | alibi_slopes=None, 94 | deterministic=deterministic, 95 | return_attn_probs=False, 96 | ) 97 | out.backward(dout) 98 | if profile: 99 | profiler.step() 100 | end = torch.cuda.Event(enable_timing=True) 101 | end.record() 102 | torch.cuda.synchronize(device=device) 103 | time = begin.elapsed_time(end)/1000 104 | 105 | if profile: 106 | profiler.stop() 107 | 108 | if rank == 0 and log: 109 | print(f"{num_iter / time:.3f} iters/s, {time*1000/num_iter:.3f} ms/iter") 110 | print(f"peak memory: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.3f} MB") 111 | mode = "fwd" if forward_only else "fwd_bwd" 112 | speed_f = efficiency( 113 | flops(batch_size, seqlen, d, nheads, causal, mode=mode), 114 | time/num_iter 115 | ) 116 | print(f"speed: {speed_f:.3f} TFLOPs/s") 117 | 118 | 119 | if __name__ == "__main__": 120 | dist.init_process_group("nccl") 121 | rank = dist.get_rank() 122 | 123 | parser = argparse.ArgumentParser(description="Process some integers.") 124 | parser.add_argument("--nheads", type=int, default=16, help="head number") 125 | parser.add_argument("--head_size", type=int, default=128, help="head dimension") 126 | parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length") 127 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 128 | parser.add_argument("--fwd_only", action="store_true", help="benchmark forward pass only") 129 | parser.add_argument("--profile", action="store_true", help="generate torch profile or not") 130 | args = parser.parse_args() 131 | 132 | torch.cuda.empty_cache() 133 | if rank == 0: 134 | print(f"{flash_attn_qkvpacked_func.__name__} BS:{args.batch_size} seq_len:{args.seq_len} nheads:{args.nheads} head_size:{args.head_size}, fwd_only: {args.fwd_only}") 135 | benchmark(flash_attn_qkvpacked_func, forward_only=args.fwd_only, log=True, profile=args.profile) 136 | 137 | -------------------------------------------------------------------------------- /sequence_parallel/usp/usp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | class Singleton: 5 | _instance = None 6 | 7 | def __new__(cls, *args, **kwargs): 8 | if not cls._instance: 9 | cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs) 10 | return cls._instance 11 | 12 | 13 | class ProcessGroupSingleton(Singleton): 14 | def __init__(self): 15 | self.ULYSSES_PG = None 16 | self.RING_PG = None 17 | 18 | 19 | PROCESS_GROUP = ProcessGroupSingleton() 20 | 21 | def set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size, use_ulysses_low=True): 22 | """ 23 | sp_ulysses_degree x sp_ring_degree = seq_parallel_degree 24 | (ulysses_degree, dp_degree) 25 | """ 26 | sp_degree = sp_ring_degree * sp_ulysses_degree 27 | dp_degree = world_size // sp_degree 28 | 29 | assert (world_size % sp_degree == 0), f"world_size {world_size} % sp_degree {sp_ulysses_degree} == 0" 30 | 31 | num_ulysses_pgs = sp_ring_degree # world_size // sp_ulysses_degree 32 | num_ring_pgs = sp_ulysses_degree # world_size // sp_ring_degree 33 | 34 | if use_ulysses_low: 35 | for dp_rank in range(dp_degree): 36 | offset = dp_rank * sp_degree 37 | for i in range(num_ulysses_pgs): 38 | ulysses_ranks = list(range(i * sp_ulysses_degree + offset, (i + 1) * sp_ulysses_degree + offset,)) 39 | group = torch.distributed.new_group(ulysses_ranks) 40 | if rank in ulysses_ranks: 41 | ulyssess_pg = group 42 | 43 | for i in range(num_ring_pgs): 44 | ring_ranks = list(range(i + offset, sp_degree + offset, num_ring_pgs)) 45 | group = torch.distributed.new_group(ring_ranks) 46 | if rank in ring_ranks: 47 | ring_pg = group 48 | 49 | else: 50 | for dp_rank in range(dp_degree): 51 | offset = dp_rank * sp_degree 52 | for i in range(num_ring_pgs): 53 | ring_ranks = list( 54 | range( 55 | i * sp_ring_degree + offset, (i + 1) * sp_ring_degree + offset 56 | ) 57 | ) 58 | group = torch.distributed.new_group(ring_ranks) 59 | if rank in ring_ranks: 60 | ring_pg = group 61 | 62 | for i in range(num_ulysses_pgs): 63 | ulysses_ranks = list( 64 | range(i + offset, sp_degree + offset, num_ulysses_pgs) 65 | ) 66 | group = torch.distributed.new_group(ulysses_ranks) 67 | if rank in ulysses_ranks: 68 | ulyssess_pg = group 69 | 70 | PROCESS_GROUP.ULYSSES_PG = ulyssess_pg 71 | PROCESS_GROUP.RING_PG = ring_pg 72 | 73 | 74 | def stripe_extract_local(value, rank, world_size, rd, ud, *args, **kwargs): 75 | # ud at the highest dim 76 | input_dim = value.dim() 77 | assert input_dim >= 2 78 | 79 | batch_size, seqlen, *rest = value.shape 80 | 81 | assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd 82 | assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud 83 | 84 | value = value.reshape(batch_size, seqlen // rd, rd, -1).contiguous() 85 | value = value.transpose(1, 2).reshape(batch_size, seqlen, -1).contiguous() 86 | value = value.chunk(world_size, dim=1)[rank] 87 | 88 | new_shape = [batch_size, seqlen // world_size] + rest 89 | return value.reshape(new_shape) 90 | 91 | 92 | def basic_extract_local(value, rank, world_size, *args, **kwargs): 93 | return value.chunk(world_size, dim=1)[rank].detach().clone() 94 | 95 | 96 | def zigzag_extract_local(value, rank, world_size, rd, ud, dim=1, *args, **kwargs): 97 | """ 98 | value is a tensor of shape (bs, seqlen, ...) 99 | """ 100 | input_dim = value.dim() 101 | assert input_dim >= 2 102 | batch_size, seqlen, *rest = value.shape 103 | 104 | value_chunks = value.chunk(2 * rd, dim=dim) 105 | r_rank = dist.get_rank(group=PROCESS_GROUP.RING_PG) 106 | u_rank = dist.get_rank(group=PROCESS_GROUP.ULYSSES_PG) 107 | 108 | assert dist.get_world_size(group=PROCESS_GROUP.RING_PG) == rd 109 | assert dist.get_world_size(group=PROCESS_GROUP.ULYSSES_PG) == ud 110 | 111 | local_value = torch.cat([value_chunks[r_rank], value_chunks[2 * rd - r_rank - 1]], dim=dim).chunk(ud, dim=dim)[u_rank] 112 | 113 | new_shape = [batch_size, seqlen // world_size] + rest 114 | return local_value.reshape(new_shape).contiguous() 115 | 116 | 117 | EXTRACT_FUNC_DICT = { 118 | "basic": basic_extract_local, 119 | "stripe": stripe_extract_local, 120 | "zigzag": zigzag_extract_local 121 | } 122 | 123 | 124 | # test if flash_attn is available 125 | try: 126 | import flash_attn 127 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 128 | HAS_FLASH_ATTN = True 129 | except ImportError: 130 | HAS_FLASH_ATTN = False 131 | 132 | try: 133 | from flash_attn_interface import _flash_attn_forward as flash_attn_forward_hopper 134 | from flash_attn_interface import _flash_attn_backward as flash_attn_func_hopper_backward 135 | from flash_attn_interface import flash_attn_func as flash3_attn_func 136 | HAS_FLASH_ATTN_HOPPER = True 137 | except ImportError: 138 | HAS_FLASH_ATTN_HOPPER = False -------------------------------------------------------------------------------- /moe_ep/test_moe_kernel_vllm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn.functional as F 4 | 5 | from vllm.model_executor.custom_op import CustomOp 6 | from vllm.platforms import current_platform 7 | 8 | from vllm.model_executor.layers.fused_moe import fused_moe 9 | from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( 10 | fused_moe as iterative_moe) 11 | 12 | import pdb 13 | 14 | @CustomOp.register("silu_and_mul") 15 | class SiluAndMul(CustomOp): 16 | """An activation function for SwiGLU. 17 | 18 | The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. 19 | 20 | Shapes: 21 | x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) 22 | return: (num_tokens, d) or (batch_size, seq_len, d) 23 | """ 24 | 25 | def __init__(self): 26 | super().__init__() 27 | if current_platform.is_cuda_alike() or current_platform.is_cpu(): 28 | self.op = torch.ops._C.silu_and_mul 29 | elif current_platform.is_xpu(): 30 | from vllm._ipex_ops import ipex_ops 31 | self.op = ipex_ops.silu_and_mul 32 | 33 | def forward_native(self, x: torch.Tensor) -> torch.Tensor: 34 | """PyTorch-native implementation equivalent to forward().""" 35 | d = x.shape[-1] // 2 36 | return F.silu(x[..., :d]) * x[..., d:] 37 | 38 | def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: 39 | d = x.shape[-1] // 2 40 | output_shape = (x.shape[:-1] + (d, )) 41 | out = torch.empty(output_shape, dtype=x.dtype, device=x.device) 42 | self.op(out, x) 43 | return out 44 | 45 | def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: 46 | d = x.shape[-1] // 2 47 | output_shape = (x.shape[:-1] + (d, )) 48 | out = torch.empty(output_shape, dtype=x.dtype, device=x.device) 49 | self.op(out, x) 50 | return out 51 | 52 | def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: 53 | d = x.shape[-1] // 2 54 | x_reshaped = x.view(-1, x.shape[-1]) 55 | s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) 56 | result = s * x_reshaped[:, d:] 57 | return result.view(*x.shape[:-1], d) 58 | 59 | 60 | def torch_moe(a, w1, w2, score, topk, expert_map): 61 | B, D = a.shape 62 | a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) 63 | out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) 64 | score = torch.softmax(score, dim=-1, dtype=torch.float32) 65 | topk_weight, topk_ids = torch.topk(score, topk) 66 | topk_weight = topk_weight.view(-1) 67 | topk_ids = topk_ids.view(-1) 68 | if expert_map is not None: 69 | topk_ids = expert_map[topk_ids] 70 | for i in range(w1.shape[0]): 71 | mask = topk_ids == i 72 | if mask.sum(): 73 | out[mask] = SiluAndMul()( 74 | a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) 75 | return (out.view(B, -1, w2.shape[1]) * 76 | topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) 77 | 78 | 79 | def test_fused_moe( 80 | m: int, 81 | n: int, 82 | k: int, 83 | e: int, 84 | topk: int, 85 | ep_size: int, 86 | dtype: torch.dtype, 87 | ): 88 | a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 89 | w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 90 | w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 91 | 92 | score = torch.randn((m, e), device="cuda", dtype=dtype) 93 | 94 | if ep_size > 1: 95 | local_e = e // ep_size 96 | e_ids = torch.randint(0, 97 | e, (local_e, ), 98 | device="cuda", 99 | dtype=torch.int32) 100 | e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) 101 | e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) 102 | w1 = w1[e_ids] 103 | w2 = w2[e_ids] 104 | else: 105 | e_map = None 106 | 107 | triton_output = fused_moe(a, 108 | w1, 109 | w2, 110 | score, 111 | topk, 112 | global_num_experts=e, 113 | expert_map=e_map, 114 | renormalize=False) 115 | 116 | torch_output = torch_moe(a, w1, w2, score, topk, e_map) 117 | torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0) 118 | iterative_output = iterative_moe(a, 119 | w1, 120 | w2, 121 | score, 122 | topk, 123 | global_num_experts=e, 124 | expert_map=e_map, 125 | renormalize=False) 126 | torch.allclose(iterative_output, 127 | torch_output, 128 | atol=2e-2, 129 | rtol=0) 130 | 131 | if __name__ == "__main__": 132 | NUM_EXPERTS = 8 133 | EP_SIZE = 4 134 | TOP_KS = 3 135 | 136 | m = 33 137 | n = 128 138 | k = 512 139 | e = NUM_EXPERTS 140 | topk = TOP_KS 141 | ep_size = EP_SIZE 142 | dtype = torch.float16 143 | test_fused_moe(m, n, k, e, topk, ep_size, dtype) 144 | 145 | -------------------------------------------------------------------------------- /sequence_parallel/loongtrain/test_double_ring_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_func 4 | from loongtrain.double_ring_attn import zigzag_ring_flash_attn_func_with_sliding_window 5 | from ring_flash_attention.zigzag_ring_flash_attn import extract_local 6 | from loongtrain.double_ring_utils import generate_2d_attn_process_group 7 | from utils import log, set_seed 8 | 9 | 10 | if __name__ == "__main__": 11 | dist.init_process_group("nccl") 12 | rank = dist.get_rank() 13 | set_seed(rank) 14 | world_size = dist.get_world_size() 15 | dtype = torch.bfloat16 16 | device = torch.device(f"cuda:{rank}") 17 | 18 | batch_size = 1 19 | seqlen = 4096 20 | nheads = 8 21 | d = 128 22 | dropout_p = 0 23 | causal = True 24 | deterministic = False 25 | 26 | assert seqlen % world_size == 0 27 | assert d % 8 == 0 28 | 29 | q = torch.randn( 30 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 31 | ) 32 | k = torch.randn( 33 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 34 | ) 35 | v = torch.randn( 36 | batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True 37 | ) 38 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 39 | 40 | dist.broadcast(q, src=0) 41 | dist.broadcast(k, src=0) 42 | dist.broadcast(v, src=0) 43 | dist.broadcast(dout, src=0) 44 | 45 | # prepare process group for double ring attention sequence parallelism 46 | context_parallel_size = 8 47 | double_ring_window_size = 4 48 | 49 | group_results = generate_2d_attn_process_group( 50 | world_size, 51 | rank, 52 | head_size=1, 53 | context_size=context_parallel_size, 54 | window_size=double_ring_window_size, 55 | head_first=True, 56 | interleaved=False, 57 | sp_size=world_size, 58 | with_cpu_group=False, 59 | ) 60 | 61 | for item in group_results: 62 | if item[5] == "head": 63 | head_group = item[2] 64 | elif item[5] == "context": 65 | context_group = item[2] 66 | elif item[5] == "intra_window": 67 | intra_window_group = item[2] 68 | elif item[5] == "inter_window": 69 | inter_window_group = item[2] 70 | elif item[5] == "dkv_intra_window": 71 | dkv_intra_window_group = item[2] 72 | elif item[5] == "dkv_inter_window": 73 | dkv_inter_window_group = item[2] 74 | 75 | # Use EXTRACT_FUNC_DICT to shard the tensors 76 | local_q = extract_local(q, rank, world_size).detach().clone() 77 | local_k = extract_local(k, rank, world_size).detach().clone() 78 | local_v = extract_local(v, rank, world_size).detach().clone() 79 | 80 | local_q.requires_grad = True 81 | local_k.requires_grad = True 82 | local_v.requires_grad = True 83 | 84 | # extract local dout 85 | local_dout = extract_local(dout, rank, world_size).detach().clone() 86 | 87 | dist.barrier() 88 | if rank == 0: 89 | print("#" * 30) 90 | print("# USP forward:") 91 | print("#" * 30) 92 | 93 | out_ref, lse, _ = flash_attn_func( 94 | q, 95 | k, 96 | v, 97 | dropout_p=dropout_p, 98 | causal=causal, 99 | window_size=(-1, -1), 100 | alibi_slopes=None, 101 | deterministic=deterministic, 102 | return_attn_probs=True, 103 | ) 104 | 105 | local_out_ref = extract_local(out_ref, rank, world_size).detach().clone() 106 | 107 | # usp attn forward 108 | double_ring_out, double_ring_lse, _ = ( 109 | zigzag_ring_flash_attn_func_with_sliding_window( 110 | local_q, 111 | local_k, 112 | local_v, 113 | dropout_p=dropout_p, 114 | causal=causal, 115 | window_size=(-1, -1), 116 | alibi_slopes=None, 117 | deterministic=deterministic, 118 | return_attn_probs=True, 119 | context_group=context_group, 120 | inter_window_group=inter_window_group, 121 | intra_window_group=intra_window_group, 122 | dkv_inter_window_group=dkv_inter_window_group, 123 | dkv_intra_window_group=dkv_intra_window_group, 124 | double_ring_window_size=double_ring_window_size, 125 | ) 126 | ) 127 | 128 | log("out diff", double_ring_out - local_out_ref) 129 | 130 | max_memory = torch.cuda.max_memory_allocated(device) / ( 131 | 1024 * 1024 132 | ) # Convert to MB 133 | print(f"[Rank#{rank}] Maximum GPU memory used: {max_memory:.2f} MB") 134 | torch.cuda.reset_peak_memory_stats(device) # Reset stats 135 | 136 | dist.barrier() 137 | if rank == 0: 138 | print("#" * 30) 139 | print("# backward:") 140 | print("#" * 30) 141 | 142 | out_ref.backward(dout) 143 | local_dq_ref = extract_local(q.grad, rank, world_size) 144 | local_dk_ref = extract_local(k.grad, rank, world_size) 145 | local_dv_ref = extract_local(v.grad, rank, world_size) 146 | 147 | double_ring_out.backward(local_dout) 148 | 149 | log("dq diff", local_dq_ref - local_q.grad) 150 | log("dk diff", local_dk_ref - local_k.grad) 151 | log("dv diff", local_dv_ref - local_v.grad) 152 | 153 | if dist.is_initialized(): 154 | dist.destroy_process_group() 155 | -------------------------------------------------------------------------------- /math_tir/routed_morph.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import List, Optional 17 | 18 | import requests 19 | 20 | 21 | class RoutedMorphSandbox: 22 | """ 23 | Client for the MorphCloud router service that mimics the API of MorphCloud's Sandbox. 24 | 25 | This class provides a simple interface to execute code via a central MorphCloud router, 26 | which manages sandbox creation and cleanup. It allows batch processing of multiple scripts 27 | in a single request for improved efficiency. 28 | 29 | Attributes: 30 | router_url (str): The URL of the MorphCloud router service. 31 | timeout (int): Execution timeout in seconds. 32 | request_timeout (int): HTTP request timeout in seconds. 33 | """ 34 | 35 | def __init__(self, router_url: str, timeout: int = 300, request_timeout: int = 60): 36 | """ 37 | Initialize the routed MorphCloud sandbox client. 38 | 39 | Args: 40 | router_url: The URL of the MorphCloud router, including host and port. 41 | timeout: Default execution timeout in seconds. 42 | request_timeout: Default HTTP request timeout in seconds. 43 | """ 44 | self.router_url = router_url 45 | self.timeout = timeout 46 | self.request_timeout = request_timeout 47 | 48 | def run_code( 49 | self, 50 | scripts: List[str], 51 | languages: Optional[List[str]] = None, 52 | timeout: Optional[int] = None, 53 | request_timeout: Optional[int] = None, 54 | ) -> List: 55 | """ 56 | Execute multiple scripts using MorphCloud via the router. 57 | 58 | Args: 59 | scripts: List of code scripts to execute. 60 | languages: List of programming languages for each script. If None, defaults to Python for all scripts. 61 | timeout: Execution timeout in seconds. If None, uses the instance timeout. 62 | request_timeout: HTTP request timeout in seconds. If None, uses the instance request_timeout. 63 | 64 | Returns: 65 | List of execution results with text and exception_str properties. 66 | """ 67 | 68 | actual_timeout = timeout if timeout is not None else self.timeout 69 | actual_request_timeout = request_timeout if request_timeout is not None else self.request_timeout 70 | 71 | # Default to Python for all scripts if languages is not provided 72 | if languages is None: 73 | languages = ["python"] * len(scripts) 74 | 75 | payload = { 76 | "scripts": scripts, 77 | "languages": languages, 78 | "timeout": actual_timeout, 79 | "request_timeout": actual_request_timeout, 80 | } 81 | 82 | try: 83 | endpoint = f"http://{self.router_url}/execute_batch" 84 | response = requests.post(endpoint, json=payload, timeout=actual_request_timeout) 85 | 86 | if response.status_code != 200: 87 | error = f"Request to MorphCloud router failed with status code: {response.status_code}" 88 | print(error) 89 | 90 | results = [] 91 | for _ in scripts: 92 | results.append(type("obj", (object,), {"text": None, "exception_str": error})) 93 | return results 94 | 95 | response_data = response.json() 96 | results = [] 97 | 98 | for item in response_data: 99 | # Log the response data to see what we're getting 100 | # print(f"RoutedMorphSandbox: Got response item: {item}") 101 | result = type( 102 | "obj", 103 | (object,), 104 | { 105 | "text": item.get("text"), 106 | "exception_str": item.get("exception_str"), 107 | }, 108 | ) 109 | results.append(result) 110 | 111 | return results 112 | 113 | except Exception as e: 114 | error = f"Error communicating with MorphCloud router: {str(e)}" 115 | print(error) 116 | 117 | results = [] 118 | for _ in scripts: 119 | results.append(type("obj", (object,), {"text": None, "exception_str": error})) 120 | return results 121 | 122 | if __name__ == "__main__": 123 | # for local testing launch an E2B router with: python e2b_router.py 124 | sbx = RoutedMorphSandbox(router_url="0.0.0.0:8001") 125 | codes = ["print('hello world')", "import math\nprint(math.sqrt(3*2+10))"] 126 | executions = sbx.run_code(codes) # Execute Python inside the sandbox 127 | for execution in executions: 128 | print(f"Execution text: {execution.text}") 129 | print(f"Execution exception: {execution.exception_str}") 130 | -------------------------------------------------------------------------------- /sequence_parallel/ring_flash_attention/rfa_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | import inspect 7 | from functools import cache 8 | 9 | 10 | __all__ = ["update_out_and_lse", "RingComm", "get_default_args"] 11 | 12 | 13 | @cache 14 | def get_default_args(func): 15 | spec = inspect.getfullargspec(func) 16 | defaults = spec.defaults if spec.defaults is not None else () 17 | padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults 18 | args = dict(zip(spec.args, padded_defaults)) 19 | if "softcap" in args: 20 | args["softcap"] = 0.0 21 | return args 22 | 23 | 24 | @torch.jit.script 25 | def _update_out_and_lse( 26 | out: torch.Tensor, 27 | lse: torch.Tensor, 28 | block_out: torch.Tensor, 29 | block_lse: torch.Tensor, 30 | ) -> Tuple[torch.Tensor, torch.Tensor]: 31 | 32 | block_out = block_out.to(torch.float32) 33 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 34 | 35 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) 36 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out 37 | # For additional context and discussion, please refer to: 38 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 39 | out = out - F.sigmoid(block_lse - lse) * (out - block_out) 40 | lse = lse - F.logsigmoid(lse - block_lse) 41 | 42 | return out, lse 43 | 44 | 45 | def update_out_and_lse( 46 | out: Optional[torch.Tensor], 47 | lse: Optional[torch.Tensor], 48 | block_out: torch.Tensor, 49 | block_lse: torch.Tensor, 50 | slice_=None, 51 | ) -> Tuple[torch.Tensor, torch.Tensor]: 52 | if out is None: 53 | if slice_ is not None: 54 | raise RuntimeError("first update_out_and_lse should not pass slice_ args") 55 | out = block_out.to(torch.float32) 56 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 57 | elif slice_ is not None: 58 | slice_out, slice_lse = out[slice_], lse[slice_] 59 | slice_out, slice_lse = _update_out_and_lse( 60 | slice_out, slice_lse, block_out, block_lse 61 | ) 62 | out[slice_], lse[slice_] = slice_out, slice_lse 63 | else: 64 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse) 65 | return out, lse 66 | 67 | 68 | @torch.jit.script 69 | def flatten_varlen_lse(lse, cu_seqlens): 70 | new_lse = [] 71 | for i in range(len(cu_seqlens) - 1): 72 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 73 | new_lse.append(lse[i, :, : end - start]) 74 | return torch.cat(new_lse, dim=1) 75 | 76 | 77 | @torch.jit.script 78 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 79 | num_seq = len(cu_seqlens) - 1 80 | num_head = lse.shape[-2] 81 | new_lse = torch.empty( 82 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device 83 | ) 84 | for i in range(num_seq): 85 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 86 | new_lse[i, : end - start] = lse[start:end] 87 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() 88 | 89 | 90 | class RingComm: 91 | def __init__(self, process_group: dist.ProcessGroup): 92 | self._process_group = process_group 93 | self._ops = [] 94 | self.rank = dist.get_rank(self._process_group) 95 | self.world_size = dist.get_world_size(self._process_group) 96 | self._reqs = None 97 | 98 | self.send_rank = (self.rank + 1) % self.world_size 99 | self.recv_rank = (self.rank - 1) % self.world_size 100 | 101 | if process_group is not None: 102 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) 103 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) 104 | 105 | def send_recv( 106 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None 107 | ) -> torch.Tensor: 108 | if recv_tensor is None: 109 | res = torch.empty_like(to_send) 110 | else: 111 | res = recv_tensor 112 | 113 | send_op = dist.P2POp( 114 | dist.isend, to_send, self.send_rank, group=self._process_group 115 | ) 116 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) 117 | self._ops.append(send_op) 118 | self._ops.append(recv_op) 119 | return res 120 | 121 | def commit(self): 122 | if self._reqs is not None: 123 | raise RuntimeError("commit called twice") 124 | self._reqs = dist.batch_isend_irecv(self._ops) 125 | 126 | def wait(self): 127 | if self._reqs is None: 128 | raise RuntimeError("wait called before commit") 129 | for req in self._reqs: 130 | req.wait() 131 | self._reqs = None 132 | self._ops = [] 133 | 134 | 135 | class AllGatherComm: 136 | def __init__(self, group=None) -> None: 137 | self.group = group 138 | self.handles = [] 139 | 140 | def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): 141 | handle = dist.all_gather_into_tensor( 142 | output_tensor, input_tensor, group=self.group, async_op=True 143 | ) 144 | self.handles.append(handle) 145 | 146 | def wait(self): 147 | for handle in self.handles: 148 | handle.wait() 149 | self.handles = [] 150 | -------------------------------------------------------------------------------- /sequence_parallel/ring_flash_attention/benchmark_stripe_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from ring_flash_attention.stripe_flash_attn import stripe_flash_attn_qkvpacked_func, extract_local 5 | import argparse 6 | from utils import flops, efficiency 7 | 8 | 9 | def benchmark(f, num_iter=100, forward_only=True, log=True, profile=False): 10 | dtype = torch.float16 11 | rank = dist.get_rank() 12 | world_size = dist.get_world_size() 13 | device = torch.device(f"cuda:{rank}") 14 | torch.cuda.set_device(device) 15 | 16 | batch_size = args.batch_size 17 | seqlen = args.seq_len 18 | nheads = args.nheads 19 | d = args.head_size 20 | 21 | dropout_p = 0 22 | causal = True 23 | deterministic = False 24 | 25 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 26 | assert d % 8 == 0 27 | 28 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 29 | dist.broadcast(qkv, src=0) 30 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 31 | dist.broadcast(dout, src=0) 32 | 33 | local_qkv = extract_local(qkv, rank, world_size).detach().clone() 34 | local_qkv.requires_grad = True 35 | local_dout = extract_local(dout, rank, world_size).detach().clone() 36 | 37 | if profile: 38 | torch.backends.cudnn.benchmark = True 39 | profiler = torch.profiler.profile( 40 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 41 | schedule=torch.profiler.schedule(wait=5, warmup=5, active=5,), 42 | record_shapes=True, 43 | profile_memory=True, 44 | with_flops=True, 45 | with_modules=True, 46 | with_stack=True, 47 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 48 | f"./benchmark/profiles/{f.__name__}_bs_{batch_size}_seq_{seqlen}_heads_{nheads}_d_{d}_rank_{dist.get_rank()}_fwd_only_{forward_only}" 49 | ), 50 | ) 51 | 52 | if profile: 53 | profiler.start() 54 | 55 | begin = torch.cuda.Event(enable_timing=True) 56 | begin.record() 57 | 58 | # warmup 59 | out = f( 60 | local_qkv, 61 | dropout_p=dropout_p, 62 | causal=causal, 63 | window_size=(-1, -1), 64 | alibi_slopes=None, 65 | deterministic=deterministic, 66 | return_attn_probs=False, 67 | ) 68 | out.backward(local_dout) 69 | 70 | begin = torch.cuda.Event(enable_timing=True) 71 | begin.record() 72 | 73 | if forward_only: 74 | with torch.no_grad(): 75 | for _ in range(num_iter): 76 | _ = f( 77 | local_qkv, 78 | dropout_p=dropout_p, 79 | causal=causal, 80 | window_size=(-1, -1), 81 | alibi_slopes=None, 82 | deterministic=deterministic, 83 | return_attn_probs=False, 84 | ) 85 | if profile: 86 | profiler.step() 87 | 88 | else: 89 | for _ in range(num_iter): 90 | qkv.grad = None 91 | out = f( 92 | local_qkv, 93 | dropout_p=dropout_p, 94 | causal=causal, 95 | window_size=(-1, -1), 96 | alibi_slopes=None, 97 | deterministic=deterministic, 98 | return_attn_probs=False, 99 | ) 100 | out.backward(local_dout) 101 | if profile: 102 | profiler.step() 103 | end = torch.cuda.Event(enable_timing=True) 104 | end.record() 105 | torch.cuda.synchronize(device=device) 106 | time = begin.elapsed_time(end)/1000 107 | 108 | if profile: 109 | profiler.stop() 110 | 111 | if rank == 0 and log: 112 | print(f"{num_iter / time:.3f} iters/s, {time*1000/num_iter:.3f} ms/iter") 113 | print(f"peak memory: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.3f} MB") 114 | mode = "fwd" if forward_only else "fwd_bwd" 115 | speed_f = efficiency( 116 | flops(batch_size, seqlen, d, nheads, causal, mode=mode), 117 | time/num_iter 118 | ) 119 | print(f"speed: {speed_f:.3f} TFLOPs/s") 120 | 121 | 122 | if __name__ == "__main__": 123 | dist.init_process_group("nccl") 124 | rank = dist.get_rank() 125 | 126 | parser = argparse.ArgumentParser(description="Process some integers.") 127 | parser.add_argument("--nheads", type=int, default=16, help="head number") 128 | parser.add_argument("--head_size", type=int, default=128, help="head dimension") 129 | parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length") 130 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 131 | parser.add_argument("--fwd_only", action="store_true", help="benchmark forward pass only") 132 | parser.add_argument("--profile", action="store_true", help="generate torch profile or not") 133 | args = parser.parse_args() 134 | 135 | torch.cuda.empty_cache() 136 | if rank == 0: 137 | print(f"{stripe_flash_attn_qkvpacked_func.__name__} BS:{args.batch_size} seq_len:{args.seq_len} nheads:{args.nheads} head_size:{args.head_size}, fwd_only: {args.fwd_only}") 138 | benchmark(stripe_flash_attn_qkvpacked_func, num_iter=500, forward_only=args.fwd_only, log=True, profile=args.profile) 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /sequence_parallel/reference.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | 6 | 7 | 8 | def construct_local_mask( 9 | seqlen_q, 10 | seqlen_k, 11 | window_size=(-1, -1), # -1 means infinite window size 12 | query_padding_mask=None, 13 | key_padding_mask=None, 14 | device=None, 15 | key_leftpad=None, 16 | ): 17 | row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") 18 | col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) 19 | if key_leftpad is not None: 20 | key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") 21 | col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) 22 | col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) 23 | sk = ( 24 | seqlen_k 25 | if key_padding_mask is None 26 | else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") 27 | ) 28 | sq = ( 29 | seqlen_q 30 | if query_padding_mask is None 31 | else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") 32 | ) 33 | if window_size[0] < 0: 34 | return col_idx > row_idx + sk - sq + window_size[1] 35 | else: 36 | sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk 37 | return torch.logical_or( 38 | col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), 39 | col_idx < row_idx + sk - sq - window_size[0], 40 | ) 41 | 42 | 43 | def attention_ref( 44 | q, 45 | k, 46 | v, 47 | query_padding_mask=None, 48 | key_padding_mask=None, 49 | attn_bias=None, 50 | dropout_p=0.0, 51 | dropout_mask=None, 52 | causal=False, 53 | window_size=(-1, -1), # -1 means infinite window size 54 | softcap=0.0, 55 | upcast=True, 56 | reorder_ops=False, 57 | key_leftpad=None, 58 | ): 59 | """ 60 | Arguments: 61 | q: (batch_size, seqlen_q, nheads, head_dim) 62 | k: (batch_size, seqlen_k, nheads_k, head_dim) 63 | v: (batch_size, seqlen_k, nheads_k, head_dim) 64 | query_padding_mask: (batch_size, seqlen_q) 65 | key_padding_mask: (batch_size, seqlen_k) 66 | attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) 67 | dropout_p: float 68 | dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) 69 | causal: whether to apply causal masking 70 | window_size: (int, int), left and right window size 71 | upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast 72 | output back to fp16/bf16. 73 | reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) 74 | without changing the math. This is to estimate the numerical error from operation 75 | reordering. 76 | Output: 77 | output: (batch_size, seqlen_q, nheads, head_dim) 78 | attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout 79 | """ 80 | if causal: 81 | window_size = (window_size[0], 0) 82 | dtype_og = q.dtype 83 | if upcast: 84 | q, k, v = q.float(), k.float(), v.float() 85 | seqlen_q, seqlen_k = q.shape[1], k.shape[1] 86 | k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) 87 | v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) 88 | d = q.shape[-1] 89 | if not reorder_ops: 90 | scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) 91 | else: 92 | scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) 93 | if softcap > 0: 94 | scores /= softcap 95 | scores = scores.tanh() 96 | scores *= softcap 97 | if key_padding_mask is not None: 98 | scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) 99 | if window_size[0] >= 0 or window_size[1] >= 0: 100 | local_mask = construct_local_mask( 101 | seqlen_q, 102 | seqlen_k, 103 | window_size, 104 | query_padding_mask, 105 | key_padding_mask, 106 | q.device, 107 | key_leftpad=key_leftpad, 108 | ) 109 | scores.masked_fill_(local_mask, float("-inf")) 110 | if attn_bias is not None: 111 | scores = scores + attn_bias 112 | attention = torch.softmax(scores, dim=-1).to(v.dtype) 113 | # Some rows might be completely masked out so we fill them with zero instead of NaN 114 | if window_size[0] >= 0 or window_size[1] >= 0: 115 | attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) 116 | # We want to mask here so that the attention matrix doesn't have any NaNs 117 | # Otherwise we'll get NaN in dV 118 | if query_padding_mask is not None: 119 | attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) 120 | dropout_scaling = 1.0 / (1 - dropout_p) 121 | # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling 122 | # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) 123 | if dropout_mask is not None: 124 | attention_drop = attention.masked_fill(~dropout_mask, 0.0) 125 | else: 126 | attention_drop = attention 127 | output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) 128 | if query_padding_mask is not None: 129 | output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) 130 | return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) 131 | -------------------------------------------------------------------------------- /sequence_parallel/ring_flash_attention/benchmark_zigzag_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from ring_flash_attention.zigzag_ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func, extract_local 5 | import argparse 6 | from utils import flops, efficiency 7 | 8 | 9 | def benchmark(f, num_iter=100, forward_only=True, log=True, profile=False): 10 | dtype = torch.float16 11 | rank = dist.get_rank() 12 | world_size = dist.get_world_size() 13 | device = torch.device(f"cuda:{rank}") 14 | torch.cuda.set_device(device) 15 | 16 | batch_size = args.batch_size 17 | seqlen = args.seq_len 18 | nheads = args.nheads 19 | d = args.head_size 20 | 21 | dropout_p = 0 22 | causal = True 23 | deterministic = False 24 | 25 | assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" 26 | assert d % 8 == 0 27 | 28 | qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) 29 | dist.broadcast(qkv, src=0) 30 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 31 | dist.broadcast(dout, src=0) 32 | 33 | local_qkv = extract_local(qkv, rank, world_size).detach().clone() 34 | local_qkv.requires_grad = True 35 | local_dout = extract_local(dout, rank, world_size).detach().clone() 36 | 37 | if profile: 38 | torch.backends.cudnn.benchmark = True 39 | profiler = torch.profiler.profile( 40 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], 41 | schedule=torch.profiler.schedule(wait=5, warmup=5, active=5,), 42 | record_shapes=True, 43 | profile_memory=True, 44 | with_flops=True, 45 | with_modules=True, 46 | with_stack=True, 47 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 48 | f"./benchmark/profiles/{f.__name__}_bs_{batch_size}_seq_{seqlen}_heads_{nheads}_d_{d}_rank_{dist.get_rank()}_fwd_only_{forward_only}" 49 | ), 50 | ) 51 | 52 | if profile: 53 | profiler.start() 54 | 55 | begin = torch.cuda.Event(enable_timing=True) 56 | begin.record() 57 | 58 | # warmup 59 | out = f( 60 | local_qkv, 61 | dropout_p=dropout_p, 62 | causal=causal, 63 | window_size=(-1, -1), 64 | alibi_slopes=None, 65 | deterministic=deterministic, 66 | return_attn_probs=False, 67 | ) 68 | out.backward(local_dout) 69 | 70 | begin = torch.cuda.Event(enable_timing=True) 71 | begin.record() 72 | 73 | if forward_only: 74 | with torch.no_grad(): 75 | for _ in range(num_iter): 76 | _ = f( 77 | local_qkv, 78 | dropout_p=dropout_p, 79 | causal=causal, 80 | window_size=(-1, -1), 81 | alibi_slopes=None, 82 | deterministic=deterministic, 83 | return_attn_probs=False, 84 | ) 85 | if profile: 86 | profiler.step() 87 | 88 | else: 89 | for _ in range(num_iter): 90 | local_qkv.grad = None 91 | out = f( 92 | local_qkv, 93 | dropout_p=dropout_p, 94 | causal=causal, 95 | window_size=(-1, -1), 96 | alibi_slopes=None, 97 | deterministic=deterministic, 98 | return_attn_probs=False, 99 | ) 100 | out.backward(local_dout) 101 | if profile: 102 | profiler.step() 103 | end = torch.cuda.Event(enable_timing=True) 104 | end.record() 105 | torch.cuda.synchronize(device=device) 106 | time = begin.elapsed_time(end)/1000 107 | 108 | if profile: 109 | profiler.stop() 110 | 111 | if rank == 0 and log: 112 | print(f"{num_iter / time:.3f} iters/s, {time*1000/num_iter:.3f} ms/iter") 113 | print(f"peak memory: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.3f} MB") 114 | mode = "fwd" if forward_only else "fwd_bwd" 115 | speed_f = efficiency( 116 | flops(batch_size, seqlen, d, nheads, causal, mode=mode), 117 | time/num_iter 118 | ) 119 | print(f"speed: {speed_f:.3f} TFLOPs/s") 120 | 121 | 122 | if __name__ == "__main__": 123 | dist.init_process_group("nccl") 124 | rank = dist.get_rank() 125 | 126 | parser = argparse.ArgumentParser(description="Process some integers.") 127 | parser.add_argument("--nheads", type=int, default=16, help="head number") 128 | parser.add_argument("--head_size", type=int, default=128, help="head dimension") 129 | parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length") 130 | parser.add_argument("--batch_size", type=int, default=2, help="batch size") 131 | parser.add_argument("--fwd_only", action="store_true", help="benchmark forward pass only") 132 | parser.add_argument("--profile", action="store_true", help="generate torch profile or not") 133 | args = parser.parse_args() 134 | 135 | torch.cuda.empty_cache() 136 | if rank == 0: 137 | print(f"{zigzag_ring_flash_attn_qkvpacked_func.__name__} BS:{args.batch_size} seq_len:{args.seq_len} nheads:{args.nheads} head_size:{args.head_size}, fwd_only: {args.fwd_only}") 138 | benchmark(zigzag_ring_flash_attn_qkvpacked_func, num_iter=500, forward_only=args.fwd_only, log=True, profile=args.profile) 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /sequence_parallel/readme_usp.md: -------------------------------------------------------------------------------- 1 | # USP 性能基准 2 | 3 | 好了,接下来可以改变参数记录下USP性能基准(设备为 8 * 4090D)(bs=2, nheads=16, head_size=16),如下表所示: 4 | 5 | - seq_len 与 ring attention type 的影响 6 | 7 | | ring type | seq_len | ulysses_degree | ring_degree | fwd_only | throughput(iters/s) | latency(ms/iter) | peak memory(MB/device) | speed(TFLOPS) | 8 | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 9 | | basic | 4096 | 2 | 4 | True | 84.019 | 11.902 | 240 | 11.5 | 10 | | basic | 8192 | 2 | 4 | True | 84.852 | 11.785 | 480 | 46.6 | 11 | | basic | 16384 | 2 | 4 | True | 44.255 | 22.596 | 960 | 97.3 | 12 | | basic | 32768 | 2 | 4 | True | 12.837 | 77.898 | 1920 | 112.9 | 13 | | basic | 65536 | 2 | 4 | True | 6.444 | 155.191 | 3841 | 226.7 | 14 | | basic | 128000 | 2 | 4 | True | 2.644 | 378.174 | 7518.9 | 354.9 | 15 | | stripe | 4096 | 2 | 4 | True | 59.643 | 16.767 | 244.1 | 8.2 | 16 | | stripe | 8192 | 2 | 4 | True | 48.455 | 20.638 | 488.3 | 26.6 | 17 | | stripe | 16384 | 2 | 4 | True | 31.839 | 31.408 | 976.7 | 70.0 | 18 | | stripe | 32768 | 2 | 4 | True | 17.832 | 56.078 | 1953.5 | 156.8 | 19 | | stripe | 65536 | 2 | 4 | True | 6.540 | 152.911 | 3907.0 | 230.1 | 20 | | stripe | 128000 | 2 | 4 | True | 3.184 | 314.057 | 7638.8 | 427.3 | 21 | | zigzag | 4096 | 2 | 4 | True | 71.727 | 13.942 | 240.1 | 9.8 | 22 | | zigzag | 8192 | 2 | 4 | True | 44.264 | 22.592 | 480.3 | 24.3 | 23 | | zigzag | 16384 | 2 | 4 | True | 32.001 | 31.249 | 960.5 | 70.4 | 24 | | zigzag | 32768 | 2 | 4 | True | 17.036 | 58.700 | 1921 | 149.8 | 25 | | zigzag | 65536 | 2 | 4 | True | 6.026 | 165.935 | 3842.0 | 212.0 | 26 | | zigzag | 128000 | 2 | 4 | True | 3.123 | 320.243 | 7514.0 | 419.1 | 27 | 28 | 29 | - seq_len 与 degree 的影响 30 | 31 | | ring type | seq_len | ulysses_degree | ring_degree | fwd_only | throughput(iters/s) | latency(ms/iter) | peak memory(MB/device) | speed(TFLOPS) | 32 | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 33 | | zigzag | 4096 | 2 | 4 | True | 71.727 | 13.942 | 240.1 | 9.8 | 34 | | zigzag | 8192 | 2 | 4 | True | 44.264 | 22.592 | 480.3 | 24.3 | 35 | | zigzag | 16384 | 2 | 4 | True | 32.001 | 31.249 | 960.5 | 70.4 | 36 | | zigzag | 32768 | 2 | 4 | True | 17.036 | 58.700 | 1921 | 149.8 | 37 | | zigzag | 65536 | 2 | 4 | True | 6.026 | 165.935 | 3842.0 | 212.0 | 38 | | zigzag | 128000 | 2 | 4 | True | 3.123 | 320.243 | 7514.0 | 419.1 | 39 | | zigzag | 4096 | 4 | 2 | True | 203.134 | 4.923 | 240.1 | 27.9 | 40 | | zigzag | 8192 | 4 | 2 | True | 102.374 | 9.768 | 480.3 | 56.3 | 41 | | zigzag | 16384 | 4 | 2 | True | 51.234 | 19.518 | 960.5 | 112.7 | 42 | | zigzag | 32768 | 4 | 2 | True | 19.720 | 50.709 | 1921 | 173.5 | 43 | | zigzag | 65536 | 4 | 2 | True | 8.613 | 116.098 | 3842.0 | 303.1 | 44 | | zigzag | 128000 | 4 | 2 | True | 3.639 | 274.807 | 7514.0 | 488.4 | 45 | 46 | 47 | - seq_len 与 反向 的影响 48 | 49 | | ring type | seq_len | ulysses_degree | ring_degree | fwd_only | throughput(iters/s) | latency(ms/iter) | peak memory(MB/device) | speed(TFLOPS) | 50 | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 51 | | zigzag | 4096 | 2 | 4 | True | 71.727 | 13.942 | 240.1 | 9.8 | 52 | | zigzag | 8192 | 2 | 4 | True | 44.264 | 22.592 | 480.3 | 24.3 | 53 | | zigzag | 16384 | 2 | 4 | True | 32.001 | 31.249 | 960.5 | 70.4 | 54 | | zigzag | 32768 | 2 | 4 | True | 17.036 | 58.700 | 1921 | 149.8 | 55 | | zigzag | 65536 | 2 | 4 | True | 6.026 | 165.935 | 3842.0 | 212.0 | 56 | | zigzag | 128000 | 2 | 4 | True | 3.123 | 320.243 | 7514.0 | 419.1 | 57 | | zigzag | 4096 | 2 | 4 | False | 31.615 | 31.631 | 240.1 | 15.2 | 58 | | zigzag | 8192 | 2 | 4 | False | 20.200 | 49.505 | 480.3 | 38.8 | 59 | | zigzag | 16384 | 2 | 4 | False | 10.513 | 95.121 | 960.5 | 80.9 | 60 | | zigzag | 32768 | 2 | 4 | False | 5.714 | 175.013 | 1921 | 175.9 | 61 | | zigzag | 65536 | 2 | 4 | False | 2.134 | 468.519 | 3842.0 | 262.8 | 62 | | zigzag | 128000 | 2 | 4 | False | 1.170 | 854.512 | 7514.0 | 549.7 | 63 | 64 | 65 | # loongtrain double ring 性能基准 66 | 67 | 接下来记录一下 Loongtrain double ring attention 的性能基准(设备为 8 * 4090D)(bs=2, nheads=16, head_size=16),如下表所示: 68 | 69 | | ring type | seq_len | context_size | window_size | fwd_only | throughput(iters/s) | latency(ms/iter) | peak memory(MB/device) | speed(TFLOPS) | 70 | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 71 | | zigzag | 4096 | 8 | 4 | True | 57.112 | 17.509 | 228.1 | 7.8 | 72 | | zigzag | 8192 | 8 | 4 | True | 46.554 | 21.480 | 456.3 | 25.6 | 73 | | zigzag | 16384 | 8 | 4 | True | 30.797 | 32.471 | 912.7 | 67.7 | 74 | | zigzag | 32768 | 8 | 4 | True | 17.982 | 55.610 | 1825.5 | 158.2 | 75 | | zigzag | 65536 | 8 | 4 | True | 7.416 | 134.849 | 3651.0 | 260.9 | 76 | | zigzag | 128000 | 8 | 4 | True | 3.270 | 305.810 | 7145.8 | 438.8 | 77 | | zigzag | 4096 | 8 | 4 | False | 29.410 | 34.002 | 228.1 | 14.1 | 78 | | zigzag | 8192 | 8 | 4 | False | 18.230 | 54.853 | 456.4 | 35.1 | 79 | | zigzag | 16384 | 8 | 4 | False | 10.783 | 92.736 | 912.7 | 83.0 | 80 | | zigzag | 32768 | 8 | 4 | False | 5.580 | 179.200 | 1825.5 | 171.8 | 81 | | zigzag | 65536 | 8 | 4 | False | 1.774 | 563.702 | 3651.0 | 218.4 | 82 | | zigzag | 128000 | 8 | 4 | False | 0.998 | 1001.561 | 7145.8 | 469.0 | 83 | 84 | 85 | - 与 window_size 的关系 86 | 87 | | ring type | seq_len | context_size | window_size | fwd_only | throughput(iters/s) | latency(ms/iter) | peak memory(MB/device) | speed(TFLOPS) | 88 | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 89 | | zigzag | 4096 | 8 | 2 | True | 71.676 | 13.952 | 228.1 | 9.8 | 90 | | zigzag | 8192 | 8 | 2 | True | 52.277 | 19.129 | 456.3 | 28.7 | 91 | | zigzag | 16384 | 8 | 2 | True | 31.828 | 31.419 | 912.7 | 69.9 | 92 | | zigzag | 32768 | 8 | 2 | True | 17.982 | 50.553 | 1825.5 | 173.9 | 93 | | zigzag | 65536 | 8 | 2 | True | 8.186 | 122.157 | 3651.0 | 288.0 | 94 | | zigzag | 128000 | 8 | 2 | True | 3.206 | 311.917 | 7144.8 | 430.3 | 95 | -------------------------------------------------------------------------------- /verl_test/agent_test/vllm_async_rollout_perf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Compare vLLM AsyncLLM backend: ExternalRayDistributedExecutor(remote call) vs RayDistributedExecutor(compiled graph) 16 | 17 | 1. Prepare openai/gsm8k dataset 18 | python3 examples/data_preprocess/gsm8k.py 19 | 20 | 2. Run perf test 21 | python3 agent_test/vllm_async_rollout_perf.py > perf.log 2>&1 22 | 23 | hardware: Nvidia 8*4090D 24 | packages: 25 | - torch==2.7.0 26 | - vllm==0.9.0.1 27 | 28 | [DEBUG] backend: sync, n_gpus_per_node: 8, batch_size: 128, step: 0, step_time: 15.25 secs 29 | [DEBUG] backend: zeromq, n_gpus_per_node: 8, batch_size: 128, step: 0, step_time: 14.77 secs 30 | [DEBUG] backend: ray, n_gpus_per_node: 8, batch_size: 128, step: 0, step_time: 17.59 secs 31 | """ 32 | import sys 33 | 34 | VERL_PATH = "/data2/zzd/rl_llm/verl" 35 | sys.path.append(VERL_PATH) 36 | 37 | import os 38 | import time 39 | 40 | import ray 41 | from omegaconf import DictConfig 42 | from torch.utils.data import SequentialSampler 43 | from torchdata.stateful_dataloader import StatefulDataLoader 44 | 45 | from agent_utils import AgentLoopManager, RayWorkerGroup, init_agent_loop_manager 46 | from verl.protocol import DataProto 47 | from verl.utils import hf_tokenizer 48 | from verl.utils.dataset import RLHFDataset 49 | from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn 50 | 51 | 52 | def init_config(n_gpus_per_node) -> DictConfig: 53 | from hydra import compose, initialize_config_dir 54 | 55 | with initialize_config_dir(config_dir=os.path.abspath(f"{VERL_PATH}/verl/trainer/config")): 56 | config = compose( 57 | config_name="ppo_trainer", 58 | overrides=[ 59 | "actor_rollout_ref.actor.use_dynamic_bsz=true", 60 | "actor_rollout_ref.actor.fsdp_config.param_offload=True", 61 | "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True", 62 | ], 63 | ) 64 | config.trainer.n_gpus_per_node = n_gpus_per_node 65 | config.data.train_batch_size = 128 66 | config.data.return_raw_chat = True 67 | config.actor_rollout_ref.model.path = "/data3/ckpt/Qwen/Qwen2.5-7B-Instruct" 68 | config.actor_rollout_ref.rollout.mode = "async" 69 | config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 70 | config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9 71 | config.actor_rollout_ref.rollout.multi_turn.format = "hermes" 72 | config.actor_rollout_ref.rollout.prompt_length = 4096 73 | config.actor_rollout_ref.rollout.response_length = 4096 74 | config.actor_rollout_ref.rollout.n = 16 75 | 76 | return config 77 | 78 | 79 | def initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]: 80 | env_vars = { 81 | "NCCL_DEBUG": "WARN", 82 | "VLLM_USE_V1": "1", 83 | "VERL_VLLM_DISTRIBUTED_BACKEND": backend, 84 | } 85 | ray.init(runtime_env={"env_vars": env_vars, "working_dir": VERL_PATH,}) 86 | 87 | # STEP 1: init async llm server 88 | server = init_agent_loop_manager(config) 89 | 90 | # STEP 2: create dataloader 91 | tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path) 92 | dataset = RLHFDataset( 93 | data_files=os.path.expanduser("/data2/zzd/data/GSM8K/train.parquet"), 94 | tokenizer=tokenizer, 95 | config=config.data, 96 | ) 97 | dataloader = StatefulDataLoader( 98 | dataset=dataset, 99 | batch_size=config.data.get("gen_batch_size", config.data.train_batch_size), 100 | num_workers=config.data.get("dataloader_num_workers", 8), 101 | drop_last=True, 102 | collate_fn=default_collate_fn, 103 | sampler=SequentialSampler(dataset), 104 | ) 105 | 106 | return server, dataloader 107 | 108 | 109 | def perf_rollout(mode, backend, n_gpus_per_node, num_steps): 110 | config = init_config(n_gpus_per_node) 111 | config.actor_rollout_ref.rollout.mode = mode 112 | agent_loop_manager, dataloader = initialize(config, backend) 113 | 114 | for step, batch in enumerate(dataloader): 115 | batch: DataProto = DataProto.from_single_dict(batch) 116 | batch = batch.pop( 117 | batch_keys=["input_ids", "attention_mask", "position_ids"], 118 | non_tensor_batch_keys=["raw_prompt_ids", "raw_prompt"], 119 | ) 120 | t_start = time.time() 121 | gen_batch = agent_loop_manager.generate_sequences(batch) 122 | t_end = time.time() 123 | print( 124 | f"[DEBUG] backend: {backend}, n_gpus_per_node: {n_gpus_per_node}, batch_size: {len(gen_batch)}, " 125 | f"step: {step}, step_time: {t_end - t_start:.2f} secs" 126 | ) 127 | if step + 1 >= num_steps: 128 | break 129 | 130 | ray.shutdown() 131 | 132 | 133 | if __name__ == "__main__": 134 | num_steps = 1 135 | n_gpus_per_node = 8 136 | 137 | test_cases = [("sync", "sync"), ("async", "zeromq"), ("async", "ray")] 138 | # test_cases = [("async", "zeromq"), ("async", "ray")] 139 | for mode, backend in test_cases: 140 | perf_rollout(mode=mode, backend=backend, n_gpus_per_node=n_gpus_per_node, num_steps=num_steps) 141 | -------------------------------------------------------------------------------- /verl_test/scripts/train_dapo_32b.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -xeuo pipefail 3 | 4 | 5 | adv_estimator=grpo 6 | 7 | use_kl_in_reward=False 8 | kl_coef=0.0 9 | use_kl_loss=False 10 | kl_loss_coef=0.0 11 | 12 | clip_ratio_low=0.2 13 | clip_ratio_high=0.28 14 | 15 | max_prompt_length=$((1024 * 1)) 16 | max_response_length=$((1024 * 1)) 17 | enable_overlong_buffer=True 18 | overlong_buffer_len=$((1024 * 4)) 19 | overlong_penalty_factor=1.0 20 | 21 | loss_agg_mode="token-mean" 22 | 23 | enable_filter_groups=True 24 | filter_groups_metric=acc 25 | max_num_gen_batches=10 26 | train_prompt_bsz=256 27 | gen_prompt_bsz=$((train_prompt_bsz * 3)) 28 | n_resp_per_prompt=16 29 | train_prompt_mini_bsz=32 30 | 31 | 32 | # Ray 33 | RAY_ADDRESS=${RAY_ADDRESS:-"http://10.157.150.10:8265"} 34 | WORKING_DIR=${WORKING_DIR:-"${PWD}"} 35 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} 36 | # Paths 37 | RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} 38 | MODEL_PATH=${MODEL_PATH:-"/mnt/nvme0/zzd/ckpt/Qwen/Qwen2.5-32B"} 39 | CKPTS_DIR=${CKPTS_DIR:-"$/mnt/nvme0/zzd/ckpt/verl_dapo_32b"} 40 | TRAIN_FILE=${TRAIN_FILE:-"/mnt/nvme0/zzd/data/dapo_data/dapo-math-17k.parquet"} 41 | TEST_FILE=${TEST_FILE:-"/mnt/nvme0/zzd/data/dapo_data/aime-2024.parquet"} 42 | 43 | 44 | # Algorithm 45 | temperature=1.0 46 | top_p=1.0 47 | top_k=-1 # 0 for HF rollout, -1 for vLLM rollout 48 | val_top_p=0.7 49 | 50 | # Performance Related Parameter 51 | sp_size=8 52 | use_dynamic_bsz=True 53 | actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) 54 | infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) 55 | offload=True 56 | gen_tp=4 57 | 58 | ray job submit --runtime-env="${RUNTIME_ENV}" \ 59 | --working-dir "${WORKING_DIR}" \ 60 | -- python3 -m recipe.dapo.src.main_dapo \ 61 | data.train_files="${TRAIN_FILE}" \ 62 | data.val_files="${TEST_FILE}" \ 63 | data.prompt_key=prompt \ 64 | data.truncation='left' \ 65 | data.max_prompt_length=${max_prompt_length} \ 66 | data.max_response_length=${max_response_length} \ 67 | data.gen_batch_size=${gen_prompt_bsz} \ 68 | data.train_batch_size=${train_prompt_bsz} \ 69 | actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ 70 | algorithm.adv_estimator=${adv_estimator} \ 71 | algorithm.use_kl_in_reward=${use_kl_in_reward} \ 72 | algorithm.kl_ctrl.kl_coef=${kl_coef} \ 73 | actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ 74 | actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ 75 | actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ 76 | actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ 77 | actor_rollout_ref.actor.clip_ratio_c=10.0 \ 78 | algorithm.filter_groups.enable=${enable_filter_groups} \ 79 | algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ 80 | algorithm.filter_groups.metric=${filter_groups_metric} \ 81 | actor_rollout_ref.model.use_remove_padding=True \ 82 | actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ 83 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ 84 | actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ 85 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ 86 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ 87 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ 88 | actor_rollout_ref.model.path="${MODEL_PATH}" \ 89 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 90 | actor_rollout_ref.actor.optim.lr=1e-6 \ 91 | actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ 92 | actor_rollout_ref.actor.optim.weight_decay=0.1 \ 93 | actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ 94 | actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ 95 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ 96 | actor_rollout_ref.actor.entropy_coeff=0 \ 97 | actor_rollout_ref.actor.grad_clip=1.0 \ 98 | actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ 99 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ 100 | actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ 101 | actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ 102 | actor_rollout_ref.rollout.enable_chunked_prefill=True \ 103 | actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ 104 | actor_rollout_ref.rollout.temperature=${temperature} \ 105 | actor_rollout_ref.rollout.top_p=${top_p} \ 106 | actor_rollout_ref.rollout.top_k="${top_k}" \ 107 | actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ 108 | actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ 109 | actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ 110 | actor_rollout_ref.rollout.val_kwargs.do_sample=True \ 111 | actor_rollout_ref.rollout.val_kwargs.n=1 \ 112 | actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ 113 | actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ 114 | actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ 115 | reward_model.reward_manager=dapo \ 116 | reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ 117 | reward_model.overlong_buffer.len=${overlong_buffer_len} \ 118 | reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ 119 | trainer.logger=['console','wandb'] \ 120 | trainer.project_name="verl-h20" \ 121 | trainer.experiment_name="dapo_math_32b_0502" \ 122 | trainer.n_gpus_per_node=8 \ 123 | trainer.nnodes=2 \ 124 | trainer.val_before_train=True \ 125 | trainer.test_freq=5 \ 126 | trainer.save_freq=100 \ 127 | trainer.total_epochs=1 \ 128 | trainer.default_local_dir="${CKPTS_DIR}" \ 129 | trainer.resume_mode=auto --------------------------------------------------------------------------------