├── .gitignore ├── README.md ├── model.py ├── stream.py ├── LICENSE ├── tp.py ├── pp.py └── dp.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # parallel-demo 2 | 使用`torch.distributed`包实现DP/TP/PP 3 | 4 | 命令`torchrun --nproc_per_node=2 (dp|tp|pp).py`即可运行 5 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Net(nn.Module): 6 | def __init__(self, in_dim, out_dim, hid_dim): 7 | super().__init__() 8 | self.w1 = nn.Parameter(torch.randn(in_dim, hid_dim) * (2 / in_dim) ** 0.5) 9 | self.w2 = nn.Parameter(torch.randn(hid_dim, out_dim) * (2 / hid_dim) ** 0.5) 10 | 11 | def forward(self, x: torch.Tensor): 12 | return (x @ self.w1).relu() @ self.w2 13 | -------------------------------------------------------------------------------- /stream.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | size = 1000000 4 | one = torch.ones(size, device="cuda") 5 | 6 | x = torch.zeros(size, device="cuda") 7 | y = torch.zeros(size, device="cuda") 8 | x.add_(one) 9 | y.add_(x) 10 | print("x", x) 11 | print("y", y) 12 | 13 | s = torch.cuda.Stream() 14 | x = torch.zeros(size, device="cuda") 15 | y = torch.zeros(size, device="cuda") 16 | for _ in range(10000): 17 | x.add_(one) 18 | with torch.cuda.stream(s): 19 | y.add_(x) 20 | print("x", x) 21 | print("y", y) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Xingkai Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from model import Net 5 | 6 | 7 | class LinearWithAsyncComm(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, input, weight): 10 | ctx.save_for_backward(input, weight) 11 | output = input @ weight 12 | return output 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | input, weight = ctx.saved_tensors 17 | grad_input = grad_output @ weight.t() 18 | handle = dist.all_reduce(grad_input, async_op=True) 19 | # input和output可能是多维,但weight肯定是二维 20 | grad_weight = input.t().view(weight.size(0), -1) @ grad_output.view(-1, weight.size(1)) 21 | handle.wait() 22 | return grad_input, grad_weight 23 | 24 | 25 | class AllReduce(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, input): 28 | dist.all_reduce(input) 29 | return input 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | return grad_output 34 | 35 | 36 | class TPNet(Net): 37 | def forward(self, x): 38 | x = LinearWithAsyncComm.apply(x, self.w1) 39 | x = x.relu() @ self.w2 40 | AllReduce.apply(x) 41 | return x 42 | 43 | @classmethod 44 | def new_tp(cls, module: Net): 45 | world_size = dist.get_world_size() 46 | rank = dist.get_rank() 47 | device = next(module.parameters()).device 48 | in_dim, out_dim, hid_dim = module.w1.size(0), module.w2.size(1), module.w1.size(1) 49 | tp_module = cls(in_dim, out_dim, hid_dim // world_size).to(device) 50 | tp_module.w1.data.copy_(module.w1.data.chunk(world_size, dim=1)[rank]) 51 | tp_module.w2.data.copy_(module.w2.data.chunk(world_size, dim=0)[rank]) 52 | return tp_module 53 | 54 | 55 | if __name__ == '__main__': 56 | dist.init_process_group("nccl") 57 | torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0))) 58 | torch.manual_seed(666) 59 | torch.cuda.manual_seed_all(666) 60 | 61 | net = Net(64, 10, 128).cuda() 62 | X = torch.randn(32, 64, device="cuda") 63 | Y = net(X) 64 | Y.mean().backward() 65 | print(Y[:, -1]) 66 | # print(net.w1.grad) 67 | net.zero_grad() 68 | 69 | net = TPNet.new_tp(net) 70 | Y = net(X) 71 | Y.mean().backward() 72 | print(Y[:, -1]) 73 | # print(net.w1.grad) 74 | -------------------------------------------------------------------------------- /pp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from model import Net 5 | 6 | 7 | class Send(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, input, rank): 10 | ctx.rank = rank 11 | dist.send(input, rank) 12 | return input 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | rank = ctx.rank 17 | dist.recv(grad_output, rank) 18 | return grad_output, None 19 | 20 | 21 | class Recv(torch.autograd.Function): 22 | @staticmethod 23 | def forward(ctx, input, rank): 24 | ctx.rank = rank 25 | dist.recv(input, rank) 26 | return input 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | rank = ctx.rank 31 | dist.send(grad_output, rank) 32 | return grad_output, None 33 | 34 | 35 | class Pipe(torch.nn.Module): 36 | def __init__(self, module: torch.nn.Sequential, shape, chunks=1): 37 | super().__init__() 38 | self.world_size = dist.get_world_size() 39 | self.rank = dist.get_rank() 40 | self.is_first = self.rank == 0 41 | self.is_last = self.rank == self.world_size - 1 42 | size = len(module) // self.world_size 43 | offset = size * self.rank 44 | self.module = module[offset:offset+size] 45 | self.chunks = chunks 46 | shape = list(shape) 47 | shape[0] //= chunks 48 | self.shape = shape 49 | 50 | def forward(self, x: torch.Tensor): 51 | ys = [] 52 | xs = x.chunk(self.chunks) 53 | for x in xs: 54 | if not self.is_first: 55 | x = x.new_empty(self.shape).requires_grad_() 56 | x = Recv.apply(x, self.rank - 1) 57 | y = self.module(x) 58 | if not self.is_last: 59 | y = Send.apply(y, self.rank + 1) 60 | ys.append(y) 61 | return torch.cat(ys) 62 | 63 | 64 | if __name__ == '__main__': 65 | dist.init_process_group("nccl") 66 | torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0))) 67 | torch.manual_seed(666) 68 | torch.cuda.manual_seed_all(666) 69 | 70 | num_layers, in_dim, out_dim, hid_dim, inter_dim = 8, 64, 10, 128, 256 71 | bs, chunks = 32, 8 72 | layers = [] 73 | layers.append(Net(in_dim, hid_dim, inter_dim)) 74 | for _ in range(num_layers - 2): 75 | layers.append(Net(hid_dim, hid_dim, inter_dim)) 76 | layers.append(Net(hid_dim, out_dim, inter_dim)) 77 | 78 | net = torch.nn.Sequential(*layers).cuda() 79 | X = torch.randn(bs, in_dim, device="cuda") 80 | Y = net(X) 81 | Y.mean().backward() 82 | print(Y[:, -1]) 83 | # print(net[0].w1.grad) 84 | net.zero_grad() 85 | 86 | net = Pipe(net, (bs, hid_dim), chunks).cuda() 87 | Y = net(X) 88 | if net.is_last: 89 | Y.mean().backward() 90 | else: 91 | Y.backward(torch.empty_like(Y)) 92 | if net.is_last: 93 | print(Y[:, -1]) 94 | # if net.is_first: 95 | # print(net.module[0].w1.grad) 96 | -------------------------------------------------------------------------------- /dp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch import Tensor 5 | from model import Net 6 | 7 | 8 | class DDP(torch.nn.Module): 9 | MIN_BUCKET_SIZE = 1024 * 1024 10 | 11 | def __init__(self, module: torch.nn.Module): 12 | super().__init__() 13 | self.module = module 14 | self.buckets: list[Tensor] = [] # bucket用于保存梯度和同步 15 | self.comm_stream = torch.cuda.Stream() 16 | 17 | num_params = len(list(module.parameters())) 18 | bucket_params: list[Tensor] = [] # 一个bucket对应的参数列表 19 | bucket_size = 0 20 | 21 | for idx, param in enumerate(reversed(list(module.parameters()))): 22 | if not param.requires_grad: 23 | continue 24 | bucket_size += param.numel() 25 | bucket_params.append(param) 26 | if bucket_size < DDP.MIN_BUCKET_SIZE and idx + 1 < num_params: 27 | continue 28 | # 攒满bucket或者已经是最后一个参数 29 | bucket = bucket_params[0].new_zeros(bucket_size) 30 | bucket.ready = False 31 | offset = 0 32 | for param in bucket_params: 33 | param.grad = bucket[offset:offset+param.numel()].view_as(param) 34 | offset += param.numel() 35 | param.register_post_accumulate_grad_hook(self.make_hook(param, bucket, bucket_params)) 36 | param.ready = False 37 | self.buckets.append(bucket) 38 | bucket_params = [] 39 | bucket_size = 0 40 | 41 | def make_hook(self, param: Tensor, bucket: Tensor, bucket_params): 42 | def hook(*args): 43 | param.ready = True 44 | if all(p.ready for p in bucket_params): 45 | self.comm_stream.wait_stream(torch.cuda.current_stream()) 46 | with torch.cuda.stream(self.comm_stream): 47 | dist.all_reduce(bucket, dist.ReduceOp.AVG) 48 | bucket.ready = True 49 | 50 | if all(b.ready for b in self.buckets): 51 | torch.cuda.current_stream().wait_stream(self.comm_stream) 52 | for p in self.module.parameters(): 53 | if p.requires_grad: 54 | p.ready = False 55 | for b in self.buckets: 56 | b.ready = False 57 | return hook 58 | 59 | def forward(self, *args, **kwargs): 60 | return self.module(*args, **kwargs) 61 | 62 | 63 | if __name__ == '__main__': 64 | dist.init_process_group("nccl") 65 | torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0))) 66 | torch.manual_seed(666) 67 | torch.cuda.manual_seed_all(666) 68 | 69 | net = Net(64, 10, 128).cuda() 70 | X = torch.randn(32, 64, device="cuda") 71 | Y = net(X) 72 | Y.mean().backward() 73 | print(Y[:, -1]) 74 | # print(net.w1.grad) 75 | net.zero_grad() 76 | 77 | net = DDP(net) 78 | X = X.chunk(dist.get_world_size())[dist.get_rank()] 79 | Y = net(X) 80 | Y.mean().backward() 81 | print(Y[:, -1]) 82 | # print(net.module.w1.grad) 83 | --------------------------------------------------------------------------------