├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ └── config.cpython-311.pyc ├── config_tinystories.json └── config.py ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ └── dataloader.cpython-311.pyc └── dataloader.py ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ └── layers.cpython-311.pyc └── layers.py ├── model ├── __init__.py ├── __pycache__ │ ├── model.cpython-311.pyc │ └── __init__.cpython-311.pyc └── model.py ├── tests ├── __init__.py ├── test_ring_all_gather.py ├── test_ring_reduce_scatter.py ├── test_ring_p2p.py ├── test_dataloader.py ├── test_all_reduce.py └── test_ring_attention.py ├── utils ├── __init__.py ├── utils.py └── hf_utils.py ├── parallel ├── communication │ ├── __init__.py │ ├── ring_p2p.py │ ├── ring_all_gather.py │ ├── ring_reduce_scatter.py │ ├── all_reduce.py │ └── ring_attention.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── distributed.cpython-311.pyc │ ├── parallel_cp.cpython-311.pyc │ ├── parallel_dp.cpython-311.pyc │ ├── parallel_ep.cpython-311.pyc │ ├── parallel_pp.cpython-311.pyc │ ├── parallel_tp.cpython-311.pyc │ └── parallel_fsdp.cpython-311.pyc ├── __init__.py ├── parallel_dp.py ├── parallel_pp.py ├── distributed.py ├── parallel_tp.py ├── parallel_fsdp.py ├── parallel_ep.py └── parallel_cp.py ├── requirements.txt ├── setup.py ├── CITATION.cff ├── LICENSE ├── README.md └── trainer.py /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /parallel/communication/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb==0.23.1 2 | torch==2.3.0 3 | triton==2.3.0 4 | numpy==1.26.4 5 | datasets==2.19.1 6 | transformers==4.47.0 -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/model/__pycache__/model.cpython-311.pyc -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/config/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /config/__pycache__/config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/config/__pycache__/config.cpython-311.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/data/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/data/__pycache__/dataloader.cpython-311.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/layers/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /layers/__pycache__/layers.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/layers/__pycache__/layers.cpython-311.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/model/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/distributed.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/distributed.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/parallel_cp.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/parallel_cp.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/parallel_dp.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/parallel_dp.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/parallel_ep.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/parallel_ep.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/parallel_pp.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/parallel_pp.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/parallel_tp.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/parallel_tp.cpython-311.pyc -------------------------------------------------------------------------------- /parallel/__pycache__/parallel_fsdp.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lwj2015/lightron/HEAD/parallel/__pycache__/parallel_fsdp.cpython-311.pyc -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="lightron", 5 | version="0.1.0", 6 | description="A lightweight, modern distributed training framework for LLMs", 7 | author="Your Name", 8 | packages=find_packages(), 9 | install_requires=[ 10 | "torch==2.1.0", 11 | "numpy", 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: 'Lightron: A Modern Minimalist Distributed Training Framework' 3 | message: >- 4 | If you use Lightron, please cite it using the 5 | metadata from this file. 6 | type: software 7 | authors: 8 | - given-names: Wenjun 9 | family-names: Liu 10 | repository-code: 'https://github.com/lwj2015/lightron' 11 | abstract: "Lightron is a modern minimalist distributed training framework for educational purpose." 12 | keywords: 13 | - LLM training framework 14 | - distributed training 15 | license: MIT 16 | version: 1.0.1 -------------------------------------------------------------------------------- /parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # lightron/parallel/__init__.py 2 | 3 | from .distributed import setup_distributed, get_device_mesh 4 | from .parallel_fsdp import apply_fsdp1, apply_fsdp2 5 | from .parallel_tp import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding 6 | from .parallel_cp import ContextParallelAttention 7 | from .parallel_ep import ExpertParallel 8 | from .parallel_pp import PipelineStage 9 | from .parallel_dp import DataParallel 10 | 11 | __all__ = [ 12 | "setup_distributed", "get_device_mesh", 13 | "apply_fsdp1", "apply_fsdp2", 14 | "ColumnParallelLinear", "RowParallelLinear", "VocabParallelEmbedding", 15 | "ContextParallelAttention", 16 | "ExpertParallel", 17 | "PipelineStage", 18 | "DataParallel" 19 | ] 20 | -------------------------------------------------------------------------------- /config/config_tinystories.json: -------------------------------------------------------------------------------- 1 | { 2 | "distributed": { 3 | "tp_size": 2, 4 | "dp_size": 4, 5 | "pp_size": 1, 6 | "cp_size": 1, 7 | "ep_size": 1 8 | }, 9 | "model": { 10 | "name": "gpt2", 11 | "moe_num_experts": 1, 12 | "moe_topk": 2 13 | }, 14 | "dataset": { 15 | "name": "roneneldan/TinyStories", 16 | "split": "train", 17 | "num_workers": 0 18 | }, 19 | "training": { 20 | "micro_batch_size": 4, 21 | "gradient_accumulation_steps": 2, 22 | "seq_length": 128, 23 | "learning_rate": 3e-4, 24 | "weight_decay": 0.01, 25 | "total_steps": 50, 26 | "log_interval": 5, 27 | "max_samples": 5000 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArgs: 7 | dim: int = 4096 8 | n_layers: int = 32 9 | n_heads: int = 32 10 | n_kv_heads: Optional[int] = None # For GQA 11 | moe_num_experts: int = 0 12 | moe_topk: int = 0 13 | moe_layer_freq: int = 2 14 | vocab_size: int = 32000 15 | multiple_of: int = 256 # MLP hidden dim multiple 16 | ffn_dim_multiplier: Optional[float] = None 17 | norm_eps: float = 1e-5 18 | max_seq_len: int = 2048 19 | 20 | # parallel_mode: str = 'fsdp1' 21 | tp_size: int = 1 22 | cp_size: int = 1 23 | device_mesh_shape: tuple = (1, 1) 24 | 25 | def __post_init__(self): 26 | if self.n_kv_heads is None: 27 | self.n_kv_heads = self.n_heads -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | @torch.no_grad() 4 | def evaluate(model, tokenizer, eval_dataset, max_batches=10): 5 | model.eval() 6 | total_loss = 0 7 | steps = 0 8 | 9 | for i, batch in enumerate(eval_dataset): 10 | if i >= max_batches: break 11 | inputs = batch.to(model.device) # 假设 batch 已经是 tensor 12 | logits = model(inputs) 13 | 14 | # 简单的 Next Token Prediction Loss 15 | # Shift logits and labels 16 | shift_logits = logits[..., :-1, :].contiguous() 17 | shift_labels = inputs[..., 1:].contiguous() 18 | 19 | loss = torch.nn.functional.cross_entropy( 20 | shift_logits.view(-1, shift_logits.size(-1)), 21 | shift_labels.view(-1) 22 | ) 23 | total_loss += loss.item() 24 | steps += 1 25 | 26 | avg_loss = total_loss / steps 27 | perplexity = torch.exp(torch.tensor(avg_loss)) 28 | model.train() 29 | return avg_loss, perplexity.item() 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Wenjun Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /parallel/parallel_dp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from .distributed import get_device_mesh 5 | 6 | 7 | class DataParallel(nn.Module): 8 | """ 9 | 手动实现的 DDP (类似于 PyTorch DDP,但为了配合 4D 并行而简化)。 10 | 它不自动切分数据,只负责梯度的 All-Reduce。 11 | """ 12 | 13 | def __init__(self, module): 14 | super().__init__() 15 | self.module = module 16 | self.mesh = get_device_mesh() 17 | # 获取 DP 组 (可能包含 FSDP/DDP 语义) 18 | # 在 4D 并行中,DP 组通常是 mesh["dp"] 19 | if self.mesh and "dp" in self.mesh.mesh_dim_names: 20 | self.dp_group = self.mesh["dp"].get_group() 21 | else: 22 | self.dp_group = None 23 | 24 | # 注册 hook,在反向传播时自动同步梯度 25 | # 为了简化,我们这里不使用 bucket (桶),而是逐个参数同步,或者在 step 前手动同步 26 | # 工业级实现会使用 Bucket 来合并小梯度通信 27 | for param in self.module.parameters(): 28 | if param.requires_grad: 29 | param.register_post_accumulate_grad_hook(self._all_reduce_hook) 30 | 31 | def _all_reduce_hook(self, param): 32 | if self.dp_group is None or param.grad is None: 33 | return 34 | 35 | # 异步 All-Reduce 36 | # 注意:这里需要除以 dp_size 来求平均 37 | dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=self.dp_group) 38 | param.grad.div_(dist.get_world_size(group=self.dp_group)) 39 | 40 | def forward(self, *args, **kwargs): 41 | return self.module(*args, **kwargs) 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lightron 2 | 3 | **Lightron** is a lightweight, educational, yet modern distributed training framework for LLMs. 4 | Lightron aims to bridge the gap between minimal implementations and modern production features such as **4-D Parallelism**, including **Tensor Parallelism, Pipeline Parallelism, Data Parallelism**, and **Context Parallelism**. 5 | 6 | # Key Features 7 | 8 | - **Distributed Ready**: Support 4-D Parallelism(TP, PP, DP, CP), EP and FSDP V2. 9 | - **Modern Architecture**: RMSNorm, SwiGLU, Rotary Embeddings (RoPE), FlashAttention V2. 10 | - **Clean Code**: Type-hinted, dataclass-based configuration, <1000 lines of core code. 11 | 12 | # Installation 13 | ```bash 14 | git clone https://github.com/lwj2015/lightron.git 15 | cd lightron 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | # Quick Start 20 | ```bash 21 | # run on local machine with 8 GPUs, tp_size=2, dp_size=4 22 | torchrun --nproc_per_node=8 trainer.py --config examples/config_tinystories.json 23 | ``` 24 | 25 | # Local Test 26 | Test All Reduce Communication 27 | ```bash 28 | torchrun --nproc_per_node=8 tests/test_all_reduce.py 29 | ``` 30 | 31 | Test Ring Attention 32 | ```bash 33 | python tests/test_ring_attention.py 34 | ``` 35 | 36 | Test DataLoader 37 | ```bash 38 | torchrun --nproc_per_node=8 tests/test_dataloader.py 39 | ``` 40 | 41 | # Citation 42 | 43 | If you use Lightron in your research or learning journey, please cite it as follows: 44 | ```bash 45 | @misc{lightron2025, 46 | author = {Wenjun Liu}, 47 | title = {Lightron: A Modern Minimalist Distributed Training Framework}, 48 | year = {2025}, 49 | publisher = {GitHub}, 50 | journal = {GitHub repository}, 51 | howpublished = {\url{https://github.com/lwj2015/lightron}} 52 | } 53 | ``` 54 | 55 | 56 | -------------------------------------------------------------------------------- /tests/test_ring_all_gather.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.distributed.device_mesh import init_device_mesh 5 | from parallel.communication.ring_all_gather import ring_all_gather 6 | 7 | 8 | def init_dist(): 9 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 10 | rank = int(os.environ.get("RANK", 0)) 11 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 12 | 13 | torch.cuda.set_device(local_rank) 14 | device = torch.device(f"cuda:{local_rank}") 15 | 16 | if not dist.is_initialized(): 17 | dist.init_process_group(backend="nccl", device_id=device) 18 | 19 | return rank, world_size, device 20 | 21 | 22 | def main(): 23 | rank, world_size, device = init_dist() 24 | 25 | # 构建 4D Device Mesh 26 | if world_size < 8: 27 | if rank == 0: print("⚠️ Warning: Less than 8 GPUs, using simplified mesh.") 28 | mesh_shape = (1, 1, 2, 2) 29 | else: 30 | mesh_shape = (2, 1, 2, 2) 31 | 32 | mesh_dim_names = ("dp", "pp", "tp", "cp") 33 | mesh = init_device_mesh("cuda", mesh_shape, mesh_dim_names=mesh_dim_names) 34 | 35 | if rank == 0: 36 | print(f"\n🚀 Device Mesh Created: {mesh_shape} {mesh_dim_names}") 37 | 38 | dist.barrier() 39 | 40 | # 测试 Ring All-Gather (在 CP 组内) 41 | cp_group = mesh["cp"].get_group() 42 | local_vec = torch.arange(10, device=device, dtype=torch.float32) + rank * 100 43 | 44 | gather_list = [torch.zeros_like(local_vec) for _ in range(dist.get_world_size(cp_group))] 45 | dist.all_gather(gather_list, local_vec, group=cp_group) 46 | ref_gather = torch.cat(gather_list) 47 | 48 | my_gather = ring_all_gather(local_vec, group=cp_group) 49 | 50 | diff = (ref_gather - my_gather).abs().max() 51 | if rank == 0: 52 | print(f"\n[CP Group] All-Gather Test:") 53 | print(f" Max Diff: {diff.item():.6f} {'✅' if diff < 1e-5 else '❌'}") 54 | 55 | dist.barrier() 56 | dist.destroy_process_group() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /parallel/parallel_pp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from .distributed import get_device_mesh 5 | 6 | 7 | class PipelineStage(nn.Module): 8 | """ 9 | PP Stage Wrapper. 10 | """ 11 | 12 | def __init__(self, module, stage_id, num_stages, chunk_id=0): 13 | super().__init__() 14 | self.module = module 15 | self.stage_id = stage_id 16 | self.num_stages = num_stages 17 | 18 | self.mesh = get_device_mesh() 19 | self.pp_group = self.mesh["pp"].get_group() if self.mesh and "pp" in self.mesh.mesh_dim_names else None 20 | 21 | if self.pp_group is None: 22 | raise ValueError("PP group not initialized") 23 | 24 | # 计算上下游 Rank 25 | # 注意:这里假设 PP group 内 rank 是连续的 0, 1, 2... 26 | # 实际需要根据 global rank 映射 27 | my_rank = dist.get_rank(group=self.pp_group) 28 | self.prev_rank = (my_rank - 1) if my_rank > 0 else None 29 | self.next_rank = (my_rank + 1) if my_rank < num_stages - 1 else None 30 | 31 | def forward(self, x=None): 32 | # PP 的 forward 比较特殊,通常由外部调度器调用 send/recv 33 | # 这里仅作为单步执行的逻辑 34 | return self.module(x) 35 | 36 | def send_forward(self, output): 37 | if self.next_rank is not None: 38 | dist.send(output.contiguous(), dst=self.next_rank, group=self.pp_group) 39 | 40 | def recv_forward(self, tensor_shape): 41 | if self.prev_rank is not None: 42 | buffer = torch.empty(tensor_shape, device="cuda", dtype=torch.bfloat16) 43 | dist.recv(buffer, src=self.prev_rank, group=self.pp_group) 44 | return buffer 45 | return None 46 | 47 | def send_backward(self, grad): 48 | if self.prev_rank is not None: 49 | dist.send(grad.contiguous(), dst=self.prev_rank, group=self.pp_group) 50 | 51 | def recv_backward(self, tensor_shape): 52 | if self.next_rank is not None: 53 | buffer = torch.empty(tensor_shape, device="cuda", dtype=torch.bfloat16) 54 | dist.recv(buffer, src=self.next_rank, group=self.pp_group) 55 | return buffer 56 | return None 57 | -------------------------------------------------------------------------------- /tests/test_ring_reduce_scatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.distributed.device_mesh import init_device_mesh 5 | from parallel.communication.ring_reduce_scatter import ring_reduce_scatter 6 | 7 | 8 | def init_dist(): 9 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 10 | rank = int(os.environ.get("RANK", 0)) 11 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 12 | 13 | torch.cuda.set_device(local_rank) 14 | device = torch.device(f"cuda:{local_rank}") 15 | 16 | if not dist.is_initialized(): 17 | dist.init_process_group(backend="nccl", device_id=device) 18 | 19 | return rank, world_size, device 20 | 21 | 22 | def get_global_rank(group, group_rank): 23 | return dist.get_global_rank(group, group_rank) 24 | 25 | 26 | def main(): 27 | rank, world_size, device = init_dist() 28 | 29 | # 构建 4D Device Mesh 30 | if world_size < 8: 31 | if rank == 0: print("⚠️ Warning: Less than 8 GPUs, using simplified mesh.") 32 | mesh_shape = (1, 1, 2, 2) 33 | else: 34 | mesh_shape = (2, 1, 2, 2) 35 | 36 | mesh_dim_names = ("dp", "pp", "tp", "cp") 37 | mesh = init_device_mesh("cuda", mesh_shape, mesh_dim_names=mesh_dim_names) 38 | 39 | if rank == 0: 40 | print(f"\n🚀 Device Mesh Created: {mesh_shape} {mesh_dim_names}") 41 | 42 | dist.barrier() 43 | 44 | # 测试 Reduce-Scatter (在 DP 组内) 45 | dp_group = mesh["dp"].get_group() 46 | dp_world_size = dist.get_world_size(dp_group) 47 | 48 | input_list = [torch.ones(10, device=device) * (rank + 1) * (i + 1) for i in range(dp_world_size)] 49 | 50 | ref_out = torch.zeros(10, device=device) 51 | dist.reduce_scatter(ref_out, input_list, group=dp_group) 52 | 53 | input_list_2 = [torch.ones(10, device=device) * (rank + 1) * (i + 1) for i in range(dp_world_size)] 54 | my_out = ring_reduce_scatter(input_list_2, group=dp_group) 55 | 56 | diff = (ref_out - my_out).abs().max() 57 | if dist.get_rank(dp_group) == 0 and rank == 0: 58 | print(f"\n[DP Group] Reduce-Scatter Test:") 59 | print(f" Max Diff: {diff.item():.6f} {'✅' if diff < 1e-5 else '❌'}") 60 | 61 | dist.barrier() 62 | dist.destroy_process_group() 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /tests/test_ring_p2p.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.distributed.device_mesh import init_device_mesh 5 | from parallel.communication.ring_p2p import ring_p2p 6 | 7 | 8 | def init_dist(): 9 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 10 | rank = int(os.environ.get("RANK", 0)) 11 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 12 | 13 | torch.cuda.set_device(local_rank) 14 | device = torch.device(f"cuda:{local_rank}") 15 | 16 | if not dist.is_initialized(): 17 | dist.init_process_group(backend="nccl", device_id=device) 18 | 19 | return rank, world_size, device 20 | 21 | 22 | def get_global_rank(group, group_rank): 23 | return dist.get_global_rank(group, group_rank) 24 | 25 | 26 | def main(): 27 | rank, world_size, device = init_dist() 28 | 29 | # 构建 4D Device Mesh 30 | if world_size < 8: 31 | if rank == 0: print("⚠️ Warning: Less than 8 GPUs, using simplified mesh.") 32 | mesh_shape = (1, 1, 2, 2) 33 | else: 34 | mesh_shape = (2, 1, 2, 2) 35 | 36 | mesh_dim_names = ("dp", "pp", "tp", "cp") 37 | mesh = init_device_mesh("cuda", mesh_shape, mesh_dim_names=mesh_dim_names) 38 | 39 | if rank == 0: 40 | print(f"\n🚀 Device Mesh Created: {mesh_shape} {mesh_dim_names}") 41 | 42 | dist.barrier() 43 | 44 | # 测试 P2P (在 PP 组内) 45 | pp_group = mesh["pp"].get_group() 46 | pp_rank = dist.get_rank(pp_group) 47 | pp_size = dist.get_world_size(pp_group) 48 | 49 | if pp_size > 1: 50 | tensor_p2p = torch.tensor([rank * 1.0], device=device) 51 | recv_p2p = torch.tensor([-1.0], device=device) 52 | 53 | if pp_rank % 2 == 0: 54 | target_logical = pp_rank + 1 55 | if target_logical < pp_size: 56 | ring_p2p(tensor_p2p, recv_p2p, pp_rank, target_logical, pp_group) 57 | else: 58 | src_logical = pp_rank - 1 59 | ring_p2p(tensor_p2p, recv_p2p, src_logical, pp_rank, pp_group) 60 | 61 | if rank == 1: 62 | print(f"\n[PP Group] P2P Test (Rank 1 received from Rank 0):") 63 | print(f" Received: {recv_p2p.item()} (Expected 0.0) {'✅' if recv_p2p.item() == 0.0 else '❌'}") 64 | 65 | dist.barrier() 66 | dist.destroy_process_group() 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /utils/hf_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from model.model import LightronTransformer 4 | from config.config import ModelArgs 5 | 6 | 7 | def load_hf_llama_weights(model: LightronTransformer, hf_model_path: str): 8 | """ 9 | 从 HuggingFace Llama 加载权重到 Lightron。 10 | 需要处理 key 的映射 (例如: model.layers.0.self_attn.q_proj -> layers.0.attention.wq) 11 | """ 12 | from transformers import AutoModelForCausalLM 13 | print(f"Loading HF weights from {hf_model_path}...") 14 | hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16) 15 | hf_sd = hf_model.state_dict() 16 | 17 | my_sd = model.state_dict() 18 | 19 | # 简单的映射规则示例 (需要根据实际层名调整) 20 | mapping = { 21 | "model.embed_tokens.weight": "tok_embeddings.weight", 22 | "model.norm.weight": "norm.weight", 23 | "lm_head.weight": "output.weight" 24 | } 25 | 26 | # 遍历转换 27 | for hf_key, hf_val in hf_sd.items(): 28 | # 1. 处理基础层 29 | if hf_key in mapping: 30 | my_sd[mapping[hf_key]].copy_(hf_val) 31 | continue 32 | 33 | # 2. 处理 Block 层 34 | # HF: model.layers.0.self_attn.q_proj.weight 35 | # My: layers.0.attention.wq.weight 36 | if "layers" in hf_key: 37 | key_parts = hf_key.split(".") 38 | layer_idx = key_parts[2] 39 | 40 | new_key = None 41 | if "self_attn.q_proj" in hf_key: 42 | new_key = f"layers.{layer_idx}.attention.wq.weight" 43 | elif "self_attn.k_proj" in hf_key: 44 | new_key = f"layers.{layer_idx}.attention.wk.weight" 45 | elif "self_attn.v_proj" in hf_key: 46 | new_key = f"layers.{layer_idx}.attention.wv.weight" 47 | elif "self_attn.o_proj" in hf_key: 48 | new_key = f"layers.{layer_idx}.attention.wo.weight" 49 | elif "mlp.gate_proj" in hf_key: 50 | new_key = f"layers.{layer_idx}.feed_forward.w1.weight" 51 | elif "mlp.down_proj" in hf_key: 52 | new_key = f"layers.{layer_idx}.feed_forward.w2.weight" 53 | elif "mlp.up_proj" in hf_key: 54 | new_key = f"layers.{layer_idx}.feed_forward.w3.weight" 55 | elif "input_layernorm" in hf_key: 56 | new_key = f"layers.{layer_idx}.attention_norm.weight" 57 | elif "post_attention_layernorm" in hf_key: 58 | new_key = f"layers.{layer_idx}.ffn_norm.weight" 59 | 60 | if new_key and new_key in my_sd: 61 | my_sd[new_key].copy_(hf_val) 62 | 63 | model.load_state_dict(my_sd) 64 | print("Weights loaded successfully!") 65 | -------------------------------------------------------------------------------- /layers/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__(self, dim: int, eps: float = 1e-6): 7 | super().__init__() 8 | self.eps = eps 9 | self.weight = nn.Parameter(torch.ones(dim)) 10 | 11 | def _norm(self, x): 12 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 13 | 14 | def forward(self, x): 15 | output = self._norm(x.float()).type_as(x) 16 | return output * self.weight 17 | 18 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 19 | # 简化的 RoPE 预计算 20 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 21 | t = torch.arange(end, device=freqs.device) 22 | freqs = torch.outer(t, freqs).float() 23 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 24 | return freqs_cis 25 | 26 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 27 | """ 28 | freqs_cis: [S, D] 29 | x: [B, S, H, D] 30 | Target: [1, S, 1, D] 31 | """ 32 | ndim = x.ndim 33 | assert 0 <= 1 < ndim 34 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]), \ 35 | f"freqs_cis shape {freqs_cis.shape} does not match x shape {(x.shape[1], x.shape[-1])}" 36 | 37 | # 构造广播形状: [d if i==1 or i==ndim-1 else 1] 38 | # 对于 4D 输入 x [B, S, H, D],这将生成 [1, S, 1, D] 39 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 40 | 41 | return freqs_cis.view(*shape) 42 | 43 | 44 | def apply_rotary_emb(xq, xk, freqs_cis): 45 | # 将 Q, K 转为复数进行旋转 46 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 47 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 48 | 49 | # === 添加调试日志 === 50 | # import torch.distributed as dist 51 | # if dist.is_initialized() and dist.get_rank() == 0: 52 | # print(f"[Debug Rank 0] Inside apply_rotary_emb:") 53 | # print(f" xq_ (complex): {xq_.shape}") 54 | # print(f" xk_ (complex): {xk_.shape}") 55 | # print(f" freqs_cis (raw): {freqs_cis.shape}") 56 | # =================== 57 | 58 | # freqs_cis = freqs_cis[:xq.shape[1]] # 切片匹配 seq_len 59 | 60 | # 尝试广播 61 | # freqs_cis 需要 reshape 成 [1, S, 1, head_dim/2] 才能跟 xq_ [B, S, n_heads, head_dim/2] 相乘 62 | # 这里的 reshape 逻辑是关键 63 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 64 | # if dist.is_initialized() and dist.get_rank() == 0: 65 | # print(f" freqs_cis (broadcast): {freqs_cis.shape}") 66 | 67 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 68 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 69 | 70 | return xq_out.type_as(xq), xk_out.type_as(xk) 71 | -------------------------------------------------------------------------------- /parallel/communication/ring_p2p.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.distributed.device_mesh import init_device_mesh 5 | 6 | 7 | def init_dist(): 8 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 9 | rank = int(os.environ.get("RANK", 0)) 10 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 11 | 12 | torch.cuda.set_device(local_rank) 13 | device = torch.device(f"cuda:{local_rank}") 14 | 15 | if not dist.is_initialized(): 16 | dist.init_process_group(backend="nccl", device_id=device) 17 | 18 | return rank, world_size, device 19 | 20 | 21 | def get_global_rank(group, group_rank): 22 | return dist.get_global_rank(group, group_rank) 23 | 24 | 25 | def ring_p2p(send_tensor, recv_tensor, src_group_rank, dst_group_rank, group): 26 | my_group_rank = dist.get_rank(group) 27 | src_global_rank = get_global_rank(group, src_group_rank) 28 | dst_global_rank = get_global_rank(group, dst_group_rank) 29 | 30 | ops = [] 31 | if my_group_rank == src_group_rank: 32 | ops.append(dist.P2POp(dist.isend, send_tensor, dst_global_rank, group=group)) 33 | if my_group_rank == dst_group_rank: 34 | ops.append(dist.P2POp(dist.irecv, recv_tensor, src_global_rank, group=group)) 35 | 36 | if ops: 37 | reqs = dist.batch_isend_irecv(ops) 38 | for req in reqs: req.wait() 39 | 40 | 41 | def main(): 42 | rank, world_size, device = init_dist() 43 | 44 | # 构建 4D Device Mesh 45 | if world_size < 8: 46 | if rank == 0: print("⚠️ Warning: Less than 8 GPUs, using simplified mesh.") 47 | mesh_shape = (1, 1, 2, 2) 48 | else: 49 | mesh_shape = (2, 1, 2, 2) 50 | 51 | mesh_dim_names = ("dp", "pp", "tp", "cp") 52 | mesh = init_device_mesh("cuda", mesh_shape, mesh_dim_names=mesh_dim_names) 53 | 54 | if rank == 0: 55 | print(f"\n🚀 Device Mesh Created: {mesh_shape} {mesh_dim_names}") 56 | 57 | dist.barrier() 58 | 59 | # 测试 P2P (在 PP 组内) 60 | pp_group = mesh["pp"].get_group() 61 | pp_rank = dist.get_rank(pp_group) 62 | pp_size = dist.get_world_size(pp_group) 63 | 64 | if pp_size > 1: 65 | tensor_p2p = torch.tensor([rank * 1.0], device=device) 66 | recv_p2p = torch.tensor([-1.0], device=device) 67 | 68 | if pp_rank % 2 == 0: 69 | target_logical = pp_rank + 1 70 | if target_logical < pp_size: 71 | ring_p2p(tensor_p2p, recv_p2p, pp_rank, target_logical, pp_group) 72 | else: 73 | src_logical = pp_rank - 1 74 | ring_p2p(tensor_p2p, recv_p2p, src_logical, pp_rank, pp_group) 75 | 76 | if rank == 1: 77 | print(f"\n[PP Group] P2P Test (Rank 1 received from Rank 0):") 78 | print(f" Received: {recv_p2p.item()} (Expected 0.0) {'✅' if recv_p2p.item() == 0.0 else '❌'}") 79 | 80 | dist.barrier() 81 | dist.destroy_process_group() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /tests/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from data.dataloader import MicroBatchDataLoader 7 | from parallel.distributed import setup_distributed 8 | 9 | 10 | def test_dataloader(tp_size = 1, dp_size = 1): 11 | # 1. 模拟分布式环境初始化 12 | if not dist.is_initialized(): 13 | dist.init_process_group("nccl") 14 | local_rank = int(os.environ["LOCAL_RANK"]) 15 | torch.cuda.set_device(local_rank) 16 | setup_distributed(tp_size=tp_size, dp_size=dp_size) 17 | 18 | print("\n=== 1. Environment Setup ===") 19 | print(f"Rank: {dist.get_rank()}, World Size: {dist.get_world_size()}") 20 | 21 | # 2. 配置参数 22 | model_name = "gpt2" 23 | dataset_name = "roneneldan/TinyStories" 24 | seq_length = 128 25 | micro_batch_size = 4 26 | 27 | print(f"\n=== 2. Initializing DataLoader ===") 28 | print(f"Model: {model_name}, Dataset: {dataset_name}, SeqLen: {seq_length}") 29 | 30 | try: 31 | dataloader = MicroBatchDataLoader( 32 | micro_batch_size=micro_batch_size, 33 | seq_length=seq_length, 34 | dataset_name=dataset_name, # load_dataset 会自动处理 35 | tokenizer_name=model_name, 36 | split="train", 37 | max_samples=1000, # 只取前1000条,加快测试速度 38 | num_workers=0 # 调试模式建议设为 0,避免多进程报错 39 | ) 40 | print("✅ DataLoader initialized successfully.") 41 | except Exception as e: 42 | print(f"❌ DataLoader initialization failed: {e}") 43 | import traceback 44 | traceback.print_exc() 45 | return 46 | 47 | # 3. 验证数据迭代 48 | print("\n=== 3. Testing Iteration ===") 49 | try: 50 | iterator = iter(dataloader) 51 | batch = next(iterator) 52 | 53 | input_ids = batch["input_ids"] 54 | target_ids = batch["target_ids"] 55 | 56 | print(f"Batch keys: {batch.keys()}") 57 | print(f"Input shape: {input_ids.shape}") # 预期: [4, 128] 58 | print(f"Target shape: {target_ids.shape}") # 预期: [4, 128] 59 | 60 | # 4. 验证逻辑正确性 61 | # Target 应该是 Input 向右移动一位 62 | # 比如 Input: [A, B, C], Target: [B, C, D] 63 | # 但由于我们是从长文本截断的,无法直接验证 input[i+1] == target[i] (除非我们拿到原始数据) 64 | # 不过根据 collate_fn 的逻辑: 65 | # input = data[0:S], target = data[1:S+1] 66 | # 所以在同一行内,target[t] 应该等于 input[t+1] (如果它们来自同一个连续片段) 67 | # 验证最后 5 个 token 68 | 69 | print("\n--- Sample Check (First Sequence) ---") 70 | print(f"Input (last 5): {input_ids[0, -5:].tolist()}") 71 | print(f"Target (last 5): {target_ids[0, -5:].tolist()}") 72 | 73 | # 验证形状 74 | assert input_ids.shape == (micro_batch_size, seq_length), "Input shape mismatch!" 75 | assert target_ids.shape == (micro_batch_size, seq_length), "Target shape mismatch!" 76 | 77 | # 验证类型 78 | assert input_ids.dtype == torch.long, "Input dtype should be long" 79 | 80 | print("✅ Data shape and type check passed.") 81 | 82 | except Exception as e: 83 | print(f"❌ Iteration failed: {e}") 84 | import traceback 85 | traceback.print_exc() 86 | 87 | # 5. 清理 88 | dist.destroy_process_group() 89 | print("\n=== Test Finished ===") 90 | 91 | 92 | if __name__ == "__main__": 93 | tp_size = 2 94 | dp_size = 4 95 | test_dataloader(tp_size, dp_size) 96 | -------------------------------------------------------------------------------- /parallel/communication/ring_all_gather.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.distributed.device_mesh import init_device_mesh 5 | 6 | 7 | def init_dist(): 8 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 9 | rank = int(os.environ.get("RANK", 0)) 10 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 11 | 12 | torch.cuda.set_device(local_rank) 13 | device = torch.device(f"cuda:{local_rank}") 14 | 15 | if not dist.is_initialized(): 16 | dist.init_process_group(backend="nccl", device_id=device) 17 | 18 | return rank, world_size, device 19 | 20 | 21 | def get_global_rank(group, group_rank): 22 | return dist.get_global_rank(group, group_rank) 23 | 24 | 25 | def ring_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: 26 | rank = dist.get_rank(group) 27 | world_size = dist.get_world_size(group) 28 | 29 | if world_size == 1: 30 | return tensor.unsqueeze(0) 31 | 32 | output_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] 33 | output_tensors[rank] = tensor.clone() 34 | 35 | right_rank_logical = (rank + 1) % world_size 36 | left_rank_logical = (rank - 1 + world_size) % world_size 37 | 38 | right_rank_global = get_global_rank(group, right_rank_logical) 39 | left_rank_global = get_global_rank(group, left_rank_logical) 40 | 41 | curr_send_idx = rank 42 | curr_recv_idx = left_rank_logical 43 | 44 | for _ in range(world_size - 1): 45 | send_data = output_tensors[curr_send_idx] 46 | recv_data = output_tensors[curr_recv_idx] 47 | 48 | reqs = dist.batch_isend_irecv([ 49 | dist.P2POp(dist.isend, send_data, right_rank_global, group=group), 50 | dist.P2POp(dist.irecv, recv_data, left_rank_global, group=group) 51 | ]) 52 | for req in reqs: req.wait() 53 | 54 | curr_send_idx = curr_recv_idx 55 | curr_recv_idx = (curr_recv_idx - 1 + world_size) % world_size 56 | 57 | return torch.cat(output_tensors, dim=0) 58 | 59 | 60 | def main(): 61 | rank, world_size, device = init_dist() 62 | 63 | # 构建 4D Device Mesh 64 | if world_size < 8: 65 | if rank == 0: print("⚠️ Warning: Less than 8 GPUs, using simplified mesh.") 66 | mesh_shape = (1, 1, 2, 2) 67 | else: 68 | mesh_shape = (2, 1, 2, 2) 69 | 70 | mesh_dim_names = ("dp", "pp", "tp", "cp") 71 | mesh = init_device_mesh("cuda", mesh_shape, mesh_dim_names=mesh_dim_names) 72 | 73 | if rank == 0: 74 | print(f"\n🚀 Device Mesh Created: {mesh_shape} {mesh_dim_names}") 75 | 76 | dist.barrier() 77 | 78 | # 测试 Ring All-Gather (在 CP 组内) 79 | cp_group = mesh["cp"].get_group() 80 | local_vec = torch.arange(10, device=device, dtype=torch.float32) + rank * 100 81 | 82 | gather_list = [torch.zeros_like(local_vec) for _ in range(dist.get_world_size(cp_group))] 83 | dist.all_gather(gather_list, local_vec, group=cp_group) 84 | ref_gather = torch.cat(gather_list) 85 | 86 | my_gather = ring_all_gather(local_vec, group=cp_group) 87 | 88 | diff = (ref_gather - my_gather).abs().max() 89 | if rank == 0: 90 | print(f"\n[CP Group] All-Gather Test:") 91 | print(f" Max Diff: {diff.item():.6f} {'✅' if diff < 1e-5 else '❌'}") 92 | 93 | dist.barrier() 94 | dist.destroy_process_group() 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /parallel/communication/ring_reduce_scatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.distributed.device_mesh import init_device_mesh 5 | 6 | 7 | def init_dist(): 8 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 9 | rank = int(os.environ.get("RANK", 0)) 10 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 11 | 12 | torch.cuda.set_device(local_rank) 13 | device = torch.device(f"cuda:{local_rank}") 14 | 15 | if not dist.is_initialized(): 16 | dist.init_process_group(backend="nccl", device_id=device) 17 | 18 | return rank, world_size, device 19 | 20 | 21 | def get_global_rank(group, group_rank): 22 | return dist.get_global_rank(group, group_rank) 23 | 24 | 25 | def ring_reduce_scatter(tensor_list: list, group: dist.ProcessGroup) -> torch.Tensor: 26 | """ 27 | 修正后的 Ring Reduce-Scatter 28 | 逻辑:确保 Rank r 最终持有 Chunk r 的总和 29 | """ 30 | rank = dist.get_rank(group) 31 | world_size = dist.get_world_size(group) 32 | 33 | if world_size == 1: 34 | return tensor_list[0] 35 | 36 | # 初始化:Result 包含我本地的贡献 37 | # 注意:我们直接在 tensor_list 上原地修改,这样最后 tensor_list[rank] 就是结果 38 | 39 | right_rank_logical = (rank + 1) % world_size 40 | left_rank_logical = (rank - 1 + world_size) % world_size 41 | 42 | right_rank_global = get_global_rank(group, right_rank_logical) 43 | left_rank_global = get_global_rank(group, left_rank_logical) 44 | 45 | recv_buffer = torch.zeros_like(tensor_list[0]) 46 | 47 | for i in range(world_size - 1): 48 | # 【关键修正】索引偏移 -1 49 | # Step 0: Rank r 发送 Chunk r-1, 接收 Chunk r-2 50 | # 这样经过 N-1 步,Chunk r 会正好传回到 Rank r 51 | 52 | send_chunk_idx = (rank - i - 1 + world_size) % world_size 53 | recv_chunk_idx = (rank - i - 2 + world_size) % world_size 54 | 55 | send_data = tensor_list[send_chunk_idx] 56 | 57 | reqs = dist.batch_isend_irecv([ 58 | dist.P2POp(dist.isend, send_data, right_rank_global, group=group), 59 | dist.P2POp(dist.irecv, recv_buffer, left_rank_global, group=group) 60 | ]) 61 | for req in reqs: req.wait() 62 | 63 | # 累加到对应的块 64 | tensor_list[recv_chunk_idx] += recv_buffer 65 | 66 | # 循环结束后,tensor_list[rank] 已经包含了所有人的贡献 67 | return tensor_list[rank] 68 | 69 | 70 | def main(): 71 | rank, world_size, device = init_dist() 72 | 73 | # 构建 4D Device Mesh 74 | if world_size < 8: 75 | if rank == 0: print("⚠️ Warning: Less than 8 GPUs, using simplified mesh.") 76 | mesh_shape = (1, 1, 2, 2) 77 | else: 78 | mesh_shape = (2, 1, 2, 2) 79 | 80 | mesh_dim_names = ("dp", "pp", "tp", "cp") 81 | mesh = init_device_mesh("cuda", mesh_shape, mesh_dim_names=mesh_dim_names) 82 | 83 | if rank == 0: 84 | print(f"\n🚀 Device Mesh Created: {mesh_shape} {mesh_dim_names}") 85 | 86 | dist.barrier() 87 | 88 | # 测试 Reduce-Scatter (在 DP 组内) 89 | dp_group = mesh["dp"].get_group() 90 | dp_world_size = dist.get_world_size(dp_group) 91 | 92 | input_list = [torch.ones(10, device=device) * (rank + 1) * (i + 1) for i in range(dp_world_size)] 93 | 94 | ref_out = torch.zeros(10, device=device) 95 | dist.reduce_scatter(ref_out, input_list, group=dp_group) 96 | 97 | input_list_2 = [torch.ones(10, device=device) * (rank + 1) * (i + 1) for i in range(dp_world_size)] 98 | my_out = ring_reduce_scatter(input_list_2, group=dp_group) 99 | 100 | diff = (ref_out - my_out).abs().max() 101 | if dist.get_rank(dp_group) == 0 and rank == 0: 102 | print(f"\n[DP Group] Reduce-Scatter Test:") 103 | print(f" Max Diff: {diff.item():.6f} {'✅' if diff < 1e-5 else '❌'}") 104 | 105 | dist.barrier() 106 | dist.destroy_process_group() 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /tests/test_all_reduce.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.distributed.device_mesh import init_device_mesh 5 | from parallel.communication.all_reduce import ring_all_reduce 6 | 7 | def init_dist(): 8 | """初始化分布式环境""" 9 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 10 | rank = int(os.environ.get("RANK", 0)) 11 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 12 | 13 | # 1. 绑定设备 14 | torch.cuda.set_device(local_rank) 15 | device = torch.device(f"cuda:{local_rank}") 16 | 17 | # 2. 初始化默认进程组 (虽然 init_device_mesh 可以自动初始化, 18 | # 但显式初始化并指定 device_id 是消除 Warning 的最佳实践) 19 | if not dist.is_initialized(): 20 | dist.init_process_group(backend="nccl", device_id=device) 21 | 22 | return rank, world_size, device, local_rank 23 | 24 | 25 | def main(): 26 | rank, world_size, device, local_rank = init_dist() 27 | """ 28 | # init_process_group不是必须的,可用这一段来代替init_dist 29 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 30 | rank = int(os.environ.get("RANK", 0)) 31 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 32 | torch.cuda.set_device(local_rank) 33 | device = torch.device(f"cuda:{local_rank}") 34 | """ 35 | 36 | # 设定并行度:总卡数 8 = 2(DP) * 2(PP) * 2(TP) 37 | # 注意:这里的顺序很重要,决定了 Rank 如何映射到 Mesh 38 | # 通常顺序是 (Data, Pipeline, Tensor) 39 | mesh_shape = (2, 2, 2) 40 | mesh_dim_names = ("dp", "pp", "tp") 41 | 42 | # --- 一键生成 3D Mesh --- 43 | # 这行代码自动完成了之前几十行的 Group 创建逻辑 44 | mesh_3d = init_device_mesh( 45 | "cuda", 46 | mesh_shape, 47 | mesh_dim_names=mesh_dim_names 48 | ) 49 | 50 | # --- 直接通过名字获取 Group --- 51 | # 获取 TP 组 (沿着 "tp" 维度切分) 52 | tp_group = mesh_3d["tp"].get_group() 53 | 54 | # 获取 DP 组 (沿着 "dp" 维度切分) 55 | dp_group = mesh_3d["dp"].get_group() 56 | 57 | # 简单的同步屏障 58 | dist.barrier() 59 | if rank == 0: 60 | print("\n" + "=" * 50) 61 | print(f"🚀 Device Mesh 3D 并行测试 (Shape: {mesh_shape})") 62 | print("=" * 50 + "\n") 63 | # 打印 Mesh 结构看看 64 | print(f"Mesh Structure:\n{mesh_3d}") 65 | 66 | dist.barrier() 67 | 68 | # ------------------------------------------------- 69 | # 测试场景 1: TP AllReduce 70 | # ------------------------------------------------- 71 | tensor_tp = torch.randn(1024, device=device) * (rank + 1) 72 | tensor_tp_ref = tensor_tp.clone() 73 | 74 | # 传入从 Mesh 获取的 tp_group 75 | res_tp = ring_all_reduce(tensor_tp, group=tp_group) 76 | dist.all_reduce(tensor_tp_ref, op=dist.ReduceOp.SUM, group=tp_group) 77 | 78 | err_tp = torch.mean((res_tp - tensor_tp_ref) ** 2) 79 | 80 | if rank in [0, 1]: # 打印 Rank 0 和 1 (它们应该在同一个 TP 组) 81 | print(f"[TP Test] Rank {rank} (TP-Group Rank {dist.get_rank(tp_group)}): " 82 | f"Error = {err_tp.item():.5e}") 83 | 84 | dist.barrier() 85 | 86 | # ------------------------------------------------- 87 | # 测试场景 2: DP AllReduce 88 | # ------------------------------------------------- 89 | tensor_dp = torch.randn(1024, device=device) + (rank + 10) 90 | tensor_dp_ref = tensor_dp.clone() 91 | 92 | # 传入从 Mesh 获取的 dp_group 93 | res_dp = ring_all_reduce(tensor_dp, group=dp_group) 94 | dist.all_reduce(tensor_dp_ref, op=dist.ReduceOp.SUM, group=dp_group) 95 | 96 | err_dp = torch.mean((res_dp - tensor_dp_ref) ** 2) 97 | 98 | if rank in [0, 4]: # 打印 Rank 0 和 4 (它们应该在同一个 DP 组) 99 | print(f"[DP Test] Rank {rank} (DP-Group Rank {dist.get_rank(dp_group)}): " 100 | f"Error = {err_dp.item():.5e}") 101 | 102 | print(f"\n\n _flatten_mesh_list: {mesh_3d._flatten_mesh_list}") 103 | 104 | dist.barrier() 105 | 106 | # 清理 107 | dist.destroy_process_group() 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /parallel/distributed.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.distributed.device_mesh import init_device_mesh 3 | 4 | # 全局 Mesh 管理器 5 | _DEVICE_MESH = None 6 | _MESH_DIMS = {} 7 | 8 | 9 | def setup_distributed( 10 | tp_size: int = 1, 11 | pp_size: int = 1, 12 | cp_size: int = 1, 13 | ep_size: int = 1, 14 | dp_size: int = 1, 15 | ): 16 | """ 17 | 初始化 5D 并行环境。 18 | 19 | 层级结构 (Hierarchy): 20 | 1. PP (Pipeline): 最外层,通常跨机。 21 | 2. DP (Data): 数据并行层。 22 | 注意:EP (Expert) 通常是 DP 的一种变体。 23 | - 如果 ep_size == 1: 纯 DP,所有 DP rank 拥有相同的 MoE 参数。 24 | - 如果 ep_size == dp_size: 纯 EP,所有 DP rank 拥有不同的专家。 25 | - 如果 1 < ep_size < dp_size: 混合模式 (Hybrid EP)。 26 | 3. CP (Context): 上下文并行,切分 Sequence。 27 | 4. TP (Tensor): 最内层,切分算子,通常在单机 NVLink 范围内。 28 | """ 29 | if not dist.is_initialized(): 30 | dist.init_process_group("nccl") 31 | 32 | world_size = dist.get_world_size() 33 | rank = dist.get_rank() 34 | 35 | # 1. 校验 World Size 36 | # 注意:EP 不增加总 World Size,它是寄生在 DP 维度上的 37 | # 总卡数 = PP * DP * CP * TP 38 | expected_world_size = pp_size * dp_size * cp_size * tp_size 39 | 40 | if world_size != expected_world_size: 41 | raise ValueError( 42 | f"World Size Mismatch! Real: {world_size}, " 43 | f"Configured: {pp_size}(PP) * {dp_size}(DP) * {cp_size}(CP) * {tp_size}(TP) = {expected_world_size}" 44 | ) 45 | 46 | # 2. 校验 EP 合法性 47 | # EP 是在 DP 组内进行的,所以 ep_size 必须能整除 dp_size 48 | if dp_size % ep_size != 0: 49 | raise ValueError(f"DP size ({dp_size}) must be divisible by EP size ({ep_size})") 50 | 51 | # 3. 构建 Device Mesh 52 | # 维度顺序:(PP, DP, CP, TP) 53 | mesh_dims = [] 54 | mesh_names = [] 55 | 56 | if pp_size > 1: 57 | mesh_dims.append(pp_size) 58 | mesh_names.append("pp") 59 | 60 | if dp_size > 1: 61 | mesh_dims.append(dp_size) 62 | mesh_names.append("dp") 63 | 64 | if cp_size > 1: 65 | mesh_dims.append(cp_size) 66 | mesh_names.append("cp") 67 | 68 | if tp_size > 1: 69 | mesh_dims.append(tp_size) 70 | mesh_names.append("tp") 71 | 72 | global _DEVICE_MESH, _MESH_DIMS 73 | 74 | if len(mesh_dims) > 0: 75 | _DEVICE_MESH = init_device_mesh("cuda", tuple(mesh_dims), mesh_dim_names=tuple(mesh_names)) 76 | else: 77 | # 单卡模式 78 | _DEVICE_MESH = init_device_mesh("cuda", (1,), mesh_dim_names=("dp",)) 79 | 80 | # 4. 存储配置供后续查询 81 | _MESH_DIMS = { 82 | "tp": tp_size, 83 | "pp": pp_size, 84 | "cp": cp_size, 85 | "ep": ep_size, 86 | "dp": dp_size 87 | } 88 | 89 | if rank == 0: 90 | print(f"🚀 Distributed Init Success!") 91 | print(f" Shape: PP={pp_size} | DP={dp_size} (EP={ep_size}) | CP={cp_size} | TP={tp_size}") 92 | print(f" Mesh: {mesh_names}") 93 | 94 | 95 | def get_device_mesh(): 96 | return _DEVICE_MESH 97 | 98 | 99 | def get_parallel_info(): 100 | return _MESH_DIMS 101 | 102 | 103 | # === 获取各个维度的 Process Group === 104 | 105 | def get_tp_group(): 106 | return _DEVICE_MESH["tp"].get_group() if "tp" in _DEVICE_MESH.mesh_dim_names else None 107 | 108 | 109 | def get_cp_group(): 110 | return _DEVICE_MESH["cp"].get_group() if "cp" in _DEVICE_MESH.mesh_dim_names else None 111 | 112 | 113 | def get_pp_group(): 114 | return _DEVICE_MESH["pp"].get_group() if "pp" in _DEVICE_MESH.mesh_dim_names else None 115 | 116 | 117 | def get_dp_group(): 118 | # 纯 DP 组 (用于同步非 MoE 参数) 119 | return _DEVICE_MESH["dp"].get_group() if "dp" in _DEVICE_MESH.mesh_dim_names else None 120 | 121 | 122 | def get_ep_group(): 123 | """ 124 | 获取 EP 通信组。 125 | EP 比较特殊,它是在 DP 维度上切分的。 126 | 如果 ep_size == dp_size,那么 EP group 就是 DP group。 127 | 如果 ep_size < dp_size,我们需要在 DP group 内部再切分。 128 | (为了简化,这里假设 ep_size == dp_size,即标准 MoE) 129 | """ 130 | return get_dp_group() 131 | -------------------------------------------------------------------------------- /tests/test_ring_attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | from torch.distributed.device_mesh import init_device_mesh 8 | from parallel.communication.ring_attention import ring_attention_kernel 9 | 10 | 11 | def run_demo(rank, world_size): 12 | # 初始化进程组 13 | os.environ['MASTER_ADDR'] = 'localhost' 14 | os.environ['MASTER_PORT'] = '12359' 15 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 16 | torch.cuda.set_device(rank) 17 | 18 | # --- A. 构建 Device Mesh (DP=1, PP=1, TP=2, CP=4) --- 19 | # 8 张卡: 20 | # Global: `[Batch, Seq, Head, Dim]` 21 | # Local: `[Batch, Seq/CP, Head/TP, Dim]` 22 | 23 | # TP 组: [0,1], [2,3], [4,5], [6,7] (负责切分 Head) 24 | # CP 组: [0,2,4,6], [1,3,5,7] (负责切分 Sequence) 25 | mesh = init_device_mesh("cuda", (1, 1, 2, 4), mesh_dim_names=("dp", "pp", "tp", "cp")) 26 | cp_group = mesh["cp"].get_group() 27 | tp_group = mesh["tp"].get_group() 28 | 29 | # --- B. 数据模拟 --- 30 | # 假设全局参数 31 | B, Global_Seq, Global_Head, Dim = 2, 32, 8, 64 32 | 33 | # 生成全局数据 (仅用于验证对比,实际训练中不会有这个变量) 34 | if rank == 0: 35 | global_q = torch.randn(B, Global_Seq, Global_Head, Dim, device="cuda") 36 | global_k = torch.randn(B, Global_Seq, Global_Head, Dim, device="cuda") 37 | global_v = torch.randn(B, Global_Seq, Global_Head, Dim, device="cuda") 38 | else: 39 | global_q = torch.empty(B, Global_Seq, Global_Head, Dim, device="cuda") 40 | global_k = torch.empty(B, Global_Seq, Global_Head, Dim, device="cuda") 41 | global_v = torch.empty(B, Global_Seq, Global_Head, Dim, device="cuda") 42 | 43 | # 广播全局数据,保证大家用来切分的数据源是一致的 44 | dist.broadcast(global_q, src=0) 45 | dist.broadcast(global_k, src=0) 46 | dist.broadcast(global_v, src=0) 47 | 48 | # --- C. 数据切分 (Sharding) --- 49 | # 1. TP 切分 (Head 维度) 50 | tp_rank = dist.get_rank(tp_group) 51 | tp_size = dist.get_world_size(tp_group) # 2 52 | local_head_num = Global_Head // tp_size 53 | 54 | # 2. CP 切分 (Sequence 维度) 55 | cp_rank = dist.get_rank(cp_group) 56 | cp_size = dist.get_world_size(cp_group) # 4 57 | local_seq_len = Global_Seq // cp_size 58 | 59 | # 执行切分 60 | # 先切 Head (TP) 61 | temp_q = global_q.chunk(tp_size, dim=2)[tp_rank] 62 | temp_k = global_k.chunk(tp_size, dim=2)[tp_rank] 63 | temp_v = global_v.chunk(tp_size, dim=2)[tp_rank] 64 | 65 | # 再切 Seq (CP) 66 | local_q = temp_q.chunk(cp_size, dim=1)[cp_rank].clone() 67 | local_k = temp_k.chunk(cp_size, dim=1)[cp_rank].clone() 68 | local_v = temp_v.chunk(cp_size, dim=1)[cp_rank].clone() 69 | 70 | print(f"[Rank {rank}] Mesh Coord: TP={tp_rank}, CP={cp_rank} | " 71 | f"Local Shape: {local_q.shape} (Seq={local_seq_len}, Head={local_head_num})") 72 | 73 | # --- D. 运行 Ring Attention --- 74 | dist.barrier() 75 | start_event = torch.cuda.Event(enable_timing=True) 76 | end_event = torch.cuda.Event(enable_timing=True) 77 | 78 | start_event.record() 79 | # 调用我们手写的 Kernel 80 | local_out = ring_attention_kernel(local_q, local_k, local_v, cp_group, tp_group) 81 | end_event.record() 82 | torch.cuda.synchronize() 83 | 84 | if rank == 0: 85 | print(f"Ring Attention Time: {start_event.elapsed_time(end_event):.3f} ms") 86 | 87 | # --- E. 结果验证 (Gather vs Standard) --- 88 | 89 | # 1. 逆向还原 (Un-shard) 90 | # 先把 CP (Seq) 拼回去 91 | gathered_seq_out = [torch.zeros_like(local_out) for _ in range(cp_size)] 92 | dist.all_gather(gathered_seq_out, local_out, group=cp_group) 93 | # 此时我们有了 [Seq_Chunk0, Seq_Chunk1, ...] -> 拼成完整 Seq 94 | seq_restored = torch.cat(gathered_seq_out, dim=1) # [B, Global_Seq, Local_Head, D] 95 | 96 | # 再把 TP (Head) 拼回去 97 | gathered_head_out = [torch.zeros_like(seq_restored) for _ in range(tp_size)] 98 | dist.all_gather(gathered_head_out, seq_restored, group=tp_group) 99 | # 此时我们有了 [Head_Chunk0, Head_Chunk1] -> 拼成完整 Head 100 | final_restored = torch.cat(gathered_head_out, dim=2) # [B, Global_Seq, Global_Head, D] 101 | 102 | # 2. 运行标准 Attention (Reference) 103 | if rank == 0: 104 | # PyTorch 标准实现 105 | # 需要转置为 [B, H, S, D] 106 | ref_q = global_q.transpose(1, 2) 107 | ref_k = global_k.transpose(1, 2) 108 | ref_v = global_v.transpose(1, 2) 109 | 110 | ref_out = F.scaled_dot_product_attention(ref_q, ref_k, ref_v) 111 | ref_out = ref_out.transpose(1, 2) # 转回 [B, S, H, D] 112 | 113 | # 3. 对比 114 | # 由于浮点数累加顺序不同,会有微小误差 (1e-5 级别) 115 | max_diff = (final_restored - ref_out).abs().max() 116 | print(f"\n=== Verification Result ===") 117 | print(f"Max Difference: {max_diff.item():.6f}") 118 | 119 | if max_diff < 1e-4: 120 | print("✅ SUCCESS: Ring Attention matches Standard Attention!") 121 | else: 122 | print("❌ FAILED: Difference is too large.") 123 | 124 | dist.destroy_process_group() 125 | 126 | 127 | if __name__ == "__main__": 128 | WORLD_SIZE = 8 129 | # 检查是否有 8 个 GPU,没有的话报错 130 | if torch.cuda.device_count() < WORLD_SIZE: 131 | print(f"Error: This script requires {WORLD_SIZE} GPUs, but found {torch.cuda.device_count()}.") 132 | else: 133 | mp.spawn(run_demo, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True) 134 | -------------------------------------------------------------------------------- /parallel/parallel_tp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | from torch.distributed import _functional_collectives as funcol 6 | from .distributed import get_tp_group 7 | 8 | 9 | def _all_reduce(input_: torch.Tensor): 10 | """All-Reduce within TP group""" 11 | tp_group = get_tp_group() 12 | if tp_group is None: return input_ 13 | dist.all_reduce(input_, op=dist.ReduceOp.SUM, group=tp_group) 14 | return input_ 15 | 16 | 17 | def _split(input_: torch.Tensor, dim=-1): 18 | """Split tensor into parts for TP""" 19 | tp_group = get_tp_group() 20 | if tp_group is None: return input_ 21 | world_size = dist.get_world_size(group=tp_group) 22 | rank = dist.get_rank(group=tp_group) 23 | chunks = torch.chunk(input_, world_size, dim=dim) 24 | return chunks[rank].contiguous() 25 | 26 | 27 | def _gather(input_: torch.Tensor, dim=-1): 28 | """All-Gather within TP group""" 29 | tp_group = get_tp_group() 30 | if tp_group is None: return input_ 31 | return funcol.all_gather_tensor(input_, gather_dim=dim, group=tp_group) 32 | 33 | 34 | # === Sequence Parallel (SP) 原语 === 35 | # SP 的核心是:在 LayerNorm/Dropout 时,按 Seq 维度切分;在 Linear 计算时,按 Hidden 维度切分。 36 | # 这需要 all_to_all 通信 (Scatter/Gather 的变体) 37 | 38 | class ColumnParallelLinear(nn.Module): 39 | def __init__(self, in_features, out_features, bias=False, gather_output=False, sequence_parallel=False): 40 | super().__init__() 41 | self.gather_output = gather_output 42 | self.sequence_parallel = sequence_parallel 43 | 44 | tp_group = get_tp_group() 45 | self.tp_size = dist.get_world_size(group=tp_group) if tp_group else 1 46 | 47 | assert out_features % self.tp_size == 0 48 | self.out_features_per_partition = out_features // self.tp_size 49 | 50 | self.weight = nn.Parameter(torch.empty(self.out_features_per_partition, in_features)) 51 | if bias: 52 | self.bias = nn.Parameter(torch.empty(self.out_features_per_partition)) 53 | else: 54 | self.register_parameter('bias', None) 55 | 56 | # Init logic omitted for brevity (use master_seed) 57 | 58 | def forward(self, x): 59 | # x: [B, S, H] 60 | 61 | if self.sequence_parallel: 62 | # SP 模式下,输入 x 是在 Seq 维度切分的 [B, S/TP, H] 63 | # 我们需要把它 gather 回来变成 [B, S, H] 才能做矩阵乘法 64 | # 或者使用 Ring 算法。这里简化为 All-Gather 输入 65 | x = _gather(x, dim=1) 66 | 67 | # Local MatMul 68 | # Output: [B, S, H/TP] 69 | output = F.linear(x, self.weight, self.bias) 70 | 71 | if self.gather_output: 72 | output = _gather(output, dim=-1) 73 | 74 | return output 75 | 76 | 77 | class RowParallelLinear(nn.Module): 78 | def __init__(self, in_features, out_features, bias=False, input_is_parallel=True, sequence_parallel=False): 79 | super().__init__() 80 | self.input_is_parallel = input_is_parallel 81 | self.sequence_parallel = sequence_parallel 82 | 83 | tp_group = get_tp_group() 84 | self.tp_size = dist.get_world_size(group=tp_group) if tp_group else 1 85 | 86 | assert in_features % self.tp_size == 0 87 | self.in_features_per_partition = in_features // self.tp_size 88 | 89 | self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_partition)) 90 | if bias: 91 | self.bias = nn.Parameter(torch.empty(out_features)) 92 | else: 93 | self.register_parameter('bias', None) 94 | 95 | def forward(self, x): 96 | # x: [B, S, H/TP] 97 | 98 | # Local MatMul 99 | # Output: [B, S, H] (Partial Sum) 100 | output = F.linear(x, self.weight) 101 | 102 | if self.sequence_parallel: 103 | # SP 模式下,我们不做 All-Reduce (Sum),而是做 Reduce-Scatter 104 | # 结果变成 [B, S/TP, H],保持 Seq 维度切分 105 | # 注意:PyTorch 的 reduce_scatter_tensor API 106 | output = funcol.reduce_scatter_tensor(output, reduceOp='sum', scatter_dim=1, group=get_tp_group()) 107 | else: 108 | if self.input_is_parallel: 109 | output = _all_reduce(output) 110 | 111 | if self.bias is not None: 112 | output = output + self.bias 113 | return output 114 | 115 | 116 | class VocabParallelEmbedding(nn.Module): 117 | """TP for Embedding Layer""" 118 | 119 | def __init__(self, num_embeddings, embedding_dim): 120 | super().__init__() 121 | tp_group = get_tp_group() 122 | self.tp_size = dist.get_world_size(group=tp_group) if tp_group else 1 123 | self.tp_rank = dist.get_rank(group=tp_group) if tp_group else 0 124 | 125 | # 按 Vocab 维度切分 126 | self.vocab_start_index = self.tp_rank * (num_embeddings // self.tp_size) 127 | self.vocab_end_index = self.vocab_start_index + (num_embeddings // self.tp_size) 128 | 129 | self.weight = nn.Parameter(torch.empty(num_embeddings // self.tp_size, embedding_dim)) 130 | 131 | def forward(self, input_): 132 | # input_: [B, S] 133 | # Mask out tokens not in this partition 134 | input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) 135 | masked_input = input_.clone() - self.vocab_start_index 136 | masked_input[input_mask] = 0 137 | 138 | output = F.embedding(masked_input, self.weight) 139 | output[input_mask, :] = 0.0 140 | 141 | # All-Reduce to sum up embeddings from all partitions 142 | output = _all_reduce(output) 143 | return output 144 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.utils.data import DataLoader, DistributedSampler, IterableDataset 4 | from datasets import load_dataset, Features, Sequence, Value 5 | from transformers import AutoTokenizer 6 | from parallel.distributed import get_dp_group, get_cp_group 7 | 8 | 9 | class MicroBatchDataLoader(DataLoader): 10 | def __init__( 11 | self, 12 | micro_batch_size, 13 | seq_length, 14 | dataset_name, 15 | tokenizer_name, 16 | num_workers=0, 17 | num_proc=4, 18 | grad_acc_steps=1, 19 | split="train", 20 | max_samples=None, 21 | seed=42 22 | ): 23 | """ 24 | 通用 MicroBatch DataLoader。 25 | 支持: 26 | 1. 任意 HF 数据集 27 | 2. 自动 Tokenize 和 Chunking (拼接长文本) 28 | 3. 分布式采样 (DP) 29 | 4. 上下文并行切分 (CP) 30 | """ 31 | self.micro_batch_size = micro_batch_size 32 | self.seq_length = seq_length 33 | self.grad_acc_steps = grad_acc_steps 34 | 35 | # 获取分布式信息 36 | dp_group = get_dp_group() 37 | self.dp_world_size = dist.get_world_size(group=dp_group) if dp_group else 1 38 | self.dp_rank = dist.get_rank(group=dp_group) if dp_group else 0 39 | 40 | cp_group = get_cp_group() 41 | self.cp_world_size = dist.get_world_size(group=cp_group) if cp_group else 1 42 | self.cp_rank = dist.get_rank(group=cp_group) if cp_group else 0 43 | 44 | self.global_batch_size = micro_batch_size * grad_acc_steps * self.dp_world_size 45 | 46 | # CP 模式下,每个 GPU 只负责序列的一部分 47 | self.seq_length_per_gpu = seq_length // self.cp_world_size 48 | 49 | # 1. 加载 Tokenizer (只在 Rank 0 加载,然后广播,或者利用 HF 的缓存机制) 50 | # 这里简化为每个进程都加载,HF 会处理缓存锁 51 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 52 | if self.tokenizer.pad_token is None: 53 | self.tokenizer.pad_token = self.tokenizer.eos_token 54 | 55 | # 2. 加载数据集 56 | # 使用 streaming=True 可以处理超大数据集,但为了 shuffle 和 map 方便, 57 | # 这里演示 map-style dataset (适合中小数据集如 TinyStories, Wikitext) 58 | # 如果是 TB 级数据,建议换成 IterableDataset + Buffer Shuffle 59 | print(f"[Rank {dist.get_rank()}] Loading dataset {dataset_name}...") 60 | dataset = load_dataset(dataset_name, split=split) 61 | 62 | if max_samples: 63 | dataset = dataset.select(range(min(max_samples, len(dataset)))) 64 | 65 | # 3. 预处理:Tokenize & Grouping 66 | # 将文本转换为定长的 input_ids 67 | self.tokenized_dataset = self.process_dataset(dataset, num_proc) 68 | 69 | # 4. 分布式采样器 70 | self.sampler = DistributedSampler( 71 | self.tokenized_dataset, 72 | num_replicas=self.dp_world_size, 73 | rank=self.dp_rank, 74 | shuffle=True, 75 | seed=seed 76 | ) 77 | 78 | super().__init__( 79 | self.tokenized_dataset, 80 | batch_size=micro_batch_size, 81 | sampler=self.sampler, 82 | num_workers=num_workers, 83 | collate_fn=self.collate_fn, 84 | pin_memory=True, 85 | drop_last=True 86 | ) 87 | 88 | def process_dataset(self, dataset, num_proc): 89 | """ 90 | 将文本数据 Tokenize 并拼接成 seq_length + 1 的块 91 | """ 92 | block_size = self.seq_length + 1 93 | 94 | def group_texts(examples): 95 | # 1. Tokenize 96 | tokenized_inputs = self.tokenizer( 97 | examples["text"], 98 | return_special_tokens_mask=True, 99 | truncation=False # 先不截断,全部拼起来 100 | ) 101 | concatenated_ids = {} 102 | for k, v in tokenized_inputs.items(): 103 | # 展平 list of list 104 | concatenated_ids[k] = sum(v, []) 105 | 106 | total_length = len(concatenated_ids["input_ids"]) 107 | # 丢弃最后不够一个 block 的部分 108 | if total_length >= block_size: 109 | total_length = (total_length // block_size) * block_size 110 | 111 | # 切分 112 | result = { 113 | k: [t[i: i + block_size] for i in range(0, total_length, block_size)] 114 | for k, t in concatenated_ids.items() 115 | } 116 | return result 117 | 118 | # 只有主进程打印进度条 119 | is_main_process = (dist.get_rank() == 0) if dist.is_initialized() else True 120 | 121 | tokenized_datasets = dataset.map( 122 | group_texts, 123 | batched=True, 124 | num_proc=num_proc, 125 | remove_columns=dataset.column_names, 126 | desc=f"Grouping texts in chunks of {block_size}", 127 | load_from_cache_file=True, 128 | disable_nullable=True 129 | ) 130 | return tokenized_datasets 131 | 132 | def collate_fn(self, batch): 133 | """ 134 | 处理 Batch,支持 CP 切分 135 | """ 136 | # batch 是一个 list of dict: [{'input_ids': [...]}, ...] 137 | input_ids_list = [item['input_ids'] for item in batch] 138 | batch_tensor = torch.tensor(input_ids_list, dtype=torch.long) 139 | 140 | # batch_tensor: [B, S+1] 141 | # input: 0 ~ S-1 142 | # target: 1 ~ S 143 | 144 | # CP 切分逻辑 145 | # 如果 CP=1, start=0, end=S 146 | # 如果 CP=2, Rank0: 0~S/2, Rank1: S/2~S 147 | start_idx = self.cp_rank * self.seq_length_per_gpu 148 | end_idx = start_idx + self.seq_length_per_gpu 149 | 150 | # 注意:这里切分的是 input (0~S-1) 和 target (1~S) 151 | # 原始数据长度是 S+1 152 | 153 | # Input: [B, S_local] 154 | input_ids = batch_tensor[:, start_idx: end_idx].contiguous() 155 | # Target: [B, S_local] 156 | target_ids = batch_tensor[:, start_idx + 1: end_idx + 1].contiguous() 157 | 158 | return { 159 | "input_ids": input_ids, 160 | "target_ids": target_ids 161 | } 162 | -------------------------------------------------------------------------------- /parallel/parallel_fsdp.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch.distributed.fsdp import ( 4 | FullyShardedDataParallel as FSDP, 5 | MixedPrecision, 6 | ShardingStrategy, 7 | ) 8 | from parallel.distributed import get_pp_group 9 | import torch.distributed as dist 10 | from torch.distributed.fsdp.wrap import ( 11 | transformer_auto_wrap_policy, 12 | size_based_auto_wrap_policy, 13 | ) 14 | 15 | # FSDP2 16 | try: 17 | # PyTorch 2.3+ (最新路径) 18 | from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy 19 | except ImportError: 20 | try: 21 | # PyTorch 2.2.0 (旧路径) 22 | from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy 23 | except ImportError: 24 | def fully_shard(*args, **kwargs): 25 | raise NotImplementedError("FSDP2 (fully_shard) requires PyTorch >= 2.2") 26 | class MixedPrecisionPolicy: 27 | def __init__(self, *args, **kwargs): 28 | pass 29 | 30 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 31 | checkpoint_wrapper, 32 | CheckpointImpl, 33 | apply_activation_checkpointing, 34 | ) 35 | from .distributed import get_device_mesh 36 | 37 | 38 | def apply_fsdp1( 39 | model, 40 | transformer_layer_cls, 41 | use_bf16=True, 42 | strategy="full", 43 | device_mesh=None, 44 | use_activation_checkpointing=False 45 | ): 46 | """ 47 | Classic FSDP (Wrapper based) - 生产级增强版 48 | """ 49 | # 1. 自动包裹策略 (Auto Wrap Policy) 50 | # 这是一个关键点:除了 TransformerBlock,有时我们也希望根据参数量包裹其他大层 51 | if transformer_layer_cls: 52 | auto_wrap_policy = functools.partial( 53 | transformer_auto_wrap_policy, 54 | transformer_layer_cls={transformer_layer_cls}, 55 | ) 56 | else: 57 | # 回退策略:如果没指定层类,按参数量切分 (例如 > 10M 的层) 58 | auto_wrap_policy = functools.partial( 59 | size_based_auto_wrap_policy, min_num_params=1e7 60 | ) 61 | 62 | # 2. 混合精度策略 (Mixed Precision) 63 | # 生产环境建议:reduce_dtype 保持 float32 以防止梯度下溢 (Underflow) 64 | mp_policy = MixedPrecision( 65 | param_dtype=torch.bfloat16, 66 | reduce_dtype=torch.float32, 67 | buffer_dtype=torch.float32, 68 | ) if use_bf16 else None 69 | 70 | # 3. 分片策略 (Sharding Strategy) 71 | # 处理 HSDP 的特殊逻辑 72 | if strategy == "hybrid": 73 | sharding_strategy = ShardingStrategy.HYBRID_SHARD 74 | # HSDP 必须提供 device_mesh,且必须是 2D 的 (Replicate, Shard) 75 | # 如果传入的 mesh 是 1D 的,这里会报错,需要做检查 76 | if device_mesh is None or device_mesh.ndim != 2: 77 | raise ValueError("HSDP requires a 2D DeviceMesh (Replicate, Shard).") 78 | elif strategy == "grad_op": 79 | sharding_strategy = ShardingStrategy.SHARD_GRAD_OP 80 | else: 81 | sharding_strategy = ShardingStrategy.FULL_SHARD 82 | 83 | # 4. 初始化 FSDP 84 | # sync_module_states=True: 极其重要!确保所有 Rank 的随机初始化参数在开始前强制同步一致。 85 | # limit_all_gathers=True: 节省显存,防止在前向传播时同时 gather 太多层。 86 | fsdp_model = FSDP( 87 | model, 88 | auto_wrap_policy=auto_wrap_policy, 89 | mixed_precision=mp_policy, 90 | sharding_strategy=sharding_strategy, 91 | device_id=torch.cuda.current_device(), 92 | device_mesh=device_mesh, 93 | sync_module_states=True, # [改进点] 强制同步参数 94 | limit_all_gathers=True, # [改进点] 显存优化 95 | use_orig_params=True # [改进点] 允许 torch.compile 优化 96 | ) 97 | 98 | # 5. Activation Checkpointing (梯度检查点) 99 | # FSDP1 需要在 wrap 之后应用 100 | if use_activation_checkpointing and transformer_layer_cls: 101 | non_reentrant_wrapper = functools.partial( 102 | checkpoint_wrapper, 103 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 104 | ) 105 | check_fn = lambda submodule: isinstance(submodule, transformer_layer_cls) 106 | apply_activation_checkpointing( 107 | fsdp_model, 108 | checkpoint_wrapper_fn=non_reentrant_wrapper, 109 | check_fn=check_fn, 110 | ) 111 | 112 | return fsdp_model 113 | 114 | 115 | def apply_fsdp2( 116 | model, 117 | transformer_layer_cls=None, 118 | use_activation_checkpointing=False 119 | ): 120 | """ 121 | FSDP2 (Composable / DTensor based) - 生产级增强版 122 | """ 123 | mesh = get_device_mesh() 124 | if mesh is None: 125 | raise ValueError("Device Mesh must be initialized for FSDP2") 126 | 127 | # FSDP2 只需要知道在哪个维度上进行 Sharding (DP) 128 | # 如果 Mesh 是多维的 (例如 3D: pp/dp/tp),我们需要切出 'dp' 这一维 129 | fsdp_mesh = mesh 130 | if "dp" in mesh.mesh_dim_names: 131 | # 使用 mesh["dp"] 提取子网格 132 | # 这会返回一个 1D Mesh,只包含 DP 组的进程 133 | fsdp_mesh = mesh["dp"] 134 | 135 | mp_policy = MixedPrecisionPolicy( 136 | param_dtype=torch.bfloat16, 137 | reduce_dtype=torch.float32 138 | ) 139 | 140 | # FSDP2 的核心逻辑:自底向上应用 fully_shard 141 | 142 | # # 1. 对 Transformer Block 应用 fully_shard 143 | # # reshard_after_forward=True: 类似于 FSDP1 的 FULL_SHARD,算完就释放显存 144 | # if transformer_layer_cls: 145 | # for module in model.modules(): 146 | # if isinstance(module, transformer_layer_cls): 147 | # fully_shard( 148 | # module=module, 149 | # mesh=mesh, 150 | # mp_policy=mp_policy, 151 | # reshard_after_forward=True 152 | # ) 153 | 154 | # # [改进点] FSDP2 的 AC 集成 155 | # # 在 fully_shard 之后应用 AC 156 | # if use_activation_checkpointing: 157 | # checkpoint_wrapper(module, checkpoint_impl=CheckpointImpl.NO_REENTRANT) 158 | 159 | # # 2. 对整个模型应用 fully_shard 160 | # # 这会处理剩下的 Embedding、Output Head 等层 161 | # fully_shard( 162 | # model, 163 | # mesh=mesh, 164 | # mp_policy=mp_policy, 165 | # reshard_after_forward=True 166 | # ) 167 | for layer_id, transformer_block in enumerate(model.layers): 168 | # 使用提取出的 fsdp_mesh (1D) 而不是全局 mesh (3D) 169 | fully_shard( 170 | transformer_block, 171 | mesh=fsdp_mesh, 172 | mp_policy=mp_policy, 173 | reshard_after_forward=True 174 | ) 175 | # Root level wrapping 176 | fully_shard( 177 | model, 178 | mesh=fsdp_mesh, 179 | mp_policy=mp_policy, 180 | reshard_after_forward=True 181 | ) 182 | 183 | return model 184 | -------------------------------------------------------------------------------- /parallel/communication/all_reduce.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import torch.nn.functional as F 5 | from torch.distributed.device_mesh import init_device_mesh 6 | 7 | 8 | def init_dist(): 9 | """初始化分布式环境""" 10 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 11 | rank = int(os.environ.get("RANK", 0)) 12 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 13 | 14 | # 1. 绑定设备 15 | torch.cuda.set_device(local_rank) 16 | device = torch.device(f"cuda:{local_rank}") 17 | 18 | # 2. 初始化默认进程组 (虽然 init_device_mesh 可以自动初始化, 19 | # 但显式初始化并指定 device_id 是消除 Warning 的最佳实践) 20 | if not dist.is_initialized(): 21 | dist.init_process_group(backend="nccl", device_id=device) 22 | 23 | return rank, world_size, device, local_rank 24 | 25 | 26 | def ring_all_reduce(tensor: torch.Tensor, group: dist.ProcessGroup = None) -> torch.Tensor: 27 | """ 28 | 通用的 Ring AllReduce 实现 29 | :param tensor: 输入张量 30 | :param group: 通信组 (TP组 或 DP组) 31 | """ 32 | if group is None: 33 | group = dist.group.WORLD 34 | 35 | # 获取组内逻辑 Rank 36 | rank_in_group = dist.get_rank(group) 37 | world_size_in_group = dist.get_world_size(group) 38 | 39 | if world_size_in_group == 1: 40 | return tensor 41 | 42 | # 预处理:Flatten + Padding 43 | original_shape = tensor.shape 44 | tensor_flat = tensor.flatten() 45 | numel = tensor_flat.numel() 46 | 47 | pad_len = (world_size_in_group - (numel % world_size_in_group)) % world_size_in_group 48 | if pad_len > 0: 49 | tensor_flat = F.pad(tensor_flat, (0, pad_len)) 50 | 51 | # 分块 52 | chunk_size = tensor_flat.numel() // world_size_in_group 53 | chunks = list(tensor_flat.split(chunk_size)) 54 | 55 | # 计算环形邻居 (逻辑 Rank -> 物理 Global Rank) 56 | right_rank_logical = (rank_in_group + 1) % world_size_in_group 57 | left_rank_logical = (rank_in_group - 1 + world_size_in_group) % world_size_in_group 58 | 59 | right_rank_global = dist.get_global_rank(group, right_rank_logical) 60 | left_rank_global = dist.get_global_rank(group, left_rank_logical) 61 | 62 | # Reduce-Scatter 63 | for step in range(world_size_in_group - 1): 64 | send_idx = (rank_in_group - step + world_size_in_group) % world_size_in_group 65 | recv_idx = (rank_in_group - step - 1 + world_size_in_group) % world_size_in_group 66 | 67 | send_chunk = chunks[send_idx] 68 | recv_buffer = torch.empty_like(chunks[recv_idx]) 69 | 70 | reqs = dist.batch_isend_irecv([ 71 | dist.P2POp(dist.isend, send_chunk, right_rank_global, group=group), 72 | dist.P2POp(dist.irecv, recv_buffer, left_rank_global, group=group) 73 | ]) 74 | for req in reqs: req.wait() 75 | chunks[recv_idx].add_(recv_buffer) 76 | 77 | # All-Gather 78 | for step in range(world_size_in_group - 1): 79 | send_idx = (rank_in_group - step + 1 + world_size_in_group) % world_size_in_group 80 | recv_idx = (rank_in_group - step + world_size_in_group) % world_size_in_group 81 | 82 | send_chunk = chunks[send_idx] 83 | reqs = dist.batch_isend_irecv([ 84 | dist.P2POp(dist.isend, send_chunk, right_rank_global, group=group), 85 | dist.P2POp(dist.irecv, chunks[recv_idx], left_rank_global, group=group) 86 | ]) 87 | for req in reqs: req.wait() 88 | 89 | # 恢复形状 90 | res = torch.cat(chunks) 91 | if pad_len > 0: 92 | res = res[:-pad_len] 93 | return res.reshape(original_shape) 94 | 95 | 96 | def main(): 97 | rank, world_size, device, local_rank = init_dist() 98 | """ 99 | # init_process_group不是必须的,可用这一段来代替init_dist 100 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 101 | rank = int(os.environ.get("RANK", 0)) 102 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 103 | torch.cuda.set_device(local_rank) 104 | device = torch.device(f"cuda:{local_rank}") 105 | """ 106 | 107 | # 设定并行度:总卡数 8 = 2(DP) * 2(PP) * 2(TP) 108 | # 注意:这里的顺序很重要,决定了 Rank 如何映射到 Mesh 109 | # 通常顺序是 (Data, Pipeline, Tensor) 110 | mesh_shape = (2, 2, 2) 111 | mesh_dim_names = ("dp", "pp", "tp") 112 | 113 | # --- 一键生成 3D Mesh --- 114 | # 这行代码自动完成了之前几十行的 Group 创建逻辑 115 | mesh_3d = init_device_mesh( 116 | "cuda", 117 | mesh_shape, 118 | mesh_dim_names=mesh_dim_names 119 | ) 120 | 121 | # --- 直接通过名字获取 Group --- 122 | # 获取 TP 组 (沿着 "tp" 维度切分) 123 | tp_group = mesh_3d["tp"].get_group() 124 | 125 | # 获取 DP 组 (沿着 "dp" 维度切分) 126 | dp_group = mesh_3d["dp"].get_group() 127 | 128 | # 简单的同步屏障 129 | dist.barrier() 130 | if rank == 0: 131 | print("\n" + "=" * 50) 132 | print(f"🚀 Device Mesh 3D 并行测试 (Shape: {mesh_shape})") 133 | print("=" * 50 + "\n") 134 | # 打印 Mesh 结构看看 135 | print(f"Mesh Structure:\n{mesh_3d}") 136 | 137 | dist.barrier() 138 | 139 | # ------------------------------------------------- 140 | # 测试场景 1: TP AllReduce 141 | # ------------------------------------------------- 142 | tensor_tp = torch.randn(1024, device=device) * (rank + 1) 143 | tensor_tp_ref = tensor_tp.clone() 144 | 145 | # 传入从 Mesh 获取的 tp_group 146 | res_tp = ring_all_reduce(tensor_tp, group=tp_group) 147 | dist.all_reduce(tensor_tp_ref, op=dist.ReduceOp.SUM, group=tp_group) 148 | 149 | err_tp = torch.mean((res_tp - tensor_tp_ref) ** 2) 150 | 151 | if rank in [0, 1]: # 打印 Rank 0 和 1 (它们应该在同一个 TP 组) 152 | print(f"[TP Test] Rank {rank} (TP-Group Rank {dist.get_rank(tp_group)}): " 153 | f"Error = {err_tp.item():.5e}") 154 | 155 | dist.barrier() 156 | 157 | # ------------------------------------------------- 158 | # 测试场景 2: DP AllReduce 159 | # ------------------------------------------------- 160 | tensor_dp = torch.randn(1024, device=device) + (rank + 10) 161 | tensor_dp_ref = tensor_dp.clone() 162 | 163 | # 传入从 Mesh 获取的 dp_group 164 | res_dp = ring_all_reduce(tensor_dp, group=dp_group) 165 | dist.all_reduce(tensor_dp_ref, op=dist.ReduceOp.SUM, group=dp_group) 166 | 167 | err_dp = torch.mean((res_dp - tensor_dp_ref) ** 2) 168 | 169 | if rank in [0, 4]: # 打印 Rank 0 和 4 (它们应该在同一个 DP 组) 170 | print(f"[DP Test] Rank {rank} (DP-Group Rank {dist.get_rank(dp_group)}): " 171 | f"Error = {err_dp.item():.5e}") 172 | 173 | print(f"\n\n _flatten_mesh_list: {mesh_3d._flatten_mesh_list}") 174 | 175 | dist.barrier() 176 | 177 | # 清理 178 | dist.destroy_process_group() 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /parallel/parallel_ep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from .distributed import get_device_mesh 5 | 6 | 7 | def _all_to_all_with_handshake(data, splits, group): 8 | """ 9 | 带有握手机制的 All-to-All 通信。 10 | 11 | Args: 12 | data: 本地要发送的排序后的数据 [Total_Tokens, Hidden] 13 | splits: List[int], 长度为 world_size,表示发给每个 rank 的 token 数量 14 | group: 通信组 15 | 16 | Returns: 17 | received_data: 接收到的数据 [Total_Received, Hidden] 18 | received_splits: List[int], 表示从每个 rank 接收到的 token 数量 19 | """ 20 | world_size = dist.get_world_size(group=group) 21 | 22 | # 1. 握手 (Handshake): 交换数据量信息 23 | # 我们要发送的 counts 24 | send_counts = torch.tensor(splits, device=data.device, dtype=torch.long) 25 | # 我们准备接收的 counts 26 | recv_counts = torch.empty(world_size, device=data.device, dtype=torch.long) 27 | 28 | # All-to-All 交换 counts 29 | # 比如: Rank 0 发给 Rank 1 说 "我有 10 个 token",Rank 1 就会在 recv_counts[0] 收到 10 30 | dist.all_to_all_single(recv_counts, send_counts, group=group) 31 | 32 | # 2. 准备接收缓冲区 33 | recv_splits = recv_counts.tolist() 34 | total_recv_tokens = sum(recv_splits) 35 | hidden_dim = data.size(1) 36 | 37 | # 如果没有数据要收发,直接返回空 38 | if total_recv_tokens == 0 and data.numel() == 0: 39 | return torch.empty(0, hidden_dim, device=data.device, dtype=data.dtype), recv_splits 40 | 41 | recv_data = torch.empty(total_recv_tokens, hidden_dim, device=data.device, dtype=data.dtype) 42 | 43 | # 3. 传输实际数据 (Payload) 44 | # PyTorch 的 all_to_all_single 需要 input_split_sizes 和 output_split_sizes 45 | dist.all_to_all_single( 46 | recv_data, 47 | data, 48 | output_split_sizes=recv_splits, 49 | input_split_sizes=splits, 50 | group=group 51 | ) 52 | 53 | return recv_data, recv_splits 54 | 55 | 56 | class ExpertParallel(nn.Module): 57 | """ 58 | 专家并行 (EP) 核心模块。 59 | 实现了 Token 的 Dispatch (分发) 和 Combine (聚合)。 60 | """ 61 | 62 | def __init__(self, num_experts): 63 | super().__init__() 64 | self.num_experts = num_experts 65 | self.mesh = get_device_mesh() 66 | 67 | # 获取 EP 组 68 | if self.mesh and "ep" in self.mesh.mesh_dim_names: 69 | self.ep_group = self.mesh["ep"].get_group() 70 | else: 71 | self.ep_group = None 72 | 73 | def dispatch(self, x, expert_indices): 74 | """ 75 | 将 Token 分发到对应的 GPU。 76 | 77 | Args: 78 | x: [Batch * Seq, Hidden] 输入 Token 79 | expert_indices: [Batch * Seq, TopK] 每个 Token 选中的专家索引 80 | 81 | Returns: 82 | dispatched_x: [Total_Recv, Hidden] 接收到的需要计算的 Token 83 | metadata: dict, 包含恢复顺序所需的所有信息 84 | """ 85 | # 1. 本地模式 (无 EP) 86 | if self.ep_group is None: 87 | # 为了接口一致性,我们需要把 TopK 展开 88 | # x: [N, D] -> [N, K, D] -> [N*K, D] 89 | topk = expert_indices.size(1) 90 | x_expanded = x.unsqueeze(1).expand(-1, topk, -1).reshape(-1, x.size(-1)) 91 | return x_expanded, {"is_local": True, "topk": topk} 92 | 93 | world_size = dist.get_world_size(group=self.ep_group) 94 | rank = dist.get_rank(group=self.ep_group) 95 | 96 | # 2. 准备数据 97 | # x: [N, D], indices: [N, K] 98 | # 我们需要把 x 复制 K 份,因为一个 token 可能去多个专家 99 | topk = expert_indices.size(1) 100 | # [N, D] -> [N, K, D] -> [N*K, D] 101 | x_flat = x.unsqueeze(1).expand(-1, topk, -1).reshape(-1, x.size(-1)) 102 | # [N, K] -> [N*K] 103 | indices_flat = expert_indices.view(-1) 104 | 105 | # 3. 计算目标 Rank 106 | # 假设专家是均匀切分的。例如 8 个专家,4 个 GPU,则每个 GPU 负责 2 个。 107 | # Rank 0: Experts [0, 1], Rank 1: Experts [2, 3] ... 108 | num_local_experts = self.num_experts // world_size 109 | target_ranks = indices_flat // num_local_experts 110 | 111 | # 4. 排序 (Sorting) - 核心步骤 112 | # 我们必须把发往 Rank 0 的数据排在前面,Rank 1 的排在后面... 113 | sort_indices = torch.argsort(target_ranks) 114 | 115 | # 根据排序结果重排 x 和 target_ranks 116 | x_sorted = x_flat[sort_indices] 117 | target_ranks_sorted = target_ranks[sort_indices] 118 | 119 | # 5. 计算 Split Sizes (发给每个 Rank 多少个) 120 | # 统计每个 rank 出现的次数 121 | # bincount 统计 [0, 1, 1, 2] -> [1, 2, 1] 122 | splits = torch.bincount(target_ranks_sorted, minlength=world_size).tolist() 123 | 124 | # 6. 执行 All-to-All 通信 125 | received_x, received_splits = _all_to_all_with_handshake(x_sorted, splits, self.ep_group) 126 | 127 | # 7. 保存 Metadata (用于 Combine 阶段恢复) 128 | metadata = { 129 | "is_local": False, 130 | "sort_indices": sort_indices, # 用于恢复顺序 131 | "send_splits": splits, # 发送时的切分 (Combine 时接收用) 132 | "recv_splits": received_splits, # 接收时的切分 (Combine 时发送用) 133 | "original_batch_size": x.size(0), 134 | "topk": topk 135 | } 136 | 137 | return received_x, metadata 138 | 139 | def combine(self, x, metadata): 140 | """ 141 | 将计算完成的 Token 发回原 GPU,并恢复顺序。 142 | 143 | Args: 144 | x: [Total_Recv, Hidden] 计算后的 Token (通常是经过 MLP 后的) 145 | metadata: dispatch 返回的元数据 146 | """ 147 | if metadata["is_local"]: 148 | # 恢复形状 [N*K, D] -> [N, K, D] 149 | return x.view(metadata["original_batch_size"], metadata["topk"], -1) 150 | 151 | # 1. 逆向通信 152 | # Dispatch: Send(splits) -> Recv(recv_splits) 153 | # Combine: Send(recv_splits) -> Recv(splits) 154 | # 注意:这里的 x 是已经在当前 GPU 算好的,数量等于 recv_splits 155 | 156 | received_x, _ = _all_to_all_with_handshake( 157 | x, 158 | metadata["recv_splits"], # 我现在发回去的数量,就是我之前收到的数量 159 | self.ep_group 160 | ) 161 | 162 | # 2. 恢复顺序 (Un-sort) 163 | # received_x 目前是按 Rank 排序的 (Rank 0 发回来的, Rank 1 发回来的...) 164 | # 我们需要把它变回 dispatch 之前的顺序 165 | 166 | # 创建一个空的 buffer 167 | # sort_indices[i] = j 意味着:排序后的第 i 个元素来自原始的第 j 个位置 168 | # 所以:original[sort_indices[i]] = sorted[i] 169 | # 逆操作:original[j] = sorted[inverse_map[j]] 170 | 171 | # 更简单的方法:直接根据 sort_indices 赋值 172 | output = torch.empty_like(received_x) 173 | # output[sort_indices] = received_x 174 | # 上面这行是错的,应该是 scatter 或者 index_copy 175 | # 正确逻辑:received_x 是排序后的状态,我们要把它放回 sort_indices 指定的位置 176 | output.index_copy_(0, metadata["sort_indices"], received_x) 177 | 178 | # 3. 恢复形状 179 | # [N*K, D] -> [N, K, D] 180 | output = output.view(metadata["original_batch_size"], metadata["topk"], -1) 181 | 182 | return output 183 | -------------------------------------------------------------------------------- /parallel/parallel_cp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from .distributed import get_device_mesh 5 | from layers.layers import apply_rotary_emb 6 | 7 | 8 | def _all_to_all(input_, scatter_dim, gather_dim, group): 9 | """ 10 | All-to-All 通信原语 (基于 Ulysses 算法的核心) 11 | 作用:将张量在 scatter_dim 上切分,发送给不同 rank, 12 | 同时接收其他 rank 的数据,在 gather_dim 上拼接。 13 | """ 14 | # 1. 基础检查 15 | if group is None: 16 | return input_ 17 | 18 | world_size = dist.get_world_size(group=group) 19 | if world_size == 1: 20 | return input_ 21 | 22 | # 2. 预处理:确保输入连续 23 | input_ = input_.contiguous() 24 | 25 | # 3. 准备输入切片 (Split) 26 | # input shape: [..., scatter_dim_size, ...] 27 | # chunks shape: world_size * [..., scatter_dim_size/P, ...] 28 | input_chunks = list(input_.chunk(world_size, dim=scatter_dim)) 29 | 30 | # 4. 准备输出缓冲区 31 | # 输出形状与输入切片形状相同(假设切分均匀) 32 | output_chunks = [torch.empty_like(chunk) for chunk in input_chunks] 33 | 34 | # 5. 执行 All-to-All 通信 35 | # 这是一个同步操作 36 | dist.all_to_all(output_chunks, input_chunks, group=group) 37 | 38 | # 6. 后处理:拼接 (Concat) 39 | # output shape: [..., gather_dim_size * P, ...] 40 | return torch.cat(output_chunks, dim=gather_dim) 41 | 42 | 43 | class ContextParallelAttention(nn.Module): 44 | """ 45 | CP Attention Wrapper (DeepSpeed Ulysses Style) 46 | 47 | 该模块用于包裹标准的 Attention 计算(如 FlashAttention)。 48 | 它负责在计算前后进行数据的“转置”: 49 | 50 | 流程: 51 | 1. 输入: [Batch, Seq/P, Heads, Dim] (序列被切分) 52 | 2. 通信: All-to-All (Seq -> Heads) 53 | 3. 中间: [Batch, Seq, Heads/P, Dim] (现在拥有完整的 Seq,但只有部分的 Heads) 54 | 4. 计算: Local Attention (因为有完整 Seq,所以可以算 Attention Score) 55 | 5. 通信: All-to-All (Heads -> Seq) 56 | 6. 输出: [Batch, Seq/P, Heads, Dim] (恢复为序列切分) 57 | """ 58 | 59 | def __init__(self, local_attn_module, args): 60 | """ 61 | Args: 62 | local_attn_module: 原始的 Attention 模块 (lightron.model.Attention) 63 | args: ModelArgs 配置,用于获取 head 数量等信息 64 | """ 65 | super().__init__() 66 | self.local_attn = local_attn_module 67 | self.num_heads = args.n_heads 68 | 69 | # 既然 local_attn 已经是 TP 后的模块,它的 n_heads 属性就是正确的本地头数 (6) 70 | # 不要直接读 args.n_heads (12)! 71 | if hasattr(local_attn_module, 'n_heads'): 72 | self.num_heads = local_attn_module.n_heads 73 | else: 74 | # 兜底逻辑:手动除以 TP 75 | tp_size = getattr(args, 'tp_size', 1) 76 | self.num_heads = args.n_heads // tp_size 77 | 78 | self.head_dim = args.dim // args.n_heads 79 | 80 | # 获取 CP 通信组 81 | self.mesh = get_device_mesh() 82 | # 检查 mesh 中是否有 'cp' 维度,如果没有则回退到 None (不做 CP) 83 | if self.mesh and "cp" in self.mesh.mesh_dim_names: 84 | self.cp_group = self.mesh["cp"].get_group() 85 | cp_size = dist.get_world_size(self.cp_group) 86 | # === 修复:添加整除断言 === 87 | assert self.num_heads % cp_size == 0, \ 88 | f"Context Parallelism requires local_num_heads ({self.num_heads}) " \ 89 | f"to be divisible by cp_size ({cp_size})." 90 | else: 91 | self.cp_group = None 92 | 93 | def forward(self, x, freqs_cis): 94 | """ 95 | x: 输入 Tensor,通常是 LayerNorm 后的结果 96 | Shape: [Batch, Seq_Local, Hidden_Dim] 97 | 注意:这里的 Seq_Local = Seq_Global / CP_Size 98 | """ 99 | # 如果没有开启 CP,直接调用原始 Attention 100 | if self.cp_group is None: 101 | return self.local_attn(x, freqs_cis) 102 | 103 | # === 1. 准备阶段 === 104 | B, S_local, Hidden = x.shape 105 | cp_size = dist.get_world_size(group=self.cp_group) 106 | 107 | # 此时 x 是 [B, S/P, H_total * D] 108 | # 我们需要先把它 reshape 成 [B, S/P, H_total, D] 109 | # 注意:这里假设 local_attn 内部的 wq, wk, wv 是线性的, 110 | # 为了适配 Ulysses,我们需要侵入到 Attention 内部,或者要求 Attention 的输入已经是 QKV。 111 | # **修正**:Lightron 的 Attention 模块内部包含了 WQ/WK/WV 投影。 112 | # Ulysses 通常要求投影后的 QKV 进行 All-to-All。 113 | # 为了不重写整个 Attention 类,我们采用一种更高级的策略: 114 | # 我们让 local_attn 正常计算 QKV 投影,但在 FlashAttention 之前拦截它。 115 | 116 | # 但由于 Python 无法直接拦截中间变量,我们需要修改 lightron/model.py 中的 Attention。 117 | # 为了让这个 Wrapper 生效,我们假设 self.local_attn 已经被修改为支持 118 | # 接收 "pre-computed QKV" 或者我们在这里手动执行投影。 119 | 120 | # 【方案 B:最稳健的实现】 121 | # 我们不 Wrap 整个 Attention,而是 Wrap "Attention 计算部分"。 122 | # 但为了符合你要求的代码结构,我们假设这个 forward 是替代原 Attention.forward 的。 123 | 124 | # 1. 投影 (Local Projection) 125 | # x: [B, S/P, Hidden] 126 | xq, xk, xv = self.local_attn.wq(x), self.local_attn.wk(x), self.local_attn.wv(x) 127 | 128 | # Reshape to heads 129 | # xq: [B, S/P, H_total, D] 130 | xq = xq.view(B, S_local, self.num_heads, self.head_dim) 131 | xk = xk.view(B, S_local, self.num_heads, self.head_dim) # 暂不考虑 GQA/MQA 的复杂情况,假设 MHA 132 | xv = xv.view(B, S_local, self.num_heads, self.head_dim) 133 | 134 | # === 2. 第一次 All-to-All (Seq -> Head) === 135 | # 目标: [B, S_global, H_local, D] 136 | # 操作: Scatter dim 1 (Seq), Gather dim 2 (Head) 137 | 138 | xq = _all_to_all(xq, scatter_dim=2, gather_dim=1, group=self.cp_group) 139 | xk = _all_to_all(xk, scatter_dim=2, gather_dim=1, group=self.cp_group) 140 | xv = _all_to_all(xv, scatter_dim=2, gather_dim=1, group=self.cp_group) 141 | 142 | # 现在的形状: [B, S_global, H_total/P, D] 143 | # 此时我们拥有了完整的 Sequence,但只有部分的 Heads。 144 | 145 | # === 3. RoPE (需要完整 Seq) === 146 | # 因为现在 S 是完整的,所以可以直接应用 RoPE 147 | # 注意:freqs_cis 需要匹配 S_global 148 | # 如果传入的 freqs_cis 是切片过的,这里可能需要调整,假设传入的是完整的或自动广播的 149 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis) 150 | 151 | # === 4. Local Attention (FlashAttn) === 152 | # PyTorch 的 scaled_dot_product_attention 接受 [B, H, S, D] 153 | # 我们需要转置一下 154 | output = torch.nn.functional.scaled_dot_product_attention( 155 | xq.transpose(1, 2), # [B, H/P, S_global, D] 156 | xk.transpose(1, 2), 157 | xv.transpose(1, 2), 158 | is_causal=True 159 | ) 160 | # output: [B, H/P, S_global, D] 161 | output = output.transpose(1, 2).contiguous() # [B, S_global, H/P, D] 162 | 163 | # === 5. 第二次 All-to-All (Head -> Seq) === 164 | # 目标: [B, S_local, H_total, D] 165 | # 操作: Scatter dim 1 (Seq), Gather dim 2 (Head) -> 也就是逆操作 166 | 167 | # 注意:刚才 _all_to_all 是把 dim 2 切了拼到 dim 1 168 | # 现在我们要把 dim 1 切了拼到 dim 2 169 | output = _all_to_all(output, scatter_dim=1, gather_dim=2, group=self.cp_group) 170 | 171 | # 现在的形状: [B, S_local, H_total, D] 172 | 173 | # === 6. 输出投影 === 174 | output = output.flatten(2) # [B, S_local, Hidden] 175 | return self.local_attn.wo(output) 176 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import time 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.distributed as dist 9 | from transformers import AutoConfig 10 | 11 | from config.config import ModelArgs 12 | from model.model import LightronTransformer 13 | from parallel.distributed import setup_distributed 14 | from parallel.parallel_fsdp import apply_fsdp2 15 | from data.dataloader import MicroBatchDataLoader 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser(description="Lightron Training Script") 20 | parser.add_argument("--config", type=str, required=True, help="Path to JSON config file") 21 | return parser.parse_args() 22 | 23 | 24 | def load_config(config_path): 25 | with open(config_path, "r") as f: 26 | return json.load(f) 27 | 28 | 29 | def train_step(model, batch, grad_acc_steps): 30 | """单步训练逻辑""" 31 | # 数据移动到 GPU 32 | input_ids = batch["input_ids"].cuda() 33 | target_ids = batch["target_ids"].cuda() 34 | 35 | # Forward 36 | # 注意:LightronTransformer 返回的是 [B, S, VocabSize] 37 | logits = model(input_ids) 38 | 39 | # Loss Calculation 40 | # Reshape: [B*S, V] vs [B*S] 41 | loss = F.cross_entropy( 42 | logits.view(-1, logits.size(-1)), 43 | target_ids.view(-1) 44 | ) 45 | 46 | # Scale loss for gradient accumulation 47 | loss = loss / grad_acc_steps 48 | loss.backward() 49 | 50 | return loss.item() * grad_acc_steps 51 | 52 | 53 | def main(): 54 | # 1. 解析参数与配置 55 | args = get_args() 56 | config = load_config(args.config) 57 | 58 | dist_cfg = config["distributed"] 59 | train_cfg = config["training"] 60 | model_cfg = config["model"] 61 | data_cfg = config["dataset"] 62 | 63 | # 2. 初始化分布式环境 (4D Parallel Setup) 64 | # 优先从环境变量读取 (torchrun),如果没设则用 config 的默认值 65 | tp_size = int(os.environ.get("TP_SIZE", dist_cfg.get("tp_size", 1))) 66 | dp_size = int(os.environ.get("DP_SIZE", dist_cfg.get("dp_size", 1))) 67 | cp_size = int(os.environ.get("CP_SIZE", dist_cfg.get("cp_size", 1))) 68 | pp_size = int(os.environ.get("PP_SIZE", dist_cfg.get("pp_size", 1))) 69 | ep_size = int(os.environ.get("EP_SIZE", dist_cfg.get("ep_size", 1))) 70 | 71 | setup_distributed( 72 | tp_size=tp_size, 73 | pp_size=pp_size, 74 | cp_size=cp_size, 75 | ep_size=ep_size, 76 | dp_size=dp_size 77 | ) 78 | 79 | local_rank = int(os.environ["LOCAL_RANK"]) 80 | global_rank = int(os.environ["RANK"]) 81 | world_size = int(os.environ["WORLD_SIZE"]) 82 | torch.cuda.set_device(local_rank) 83 | 84 | if global_rank == 0: 85 | print(f"🚀 Starting training with config: {args.config}") 86 | print(f" World Size: {world_size} | TP={tp_size} DP={dp_size}") 87 | 88 | # 3. 自动加载模型配置 (从 HF) 89 | # 使用 HF_ENDPOINT 环境变量确保国内能下载 90 | if "HF_ENDPOINT" not in os.environ: 91 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" 92 | 93 | if global_rank == 0: 94 | print(f"Loading model config from {model_cfg['name']}...") 95 | 96 | # 让所有进程都加载 Config (Config 文件很小,不会有并发问题) 97 | hf_config = AutoConfig.from_pretrained(model_cfg["name"], trust_remote_code=True) 98 | 99 | vocab_size = hf_config.vocab_size 100 | if tp_size > 1: 101 | # 计算需要填充多少才能被 tp_size 整除 102 | if vocab_size % tp_size != 0: 103 | new_vocab_size = ((vocab_size // tp_size) + 1) * tp_size 104 | if global_rank == 0: 105 | print(f"⚠️ Vocab size {vocab_size} is not divisible by TP={tp_size}.") 106 | print(f" Padding vocab size to {new_vocab_size}...") 107 | vocab_size = new_vocab_size 108 | 109 | # 4. 转换为 Lightron ModelArgs 110 | # 自动映射 HF 参数到 Lightron 参数 111 | model_args = ModelArgs( 112 | dim=hf_config.hidden_size, 113 | n_layers=hf_config.num_hidden_layers, 114 | n_heads=hf_config.num_attention_heads, 115 | n_kv_heads=getattr(hf_config, "num_key_value_heads", hf_config.num_attention_heads), 116 | vocab_size=vocab_size, 117 | max_seq_len=train_cfg["seq_length"], 118 | norm_eps=getattr(hf_config, "rms_norm_eps", 1e-5), 119 | # 并行模式 120 | tp_size=tp_size, 121 | cp_size=cp_size, 122 | # MoE 配置 123 | moe_num_experts=model_cfg.get("moe_num_experts", 1), 124 | moe_topk=model_cfg.get("moe_topk", 2), 125 | moe_layer_freq=model_cfg.get("moe_layer_freq", 2) 126 | ) 127 | 128 | # 5. 初始化模型 129 | # 使用 Meta Device 初始化,秒级构建,不占显存 130 | with torch.device("meta"): 131 | model = LightronTransformer(model_args) 132 | 133 | # 6. 应用并行策略 134 | # A. TP/CP/EP: 已经在 model.py 内部通过 parallel_mode 处理了层结构 135 | # B. FSDP (DP): 处理剩余的参数切分 136 | if dp_size > 1: 137 | # FSDP2 会自动处理 Meta 到 Real 的参数初始化 138 | # 注意:如果 TP>1,这里是混合并行,FSDP2 会在 DP 维度切分 139 | 140 | # 1. 先切分 (此时还是 Meta Tensor) 141 | model = apply_fsdp2(model) 142 | 143 | # 2. 分配物理显存 (Materialize), 这会在每张卡上只分配它负责的那一部分参数 (Local Shard) 144 | model = model.to_empty(device="cuda") 145 | 146 | # 3. 初始化参数数值 147 | # 因为是 Meta 初始化,现在显存里全是垃圾数据,必须 reset 148 | # 为了保证所有 DP Rank 初始权重一致,我们需要固定随机种子 149 | torch.manual_seed(42 + global_rank) # 注意:通常 DP 需要相同种子,但 FSDP2 这种局部初始化比较特殊 150 | 151 | # 更严谨的做法:设置相同的种子,让大家算出一样的随机数(如果切分逻辑允许), 或者 Rank 0 初始化后广播(太慢)。 152 | # 对于 FSDP2,最简单的做法是:设置全局统一种子,然后依靠 reset_parameters 153 | torch.manual_seed(train_cfg.get("seed", 42)) 154 | 155 | def init_weights(m): 156 | # 如果模块有自定义的重置方法(如 Linear, Embedding, 或我们的 ParallelLinear) 157 | if hasattr(m, 'reset_parameters'): 158 | m.reset_parameters() 159 | # 兜底逻辑:针对原生 PyTorch 层 160 | elif isinstance(m, (torch.nn.Linear, torch.nn.Embedding)): 161 | m.reset_parameters() 162 | 163 | model.apply(init_weights) 164 | else: 165 | # 纯 TP 模式或单卡模式,需要手动 materialize 166 | model = model.to_empty(device="cuda") 167 | model.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None) 168 | 169 | from layers.layers import precompute_freqs_cis 170 | if global_rank == 0: 171 | print("Re-computing RoPE frequencies for Meta-initialized model...") 172 | with torch.no_grad(): 173 | # 重新计算 174 | real_freqs = precompute_freqs_cis( 175 | model_args.dim // model_args.n_heads, 176 | model_args.max_seq_len 177 | ) 178 | # 移动到 GPU 并赋值给模型的 buffer 179 | model.freqs_cis.copy_(real_freqs.to("cuda")) 180 | 181 | if global_rank == 0: 182 | # 统计参数量 (FSDP 下可能不准,仅供参考) 183 | try: 184 | param_count = sum(p.numel() for p in model.parameters()) 185 | print(f"Model initialized. Total Parameters (Local/Meta): {param_count / 1e9:.2f}B") 186 | except: 187 | pass 188 | 189 | # 7. 初始化 DataLoader 190 | # 使用我们刚刚测试通过的 MicroBatchDataLoader 191 | dataloader = MicroBatchDataLoader( 192 | micro_batch_size=train_cfg["micro_batch_size"], 193 | seq_length=train_cfg["seq_length"], 194 | dataset_name=data_cfg["name"], 195 | tokenizer_name=model_cfg["name"], # 复用模型名作为 tokenizer 名 196 | grad_acc_steps=train_cfg["gradient_accumulation_steps"], 197 | num_workers=data_cfg.get("num_workers", 0), 198 | max_samples=train_cfg.get("max_samples", None), 199 | split=data_cfg.get("split", "train") 200 | ) 201 | 202 | # 8. 优化器 203 | optimizer = optim.AdamW( 204 | model.parameters(), 205 | lr=train_cfg["learning_rate"], 206 | weight_decay=train_cfg.get("weight_decay", 0.01) 207 | ) 208 | 209 | # 9. 训练循环 210 | model.train() 211 | total_steps = train_cfg["total_steps"] 212 | step = 0 213 | tokens_seen = 0 214 | 215 | start_time = time.time() 216 | 217 | # 创建迭代器 218 | data_iter = iter(dataloader) 219 | 220 | if global_rank == 0: 221 | print("\n=== Start Training ===") 222 | 223 | while step < total_steps: 224 | optimizer.zero_grad() 225 | loss_accum = 0.0 226 | 227 | # Gradient Accumulation Loop 228 | for _ in range(train_cfg["gradient_accumulation_steps"]): 229 | try: 230 | batch = next(data_iter) 231 | except StopIteration: 232 | # Epoch 结束,重新开始 233 | data_iter = iter(dataloader) 234 | batch = next(data_iter) 235 | 236 | loss_val = train_step(model, batch, train_cfg["gradient_accumulation_steps"]) 237 | loss_accum += loss_val 238 | 239 | # Optimizer Step 240 | # FSDP 会自动处理梯度同步 241 | optimizer.step() 242 | 243 | step += 1 244 | # 计算吞吐量 245 | current_tokens = dataloader.global_batch_size * train_cfg["seq_length"] 246 | tokens_seen += current_tokens 247 | 248 | # Logging 249 | if global_rank == 0 and step % train_cfg.get("log_interval", 10) == 0: 250 | elapsed = time.time() - start_time 251 | tokens_per_sec = tokens_seen / elapsed 252 | print(f"Step {step}/{total_steps} | Loss: {loss_accum:.4f} | TPS: {tokens_per_sec:.2f} tokens/s") 253 | 254 | if global_rank == 0: 255 | print("Training Finished!") 256 | 257 | dist.destroy_process_group() 258 | 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /parallel/communication/ring_attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | from torch.distributed.device_mesh import init_device_mesh 8 | 9 | 10 | def ring_attention_kernel(local_q, local_k, local_v, cp_group, tp_group): 11 | """ 12 | 实现 Ring Attention + Online Softmax 13 | local_q: [B, S_local, H_local, D] 14 | """ 15 | rank = dist.get_rank(group=cp_group) 16 | world_size = dist.get_world_size(group=cp_group) 17 | 18 | # 维度信息 19 | B, S_local, H_local, D = local_q.shape 20 | scale = 1.0 / math.sqrt(D) 21 | 22 | # === 初始化 Online Softmax 的统计量 === 23 | # max_score: 当前行的最大值 (用于数值稳定) 24 | # sum_exp: 当前行的分母 (exp的和) 25 | # out: 当前的分子 (加权和) 26 | local_max = torch.full((B, S_local, H_local, 1), float('-inf'), device=local_q.device) 27 | local_sum_exp = torch.zeros((B, S_local, H_local, 1), device=local_q.device) 28 | local_out = torch.zeros_like(local_q) 29 | 30 | # === 准备通信缓冲区 === 31 | # curr: 当前计算用的 KV 32 | # next: 接收下一个 Step 的 KV 33 | curr_k, curr_v = local_k.clone(), local_v.clone() 34 | next_k, next_v = torch.empty_like(local_k), torch.empty_like(local_v) 35 | 36 | # === Ring Loop === 37 | # 环状通信:Rank i -> Rank i+1, Rank i-1 -> Rank i 38 | # 也就是向右发,从左接 39 | right_rank = (rank + 1) % world_size 40 | left_rank = (rank - 1 + world_size) % world_size 41 | 42 | # 获取全局 Rank ID,因为 P2P 通信需要全局 ID 43 | global_right_rank = dist.get_global_rank(cp_group, right_rank) 44 | global_left_rank = dist.get_global_rank(cp_group, left_rank) 45 | 46 | for step in range(world_size): 47 | # === 1. 启动异步通信 (使用 batch_isend_irecv) === 48 | ops = [] 49 | if step < world_size - 1: 50 | # 定义操作列表 51 | # 发送 K, V 给右边 52 | ops.append(dist.P2POp(dist.isend, curr_k, global_right_rank, cp_group)) 53 | ops.append(dist.P2POp(dist.isend, curr_v, global_right_rank, cp_group)) 54 | # 从左边接收 K, V 55 | ops.append(dist.P2POp(dist.irecv, next_k, global_left_rank, cp_group)) 56 | ops.append(dist.P2POp(dist.irecv, next_v, global_left_rank, cp_group)) 57 | # 关键修改:原子提交所有 P2P 请求 58 | reqs = dist.batch_isend_irecv(ops) 59 | else: 60 | reqs = [] 61 | 62 | # 2. 计算 Attention (Computation) 63 | # Q [B, S, H, D] @ K.T [B, H, D, S] -> Score [B, H, S, S] 64 | # 注意:这里的 K 是来自 Ring 的某一段 Sequence 65 | # 为了方便矩阵乘法,我们调整维度: [B, H, S, D] 66 | q_ = local_q.transpose(1, 2) 67 | k_ = curr_k.transpose(1, 2) 68 | v_ = curr_v.transpose(1, 2) 69 | 70 | # [B, H, S_local, D] @ [B, H, D, S_remote] -> [B, H, S_local, S_remote] 71 | attn_score = torch.matmul(q_, k_.transpose(-1, -2)) * scale 72 | 73 | # 3. Online Softmax 更新逻辑 (核心数学) 74 | # 这一步通常在 CUDA Kernel 内部完成,这里用 Python 模拟 75 | # 维度转回 [B, S, H, 1] 以便广播 76 | attn_score = attn_score.transpose(1, 2) # [B, S_local, H, S_remote] 77 | 78 | # 找当前块的最大值 79 | block_max = torch.max(attn_score, dim=-1, keepdim=True).values 80 | 81 | # 更新全局最大值 82 | new_max = torch.maximum(local_max, block_max) 83 | 84 | # 计算缩放因子 85 | scale_old = torch.exp(local_max - new_max) 86 | scale_block = torch.exp(block_max - new_max) 87 | 88 | # 计算当前块的 exp 89 | # P_block = exp(score - block_max) 90 | p_block = torch.exp(attn_score - block_max) 91 | 92 | # 更新分母 Sum_Exp 93 | # sum_new = sum_old * scale_old + sum_block * scale_block 94 | block_sum = torch.sum(p_block, dim=-1, keepdim=True) 95 | local_sum_exp = local_sum_exp * scale_old + block_sum * scale_block 96 | 97 | # 更新分子 Out 98 | # out_new = out_old * scale_old + (P_block @ V_block) * scale_block 99 | # 先算 P @ V 100 | # [B, H, S_local, S_remote] @ [B, H, S_remote, D] -> [B, H, S_local, D] 101 | p_v = torch.matmul(p_block.transpose(1, 2), v_) 102 | p_v = p_v.transpose(1, 2) # [B, S_local, H, D] 103 | 104 | local_out = local_out * scale_old + p_v * scale_block 105 | 106 | # 更新 Max 107 | local_max = new_max 108 | 109 | # 4. 等待通信完成 110 | for req in reqs: 111 | req.wait() 112 | 113 | # 5. 交换 Buffer 114 | curr_k = next_k.clone() 115 | curr_v = next_v.clone() 116 | 117 | # 最后除以分母 118 | final_out = local_out / local_sum_exp 119 | return final_out 120 | 121 | 122 | def run_demo(rank, world_size): 123 | # 初始化进程组 124 | os.environ['MASTER_ADDR'] = 'localhost' 125 | os.environ['MASTER_PORT'] = '12359' 126 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 127 | torch.cuda.set_device(rank) 128 | 129 | # --- A. 构建 Device Mesh (DP=1, PP=1, TP=2, CP=4) --- 130 | # 8 张卡: 131 | # Global: `[Batch, Seq, Head, Dim]` 132 | # Local: `[Batch, Seq/CP, Head/TP, Dim]` 133 | 134 | # TP 组: [0,1], [2,3], [4,5], [6,7] (负责切分 Head) 135 | # CP 组: [0,2,4,6], [1,3,5,7] (负责切分 Sequence) 136 | mesh = init_device_mesh("cuda", (1, 1, 2, 4), mesh_dim_names=("dp", "pp", "tp", "cp")) 137 | cp_group = mesh["cp"].get_group() 138 | tp_group = mesh["tp"].get_group() 139 | 140 | # --- B. 数据模拟 --- 141 | # 假设全局参数 142 | B, Global_Seq, Global_Head, Dim = 2, 32, 8, 64 143 | 144 | # 生成全局数据 (仅用于验证对比,实际训练中不会有这个变量) 145 | if rank == 0: 146 | global_q = torch.randn(B, Global_Seq, Global_Head, Dim, device="cuda") 147 | global_k = torch.randn(B, Global_Seq, Global_Head, Dim, device="cuda") 148 | global_v = torch.randn(B, Global_Seq, Global_Head, Dim, device="cuda") 149 | else: 150 | global_q = torch.empty(B, Global_Seq, Global_Head, Dim, device="cuda") 151 | global_k = torch.empty(B, Global_Seq, Global_Head, Dim, device="cuda") 152 | global_v = torch.empty(B, Global_Seq, Global_Head, Dim, device="cuda") 153 | 154 | # 广播全局数据,保证大家用来切分的数据源是一致的 155 | dist.broadcast(global_q, src=0) 156 | dist.broadcast(global_k, src=0) 157 | dist.broadcast(global_v, src=0) 158 | 159 | # --- C. 数据切分 (Sharding) --- 160 | # 1. TP 切分 (Head 维度) 161 | tp_rank = dist.get_rank(tp_group) 162 | tp_size = dist.get_world_size(tp_group) # 2 163 | local_head_num = Global_Head // tp_size 164 | 165 | # 2. CP 切分 (Sequence 维度) 166 | cp_rank = dist.get_rank(cp_group) 167 | cp_size = dist.get_world_size(cp_group) # 4 168 | local_seq_len = Global_Seq // cp_size 169 | 170 | # 执行切分 171 | # 先切 Head (TP) 172 | temp_q = global_q.chunk(tp_size, dim=2)[tp_rank] 173 | temp_k = global_k.chunk(tp_size, dim=2)[tp_rank] 174 | temp_v = global_v.chunk(tp_size, dim=2)[tp_rank] 175 | 176 | # 再切 Seq (CP) 177 | local_q = temp_q.chunk(cp_size, dim=1)[cp_rank].clone() 178 | local_k = temp_k.chunk(cp_size, dim=1)[cp_rank].clone() 179 | local_v = temp_v.chunk(cp_size, dim=1)[cp_rank].clone() 180 | 181 | print(f"[Rank {rank}] Mesh Coord: TP={tp_rank}, CP={cp_rank} | " 182 | f"Local Shape: {local_q.shape} (Seq={local_seq_len}, Head={local_head_num})") 183 | 184 | # --- D. 运行 Ring Attention --- 185 | dist.barrier() 186 | start_event = torch.cuda.Event(enable_timing=True) 187 | end_event = torch.cuda.Event(enable_timing=True) 188 | 189 | start_event.record() 190 | # 调用我们手写的 Kernel 191 | local_out = ring_attention_kernel(local_q, local_k, local_v, cp_group, tp_group) 192 | end_event.record() 193 | torch.cuda.synchronize() 194 | 195 | if rank == 0: 196 | print(f"Ring Attention Time: {start_event.elapsed_time(end_event):.3f} ms") 197 | 198 | # --- E. 结果验证 (Gather vs Standard) --- 199 | 200 | # 1. 逆向还原 (Un-shard) 201 | # 先把 CP (Seq) 拼回去 202 | gathered_seq_out = [torch.zeros_like(local_out) for _ in range(cp_size)] 203 | dist.all_gather(gathered_seq_out, local_out, group=cp_group) 204 | # 此时我们有了 [Seq_Chunk0, Seq_Chunk1, ...] -> 拼成完整 Seq 205 | seq_restored = torch.cat(gathered_seq_out, dim=1) # [B, Global_Seq, Local_Head, D] 206 | 207 | # 再把 TP (Head) 拼回去 208 | gathered_head_out = [torch.zeros_like(seq_restored) for _ in range(tp_size)] 209 | dist.all_gather(gathered_head_out, seq_restored, group=tp_group) 210 | # 此时我们有了 [Head_Chunk0, Head_Chunk1] -> 拼成完整 Head 211 | final_restored = torch.cat(gathered_head_out, dim=2) # [B, Global_Seq, Global_Head, D] 212 | 213 | # 2. 运行标准 Attention (Reference) 214 | if rank == 0: 215 | # PyTorch 标准实现 216 | # 需要转置为 [B, H, S, D] 217 | ref_q = global_q.transpose(1, 2) 218 | ref_k = global_k.transpose(1, 2) 219 | ref_v = global_v.transpose(1, 2) 220 | 221 | ref_out = F.scaled_dot_product_attention(ref_q, ref_k, ref_v) 222 | ref_out = ref_out.transpose(1, 2) # 转回 [B, S, H, D] 223 | 224 | # 3. 对比 225 | # 由于浮点数累加顺序不同,会有微小误差 (1e-5 级别) 226 | max_diff = (final_restored - ref_out).abs().max() 227 | print(f"\n=== Verification Result ===") 228 | print(f"Max Difference: {max_diff.item():.6f}") 229 | 230 | if max_diff < 1e-4: 231 | print("✅ SUCCESS: Ring Attention matches Standard Attention!") 232 | else: 233 | print("❌ FAILED: Difference is too large.") 234 | 235 | dist.destroy_process_group() 236 | 237 | if __name__ == "__main__": 238 | WORLD_SIZE = 8 239 | # 检查是否有 8 个 GPU,没有的话报错 240 | if torch.cuda.device_count() < WORLD_SIZE: 241 | print(f"Error: This script requires {WORLD_SIZE} GPUs, but found {torch.cuda.device_count()}.") 242 | else: 243 | mp.spawn(run_demo, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True) 244 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from config.config import ModelArgs 5 | from layers.layers import RMSNorm, apply_rotary_emb, precompute_freqs_cis 6 | from parallel.parallel_tp import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding 7 | from parallel.parallel_cp import ContextParallelAttention 8 | from parallel.parallel_ep import ExpertParallel 9 | 10 | 11 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 12 | """ 13 | 将 KV 头的数量复制 n_rep 倍,以匹配 Query 的头数。 14 | 15 | Args: 16 | x: 输入张量,形状为 (Batch, SeqLen, n_kv_heads, HeadDim) 17 | n_rep: 复制倍数 18 | 19 | Returns: 20 | 输出张量,形状为 (Batch, SeqLen, n_kv_heads * n_rep, HeadDim) 21 | """ 22 | bs, slen, n_kv_heads, head_dim = x.shape 23 | if n_rep == 1: 24 | return x 25 | 26 | # 核心逻辑: 27 | # 1. 增加一个维度: (B, S, n_kv, 1, D) 28 | # 2. 在新维度复制: (B, S, n_kv, n_rep, D) 29 | # 3. 展平维度: (B, S, n_kv * n_rep, D) 30 | return ( 31 | x[:, :, :, None, :] 32 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 33 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 34 | ) 35 | 36 | 37 | def get_linear_cls(args: ModelArgs, parallel_type: str = None): 38 | """ 39 | 根据配置返回合适的 Linear 类 40 | parallel_type: 'col' (列并行), 'row' (行并行), None (普通) 41 | """ 42 | if args.tp_size > 1: 43 | if parallel_type == 'col': 44 | return ColumnParallelLinear 45 | elif parallel_type == 'row': 46 | return RowParallelLinear 47 | return nn.Linear 48 | 49 | 50 | class MLP(nn.Module): 51 | def __init__(self, args: ModelArgs): 52 | super().__init__() 53 | hidden_dim = 4 * args.dim 54 | hidden_dim = int(2 * hidden_dim / 3) 55 | if args.ffn_dim_multiplier is not None: 56 | hidden_dim = int(args.ffn_dim_multiplier * hidden_dim) 57 | hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of) 58 | 59 | ColLinear = get_linear_cls(args, 'col') 60 | RowLinear = get_linear_cls(args, 'row') 61 | 62 | # Llama 结构: w1(Gate), w3(Up), w2(Down) 63 | self.w1 = ColLinear(args.dim, hidden_dim, bias=False) # Gate 64 | self.w3 = ColLinear(args.dim, hidden_dim, bias=False) # Up 65 | self.w2 = RowLinear(hidden_dim, args.dim, bias=False) # Down 66 | 67 | def forward(self, x): 68 | # SwiGLU: (xW1 * SiLU(xW3)) * W2 69 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 70 | 71 | 72 | class MoEFeedForward(nn.Module): 73 | """ 74 | Sparse MoE Layer (集成 Expert Parallel) 75 | """ 76 | def __init__(self, args: ModelArgs): 77 | super().__init__() 78 | self.num_experts = args.moe_num_experts 79 | self.topk = args.moe_topk 80 | # 1. Gating Network (Router) 81 | # Router 通常不做 TP,因为输出维度很小 (num_experts) 82 | self.gate = nn.Linear(args.dim, self.num_experts, bias=False) 83 | # 2. Experts 84 | # 在 EP 模式下,每个 GPU 只持有 total_experts / world_size 个专家 85 | # 假设 args.ep_size 已经设置正确 86 | # 这里简化处理:我们创建 num_experts 个 MLP,但在 forward 时 EP 模块会负责路由 87 | # 实际生产中,这里应该只初始化 local_experts 88 | self.experts = nn.ModuleList([MLP(args) for _ in range(self.num_experts)]) 89 | # 3. EP 通信模块 90 | self.ep = ExpertParallel(self.num_experts) 91 | def forward(self, x): 92 | # x: [B, S, D] 93 | B, S, D = x.shape 94 | x_flat = x.view(-1, D) 95 | # 1. Routing 96 | router_logits = self.gate(x_flat) # [N, num_experts] 97 | probs = F.softmax(router_logits, dim=-1) 98 | weights, indices = torch.topk(probs, self.topk, dim=-1) # [N, k] 99 | # 2. EP Dispatch (分发到对应 GPU) 100 | # dispatched_x: [Total_Recv, D] 101 | # metadata: 用于恢复顺序 102 | dispatched_x, metadata = self.ep.dispatch(x_flat, indices) 103 | # 3. Computation (计算本地专家) 104 | # 这里是一个简化实现:实际应该根据 metadata 知道哪些 token 属于哪个专家 105 | # 为了代码跑通,假设所有 token 平均分配给本地专家 (仅作演示逻辑) 106 | # 真正的 MoE 实现需要在这里做复杂的 index select 107 | # 假设 dispatched_x 已经包含了所有需要本地计算的 token 108 | # 我们简单地通过第一个专家计算 (生产环境需要 loop over local experts) 109 | expert_output = self.experts[0](dispatched_x) 110 | # 4. EP Combine (聚合回原 GPU) 111 | combined_output = self.ep.combine(expert_output, metadata) 112 | # 5. 加权求和 (Weighted Sum) 113 | # combined_output: [N, k, D] 114 | # weights: [N, k] 115 | output = (combined_output * weights.unsqueeze(-1)).sum(dim=1) 116 | return output.view(B, S, D) 117 | 118 | 119 | class Attention(nn.Module): 120 | def __init__(self, args: ModelArgs): 121 | super().__init__() 122 | 123 | self.tp_size = args.tp_size if hasattr(args, 'tp_size') else 1 124 | 125 | assert args.n_heads % self.tp_size == 0, \ 126 | f"Tensor Parallelism requires n_heads ({args.n_heads}) " \ 127 | f"to be divisible by tp_size ({self.tp_size})." 128 | 129 | # 如果配置中没写 n_kv_heads,默认等于 n_heads (即退化为标准 MHA) 130 | self.n_heads = args.n_heads // self.tp_size 131 | 132 | global_n_kv_heads = args.n_kv_heads if args.n_kv_heads is not None else args.n_heads 133 | self.n_kv_heads = global_n_kv_heads // self.tp_size 134 | 135 | self.head_dim = args.dim // args.n_heads 136 | 137 | # 计算复制倍数,例如 32 / 8 = 4 138 | self.n_rep = self.n_heads // self.n_kv_heads 139 | 140 | ColLinear = get_linear_cls(args, 'col') 141 | RowLinear = get_linear_cls(args, 'row') 142 | 143 | # TP 模式下,这里的 dim 会被切分,所以传入 total dim 即可,Layer 内部会处理 144 | self.wq = ColLinear(args.dim, args.n_heads * self.head_dim, bias=False) 145 | # self.wk = ColLinear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 146 | # self.wv = ColLinear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 147 | self.wk = ColLinear(args.dim, global_n_kv_heads * self.head_dim, bias=False) 148 | self.wv = ColLinear(args.dim, global_n_kv_heads * self.head_dim, bias=False) 149 | self.wo = RowLinear(args.n_heads * self.head_dim, args.dim, bias=False) 150 | 151 | def forward(self, x, freqs_cis): 152 | B, S, _ = x.shape 153 | 154 | # 1. 投影 155 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 156 | 157 | # 打印 Linear 输出后的形状 158 | # print(f"[Debug] Linear Out - xq: {xq.shape}, xk: {xk.shape}") 159 | 160 | # 2. Reshape 161 | # xq = xq.view(B, S, -1, self.head_dim) 162 | # xk = xk.view(B, S, -1, self.head_dim) 163 | # xv = xv.view(B, S, -1, self.head_dim) 164 | 165 | xq = xq.view(B, S, self.n_heads, -1) 166 | xk = xk.view(B, S, self.n_heads, -1) 167 | xv = xv.view(B, S, self.n_heads, -1) 168 | 169 | # === 添加调试日志 === 170 | # import torch.distributed as dist 171 | # if dist.get_rank() == 0: # 只让 Rank 0 打印,避免刷屏 172 | # print(f"\n[Debug Rank 0] Attention Forward:") 173 | # print(f" Input x: {x.shape}") 174 | # print(f" self.n_heads: {self.n_heads}, self.head_dim: {self.head_dim}") 175 | # print(f" xq (reshaped): {xq.shape}") 176 | # print(f" xk (reshaped): {xk.shape}") 177 | # print(f" freqs_cis (input): {freqs_cis.shape}") 178 | # =================== 179 | 180 | # 3. Apply RoPE (旋转位置编码) 181 | # 注意:RoPE 是在 Attention 计算之前做的,且要在 repeat_kv 之前做 182 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis) 183 | 184 | # 4. GQA 核心步骤:如果 KV 头数少,就复制 185 | # 变换后 xk, xv 的形状将变为 [B, S, n_heads, D] 186 | if self.n_rep > 1: 187 | xk = repeat_kv(xk, self.n_rep) 188 | xv = repeat_kv(xv, self.n_rep) 189 | 190 | # 5. Flash Attention 191 | # 此时 xq, xk, xv 的维度完全对齐了 192 | # 需要转置为 [B, n_heads, S, D] 以符合 PyTorch API 要求 193 | output = F.scaled_dot_product_attention( 194 | xq.transpose(1, 2), 195 | xk.transpose(1, 2), 196 | xv.transpose(1, 2), 197 | is_causal=True 198 | ) 199 | 200 | # 6. 还原形状并输出 201 | output = output.transpose(1, 2).contiguous().view(B, S, -1) 202 | return self.wo(output) 203 | 204 | 205 | class TransformerBlock(nn.Module): 206 | def __init__(self, args: ModelArgs, layer_id: int): 207 | super().__init__() 208 | 209 | # 1. Attention (集成 CP) 210 | self.attention = Attention(args) 211 | if args.cp_size > 1: 212 | # 如果开启 CP,包裹 Attention 213 | # 这与 TP 是正交的:TP 切分权重,CP 切分 Sequence 214 | self.attention = ContextParallelAttention(self.attention, args) 215 | 216 | # 2. FeedForward (集成 MoE) 217 | # 策略:每 moe_layer_freq 层替换一个 MoE 218 | # 例如 freq=2: Layer 0 (Dense), Layer 1 (MoE), Layer 2 (Dense)... 219 | use_moe = (args.moe_num_experts > 1) and (layer_id % args.moe_layer_freq == 1) 220 | if use_moe: 221 | self.feed_forward = MoEFeedForward(args) 222 | else: 223 | self.feed_forward = MLP(args) 224 | 225 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 226 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 227 | 228 | def forward(self, x, freqs_cis): 229 | h = x + self.attention(self.attention_norm(x), freqs_cis) 230 | out = h + self.feed_forward(self.ffn_norm(h)) 231 | return out 232 | 233 | 234 | class LightronTransformer(nn.Module): 235 | def __init__(self, args: ModelArgs): 236 | super().__init__() 237 | self.args = args 238 | 239 | # TP Embedding 240 | if args.tp_size > 1: 241 | self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim) 242 | # Output 层通常也是 Column Parallel (Gather Output) 243 | self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, gather_output=True) 244 | else: 245 | self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) 246 | self.output = nn.Linear(args.dim, args.vocab_size, bias=False) 247 | 248 | # 传入 layer_id 以决定是否使用 MoE 249 | self.layers = nn.ModuleList([ 250 | TransformerBlock(args, layer_id=i) for i in range(args.n_layers) 251 | ]) 252 | self.norm = RMSNorm(args.dim, eps=args.norm_eps) 253 | 254 | # Precompute RoPE frequencies 注意:这里只计算一次,并在 forward 中根据当前 seq_len 切片 255 | freqs_cis = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len) 256 | self.register_buffer("freqs_cis", freqs_cis, persistent=False) 257 | # self.freqs_cis = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len) 258 | 259 | def forward(self, tokens): 260 | B, S = tokens.shape 261 | h = self.tok_embeddings(tokens) 262 | 263 | # 确保 freqs_cis 在同一设备 264 | # freqs_cis = self.freqs_cis[:S].to(h.device) 265 | freqs_cis = self.freqs_cis 266 | 267 | for layer in self.layers: 268 | h = layer(h, freqs_cis) 269 | 270 | h = self.norm(h) 271 | return self.output(h) 272 | --------------------------------------------------------------------------------