├── .gitignore ├── Makefile ├── README.md ├── Serialization-timing.ipynb ├── __init__.py ├── mpi_comms.py ├── ps.py ├── serialization.py ├── test_comms.py ├── test_iallgather.py └── test_mpi.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | test: 3 | mpirun -n 2 py.test -s 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Why? 2 | 3 | ### Features 4 | It fits for our problem constraints: 5 | 6 | * models fit on one device 7 | * communication is reliable 8 | * only one PS required 9 | * have MPI available. 10 | 11 | Additionally, this implementation of a PS 12 | 13 | * uses MPI, a declarative standard common in clusters. 14 | * can allow compression if concerned about bandwidth. 15 | 16 | ### Limitations 17 | or, why use other PS systems? You may... 18 | 19 | * want multiple PSs. 20 | * have models that may not fit on one machine. 21 | * not want to use MPI. 22 | 23 | ### Why not torch.distributed? 24 | * torch.distributed only sends tensors. We may want to send generic Python 25 | objects. 26 | * torch.distributed does not have any `Igatherv`, which is what we want for a 27 | PS. 28 | 29 | ## Notes 30 | * `MPI_Gather` requires the recv count of any particular message. 31 | 1. Send sizes before, allocate largest size (and then trim down per size) 32 | 2. (when recieinv from any source) 33 | * Tensorflow does... (see tensorflow's docs for ConditionalAccumulator) 34 | 1. make sure gradients are current 35 | 2. waits until $n$ gradients have been received from any number of workers 36 | 37 | ## plan 38 | Implement a parameter server with mpi4py 39 | 40 | 0. Call encode 41 | 1. Convert to pickle object 42 | 2. (optional) compress 43 | 4. Allocate largest size expected + some padding (to see if overflowed) 44 | 5. Convert to NumPy buffer, use Igather 45 | 6. Convert back to pickle object when needed 46 | 7. Send dictionary to decode 47 | 48 | ## Notes 49 | * An MPI enabled parameter server will not fit models that cannot fit on every 50 | machine. 51 | * Though we can safely (?) assume that data for every machine is costant 52 | 53 | ## Resources 54 | * Good summary of parameter servers at http://hunch.net/?p=151364 55 | 56 | ## Async PS psuedo-code 57 | * This is algorithm AsySG-InCon (for inconsistent reads) in [1] 58 | * [1]:Asynchronous parallel stochastic gradient for nonconvex optimization, 59 | https://arxiv.org/abs/1506.08272 60 | 61 | ``` python 62 | # optimizaiton step function 63 | irequest_params() 64 | for p in params: 65 | if rank == 0: 66 | comms = [] 67 | while True: 68 | comms += [recv(MPI.ANY_SOURCE)] 69 | if len(comms) == 32: 70 | break 71 | params = [receive(comm) for comm in comms] 72 | p = sum(params) 73 | step() 74 | else: 75 | send(param, 0) 76 | req = ibcast(p.data) 77 | ``` 78 | 79 | It's easy to impelment inconsistent reads with `ibcast`. It's harder to do 80 | consistent reads; we need to do a buffered broadcast to allow workers to 81 | continue to compute gradients. 82 | 83 | ## Iallgatherv 84 | I tried playing with `Iallgatherv`. Some problems: 85 | 86 | * Does not support sending objects of *unknown* sizes. Every rank needs to know 87 | the size of every object it's receiving. 88 | * Sending the objects of these types involves sending a tuple as shown in [the 89 | implementation][allgather-impl] and [the tests][allgather-tests] 90 | 91 | [allgather-tests]:https://github.com/mpi4py/mpi4py/blob/bd5278b232bde9f40247c3af1a8aed6166e7cbcf/test/test_cco_nb_vec.py#L192 92 | [allgather-impl]:https://github.com/mpi4py/mpi4py/blob/3fd4dbd57b54f412e28b84aa6f77fb440c120f7d/test/arrayimpl.py#L99 93 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .ps import MPI_PS, Adam, SGD 2 | -------------------------------------------------------------------------------- /mpi_comms.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import functools 4 | import zlib 5 | import blosc 6 | from mpi4py import MPI 7 | import torch 8 | import numpy as np 9 | import time 10 | 11 | comm = MPI.COMM_WORLD 12 | rank = comm.Get_rank() 13 | size = comm.Get_size() 14 | 15 | max_bytes = {} 16 | 17 | 18 | def compress(msg, level=0, name='blosclz'): 19 | """ 20 | Compress a message. 21 | """ 22 | if name in {'lz4', 'snappy'}: 23 | raise ValueError('Do not specify lz4 or snappy. I ran into hard to ' 24 | 'debug issues when I did this. blosclz seems to work') 25 | code = blosc.compress(msg, clevel=level, cname=name) 26 | return bytearray(code) 27 | 28 | def decompress(code): 29 | msg = blosc.decompress(code) 30 | return msg 31 | 32 | def to_np(d): 33 | if isinstance(d, torch.cuda.FloatTensor): 34 | return d.cpu().numpy() 35 | if isinstance(d, torch.Tensor): 36 | return d.numpy() 37 | if isinstance(d, dict): 38 | return {k: to_np(v) for k, v in d.items()} 39 | if isinstance(d, list): 40 | return list(map(to_np, d)) 41 | if isinstance(d, map): 42 | return map(to_np, d) 43 | return d 44 | 45 | 46 | def to_torch(d, cuda=False): 47 | if isinstance(d, np.ndarray): 48 | d = torch.Tensor(d) 49 | if cuda: 50 | d = d.cuda(async=True) 51 | return d 52 | if isinstance(d, dict): 53 | return {k: to_torch(v, cuda=cuda) for k, v in d.items()} 54 | if isinstance(d, list): 55 | return list(map(functools.partial(to_torch, cuda=cuda), d)) 56 | if isinstance(d, map): 57 | return map(functools.partial(to_torch, cuda=cuda), d) 58 | return d 59 | 60 | def igather(obj, name=""): 61 | """ 62 | Gathers a python object to a root node. 63 | 64 | Returns 65 | ======= 66 | objs : list[obj] 67 | List of objects from each worker 68 | req : MPI.REQUEST 69 | Supports req.Wait function 70 | """ 71 | global max_bytes 72 | obj = to_np(obj) 73 | t = [time.time()] 74 | pickled = pickle.dumps(obj) 75 | t += [time.time()] 76 | send = bytearray(pickled) 77 | send = compress(send) 78 | t += [time.time()] 79 | 80 | send += bytearray(b'\x29'*32) 81 | 82 | max_bytes[name] = max(max_bytes.get(name, 0), (len(send) + 1) * 10) 83 | max_bytes[name] = max(max_bytes[name], 1024 * 15) 84 | # print(len(send), max_bytes[name]) 85 | recv = bytearray(max_bytes[name] * size) 86 | # print(max_bytes[name]) 87 | t += [time.time()] 88 | req = comm.Igatherv([send, MPI.BYTE], [recv, MPI.BYTE]) 89 | t += [time.time()] 90 | return recv, req, {'pickle_time': t[1] - t[0], 'compress_time': t[2] - t[1], 91 | 'alloc_time': t[3] - t[2], 92 | 'igather_time': t[4] - t[3], 93 | 'alloc_bytes': max_bytes[name]} 94 | 95 | 96 | def trim_msg(msg): 97 | """ 98 | msg : bytearray 99 | Somewhere in msg, 32 elements are 0x29. Returns the msg before that 100 | """ 101 | i = msg.find(b'\x29'*32) 102 | if i == -1: 103 | raise Exception('trim_msg error; end of msg not found') 104 | return msg[:i] 105 | 106 | 107 | def irecv(recv, req, name="", cuda=False): 108 | global max_bytes 109 | if rank == 0: 110 | req.Wait() 111 | bytes_ = max_bytes[name] 112 | msgs = [recv[bytes_*n:bytes_*(n+1)] for n in range(size)] 113 | msgs = map(trim_msg, msgs) 114 | msgs = map(blosc.decompress, msgs) 115 | objs = map(pickle.loads, msgs) 116 | objs = map(functools.partial(to_torch, cuda=cuda), objs) 117 | return list(objs) 118 | 119 | 120 | def irecv1(recv, req, cuda=False): 121 | req.Wait() 122 | recv = decompress(recv) 123 | obj = pickle.loads(recv) 124 | return to_torch(obj, cuda=cuda) 125 | 126 | 127 | def ibroadcast(obj): 128 | obj = to_np(obj) 129 | pickled = pickle.dumps(obj) 130 | send = bytearray(pickled) 131 | send = compress(send) 132 | req = comm.Ibcast([send, MPI.BYTE]) 133 | return send, req 134 | 135 | def to_mpi_v(v, counts, dtype=MPI.BYTE): 136 | displacements = [sum(counts[:i]) for i in range(len(counts))] 137 | return (v, (counts, displacements), dtype) 138 | 139 | 140 | def to_mpi(v, dtype=MPI.BYTE): 141 | return (v, dtype) 142 | 143 | 144 | class Iallgather: 145 | def __init__(self): 146 | self.comm = MPI.COMM_WORLD 147 | self.rank = comm.Get_rank() 148 | self.size = comm.Get_size() 149 | 150 | def _get_counts(self, rank_size): 151 | rank_size = np.array(rank_size, dtype=np.int32) 152 | counts = np.zeros(size, dtype=np.int32) 153 | req = self.comm.Iallgather(rank_size, counts) 154 | return req, counts 155 | 156 | def prepare(self, counts): 157 | responses = map(self._get_counts, counts) 158 | return list(responses) 159 | 160 | def send(self, send, counts): 161 | recv = bytearray(sum(counts)) 162 | req = self.comm.Iallgatherv(to_mpi(send), to_mpi_v(recv, counts)) 163 | return recv, req, counts 164 | 165 | def recv(self, recv, req, counts, cuda=False): 166 | displacements = [sum(counts[:i]) for i in range(len(counts))] 167 | req.Wait() 168 | msgs = [recv[displacements[i]:displacements[i+1]] 169 | for i in range(len(displacements) - 1)] 170 | msgs += [recv[displacements[-1]:]] 171 | msgs = map(blosc.decompress, msgs) 172 | objs = map(pickle.loads, msgs) 173 | objs = map(to_np, objs) 174 | return list(objs) 175 | 176 | def print_summary(flat_dict): 177 | string = " {" 178 | for k, v in flat_dict.items(): 179 | if isinstance(v, (torch.Tensor, torch.cuda.FloatTensor, np.ndarray)): 180 | string += f"{k}: {v.shape}, " 181 | else: 182 | string += f"{k}: {v}, " 183 | string += "}" 184 | print(string) 185 | 186 | def format_for_send(obj): 187 | code = to_np(obj) 188 | pickled = pickle.dumps(code) 189 | send = bytearray(pickled) 190 | # TODO: get sizes from all other machines here (will reduce the straggler 191 | # effect) 192 | packaged = compress(send) 193 | return packaged, {'msg_bytes': len(send), 'packaged_bytes':len(packaged)} -------------------------------------------------------------------------------- /ps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from functools import partial 4 | from toolz import reduce 5 | import os 6 | import sys 7 | import math 8 | from collections import OrderedDict 9 | __dir__ = "/".join(__file__.split('/')[:-1]) 10 | sys.path.append(__dir__) 11 | import mpi_comms as comms 12 | from mpi4py import MPI 13 | import pickle 14 | from distributed import Client, LocalCluster 15 | from pprint import pprint 16 | sys.path.append('..') 17 | sys.path.append('.') 18 | import codings 19 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 20 | import concurrent.futures 21 | import cProfile, pstats, io 22 | # import multiprocessing as mp 23 | # mp.set_start_method('forkserver') 24 | 25 | def _bytes_of(obj): 26 | # BUG: for 2D arrays doesn't return the number of bytes 27 | # that is, when sizes printed, only 1D sizes printed 28 | if isinstance(obj, torch.autograd.Variable): 29 | print('autograd variable') 30 | return _bytes_of(obj.grad) + obj.element_size()*obj.numel() 31 | cuda_tensor = getattr(obj, 'cuda', False) 32 | if isinstance(obj, torch.Tensor) or cuda_tensor: 33 | # t_size is a lower bound; only the number of elements 34 | t_size = obj.element_size() * obj.numel() 35 | # py_size = sys.getsizeof(obj) 36 | return t_size 37 | 38 | if isinstance(obj, dict): 39 | return sum([_bytes_of(v) for k, v in obj.items()]) 40 | if isinstance(obj, tuple) or isinstance(obj, list): 41 | return sum([_bytes_of(v) for v in obj]) 42 | 43 | return sys.getsizeof(obj) # only counting tensors as stores 44 | 45 | 46 | def find_param(params, name): 47 | matches = [p for p in params if p.name == name] 48 | if len(matches) > 1: 49 | raise ValueError('More than one name found') 50 | return matches[0] 51 | 52 | 53 | class MPI_PS(torch.optim.Optimizer): 54 | def __init__(self, named_params, *args, 55 | names=[], 56 | optim='sgd', 57 | code=None, 58 | use_mpi=True, cuda=False, 59 | **kwargs): 60 | self.code = code 61 | self.optim = optim 62 | 63 | for i, (name, param) in enumerate(named_params): 64 | param.name = name 65 | param.register_hook(partial(self.async_code, name=name, 66 | encode=code.encode)) 67 | self.use_mpi = use_mpi 68 | self.names = names 69 | self.cuda = cuda 70 | 71 | self.comm = MPI.COMM_WORLD 72 | self.rank = self.comm.Get_rank() 73 | self.size = self.comm.Get_size() 74 | self.steps = 0 75 | self.iallgather = comms.Iallgather() 76 | super(MPI_PS, self).__init__(*args, **kwargs) 77 | 78 | self.recv_msgs = {} 79 | self.msgs = {} 80 | self.timings = [] 81 | self.futures = [] 82 | self.msgs = {} 83 | self.names = [] 84 | # self.encode_timings = [] 85 | self.pool = ThreadPoolExecutor(max_workers=200) 86 | # self.pool = ProcessPoolExecutor() 87 | # TODO: look into this, chanage to processes 88 | 89 | def __exit(self): 90 | self.pool.shutdown() 91 | 92 | def format_for_send(self, grad, encode=None, format=comms.format_for_send, 93 | **kwargs): 94 | code = encode(grad.data, **kwargs) 95 | msg, data = format(code) 96 | return msg, data 97 | 98 | def async_code(self, grad, *args, name=None, **kwargs): 99 | future = self.pool.submit(self.format_for_send, grad, *args, **kwargs) 100 | self.futures += [future] 101 | self.names += [name] 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | 110 | loss = None 111 | if closure is not None: 112 | loss = closure() 113 | 114 | self.steps += 1 115 | 116 | data = {'comm_wait': 0, 'optim_step_time': 0, 'decode_time': 0} 117 | for group in self.param_groups: 118 | if len(set(self.names)) != len(group['params']): 119 | raise ValueError('len(set(names)) != len(params)') 120 | 121 | # Order in order of computation 122 | ordered_params = OrderedDict([(p.name, p) 123 | for p in group['params'][::-1]]) 124 | 125 | # iallgather.send takes a long time because of stragglers (and it 126 | # has to send the counts to each machine). 127 | # To fix this, send all sizes async 128 | start = time.time() 129 | msgs_and_data = (future.result() for future in self.futures) 130 | # msgs = concurrent.futures.wait(self.futures) 131 | # print(msgs, type(msgs[0])) 132 | msgs_and_data = list(msgs_and_data) 133 | msgs = [msg for msg, _ in msgs_and_data] 134 | msg_metadata = [datum for _, datum in msgs_and_data] 135 | for key in ['msg_bytes', 'packaged_bytes']: 136 | data[key] = sum(datum[key] for datum in msg_metadata) / len(msg_metadata) 137 | 138 | data['code_wait'] = time.time() - start 139 | start = time.time() 140 | sizes = self.iallgather.prepare(list(map(len, msgs))) 141 | data['iallgather_prepare_time'] = time.time() - start 142 | 143 | start = time.time() 144 | responses = [] 145 | for (req, count), msg in zip(sizes, msgs): 146 | req.Wait() 147 | responses += [self.iallgather.send(msg, count)] 148 | data['isend_time'] = time.time() - start 149 | 150 | names = [p.name for p in group['params'][::-1]] 151 | if len(names) != len(set(names)): 152 | repeated = set([x for x in names if names.count(x) > 1]) 153 | raise ValueError(f'names not unique. Repeated names = {repeated}') 154 | 155 | paired_info = [(name, ordered_params[name], msg, r) 156 | for name, msg, r in zip(self.names, msgs, responses)] 157 | self.names = [] 158 | self.futures = [] 159 | for name, p, msg, response in paired_info: 160 | start = time.time() 161 | codes = self.iallgather.recv(*response, cuda=self.cuda) 162 | data['comm_wait'] += time.time() - start 163 | 164 | start = time.time() 165 | self.code.codes = codes 166 | grads = map(partial(self.code.decode, cuda=self.cuda), codes) 167 | grads = list(map(partial(comms.to_torch, cuda=self.cuda), grads)) 168 | data['decode_time'] += time.time() - start 169 | 170 | start = time.time() 171 | 172 | cond = all([g.shape == grads[0].shape for g in grads]) 173 | if not cond: 174 | print(" !!", self.rank, p.name, [g.shape for g in grads]) 175 | raise ValueError('shapes not the same') 176 | d_p = sum(grads) 177 | 178 | if p.grad is None: 179 | continue 180 | # d_p = p.grad.data 181 | if self.optim == 'sgd': 182 | kwargs = {k: group[k] for k in ['weight_decay', 'momentum', 183 | 'dampening', 'nesterov', 'lr']} 184 | elif self.optim == 'adam': 185 | kwargs = {k: group[k] for k in ['betas', 'weight_decay', 186 | 'eps', 'lr']} 187 | else: 188 | raise ValueError('self.optim not in [sgd, adam]') 189 | 190 | self.optim_step(p, d_p, **kwargs) 191 | data['optim_step_time'] += time.time() - start 192 | 193 | return loss, data 194 | 195 | class SGD(MPI_PS, torch.optim.SGD): 196 | 197 | def optim_step(self, p, d_p, weight_decay=0, momentum=0, dampening=0, 198 | nesterov=0, lr=0): 199 | if weight_decay != 0: 200 | d_p.add_(weight_decay, p.data) 201 | if momentum != 0: 202 | param_state = self.state[p] 203 | if 'momentum_buffer' not in param_state: 204 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 205 | buf.mul_(momentum).add_(d_p) 206 | else: 207 | buf = param_state['momentum_buffer'] 208 | buf.mul_(momentum).add_(1 - dampening, d_p) 209 | if nesterov: 210 | d_p = d_p.add(momentum, buf) 211 | else: 212 | d_p = buf 213 | 214 | p.data.add_(-lr, d_p) 215 | 216 | 217 | class Adam(MPI_PS, torch.optim.Adam): 218 | def optim_step(self, p, grad, amsgrad=False, betas=[0.9, 0.999], weight_decay=0, 219 | eps=1e-8, lr=1e-3): 220 | if grad.is_sparse: 221 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 222 | 223 | state = self.state[p] 224 | 225 | # State initialization 226 | if len(state) == 0: 227 | state['step'] = 0 228 | # Exponential moving average of gradient values 229 | state['exp_avg'] = torch.zeros_like(p.data) 230 | # Exponential moving average of squared gradient values 231 | state['exp_avg_sq'] = torch.zeros_like(p.data) 232 | if amsgrad: 233 | # Maintains max of all exp. moving avg. of sq. grad. values 234 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 235 | 236 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 237 | if amsgrad: 238 | max_exp_avg_sq = state['max_exp_avg_sq'] 239 | beta1, beta2 = betas 240 | 241 | state['step'] += 1 242 | 243 | if weight_decay != 0: 244 | grad = grad.add(weight_decay, p.data) 245 | 246 | # Decay the first and second moment running average coefficient 247 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 248 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 249 | if amsgrad: 250 | # Maintains the maximum of all 2nd moment running avg. till now 251 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 252 | # Use the max. for normalizing running avg. of gradient 253 | denom = max_exp_avg_sq.sqrt().add_(eps) 254 | else: 255 | denom = exp_avg_sq.sqrt().add_(eps) 256 | 257 | bias_correction1 = 1 - beta1 ** state['step'] 258 | bias_correction2 = 1 - beta2 ** state['step'] 259 | step_size = lr * math.sqrt(bias_correction2) / bias_correction1 260 | 261 | p.data.addcdiv_(-step_size, exp_avg, denom) 262 | -------------------------------------------------------------------------------- /serialization.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import cloudpickle 3 | import blosc 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def _tensor_info(v, k): 9 | ret = {f: getattr(v, f)() for f in ['numel', 'data_ptr', 'element_size']} 10 | ret.update(k=k) 11 | return ret 12 | 13 | 14 | def _predump(obj): 15 | tensors = {k: v for k, v in obj.items() 16 | if isinstance(v, (torch.Tensor, torch.Storage))} 17 | info = [_tensor_info(v, k) for k, v in tensors.items()] 18 | d = {k: v for k, v in obj.items() if k not in tensors} 19 | return cloudpickle.dumps(d), info 20 | 21 | 22 | def _tensor_dump(info): 23 | return blosc.compress_ptr(info['data_ptr'], info['numel'], info['element_size']) 24 | 25 | def compress(obj): 26 | d, tensor_info = _predump(obj) 27 | msgs = list(map(_tensor_dump, tensor_info)) 28 | 29 | 30 | return msg 31 | 32 | 33 | def decompress(msg): 34 | b = blosc.decompress(msg) 35 | y = torch.ByteStorage().from_buffer(b, 'native') 36 | return bytes(y) 37 | # x_hat = torch.ByteTensor(y) 38 | # return x_hat 39 | 40 | if __name__ == "__main__": 41 | n = int(1e3) 42 | x = torch.linspace(0, 6.28, n) 43 | y = torch.sin(x) + torch.randn(n) / 4 44 | obj = {'x': x, 'y': y, 'n': n} 45 | obj_np = {'x': x.numpy(), 'y': y.numpy(), 'n': n} 46 | 47 | o = dumps(obj) 48 | msg = compress(o) 49 | oo = decompress(msg) 50 | obj_hat = loads(oo) 51 | -------------------------------------------------------------------------------- /test_comms.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import comms 3 | 4 | comm = MPI.COMM_WORLD 5 | rank = comm.Get_rank() 6 | size = comm.Get_size() 7 | 8 | 9 | def test_gather(): 10 | obj = {'str': 'str', 'rank': rank, 'list': [rank] * (rank + 1)} 11 | msg = comms.igather(obj, name=1) 12 | objs = comms.irecv(*msg, name=1) 13 | sent = [{'str': 'str', 'rank': rank, 'list': [rank] * (rank + 1)} 14 | for rank in range(size)] 15 | if rank == 0: 16 | assert objs == sent 17 | 18 | 19 | def test_bcast(): 20 | obj = {'x': 'x', 'list': [1]} 21 | if rank == 0: 22 | obj = {'a': 'a', 'list': [0]} 23 | tmp = comms.ibroadcast(obj) 24 | recv = comms.irecv1(*tmp) 25 | sent = {'a': 'a', 'list': [0]} 26 | assert recv == sent 27 | 28 | if __name__ == "__main__": 29 | test_bcast() 30 | test_gather() 31 | -------------------------------------------------------------------------------- /test_iallgather.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import pickle 3 | import numpy as np 4 | import random 5 | 6 | 7 | comm = MPI.COMM_WORLD 8 | rank = comm.Get_rank() 9 | size = comm.Get_size() 10 | 11 | 12 | def to_mpi_v(v, counts, dtype=MPI.BYTE): 13 | displacements = [sum(counts[:i]) for i in range(len(counts))] 14 | return (v, (counts, displacements), dtype) 15 | 16 | 17 | def to_mpi(v, dtype=MPI.BYTE): 18 | return (v, dtype) 19 | 20 | 21 | def collect_sizes(rank_size, size=size): 22 | rank_size = np.array(rank_size, dtype=np.int16) 23 | counts = np.zeros(size, dtype=np.int16) 24 | req = comm.Iallgather(rank_size, counts) 25 | req.Wait() 26 | return counts 27 | 28 | 29 | def _make_obj(rank, size): 30 | obj = {'rank': rank, 'list': [rank] * (rank + 1)} 31 | return obj 32 | 33 | def format_for_send(obj): 34 | send = pickle.dumps(obj) 35 | return send 36 | 37 | if __name__ == "__main__": 38 | obj = _make_obj(rank, size) 39 | send = format_for_send(obj) 40 | 41 | counts = collect_sizes(len(send)) 42 | 43 | recv = bytearray(sum(counts)) 44 | req = comm.Iallgatherv(to_mpi(send), to_mpi_v(recv, counts)) 45 | req.Wait() 46 | 47 | displacements = [sum(counts[:i]) for i in range(len(counts))] 48 | pickles = [recv[displacements[i]:displacements[i+1]] 49 | for i in range(len(displacements) - 1)] 50 | pickles += [recv[displacements[-1]:]] 51 | jar = [pickle.loads(p) for p in pickles] 52 | objs_true = [_make_obj(rank, size) for rank in range(size)] 53 | 54 | assert jar == objs_true 55 | -------------------------------------------------------------------------------- /test_mpi.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import zlib 3 | from mpi4py import MPI 4 | import numpy as np 5 | 6 | comm = MPI.COMM_WORLD 7 | rank = comm.Get_rank() 8 | size = comm.Get_size() 9 | dtype = np.dtype('int16') 10 | 11 | name = 1 12 | obj = {'str': 'str', 'int': 1} 13 | max_bytes = {} 14 | pickled = pickle.dumps(obj) 15 | send = bytearray(pickled) 16 | max_bytes[name] = max(max_bytes.get(name, 0), len(send)) * 100 17 | print(max_bytes, len(send)) 18 | recv = bytearray(max_bytes[name] * size) 19 | # req = comm.Ialltoallv([send, MPI.BYTE], [recv, MPI.BYTE]) 20 | req = comm.Ialltoallv([send, MPI.BYTE], [recv, MPI.BYTE]) 21 | 22 | # req.Wait() 23 | # bytes_ = max_bytes[name] 24 | # msgs = [recv[bytes_*n:bytes_*(n+1)] for n in range(size)] 25 | # objs = [pickle.loads(msg) for msg in msgs] 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | def test_async_send_bytearray(): 35 | obj = {'size': size, 'rank': [rank] * (rank+1)} 36 | pickled = pickle.dumps(obj) 37 | send = bytearray(pickled) 38 | recv = None 39 | if rank == 0: 40 | max_bytes = 1024 * 1024 41 | recv = bytearray(max_bytes * size) 42 | 43 | dtype = MPI.BYTE 44 | req = comm.Igatherv([send, dtype], [recv, dtype]) 45 | req.Wait() 46 | 47 | if rank == 0: 48 | msgs = [recv[max_bytes*n:max_bytes*(n+1)] for n in range(size)] 49 | objs = [pickle.loads(msg) for msg in msgs] 50 | sent = [{'size': size, 'rank': [rank]*(rank+1)} for rank in range(size)] 51 | assert objs == sent 52 | 53 | 54 | def test_async_diff_sizes(): 55 | data = {'a': 'a', 'async': [rank] * (rank + 1)} 56 | pickled = pickle.dumps(data) 57 | while len(pickled) % dtype.itemsize != 0: 58 | pickled += b' ' 59 | send = np.fromstring(pickled, dtype=dtype) 60 | 61 | recv = None 62 | if rank == 0: 63 | max_bytes = 1024 * 1024 64 | recv = np.empty([size, max_bytes // 2], dtype=dtype) 65 | 66 | req = comm.Igatherv(send, recv) 67 | req.Wait() 68 | 69 | if rank == 0: 70 | msgs = np.array_split(recv, size) 71 | objs = [pickle.loads(msg) for msg in msgs] 72 | 73 | sent = [{'a': 'a', 'async': [rank] * (rank + 1)} 74 | for rank in range(size)] 75 | assert objs == sent 76 | 77 | 78 | def test_sync_same_size(): 79 | data = {'a': 'a', 'sync': [rank] * (rank + 1)} 80 | pickled = pickle.dumps(data) 81 | 82 | send = np.fromstring(pickled, dtype=np.uint8) 83 | recv = None 84 | if rank == 0: 85 | recv = np.empty([size, size*len(pickled)], dtype=np.uint8) 86 | 87 | comm.Gatherv(send, recv) 88 | 89 | if rank == 0: 90 | msgs = np.array_split(recv, size) 91 | elems = [msg.tobytes() for msg in msgs] 92 | objs = [pickle.loads(x) for x in elems] 93 | 94 | sent = [{'a': 'a', 'sync': [rank] * (rank + 1)} 95 | for rank in range(size)] 96 | assert objs == sent 97 | 98 | 99 | --------------------------------------------------------------------------------