├── Cargo.toml ├── README.md ├── lib ├── orchpy │ ├── orchpy │ │ ├── __init__.py │ │ ├── fundamental.py │ │ ├── main.pyx │ │ ├── unison.pyx │ │ └── utils.pxi │ └── setup.py └── papaya │ ├── __init__.py │ ├── dist.py │ ├── single.py │ └── test.py ├── schema ├── comm.proto ├── make-schema.sh └── types.proto ├── scripts ├── README.md ├── masterimage │ └── Dockerfile ├── orchestra-master-controller.json ├── orchestra-master-service.json ├── orchestra-slave-controller.json ├── orchestra-slave-service.json ├── slaveimage │ └── Dockerfile └── worker.py ├── shell.py ├── src ├── client.rs ├── graph.rs ├── lib.rs ├── main.rs ├── scheduler.rs ├── server.rs └── utils.rs └── test ├── mapreduce.py ├── matmul.py ├── run_papaya_test.py ├── runbasictests.py ├── runtest.py └── testprograms.py /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "orchestra" 3 | version = "0.2.0" 4 | authors = ["Philipp Moritz ", "Robert Nishihara "] 5 | 6 | [lib] 7 | name = "orchestralib" 8 | crate-type = ["dylib"] 9 | 10 | [dependencies] 11 | log = "0.3" 12 | env_logger = "0.3" 13 | libc = "0.1.10" 14 | rand = "*" 15 | argparse = "*" 16 | 17 | [dependencies.protobuf] 18 | git = "https://github.com/stepancheg/rust-protobuf.git" 19 | 20 | [dependencies.zmq] 21 | git = "https://github.com/erickt/rust-zmq.git" 22 | 23 | [dependencies.petgraph] 24 | git = "https://github.com/bluss/petulant-avenger-graphlibrary" 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Orchestra 2 | 3 | Orchestra is work in progress; all the functionality that is implemented is 4 | expected to work, if you run into problems you should file an issue on github. 5 | While performance is an explicit goal of this project, so far most focus has 6 | been on correctness rather than performance. 7 | 8 | ## Setup 9 | The instructions below suffice to run the example code on Ubuntu instance on EC2. 10 | 11 | Install dependencies 12 | 13 | - `sudo apt-get update` 14 | - `sudo apt-get install -y emacs git gcc libzmq3-dev python2.7-dev python-pip` 15 | - `sudo pip install numpy` 16 | - `sudo pip install protobuf` 17 | 18 | Install rust (we currently need the nightly build) 19 | 20 | - `curl -sSf https://static.rust-lang.org/rustup.sh | sh -s -- --channel=nightly` 21 | 22 | Build protobuf compiler 23 | 24 | - `sudo apt-get install protobuf-compiler` 25 | - `cd ~` 26 | - `git clone https://github.com/stepancheg/rust-protobuf.git` 27 | - `cd rust-protobuf` 28 | - `cargo build` 29 | - add the line `export PATH=$HOME/rust-protobuf/target/debug:$PATH` to `~/.bashrc` 30 | - `source ~/.bashrc` 31 | 32 | Install cprotobuf 33 | 34 | - `cd ~` 35 | - `git clone https://github.com/pcmoritz/cprotobuf.git` 36 | - `cd cprotobuf` 37 | - `python setup.py install` 38 | 39 | Clone Orchestra and create schema 40 | 41 | - `cd ~` 42 | - `git clone https://github.com/amplab/orchestra.git` 43 | - `cd orchestra/schema` 44 | - `bash make-schema.sh` 45 | - `cd $HOME/orchestra` 46 | - add `export LD_LIBRARY_PATH=$HOME/orchestra/target/debug/:$LD_LIBRARY_PATH` to `~/.bashrc` 47 | 48 | Build orchpy 49 | - cd `~/orchestra/lib/orchpy/` 50 | - `python setup.py build` 51 | - add something like `export PYTHONPATH=PATH_TO_ORCHESTRA/orchestra/lib/orchpy/build/lib.linux-x86_64-2.7:$PYTHONPATH` to `~/.bashrc`, this will vary depending on your operating system 52 | - add something like `export PYTHONPATH=PATH_TO_ORCHESTRA/orchestra/lib:$PYTHONPATH` to `~/.bashrc` 53 | - `source ~/.bashrc` 54 | 55 | Add Orchestra to your python path 56 | 57 | - add `export PYTHONPATH=$HOME/orchestra/lib/python:$PYTHON_PATH` to `~/.bashrc ` 58 | - `source ~/.bashrc` 59 | 60 | ## Running the tests 61 | 62 | In a terminal, run 63 | 64 | - `cd ~/orchestra/test` 65 | - `RUST_LOG=orchestra=info python runtest.py` 66 | -------------------------------------------------------------------------------- /lib/orchpy/orchpy/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import ObjRef, distributed, check_types, serialize_args, deserialize_args, context, Context, register_current, register_distributed 2 | from .fundamental import ObjRefs, ObjRefsProto 3 | -------------------------------------------------------------------------------- /lib/orchpy/orchpy/fundamental.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import orchpy as op 3 | from cprotobuf import ProtoEntity, Field 4 | 5 | class ObjRefsProto(ProtoEntity): 6 | shape = Field('uint64', 1, repeated=True) 7 | objrefs = Field('bytes', 2, required=False) 8 | 9 | class ObjRefs(object): 10 | 11 | def construct(self): 12 | self.array = np.frombuffer(self.proto.objrefs, dtype=np.dtype('uint64')) 13 | self.array.shape = self.proto.shape 14 | 15 | def deserialize(self, data): 16 | self.proto.ParseFromString(data) 17 | self.construct() 18 | 19 | def from_proto(self, proto): 20 | self.proto = proto 21 | self.construct() 22 | 23 | def __init__(self, shape=None): 24 | self.proto = ObjRefsProto() 25 | if shape != None: 26 | self.proto.shape = shape 27 | self.proto.objrefs = bytearray(np.product(shape) * np.dtype('uint64').itemsize) 28 | self.construct() 29 | 30 | def __getitem__(self, index): 31 | result = self.array[index] 32 | if type(result) == np.uint64: 33 | return op.ObjRef(result) 34 | else: 35 | return np.vectorize(op.ObjRef)(result) 36 | 37 | def __setitem__(self, index, val): 38 | self.array[index] = val.get_id() 39 | -------------------------------------------------------------------------------- /lib/orchpy/orchpy/main.pyx: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | # cython: language_level=3 4 | #cython.wraparound=False 5 | #cython.boundscheck=False 6 | cimport cython 7 | from cpython cimport array 8 | from libc.stdint cimport uint16_t 9 | import array 10 | import cprotobuf 11 | import numpy as np 12 | import orchpy.unison as unison 13 | import orchpy.protos_pb as pb 14 | import types 15 | 16 | # see http://python-future.org/stdlib_incompatibilities.html 17 | from future.utils import bytes_to_native_str 18 | 19 | include "utils.pxi" 20 | 21 | cdef class ObjRef: 22 | cdef size_t _id 23 | 24 | def __cinit__(self, id): 25 | self._id = id 26 | 27 | def __richcmp__(self, other, int op): 28 | if op == 2: 29 | return self.get_id() == other.get_id() 30 | else: 31 | raise NotImplementedError("operator not implemented") 32 | 33 | cpdef get_id(self): 34 | return self._id 35 | 36 | cdef int get_id(ObjRef value): 37 | return value._id 38 | 39 | cdef inline bytes get_elements(bytearray buf, int start, int len): 40 | cdef char *buff = PyByteArray_AS_STRING(buf) 41 | return PyBytes_FromStringAndSize(buff + start, len) 42 | 43 | # this is a draft of the implementation, eventually we will use Python 3's typing 44 | # module and this backport: 45 | # https://github.com/python/typing/blob/master/python2/typing.py 46 | 47 | cpdef check_type(val, t): 48 | if type(val) == ObjRef: 49 | # at the moment, obj references can be of any type; think about making them typed 50 | return 51 | if type(val) == list: 52 | for i, elem in enumerate(val): 53 | try: 54 | check_type(elem, t[1]) 55 | except: 56 | raise Exception("Type error: Heterogeneous list " + str(val) + " at index " + str(i)) 57 | return 58 | if type(val) == tuple: 59 | for i, elem in enumerate(val): 60 | try: 61 | check_type(elem, t[1][i]) 62 | except: 63 | raise Exception("Type error: Type " + str(val) + " at index " + str(i) + " does not match") 64 | return 65 | if (type(val) == int or type(val) == long) and (t == int or t == long): 66 | return True 67 | if type(val) != t: 68 | raise Exception("Type of " + str(val) + " is not " + str(t)) 69 | 70 | # eventually move this into unison 71 | cpdef check_types(vals, schema): 72 | for i, val in enumerate(vals): 73 | check_type(val, schema[i]) 74 | 75 | cpdef serialize_args(args): 76 | result = pb.Args() 77 | cdef bytearray buf = bytearray() 78 | cdef size_t prev_index = 0 79 | objrefs = [] 80 | data = [] 81 | for arg in args: 82 | if type(arg) == ObjRef: 83 | objrefs.append(get_id(arg)) 84 | else: 85 | prev_index = len(buf) 86 | unison.serialize(buf, arg) 87 | data.append(get_elements(buf, prev_index, len(buf) - prev_index)) 88 | prev_index = len(buf) 89 | objrefs.append(-len(data)) 90 | result.objrefs = objrefs 91 | result.data = data 92 | return result 93 | 94 | cpdef deserialize_args(args, types): 95 | result = [] 96 | for k in range(len(args.objrefs)): 97 | elem = args.objrefs[k] 98 | if elem >= 0: # then elem is an ObjRef 99 | result.append(ObjRef(elem)) 100 | else: # then args[k] is being passed by value 101 | if k < len(types) - 1: 102 | arg_type = types[k] 103 | elif k == len(types) - 1 and types[-1] is not None: 104 | arg_type = types[k] 105 | elif k == len(types) - 1 and types[-1] is None: 106 | arg_type = types[-2] 107 | else: 108 | raise Exception() 109 | result.append(unison.deserialize(args.data[-elem - 1], arg_type)) 110 | return result 111 | 112 | cdef struct Slice: 113 | size_t size 114 | char* ptr 115 | 116 | cdef extern void* orchestra_create_context(const char* server_addr, uint16_t reply_port, uint16_t publish_port, const char* client_addr, uint16_t client_port) 117 | cdef extern size_t orchestra_register_function(void* context, const char* name) 118 | cdef extern size_t orchestra_step(void* context) 119 | cdef extern Slice orchestra_get_args(void* context) 120 | cdef extern size_t orchestra_function_index(void* context) 121 | cdef extern size_t orchestra_call(void* context, const char* name, const char* args, size_t argslen) 122 | cdef extern void orchestra_map(void* context, char* name, char* args, size_t argslen, size_t* retlist) 123 | cdef extern void orchestra_store_result(void* context, size_t objref, char* data, size_t datalen) 124 | cdef extern size_t orchestra_get_obj_len(void* Context, size_t objref) 125 | cdef extern char* orchestra_get_obj_ptr(void* context, size_t objref) 126 | cdef extern size_t orchestra_pull(void* context, size_t objref) 127 | cdef extern size_t orchestra_push(void* context) 128 | cdef extern void orchestra_debug_info(void* context) 129 | cdef extern void orchestra_destroy_context(void* context) 130 | 131 | cdef class Context: 132 | cdef void* context 133 | cdef public list functions 134 | cdef public list arg_types 135 | 136 | def __cinit__(self): 137 | self.context = NULL 138 | self.functions = [] 139 | self.arg_types = [] 140 | 141 | def connect(self, server_addr, reply_port, publish_port, client_addr, client_port): 142 | self.context = orchestra_create_context(server_addr, reply_port, publish_port, client_addr, client_port) 143 | 144 | def close(self): 145 | orchestra_destroy_context(self.context) 146 | 147 | def debug_info(self): 148 | orchestra_debug_info(self.context) 149 | 150 | cpdef get_object(self, ObjRef objref, type): 151 | index = objref.get_id() 152 | ptr = orchestra_get_obj_ptr(self.context, index) 153 | len = orchestra_get_obj_len(self.context, index) 154 | data = PyBytes_FromStringAndSize(ptr, len) 155 | return unison.deserialize(data, type) 156 | 157 | def main_loop(self): 158 | cdef size_t objref = 0 159 | while True: 160 | objref = orchestra_step(self.context) 161 | fnidx = orchestra_function_index(self.context) 162 | slice = orchestra_get_args(self.context) 163 | data = PyBytes_FromStringAndSize(slice.ptr, slice.size) 164 | func = self.functions[fnidx] 165 | args = pb.Args() 166 | args.ParseFromString(data) 167 | result = func(args) 168 | orchestra_store_result(self.context, objref, result, len(result)) 169 | 170 | """Args is serialized version of the arguments.""" 171 | def call(self, func_name, module_name, arglist): 172 | args = serialize_args(arglist).SerializeToString() 173 | return ObjRef(orchestra_call(self.context, module_name + "." + func_name, args, len(args))) 174 | 175 | def map(self, func, arglist): 176 | arraytype = bytes_to_native_str(b'L') 177 | args = serialize_args(arglist).SerializeToString() 178 | cdef array.array result = array.array(arraytype, len(arglist) * [0]) # TODO(pcmoritz) This might be slow 179 | orchestra_map(self.context, func.name, args, len(args), result.data.as_voidptr) 180 | retlist = [] 181 | for elem in result: 182 | retlist.append(ObjRef(elem)) 183 | return retlist 184 | 185 | """Register a function that can be called remotely.""" 186 | def register(self, func_name, module_name, function, *args): 187 | fnid = orchestra_register_function(self.context, module_name + "." + func_name) 188 | assert(fnid == len(self.functions)) 189 | self.functions.append(function) 190 | self.arg_types.append(args) 191 | 192 | def pull(self, type, objref): 193 | objref = orchestra_pull(self.context, objref.get_id()) 194 | return self.get_object(ObjRef(objref), type) 195 | 196 | def push(self, obj): 197 | buf = bytearray() 198 | unison.serialize(buf, obj) 199 | objref = orchestra_push(self.context) 200 | orchestra_store_result(self.context, objref, buf, len(buf)) 201 | return ObjRef(objref) 202 | 203 | context = Context() 204 | 205 | def distributed(types, return_type): 206 | def distributed_decorator(func): 207 | # deserialize arguments, execute function and serialize result 208 | def func_executor(args): 209 | arguments = [] 210 | protoargs = deserialize_args(args, types) 211 | for (i, proto) in enumerate(protoargs): 212 | if type(proto) == ObjRef: 213 | if i < len(types) - 1: 214 | arguments.append(context.get_object(proto, types[i])) 215 | elif i == len(types) - 1 and types[-1] is not None: 216 | arguments.append(context.get_object(proto, types[i])) 217 | elif types[-1] is None: 218 | arguments.append(context.get_object(proto, types[-2])) 219 | else: 220 | raise Exception("Passed in " + str(len(args)) + " arguments to function " + func.__name__ + ", which takes only " + str(len(types)) + " arguments.") 221 | else: 222 | arguments.append(proto) 223 | buf = bytearray() 224 | result = func(*arguments) 225 | if unison.unison_type(result) != return_type: 226 | raise Exception("Return type of " + func.func_name + " does not match the return type specified in the @distributed decorator, was expecting " + str(return_type) + " but received " + str(unison.unison_type(result))) 227 | unison.serialize(buf, result) 228 | return memoryview(buf).tobytes() 229 | # for remotely executing the function 230 | def func_call(*args, typecheck=False): 231 | if typecheck: 232 | check_types(args, func_call.types) 233 | return context.call(func_call.func_name, func_call.module_name, args) 234 | func_call.func_name = func.__name__.encode() # why do we call encode()? 235 | func_call.module_name = func.__module__.encode() # why do we call encode()? 236 | func_call.is_distributed = True 237 | func_call.executor = func_executor 238 | func_call.types = types 239 | return func_call 240 | return distributed_decorator 241 | 242 | def register_current(): 243 | for (name, val) in globals().items(): 244 | try: 245 | if val.is_distributed: 246 | context.register(name.encode(), __name__, val.executor, *val.types) 247 | except AttributeError: 248 | pass 249 | 250 | def register_distributed(module): 251 | moduledir = dir(module) 252 | for name in moduledir: 253 | val = getattr(module, name) 254 | try: 255 | if val.is_distributed: 256 | context.register(name.encode(), module.__name__, val.executor, *val.types) 257 | except AttributeError: 258 | pass 259 | -------------------------------------------------------------------------------- /lib/orchpy/orchpy/unison.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # unison: fast, space efficient and backward compatible python serialization 3 | # 4 | # This module exports: 5 | # 6 | # unison.serialize(bytearray, object): Serialize a python object 7 | # into a bytearray 8 | # unison.deserialize(bytes, schema): Deserialize a python object 9 | # with a given schema from a byte string 10 | # 11 | # The schema can be a python type, or an object of the form List[schema], 12 | # Tuple[schema1, schema2, ...] 13 | # 14 | # See runtest.py for examples 15 | 16 | import cprotobuf 17 | import orchpy.protos_pb as pb 18 | import orchpy.main as main 19 | import numpy as np 20 | 21 | try: 22 | import cPickle as pickle 23 | except ImportError: 24 | import pickle 25 | 26 | include "utils.pxi" 27 | 28 | class TypeAlias(object): 29 | """Class for defining generic aliases for library types.""" 30 | 31 | def __init__(self, target_type): 32 | self.target_type = target_type 33 | 34 | def __getitem__(self, typeargs): 35 | return (self.target_type, typeargs) 36 | 37 | cpdef unison_type(obj): 38 | """ 39 | Returns the unison type of obj. For example, 40 | unison_type([0, 1]) == (list, (int, int)) 41 | whereas, 42 | type([0, 1]) == list 43 | """ 44 | if type(obj) == list: 45 | return (list, tuple([unison_type(e) for e in obj])) 46 | if type(obj) == tuple: 47 | return (tuple, tuple([unison_type(e) for e in obj])) 48 | return type(obj) 49 | 50 | List = TypeAlias(list) 51 | # Dict = TypeAlias(dict) 52 | # Set = TypeAlias(set) 53 | Tuple = TypeAlias(tuple) 54 | # Callable = TypeAlias(callable) 55 | 56 | cdef int numpy_dtype_to_proto(dtype): 57 | if dtype == np.dtype('int32'): 58 | return pb.INT32 59 | if dtype == np.dtype('int64'): 60 | return pb.INT64 61 | if dtype == np.dtype('float32'): 62 | return pb.FLOAT32 63 | if dtype == np.dtype('float64'): 64 | return pb.FLOAT64 65 | 66 | cpdef array_to_proto(array): 67 | result = pb.Array() 68 | result.shape.extend(array.shape) 69 | result.data = np.getbuffer(array, 0, array.size * array.dtype.itemsize) 70 | result.dtype = numpy_dtype_to_proto(array.dtype) 71 | return result 72 | 73 | cdef proto_dtype_to_numpy(dtype): 74 | if dtype == pb.INT32: 75 | return np.dtype('int32') 76 | if dtype == pb.INT64: 77 | return np.dtype('int64') 78 | if dtype == pb.FLOAT32: 79 | return np.dtype('float32') 80 | if dtype == pb.FLOAT64: 81 | return np.dtype('float64') 82 | 83 | cpdef proto_to_array(proto): 84 | dtype = proto_dtype_to_numpy(proto.dtype) 85 | result = np.frombuffer(proto.data, dtype=dtype) 86 | result.shape = proto.shape 87 | return result 88 | 89 | cpdef serialize(bytearray buf, val): 90 | if type(val) == int or type(val) == long: 91 | raw_encode_uint64(buf, val) 92 | elif type(val) == float: 93 | encode_float(buf, val) 94 | elif type(val) == str or type(val) == unicode: 95 | encode_string(buf, val) 96 | elif type(val) == tuple: 97 | for elem in val: 98 | serialize(buf, elem) 99 | elif type(val) == list: 100 | serialize(buf, len(val)) 101 | for elem in val: 102 | serialize(buf, elem) 103 | elif type(val) == np.ndarray: 104 | proto = array_to_proto(val) 105 | data = proto.SerializeToString() 106 | serialize(buf, len(data)) 107 | buf.extend(data) 108 | elif type(val) == main.ObjRef: 109 | serialize(buf, val.get_id()) 110 | else: 111 | if hasattr(val, 'proto') and hasattr(val.proto, 'SerializeToString'): 112 | data = val.proto.SerializeToString() 113 | else: 114 | data = pickle.dumps(val, pickle.HIGHEST_PROTOCOL) 115 | serialize(buf, len(data)) 116 | buf.extend(data) 117 | 118 | cdef object deserialize_primitive(char **buff, char *end, type t): 119 | if t == int or t == long: 120 | return decode_uint64(buff, end) 121 | if t == float: 122 | return decode_float(buff, end) 123 | if t == str or t == unicode: 124 | return decode_string(buff, end) 125 | if t == np.ndarray: 126 | size = deserialize_primitive(buff, end, int) 127 | data = PyBytes_FromStringAndSize(buff[0], size) 128 | buff[0] += size 129 | array = pb.Array() 130 | array.ParseFromString(data) 131 | return proto_to_array(array) 132 | if t == main.ObjRef: 133 | return main.ObjRef(decode_uint64(buff, end)) 134 | else: 135 | size = deserialize_primitive(buff, end, int) 136 | data = PyBytes_FromStringAndSize(buff[0], size) 137 | buff[0] += size 138 | if hasattr(t, 'deserialize'): 139 | result = t() 140 | result.deserialize(data) 141 | else: 142 | result = pickle.loads(data) 143 | return result 144 | 145 | 146 | cdef object deserialize_buffer(char **buff, char *end, schema): 147 | if type(schema) == type: 148 | return deserialize_primitive(buff, end, schema) 149 | if type(schema) == tuple and schema[0] == tuple: 150 | result = [] 151 | for t in schema[1]: 152 | result.append(deserialize_buffer(buff, end, t)) 153 | return tuple(result) 154 | if type(schema) == tuple and schema[0] == list: 155 | result = [] 156 | len = deserialize_primitive(buff, end, long) 157 | for i in range(len): 158 | result.append(deserialize_buffer(buff, end, schema[1])) 159 | return result 160 | 161 | cpdef object deserialize(data, schema): 162 | cdef char *buff = data 163 | cdef Py_ssize_t size = len(data) 164 | cdef char *end = buff + size 165 | return deserialize_buffer(&buff, end, schema) 166 | -------------------------------------------------------------------------------- /lib/orchpy/orchpy/utils.pxi: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | from cpython cimport bytearray, PySequence_Length, PySequence_InPlaceConcat, PyUnicode_AsUTF8String 3 | from libc.stdint cimport uint64_t, int64_t, uint32_t, int32_t 4 | 5 | # {{{ definitions 6 | 7 | cdef extern from "Python.h": 8 | Py_ssize_t PyByteArray_GET_SIZE(object array) 9 | object PyUnicode_FromStringAndSize(char *buff, Py_ssize_t len) 10 | object PyBytes_FromStringAndSize(char *buff, Py_ssize_t len) 11 | object PyString_FromStringAndSize(char *buff, Py_ssize_t len) 12 | int PyByteArray_Resize(object self, Py_ssize_t size) except -1 13 | char* PyByteArray_AS_STRING(object bytearray) 14 | 15 | ctypedef object(*Decoder)(char **pointer, char *end) 16 | 17 | class InternalDecodeError(Exception): 18 | pass 19 | 20 | cdef inline object makeDecodeError(char* pointer, message): 21 | cdef uint64_t locator = pointer 22 | return InternalDecodeError(locator, message) 23 | 24 | class DecodeError(Exception): 25 | def __init__(self, pointer, message): 26 | self.pointer = pointer 27 | self.message = message 28 | def __str__(self): 29 | return self.message.format(self.pointer) 30 | 31 | # }}} 32 | 33 | # {{{ decoding 34 | 35 | # {{{ raw stuff 36 | 37 | cdef inline int raw_decode_uint32(char **start, char *end, uint32_t *result) nogil: 38 | cdef uint32_t value = 0 39 | cdef uint32_t byte 40 | cdef char *pointer = start[0] 41 | cdef int counter = 0 42 | while True: 43 | if pointer >= end: 44 | return -1 45 | byte = pointer[0] 46 | value |= (byte & 0x7f) << counter 47 | counter+=7 48 | pointer+=1 49 | if byte & 0x80 == 0: 50 | break 51 | start[0] = pointer 52 | result[0] = value 53 | return 0 54 | 55 | cdef inline int raw_decode_uint64(char **start, char *end, uint64_t *result) nogil: 56 | cdef uint64_t value = 0 57 | cdef uint64_t byte 58 | cdef char *pointer = start[0] 59 | cdef int counter = 0 60 | 61 | while True: 62 | if pointer >= end: 63 | return -1 64 | byte = pointer[0] 65 | value |= (byte & 0x7f) << counter 66 | counter+=7 67 | pointer+=1 68 | if byte & 0x80 == 0: 69 | break 70 | start[0] = pointer 71 | result[0] = value 72 | return 0 73 | 74 | cdef inline int raw_decode_fixed32(char **pointer, char *end, uint32_t *result) nogil: 75 | cdef uint32_t value = 0 76 | cdef char *start = pointer[0] 77 | cdef int i 78 | 79 | for i from 0 <= i < 4: 80 | if start == end: 81 | return -1 82 | value |= start[0] << (i * 8) 83 | start += 1 84 | pointer[0] = start 85 | result[0] = value 86 | return 0 87 | 88 | cdef inline int raw_decode_fixed64(char **pointer, char *end, uint64_t *result) nogil: 89 | cdef uint64_t value = 0 90 | cdef char *start = pointer[0] 91 | cdef uint64_t temp = 0 92 | cdef int i 93 | for i from 0 <= i < 8: 94 | if start == end: 95 | return -1 96 | temp = start[0] 97 | value |= temp << (i * 8) 98 | start += 1 99 | pointer[0] = start 100 | result[0] = value 101 | return 0 102 | 103 | cdef inline int raw_decode_delimited(char **pointer, char *end, char **result, uint64_t *size) nogil: 104 | if raw_decode_uint64(pointer, end, size): 105 | return -1 106 | 107 | cdef char* start = pointer[0] 108 | if start+size[0] > end: 109 | return -2 110 | 111 | result[0] = start 112 | pointer[0] = start+size[0] 113 | return 0 114 | 115 | # }}} 116 | 117 | cdef object decode_uint32(char **pointer, char *end): 118 | cdef uint32_t result 119 | if raw_decode_uint32(pointer, end, &result): 120 | raise makeDecodeError(pointer[0], "Can't decode value of type `uint32` at [{0}]") 121 | 122 | return result 123 | 124 | cdef object decode_uint64(char **pointer, char *end): 125 | cdef uint64_t result 126 | if raw_decode_uint64(pointer, end, &result): 127 | raise makeDecodeError(pointer[0], "Can't decode value of type `uint64` at [{0}]") 128 | 129 | return result 130 | 131 | cdef object decode_int32(char **pointer, char *end, ): 132 | cdef int32_t result 133 | if raw_decode_uint32(pointer, end, &result): 134 | raise makeDecodeError(pointer[0], "Can't decode value of type `int32` at [{0}]") 135 | 136 | return result 137 | 138 | cdef object decode_int64(char **pointer, char *end, ): 139 | cdef int64_t result 140 | if raw_decode_uint64(pointer, end, &result): 141 | raise makeDecodeError(pointer[0], "Can't decode value of type `int64` at [{0}]") 142 | 143 | return result 144 | 145 | cdef object decode_sint32(char **pointer, char *end, ): 146 | cdef uint32_t result 147 | if raw_decode_uint32(pointer, end, &result): 148 | raise makeDecodeError(pointer[0], "Can't decode value of type `sint32` at [{0}]") 149 | 150 | return ((result >> 1) ^ (-(result & 1))) 151 | 152 | cdef object decode_sint64(char **pointer, char *end, ): 153 | cdef uint64_t un 154 | if raw_decode_uint64(pointer, end, &un): 155 | raise makeDecodeError(pointer[0], "Can't decode value of type `sint64` at [{0}]") 156 | 157 | return ((un >> 1) ^ (-(un & 1))) 158 | 159 | cdef object decode_fixed32(char **pointer, char *end, ): 160 | cdef uint32_t result 161 | if raw_decode_fixed32(pointer, end, &result): 162 | raise makeDecodeError(pointer[0], "Can't decode value of type `fixed32` at [{0}]") 163 | 164 | return result 165 | 166 | cdef object decode_fixed64(char **pointer, char *end, ): 167 | cdef uint64_t result 168 | if raw_decode_fixed64(pointer, end, &result): 169 | raise makeDecodeError(pointer[0], "Can't decode value of type `fixed64` at [{0}]") 170 | 171 | return result 172 | 173 | cdef object decode_sfixed32(char **pointer, char *end, ): 174 | cdef int32_t result 175 | if raw_decode_fixed32(pointer, end, &result): 176 | raise makeDecodeError(pointer[0], "Can't decode value of type `sfixed32` at [{0}]") 177 | 178 | return result 179 | 180 | cdef object decode_sfixed64(char **pointer, char *end, ): 181 | cdef int64_t result 182 | if raw_decode_fixed64(pointer, end, &result): 183 | raise makeDecodeError(pointer[0], "Can't decode value of type `sfixed64` at [{0}]") 184 | 185 | return result 186 | 187 | cdef object decode_bytes(char **pointer, char *end, ): 188 | cdef char *result 189 | cdef uint64_t size 190 | cdef int ret = raw_decode_delimited(pointer, end, &result, &size) 191 | if ret==0: 192 | return PyBytes_FromStringAndSize(result, size) 193 | 194 | if ret == -1: 195 | raise makeDecodeError(pointer[0], "Can't decode size for value of type `bytes` at [{0}]") 196 | elif ret == -2: 197 | raise makeDecodeError(pointer[0], "Can't decode value of type `bytes` of size %d at [{0}]" % size) 198 | 199 | 200 | cdef object decode_string(char **pointer, char *end, ): 201 | cdef char *result 202 | cdef uint64_t size 203 | cdef int ret = raw_decode_delimited(pointer, end, &result, &size) 204 | if ret==0: 205 | return PyUnicode_FromStringAndSize(result, size) 206 | 207 | if ret == -1: 208 | raise makeDecodeError(pointer[0], "Can't decode size for value of type `string` at [{0}]") 209 | elif ret == -2: 210 | raise makeDecodeError(pointer[0], "Can't decode value of type `string` of size %d at [{0}]" % size) 211 | 212 | cdef object decode_float(char **pointer, char *end, ): 213 | cdef float result 214 | if raw_decode_fixed32(pointer, end, &result): 215 | raise makeDecodeError(pointer[0], "Can't decode value of type `float` at [{0}]") 216 | 217 | return result 218 | 219 | cdef object decode_double(char **pointer, char *end, ): 220 | cdef double result 221 | if raw_decode_fixed64(pointer, end, &result): 222 | raise makeDecodeError(pointer[0], "Can't decode value of type `double` at [{0}]") 223 | 224 | return result 225 | 226 | cdef object decode_bool(char **pointer, char *end, ): 227 | cdef char* start = pointer[0] 228 | pointer[0] = start + 1 229 | 230 | return start[0] 231 | 232 | # }}} 233 | 234 | # {{{ encoding 235 | 236 | cdef inline int raw_encode_uint32(bytearray array, uint32_t n) except -1: 237 | cdef unsigned short int rem 238 | cdef Py_ssize_t size = PyByteArray_GET_SIZE(array) 239 | PyByteArray_Resize(array, size + 10) 240 | cdef char *buff = PyByteArray_AS_STRING(array) + size 241 | 242 | if 0!=n: 243 | while True: 244 | rem = (n & 0x7f) 245 | n = n>>7 246 | if 0==n: 247 | buff[0] = rem 248 | buff+=1 249 | break 250 | else: 251 | rem = rem | 0x80 252 | buff[0] = rem 253 | buff+=1 254 | else: 255 | buff[0] = b'\0' 256 | buff+=1 257 | 258 | PyByteArray_Resize(array, buff - PyByteArray_AS_STRING(array)) 259 | return 0 260 | 261 | cdef inline int raw_encode_uint64(bytearray array, uint64_t n) except -1: 262 | cdef unsigned short int rem 263 | cdef Py_ssize_t size = PyByteArray_GET_SIZE(array) 264 | PyByteArray_Resize(array, size + 20) 265 | cdef char *buff = PyByteArray_AS_STRING(array) + size 266 | 267 | if 0!=n: 268 | while True: 269 | rem = (n & 0x7f) 270 | n = n>>7 271 | if 0==n: 272 | buff[0] = rem 273 | buff+=1 274 | break 275 | else: 276 | rem = rem | 0x80 277 | buff[0] = rem 278 | buff+=1 279 | else: 280 | buff[0] = b'\0' 281 | buff+=1 282 | PyByteArray_Resize(array, buff - PyByteArray_AS_STRING(array)) 283 | return 0 284 | 285 | cdef inline int raw_encode_fixed32(bytearray array, uint32_t n) except -1: 286 | cdef unsigned short int rem 287 | cdef Py_ssize_t size = PyByteArray_GET_SIZE(array) 288 | PyByteArray_Resize(array, size + 4) 289 | cdef char *buff = PyByteArray_AS_STRING(array) + size 290 | cdef int i 291 | 292 | for i from 0 <= i < 4: 293 | rem = n & 0xff 294 | n = n >> 8 295 | buff[0] = rem 296 | buff += 1 297 | 298 | return 0 299 | 300 | cdef inline int raw_encode_fixed64(bytearray array, uint64_t n) except -1: 301 | cdef unsigned short int rem 302 | cdef Py_ssize_t size = PyByteArray_GET_SIZE(array) 303 | PyByteArray_Resize(array, size + 8) 304 | cdef char *buff = PyByteArray_AS_STRING(array) + size 305 | cdef int i 306 | 307 | for i from 0 <= i < 8: 308 | rem = n & 0xff 309 | n = n >> 8 310 | buff[0] = rem 311 | buff += 1 312 | 313 | return 0 314 | 315 | cdef inline encode_string(bytearray array, object n): 316 | if isinstance(n, unicode): 317 | n = PyUnicode_AsUTF8String(n) 318 | cdef Py_ssize_t len = PySequence_Length(n) 319 | raw_encode_uint64(array, len) 320 | PySequence_InPlaceConcat(array, n) 321 | 322 | cdef inline encode_float(bytearray array, object value): 323 | cdef float f = value 324 | raw_encode_fixed32(array, (&f)[0]) 325 | 326 | # }}} 327 | -------------------------------------------------------------------------------- /lib/orchpy/setup.py: -------------------------------------------------------------------------------- 1 | #from distutils.core import setup 2 | #from distutils.extension import Extension 3 | from setuptools import setup, Extension, find_packages 4 | from Cython.Build import cythonize 5 | 6 | # because of relative paths, this must be run from inside orchestra/lib/orchpy/ 7 | 8 | setup( 9 | name = "orchestra", 10 | version = "0.1.dev0", 11 | ext_modules = cythonize([ 12 | Extension("orchpy/main", 13 | sources = ["orchpy/main.pyx"], libraries=["orchestralib"], 14 | library_dirs=['../../target/debug/']), 15 | Extension("orchpy/unison", sources = ["orchpy/unison.pyx"])], 16 | compiler_directives={'language_level': 3}), 17 | use_2to3=True, 18 | packages=find_packages() 19 | ) 20 | -------------------------------------------------------------------------------- /lib/papaya/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amplab/orchestra/280ab768e1df536367ab65eaee009384c3e522c3/lib/papaya/__init__.py -------------------------------------------------------------------------------- /lib/papaya/dist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import papaya.single as single 3 | import orchpy as op 4 | import unison 5 | from cprotobuf import ProtoEntity, Field 6 | 7 | block_size = 10 8 | 9 | class DistArrayProto(ProtoEntity): 10 | shape = Field('uint64', 1, repeated=True) 11 | dtype = Field('string', 2, required=True) 12 | objrefs = Field(op.ObjRefsProto, 3, required=False) 13 | 14 | class DistArray(object): 15 | def construct(self): 16 | self.dtype = self.proto.dtype 17 | self.shape = self.proto.shape 18 | self.num_blocks = [int(np.ceil(1.0 * a / block_size)) for a in self.proto.shape] 19 | self.blocks = op.ObjRefs() 20 | self.blocks.from_proto(self.proto.objrefs) 21 | 22 | def deserialize(self, data): 23 | self.proto.ParseFromString(data) 24 | self.construct() 25 | 26 | def from_proto(self, data): 27 | self.proto = proto 28 | construct() 29 | 30 | def __init__(self, dtype='float', shape=None): 31 | self.proto = DistArrayProto() 32 | if shape != None: 33 | self.proto.shape = shape 34 | self.proto.dtype = dtype 35 | self.num_blocks = [int(np.ceil(1.0 * a / block_size)) for a in self.proto.shape] 36 | objrefs = op.ObjRefs(self.num_blocks) 37 | self.proto.objrefs = objrefs.proto 38 | self.construct() 39 | 40 | def compute_block_lower(self, index): 41 | lower = [] 42 | for i in range(len(index)): 43 | lower.append(index[i] * block_size) 44 | return lower 45 | 46 | def compute_block_upper(self, index): 47 | upper = [] 48 | for i in range(len(index)): 49 | upper.append(min((index[i] + 1) * block_size, self.shape[i])) 50 | return upper 51 | 52 | def compute_block_shape(self, index): 53 | lower = self.compute_block_lower(index) 54 | upper = self.compute_block_upper(index) 55 | return [u - l for (l, u) in zip(lower, upper)] 56 | 57 | def assemble(self): 58 | """Assemble an array on this node from a distributed array object reference.""" 59 | result = np.zeros(self.shape) 60 | for index in np.ndindex(*self.num_blocks): 61 | lower = self.compute_block_lower(index) 62 | upper = self.compute_block_upper(index) 63 | result[[slice(l, u) for (l, u) in zip(lower, upper)]] = op.context.pull(np.ndarray, self.blocks[index]) 64 | return result 65 | 66 | def __getitem__(self, sliced): 67 | # TODO(rkn): fix this, this is just a placeholder that should work but is inefficient 68 | a = self.assemble() 69 | return a[sliced] 70 | 71 | #@op.distributed([DistArray], np.ndarray) 72 | def assemble(a): 73 | return a.assemble() 74 | 75 | #@op.distributed([unison.List[int], unison.List[int], str], DistArray) 76 | def zeros(shape, dtype): 77 | dist_array = DistArray(dtype, shape) 78 | for index in np.ndindex(*dist_array.num_blocks): 79 | dist_array.blocks[index] = single.zeros(dist_array.compute_block_shape(index)) 80 | return dist_array 81 | 82 | #@op.distributed([DistArray], DistArray) 83 | def copy(a): 84 | dist_array = DistArray(a.dtype, a.shape) 85 | for index in np.ndindex(*dist_array.num_blocks): 86 | dist_array.blocks[index] = single.copy(a.blocks[index]) 87 | return dist_array 88 | 89 | def eye(dim, dtype): 90 | # TODO(rkn): this code is pretty ugly, please clean it up 91 | dist_array = zeros([dim, dim], dtype) 92 | num_blocks = dist_array.num_blocks[0] 93 | for i in range(num_blocks - 1): 94 | dist_array.blocks[i, i] = single.eye(block_size) 95 | dist_array.blocks[num_blocks - 1, num_blocks - 1] = single.eye(dim - block_size * (num_blocks - 1)) 96 | return dist_array 97 | 98 | #@op.distributed([unison.List[int], unison.List[int], str], DistArray) 99 | def random_normal(shape): 100 | dist_array = DistArray("float", shape) 101 | for index in np.ndindex(*dist_array.num_blocks): 102 | dist_array.blocks[index] = single.random_normal(dist_array.compute_block_shape(index)) 103 | return dist_array 104 | 105 | #@op.distributed([DistArray], DistArray) 106 | def triu(a): 107 | if len(a.shape) != 2: 108 | raise Exception("input must have dimension 2, but len(a.shape) is " + str(len(a.shape))) 109 | dist_array = DistArray(a.dtype, a.shape) 110 | for i in range(a.num_blocks[0]): 111 | for j in range(a.num_blocks[1]): 112 | if i < j: 113 | dist_array.blocks[i, j] = single.copy(a.blocks[i, j]) 114 | elif i == j: 115 | dist_array.blocks[i, j] = single.triu(a.blocks[i, j]) 116 | else: 117 | dist_array.blocks[i, j] = single.zeros([block_size, block_size]) 118 | return dist_array 119 | 120 | #@op.distributed([DistArray], DistArray) 121 | def tril(a): 122 | if len(a.shape) != 2: 123 | raise Exception("input must have dimension 2, but len(a.shape) is " + str(len(a.shape))) 124 | dist_array = DistArray(a.dtype, a.shape) 125 | for i in range(a.num_blocks[0]): 126 | for j in range(a.num_blocks[1]): 127 | if i > j: 128 | dist_array.blocks[i, j] = single.copy(a.blocks[i, j]) 129 | elif i == j: 130 | dist_array.blocks[i, j] = single.triu(a.blocks[i, j]) 131 | else: 132 | dist_array.blocks[i, j] = single.zeros([block_size, block_size]) 133 | return dist_array 134 | 135 | @op.distributed([np.ndarray, None], np.ndarray) 136 | def blockwise_inner(*matrices): 137 | n = len(matrices) 138 | assert(np.mod(n, 2) == 0) 139 | shape = (matrices[0].shape[0], matrices[n / 2].shape[1]) 140 | result = np.zeros(shape) 141 | for i in range(n / 2): 142 | result += np.dot(matrices[i], matrices[n / 2 + i]) 143 | return result 144 | 145 | #@op.distributed([DistArray, DistArray], DistArray) 146 | def dot(a, b): 147 | assert(a.dtype == b.dtype) 148 | assert(len(a.shape) == len(b.shape) == 2) 149 | assert(a.shape[1] == b.shape[0]) 150 | dtype = a.dtype 151 | shape = [a.shape[0], b.shape[1]] 152 | res = DistArray(dtype, shape) 153 | for i in range(res.num_blocks[0]): 154 | for j in range(res.num_blocks[1]): 155 | args = list(a.blocks[i,:]) + list(b.blocks[:,j]) 156 | res.blocks[i,j] = blockwise_inner(*args) 157 | return res 158 | 159 | # @op.distributed([DistArray], unison.Tuple[DistArray, np.ndarray]) 160 | def tsqr(a): 161 | """ 162 | arguments: 163 | a: a distributed matrix 164 | Suppose that 165 | a.shape == (M, N) 166 | K == min(M, N) 167 | return values: 168 | q: DistArray, if q_full = op.context.pull(DistArray, q).assemble(), then 169 | q_full.shape == (M, K) 170 | np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True 171 | r: np.ndarray, if r_val = op.context.pull(np.ndarray, r), then 172 | r_val.shape == (K, N) 173 | np.allclose(r, np.triu(r)) == True 174 | """ 175 | # TODO: implement tsqr in two stages, first create the tree data structure 176 | # where each thing is an objref of a numpy array (each Q_ij is a numpy 177 | # array). Then assemble the matrix essentially via a map call on each Q_i0. 178 | assert len(a.shape) == 2 179 | assert a.num_blocks[1] == 1 180 | num_blocks = a.num_blocks[0] 181 | K = int(np.ceil(np.log2(num_blocks))) + 1 182 | q_tree = np.zeros((num_blocks, K), dtype=op.ObjRef) 183 | current_rs = [] 184 | for i in range(num_blocks): 185 | block = a.blocks[i, 0] 186 | q = single.qr_return_q(block) 187 | r = single.qr_return_r(block) 188 | q_tree[i, 0] = q 189 | current_rs.append(r) 190 | assert op.context.pull(np.ndarray, q).shape[0] == op.context.pull(np.ndarray, a.blocks[i, 0]).shape[0] # TODO(rkn): remove this code at some point 191 | assert op.context.pull(np.ndarray, r).shape[1] == op.context.pull(np.ndarray, a.blocks[i, 0]).shape[1] # TODO(rkn): remove this code at some point 192 | for j in range(1, K): 193 | new_rs = [] 194 | for i in range(int(np.ceil(1.0 * len(current_rs) / 2))): 195 | stacked_rs = single.vstack(*current_rs[(2 * i):(2 * i + 2)]) 196 | q = single.qr_return_q(stacked_rs) 197 | r = single.qr_return_r(stacked_rs) 198 | q_tree[i, j] = q 199 | new_rs.append(r) 200 | current_rs = new_rs 201 | assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs)) 202 | 203 | # handle the special case in which the whole DistArray "a" fits in one block 204 | # and has fewer rows than columns, this is a bit ugly so think about how to 205 | # remove it 206 | if a.shape[0] >= a.shape[1]: 207 | q_result = DistArray(a.dtype, a.shape) 208 | else: 209 | q_result = DistArray(a.dtype, [a.shape[0], a.shape[0]]) 210 | 211 | # reconstruct output 212 | for i in range(num_blocks): 213 | q_block_current = q_tree[i, 0] 214 | ith_index = i 215 | for j in range(1, K): 216 | if np.mod(ith_index, 2) == 0: 217 | lower = [0, 0] 218 | upper = [a.shape[1], block_size] 219 | else: 220 | lower = [a.shape[1], 0] 221 | upper = [2 * a.shape[1], block_size] 222 | ith_index /= 2 223 | q_block_current = single.dot(q_block_current, single.subarray(q_tree[ith_index, j], lower, upper)) 224 | q_result.blocks[i] = q_block_current 225 | r = op.context.pull(np.ndarray, current_rs[0]) 226 | assert r.shape == (min(a.shape[0], a.shape[1]), a.shape[1]) 227 | return q_result, r 228 | 229 | def tsqr_hr(a): 230 | """Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf""" 231 | q, r_temp = tsqr(a) 232 | y, u, s = single.modified_lu(assemble(q)) 233 | s_full = np.diag(s) 234 | b = q.shape[1] 235 | y_top = y[:b, :b] 236 | t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T)) 237 | r = np.dot(s_full, r_temp) 238 | return y, t, y_top, r 239 | 240 | def array_from_blocks(blocks): 241 | dims = len(blocks.shape) 242 | num_blocks = list(blocks.shape) 243 | shape = [] 244 | for i in range(len(blocks.shape)): 245 | index = [0] * dims 246 | index[i] = -1 247 | index = tuple(index) 248 | remainder = op.context.pull(np.ndarray, blocks[index]).shape[i] 249 | shape.append(block_size * (num_blocks[i] - 1) + remainder) 250 | dist_array = DistArray("float", shape) 251 | for index in np.ndindex(*blocks.shape): 252 | dist_array.blocks[index] = blocks[index] 253 | return dist_array 254 | 255 | def qr(a): 256 | """Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf""" 257 | m, n = a.shape[0], a.shape[1] 258 | k = min(m, n) 259 | 260 | # we will store our scratch work in a_work 261 | a_work = DistArray(a.dtype, a.shape) 262 | for index in np.ndindex(*a.num_blocks): 263 | a_work.blocks[index] = a.blocks[index] 264 | 265 | r_res = zeros([k, n], a.dtype) 266 | y_res = zeros([m, k], a.dtype) 267 | Ts = [] 268 | 269 | for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0] 270 | b = min(block_size, a.shape[1] - block_size * i) 271 | column_dist_array = DistArray(a_work.dtype, [m, b]) 272 | y, t, _, R = tsqr_hr(array_from_blocks(a_work.blocks[i:, i:(i + 1)])) 273 | 274 | for j in range(i, a.num_blocks[0]): 275 | y_res.blocks[j, i] = op.context.push(y[((j - i) * block_size):((j - i + 1) * block_size), :]) # eventually this should go away 276 | if a.shape[0] > a.shape[1]: 277 | # in this case, R needs to be square 278 | r_res.blocks[i, i] = op.context.push(np.vstack([R, np.zeros((R.shape[1] - R.shape[0], R.shape[1]))])) 279 | else: 280 | r_res.blocks[i, i] = op.context.push(R) 281 | Ts.append(t) 282 | 283 | for c in range(i + 1, a.num_blocks[1]): 284 | W_rcs = [] 285 | for r in range(i, a.num_blocks[0]): 286 | y_ri = y[((r - i) * block_size):((r - i + 1) * block_size), :] 287 | W_rcs.append(np.dot(y_ri.T, op.context.pull(np.ndarray, a_work.blocks[r, c]))) # eventually the pull should go away 288 | W_c = np.sum(W_rcs, axis=0) 289 | for r in range(i, a.num_blocks[0]): 290 | y_ri = y[((r - i) * block_size):((r - i + 1) * block_size), :] 291 | A_rc = op.context.pull(np.ndarray, a_work.blocks[r, c]) - np.dot(y_ri, np.dot(t.T, W_c)) 292 | a_work.blocks[r, c] = op.context.push(A_rc) 293 | r_res.blocks[i, c] = a_work.blocks[i, c] 294 | 295 | q_res = eye(a.shape[0], "float") 296 | # construct q_res from Ys and Ts 297 | #TODO(construct q_res from Ys and Ts) 298 | # for i in range(a.num_blocks[1]): 299 | 300 | #return q_res, r_res 301 | 302 | return Ts, y_res, r_res 303 | -------------------------------------------------------------------------------- /lib/papaya/single.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import orchpy as op 3 | import unison 4 | 5 | @op.distributed([unison.List[int]], np.ndarray) 6 | def zeros(shape): 7 | return np.zeros(shape) 8 | 9 | @op.distributed([int], np.ndarray) 10 | def eye(dim): 11 | return np.eye(dim) 12 | 13 | @op.distributed([unison.List[int]], np.ndarray) 14 | def random_normal(shape): 15 | return np.random.normal(size=shape) 16 | 17 | @op.distributed([np.ndarray, np.ndarray], np.ndarray) 18 | def dot(a, b): 19 | return np.dot(a, b) 20 | 21 | # TODO(rkn): this should take the same optional "mode" argument as np.linalg.qr, except that the different options sometimes have different numbers of return values, which could be a problem 22 | @op.distributed([np.ndarray], unison.Tuple[np.ndarray, np.ndarray]) 23 | def qr(a): 24 | """ 25 | Suppose (n, m) = a.shape 26 | If n >= m: 27 | q.shape == (n, m) 28 | r.shape == (m, m) 29 | If n < m: 30 | q.shape == (n, n) 31 | r.shape == (n, m) 32 | """ 33 | return np.linalg.qr(a) 34 | 35 | # TODO(rkn): stopgap until we support returning tuples of object references 36 | @op.distributed([np.ndarray], np.ndarray) 37 | def qr_return_q(a): 38 | q, r = np.linalg.qr(a) 39 | return q 40 | 41 | # TODO(rkn): stopgap until we support returning tuples of object references 42 | @op.distributed([np.ndarray], np.ndarray) 43 | def qr_return_r(a): 44 | q, r = np.linalg.qr(a) 45 | return r 46 | 47 | # TODO(rkn): My preferred signature would have been 48 | # @op.distributed([unison.List[np.ndarray]], np.ndarray) but that currently 49 | # doesn't work because that would expect a list of ndarrays not a list of 50 | # ObjRefs 51 | @op.distributed([np.ndarray, None], np.ndarray) 52 | def vstack(*xs): 53 | return np.vstack(xs) 54 | 55 | # would have preferred @op.distributed([unison.List[np.ndarray]], np.ndarray) 56 | @op.distributed([np.ndarray, None], np.ndarray) 57 | def hstack(*xs): 58 | return np.hstack(xs) 59 | 60 | # TODO(rkn): this doesn't parallel the numpy API, but we can't really slice an ObjRef, think about this 61 | @op.distributed([np.ndarray, unison.List[int], unison.List[int]], np.ndarray) 62 | def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices" 63 | return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]] 64 | 65 | @op.distributed([np.ndarray], np.ndarray) 66 | def copy(a): 67 | return np.copy(a) 68 | 69 | @op.distributed([np.ndarray], np.ndarray) 70 | def tril(a): 71 | return np.tril(a) 72 | 73 | @op.distributed([np.ndarray], np.ndarray) 74 | def triu(a): 75 | return np.triu(a) 76 | 77 | #@op.distributed([np.ndarray], unison.Tuple[np.ndarray, np.ndarray, np.ndarray]) 78 | def modified_lu(q): 79 | """ 80 | Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf 81 | 82 | takes a matrix q with orthonormal columns, returns l, u, s such that q - s = l * u 83 | arguments: 84 | q: a two dimensional orthonormal q 85 | return values: 86 | l: lower triangular 87 | u: upper triangular 88 | s: a diagonal matrix represented by its diagonal 89 | """ 90 | m, b = q.shape[0], q.shape[1] 91 | S = np.zeros(b) 92 | 93 | q_work = np.copy(q) 94 | 95 | for i in range(b): 96 | S[i] = -1 * np.sign(q_work[i, i]) 97 | q_work[i, i] -= S[i] 98 | 99 | # scale ith column of L by diagonal element 100 | q_work[(i + 1):m, i] /= q_work[i, i] 101 | 102 | # perform Schur complement update 103 | q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], q_work[i, (i + 1):b]) 104 | 105 | L = np.tril(q_work) 106 | for i in range(b): 107 | L[i, i] = 1 108 | U = np.triu(q_work)[:b, :] 109 | return L, U, S 110 | -------------------------------------------------------------------------------- /lib/papaya/test.py: -------------------------------------------------------------------------------- 1 | import papaya.dist 2 | import papaya.single 3 | import orchpy as op 4 | import unison 5 | 6 | import argparse 7 | import numpy as np 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('server_port', type=int, help='the port to post requests to') 11 | parser.add_argument('client_port', type=int, help='the port to listen at') 12 | parser.add_argument('subscriber_port', type=int, help='the port used to set up the connections') 13 | 14 | if __name__ == "__main__": 15 | args = parser.parse_args() 16 | op.context.connect("127.0.0.1", args.server_port, args.subscriber_port, "127.0.0.1", args.client_port) 17 | op.register_current() 18 | op.register_distributed(papaya.dist) 19 | op.register_distributed(papaya.single) 20 | op.context.main_loop() 21 | 22 | from cprotobuf import ProtoEntity, Field 23 | 24 | class Test(ProtoEntity): 25 | objrefs = Field('bytes', 1, required=False) 26 | -------------------------------------------------------------------------------- /schema/comm.proto: -------------------------------------------------------------------------------- 1 | // This file is defining the communication module of Orchestra 2 | 3 | // The master builds the computation graph and does the scheduling, but has no 4 | // contact with data. The messages it sends and receives are very small, which 5 | // allows us to scale up to many clients. Data is typically referenced using so 6 | // called objrefs. These are remote references to objects, which might be 7 | // implemented using a global id or hashing. Each objref is associated with an 8 | // immutable blob of data which for fault tolerance and efficiency reasons is 9 | // distributed among a number of clients. 10 | 11 | message Args { 12 | repeated int64 objrefs = 1; // negative objrefs are indices into the data array, nonnegative ones are objrefs 13 | repeated bytes data = 2; 14 | } 15 | 16 | message Call { 17 | optional string name = 1; 18 | optional Args args = 2; 19 | repeated uint64 result = 3; 20 | enum Type { 21 | INVOKE_CALL = 1; // normal function call 22 | MAP_CALL = 2; // perform a map 23 | REDUCE_CALL = 3; // perform a reduce 24 | } 25 | optional Type type = 4; 26 | } 27 | 28 | message Blob { 29 | optional uint64 objref = 1; 30 | optional bytes data = 2; 31 | } 32 | 33 | message PullInfo { 34 | optional uint64 workerid = 1; 35 | optional uint64 objref = 2; 36 | } 37 | 38 | message FnInfo { 39 | optional string fnname = 1; 40 | repeated uint64 workerid = 2; 41 | } 42 | 43 | message ObjInfo { 44 | optional uint64 objref = 1; 45 | repeated uint64 workerid = 2; 46 | } 47 | 48 | message SchedulerInfo { 49 | repeated uint64 worker_queue = 1; 50 | repeated Call job_queue = 2; 51 | repeated PullInfo pull_queue = 3; 52 | repeated ObjInfo objtable = 4; 53 | repeated FnInfo fntable = 5; 54 | } 55 | 56 | enum MessageType { 57 | ACK = 1; // acknowledge a message 58 | INVOKE = 2; // invoke a distributed function call (uses call) 59 | REGISTER_CLIENT = 3; // register a client (uses address) 60 | REGISTER_FUNCTION = 4; // register a function (uses workerid and fnname) 61 | PUSH = 5; // client delivers an object to another machine (uses blob) 62 | PULL = 6; // client tells server to initiate sending data from nearest client (uses objref) 63 | HELLO = 7; // for registering the subscription channel 64 | DELIVER = 8; // server wants the client to deliver an object to a client (uses objref and address) 65 | DONE = 9; // client signals to the server that the current function call is completed 66 | DEBUG = 10; // sending and receiving debug info 67 | ACC = 11; // accept the delivery of an object 68 | } 69 | 70 | message Message { 71 | optional MessageType type = 1; // type of the message 72 | optional Call call = 2; // function call to be executed 73 | optional string address = 3; // address of the client we are registering or delivering data to 74 | optional Blob blob = 4; 75 | optional uint64 objref = 5; 76 | optional uint64 workerid = 6; // uniquely identify worker 77 | optional string fnname = 7; 78 | optional SchedulerInfo scheduler_info = 8; 79 | optional uint64 setup_port = 9; // the setup port for the client 80 | } 81 | -------------------------------------------------------------------------------- /schema/make-schema.sh: -------------------------------------------------------------------------------- 1 | # for rust 2 | protoc --rust_out ../src/ comm.proto 3 | protoc --rust_out ../src/ types.proto 4 | 5 | # for cprotobuf 6 | protoc --cprotobuf_out ../lib/orchpy/orchpy comm.proto types.proto 7 | -------------------------------------------------------------------------------- /schema/types.proto: -------------------------------------------------------------------------------- 1 | // Data types for Orchestra 2 | 3 | message Int { 4 | optional int32 val = 1; 5 | } 6 | 7 | message Float { 8 | optional float val = 1; 9 | } 10 | 11 | message String { 12 | optional string val = 1; 13 | } 14 | 15 | // A tensor library for Orchestra 16 | 17 | enum DataType { 18 | INT32 = 1; 19 | INT64 = 2; 20 | FLOAT32 = 3; 21 | FLOAT64 = 4; 22 | } 23 | 24 | message Array { 25 | repeated uint64 shape = 1; 26 | optional DataType dtype = 2; 27 | optional bytes data = 3; 28 | } 29 | 30 | message DistArray { 31 | repeated uint64 shape = 1; 32 | repeated uint64 block_shape = 2; 33 | optional DataType dtype = 3; 34 | optional Array objrefs = 4; // tensor of objrefs, each representing a block subtensor living on a node 35 | } 36 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Setting up a Kubernetes cluster 2 | 3 | There are various guides on how to run Kubernetes clusters on a wide variety of 4 | hardware (various cloud solutions, bare metal, etc.). The following instructions 5 | should work similarly on all of these, we encourage you to consult 6 | [https://github.com/kubernetes/kubernetes/blob/release-1.1/docs/getting-started-guides/]. 7 | 8 | For EC2, you should download and run Kubernetes using 9 | ``` 10 | export KUBERNETES_PROVIDER=aws; wget -q -O - https://get.k8s.io | bash 11 | ``` 12 | or if it is already downloaded with the `./kube-up.sh` script: 13 | ``` 14 | export MASTER_SIZE=m3.medium 15 | export MINION_SIZE=m3.medium 16 | ./kube-up.sh 17 | ``` 18 | 19 | # Launching Orchestra 20 | 21 | We provide a Docker image with Orchestra under 22 | [https://hub.docker.com/r/pcmoritz/orchestra/]. It can be started with 23 | 24 | ``` 25 | ./kubectl.sh run master --image=pcmoritz/orchestra:pre 26 | ``` 27 | 28 | # Creating and uploading docker images 29 | 30 | To create the docker image, run `docker build .` in the directory that contains 31 | the `Dockerfile`. 32 | 33 | You can then start a container that runs the image by executing 34 | ``` 35 | docker run -t -i 36 | ``` 37 | 38 | Run `docker ps` to get the id of the container and commit the image using 39 | ``` 40 | docker commit -m "" -a "" pcmoritz/orchestra:pre 41 | ``` 42 | 43 | The image can then be uploaded by running 44 | ``` 45 | docker login 46 | docker push pcmoritz/orchestra:pre 47 | ``` 48 | -------------------------------------------------------------------------------- /scripts/masterimage/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu 2 | MAINTAINER Philipp Moritz email: pcm@berkeley.edu 3 | RUN apt-get update 4 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends build-essential ca-certificates curl git wget nano unzip libzmq3-dev python2.7-dev python-pip 5 | RUN pip install numpy cython 6 | RUN pip install protobuf 7 | RUN pip install psutil 8 | ENV RUST_VERSION=1.8.0 9 | RUN cd /tmp && \ 10 | curl -sO https://static.rust-lang.org/dist/rust-nightly-x86_64-unknown-linux-gnu.tar.gz && \ 11 | tar -xvzf rust-nightly-x86_64-unknown-linux-gnu.tar.gz && \ 12 | ./rust-nightly-x86_64-unknown-linux-gnu/install.sh --without=rust-docs && \ 13 | rm -rf \ 14 | rust-$RUST_VERSION-x86_64-unknown-linux-gnu \ 15 | rust-$RUST_VERSION-x86_64-unknown-linux-gnu.tar.gz \ 16 | /var/lib/apt/lists/* \ 17 | /tmp/* \ 18 | /var/tmp/* 19 | RUN cd ~ && \ 20 | git clone https://github.com/pcmoritz/cprotobuf.git && \ 21 | cd cprotobuf && \ 22 | python setup.py install 23 | RUN cd ~ && \ 24 | wget https://github.com/amplab/orchestra/releases/download/v0.1alpha/orchestra.zip && \ 25 | unzip orchestra.zip -d orchestra && \ 26 | cd orchestra && \ 27 | cargo build 28 | RUN cd ~ && \ 29 | cd orchestra/lib/orchpy/ && \ 30 | python setup.py build 31 | ENV PYTHONPATH=/root/orchestra/lib:/root/orchestra/lib/orchpy/build/lib.linux-x86_64-2.7:${PYTHONPATH} 32 | ENV LD_LIBRARY_PATH=/root/orchestra/target/debug/:${LD_LIBRARY_PATH} 33 | 34 | WORKDIR /root/orchestra 35 | CMD ["cargo", "run", "7114", "7227", "7228"] 36 | EXPOSE 7114 7227 7228 37 | -------------------------------------------------------------------------------- /scripts/orchestra-master-controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "apiVersion": "v1", 3 | "kind": "ReplicationController", 4 | "metadata": { 5 | "name": "orchestra-master", 6 | "labels": { 7 | "app": "orchestra", 8 | "role": "master" 9 | } 10 | }, 11 | "spec": { 12 | "replicas": 1, 13 | "selector":{ 14 | "app": "orchestra", 15 | "role": "master" 16 | }, 17 | "template": { 18 | "metadata": { 19 | "labels": { 20 | "app": "orchestra", 21 | "role": "master" 22 | } 23 | }, 24 | "spec": { 25 | "containers": [{ 26 | "name": "orchestra-master", 27 | "image": "pcmoritz/orchestra:pre", 28 | "ports": [{"containerPort": 7114}, {"containerPort": 7227}] 29 | }] 30 | } 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /scripts/orchestra-master-service.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"Service", 3 | "apiVersion":"v1", 4 | "metadata":{ 5 | "name":"orchestra-service", 6 | "labels":{ 7 | "app":"orchestra", 8 | "role":"master" 9 | } 10 | }, 11 | "spec":{ 12 | "clusterIP": "10.0.171.131", 13 | "ports": [ 14 | { 15 | "name": "incoming", 16 | "port": 7114 17 | }, 18 | { 19 | "name": "publish", 20 | "port":7227 21 | }, 22 | { 23 | "name": "setup", 24 | "port":7228 25 | } 26 | ], 27 | "selector":{ 28 | "app":"orchestra", 29 | "role":"master" 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /scripts/orchestra-slave-controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"ReplicationController", 3 | "apiVersion":"v1", 4 | "metadata":{ 5 | "name":"orchestra-slave", 6 | "labels":{ 7 | "app":"orchestra", 8 | "role":"slave" 9 | } 10 | }, 11 | "spec":{ 12 | "replicas":4, 13 | "selector":{ 14 | "app":"orchestra", 15 | "role":"slave" 16 | }, 17 | "template":{ 18 | "metadata":{ 19 | "labels":{ 20 | "app":"orchestra", 21 | "role":"slave" 22 | } 23 | }, 24 | "spec":{ 25 | "containers":[ 26 | { 27 | "name":"orchestra-slave", 28 | "image":"pcmoritz/orchestra-slave:pre", 29 | "ports":[ 30 | { 31 | "name":"orch-slave", 32 | "containerPort":4000 33 | }, 34 | { 35 | "name":"orch-shell", 36 | "containerPort":2222 37 | } 38 | ], 39 | "env": [{ 40 | "name": "SLAVE_IP", 41 | "valueFrom": { 42 | "fieldRef": { 43 | "fieldPath": "status.podIP" 44 | } 45 | } 46 | }] 47 | } 48 | ] 49 | } 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /scripts/orchestra-slave-service.json: -------------------------------------------------------------------------------- 1 | { 2 | "kind":"Service", 3 | "apiVersion":"v1", 4 | "metadata":{ 5 | "name":"orchestra-slaves", 6 | "labels":{ 7 | "app":"orchestra", 8 | "role":"slave" 9 | } 10 | }, 11 | "spec":{ 12 | "clusterIP": "None", 13 | "ports": [ 14 | { 15 | "name": "reply", 16 | "port": 4000, 17 | "targetPort": "orch-slave" 18 | } 19 | ], 20 | "selector":{ 21 | "app":"orchestra", 22 | "role":"slave" 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /scripts/slaveimage/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu 2 | MAINTAINER Philipp Moritz email: pcm@berkeley.edu 3 | RUN apt-get update 4 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends build-essential ca-certificates curl git wget nano unzip libzmq3-dev python2.7-dev python-pip 5 | RUN pip install numpy cython 6 | RUN pip install protobuf 7 | RUN pip install psutil 8 | RUN pip install IPython 9 | ENV RUST_VERSION=1.8.0 10 | RUN cd /tmp && \ 11 | curl -sO https://static.rust-lang.org/dist/rust-nightly-x86_64-unknown-linux-gnu.tar.gz && \ 12 | tar -xvzf rust-nightly-x86_64-unknown-linux-gnu.tar.gz && \ 13 | ./rust-nightly-x86_64-unknown-linux-gnu/install.sh --without=rust-docs && \ 14 | rm -rf \ 15 | rust-$RUST_VERSION-x86_64-unknown-linux-gnu \ 16 | rust-$RUST_VERSION-x86_64-unknown-linux-gnu.tar.gz \ 17 | /var/lib/apt/lists/* \ 18 | /tmp/* \ 19 | /var/tmp/* 20 | RUN cd ~ && \ 21 | git clone https://github.com/pcmoritz/cprotobuf.git && \ 22 | cd cprotobuf && \ 23 | python setup.py install 24 | RUN cd ~ && \ 25 | wget https://github.com/amplab/orchestra/releases/download/v0.1alpha/orchestra.zip && \ 26 | unzip orchestra.zip -d orchestra && \ 27 | cd orchestra && \ 28 | cargo build 29 | RUN cd ~ && \ 30 | cd orchestra/lib/orchpy/ && \ 31 | python setup.py build 32 | ENV PYTHONPATH=/root/orchestra/lib:/root/orchestra/lib/orchpy/build/lib.linux-x86_64-2.7:${PYTHONPATH} 33 | ENV LD_LIBRARY_PATH=/root/orchestra/target/debug/:${LD_LIBRARY_PATH} 34 | 35 | WORKDIR /root/orchestra 36 | CMD ["python", "scripts/worker.py", "7114", "4000", "7227"] 37 | EXPOSE 4000 38 | EXPOSE 2222 39 | -------------------------------------------------------------------------------- /scripts/worker.py: -------------------------------------------------------------------------------- 1 | import orchpy as op 2 | import argparse 3 | import os 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('server_port', type=int, help='the port to post requests to') 7 | parser.add_argument('client_port', type=int, help='the port to listen at') 8 | parser.add_argument('subscriber_port', type=int, help='the port used to set up the connections') 9 | 10 | @op.distributed(str, int) 11 | def setup(filename): 12 | 13 | return 0 14 | 15 | if __name__ == "__main__": 16 | args = parser.parse_args() 17 | server_ip = os.getenv('SERVER_IP', "127.0.0.1") 18 | op.context.connect(server_ip, args.server_port, args.subscriber_port, "127.0.0.1", args.client_port) 19 | op.register_current(globals().items()) 20 | op.context.main_loop() 21 | -------------------------------------------------------------------------------- /shell.py: -------------------------------------------------------------------------------- 1 | import orchpy as op 2 | import argparse 3 | 4 | import numpy as np 5 | import papaya.single as single 6 | import papaya.dist as dist 7 | import unison 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('server_address', type=str, help='public ip address of the server') 11 | parser.add_argument('server_port', type=int, help='the port to post requests to') 12 | parser.add_argument('publish_port', type=int, help='the port for the publish channel') 13 | parser.add_argument('client_address', type=str, help='public ip address of this client') 14 | parser.add_argument('shell_port', type=int, help='the port at which the client listens') 15 | 16 | 17 | if __name__ == '__main__': 18 | args = parser.parse_args() 19 | op.context.connect(args.server_address, args.server_port, args.publish_port, args.client_address, args.shell_port) 20 | 21 | def test_dot(d1, d2, d3): 22 | print "testing dot with d1 = " + str(d1) + ", d2 = " + str(d2) + ", d3 = " + str(d3) 23 | a = dist.random_normal([d1, d2]) 24 | b = dist.random_normal([d2, d3]) 25 | c = dist.dot(a, b) 26 | a_val = a.assemble() 27 | b_val = b.assemble() 28 | c_val = c.assemble() 29 | np.testing.assert_allclose(np.dot(a_val, b_val), c_val) 30 | 31 | def test_tsqr(d1, d2): 32 | print "testing tsqr with d1 = " + str(d1) + ", d2 = " + str(d2) 33 | a = dist.random_normal([d1, d2]) 34 | q, r = dist.tsqr(a) 35 | a_val = a.assemble() 36 | q_val = q.assemble() 37 | np.testing.assert_allclose(np.dot(q_val, r), a_val) # check that a = q * r 38 | np.testing.assert_allclose(np.dot(q_val.T, q_val), np.eye(min(d1, d2)), atol=1e-6) # check that q.T * q = I 39 | np.testing.assert_allclose(np.triu(r), r) # check that r is upper triangular 40 | 41 | def test_modified_lu(d1, d2): 42 | print "testing modified_lu with d1 = " + str(d1) + ", d2 = " + str(d2) 43 | assert d1 >= d2 44 | k = min(d1, d2) 45 | m = np.random.normal(size=(d1, d2)) 46 | q, r = np.linalg.qr(m) 47 | l, u, s = single.modified_lu(q) 48 | s_mat = np.zeros((d1, d2)) 49 | for i in range(len(s)): 50 | s_mat[i, i] = s[i] 51 | np.testing.assert_allclose(q - s_mat, np.dot(l, u)) # check that q - s = l * u 52 | np.testing.assert_allclose(np.triu(u), u) # check that u is upper triangular 53 | np.testing.assert_allclose(np.tril(l), l) # check that u is lower triangular 54 | 55 | def test_tsqr_hr(d1, d2): 56 | print "testing tsqr_hr with d1 = " + str(d1) + ", d2 = " + str(d2) 57 | a = dist.random_normal([d1, d2]) 58 | a_val = a.assemble() 59 | y, t, y_top, r = dist.tsqr_hr(a) 60 | tall_eye = np.zeros((d1, min(d1, d2))) 61 | np.fill_diagonal(tall_eye, 1) 62 | q = tall_eye - np.dot(y, np.dot(t, y_top.T)) 63 | np.testing.assert_allclose(np.dot(q.T, q), np.eye(min(d1, d2)), atol=1e-6) # check that q.T * q = I 64 | np.testing.assert_allclose(np.dot(q, r), a_val) # check that a = (I - y * t * y_top.T) * r 65 | 66 | def test_qr(d1, d2): 67 | print "testing qr with d1 = " + str(d1) + ", d2 = " + str(d2) 68 | a = dist.random_normal([d1, d2]) 69 | a_val = a.assemble() 70 | Ts, y_res, r_res = dist.qr(a) 71 | r_val = r_res.assemble() 72 | y_val = y_res.assemble() 73 | q = np.eye(d1) 74 | for i in range(min(a.num_blocks[0], a.num_blocks[1])): 75 | q = np.dot(q, np.eye(d1) - np.dot(y_val[:, (i * 10):((i + 1) * 10)], np.dot(Ts[i], y_val[:, (i * 10):((i + 1) * 10)].T))) 76 | q = q[:, :min(d1, d2)] 77 | np.testing.assert_allclose(np.triu(r_val), r_val) # check that r is upper triangular 78 | np.testing.assert_allclose(np.dot(q.T, q), np.eye(min(d1, d2)), atol=1e-6) # check that q.T * q = I 79 | np.testing.assert_allclose(np.dot(q, r_val), a_val) # check that a = q * r 80 | 81 | 82 | for i in range(10): 83 | d1 = np.random.randint(1, 100) 84 | d2 = np.random.randint(1, 100) 85 | d3 = np.random.randint(1, 100) 86 | test_dot(d1, d2, d3) 87 | 88 | for i in range(10): 89 | d1 = np.random.randint(1, 50) 90 | d2 = np.random.randint(1, 11) 91 | test_tsqr(d1, d2) 92 | 93 | for i in range(10): 94 | d2 = np.random.randint(1, 100) 95 | d1 = np.random.randint(d2, 100) 96 | test_modified_lu(d1, d2) 97 | 98 | for i in range(10): 99 | d1 = np.random.randint(1, 100) 100 | d2 = np.random.randint(1, 11) 101 | test_tsqr_hr(d1, d2) 102 | 103 | for i in range(10): 104 | d1 = np.random.randint(1, 100) 105 | d2 = np.random.randint(1, 100) 106 | test_qr(d1, d2) 107 | 108 | import IPython 109 | IPython.embed() 110 | -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use zmq; 3 | use zmq::{Socket}; 4 | 5 | use comm; 6 | use utils::{ObjRef, WorkerID, receive_message, send_message, receive_subscription, send_ack, receive_ack, connect_socket, to_zmq_socket_addr}; 7 | use std::thread; 8 | use std::sync::Arc; 9 | use std::sync::Mutex; 10 | use std::sync::MutexGuard; 11 | use std::str::FromStr; 12 | use std::net::IpAddr; 13 | 14 | use protobuf::{Message, RepeatedField}; 15 | 16 | use std::sync::mpsc; 17 | use std::sync::mpsc::{Sender, Receiver}; 18 | 19 | pub type FnRef = usize; // Index of locally registered function 20 | pub type ObjStore = HashMap>; // collection of objects stored on the client 21 | 22 | pub enum Event { 23 | Obj(ObjRef), // a new object becomes available 24 | Invoke(comm::Call), // a new job request 25 | Debug(comm::Message) // for debugging purposes 26 | } 27 | 28 | #[derive(Clone, PartialEq)] 29 | pub enum State { 30 | Processing { 31 | call: comm::Call, 32 | deps: Vec // sorted 33 | }, 34 | Waiting 35 | } 36 | 37 | pub struct Context { 38 | zmq_ctx: zmq::Context, 39 | 40 | objects: Arc>, // mapping from objrefs to data 41 | functions: HashMap, // mapping from function name to interpreter-local function reference 42 | types: HashMap, // mapping from type name to type id 43 | 44 | state: State, // Some(call) if function call has just been evaluated and None otherwise 45 | function: FnRef, // function that is currently active 46 | pub args: Vec, // serialized version of the Args datastructure 47 | 48 | notify_main: Receiver, // reply thread signals main thread 49 | request: Socket, 50 | workerid: WorkerID 51 | } 52 | 53 | impl Context { 54 | pub fn start_reply_thread(zmq_ctx: &mut zmq::Context, client_addr: &str, notify_main: Sender, objects: Arc>) { 55 | let mut reply = zmq_ctx.socket(zmq::REP).unwrap(); 56 | reply.bind(client_addr).unwrap(); 57 | 58 | thread::spawn(move || { 59 | loop { 60 | let msg = receive_message(&mut reply); 61 | match msg.get_field_type() { 62 | comm::MessageType::PUSH => { 63 | let blob = msg.get_blob(); 64 | let objref = blob.get_objref(); 65 | { 66 | objects.lock().unwrap().insert(objref, blob.get_data().to_vec()); 67 | } 68 | send_ack(&mut reply); 69 | notify_main.send(Event::Obj(objref)).unwrap(); 70 | }, 71 | comm::MessageType::INVOKE => { 72 | notify_main.send(Event::Invoke(msg.get_call().clone())).unwrap(); 73 | send_ack(&mut reply); 74 | }, 75 | _ => { 76 | error!("error, got {:?}", msg.get_field_type()); 77 | error!("{:?}", msg.get_address()); 78 | } 79 | } 80 | } 81 | }); 82 | } 83 | pub fn new(server_addr: &IpAddr, reply_port: u16, publish_port: u16, client_addr: &IpAddr, client_port: u16) -> Context { 84 | let mut zmq_ctx = zmq::Context::new(); 85 | 86 | let mut request = zmq_ctx.socket(zmq::REQ).unwrap(); 87 | request.connect(&to_zmq_socket_addr(server_addr, reply_port)[..]).unwrap(); 88 | 89 | let (reply_sender, reply_receiver) = mpsc::channel(); // TODO: rename this 90 | 91 | info!("connecting to server..."); 92 | let mut reg = comm::Message::new(); 93 | reg.set_field_type(comm::MessageType::REGISTER_CLIENT); 94 | reg.set_address(to_zmq_socket_addr(client_addr, client_port)); 95 | 96 | let objects = Arc::new(Mutex::new(HashMap::new())); 97 | 98 | let localhost = IpAddr::from_str("0.0.0.0").unwrap(); 99 | Context::start_reply_thread(&mut zmq_ctx, &to_zmq_socket_addr(&localhost, client_port)[..], reply_sender.clone(), objects.clone()); 100 | 101 | thread::sleep_ms(10); 102 | 103 | send_message(&mut request, &mut reg); 104 | let ack = receive_message(&mut request); 105 | let workerid = ack.get_workerid() as WorkerID; 106 | info!("my workerid is {}", workerid); 107 | let setup_port = ack.get_setup_port() as u16; 108 | info!("setup port is {}", setup_port); 109 | 110 | // the network thread listens to commands on the master subscription channel and serves the other client channels with data. It notifies the main thread if new data becomes available. 111 | 112 | // do handshake with server and get client id, so we can communicate on the subscription channel 113 | 114 | // let (main_sender, main_receiver) = mpsc::channel(); 115 | // let (network_sender, network_receiver) = mpsc::channel(); 116 | let mut clients: HashMap = HashMap::new(); // other clients that are part of the cluster 117 | 118 | let thread_objects = objects.clone(); 119 | let server_addr = server_addr.clone(); 120 | 121 | thread::spawn(move || { 122 | let mut zmq_ctx = zmq::Context::new(); 123 | let mut subscriber = Context::connect_network_thread(&mut zmq_ctx, workerid, &server_addr, setup_port, publish_port); 124 | 125 | loop { 126 | let msg = receive_subscription(&mut subscriber); 127 | 128 | match msg.get_field_type() { 129 | comm::MessageType::REGISTER_CLIENT => { 130 | // push onto workers 131 | info!("connecting to client {}", msg.get_address()); 132 | let mut other = zmq_ctx.socket(zmq::REQ).unwrap(); 133 | other.connect(msg.get_address()).ok().unwrap(); 134 | clients.insert(msg.get_address().into(), other); 135 | // new code: create new thread, insert it here 136 | } 137 | comm::MessageType::DELIVER => { 138 | let objref = msg.get_objref(); 139 | let mut answer = comm::Message::new(); 140 | answer.set_field_type(comm::MessageType::PUSH); 141 | let mut blob = comm::Blob::new(); 142 | blob.set_objref(objref); 143 | let data = { 144 | let objs : MutexGuard>> = thread_objects.lock().unwrap(); 145 | objs.get(&objref).and_then(|data| Some(data.to_vec())).expect("data not available on this client") 146 | }; 147 | blob.set_data(data); 148 | answer.set_blob(blob); 149 | let target = &mut clients.get_mut(msg.get_address()).expect("target client not found"); // shouldn't happen 150 | send_message(target, &mut answer); 151 | receive_ack(target); 152 | }, 153 | comm::MessageType::DEBUG => { 154 | reply_sender.send(Event::Debug(msg)).unwrap(); 155 | }, 156 | _ => {} 157 | } 158 | } 159 | }); 160 | 161 | return Context { 162 | zmq_ctx: zmq_ctx, 163 | objects: objects.clone(), functions: HashMap::new(), types: HashMap::new(), 164 | state: State::Waiting, function: 0, args: Vec::new(), 165 | notify_main: reply_receiver, 166 | request: request, 167 | workerid: workerid 168 | } 169 | } 170 | 171 | fn connect_network_thread(zmq_ctx: &mut zmq::Context, workerid: WorkerID, server_addr: &IpAddr, setup_port: u16, subscriber_port: u16) -> Socket { 172 | let mut subscriber = zmq_ctx.socket(zmq::SUB).unwrap(); 173 | info!("subscriber_port {}", subscriber_port); 174 | connect_socket(&mut subscriber, server_addr, subscriber_port); 175 | subscriber.set_subscribe(format!("{:0>#07}", workerid).as_bytes()).unwrap(); 176 | 177 | let mut setup = zmq_ctx.socket(zmq::REQ).unwrap(); 178 | connect_socket(&mut setup, server_addr, setup_port); 179 | info!("setup_port {}", setup_port); 180 | thread::sleep_ms(10); 181 | // set up sub/pub socket 182 | let mut msg = zmq::Message::new().unwrap(); 183 | subscriber.recv(&mut msg, 0).unwrap(); 184 | 185 | setup.send(b"joining", 0).unwrap(); 186 | 187 | info!("accepted server invitation"); 188 | 189 | return subscriber 190 | } 191 | 192 | pub fn add_object<'b>(self: &'b mut Context, objref: ObjRef, data: Vec) { 193 | self.objects.lock().unwrap().insert(objref, data); 194 | } 195 | 196 | pub fn add_function<'b>(self: &'b mut Context, name: String) -> usize { 197 | info!("registering function {}", name); 198 | let idx = self.functions.len(); 199 | self.functions.insert(name.to_string(), idx); 200 | 201 | let mut msg = comm::Message::new(); 202 | msg.set_field_type(comm::MessageType::REGISTER_FUNCTION); 203 | msg.set_fnname(name.to_string()); 204 | msg.set_workerid(self.workerid as u64); 205 | send_message(&mut self.request, &mut msg); 206 | receive_ack(&mut self.request); 207 | 208 | return idx; 209 | } 210 | pub fn get_function<'b>(self: &'b Context) -> FnRef { 211 | return self.function; 212 | } 213 | // TODO: Make this more efficient, i.e. use only one lookup 214 | pub fn get_obj_len<'b>(self: &'b Context, objref: ObjRef) -> Option { 215 | self.objects.lock().unwrap().get(&objref).and_then(|data| Some(data[..].len())) 216 | } 217 | pub fn get_obj_ptr<'b>(self: &'b Context, objref: ObjRef) -> Option<*const u8> { 218 | self.objects.lock().unwrap().get(&objref).and_then(|data| Some(data[..].as_ptr())) 219 | } 220 | pub fn add_type<'b>(self: &'b mut Context, name: String) { 221 | let index = self.types.len(); 222 | self.types.insert(name, index as i32); 223 | } 224 | pub fn get_type<'b>(self: &'b mut Context, name: String) -> Option { 225 | return self.types.get(&name).and_then(|&num| Some(num)); 226 | } 227 | pub fn remote_call_function<'b>(self: &'b mut Context, name: String, args: comm::Args) -> ObjRef { 228 | let mut msg = comm::Message::new(); 229 | msg.set_field_type(comm::MessageType::INVOKE); 230 | let mut call = comm::Call::new(); 231 | call.set_name(name); 232 | call.set_args(args); 233 | msg.set_call(call); 234 | send_message(&mut self.request, &mut msg); 235 | let answer = receive_message(&mut self.request); 236 | let result = answer.get_call().get_result(); 237 | assert!(result.len() == 1); 238 | return result[0]; 239 | } 240 | // TODO: Remove duplication between remote_call_function and remote_call_map 241 | pub fn remote_call_map<'b>(self: &'b mut Context, name: String, args: comm::Args) -> Vec { 242 | let mut msg = comm::Message::new(); 243 | msg.set_field_type(comm::MessageType::INVOKE); 244 | let mut call = comm::Call::new(); 245 | call.set_field_type(comm::Call_Type::MAP_CALL); 246 | call.set_name(name); 247 | call.set_args(args); 248 | msg.set_call(call); 249 | send_message(&mut self.request, &mut msg); 250 | let answer = receive_message(&mut self.request); 251 | return answer.get_call().get_result().to_vec(); // TODO: get rid of this copy 252 | } 253 | pub fn pull_remote_object<'b>(self: &'b mut Context, objref: ObjRef) -> ObjRef { 254 | { 255 | let objects = self.objects.lock().unwrap(); 256 | if objects.contains_key(&objref) { 257 | return objref; 258 | } 259 | } 260 | let mut msg = comm::Message::new(); 261 | msg.set_field_type(comm::MessageType::PULL); 262 | msg.set_objref(objref); 263 | msg.set_workerid(self.workerid as u64); 264 | send_message(&mut self.request, &mut msg); 265 | receive_ack(&mut self.request); 266 | loop { 267 | // println!("looping"); 268 | match self.notify_main.recv().unwrap() { 269 | Event::Obj(pushedref) => { 270 | if pushedref == objref { 271 | return objref; 272 | } 273 | } 274 | _ => {} 275 | } 276 | } 277 | } 278 | pub fn push_remote_object<'b>(self: &'b mut Context) -> ObjRef { 279 | let mut msg = comm::Message::new(); 280 | msg.set_field_type(comm::MessageType::PUSH); 281 | msg.set_workerid(self.workerid as u64); 282 | send_message(&mut self.request, &mut msg); 283 | let answer = receive_message(&mut self.request); 284 | let result = answer.get_call().get_result(); 285 | assert!(result.len() == 1); 286 | return result[0]; 287 | } 288 | pub fn pull_debug_info<'b>(self: &'b mut Context) -> comm::Message { 289 | let mut msg = comm::Message::new(); 290 | msg.set_field_type(comm::MessageType::DEBUG); 291 | msg.set_workerid(self.workerid as u64); 292 | send_message(&mut self.request, &mut msg); 293 | receive_ack(&mut self.request); 294 | loop { 295 | match self.notify_main.recv().unwrap() { 296 | Event::Debug(msg) => { 297 | return msg; 298 | }, 299 | _ => {} 300 | } 301 | } 302 | } 303 | pub fn finish_request<'b>(self: &'b mut Context) { 304 | match self.state.clone() { // TODO: remove the clone 305 | State::Processing{call, deps: _} => { 306 | let mut done = comm::Message::new(); 307 | done.set_field_type(comm::MessageType::DONE); 308 | done.set_call(call.clone()); 309 | done.set_workerid(self.workerid as u64); 310 | send_message(&mut self.request, &mut done); 311 | receive_ack(&mut self.request); 312 | self.state = State::Waiting; 313 | } 314 | State::Waiting => {} 315 | } 316 | } 317 | 318 | pub fn client_step<'b>(self: &'b mut Context) -> ObjRef { 319 | loop { 320 | match self.notify_main.recv().unwrap() { 321 | Event::Obj(objref) => { 322 | // TODO: Make this more efficient: 323 | // START 324 | let mut acc = comm::Message::new(); 325 | acc.set_field_type(comm::MessageType::ACC); 326 | acc.set_workerid(self.workerid as u64); 327 | acc.set_objref(objref); 328 | send_message(&mut self.request, &mut acc); 329 | let answer = receive_message(&mut self.request); 330 | // END 331 | // if all elements for the current call are satisfied, evaluate it 332 | match self.state { 333 | State::Waiting => {}, 334 | State::Processing {call: _, ref mut deps} => { 335 | match deps.binary_search(&objref) { 336 | Ok(idx) => { 337 | deps.remove(idx); // TODO: use more efficient data structure 338 | } 339 | _ => {} 340 | } 341 | } 342 | } 343 | }, 344 | Event::Invoke(call) => { 345 | info!("starting to evaluate {:?}", call.get_name()); 346 | assert!(self.state == State::Waiting); 347 | let mut args = vec!(); 348 | { 349 | let objects = self.objects.lock().unwrap(); 350 | for elem in call.get_args().get_objrefs() { 351 | if *elem >= 0 { 352 | let objref = *elem as u64; 353 | if !objects.contains_key(&objref) { 354 | args.push(objref); 355 | } 356 | } 357 | } 358 | } 359 | args.sort(); 360 | args.dedup(); 361 | info!("need args {:?}", args); 362 | self.state = State::Processing{call: call.clone(), deps: args}; 363 | // if all elements for the current call are satisfied, evaluate it 364 | }, 365 | _ => {} 366 | } 367 | 368 | match self.state { 369 | State::Waiting => {}, 370 | State::Processing {ref call, ref deps} => { 371 | if deps.len() == 0 { 372 | // serializing args datastructure 373 | self.args.clear(); 374 | call.get_args().write_to_writer(&mut self.args).unwrap(); 375 | // calling the function 376 | let name = call.get_name().to_string(); 377 | self.function = self.functions.get(&name).expect("function not found").clone(); 378 | let result = call.get_result(); 379 | assert!(result.len() == 1); 380 | return result[0] 381 | } 382 | } 383 | } 384 | } 385 | } 386 | } 387 | -------------------------------------------------------------------------------- /src/graph.rs: -------------------------------------------------------------------------------- 1 | use petgraph::{Graph, Directed}; 2 | use petgraph::graph::NodeIndex; 3 | use utils::ObjRef; 4 | 5 | pub type Host = u64; 6 | 7 | // A node in the computation graph, can be a data node (Obj), a function call node (Op), a Map node 8 | // or a Reduce node. An opid is a pointer into the ops vector of the computation graph. 9 | #[derive(Hash, PartialOrd, Ord, PartialEq, Eq, Clone, Copy, Debug)] 10 | enum Node<'a> { 11 | Map { 12 | opid: usize 13 | }, 14 | Reduce { 15 | opid: usize 16 | }, 17 | Op { 18 | opid: usize 19 | }, 20 | Obj { 21 | objref: ObjRef, 22 | hosts: &'a [Host] 23 | } 24 | } 25 | 26 | pub struct CompGraph<'a> { 27 | objs: Vec, // mapping from objrefs to nodes in the graph 28 | ops: Vec, // names of operations 29 | graph: Graph, f32, Directed> // computation graph 30 | } 31 | 32 | impl<'a> CompGraph<'a> { 33 | pub fn new() -> CompGraph<'a> { 34 | return CompGraph { 35 | graph: Graph::new(), 36 | objs: Vec::new(), 37 | ops: Vec::new() 38 | }; 39 | } 40 | pub fn add_obj(self: &mut CompGraph<'a>) -> (ObjRef, NodeIndex) { 41 | let objref = self.objs.len() as ObjRef; 42 | let obj = self.graph.add_node(Node::Obj{objref: objref, hosts: &[]}); 43 | self.objs.push(obj); 44 | return (objref, obj); 45 | } 46 | pub fn add_op<'b>(self: &mut CompGraph<'a>, name: String, args: &'b [ObjRef], result: ObjRef) { 47 | self.ops.push(name); // TODO: only store unique names 48 | let func = self.graph.add_node(Node::Op {opid: self.ops.len() - 1}); 49 | let res = self.graph.add_node(Node::Obj {objref: result, hosts: &[]}); 50 | for arg in args { 51 | self.graph.add_edge(self.objs[*arg as usize], func, 0.0); 52 | } 53 | self.graph.add_edge(func, res, 0.0); 54 | } 55 | pub fn add_map<'b>(self: &mut CompGraph<'a>, name: String, args: &'b [ObjRef], results: &'b [ObjRef]) { 56 | assert!(args.len() == results.len()); 57 | self.ops.push(name); // TODO: only store unique names 58 | let map = self.graph.add_node(Node::Map {opid: self.ops.len() - 1}); 59 | for i in 0..args.len() { 60 | let res = self.graph.add_node(Node::Obj {objref: results[i], hosts: &[]}); 61 | self.graph.add_edge(self.objs[args[i] as usize], map, 0.0); 62 | self.graph.add_edge(map, self.objs[results[i] as usize], 0.0); 63 | } 64 | } 65 | pub fn add_reduce<'b>(self: &mut CompGraph<'a>, name: String, args: &'b [ObjRef], result: ObjRef) { 66 | self.ops.push(name); // TODO: only store unique names 67 | let reduce = self.graph.add_node(Node::Reduce {opid: self.ops.len() - 1}); 68 | let res = self.graph.add_node(Node::Obj {objref: result, hosts: &[]}); 69 | for arg in args { 70 | self.graph.add_edge(self.objs[*arg as usize], reduce, 0.0); 71 | } 72 | self.graph.add_edge(reduce, res, 0.0); 73 | } 74 | } 75 | 76 | pub struct DotBuilder { 77 | buf: String, 78 | } 79 | 80 | impl DotBuilder { 81 | pub fn new_digraph(name: &str) -> Self { 82 | DotBuilder{buf: format!("digraph \"{}\" {}", name, "{\n")} 83 | } 84 | 85 | pub fn set_node_attrs(&mut self, node: &str, attrs: &str) { 86 | self.buf.push_str(&format!("\"{}\" [{}];\n", node, attrs)); 87 | } 88 | 89 | pub fn add_edge(&mut self, from: &str, to: &str) { 90 | self.buf.push_str(&format!("\"{}\" -> \"{}\";\n", from, to)); 91 | } 92 | 93 | pub fn finish(&mut self) { 94 | self.buf.push_str("}\n"); 95 | } 96 | } 97 | 98 | pub fn to_dot<'a>(graph: &CompGraph<'a>) -> String { 99 | let mut builder = DotBuilder::new_digraph(""); 100 | for i in 0..graph.graph.node_count() { 101 | let idx = NodeIndex::new(i); 102 | let id = i.to_string(); 103 | let weight = graph.graph.node_weight(idx).unwrap(); 104 | let label = match *weight { 105 | Node::Op { opid } => format!("label=\"{}\"", &graph.ops[opid]), 106 | Node::Obj { objref, hosts } => format!("label=\"{}\"", objref), 107 | Node::Map { opid } => format!("label=\"{}\"", &graph.ops[opid]), 108 | Node::Reduce { opid } => format!("label=\"reduce {}\"", &graph.ops[opid]) 109 | }; 110 | builder.set_node_attrs(&id, &label); 111 | } 112 | for edge in graph.graph.raw_edges() { 113 | let src = edge.source().index().to_string(); 114 | let target = edge.target().index().to_string(); 115 | builder.add_edge(&src[..], &target[..]); 116 | } 117 | builder.finish(); 118 | return builder.buf; 119 | } 120 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // This file is supposed to shield all the unsafe functionality from the rest of the program 2 | #![crate_type = "dylib"] 3 | 4 | #![feature(ip_addr)] 5 | #![feature(convert)] 6 | #![feature(box_syntax)] 7 | #[macro_use] 8 | extern crate log; 9 | extern crate env_logger; 10 | extern crate protobuf; 11 | extern crate libc; 12 | extern crate rand; 13 | 14 | extern crate zmq; 15 | 16 | pub mod comm; 17 | pub mod client; 18 | pub mod utils; 19 | 20 | use libc::{size_t, c_char, uint8_t}; 21 | use std::slice; 22 | use client::{Context}; 23 | use std::ffi::CStr; 24 | use std::mem::transmute; 25 | use std::str; 26 | use std::str::FromStr; 27 | use std::net::IpAddr; 28 | use protobuf::{CodedInputStream, Message}; 29 | 30 | fn string_from_c(string: *const c_char) -> String { 31 | let c_str: &CStr = unsafe { CStr::from_ptr(string) }; 32 | let str_slice: &str = str::from_utf8(c_str.to_bytes()).unwrap(); 33 | return str_slice.to_owned(); 34 | } 35 | 36 | #[repr(C)] 37 | pub struct Slice { 38 | len: usize, 39 | data: *const uint8_t 40 | } 41 | 42 | #[no_mangle] 43 | pub extern "C" fn orchestra_create_context(server_addr: *const c_char, reply_port: u16, publish_port: u16, client_addr: *const c_char, client_port: u16) -> *mut Context { 44 | let server_string = string_from_c(server_addr); 45 | let server_addr = IpAddr::from_str(&server_string).unwrap(); // TODO: Proper error handling 46 | let client_string = string_from_c(client_addr); 47 | let client_addr = IpAddr::from_str(&client_string).unwrap(); // TODO: Proper error handling 48 | 49 | match env_logger::init() { 50 | Ok(()) => {}, 51 | SetLoggerError => {} // logging framework already initialized 52 | } 53 | 54 | let res = unsafe { transmute(box Context::new(&server_addr, reply_port, publish_port, &client_addr, client_port)) }; 55 | return res; 56 | } 57 | 58 | #[no_mangle] 59 | pub extern "C" fn orchestra_destroy_context(context: *mut Context) { 60 | let _drop_me: Box = unsafe { transmute(context) }; 61 | } 62 | 63 | /* 64 | 65 | #[no_mangle] 66 | pub extern "C" fn orchestra_register_type(context: *mut Context, name: *const c_char) { 67 | let name = string_from_c(name); 68 | unsafe { (*context).client().add_type(name) }; 69 | } 70 | 71 | #[no_mangle] 72 | pub extern "C" fn orchestra_get_type(context: *mut Context, name: *const c_char) -> c_int { 73 | let name = string_from_c(name); 74 | unsafe { 75 | match (*context).client().get_type(name) { 76 | Some(id) => return id as c_int, 77 | None => return -1 78 | } 79 | } 80 | } 81 | 82 | */ 83 | 84 | #[no_mangle] 85 | pub extern "C" fn orchestra_register_function(context: *mut Context, name: *const c_char) -> usize { 86 | let name = string_from_c(name); 87 | unsafe { return (*context).add_function(name) }; 88 | } 89 | 90 | #[no_mangle] 91 | pub extern "C" fn orchestra_store_result(context: *mut Context, objref: size_t, data: *const uint8_t, datalen: size_t) { 92 | let data = unsafe { slice::from_raw_parts(data, datalen as usize) }; 93 | unsafe { (*context).add_object(objref, data.to_vec()) }; 94 | } 95 | 96 | pub fn args_from_c(args: *const uint8_t, argslen: size_t) -> comm::Args { 97 | let bytes = unsafe { slice::from_raw_parts::(args, argslen as usize) }; 98 | let mut result = comm::Args::new(); 99 | let mut is = CodedInputStream::from_bytes(bytes); 100 | result.merge_from(&mut is).unwrap(); 101 | return result; 102 | } 103 | 104 | /* 105 | pub fn result_to_c(context: *mut Context, result: comm::Args) -> Slice { 106 | unsafe { 107 | result.write_to_writer(&mut (*context).result).unwrap(); 108 | return Slice { len: (*context).result[..].len() as u64, data: (*context).result[..].as_ptr() } 109 | } 110 | } 111 | */ 112 | 113 | #[no_mangle] 114 | pub extern "C" fn orchestra_call(context: *mut Context, name: *const c_char, args: *const uint8_t, argslen: size_t) -> size_t { 115 | let name = string_from_c(name); 116 | let arguments = args_from_c(args, argslen); 117 | unsafe { 118 | return (*context).remote_call_function(name, arguments); 119 | } 120 | } 121 | 122 | /// retlist needs to be preallocated on caller side 123 | #[no_mangle] 124 | pub extern "C" fn orchestra_map(context: *mut Context, name: *const c_char, args: *const uint8_t, argslen: size_t, retlist: *mut size_t) { 125 | let name = string_from_c(name); 126 | let arguments = args_from_c(args, argslen); 127 | unsafe { 128 | let result = (*context).remote_call_map(name, arguments); 129 | for (i, elem) in result.iter().enumerate() { 130 | *retlist.offset(i as isize) = *elem; 131 | } 132 | }; 133 | } 134 | 135 | #[no_mangle] 136 | pub extern "C" fn orchestra_pull(context: *mut Context, objref: size_t) -> size_t { 137 | unsafe { return (*context).pull_remote_object(objref); } 138 | } 139 | 140 | #[no_mangle] 141 | pub extern "C" fn orchestra_push(context: *mut Context) -> size_t { 142 | unsafe { return (*context).push_remote_object(); } 143 | } 144 | 145 | #[no_mangle] 146 | pub extern "C" fn orchestra_debug_info(context: *mut Context) { 147 | unsafe { 148 | let msg = (*context).pull_debug_info(); 149 | println!("worker queue: {:?}", msg.get_scheduler_info().get_worker_queue()); 150 | println!("job queue:"); 151 | for call in msg.get_scheduler_info().get_job_queue() { 152 | println!("call: {:?}, {:?} -> {:?}", call.get_name(), call.get_args(), call.get_result()); 153 | } 154 | println!("object table:"); 155 | for info in msg.get_scheduler_info().get_objtable() { 156 | println!("entry: {:?}: {:?}", info.get_objref(), info.get_workerid()); 157 | } 158 | println!("function table"); 159 | for info in msg.get_scheduler_info().get_fntable() { 160 | println!("entry: {:?}: {:?}", info.get_fnname(), info.get_workerid()); 161 | } 162 | } 163 | } 164 | 165 | #[no_mangle] 166 | pub extern "C" fn orchestra_step(context: *mut Context) -> size_t { 167 | unsafe { 168 | (*context).finish_request(); 169 | return (*context).client_step(); 170 | } 171 | } 172 | 173 | #[no_mangle] 174 | pub extern "C" fn orchestra_function_index(context: *mut Context) -> usize { 175 | unsafe { (*context).get_function() } 176 | } 177 | 178 | #[no_mangle] 179 | pub extern "C" fn orchestra_get_args(context: *mut Context) -> Slice { 180 | unsafe { return Slice { len: (*context).args[..].len(), data: (*context).args[..].as_ptr() } } 181 | } 182 | 183 | #[no_mangle] 184 | pub extern "C" fn orchestra_get_obj_len(context: *mut Context, objref: u64) -> usize { 185 | unsafe { (*context).get_obj_len(objref).expect("object reference not found") } 186 | } 187 | 188 | #[no_mangle] 189 | pub extern "C" fn orchestra_get_obj_ptr(context: *mut Context, objref: u64) -> *const uint8_t { 190 | unsafe { (*context).get_obj_ptr(objref).expect("object reference not found") } 191 | } 192 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #![feature(ip_addr)] 2 | #![feature(convert)] 3 | #![feature(custom_derive)] 4 | #![feature(deque_extras)] 5 | 6 | #[macro_use] 7 | extern crate log; 8 | extern crate argparse; 9 | extern crate env_logger; 10 | extern crate rand; 11 | extern crate petgraph; 12 | extern crate protobuf; 13 | extern crate zmq; 14 | 15 | pub mod comm; 16 | mod types; 17 | mod graph; 18 | pub mod server; 19 | pub mod scheduler; 20 | pub mod utils; 21 | 22 | use argparse::{ArgumentParser, Store}; 23 | 24 | fn main() { 25 | let mut incoming_port = 0; 26 | let mut publish_port = 0; 27 | let mut setup_port = 0; 28 | { 29 | let mut ap = ArgumentParser::new(); 30 | ap.set_description("Orchestra server"); 31 | ap.refer(&mut incoming_port).add_argument("incoming_port", Store, "port for incoming requests"); 32 | ap.refer(&mut publish_port).add_argument("publish_port", Store, "port for message broadcasting"); 33 | ap.refer(&mut setup_port).add_argument("setup_port", Store, "port for setting up broadcasting"); 34 | ap.parse_args_or_exit(); 35 | } 36 | env_logger::init().unwrap(); 37 | let mut server = server::Server::new(publish_port); 38 | server.main_loop(incoming_port, setup_port); 39 | } 40 | -------------------------------------------------------------------------------- /src/scheduler.rs: -------------------------------------------------------------------------------- 1 | use std::iter::FromIterator; 2 | use std::collections::VecDeque; 3 | use std::thread; 4 | use std::sync::mpsc; 5 | use std::sync::mpsc::{Sender, Receiver}; 6 | use std::sync::{Arc, RwLock, Mutex, MutexGuard}; 7 | use comm; 8 | use utils::{WorkerID, ObjRef, ObjTable, FnTable}; 9 | use server::Worker; 10 | use protobuf::RepeatedField; 11 | 12 | /// Notify the scheduler that something happened 13 | pub enum Event { 14 | /// A worker becomes available for computation. 15 | Worker(WorkerID), 16 | /// An object becomes available. 17 | Obj(ObjRef), 18 | /// A job is being scheduled. 19 | Job(comm::Call), 20 | /// A pull request was issued. 21 | Pull(WorkerID, ObjRef), 22 | /// A new worker has been added. 23 | Register(WorkerID, Sender), 24 | /// Dump status of the scheduler. 25 | Debug(WorkerID) 26 | } 27 | 28 | /// A scheduler assigns incoming jobs to workers. It communicates with the worker pool through 29 | /// channels. If a job is scheduled or a worker becomes available, this is signaled to the 30 | /// Scheduler using the channel returned by the `Scheduler::start` method. The scheduler signals the 31 | /// execution of a function call to the appropriate worker thread via a channel that is registered 32 | /// using `Event::Register`. 33 | pub struct Scheduler { 34 | objtable: Arc>, 35 | fntable: Arc>, 36 | } 37 | 38 | impl Scheduler { 39 | /// Start the scheduling thread. 40 | pub fn start(objtable: Arc>, fntable: Arc>) -> Sender { 41 | let (event_sender, event_receiver) = mpsc::channel(); // notify the scheduler that a worker, job or object becomes available 42 | let scheduler = Scheduler { objtable: objtable.clone(), fntable: fntable.clone() }; 43 | scheduler.start_dispatch_thread(event_receiver); 44 | return event_sender 45 | } 46 | 47 | fn send_function_call(workers: &Vec>, workerid: WorkerID, job: comm::Call) { 48 | info!("scheduling function call {} on worker {}", job.get_name(), workerid); 49 | let mut msg = comm::Message::new(); 50 | msg.set_field_type(comm::MessageType::INVOKE); 51 | msg.set_call(job); 52 | workers[workerid].send(msg).unwrap(); 53 | } 54 | 55 | fn send_pull_request(workers: &Vec>, workerid: WorkerID, objref: ObjRef) { 56 | let mut msg = comm::Message::new(); 57 | msg.set_field_type(comm::MessageType::PULL); 58 | msg.set_workerid(workerid as u64); 59 | msg.set_objref(objref); 60 | workers[workerid].send(msg).unwrap(); 61 | } 62 | 63 | fn send_debugging_info(self: &Scheduler, socket: &Sender, worker_queue: &VecDeque, job_queue: &VecDeque) { 64 | let mut scheduler_info = comm::SchedulerInfo::new(); 65 | scheduler_info.set_worker_queue(worker_queue.iter().map(|x| *x as u64).collect()); 66 | let mut jobs = Vec::new(); 67 | for job in job_queue.iter() { 68 | jobs.push(job.clone()); 69 | } 70 | scheduler_info.set_job_queue(RepeatedField::from_vec(jobs)); 71 | let objtable = self.objtable.lock().unwrap(); 72 | let mut objs = Vec::new(); 73 | for (objref, workers) in objtable.iter().enumerate() { 74 | let mut info = comm::ObjInfo::new(); 75 | info.set_objref(objref as u64); 76 | let workers : &Vec = workers; 77 | info.set_workerid((*workers).iter().map(|x| *x as u64).collect()); 78 | objs.push(info); 79 | } 80 | scheduler_info.set_objtable(RepeatedField::from_vec(objs)); 81 | 82 | let fntable = self.fntable.read().unwrap(); 83 | let mut fns = Vec::new(); 84 | for (fnname, workers) in fntable.iter() { 85 | let mut info = comm::FnInfo::new(); 86 | info.set_fnname(fnname.to_string()); 87 | info.set_workerid((*workers).iter().map(|x| *x as u64).collect()); 88 | fns.push(info); 89 | } 90 | scheduler_info.set_fntable(RepeatedField::from_vec(fns)); 91 | 92 | let mut msg = comm::Message::new(); 93 | msg.set_field_type(comm::MessageType::DEBUG); 94 | msg.set_scheduler_info(scheduler_info); 95 | 96 | socket.send(msg).unwrap(); 97 | } 98 | 99 | /// Find job whose dependencies are met. 100 | fn find_next_job(self: &Scheduler, workerid: WorkerID, job_queue: &VecDeque) -> Option { 101 | let objtable = &self.objtable.lock().unwrap(); 102 | for (i, job) in job_queue.iter().enumerate() { 103 | if !self.fntable.read().unwrap().contains_key(job.get_name()) { 104 | panic!("next job bailing"); 105 | return None; 106 | } 107 | if self.fntable.read().unwrap()[job.get_name()].binary_search(&workerid).is_ok() && self.can_run(job, objtable) { 108 | return Some(i); 109 | } 110 | } 111 | return None; 112 | } 113 | 114 | fn can_run(self: &Scheduler, job: &comm::Call, objtable: &MutexGuard) -> bool { 115 | for elem in job.get_args().get_objrefs() { 116 | if *elem >= 0 { 117 | if objtable[*elem as usize].len() == 0 { 118 | return false; 119 | } 120 | } 121 | } 122 | return true; 123 | } 124 | 125 | // TODO: replace fntable vector with bitfield 126 | fn find_next_worker(self: &Scheduler, job: &comm::Call, worker_queue: &VecDeque) -> Option { 127 | let objtable = &self.objtable.lock().unwrap(); 128 | for (i, workerid) in worker_queue.iter().enumerate() { 129 | if !self.fntable.read().unwrap().contains_key(job.get_name()) { 130 | panic!("next worker bailing"); 131 | return None; 132 | } 133 | if self.fntable.read().unwrap()[job.get_name()].binary_search(workerid).is_ok() && self.can_run(job, objtable) { 134 | return Some(i); 135 | } 136 | } 137 | return None; 138 | } 139 | 140 | // will be notified of workers or jobs that become available throught the worker_notify or job_notify channel 141 | fn start_dispatch_thread(self: Scheduler, event_notify: Receiver) { 142 | thread::spawn(move || { 143 | let mut workers = Vec::>::new(); 144 | let mut worker_queue = VecDeque::::new(); 145 | let mut job_queue = VecDeque::::new(); 146 | let mut pull_queue = VecDeque::<(WorkerID, ObjRef)>::new(); 147 | 148 | loop { 149 | // use the most simple algorithms for now 150 | match event_notify.recv().unwrap() { 151 | Event::Worker(workerid) => { 152 | match self.find_next_job(workerid, &job_queue) { 153 | Some(jobidx) => { 154 | let job = job_queue.swap_remove_front(jobidx).unwrap(); 155 | Scheduler::send_function_call(&mut workers, workerid, job); 156 | } 157 | None => { 158 | worker_queue.push_back(workerid); 159 | } 160 | } 161 | }, 162 | Event::Job(job) => { 163 | match self.find_next_worker(&job, &worker_queue) { 164 | Some(workeridx) => { 165 | let workerid = worker_queue.swap_remove_front(workeridx).unwrap(); 166 | Scheduler::send_function_call(&mut workers, workerid, job); 167 | } 168 | None => { 169 | job_queue.push_back(job); 170 | } 171 | } 172 | }, 173 | Event::Obj(newobjref) => { 174 | // TODO: do this with a binary search 175 | for &(workerid, objref) in pull_queue.iter() { 176 | if objref == newobjref { 177 | Scheduler::send_pull_request(&mut workers, workerid, objref); 178 | } 179 | } 180 | // see if we can evaluate one of the pending jobs now 181 | let mut workeridx = 0; 182 | while workeridx < worker_queue.len() { 183 | let workerid = *worker_queue.get(workeridx).unwrap(); 184 | match self.find_next_job(workerid, &job_queue) { 185 | Some(jobidx) => { 186 | let job = job_queue.swap_remove_front(jobidx).unwrap(); 187 | worker_queue.swap_remove_front(workeridx).unwrap(); 188 | Scheduler::send_function_call(&mut workers, workerid, job); 189 | } 190 | None => { 191 | workeridx += 1; 192 | } 193 | } 194 | } 195 | }, 196 | Event::Pull(workerid, objref) => { 197 | if self.objtable.lock().unwrap()[objref as usize].len() > 0 { 198 | Scheduler::send_pull_request(&mut workers, workerid, objref); 199 | } else { 200 | pull_queue.push_back((workerid, objref)); 201 | } 202 | }, 203 | Event::Register(workerid, incoming) => { 204 | while workers.len() < workerid + 1 { 205 | workers.push(incoming.clone()); 206 | } 207 | workers[workerid] = incoming; 208 | }, 209 | Event::Debug(workerid) => { 210 | self.send_debugging_info(&workers[workerid], &worker_queue, &job_queue); 211 | } 212 | } 213 | } 214 | }); 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /src/server.rs: -------------------------------------------------------------------------------- 1 | use comm; 2 | use graph; 3 | use scheduler; 4 | use scheduler::{Scheduler, Event}; 5 | use utils::{send_message, receive_message, receive_ack, send_ack, bind_socket, push_objrefs}; 6 | use utils::{WorkerID, ObjRef, ObjTable, FnTable}; 7 | use graph::CompGraph; 8 | use rand; 9 | use rand::distributions::{IndependentSample, Range}; 10 | use std::io::{Read, Write}; 11 | use std::collections::VecDeque; 12 | use zmq; 13 | use zmq::Socket; 14 | use std::process; 15 | use std::sync::mpsc::{Sender, Receiver}; 16 | use std::sync::mpsc; 17 | use std::thread; 18 | use std::sync::{Arc, RwLock, Mutex, MutexGuard, RwLockReadGuard}; 19 | use std::str::FromStr; 20 | use std::net::IpAddr; 21 | use std::collections::HashMap; 22 | use protobuf::{Message, RepeatedField}; 23 | use std::iter::Iterator; 24 | 25 | /// Contains informations about worker. 26 | pub struct Worker { 27 | addr: String 28 | } 29 | 30 | /// A group of workers that are managed and scheduled together. They are connected with the server 31 | /// using a zero mq `PUB` channel used for one-way communication from server to client. 32 | /// Furthermore, each client is connected to each other client using a REP/REQ socket pair; all 33 | /// data is transferred using these client side connections. It is the `WorkerPool`s task to 34 | /// establish the connections. 35 | pub struct WorkerPool { 36 | /// Workers that have been registered with this pool. 37 | workers: Arc>>, 38 | /// Notify the scheduler that a worker, job or object becomes available. 39 | scheduler_notify: Sender, 40 | /// Send delivery requests to clients. 41 | publish_notify: Sender<(WorkerID, comm::Message)>, 42 | } 43 | 44 | impl WorkerPool { 45 | /// Create a new `WorkerPool`. 46 | pub fn new(objtable: Arc>, fntable: Arc>, publish_port: u16) -> WorkerPool { 47 | let (publish_sender, publish_receiver) = mpsc::channel(); 48 | let scheduler_notify = Scheduler::start(objtable, fntable); 49 | WorkerPool::start_publisher_thread(publish_receiver, publish_port); 50 | return WorkerPool { workers: Arc::new(RwLock::new(Vec::new())), publish_notify: publish_sender, scheduler_notify: scheduler_notify } 51 | } 52 | 53 | /// Start the thread that is used to feed the PUB/SUB network between the server and the workers. 54 | pub fn start_publisher_thread(publish_notify: Receiver<(WorkerID, comm::Message)>, publish_port: u16) { 55 | thread::spawn(move || { 56 | let mut zmq_ctx = zmq::Context::new(); 57 | let mut publisher = zmq_ctx.socket(zmq::PUB).unwrap(); 58 | let localhost = IpAddr::from_str("0.0.0.0").unwrap(); 59 | bind_socket(&mut publisher, &localhost, Some(publish_port)); 60 | loop { 61 | match publish_notify.recv().unwrap() { 62 | (workerid, msg) => { 63 | let mut buf = Vec::new(); 64 | write!(buf, "{:0>#07}", workerid).unwrap(); 65 | msg.write_to_writer(&mut buf).unwrap(); 66 | publisher.send(buf.as_slice(), 0).unwrap(); 67 | } 68 | } 69 | } 70 | }); 71 | } 72 | 73 | /// Add new job to the queue. 74 | pub fn queue_job(self: &mut WorkerPool, job: comm::Call) { 75 | self.scheduler_notify.send(scheduler::Event::Job(job)).unwrap(); 76 | } 77 | 78 | /// Return the number of workers in the pool. 79 | pub fn len(self: &WorkerPool) -> usize { 80 | return self.workers.read().unwrap().len(); 81 | } 82 | 83 | /// Connect a new worker to the workers already present in the pool. 84 | fn connect(self: &mut WorkerPool, zmq_ctx: &mut zmq::Context, addr: &str, workerid: WorkerID, setup_socket: &mut Socket) -> Socket { 85 | info!("connecting worker {}", workerid); 86 | let mut socket = zmq_ctx.socket(zmq::REQ).unwrap(); 87 | socket.connect(addr).unwrap(); 88 | let mut buf = zmq::Message::new().unwrap(); 89 | loop { 90 | let mut hello = comm::Message::new(); 91 | hello.set_field_type(comm::MessageType::HELLO); 92 | self.publish_notify.send((workerid, hello)).unwrap(); 93 | thread::sleep_ms(10); // don't float the message queue 94 | match setup_socket.recv(&mut buf, zmq::DONTWAIT) { 95 | Ok(_) => break, 96 | Err(_) => continue 97 | } 98 | } 99 | // connect new client with other clients that are already connected 100 | // and connect already connected clients with the new client 101 | for i in 0..self.len() { 102 | let mut message = comm::Message::new(); 103 | message.set_field_type(comm::MessageType::REGISTER_CLIENT); 104 | let other_party = &self.workers.read().unwrap()[i].addr; 105 | message.set_address(other_party.clone()); // fix this 106 | self.publish_notify.send((workerid, message)).unwrap(); 107 | 108 | let mut request = comm::Message::new(); 109 | request.set_field_type(comm::MessageType::REGISTER_CLIENT); 110 | request.set_address(addr.into()); 111 | self.publish_notify.send((i, request)).unwrap(); 112 | } 113 | return socket; 114 | } 115 | 116 | /// Tell a client `pullid` to deliver an object to another client with address `addr`. 117 | pub fn send_deliver_request(pullid: WorkerID, addr: &str, objref: ObjRef, publish_notify: &Sender<(WorkerID, comm::Message)>) { 118 | let mut deliver = comm::Message::new(); 119 | deliver.set_field_type(comm::MessageType::DELIVER); 120 | deliver.set_objref(objref); 121 | deliver.set_address(addr.into()); 122 | publish_notify.send((pullid, deliver)).unwrap(); 123 | } 124 | 125 | /// Deliver the object with id `objref` to the worker with id `workerid`. 126 | pub fn deliver_object(workerid: WorkerID, objref: ObjRef, workers: &Arc>>, objtable: &Arc>, publish_notify: &Sender<(WorkerID, comm::Message)>) { 127 | if !objtable.lock().unwrap()[objref as usize].contains(&workerid) { 128 | // pick random worker 129 | let mut rng = rand::thread_rng(); // supposed to have no performance penalty 130 | let range = Range::new(0, objtable.lock().unwrap()[objref as usize].len()); 131 | let idx = range.ind_sample(&mut rng); 132 | let pullid = objtable.lock().unwrap()[objref as usize][idx]; 133 | info!("delivering object {} from {} to {}, addr {}", objref, pullid, workerid, &workers.read().unwrap()[workerid].addr); 134 | WorkerPool::send_deliver_request(pullid, &workers.read().unwrap()[workerid].addr, objref, &publish_notify); 135 | } 136 | } 137 | 138 | /// Register a new worker with the worker pool. 139 | pub fn register(self: &mut WorkerPool, zmq_ctx: &mut zmq::Context, addr: &str, objtable: Arc>, setup_socket: &mut Socket) -> WorkerID { 140 | info!("registering new worker"); 141 | let (incoming, receiver) = mpsc::channel(); 142 | let workerid = self.len(); 143 | let sender = self.scheduler_notify.clone(); 144 | let publish_notify = self.publish_notify.clone(); 145 | let mut socket = self.connect(zmq_ctx, addr, workerid, setup_socket); 146 | let workers = self.workers.clone(); 147 | let objtable = objtable.clone(); 148 | thread::spawn(move || { 149 | sender.send(scheduler::Event::Worker(workerid)).unwrap(); // pull for new work 150 | loop { 151 | let request : comm::Message = receiver.recv().unwrap(); // get the item of work the scheduler chose for us 152 | match request.get_field_type() { 153 | comm::MessageType::INVOKE => { 154 | // orchestrate packages being sent to worker node, start the work there 155 | let results = request.get_call().get_result(); 156 | assert!(results.len() == 1); 157 | send_function_call(&mut socket, request.get_call().get_name(), request.get_call().get_args(), results[0]); 158 | receive_ack(&mut socket); // TODO: Avoid this round trip 159 | // deduplicate: (TODO: get rid of inefficiency): 160 | let mut args = Vec::new(); 161 | push_objrefs(request.get_call().get_args(), &mut args); 162 | args.sort(); 163 | args.dedup(); 164 | info!("sending args {:?}", args); 165 | for objref in args.iter() { 166 | WorkerPool::deliver_object(workerid, *objref, &workers, &objtable, &publish_notify) 167 | } 168 | }, 169 | comm::MessageType::PULL => { 170 | let objref = request.get_objref(); 171 | WorkerPool::deliver_object(workerid, objref, &workers, &objtable, &publish_notify); 172 | }, 173 | comm::MessageType::DEBUG => { 174 | println!("pull through to {}", workerid); 175 | publish_notify.send((workerid, request)).unwrap(); // pull request through 176 | }, 177 | _ => {} 178 | } 179 | } 180 | }); 181 | self.workers.write().unwrap().push(Worker {addr: addr.into()}); 182 | self.scheduler_notify.send(scheduler::Event::Register(workerid, incoming)); 183 | return workerid; 184 | } 185 | } 186 | 187 | /// The server orchestrates the computation. 188 | pub struct Server<'a> { 189 | /// For each object reference, the `objtable` stores the list of workers that hold this object. 190 | objtable: Arc>, 191 | /// The `fntable` is the mapping from function names to workers that can execute the function (sorted). 192 | fntable: Arc>, 193 | /// Computation graph for this server. 194 | graph: graph::CompGraph<'a>, 195 | /// A pool of workers that are managed by this server. 196 | workerpool: WorkerPool, 197 | /// The ZeroMQ context for this server. 198 | zmq_ctx: zmq::Context 199 | } 200 | 201 | impl<'a> Server<'a> { 202 | /// Create a new server. 203 | pub fn new(publish_port: u16) -> Server<'a> { 204 | let mut ctx = zmq::Context::new(); 205 | 206 | let objtable = Arc::new(Mutex::new(Vec::new())); 207 | let fntable = Arc::new(RwLock::new(HashMap::new())); 208 | 209 | Server { 210 | workerpool: WorkerPool::new(objtable.clone(), fntable.clone(), publish_port), 211 | objtable: objtable, 212 | fntable: fntable, 213 | graph: CompGraph::new(), 214 | zmq_ctx: ctx 215 | } 216 | } 217 | 218 | /// Start the server's main loop. 219 | pub fn main_loop<'b>(self: &'b mut Server<'a>, incoming_port: u16, setup_port: u16) { 220 | let mut socket = self.zmq_ctx.socket(zmq::REP).ok().unwrap(); 221 | let localhost = IpAddr::from_str("0.0.0.0").unwrap(); 222 | bind_socket(&mut socket, &localhost, Some(incoming_port)); 223 | loop { 224 | self.process_request(&mut socket, setup_port); 225 | } 226 | } 227 | 228 | /// Add new object to the computation graph and the object pool. 229 | pub fn register_new_object<'b>(self: &'b mut Server<'a>) -> ObjRef { 230 | let (objref, _) = self.graph.add_obj(); 231 | assert!(objref as usize == self.objtable.lock().unwrap().len()); 232 | self.objtable.lock().unwrap().push(vec!()); 233 | return objref; 234 | } 235 | 236 | /// Tell the server that a worker holds a certain object. 237 | pub fn register_result<'b>(self: &'b mut Server<'a>, objref: ObjRef, workerid: WorkerID) { 238 | // TODO: Keep vector sorted while inserting 239 | self.objtable.lock().unwrap()[objref as usize].push(workerid); 240 | } 241 | 242 | /// Add a new call to the computation graph. 243 | pub fn add_call<'b>(self: &'b mut Server<'a>, fnname: String, args: &'b [ObjRef]) -> ObjRef { 244 | let result = self.register_new_object(); 245 | self.graph.add_op(fnname, args, result); 246 | return result; 247 | } 248 | 249 | /// Add a map call to the computation graph. 250 | pub fn add_map<'b>(self: &'b mut Server<'a>, fnname: String, args: &'b comm::Args) -> Vec { 251 | // TODO: Do this with only one lock 252 | let mut result = Vec::new(); 253 | for arg in args.get_objrefs() { 254 | let objref = self.register_new_object(); 255 | result.push(objref); 256 | } 257 | // TODO: add the op here 258 | return result; 259 | } 260 | 261 | /// Add a reduce call to the computation graph. 262 | pub fn add_reduce<'b>(self: &'b mut Server<'a>, fname: String, args: &'b [ObjRef]) -> ObjRef { 263 | let objref = self.register_new_object(); 264 | // TODO: add the op here 265 | return objref; 266 | } 267 | 268 | /// Add a worker's request for evaluation to the computation graph and notify the scheduler. 269 | pub fn add_request<'b>(self: &'b mut Server<'a>, call: &'b comm::Call) -> comm::Message { 270 | let mut call = call.clone(); 271 | let mut args = Vec::new(); 272 | push_objrefs(call.get_args(), &mut args); 273 | if call.get_field_type() == comm::Call_Type::INVOKE_CALL { 274 | let objref = self.add_call(call.get_name().into(), &args[..]); 275 | call.set_result(vec!(objref)); 276 | self.workerpool.queue_job(call.clone()); // can we get rid of this clone? 277 | } 278 | if call.get_field_type() == comm::Call_Type::MAP_CALL { 279 | let objrefs = self.add_map(call.get_name().into(), call.get_args()); 280 | // Add to the scheduler 281 | for (arg, res) in call.get_args().get_objrefs().iter().zip(objrefs.iter()) { 282 | let mut c = comm::Call::new(); 283 | let mut a = comm::Args::new(); 284 | a.set_objrefs(vec!((*arg).clone())); // TODO: copy needed? 285 | c.set_args(a); 286 | c.set_result(vec!(*res)); 287 | c.set_name(call.get_name().into()); 288 | // INVOKE_CALL is already the default 289 | self.workerpool.queue_job(c); 290 | } 291 | call.set_result(objrefs); 292 | } 293 | if call.get_field_type() == comm::Call_Type::REDUCE_CALL { 294 | 295 | } 296 | // add obj refs here 297 | let mut message = comm::Message::new(); 298 | message.set_field_type(comm::MessageType::DONE); 299 | message.set_call(call); 300 | return message; 301 | } 302 | 303 | /// Dump the computation graph to a .dot file. 304 | pub fn dump<'b>(self: &'b mut Server<'a>, out: &'b mut Write) { 305 | let res = graph::to_dot(&self.graph); 306 | out.write(res.as_bytes()).unwrap(); 307 | } 308 | 309 | /// Establish the setup port that will be used for setting up the client server connection 310 | fn bind_setup_socket(zmq_ctx: &mut zmq::Context) -> (Socket, u16) { 311 | let mut setup_socket = zmq_ctx.socket(zmq::REP).ok().unwrap(); 312 | let localhost = IpAddr::from_str("0.0.0.0").unwrap(); 313 | let port = bind_socket(&mut setup_socket, &localhost, None); 314 | return (setup_socket, port) 315 | } 316 | 317 | /// Process request by client. 318 | pub fn process_request<'b>(self: &'b mut Server<'a>, socket: &'b mut Socket, setup_port: u16) { 319 | let msg = receive_message(socket); 320 | match msg.get_field_type() { 321 | comm::MessageType::INVOKE => { 322 | let mut message = self.add_request(msg.get_call()); 323 | // info!("add request {:?} {:?}, result {:?}", msg.get_call().get_field_type(), msg.get_call().get_name(), message.get_call().get_result()); 324 | send_message(socket, &mut message); 325 | }, 326 | comm::MessageType::PUSH => { 327 | let workerid = msg.get_workerid() as WorkerID; 328 | let objref = self.register_new_object(); 329 | self.register_result(objref, workerid); 330 | self.workerpool.scheduler_notify.send(scheduler::Event::Obj(objref)).unwrap(); 331 | 332 | let mut call = comm::Call::new(); 333 | call.set_result(vec!(objref)); 334 | let mut message = comm::Message::new(); 335 | message.set_field_type(comm::MessageType::DONE); // this is never used 336 | message.set_call(call); // this is not really a call, just used to store the objref 337 | send_message(socket, &mut message); 338 | }, 339 | comm::MessageType::REGISTER_CLIENT => { 340 | let workerid = self.workerpool.len(); 341 | let (mut setup_socket, setup_port) = Server::bind_setup_socket(&mut self.zmq_ctx); 342 | info!("chose port {}", setup_port); 343 | let mut ack = comm::Message::new(); 344 | ack.set_field_type(comm::MessageType::ACK); 345 | ack.set_workerid(workerid as u64); 346 | ack.set_setup_port(setup_port as u64); 347 | send_message(socket, &mut ack); 348 | self.workerpool.register(&mut self.zmq_ctx, msg.get_address(), self.objtable.clone(), &mut setup_socket); 349 | }, 350 | comm::MessageType::REGISTER_FUNCTION => { 351 | let workerid = msg.get_workerid() as WorkerID; 352 | let fnname = msg.get_fnname(); 353 | info!("function {} registered (worker {})", fnname.to_string(), workerid); 354 | let mut table = self.fntable.write().unwrap(); 355 | if !table.contains_key(fnname) { 356 | table.insert(fnname.into(), vec!()); 357 | } 358 | match table.get(fnname).unwrap().binary_search(&workerid) { 359 | Ok(_) => {}, 360 | Err(idx) => { table.get_mut(fnname).unwrap().insert(idx, workerid); } 361 | } 362 | send_ack(socket); 363 | } 364 | comm::MessageType::PULL => { 365 | let workerid = msg.get_workerid() as WorkerID; 366 | let objref = msg.get_objref(); 367 | info!("object {} pulled (worker {})", objref, workerid); 368 | send_ack(socket); 369 | self.workerpool.scheduler_notify.send(scheduler::Event::Pull(workerid, objref)).unwrap(); 370 | }, 371 | comm::MessageType::DONE => { 372 | send_ack(socket); 373 | let result = msg.get_call().get_result(); 374 | assert!(result.len() == 1); 375 | let workerid = msg.get_workerid() as WorkerID; 376 | self.register_result(result[0], workerid); // this must happen before we notify the scheduler 377 | self.workerpool.scheduler_notify.send(scheduler::Event::Worker(msg.get_workerid() as usize)).unwrap(); 378 | self.workerpool.scheduler_notify.send(scheduler::Event::Obj(result[0])).unwrap(); 379 | }, 380 | comm::MessageType::ACC => { 381 | send_ack(socket); 382 | self.objtable.lock().unwrap()[msg.get_objref() as usize].push(msg.get_workerid() as usize); 383 | info!("delivery of {} to {} successful", msg.get_objref(), msg.get_workerid()); 384 | } 385 | comm::MessageType::DEBUG => { 386 | info!("received debug request"); 387 | send_ack(socket); 388 | self.workerpool.scheduler_notify.send(scheduler::Event::Debug(msg.get_workerid() as usize)).unwrap(); 389 | }, 390 | _ => { 391 | error!("message {:?} not allowed in this state", msg.get_field_type()); 392 | process::exit(1); 393 | } 394 | } 395 | } 396 | } 397 | 398 | /// Send request for function execution to a worker through the socket `socket`. 399 | pub fn send_function_call(socket: &mut Socket, name: &str, arguments: &comm::Args, result: ObjRef) { 400 | let mut message = comm::Message::new(); 401 | message.set_field_type(comm::MessageType::INVOKE); 402 | let mut call = comm::Call::new(); 403 | call.set_field_type(comm::Call_Type::INVOKE_CALL); 404 | call.set_name(name.into()); 405 | call.set_args(arguments.clone()); // TODO: get rid of this copy 406 | call.set_result(vec!(result)); 407 | message.set_call(call); 408 | send_message(socket, &mut message); 409 | } 410 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use comm; 2 | use protobuf; 3 | use protobuf::Message; 4 | use protobuf::core::MessageStatic; 5 | use zmq; 6 | use zmq::{Socket}; 7 | use std::io::Cursor; 8 | use std::ops::{Deref}; 9 | use std::collections::HashMap; 10 | use std::net::IpAddr; 11 | use rand; 12 | use rand::distributions::{IndependentSample, Range}; 13 | 14 | /// A unique identifier for an object stored on one of the workers. 15 | pub type ObjRef = u64; 16 | /// A unique identifier for a worker. 17 | pub type WorkerID = usize; 18 | /// For each object, contains a vector of worker ids that hold the object. 19 | pub type ObjTable = Vec>; 20 | /// For each function, contains a sorted vector of worker ids that can execute the function. 21 | pub type FnTable = HashMap>; 22 | 23 | /// Given a predicate `absent` that can test if an object is unavailable on the client, compute 24 | /// which objects fom `args` still need to be send so the function call can be invoked. 25 | pub fn args_to_send bool>(args: &[ObjRef], absent: F) -> Vec { 26 | let mut scratch = args.to_vec(); 27 | scratch.sort(); 28 | // deduplicate 29 | let mut curr = 0; 30 | for i in 0..scratch.len() { 31 | let arg = scratch[i]; 32 | if i > 0 && arg == scratch[i-1] { 33 | continue; 34 | } 35 | if absent(arg) { 36 | scratch[curr] = arg; 37 | curr += 1 38 | } 39 | } 40 | scratch.truncate(curr); 41 | return scratch 42 | } 43 | 44 | #[test] 45 | fn test_args_to_send() { 46 | let args = vec![1, 4, 5, 5, 2, 2, 3, 3]; 47 | let present = vec![1, 2, 4]; 48 | let res = args_to_send(&args, |objref| present.binary_search(&objref).is_err()); 49 | assert_eq!(res, vec!(3, 5)); 50 | } 51 | 52 | pub fn push_objrefs(args: &comm::Args, result: &mut Vec) { 53 | for elem in args.get_objrefs() { 54 | if *elem >= 0 { 55 | result.push(*elem as u64); 56 | } 57 | } 58 | } 59 | 60 | /// Send a protocol buffer message on a socket. 61 | pub fn send_message(socket: &mut Socket, message: &mut comm::Message) { 62 | let mut buf = Vec::new(); 63 | message.write_to_vec(&mut buf).unwrap(); 64 | socket.send(buf.as_slice(), 0).unwrap(); 65 | } 66 | 67 | /// Receive a protocol buffer message over a socket. 68 | pub fn receive_message(socket: &mut Socket) -> comm::Message { 69 | let mut msg = zmq::Message::new().unwrap(); 70 | socket.recv(&mut msg, 0).unwrap(); 71 | let mut input_stream = protobuf::CodedInputStream::from_bytes(msg.deref()); 72 | return protobuf::core::parse_from::(&mut input_stream).unwrap(); 73 | } 74 | 75 | /// Receive a protocol buffer message through a subscription socket. 76 | pub fn receive_subscription(subscriber: &mut Socket) -> comm::Message { 77 | let mut msg = zmq::Message::new().unwrap(); 78 | subscriber.recv(&mut msg, 0).unwrap(); 79 | let mut read_buf = Cursor::new(msg.as_mut()); 80 | read_buf.set_position(7); 81 | return protobuf::parse_from_reader(&mut read_buf).unwrap(); 82 | } 83 | 84 | /// Send an acknowledgement package. 85 | pub fn send_ack(socket: &mut Socket) { 86 | let mut ack = comm::Message::new(); 87 | ack.set_field_type(comm::MessageType::ACK); 88 | send_message(socket, &mut ack); 89 | } 90 | 91 | /// Receive an acknowledgement package. 92 | pub fn receive_ack(socket: &mut Socket) { 93 | let ack = receive_message(socket); 94 | assert!(ack.get_field_type() == comm::MessageType::ACK); 95 | } 96 | 97 | pub fn to_zmq_socket_addr(addr: &IpAddr, port: u16) -> String { 98 | return format!("tcp://{}:{}", addr, port).into(); 99 | } 100 | 101 | /// Bind a ZeroMQ socket to specific address. If port is None, connect to a free port. Return port. 102 | pub fn bind_socket(socket: &mut Socket, host: &IpAddr, port: Option) -> u16 { 103 | match port { 104 | None => { 105 | loop { 106 | let mut rng = rand::thread_rng(); 107 | let range = Range::new(2048, 65535); 108 | let port = range.ind_sample(&mut rng); 109 | match socket.bind(&to_zmq_socket_addr(host, port)[..]) { 110 | Ok(()) => { return port }, 111 | Err(err) => { continue } 112 | } 113 | } 114 | } 115 | Some(port) => { 116 | match socket.bind(&to_zmq_socket_addr(host, port)[..]) { 117 | Ok(()) => { return port }, 118 | Err(err) => { panic!("Could not bind socket. Make sure port {} is not used yet. {}", port, err) } 119 | } 120 | } 121 | } 122 | } 123 | 124 | /// Connect a ZeroMQ socket to specific address 125 | pub fn connect_socket(socket: &mut Socket, host: &IpAddr, port: u16) { 126 | match socket.connect(&to_zmq_socket_addr(host, port)[..]) { 127 | Ok(()) => {}, 128 | Err(_) => { panic!("Could not connect socket. Make sure port is set correctly.") } 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /test/mapreduce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import orchpy as op 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('server_port', type=int, help='the port to post requests to') 7 | parser.add_argument('client_port', type=int, help='the port to listen at') 8 | parser.add_argument('subscriber_port', type=int, help='the port used to set up the connections') 9 | 10 | @op.distributed([], np.ndarray) 11 | def zeros(): 12 | return np.zeros((100, 100)) 13 | 14 | @op.distributed([str], str) 15 | def str_identity(string): 16 | return string 17 | 18 | @op.distributed([], np.ndarray) 19 | def create_dist_array(): 20 | objrefs = np.empty((2, 2), dtype=np.dtype("int64")) 21 | for i in range(2): 22 | for j in range(2): 23 | objrefs[i,j] = zeros().get_id() 24 | return objrefs 25 | 26 | @op.distributed([np.ndarray], np.ndarray) 27 | def plusone(matrix): 28 | return matrix + 1.0 29 | 30 | if __name__ == "__main__": 31 | args = parser.parse_args() 32 | op.context.connect("127.0.0.1", args.server_port, args.subscriber_port, "127.0.0.1", args.client_port) 33 | op.register_current() 34 | op.context.main_loop() 35 | -------------------------------------------------------------------------------- /test/matmul.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import orchpy as op 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('server_port', type=int, help='the port to post requests to') 7 | parser.add_argument('client_port', type=int, help='the port to listen at') 8 | parser.add_argument('subscriber_port', type=int, help='the port used to set up the connections') 9 | 10 | @op.distributed([np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray) 11 | def blockwise_dot(*matrices): 12 | result = np.zeros((100, 100)) 13 | k = len(matrices) / 2 14 | for i in range(k): 15 | result += np.dot(matrices[i], matrices[k+i]) 16 | return result 17 | 18 | @op.distributed([np.ndarray, np.ndarray], np.ndarray) 19 | def matrix_multiply(first_dist_mat, second_dist_mat): 20 | objrefs = np.zeros((2, 2), dtype=np.dtype("int64")) 21 | for i in range(2): 22 | for j in range(2): 23 | args = list(map(op.ObjRef, first_dist_mat[i,:])) + list(map(op.ObjRef, second_dist_mat[:,j])) 24 | objrefs[i,j] = blockwise_dot(*args).get_id() 25 | return objrefs 26 | 27 | if __name__ == "__main__": 28 | args = parser.parse_args() 29 | op.context.connect("127.0.0.1", args.server_port, args.subscriber_port, "127.0.0.1", args.client_port) 30 | op.register_current(globals().items()) 31 | op.context.main_loop() 32 | -------------------------------------------------------------------------------- /test/run_papaya_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import orchpy as op 4 | import orchpy.unison as unison 5 | import subprocess, os, socket, signal 6 | from testprograms import zeros, testfunction, testobjrefs, arrayid 7 | import time 8 | 9 | import papaya.dist as dist 10 | import papaya.single as single 11 | 12 | def get_unused_port(): 13 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | s.bind(('localhost', 0)) 15 | addr, port = s.getsockname() 16 | s.close() 17 | return port 18 | 19 | numworkers = 10 20 | 21 | class Papayatest(unittest.TestCase): 22 | 23 | def setUp(self): 24 | self.incoming_port = get_unused_port() 25 | print "incoming port is", self.incoming_port 26 | self.publish_port = get_unused_port() 27 | print "publish port is", self.publish_port 28 | 29 | self.master = subprocess.Popen(["cargo", "run", "--release", "--bin", "orchestra", "--", str(self.incoming_port), str(self.publish_port)], env=dict(os.environ, RUST_BACKTRACE="1"), preexec_fn=os.setsid) 30 | self.workers = map(lambda worker: subprocess.Popen(["python", "testprograms.py", str(self.incoming_port), str(get_unused_port()), str(self.publish_port)], preexec_fn=os.setsid), range(numworkers)) 31 | 32 | def testConnect(self): 33 | self.client_port = get_unused_port() 34 | op.context.connect("127.0.0.1", self.incoming_port, self.publish_port, "127.0.0.1", self.client_port) 35 | op.context.debug_info() 36 | 37 | time.sleep(1.0) # todo(pcmoritz) fix this 38 | 39 | 40 | def testPapayaFunctions(self): 41 | def test_dist_dot(d1, d2, d3): 42 | print "testing dist_dot with d1 = " + str(d1) + ", d2 = " + str(d2) + ", d3 = " + str(d3) 43 | a = dist.dist_random_normal([d1, d2], [10, 10]) 44 | b = dist.dist_random_normal([d2, d3], [10, 10]) 45 | c = dist.dist_dot(a, b) 46 | a_val = a.assemble() 47 | b_val = b.assemble() 48 | c_val = c.assemble() 49 | np.testing.assert_allclose(np.dot(a_val, b_val), c_val) 50 | 51 | def test_dist_tsqr(d1, d2): 52 | print "testing dist_tsqr with d1 = " + str(d1) + ", d2 = " + str(d2) 53 | a = dist.dist_random_normal([d1, d2], [10, 10]) 54 | q, r = dist.dist_tsqr(a) 55 | a_val = a.assemble() 56 | q_val = q.assemble() 57 | np.testing.assert_allclose(np.dot(q_val, r), a_val) # check that a = q * r 58 | np.testing.assert_allclose(np.dot(q_val.T, q_val), np.eye(min(d1, d2)), atol=1e-6) # check that q.T * q = I 59 | np.testing.assert_allclose(np.triu(r), r) # check that r is upper triangular 60 | 61 | def test_single_modified_lu(d1, d2): 62 | print "testing single_modified_lu with d1 = " + str(d1) + ", d2 = " + str(d2) 63 | assert d1 >= d2 64 | k = min(d1, d2) 65 | m = np.random.normal(size=(d1, d2)) 66 | q, r = np.linalg.qr(m) 67 | l, u, s = single.single_modified_lu(q) 68 | s_mat = np.zeros((d1, d2)) 69 | for i in range(len(s)): 70 | s_mat[i, i] = s[i] 71 | np.testing.assert_allclose(q - s_mat, np.dot(l, u)) # check that q - s = l * u 72 | np.testing.assert_allclose(np.triu(u), u) # check that u is upper triangular 73 | np.testing.assert_allclose(np.tril(l), l) # check that u is lower triangular 74 | 75 | def test_dist_tsqr_hr(d1, d2): 76 | print "testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + str(d2) 77 | a = dist.dist_random_normal([d1, d2], [10, 10]) 78 | a_val = a.assemble() 79 | y, t, y_top, r = dist.dist_tsqr_hr(a) 80 | tall_eye = np.zeros((d1, min(d1, d2))) 81 | np.fill_diagonal(tall_eye, 1) 82 | q = tall_eye - np.dot(y, np.dot(t, y_top.T)) 83 | np.testing.assert_allclose(np.dot(q.T, q), np.eye(min(d1, d2)), atol=1e-6) # check that q.T * q = I 84 | np.testing.assert_allclose(np.dot(q, r), a_val) # check that a = (I - y * t * y_top.T) * r 85 | 86 | for i in range(10): 87 | d1 = np.random.randint(1, 100) 88 | d2 = np.random.randint(1, 100) 89 | d3 = np.random.randint(1, 100) 90 | test_dist_dot(d1, d2, d3) 91 | 92 | for i in range(10): 93 | d1 = np.random.randint(1, 50) 94 | d2 = np.random.randint(1, 11) 95 | test_dist_tsqr(d1, d2) 96 | 97 | for i in range(10): 98 | d2 = np.random.randint(1, 100) 99 | d1 = np.random.randint(d2, 100) 100 | test_single_modified_lu(d1, d2) 101 | 102 | for i in range(10): 103 | d1 = np.random.randint(1, 100) 104 | d2 = np.random.randint(1, 11) 105 | test_dist_tsqr_hr(d1, d2) 106 | 107 | def tearDown(self): 108 | os.killpg(self.master.pid, signal.SIGTERM) 109 | for worker in self.workers: 110 | os.killpg(worker.pid, signal.SIGTERM) 111 | 112 | # self.context.close() 113 | 114 | 115 | if __name__ == '__main__': 116 | unittest.main() 117 | -------------------------------------------------------------------------------- /test/runbasictests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import orchpy as op 4 | import orchpy.unison as unison 5 | import subprocess, os, socket, signal 6 | from testprograms import zeros, testfunction, testobjrefs, arrayid 7 | import time 8 | 9 | 10 | def get_unused_port(): 11 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 12 | s.bind(('localhost', 0)) 13 | addr, port = s.getsockname() 14 | s.close() 15 | return port 16 | 17 | class UnisonTest(unittest.TestCase): 18 | 19 | def testSerializeArray(self): 20 | buf = bytearray() 21 | a = np.zeros((100, 100)) 22 | unison.serialize(buf, a) 23 | data = memoryview(buf).tobytes() 24 | res = unison.deserialize(data, np.ndarray) 25 | self.assertTrue(np.alltrue(a == res)) 26 | 27 | buf = bytearray() 28 | t = (123, 42) 29 | unison.serialize(buf, t) 30 | schema = unison.Tuple[int, int] 31 | data = memoryview(buf).tobytes() 32 | res = unison.deserialize(data, schema) 33 | self.assertTrue(t == res) 34 | 35 | buf = bytearray() 36 | l = ([1, 2, 3, 4], [1.0, 2.0]) 37 | unison.serialize(buf, l) 38 | schema = unison.Tuple[unison.List[int], unison.List[float]] 39 | data = memoryview(buf).tobytes() 40 | res = unison.deserialize(data, schema) 41 | self.assertTrue(l == res) 42 | 43 | class UnisonTest(unittest.TestCase): 44 | def testTypeCheck(self): 45 | l = 200 * [op.ObjRef(1)] + 50 * [1L] + 50 * [1.0] + 50 * [u"hi"] 46 | t = 200 * [op.ObjRef] + 50 * [int] + 50 * [float] + 50 * [unicode] 47 | op.check_types(l, t) 48 | try: 49 | l = [1, 2, 3, 4, "hi"] 50 | t = unison.List[int] 51 | op.check_types([l], [t]) 52 | self.assertFalse(True) 53 | except: 54 | self.assertTrue(True) 55 | l = ("hello", "world") 56 | t = unison.Tuple[str, str] 57 | op.check_types([l], [t]) 58 | l = [[1, 2, 3], ([1, 2, 3], ("hello", "world")), np.array([1.0, 2.0, 3.0])] 59 | t = [unison.List[int], unison.Tuple[unison.List[int], unison.Tuple[str, str]], np.ndarray] 60 | op.check_types(l, t) 61 | 62 | class OrchestraTest(unittest.TestCase): 63 | 64 | def testArgs(self): 65 | l = 200 * [op.ObjRef(1)] + 50 * [1L] + 50 * [1.0] + 50 * [u"hi"] 66 | t = 200 * [op.ObjRef] + 50 * [int] + 50 * [float] + 50 * [unicode] 67 | args = op.serialize_args(l) 68 | res = op.deserialize_args(args, t) 69 | self.assertTrue(res == l) 70 | 71 | def testDistributed(self): 72 | l = [u"hello", op.ObjRef(2), 3] 73 | args = op.serialize_args(l) 74 | res = op.deserialize_args(args, testfunction.types) 75 | self.assertTrue(res == l) 76 | 77 | numworkers = 2 78 | 79 | class ClientTest(unittest.TestCase): 80 | 81 | def setUp(self): 82 | self.incoming_port = get_unused_port() 83 | print "incoming port is", self.incoming_port 84 | self.publish_port = get_unused_port() 85 | print "publish port is", self.publish_port 86 | 87 | self.master = subprocess.Popen(["cargo", "run", "--release", "--bin", "orchestra", "--", str(self.incoming_port), str(self.publish_port)], env=dict(os.environ, RUST_BACKTRACE="1"), preexec_fn=os.setsid) 88 | self.workers = map(lambda worker: subprocess.Popen(["python", "testprograms.py", str(self.incoming_port), str(get_unused_port()), str(self.publish_port)], preexec_fn=os.setsid), range(numworkers)) 89 | 90 | def testConnect(self): 91 | self.client_port = get_unused_port() 92 | op.context.connect("127.0.0.1", self.incoming_port, self.publish_port, "127.0.0.1", self.client_port) 93 | op.context.debug_info() 94 | 95 | time.sleep(1.0) # todo(pcmoritz) fix this 96 | 97 | res = zeros([100, 100]) 98 | objrefs = testobjrefs() 99 | 100 | arrayid(res) 101 | 102 | res = op.context.pull(op.ObjRefs, objrefs) 103 | 104 | 105 | def tearDown(self): 106 | os.killpg(self.master.pid, signal.SIGTERM) 107 | for worker in self.workers: 108 | os.killpg(worker.pid, signal.SIGTERM) 109 | 110 | # self.context.close() 111 | 112 | 113 | if __name__ == '__main__': 114 | unittest.main() 115 | -------------------------------------------------------------------------------- /test/runtest.py: -------------------------------------------------------------------------------- 1 | import subprocess, os, signal, time, socket 2 | import unittest 3 | import orchpy as op 4 | from random import randint 5 | import numpy as np 6 | 7 | def get_unused_port(): 8 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 9 | s.bind(('localhost', 0)) 10 | addr, port = s.getsockname() 11 | s.close() 12 | return port 13 | 14 | numworkers = 5 15 | 16 | class OrchestraTest(unittest.TestCase): 17 | 18 | def setUp(self): 19 | incoming_port = get_unused_port() 20 | print "incoming port is", incoming_port 21 | publish_port = get_unused_port() 22 | print "publish port is", publish_port 23 | 24 | self.master = subprocess.Popen(["cargo", "run", "--bin", "orchestra", "--", str(incoming_port), str(publish_port)], env=dict(os.environ, RUST_BACKTRACE="1"), preexec_fn=os.setsid) 25 | self.workers = map(lambda worker: subprocess.Popen(["python", "mapreduce.py", str(incoming_port), str(get_unused_port()), str(publish_port)], preexec_fn=os.setsid), range(numworkers)) 26 | self.workers = map(lambda worker: subprocess.Popen(["python", "matmul.py", str(incoming_port), str(get_unused_port()), str(publish_port)], preexec_fn=os.setsid), range(numworkers)) 27 | op.context.connect("127.0.0.1", incoming_port, publish_port, "127.0.0.1", get_unused_port()) 28 | 29 | def tearDown(self): 30 | os.killpg(self.master.pid, signal.SIGTERM) 31 | for worker in self.workers: 32 | os.killpg(worker.pid, signal.SIGTERM) 33 | 34 | class CallTest(OrchestraTest): 35 | 36 | def testCall(self): 37 | time.sleep(0.5) 38 | 39 | import mapreduce 40 | M = mapreduce.zeros() 41 | res = op.context.pull(np.ndarray, M) 42 | self.assertTrue(np.linalg.norm(res) < 1e-5) 43 | 44 | M = mapreduce.create_dist_array() 45 | res = op.context.pull(np.ndarray, M) 46 | 47 | class MapTest(OrchestraTest): 48 | 49 | def testMap(self): 50 | time.sleep(0.5) 51 | 52 | m = 5 53 | import mapreduce 54 | from mapreduce import plusone 55 | args = [] 56 | for i in range(m): 57 | args.append(mapreduce.zeros()) 58 | res = op.context.map(plusone, args) 59 | for i in range(m): 60 | mat = op.context.pull(np.ndarray, res[i]) 61 | 62 | class MatMulTest(OrchestraTest): 63 | 64 | def testMatMul(self): 65 | time.sleep(1.0) 66 | 67 | print "starting computation" 68 | 69 | import matmul 70 | import mapreduce 71 | 72 | M = mapreduce.create_dist_array() 73 | res = matmul.matrix_multiply(M, M) 74 | 75 | A = op.context.assemble(M) 76 | B = op.context.assemble(res) 77 | 78 | self.assertTrue(np.linalg.norm(A.dot(A) - B) <= 1e-4) 79 | 80 | class CallByValueTest(OrchestraTest): 81 | 82 | def testCallByValue(self): 83 | time.sleep(0.5) 84 | 85 | import mapreduce 86 | res = mapreduce.str_identity("hello world") 87 | 88 | 89 | if __name__ == '__main__': 90 | unittest.main() 91 | -------------------------------------------------------------------------------- /test/testprograms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import orchpy as op 3 | import orchpy.unison as unison 4 | import argparse 5 | 6 | import papaya.dist 7 | import papaya.single 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('server_port', type=int, help='the port to post requests to') 11 | parser.add_argument('client_port', type=int, help='the port to listen at') 12 | parser.add_argument('subscriber_port', type=int, help='the port used to set up the connections') 13 | 14 | @op.distributed([unicode, op.ObjRef, int], unicode) 15 | def testfunction(a, b, c): 16 | return a 17 | 18 | @op.distributed([], op.ObjRefs) 19 | def testobjrefs(): 20 | return op.ObjRefs((10,10)) 21 | 22 | @op.distributed([unison.List[int]], np.ndarray) 23 | def zeros(shape): 24 | return np.zeros(shape) 25 | 26 | @op.distributed([np.ndarray], np.ndarray) 27 | def arrayid(array): 28 | return array 29 | 30 | if __name__ == "__main__": 31 | args = parser.parse_args() 32 | op.context.connect("127.0.0.1", args.server_port, args.subscriber_port, "127.0.0.1", args.client_port) 33 | op.register_current() 34 | op.register_distributed(papaya.dist) 35 | op.register_distributed(papaya.single) 36 | op.context.main_loop() 37 | --------------------------------------------------------------------------------