├── README.md ├── data_parallel.py ├── data_parallel_my.py └── data_parallel_my_v2.py /README.md: -------------------------------------------------------------------------------- 1 | # Balanced-DataParallel 2 | 这里是改进了pytorch的DataParallel, 用来平衡第一个GPU的显存使用量 3 | 4 | 本代码来自transformer-XL:https://github.com/kimiyoung/transformer-xl 5 | 6 | 代码不是本人写的, 但是感觉很好用, 就分享一下. 7 | 8 | # 怎么使用: 9 | 10 |   这个 `BalancedDataParallel` 类使用起来和 `DataParallel` 类似, 下面是一个示例代码: 11 | 12 | ``` 13 | my_net = MyNet() 14 | my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda() 15 | ``` 16 | 17 |   这里包含三个参数, 第一个参数是第一个GPU要分配多大的batch_size, 但是要注意, 如果你使用了梯度累积, 那么这里传入的是每次进行运算的实际batch_size大小. 举个例子, 比如你在3个GPU上面跑代码, 但是一个GPU最大只能跑3条数据, 但是因为0号GPU还要做一些数据的整合操作, 于是0号GPU只能跑2条数据, 这样一算, 你可以跑的大小是2+3+3=8, 于是你可以设置下面的这样的参数: 18 | 19 | ``` 20 | batch_szie = 8 21 | gpu0_bsz = 2 22 | acc_grad = 1 23 | my_net = MyNet() 24 | my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda() 25 | ``` 26 | 27 |   这个时候突然想跑个batch size是16的怎么办呢, 那就是4+6+6=16了, 这样设置累积梯度为2就行了: 28 | 29 | 30 | ``` 31 | batch_szie = 16 32 | gpu0_bsz = 4 33 | acc_grad = 2 34 | my_net = MyNet() 35 | my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda() 36 | 37 | ``` 38 | 39 | ### 各个版本的data_parallel 40 | 41 | - data_parallel.py: 原作者的代码, 但是使用的时候发现, 如果batch size设置的小于GPU的数量, 会导致最后一个批次的数据分配的不足以所有的GPU分配, 然后报错. 42 | 43 | - data_parallel_my.py: 我稍微改了一点, 然后稍微测试了一下, 应该是解决了上面的问题. 44 | 45 | - data_parallel_my_v2.py:上面第一个版本的修改,导致无法设置gpu0_bsz=0,这个版本应该是修复这个问题了 46 | -------------------------------------------------------------------------------- /data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | 91 | # print('bsz: ', bsz) 92 | # print('num_dev: ', num_dev) 93 | # print('gpu0_bsz: ', gpu0_bsz) 94 | # print('bsz_unit: ', bsz_unit) 95 | # print('chunk_sizes: ', chunk_sizes) 96 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 97 | 98 | -------------------------------------------------------------------------------- /data_parallel_my.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | # print('len(inputs)1: ', str(len(inputs))) 66 | # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) 67 | if len(self.device_ids) == 1: 68 | return self.module(*inputs[0], **kwargs[0]) 69 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 70 | if self.gpu0_bsz == 0: 71 | replicas = replicas[1:] 72 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 73 | return self.gather(outputs, self.output_device) 74 | 75 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 76 | return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) 77 | 78 | def scatter(self, inputs, kwargs, device_ids): 79 | bsz = inputs[0].size(self.dim) 80 | num_dev = len(self.device_ids) 81 | gpu0_bsz = self.gpu0_bsz 82 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 83 | if gpu0_bsz < bsz_unit: 84 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 85 | delta = bsz - sum(chunk_sizes) 86 | for i in range(delta): 87 | chunk_sizes[i + 1] += 1 88 | if gpu0_bsz == 0: 89 | chunk_sizes = chunk_sizes[1:] 90 | else: 91 | return super().scatter(inputs, kwargs, device_ids) 92 | 93 | # print('bsz: ', bsz) 94 | # print('num_dev: ', num_dev) 95 | # print('gpu0_bsz: ', gpu0_bsz) 96 | # print('bsz_unit: ', bsz_unit) 97 | # print('chunk_sizes: ', chunk_sizes) 98 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 99 | 100 | -------------------------------------------------------------------------------- /data_parallel_my_v2.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | 66 | print('len(inputs): ', str(len(inputs))) 67 | print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) 68 | 69 | if len(self.device_ids) == 1: 70 | return self.module(*inputs[0], **kwargs[0]) 71 | if self.gpu0_bsz == 0: 72 | replicas = self.replicate(self.module, self.device_ids) 73 | else: 74 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 75 | 76 | # replicas = self.replicate(self.module, device_ids[:len(inputs)]) 77 | if self.gpu0_bsz == 0: 78 | replicas = replicas[1:] 79 | 80 | print('replicas:', str(len(replicas))) 81 | 82 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 83 | return self.gather(outputs, self.output_device) 84 | 85 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 86 | return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) 87 | 88 | def scatter(self, inputs, kwargs, device_ids): 89 | bsz = inputs[0].size(self.dim) 90 | num_dev = len(self.device_ids) 91 | gpu0_bsz = self.gpu0_bsz 92 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 93 | if gpu0_bsz < bsz_unit: 94 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 95 | delta = bsz - sum(chunk_sizes) 96 | for i in range(delta): 97 | chunk_sizes[i + 1] += 1 98 | if gpu0_bsz == 0: 99 | chunk_sizes = chunk_sizes[1:] 100 | else: 101 | return super().scatter(inputs, kwargs, device_ids) 102 | 103 | print('bsz: ', bsz) 104 | print('num_dev: ', num_dev) 105 | print('gpu0_bsz: ', gpu0_bsz) 106 | print('bsz_unit: ', bsz_unit) 107 | print('chunk_sizes: ', chunk_sizes) 108 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 109 | 110 | --------------------------------------------------------------------------------