├── .dockerignore ├── .ebextensions └── websocket.config └── cloud-node ├── Dockerfile ├── README.md ├── aggregator.py ├── buildspec.yml ├── coordinator.py ├── message.py ├── model.py ├── requirements.txt ├── server.py ├── state.py ├── tools ├── assets │ ├── init_mlp_model_with_w.h5 │ ├── my_model.h5 │ └── saved_mlp_model_with_w.h5 └── start_new_session.py └── updatestore.py /.dockerignore: -------------------------------------------------------------------------------- 1 | tools/ 2 | temp/ 3 | # Elastic Beanstalk Files 4 | .elasticbeanstalk/* 5 | .git 6 | .gitignore -------------------------------------------------------------------------------- /.ebextensions/websocket.config: -------------------------------------------------------------------------------- 1 | option_settings: 2 | aws:elb:listener:80: 3 | ListenerProtocol: TCP 4 | InstancePort: 80 5 | InstanceProtocol: TCP 6 | aws:autoscaling:launchconfiguration: 7 | SecurityGroups: ebs-websocket 8 | -------------------------------------------------------------------------------- /cloud-node/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6 2 | COPY . /app 3 | WORKDIR /app 4 | 5 | # This is a weird fix for the dependencies. Be careful when modifying. 6 | RUN pip install -r requirements.txt 7 | RUN pip uninstall tensorflow tensorflow-estimator tensorflow-hub tensorflowjs numpy Keras Keras-Applications Keras-Preprocessing -y 8 | RUN pip install 'tensorflow==1.13.1' 'keras==2.2.4' 9 | RUN pip install 'tensorflowjs==1.0.1' 10 | RUN pip uninstall numpy -y 11 | RUN pip install 'numpy==1.16.3' 12 | 13 | EXPOSE 8999 14 | CMD ["python", "server.py"] 15 | -------------------------------------------------------------------------------- /cloud-node/README.md: -------------------------------------------------------------------------------- 1 | # DataAgora's Cloud Node 2 | 3 | Coordinator & aggregator for private and decentralized machine learning. 4 | 5 | For more information, please see [the wiki.](https://github.com/DataAgora/cloud-node/wiki) 6 | 7 | 8 | ## Local Set Up 9 | 10 | To start off, install the required Python dependencies by running: 11 | 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | After this, you can run the program by running: 17 | 18 | ``` 19 | python server.py 20 | ``` 21 | 22 | 23 | ## Deploying 24 | 25 | Deployment of this code happens automatically upon the creation of a Repo, so there's no need in deploying it manually. 26 | 27 | However, if you would like to deploy manually, you'd have to follow the [instructions here](https://docs.aws.amazon.com/elasticbeanstalk/latest/dg/single-container-docker.html). 28 | 29 | 30 | ## Tests 31 | 32 | There aren't any tests available for this repo yet. Please help us write them! 33 | 34 | 35 | -------------------------------------------------------------------------------- /cloud-node/aggregator.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | import numpy as np 5 | 6 | import state 7 | from updatestore import store_update 8 | from coordinator import start_next_round 9 | 10 | 11 | logging.basicConfig(level=logging.DEBUG) 12 | 13 | def handle_new_weights(message, clients_dict): 14 | """ 15 | Handle new weights from a Library. 16 | """ 17 | results = {"error": False, "message": "Success."} 18 | 19 | # 1. Check things match. 20 | if state.state["session_id"] != message.session_id: 21 | return { 22 | "error": True, 23 | "message": "The session id in the message doesn't match the service's." 24 | } 25 | 26 | if state.state["current_round"] != message.round: 27 | return { 28 | "error": True, 29 | "message": "The round in the message doesn't match the current round." 30 | } 31 | 32 | # 2 Lock section/variables that will be changed... 33 | state.state_lock.acquire() 34 | 35 | state.state["last_message_time"] = time.time() 36 | 37 | # 3. Do running weighted average on the new weights. 38 | do_running_weighted_average(message) 39 | 40 | # 4. Update the number of nodes averaged (+1) 41 | state.state["num_nodes_averaged"] += 1 42 | 43 | # 5. Log this update. 44 | # NOTE: Disabled until we actually need it. Could be useful for a progress bar. 45 | # store_update("UPDATE_RECEIVED", message, with_weights=False) 46 | 47 | # 6. If 'Continuation Criteria' is met... 48 | if check_continuation_criteria(state.state["initial_message"].continuation_criteria): 49 | # 6.a. Update round number (+1) 50 | state.state["current_round"] += 1 51 | 52 | # 6.b. If 'Termination Criteria' isn't met, then kickstart a new FL round 53 | # NOTE: We need a way to swap the weights from the initial message 54 | # in node............ 55 | if not check_termination_criteria(state.state["initial_message"].termination_criteria): 56 | print("Going to the next round...") 57 | results = kickstart_new_round(clients_dict["LIBRARY"]) 58 | 59 | # 6.c. Log the resulting weights for the user (for this round) 60 | store_update("ROUND_COMPLETED", message) 61 | 62 | # 7. If 'Termination Criteria' is met... 63 | # (NOTE: can't and won't happen with step 6.c.) 64 | if check_termination_criteria(state.state["initial_message"].termination_criteria): 65 | # 7.a. Reset all state in the service and mark BUSY as false 66 | state.reset_state() 67 | 68 | # 8. Release section/variables that were changed... 69 | state.state_lock.release() 70 | 71 | return results 72 | 73 | 74 | def kickstart_new_round(clients_list): 75 | """ 76 | Selects new nodes to run federated averaging with, and passes them the new 77 | averaged model. 78 | """ 79 | # Make the new message with new round (weights are swapped in the coordinator) 80 | new_message = state.state["initial_message"] 81 | new_message.round = state.state["current_round"] 82 | 83 | # Start a new round 84 | return start_next_round(new_message, clients_list) 85 | 86 | 87 | def do_running_weighted_average(message): 88 | """ 89 | Runs running weighted average with the new weights and the current weights 90 | and changes the global state with the result. 91 | """ 92 | # If this is the first weights we're averaging, just update them and return 93 | if state.state["current_weights"] is None or state.state["sigma_omega"] is None: 94 | state.state["current_weights"] = message.weights 95 | state.state["sigma_omega"] = message.omega 96 | return 97 | 98 | # Get the variables ready 99 | current_weights = state.state["current_weights"] 100 | sigma_omega = state.state["sigma_omega"] 101 | new_weights = message.weights 102 | new_omega = message.omega 103 | 104 | # Run the math 105 | temp = np.multiply(current_weights, float(sigma_omega)) 106 | temp = np.add(temp, np.multiply(new_weights, float(new_omega))) 107 | new_sigma_omega = sigma_omega + new_omega 108 | new_weighted_avg = np.divide(temp, float(new_sigma_omega)) 109 | 110 | # Update state 111 | state.state["current_weights"] = new_weighted_avg 112 | state.state["sigma_omega"] = new_sigma_omega 113 | 114 | 115 | def check_continuation_criteria(continuation_criteria): 116 | """ 117 | Right now only implements percentage of nodes averaged. 118 | 119 | TODO: Implement an absolute number of nodes to average (NUM_NODES_AVERAGED). 120 | """ 121 | if "type" not in continuation_criteria: 122 | raise Exception("Continuation criteria is not well defined.") 123 | 124 | if continuation_criteria["type"] == "PERCENTAGE_AVERAGED": 125 | if state.state["num_nodes_chosen"] == 0: 126 | # TODO: Implement a lower bound of how many nodes are needed to 127 | # continue to the next round. 128 | 129 | # TODO: Count the nodes at the time of averaging instead of at the 130 | # time of session creation. 131 | 132 | # In the meantime, if 0 nodes were active at the beginning of the 133 | # session, then the update of the first node to finish training will 134 | # trigger the continuation criteria. 135 | return True 136 | percentage = state.state["num_nodes_averaged"] / state.state["num_nodes_chosen"] 137 | return continuation_criteria["value"] <= percentage 138 | else: 139 | raise Exception("Continuation criteria is not well defined.") 140 | 141 | 142 | def check_termination_criteria(termination_criteria): 143 | """ 144 | Right now only implements a maximum amount of rounds. 145 | """ 146 | if "type" not in termination_criteria: 147 | raise Exception("Termination criteria is not well defined.") 148 | 149 | if termination_criteria["type"] == "MAX_ROUND": 150 | return termination_criteria["value"] < state.state["current_round"] 151 | else: 152 | raise Exception("Termination criteria is not well defined.") 153 | -------------------------------------------------------------------------------- /cloud-node/buildspec.yml: -------------------------------------------------------------------------------- 1 | version: 0.2 2 | phases: 3 | install: 4 | runtime-versions: 5 | python: 3.7 6 | commands: 7 | - pip install -r requirements.txt 8 | artifacts: 9 | files: 10 | - '**/*' 11 | # build: 12 | # commands: 13 | # - python server.py 14 | 15 | 16 | -------------------------------------------------------------------------------- /cloud-node/coordinator.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import logging 3 | 4 | import state 5 | from model import convert_and_save_b64model, convert_and_save_model, swap_weights 6 | 7 | 8 | logging.basicConfig(level=logging.DEBUG) 9 | 10 | def start_new_session(message, clients_dict): 11 | """ 12 | Starts a new DML session. 13 | """ 14 | print("Starting new session...") 15 | 16 | # // 1. If server is BUSY, error. Otherwise, mark the service as BUSY. 17 | if state.state["busy"]: 18 | print("Aborting because the server is busy.") 19 | return { 20 | "error": True, 21 | "message": "Server is already busy working." 22 | } 23 | state.state_lock.acquire() 24 | state.state["busy"] = True 25 | 26 | # // 2. Set the internal round variable to 1, reset the number of nodes 27 | # // averaged to 0, update the initial message. 28 | state.state["current_round"] = 1 29 | state.state["num_nodes_averaged"] = 0 30 | state.state["initial_message"] = message 31 | state.state["repo_id"] = message.repo_id 32 | state.state["session_id"] = str(uuid.uuid4()) 33 | 34 | # // 3. According to the 'Selection Criteria', choose clients to forward 35 | # // training messages to. 36 | chosen_clients = _choose_clients( 37 | message.selection_criteria, 38 | clients_dict["LIBRARY"] 39 | ) 40 | state.state["num_nodes_chosen"] = len(chosen_clients) 41 | 42 | # // 4. Convert .h5 model into TFJS model 43 | _ = convert_and_save_b64model(message.h5_model) 44 | 45 | # // 5. Kickstart a DML Session with the TFJS model and round # 1 46 | new_message = { 47 | "session_id": state.state["session_id"], 48 | "repo_id": state.state["repo_id"], 49 | "round": 1, 50 | "action": "TRAIN", 51 | "hyperparams": message.hyperparams, 52 | } 53 | state.state["last_message_sent_to_library"] = new_message 54 | state.state_lock.release() 55 | return { 56 | "error": False, 57 | "action": "BROADCAST", 58 | "client_list": chosen_clients, 59 | "message": new_message, 60 | } 61 | 62 | 63 | def start_next_round(message, clients_list): 64 | """ 65 | Starts a new round in the current DML Session. 66 | """ 67 | print("Starting next round...") 68 | state.state["num_nodes_averaged"] = 0 69 | 70 | # According to the 'Selection Criteria', choose clients to forward 71 | # training messages to. 72 | chosen_clients = _choose_clients(message.selection_criteria, clients_list) 73 | state.state["num_nodes_chosen"] = len(chosen_clients) 74 | 75 | # Swap weights and convert (NEW) .h5 model into TFJS model 76 | swap_weights() 77 | assert state.state["current_round"] > 0 78 | _ = convert_and_save_model(state.state["current_round"] - 1) 79 | 80 | # Kickstart a DML Session with the TFJS model 81 | new_message = { 82 | "session_id": state.state["session_id"], 83 | "repo_id": state.state["repo_id"], 84 | "round": state.state["current_round"], 85 | "action": "TRAIN", 86 | "hyperparams": message.hyperparams, 87 | } 88 | state.state["last_message_sent_to_library"] = new_message 89 | return { 90 | "error": False, 91 | "action": "BROADCAST", 92 | "client_list": chosen_clients, 93 | "message": new_message, 94 | } 95 | 96 | 97 | def _choose_clients(selection_criteria, client_list): 98 | """ 99 | TO BE FINISHED. 100 | 101 | Need to define a selection criteria object first. 102 | 103 | Right now it just chooses all clients. 104 | """ 105 | return client_list 106 | -------------------------------------------------------------------------------- /cloud-node/message.py: -------------------------------------------------------------------------------- 1 | import json 2 | import base64 3 | from enum import Enum 4 | 5 | import numpy as np 6 | 7 | class MessageType(Enum): 8 | """ 9 | Message Type 10 | 11 | Message Types that the service can work with. 12 | 13 | """ 14 | 15 | REGISTER = "REGISTER" 16 | NEW_SESSION = "NEW_SESSION" 17 | NEW_WEIGHTS = "NEW_WEIGHTS" 18 | 19 | 20 | class Message: 21 | """ 22 | Message 23 | 24 | Base class for messages received by the service. 25 | 26 | """ 27 | 28 | @staticmethod 29 | def make(serialized_message): 30 | type, data = serialized_message["type"], serialized_message 31 | for cls in Message.__subclasses__(): 32 | if cls.type == type: 33 | return cls(data) 34 | raise ValueError("Message type is invalid!") 35 | 36 | 37 | class RegistrationMessage(Message): 38 | """ 39 | Registration Message 40 | 41 | The type of message initially sent by a node with information of what type 42 | of node they are. 43 | 44 | `node_type` should be one of DASHBOARD or LIBRARY. 45 | 46 | """ 47 | 48 | type = MessageType.REGISTER.value 49 | 50 | def __init__(self, serialized_message): 51 | self.node_type = serialized_message["node_type"].upper() 52 | 53 | def __repr__(self): 54 | return json.dumps({ 55 | "node_type": self.node_type 56 | }) 57 | 58 | 59 | class NewSessionMessage(Message): 60 | """ 61 | New Session Message 62 | 63 | The type of message sent by Explora to start a new session. 64 | 65 | """ 66 | 67 | type = MessageType.NEW_SESSION.value 68 | 69 | def __init__(self, serialized_message): 70 | self.repo_id = serialized_message["repo_id"] 71 | self.h5_model = serialized_message["h5_model"] 72 | self.hyperparams = serialized_message["hyperparams"] 73 | self.selection_criteria = serialized_message["selection_criteria"] 74 | self.continuation_criteria = serialized_message["continuation_criteria"] 75 | self.termination_criteria = serialized_message["termination_criteria"] 76 | 77 | def __repr__(self): 78 | return json.dumps({ 79 | "repo_id": self.repo_id, 80 | "h5_model": self.h5_model[:20], 81 | "hyperparams": self.hyperparams, 82 | "selection_criteria": self.selection_criteria, 83 | "continuation_criteria": self.continuation_criteria, 84 | "termination_criteria": self.termination_criteria, 85 | }) 86 | 87 | 88 | class NewWeightsMessage(Message): 89 | """ 90 | New Weights Message 91 | 92 | The type of message sent by the Library. This is an update. 93 | 94 | """ 95 | 96 | type = MessageType.NEW_WEIGHTS.value 97 | 98 | def __init__(self, serialized_message): 99 | self.session_id = serialized_message["session_id"] 100 | self.round = serialized_message["round"] 101 | self.action = serialized_message["action"] 102 | self.weights = np.array( 103 | serialized_message["results"]["weights"], 104 | dtype=np.dtype(float), 105 | ) 106 | self.omega = serialized_message["results"]["omega"] 107 | 108 | def __repr__(self): 109 | return json.dumps({ 110 | "session_id": self.session_id, 111 | "round": self.round, 112 | "action": self.action, 113 | "weights": "omitted", 114 | "omega": self.omega, 115 | }) 116 | -------------------------------------------------------------------------------- /cloud-node/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import json 4 | import base64 5 | from functools import reduce 6 | 7 | import keras 8 | import numpy as np 9 | import tensorflowjs as tfjs 10 | from keras import backend as K 11 | 12 | import state 13 | 14 | 15 | TEMP_FOLDER = 'temp' 16 | 17 | def convert_and_save_b64model(base64_h5_model): 18 | """ 19 | This function is to be called at the beginning of a DML Serssion. 20 | 21 | It takes the initial h5 model encoded in Base64, decodes it, saves it on 22 | disk (in //model.h5), then calls the helper 23 | function `_convert_and_save_model()` to convert the model into a tf.js 24 | model which will be served to library nodes. 25 | 26 | base64_h5_model: base64 string of an h5 keras model 27 | """ 28 | session_id = state.state["session_id"] 29 | model_path = os.path.join(TEMP_FOLDER, session_id) 30 | 31 | # Create directory if necessary 32 | if not os.path.exists(model_path): 33 | os.makedirs(model_path) 34 | 35 | # Save model on disk 36 | h5_model_path = model_path + '/model.h5' 37 | h5_model_bytes = base64.b64decode(base64_h5_model) 38 | with open(h5_model_path, 'wb') as fp: 39 | fp.write(h5_model_bytes) 40 | 41 | # Convert and save model for serving 42 | return _convert_and_save_model(h5_model_path) 43 | 44 | 45 | def convert_and_save_model(round): 46 | """ 47 | This function is to be called when moving to a new round. 48 | 49 | It takes the current round number to construct the path where the model is 50 | stored, then it calls the helper function `_convert_and_save_model()` to 51 | convert the model into a tf.js model which will be served to library nodes. 52 | """ 53 | session_id = state.state["session_id"] 54 | round = state.state["current_round"] 55 | model_path = os.path.join(TEMP_FOLDER, session_id) 56 | h5_model_path = model_path + '/model{}.h5'.format(round) 57 | return _convert_and_save_model(h5_model_path) 58 | 59 | 60 | def _convert_and_save_model(h5_model_path): 61 | """ 62 | Helper function that converts the given h5 model (from the path) into a 63 | tf.js model, extracts metadata from the model, and prepares the temp folder 64 | where this new converted model will be served from. 65 | 66 | The new converted model gets stored in: 67 | // 68 | 69 | Where the following files get created: 70 | - group1.-shard1of1.bin 71 | - model.json 72 | - metadata.json 73 | 74 | This function returns the path to the converted model on disk. 75 | """ 76 | session_id = state.state["session_id"] 77 | round = state.state["current_round"] 78 | converted_model_path = os.path.join(TEMP_FOLDER, session_id, str(round)) 79 | 80 | _keras_2_tfjs(h5_model_path, converted_model_path) 81 | 82 | model_json_path = converted_model_path + "/model.json" 83 | with open(model_json_path, 'r') as fp: 84 | model_json = json.loads(fp.read()) 85 | state.state["weights_shape"] = model_json["weightsManifest"][0]["weights"] 86 | 87 | metadata_path = converted_model_path + '/metadata.json' 88 | metadata = { 89 | "session_id": session_id, 90 | "current_round": round, 91 | } 92 | with open(metadata_path, 'w') as fp: 93 | json.dump(metadata, fp, sort_keys=True, indent=4) 94 | 95 | return converted_model_path 96 | 97 | 98 | def swap_weights(): 99 | """ 100 | Loads the initial stored h5 model in //model.h5, 101 | swaps the weights with the aggregated weights currently in the global state, 102 | then saves the new model in //model.h5. 103 | 104 | This function is to be called before running `convert_and_save_model()` 105 | when moving to a new round. 106 | """ 107 | model_path = os.path.join(TEMP_FOLDER, state.state["session_id"]) 108 | h5_model_path = model_path + '/model.h5' 109 | model = keras.models.load_model(h5_model_path) 110 | 111 | weights_flat = state.state["current_weights"] 112 | weights_shape = state.state["weights_shape"] 113 | weights, start = [], 0 114 | for shape_data in weights_shape: 115 | shape = shape_data["shape"] 116 | size = reduce(lambda x, y: x*y, shape) 117 | weights_np = np.array(weights_flat[start:start+size]) 118 | weights_np.resize(tuple(shape)) 119 | weights.append(weights_np) 120 | start += size 121 | model.set_weights(weights) 122 | 123 | round = state.state["current_round"] 124 | new_h5_model_path = model_path + '/model{0}.h5'.format(round) 125 | model.save(new_h5_model_path) 126 | K.clear_session() 127 | 128 | def _keras_2_tfjs(h5_model_path, path_to_save): 129 | """ 130 | Converts a Keras h5 model into a tf.js model and saves it on disk. 131 | """ 132 | model = keras.models.load_model(h5_model_path) 133 | tfjs.converters.save_keras_model(model, path_to_save, np.uint16) 134 | K.clear_session() 135 | 136 | def _test(): 137 | """ 138 | Nothing important here. 139 | """ 140 | state.init() 141 | out = _convert_and_save_model('../notebooks/saved_mlp_model_with_w.h5') 142 | print(out) 143 | 144 | 145 | # def decode_weights(h5_model_path): 146 | # """ 147 | # Do Keras stuff here 148 | # """ 149 | # model = keras.models.load_model(h5_model_path) 150 | # weights = model.get_weights() 151 | # return weights 152 | 153 | 154 | # def _test2(): 155 | # out = decode_weights('../notebooks/saved_mlp_model_with_w.h5') 156 | # print(out) 157 | -------------------------------------------------------------------------------- /cloud-node/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | appnope==0.1.0 3 | astor==0.7.1 4 | attrs==19.1.0 5 | autobahn==19.3.3 6 | Automat==0.7.0 7 | backcall==0.1.0 8 | blessed==1.15.0 9 | boto3==1.9.130 10 | botocore==1.12.136 11 | cached-property==1.5.1 12 | cement==2.8.2 13 | certifi==2019.3.9 14 | chardet==3.0.4 15 | Click==7.0 16 | colorama==0.3.9 17 | constantly==15.1.0 18 | decorator==4.4.0 19 | dnspython==1.16.0 20 | docker==3.7.2 21 | docker-compose==1.23.2 22 | docker-pycreds==0.4.0 23 | dockerpty==0.4.1 24 | docopt==0.6.2 25 | docutils==0.14 26 | eventlet==0.24.1 27 | Flask==1.0.2 28 | Flask-Cors==3.0.7 29 | Flask-SocketIO==3.3.2 30 | Flask-Sockets==0.2.1 31 | Flask-uWSGI-WebSocket==0.6.1 32 | future==0.16.0 33 | gast==0.2.2 34 | gevent==1.4.0 35 | gevent-websocket==0.10.1 36 | google-pasta==0.1.4 37 | greenlet==0.4.15 38 | grpcio==1.19.0 39 | h5py==2.8.0 40 | hyperlink==18.0.0 41 | idna==2.7 42 | incremental==17.5.0 43 | ipykernel==5.1.0 44 | ipython==7.4.0 45 | ipython-genutils==0.2.0 46 | itsdangerous==1.1.0 47 | jedi==0.13.3 48 | Jinja2==2.10 49 | jmespath==0.9.4 50 | jsonschema==2.6.0 51 | jupyter-client==5.2.4 52 | jupyter-core==4.4.0 53 | Keras==2.2.4 54 | Keras-Applications==1.0.7 55 | Keras-Preprocessing==1.0.9 56 | Markdown==3.1 57 | MarkupSafe==1.1.1 58 | mock==2.0.0 59 | monotonic==1.5 60 | numpy==1.15.1 61 | parso==0.4.0 62 | pathspec==0.5.9 63 | pbr==5.1.3 64 | pexpect==4.7.0 65 | pickleshare==0.7.5 66 | prompt-toolkit==2.0.9 67 | protobuf==3.7.1 68 | ptyprocess==0.6.0 69 | pyasn1==0.4.5 70 | Pygments==2.3.1 71 | PyHamcrest==1.9.0 72 | python-dateutil==2.8.0 73 | python-engineio==3.5.0 74 | python-socketio==3.1.2 75 | PyYAML==3.13 76 | pyzmq==18.0.1 77 | requests==2.20.1 78 | rsa==3.4.2 79 | s3transfer==0.2.0 80 | scipy==1.2.1 81 | semantic-version==2.5.0 82 | SimpleWebSocketServer==0.1.1 83 | six==1.11.0 84 | tb-nightly==1.14.0a20190319 85 | tensorboard==1.13.1 86 | tensorflow==1.13.1 87 | tensorflow-estimator==1.13.0 88 | tensorflow-estimator-2.0-preview==1.14.0.dev2019040400 89 | tensorflow-hub==0.3.0 90 | tensorflowjs==1.0.1 91 | termcolor==1.1.0 92 | texttable==0.9.1 93 | tf-nightly-2.0-preview==2.0.0.dev20190404 94 | tornado==6.0.2 95 | traitlets==4.3.2 96 | Twisted==18.9.0 97 | txaio==18.8.1 98 | urllib3==1.24.1 99 | uWSGI==2.0.18 100 | wcwidth==0.1.7 101 | websocket-client==0.56.0 102 | websocket-server==0.4 103 | websockets==7.0 104 | Werkzeug==0.15.2 105 | zope.interface==4.6.0 106 | -------------------------------------------------------------------------------- /cloud-node/server.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import uuid 7 | import json 8 | 9 | import io 10 | 11 | from flask_cors import CORS, cross_origin 12 | from twisted.python import log 13 | 14 | import werkzeug.formparser 15 | 16 | from twisted.web.server import Site 17 | from twisted.web.wsgi import WSGIResource 18 | from twisted.internet import task, reactor 19 | from flask import Flask, jsonify, send_from_directory 20 | from autobahn.twisted.websocket import WebSocketServerProtocol 21 | from autobahn.twisted.websocket import WebSocketServerFactory 22 | from autobahn.twisted.resource import WebSocketResource, WSGIRootResource 23 | 24 | import state 25 | from message import MessageType, Message 26 | from coordinator import start_new_session 27 | from aggregator import handle_new_weights 28 | import os 29 | 30 | from tensorflow.keras import backend as K 31 | import tensorflow as tf 32 | 33 | class CloudNodeProtocol(WebSocketServerProtocol): 34 | """ 35 | Cloud Node Protocol 36 | 37 | Class that implements part of the Cloud Node networking logic (what happens 38 | when a new node connects, sends a message, disconnects). The networking here 39 | happens through Websockets using the autobahn library. 40 | 41 | """ 42 | 43 | def onConnect(self, request): 44 | """ 45 | Logs that a node has successfully connected. 46 | """ 47 | print("Client connecting: {}".format(request.peer)) 48 | 49 | def onOpen(self): 50 | """ 51 | Logs that a connection was opened. 52 | """ 53 | print("WebSocket connection open.") 54 | 55 | def onClose(self, wasClean, code, reason): 56 | """ 57 | Deregisters a node upon websocket closure and logs it. 58 | """ 59 | print("WebSocket connection closed: {}".format(reason)) 60 | self.factory.unregister(self) 61 | 62 | def onMessage(self, payload, isBinary): 63 | """ 64 | Processes the payload received by a connected node. 65 | 66 | Messages are ignored unless the message is of type "REGISTER" or the 67 | node has already been registered (by sending a "REGISTER" type message). 68 | 69 | """ 70 | print("Got payload!") 71 | print(os.environ) 72 | if isBinary: 73 | print("Binary message not supported.") 74 | return 75 | 76 | # Convert message to JSON 77 | try: 78 | serialized_message = json.loads(payload) 79 | except Exception: 80 | print("Error while converting JSON.") 81 | return 82 | 83 | # Deserialize message 84 | try: 85 | message = Message.make(serialized_message) 86 | print("Message ({0}) contents: {1}".format(message.type, message)) 87 | except Exception as e: 88 | print("Error deserializing message!", e) 89 | error_json = json.dumps({"error": True, "message": "Error deserializing message: {}".format(e)}) 90 | self.sendMessage(error_json.encode(), isBinary) 91 | return 92 | 93 | # Process message 94 | if message.type == MessageType.REGISTER.value: 95 | # Register the node 96 | if message.node_type in ["DASHBOARD", "LIBRARY"]: 97 | self.factory.register(self, message.node_type) 98 | print("Registered node as type: {}".format(message.node_type)) 99 | 100 | if message.node_type == "LIBRARY" and state.state["busy"] is True: 101 | # There's a session active, we should incorporate the just 102 | # added node into the session! 103 | print("Adding the new library node to this round!") 104 | last_message = state.state["last_message_sent_to_library"] 105 | message_json = json.dumps(last_message) 106 | self.sendMessage(message_json.encode(), isBinary) 107 | else: 108 | print("WARNING: Incorrect node type ({}) -- ignoring!".format(message.node_type)) 109 | elif message.type == MessageType.NEW_SESSION.value: 110 | # Verify this node has been registered 111 | if not self._nodeHasBeenRegistered(client_type="DASHBOARD"): return 112 | 113 | # Start new DML Session 114 | results = start_new_session(message, self.factory.clients) 115 | 116 | # Error check 117 | if results["error"]: 118 | self.sendMessage(json.dumps(results).encode(), isBinary) 119 | return 120 | 121 | # Handle results 122 | if results["action"] == "BROADCAST": 123 | self._broadcastMessage( 124 | payload=results["message"], 125 | client_list=results["client_list"], 126 | isBinary=isBinary, 127 | ) 128 | 129 | elif message.type == MessageType.NEW_WEIGHTS.value: 130 | # Verify this node has been registered 131 | if not self._nodeHasBeenRegistered(client_type="LIBRARY"): return 132 | 133 | # Handle new weights (average, move to next round, terminate session) 134 | 135 | results = handle_new_weights(message, self.factory.clients) 136 | 137 | # Error check 138 | if results["error"]: 139 | self.sendMessage(json.dumps(results).encode(), isBinary) 140 | return 141 | 142 | # Handle message 143 | if "action" in results: 144 | if results["action"] == "BROADCAST": 145 | self._broadcastMessage( 146 | payload=results["message"], 147 | client_list=results["client_list"], 148 | isBinary=isBinary, 149 | ) 150 | else: 151 | # Acknowledge message (temporarily! -- node doesn't need to know) 152 | self.sendMessage(json.dumps({"error": False, "message": "ack"}).encode(), isBinary) 153 | message = { 154 | "session_id": state.state["session_id"], 155 | "repo_id": state.state["repo_id"], 156 | "action": "NEW_MODEL" 157 | } 158 | self._broadcastMessage( 159 | payload=message, 160 | client_list = self.factory.clients["DASHBOARD"], 161 | isBinary=isBinary 162 | ) 163 | else: 164 | print("Unknown message type!") 165 | error_json = json.dumps({"error": True, "message": "Unknown message type!"}) 166 | self.sendMessage(error_json.encode(), isBinary) 167 | 168 | print("[[DEBUG] State: {}".format(state.state)) 169 | 170 | def _broadcastMessage(self, payload, client_list, isBinary): 171 | """ 172 | Broadcast message (`payload`) to a `client_list`. 173 | """ 174 | for c in client_list: 175 | results_json = json.dumps(payload) 176 | c.sendMessage(results_json.encode(), isBinary) 177 | 178 | def _nodeHasBeenRegistered(self, client_type): 179 | """ 180 | Returns whether the node in scope has been registered into one of the 181 | `client_type`'s. 182 | """ 183 | return self.factory.is_registered(self, client_type) 184 | 185 | class CloudNodeFactory(WebSocketServerFactory): 186 | 187 | """ 188 | Cloud Node Factory 189 | 190 | Class that implements part of the Cloud Node networking logic. It keeps 191 | track of the nodes that have been registered. 192 | 193 | """ 194 | 195 | def __init__(self): 196 | WebSocketServerFactory.__init__(self) 197 | self.clients = {"DASHBOARD": [], "LIBRARY": []} 198 | 199 | def register(self, client, type): 200 | client_already_exists = False 201 | for _, clients in self.clients.items(): 202 | if client in clients: 203 | client_already_exists = True 204 | if not client_already_exists: 205 | print("Registered client {}".format(client.peer)) 206 | self.clients[type].append(client) 207 | 208 | def unregister(self, client): 209 | for node_type, clients in self.clients.items(): 210 | if client in clients: 211 | print("Unregistered client {}".format(client.peer)) 212 | self.clients[node_type].remove(client) 213 | 214 | def is_registered(self, client, client_type): 215 | """Returns whether client is in the list of clients.""" 216 | return client in self.clients[client_type] 217 | 218 | # NOTE: We need to implement some ping-pong/ways to deal with disconnections. 219 | 220 | app = Flask(__name__) 221 | app.secret_key = str(uuid.uuid4()) 222 | CORS(app) 223 | 224 | @app.route("/status") 225 | def get_status(): 226 | """ 227 | Returns the status of the Cloud Node. 228 | 229 | The dashboard-api is the only hitting this endpoint, so it should be secured. 230 | """ 231 | return jsonify({"Busy": state.state["busy"]}) 232 | 233 | @app.route('/model/') 234 | def serve_model(filename): 235 | """ 236 | Serves the models to the user. 237 | 238 | TODO: Should do this through ngnix for a boost in performance. Should also 239 | have some auth token -> session id mapping (security fix in the future). 240 | """ 241 | session_id = state.state["session_id"] 242 | round = state.state["current_round"] 243 | return send_from_directory( 244 | app.root_path + '/temp/' + session_id + "/" + str(round), 245 | filename, 246 | ) 247 | 248 | @app.route('/secret/reset_state') 249 | def reset_state(): 250 | """ 251 | Resets the state of the cloud node. 252 | 253 | TODO: This is only for debugging. Should be deleted. 254 | """ 255 | state.state_lock.acquire() 256 | state.reset_state() 257 | state.state_lock.release() 258 | return "State reset successfully!" 259 | 260 | @app.route('/secret/get_state') 261 | def get_state(): 262 | """ 263 | Get the state of the cloud node. 264 | 265 | TODO: This is only for debugging. Should be deleted. 266 | """ 267 | return repr(state.state) 268 | 269 | # def check_timeout_condition(): 270 | # """ 271 | # TO BE IMPLEMENTED. 272 | # """ 273 | # TIMEOUT_DELTA_IN_MINS = 10 274 | # time_now = time.time() 275 | # if time_now > TIMEOUT_DELTA_IN_MINS * 60: 276 | # # Need to trigger the event of broadcasting to all nodes. 277 | # # The nodes to drop everything they were doing. 278 | # pass 279 | 280 | 281 | if __name__ == '__main__': 282 | 283 | log.startLogging(sys.stdout) 284 | 285 | factory = CloudNodeFactory() 286 | factory.protocol = CloudNodeProtocol 287 | wsResource = WebSocketResource(factory) 288 | 289 | wsgiResource = WSGIResource(reactor, reactor.getThreadPool(), app) 290 | rootResource = WSGIRootResource(wsgiResource, {b'': wsResource}) 291 | site = Site(rootResource) 292 | 293 | state.init() 294 | 295 | reactor.listenTCP(8999, site) 296 | reactor.run() 297 | 298 | print("Starting cloud node...") 299 | print(os.environ) 300 | -------------------------------------------------------------------------------- /cloud-node/state.py: -------------------------------------------------------------------------------- 1 | """Global state for the service.""" 2 | 3 | def init(): 4 | import threading 5 | global state_lock 6 | state_lock = threading.Lock() 7 | 8 | global reset_state 9 | def reset_state(): 10 | global state 11 | state = { 12 | "busy": False, 13 | "session_id": None, 14 | "repo_id": None, 15 | "current_round": 0, 16 | "num_nodes_averaged": 0, 17 | "num_nodes_chosen": 0, 18 | "current_weights" : None, 19 | "sigma_omega": None, 20 | "weights_shape": None, 21 | "initial_message": None, 22 | "last_message_time": None, 23 | "last_message_sent_to_library": None, 24 | } 25 | 26 | reset_state() 27 | -------------------------------------------------------------------------------- /cloud-node/tools/assets/init_mlp_model_with_w.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DiscreetAI/cloud-node/6995bb282bfe37adcdc56a6729165cd7d15699dd/cloud-node/tools/assets/init_mlp_model_with_w.h5 -------------------------------------------------------------------------------- /cloud-node/tools/assets/my_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DiscreetAI/cloud-node/6995bb282bfe37adcdc56a6729165cd7d15699dd/cloud-node/tools/assets/my_model.h5 -------------------------------------------------------------------------------- /cloud-node/tools/assets/saved_mlp_model_with_w.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DiscreetAI/cloud-node/6995bb282bfe37adcdc56a6729165cd7d15699dd/cloud-node/tools/assets/saved_mlp_model_with_w.h5 -------------------------------------------------------------------------------- /cloud-node/tools/start_new_session.py: -------------------------------------------------------------------------------- 1 | import json 2 | import base64 3 | 4 | from autobahn.twisted.websocket import WebSocketClientProtocol 5 | 6 | # This ID is used to log things into DynamoDB. 7 | # If you want your results to be saved somewhere you can easily access, then 8 | # change this to something you can remember. 9 | REPO_ID = "test" 10 | 11 | CLOUD_NODE_HOST = "3e55b6e37447aca26c807c2aa5961d89.au4c4pd2ch.us-west-1.elasticbeanstalk.com" 12 | CLOUD_NODE_PORT = 80 13 | 14 | with open('assets/my_model.h5', mode='rb') as file: 15 | file_content = file.read() 16 | encoded_content = base64.encodebytes(file_content) 17 | h5_model = encoded_content.decode('ascii') 18 | 19 | NEW_MESSAGE = { 20 | "type": "NEW_SESSION", 21 | "repo_id": REPO_ID, 22 | "h5_model": h5_model, 23 | "hyperparams": { 24 | "batch_size": 128, 25 | "epochs": 10, 26 | }, 27 | "selection_criteria": { 28 | "type": "ALL_NODES", 29 | }, 30 | "continuation_criteria": { 31 | "type": "PERCENTAGE_AVERAGED", 32 | "value": 0.75 33 | }, 34 | "termination_criteria": { 35 | "type": "MAX_ROUND", 36 | "value": 2 37 | } 38 | } 39 | 40 | NEW_CONNECTION_MESSAGE = { 41 | "type": "REGISTER", 42 | "node_type": "dashboard", 43 | } 44 | 45 | class NewSessionTestProtocol(WebSocketClientProtocol): 46 | 47 | def onOpen(self): 48 | json_data = json.dumps(NEW_CONNECTION_MESSAGE) 49 | self.sendMessage(json_data.encode()) 50 | json_data = json.dumps(NEW_MESSAGE) 51 | self.sendMessage(json_data.encode()) 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | import sys 57 | 58 | from twisted.python import log 59 | from twisted.internet import reactor 60 | log.startLogging(sys.stdout) 61 | 62 | from autobahn.twisted.websocket import WebSocketClientFactory 63 | factory = WebSocketClientFactory() 64 | factory.protocol = NewSessionTestProtocol 65 | 66 | reactor.connectTCP(CLOUD_NODE_HOST, CLOUD_NODE_PORT, factory) 67 | reactor.run() 68 | -------------------------------------------------------------------------------- /cloud-node/updatestore.py: -------------------------------------------------------------------------------- 1 | import time 2 | import uuid 3 | import os 4 | 5 | import boto3 6 | 7 | import state 8 | from model import TEMP_FOLDER 9 | 10 | 11 | def store_update(type, message, with_weights=True): 12 | """ 13 | Stores an update in DynamoDB. If weights are present, it stores them in S3. 14 | """ 15 | 16 | print("[{0}]: {1}".format(type, message)) 17 | 18 | access_key = os.environ["ACCESS_KEY_ID"] 19 | secret_key = os.environ["SECRET_ACCESS_KEY"] 20 | if with_weights: 21 | try: 22 | repo_id = state.state['repo_id'] 23 | session_id = state.state['session_id'] 24 | round = state.state['current_round'] 25 | s3 = boto3.resource('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key) 26 | weights_s3_key = '{0}/{1}/{2}/model.h5'.format(repo_id, session_id, round) 27 | object = s3.Object('updatestore', weights_s3_key) 28 | h5_filepath = TEMP_FOLDER + "/{0}/model{1}.h5".format(session_id, round) 29 | object.put(Body=open(h5_filepath, 'rb')) 30 | except Exception as e: 31 | print("S3 Error: {0}".format(e)) 32 | 33 | try: 34 | dynamodb = boto3.resource('dynamodb', region_name='us-west-1', aws_access_key_id=access_key, aws_secret_access_key=secret_key) 35 | table = dynamodb.Table("UpdateStore") 36 | item = { 37 | 'Id': str(uuid.uuid4()), 38 | 'RepoId': state.state["repo_id"], 39 | 'Timestamp': int(time.time()), 40 | 'ContentType': type, 41 | 'SessionId': state.state["session_id"], 42 | 'Content': repr(message), 43 | } 44 | if with_weights: 45 | item['WeightsS3Key'] = "s3://updatestore/" + weights_s3_key 46 | table.put_item(Item=item) 47 | except Exception as e: 48 | print("DB Error: {0}".format(e)) 49 | --------------------------------------------------------------------------------