├── LICENSE ├── ModelParallel.py ├── README.md └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Kai Tian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ModelParallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import threading 4 | 5 | 6 | def distribute_module(module, device): 7 | return module.cuda(device) 8 | 9 | 10 | def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): 11 | assert len(modules) == len(inputs) 12 | if kwargs_tup is not None: 13 | assert len(kwargs_tup) == len(modules) 14 | else: 15 | kwargs_tup = ({},) * len(modules) 16 | if devices is not None: 17 | assert len(modules) == len(devices) 18 | else: 19 | raise VauleError('devices is None') 20 | 21 | lock = threading.Lock() 22 | results = {} 23 | #grad_enabled = torch.is_grad_enabled() 24 | 25 | def _worker(i, module, input, kwargs, device=None): 26 | # torch.set_grad_enabled(grad_enabled) 27 | try: 28 | with torch.cuda.device(device): 29 | output = module(input) 30 | with lock: 31 | results[i] = output 32 | except Exception as e: 33 | with lock: 34 | results[i] = e 35 | 36 | if len(modules) > 1: 37 | threads = [threading.Thread(target=_worker, 38 | args=(i, module, input, kwargs, device)) 39 | for i, (module, input, kwargs, device) in 40 | enumerate(zip(modules, inputs, kwargs_tup, devices))] 41 | for thread in threads: 42 | thread.start() 43 | for thread in threads: 44 | thread.join() 45 | else: 46 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 47 | 48 | outputs = [] 49 | for i in range(len(inputs)): 50 | output = results[i] 51 | if isinstance(output, Exception): 52 | raise output 53 | outputs.append(output) 54 | return outputs 55 | 56 | 57 | class ModelParallel(nn.Module): 58 | 59 | def __init__(self, model, device_ids=None, output_device=None): 60 | super(ModelParallel, self).__init__() 61 | 62 | if not torch.cuda.is_available(): 63 | self.module = model 64 | self.device_ids = [] 65 | return 66 | 67 | if device_ids is None: 68 | device_ids = list(range(torch.cuda.device_count())) 69 | if not hasattr(model, 'module'): 70 | raise ValueError("model does not has module attribute") 71 | if len(device_ids) < len(model.module): 72 | print('warning: number of devices is not enough for module parallel') 73 | else: 74 | device_ids = device_ids[:len(model.module)] 75 | 76 | if output_device is None: 77 | output_device = device_ids[0] 78 | self.output_device = output_device 79 | self.device_ids = device_ids 80 | self.module = model.module # module is a list 81 | self.distribute(self.module, device_ids) 82 | 83 | def forward(self, *inputs, **kwargs): 84 | if not self.device_ids: 85 | return self.module(*inputs, **kwargs) 86 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 87 | 88 | if len(self.device_ids) == 1: 89 | return self.module(*inputs[0], **kwargs[0]) 90 | 91 | outputs = self.parallel_apply(self.module, inputs, kwargs) 92 | return self.gather(outputs, self.output_device) 93 | 94 | def distribute(self, module, device_ids): 95 | return [distribute_module(m, id) for m, id in zip(module, device_ids)] 96 | 97 | def scatter(self, inputs, kwargs, device_ids): 98 | if len(inputs) == 1: 99 | inputs = [inputs[0].cuda(id) for id in device_ids] 100 | else: 101 | inputs = [input.cuda(id) for input, id in zip(inputs, device_ids)] 102 | kwargs = None 103 | inputs = tuple(inputs) 104 | return inputs, kwargs 105 | 106 | def parallel_apply(self, replicas, inputs, kwargs): 107 | return parallel_apply(replicas, inputs, kwargs, self.device_ids) 108 | 109 | def gather(self, outputs, output_device): 110 | outputs = [output.cuda(output_device) for output in outputs] 111 | return outputs 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model Parallelism 2 | Model Parallelism for pytorch training multiple networks on multiple GPUs. 3 | 4 | ## ToDo List 5 | - [ ] Handle different kwargs for different networks 6 | 7 | # Usage 8 | Model parallel is a wrapper for training multiple networks on multi-GPU simultaneously. Such as training ensemble models or multiple choice learning networks. 9 | 10 | Unlike data parallel, the outputs of model parallel is a list for general purpose. 11 | 12 | ```python 13 | # First define a ensemble module 14 | import torch 15 | import torch.nn as nn 16 | import torchvision.models as models 17 | from ModelParallel import ModelParallel 18 | 19 | 20 | class Ensemble(nn.Module): 21 | def __init__(self, m): 22 | super(Ensemble, self).__init__() 23 | self.m = m 24 | self.module = nn.ModuleList([models.resnet50() for _ in range(m)]) 25 | 26 | def forward(self, input): 27 | return [self.module[i](input) for i in range(self.m)] 28 | 29 | model = Ensemble(4) 30 | model = ModelParallel(model, device_ids=[0, 1, 2, 3], output_device=0) 31 | 32 | x = torch.rand(128, 3, 224, 224) 33 | y = model(Variable(x)) 34 | 35 | ``` 36 | 37 | ## Useful links 38 | Some multithreading code is borrowed from [pytorch data parallel](https://github.com/pytorch/pytorch/blob/v0.3.1/torch/nn/parallel/parallel_apply.py) 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.autograd import Variable 4 | from ModelParallel import ModelParallel 5 | import torchvision.models as models 6 | import time 7 | 8 | 9 | class SmallNet(nn.Module): 10 | def __init__(self): 11 | super(SmallNet, self).__init__() 12 | 13 | self.linear = nn.Sequential(nn.Linear(30, 1000), 14 | nn.Sigmoid(), 15 | nn.Linear(1000, 3000), 16 | nn.Sigmoid(), 17 | nn.Linear(3000, 2) 18 | ) 19 | 20 | def forward(self, input): 21 | return self.linear(input) 22 | 23 | 24 | class Ensemble(nn.Module): 25 | def __init__(self, m, mode='small'): 26 | super(Ensemble, self).__init__() 27 | if mode == 'small': 28 | self.module = nn.ModuleList([SmallNet() for i in range(m)]) 29 | elif mode == 'large': 30 | self.module = nn.ModuleList([models.resnet50() for i in range(m)]) 31 | 32 | def forward(self, input): 33 | return [module(input) for module in self.module] 34 | 35 | 36 | def test_model_parallel(mode='small'): 37 | ensemble = Ensemble(4, mode) 38 | 39 | model = ModelParallel(ensemble, device_ids=[0, 1, 2, 3], output_device=0) 40 | if mode == 'small': 41 | input = Variable(torch.rand(512, 30)) 42 | elif mode == 'large': 43 | input = Variable(torch.rand(128, 3, 224, 224)) 44 | 45 | end = time.time() 46 | y = model(input) 47 | print('using model parallel') 48 | print('time : ', time.time() - end) 49 | 50 | 51 | def test_without_parallel(mode='small'): 52 | ensemble = Ensemble(4, mode) 53 | [ensemble.module[i].cuda(i) for i in range(4)] 54 | 55 | if mode == 'small': 56 | input = Variable(torch.rand(512, 30)) 57 | elif mode == 'large': 58 | input = Variable(torch.rand(128, 3, 224, 224)) 59 | 60 | end = time.time() 61 | y = [ensemble.module[i](input.cuda(i)) for i in range(4)] 62 | print('without model parallel') 63 | print('time: ', time.time() - end) 64 | 65 | 66 | if __name__ == '__main__': 67 | # on small net 68 | test_model_parallel('small') 69 | test_without_parallel('small') 70 | 71 | # on imagenet resnet50 72 | test_model_parallel('large') 73 | test_without_parallel('large') 74 | --------------------------------------------------------------------------------