├── src ├── split_with_socket │ ├── utils │ │ ├── atomic_socket.py │ │ ├── test_client.py │ │ ├── relay_protocol.py │ │ ├── client.py │ │ └── server.py │ ├── data_package.py │ ├── mnistMIT.py │ └── mnistHarvard.py ├── no_split │ └── mnist.py └── split_no_socket │ └── mnist_split.py ├── LICENSE ├── .gitignore └── README.md /src/split_with_socket/utils/atomic_socket.py: -------------------------------------------------------------------------------- 1 | # Socket wrapper that sends the amount of data before each message. 2 | 3 | buffer_size = 4096 4 | 5 | def send(socket, data): 6 | bytes = len(data.encode()) 7 | socket.sendall(str(format(bytes, '10d')).encode()) 8 | socket.sendall(data.encode()) 9 | 10 | def recv(socket): 11 | # Look for the response 12 | data = socket.recv(10) 13 | if (len(data) == 0): 14 | return 15 | 16 | amount_received = 0 17 | amount_expected = int(data) 18 | 19 | complete_str = "" 20 | 21 | while amount_received < amount_expected: 22 | data = socket.recv(buffer_size) 23 | amount_received += len(data) 24 | complete_str += data.decode() 25 | 26 | return complete_str -------------------------------------------------------------------------------- /src/split_with_socket/utils/test_client.py: -------------------------------------------------------------------------------- 1 | import utils.client as client 2 | 3 | #relay_server_host = 'nebula.media.mit.edu' # as both code is running on same pc 4 | #relay_server_host = 'ec2-3-14-72-103.us-east-2.compute.amazonaws.com' 5 | relay_server_host = '127.0.0.1' 6 | relay_server_port = 5004 # socket server port number 7 | 8 | client_id = 'ME' 9 | client_send_to = 'ME' 10 | test_data = 'Hello' 11 | 12 | class DataPackage(): 13 | def __init__(self, s): 14 | self.s = s 15 | 16 | 17 | client.connect(relay_server_host, relay_server_port, client_id) 18 | 19 | print('Testing sending the word ' + test_data) 20 | 21 | client.sendData(client_send_to, DataPackage(test_data)) 22 | 23 | print('Received back the word: ' + client.receiveData().s) 24 | 25 | client.disconnect() # close the connection 26 | 27 | -------------------------------------------------------------------------------- /src/split_with_socket/data_package.py: -------------------------------------------------------------------------------- 1 | # This file contains the objects that encompass the basic Split Learning protocol. 2 | # They can be changed to anything you want. The Client API will serialize this objects 3 | # and send them to the desired client. 4 | 5 | # Asks the other side to reset the current model if any. Receives confirmation back 6 | class ResetModelPackage(): 7 | pass 8 | 9 | class NewModelReadyPackage(): 10 | pass 11 | 12 | # Asks the other side to train with the following data and labels. It should return a backward prop object. 13 | class ForwardPropPackage(): 14 | def __init__(self, y, labels): 15 | self.y = y 16 | self.labels = labels 17 | 18 | class BackwardPropPackage(): 19 | def __init__(self, grad, loss): 20 | self.grad = grad 21 | self.loss = loss 22 | 23 | # Packages to evaluate the model 24 | class EvaluatePackage(): 25 | def __init__(self, y): 26 | self.y = y 27 | 28 | class EvaluationPackage(): 29 | def __init__(self, logps): 30 | self.logps = logps -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vitor Pamplona 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Split Learning with a Relay Server 2 | 3 | Simple MNIST Machine Learning code in three ways: 4 | 1. Normal: sequential code to train and test an MNIST model (no split, no sockets) 5 | 2. Split: Split learning code to train and test an MNIST model (no sockets) 6 | 3. Split w/ Sockets: Split learning code to train and test an MNIST model between machines at Harvard - first layers - and MIT - last layers - using a relay message server. 7 | 8 | # Running Locally 9 | 10 | Open 5 terminal windows and run in this sequence. 11 | 12 | ## Terminal 1: Regular MNist code 13 | ``` 14 | python3 src/no_split/mnist.py 15 | ``` 16 | 17 | Expected output: Model Accuracy = 0.9775 18 | 19 | ## Terminal 2: Split Learning (no sockets) code 20 | ``` 21 | python3 src/split_no_socket/mnist_split.py 22 | ``` 23 | 24 | Expected output: Model Accuracy = 0.9742 25 | 26 | ## Terminal 3,4,5: Split Learning with a Relay Server 27 | 28 | ### Terminal 3: Relay Server 29 | ``` 30 | python3 src/split_with_socket/utils/server.py 31 | ``` 32 | 33 | if you need to kill the process associated with the port, type: 34 | ``` 35 | sudo fuser -k 8000/tcp 36 | ``` 37 | 38 | ### Terminal 4: Run MIT 39 | ``` 40 | python3 src/split_with_socket/mnistMIT.py 41 | ``` 42 | 43 | ### Terminal 5: Run Harvard 44 | ``` 45 | python3 src/split_with_socket/mnistHarvard.py 46 | ``` 47 | Expected output: Model Accuracy = 0.9719 48 | 49 | # Installing a server: 50 | 51 | Basic OS update: 52 | ``` 53 | sudo apt update 54 | sudo apt upgrade 55 | sudo hostnamectl set-hostname mit-relay 56 | sudo shutdown -r now 57 | ``` 58 | 59 | Create a Personal Access Token at https://github.mit.edu/settings/tokens (permissions: admin:repo_hook, notifications, read:discussion, read:org, repo, user) to Setup GIT 60 | ``` 61 | git config --global user.email "" 62 | git config --global user.name "" 63 | git config credential.helper store 64 | git clone https://github.mit.edu/vitor-1/split_learning_relay_server.git 65 | Username: 66 | Password: 67 | ``` 68 | 69 | Install Python for the server 70 | ``` 71 | sudo apt install python3 72 | sudo apt install python3-pip 73 | ``` 74 | 75 | Install the ML tools if you want to run everything on the server for testing purposes. 76 | ``` 77 | pip3 install torch 78 | pip3 install torchvision 79 | pip3 install matplotlib 80 | pip3 install numpy 81 | ``` 82 | 83 | Pytorch needs at least 1.5GB of RAM. If you need more memory, don't forget to add a swap file 84 | ``` 85 | sudo dd if=/dev/zero of=/swapfile bs=1k count=4096k 86 | sudo chown root:root /swapfile 87 | sudo chmod 0600 /swapfile 88 | sudo mkswap /swapfile 89 | sudo swapon /swapfile 90 | ``` 91 | 92 | To remove the Swap file 93 | ``` 94 | sudo swapoff /swapfile 95 | sudo rm /swapfile 96 | ``` 97 | -------------------------------------------------------------------------------- /src/no_split/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import matplotlib.pyplot as plt 6 | from time import time 7 | from torchvision import datasets, transforms 8 | from torch import nn, optim 9 | from torch.autograd import Variable 10 | 11 | #define transformations 12 | transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]) 13 | 14 | #download dataset 15 | trainset = datasets.MNIST('/Users/vitorhome/Documents/workspace/splitlearning_relayserver/datasets/train', download=True, train=True, transform=transform) 16 | valset = datasets.MNIST('/Users/vitorhome/Documents/workspace/splitlearning_relayserver/datasets/val', download=True, train=False, transform=transform) 17 | 18 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) 19 | valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True) 20 | 21 | #define models 22 | input_size = 784 23 | hidden_sizes = [128, 64] 24 | output_size = 10 25 | 26 | # Build a feed-forward network 27 | model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), 28 | nn.ReLU(), 29 | nn.Linear(hidden_sizes[0], hidden_sizes[1]), 30 | nn.ReLU(), 31 | nn.Linear(hidden_sizes[1], output_size), 32 | nn.LogSoftmax(dim=1)) 33 | 34 | criterion = nn.NLLLoss() 35 | optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9) 36 | 37 | time0 = time() 38 | epochs = 15 39 | for e in range(epochs): 40 | running_loss = 0 41 | img = 0 42 | 43 | for images, labels in trainloader: 44 | # Flatten MNIST images into a 784 long vector 45 | images = images.view(images.shape[0], -1) 46 | 47 | # Clean the gradients 48 | optimizer.zero_grad() 49 | 50 | # evaluate full model in one pass. 51 | output = model(images) 52 | 53 | # calculate loss 54 | loss = criterion(output, labels) 55 | 56 | #backprop the second model 57 | loss.backward() 58 | 59 | #optimize the weights 60 | optimizer.step() 61 | 62 | running_loss += loss.item() 63 | 64 | img = img+1 65 | 66 | print("Epoch {} {} - Training loss: {}".format(e, img, running_loss/len(trainloader))) 67 | else: 68 | print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader))) 69 | 70 | print("\nTraining Time (in minutes) =",(time()-time0)/60) 71 | 72 | 73 | correct_count, all_count = 0, 0 74 | for images,labels in valloader: 75 | for i in range(len(labels)): 76 | img = images[i].view(1, 784) 77 | with torch.no_grad(): 78 | logps = model(img) 79 | 80 | ps = torch.exp(logps) 81 | probab = list(ps.numpy()[0]) 82 | pred_label = probab.index(max(probab)) 83 | true_label = labels.numpy()[i] 84 | if(true_label == pred_label): 85 | correct_count += 1 86 | all_count += 1 87 | 88 | print("Number Of Images Tested =", all_count) 89 | print("\nModel Accuracy =", (correct_count/all_count)) -------------------------------------------------------------------------------- /src/split_with_socket/mnistMIT.py: -------------------------------------------------------------------------------- 1 | import utils.client as client 2 | import data_package as dataPkg 3 | 4 | import os 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | from torchvision import datasets, transforms 11 | from torch import nn, optim 12 | from torch.autograd import Variable 13 | 14 | #relay_server_host = 'nebula.media.mit.edu' # as both code is running on same pc 15 | #relay_server_host = 'ec2-3-14-72-103.us-east-2.compute.amazonaws.com' 16 | relay_server_host = '127.0.0.1' 17 | relay_server_port = 5004 # socket server port number 18 | 19 | client_id = 'MIT' 20 | client_send_to = 'HARVARD' 21 | 22 | #define transformations 23 | transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]) 24 | 25 | #download dataset 26 | trainset = datasets.MNIST('/tmp/splitlearning_relayserver/datasets/train', download=True, train=True, transform=transform) 27 | valset = datasets.MNIST('/tmp/splitlearning_relayserver/datasets/val', download=True, train=False, transform=transform) 28 | 29 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) 30 | valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True) 31 | 32 | #define models 33 | hidden_sizes = [128, 64] 34 | output_size = 10 35 | 36 | criterion = nn.NLLLoss() 37 | 38 | model = None 39 | optimizer = None 40 | 41 | def reset_model(): 42 | global model 43 | model = nn.Sequential(nn.ReLU(), 44 | nn.Linear(hidden_sizes[1], output_size), 45 | nn.LogSoftmax(dim=1)) 46 | 47 | global optimizer 48 | optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9) 49 | 50 | return dataPkg.NewModelReadyPackage() 51 | 52 | def train_loop(fwd_package): 53 | # cleaning gradients 54 | optimizer.zero_grad() 55 | 56 | # evaluate 57 | output = model(fwd_package.y) 58 | 59 | # calculate losses 60 | loss = criterion(output, fwd_package.labels) 61 | 62 | # backprop 63 | loss.backward() 64 | 65 | # update weights 66 | optimizer.step() 67 | 68 | #return backward Prop package to Harvard. 69 | return dataPkg.BackwardPropPackage(fwd_package.y.grad, loss.item()) 70 | 71 | def eval(eval_package): 72 | output = model(eval_package.y) 73 | return dataPkg.EvaluationPackage(output) 74 | 75 | def mit_program(): 76 | client.connect(relay_server_host, relay_server_port, client_id) 77 | 78 | reset_model() 79 | 80 | while True: 81 | package = client.receiveData() 82 | 83 | if type(package) is dataPkg.ResetModelPackage: 84 | client.sendData(client_send_to, reset_model()) 85 | if type(package) is dataPkg.ForwardPropPackage: 86 | client.sendData(client_send_to, train_loop(package)) 87 | if type(package) is dataPkg.EvaluatePackage: 88 | client.sendData(client_send_to, eval(package)) 89 | 90 | client.disconnect() # close the connection 91 | 92 | if __name__ == '__main__': 93 | mit_program() -------------------------------------------------------------------------------- /src/split_with_socket/utils/relay_protocol.py: -------------------------------------------------------------------------------- 1 | # Protocol using JSON for a Relay Server 2 | # 3 | # 1. Registration: {"register_client_id": [name of the requested client]} 4 | # This informs the server the name of the current socket connection 5 | # 6 | # 2. Send message to: {"to": the_other_client_id, "data": any object} 7 | # Informs the server to send the data to another known client 8 | # 9 | # If the client is not found, the server returns: 10 | # Client not connected: {"client_id_not_connected": [name of the requested client]} 11 | # 12 | # If the message arrives broken, you can ask the client to send again: 13 | # Repeat last message: {"repeat_last_from": client_id} 14 | 15 | import json # to enconde and decode the data 16 | 17 | # Registration object for new connections 18 | class Registration: 19 | def __init__(self, client_id): 20 | self.client_id = client_id 21 | 22 | # Client receives a message from another client. 23 | class ReceivedFrom: 24 | def __init__(self, client_id, data): 25 | self.client_id = client_id 26 | self.data = data 27 | 28 | # Client sends a message to another client. 29 | class SendTo: 30 | def __init__(self, client_id, data): 31 | self.client_id = client_id 32 | self.data = data 33 | 34 | # Repeat call to ask servers and clients to repeat their last message 35 | class Repeat: 36 | def __init__(self, client_id): 37 | self.client_id = client_id 38 | 39 | # Tells the client that the partner is not connected 40 | class PartnerNotConnected: 41 | def __init__(self, client_id): 42 | self.client_id = client_id 43 | 44 | # Process any received messages in the objects above. 45 | def parse_data_msg(data): 46 | msg = json.loads(data) 47 | 48 | if 'repeat_last_from' in msg.keys(): 49 | return Repeat(msg['repeat_last_from']) 50 | 51 | if 'client_id_not_connected' in msg.keys(): 52 | return PartnerNotConnected(msg['client_id_not_connected']) 53 | 54 | if 'from' in msg.keys(): 55 | return ReceivedFrom(msg['from'], msg['data']) 56 | 57 | if 'to' in msg.keys(): 58 | return SendTo(msg['to'], msg['data']) 59 | 60 | if 'register_client_id' in msg.keys(): 61 | return Registration(msg['register_client_id']) 62 | 63 | return 64 | 65 | # Creates the send message to another client. 66 | def send(to, serialized_str): 67 | message = {"to": to, "data": serialized_str} 68 | return json.dumps(message) 69 | 70 | # Server relays a SendTo package into a ReceivedFrom package and sends it to the receiving client. 71 | def relay_message(client_id, data): 72 | return json.dumps({'from': client_id, 'data': data}) 73 | 74 | # Creates the registration message for the client to send to the server. 75 | def register(client_id): 76 | return json.dumps({"register_client_id": client_id}) 77 | 78 | # Creates the repeat message to send to the server. 79 | def repeat(client_id): 80 | return json.dumps({"repeat_last_from": client_id}) 81 | 82 | # Tells the client that the partner he is asking to connect to in not available. 83 | def not_connected(to): 84 | return json.dumps({"client_id_not_connected": to}) 85 | -------------------------------------------------------------------------------- /src/split_no_socket/mnist_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import matplotlib.pyplot as plt 6 | from time import time 7 | from torchvision import datasets, transforms 8 | from torch import nn, optim 9 | from torch.autograd import Variable 10 | 11 | #define transformations 12 | transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]) 13 | 14 | #download dataset 15 | trainset = datasets.MNIST('/Users/vitorhome/Documents/workspace/splitlearning_relayserver/datasets/train', download=True, train=True, transform=transform) 16 | valset = datasets.MNIST('/Users/vitorhome/Documents/workspace/splitlearning_relayserver/datasets/val', download=True, train=False, transform=transform) 17 | 18 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) 19 | valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True) 20 | 21 | #define models 22 | input_size = 784 23 | hidden_sizes = [128, 64] 24 | output_size = 10 25 | 26 | model_1 = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), 27 | nn.ReLU(), 28 | nn.Linear(hidden_sizes[0], hidden_sizes[1])) 29 | 30 | 31 | model_2 = nn.Sequential(nn.ReLU(), 32 | nn.Linear(hidden_sizes[1], output_size), 33 | nn.LogSoftmax(dim=1)) 34 | 35 | 36 | criterion = nn.NLLLoss() 37 | 38 | optimizer_1 = optim.SGD(model_1.parameters(), lr=0.003, momentum=0.9) 39 | optimizer_2 = optim.SGD(model_2.parameters(), lr=0.003, momentum=0.9) 40 | 41 | time0 = time() 42 | 43 | epochs = 15 44 | 45 | for e in range(epochs): 46 | running_loss = 0 47 | img = 0 48 | 49 | for images, labels in trainloader: 50 | # Flatten MNIST images into a 784 long vector 51 | images = images.view(images.shape[0], -1) 52 | 53 | # clean the gradients 54 | optimizer_1.zero_grad() 55 | optimizer_2.zero_grad() 56 | 57 | # evaluate the first model 58 | output_1 = model_1(images) 59 | 60 | # get the gradients 61 | y2 = Variable(output_1.data, requires_grad=True) 62 | 63 | # evaluate the 2nd model with the gradients 64 | output_2 = model_2(y2) 65 | 66 | # calculate the loss 67 | loss = criterion(output_2, labels) 68 | 69 | # backprop the second model 70 | loss.backward() 71 | 72 | # continue backproping through the first model with the gradients (that have been updated) 73 | output_1.backward(y2.grad) 74 | 75 | #optimize the weights for both models 76 | optimizer_1.step() 77 | optimizer_2.step() 78 | 79 | running_loss += loss.item() 80 | 81 | img = img+1 82 | 83 | print("Epoch {} {} - Training loss: {}".format(e, img, running_loss/len(trainloader))) 84 | else: 85 | print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader))) 86 | 87 | print("\nTraining Time (in minutes) =",(time()-time0)/60) 88 | 89 | 90 | correct_count, all_count = 0, 0 91 | for images,labels in valloader: 92 | for i in range(len(labels)): 93 | img = images[i].view(1, 784) 94 | with torch.no_grad(): 95 | output1 = model_1(img) 96 | y2 = Variable(output1.data, requires_grad=True) 97 | logps = model_2(y2) 98 | 99 | ps = torch.exp(logps) 100 | probab = list(ps.numpy()[0]) 101 | pred_label = probab.index(max(probab)) 102 | true_label = labels.numpy()[i] 103 | if(true_label == pred_label): 104 | correct_count += 1 105 | all_count += 1 106 | 107 | print("Number Of Images Tested =", all_count) 108 | print("\nModel Accuracy =", (correct_count/all_count)) -------------------------------------------------------------------------------- /src/split_with_socket/utils/client.py: -------------------------------------------------------------------------------- 1 | # Simple client library for the relay server, a communication server where two clients 2 | # connect to and can message each other by name. 3 | # 4 | # It sends socket connections with the relay protocol and guarantees the delivery of 5 | # the messages to the server and back to the client. 6 | # 7 | # 1. Registration: {"client_id": client_id} 8 | # This informs the server the name of the current socket connection 9 | # 10 | # 2. Send message to: {"to": the_other_client_id, "data": any object} 11 | # Informs the server to send the data to another known client 12 | # 13 | # USAGE: 14 | # 1. connect(SERVER_IP, PORT, CLIENT_ID) connects to the server and registers as a CLIENT_ID 15 | # 2. sendData(CLIENT_ID, DATA) sends any python object to the desired client 16 | # 3. receiveData() locks the current thread until it receives an object back. 17 | # 4. disconnect() closes the socket. 18 | 19 | import socket # for connecting to the server 20 | import time 21 | import utils.relay_protocol as protocol 22 | import utils.atomic_socket as atomic_socket 23 | import pickle # object serialization 24 | import base64 # Make sure the serialized object can be transformed into a string 25 | 26 | buffer_size = 35000 27 | 28 | # instantiate 29 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 30 | client_name = '' 31 | 32 | last_message = '' 33 | 34 | def connect(relay_server_host, relay_server_port, client_id): 35 | global client_name 36 | client_name = client_id 37 | print('Starting Client ' + client_id) # show in terminal 38 | 39 | # connect to the server 40 | client_socket.connect((relay_server_host, relay_server_port)) 41 | 42 | # register this computer by name 43 | atomic_socket.send(client_socket, protocol.register(client_id)) 44 | 45 | # Wait for 1 second otherwise socket concatenates the two sends 46 | time.sleep(1) 47 | 48 | def sendData(to, data): 49 | global last_message 50 | print('Sending ' + data.__class__.__name__ + ' data to ' + to) 51 | last_message = protocol.send(to, serializeObjToStr(data)) 52 | atomic_socket.send(client_socket, last_message) 53 | 54 | def receiveData(): 55 | while True: 56 | try: 57 | msgRaw = atomic_socket.recv(client_socket) 58 | obj = protocol.parse_data_msg(msgRaw) 59 | 60 | if type(obj) is protocol.Repeat: 61 | print('Repeating: ' + last_message) 62 | atomic_socket.send(client_socket, last_message) 63 | elif type(obj) is protocol.PartnerNotConnected: 64 | print('Error: ' + obj.client_id + ' not connected') 65 | elif type(obj) is protocol.ReceivedFrom: 66 | object = deserializeStrToObj(obj.data) 67 | print('Received ' + object.__class__.__name__ + ' data from ' + obj.client_id) 68 | return object 69 | except Exception as inst: 70 | print(inst) 71 | atomic_socket.send(client_socket, protocol.repeat(client_name)) 72 | 73 | return 74 | 75 | def serializeObjToStr(data): 76 | picked_byte_representation = pickle.dumps(data) 77 | serialized_bytes = base64.b64encode(picked_byte_representation) 78 | return serialized_bytes.decode('utf-8') 79 | 80 | def deserializeStrToObj(str): 81 | encodedStr = str.encode('utf-8') 82 | picked_byte_representation = base64.b64decode(encodedStr) 83 | return pickle.loads(picked_byte_representation) 84 | 85 | def disconnect(): 86 | client_socket.close() # close the connection -------------------------------------------------------------------------------- /src/split_with_socket/utils/server.py: -------------------------------------------------------------------------------- 1 | # Simple relay server: a communication server where two clients 2 | # connect to and can message each other by name 3 | # 4 | # It accepts socket connections with the relay protocol 5 | # 6 | # USAGE: 7 | # python server.py 8 | 9 | import socket # for connecting to the server 10 | import atomic_socket 11 | import _thread # to manage multiple clients 12 | import relay_protocol as protocol 13 | 14 | host = '0.0.0.0' # host public IP address 15 | port = 5004 # initiate port no above 1024 16 | 17 | clients = {} # {id, socket} list of all clients connected to the server. 18 | lastFormattedMsg = {} # {id,message} 19 | 20 | def server_program(): 21 | print('Starting Relay Server at ' + host + ':' + str(port)) # show in terminal 22 | print('Waiting for clients...') 23 | 24 | server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # get instance 25 | server_socket.bind((host, port)) # bind host address and port together 26 | 27 | # configure how many clients the server can listen simultaneously 28 | server_socket.listen(2) 29 | 30 | while True: 31 | conn, address = server_socket.accept() # accept new connection 32 | _thread.start_new_thread(on_new_client,(conn, address)) #puts the communication of this socket into a thread. 33 | 34 | server_socket.close() # close the connection when done. 35 | 36 | def on_new_client(clientsocket, addr): 37 | # wait for the first message, the client's name. 38 | msgRaw = atomic_socket.recv(clientsocket) 39 | 40 | # First object is a registration object. 41 | # Register the name of the client into the list. 42 | obj = protocol.parse_data_msg(msgRaw) 43 | 44 | print("Connection from: " + str(addr) + " " + obj.client_id) 45 | 46 | # Register the socket name at a list. 47 | clients[obj.client_id] = clientsocket 48 | 49 | # keeps waitinf for the messages. 50 | wait_for_next_message(obj.client_id) 51 | 52 | def wait_for_next_message(client_id): 53 | clientsocket = clients[client_id] 54 | 55 | # keeps waitinf for the messages. 56 | # msgRaw = clientsocket.recv(buffer_size) 57 | msgRaw = atomic_socket.recv(clientsocket) 58 | 59 | while msgRaw: 60 | print("Message received from " + client_id) # + ": " + msgRaw 61 | 62 | try: 63 | obj = protocol.parse_data_msg(msgRaw) 64 | print("Object decoded " + obj.__class__.__name__) 65 | 66 | if type(obj) is protocol.Repeat: 67 | print('Repeating the message ' + lastFormattedMsg[obj.client_id]) 68 | atomic_socket.send(clientsocket, lastFormattedMsg[obj.client_id]) 69 | elif type(obj) is protocol.SendTo: 70 | if obj.client_id in clients.keys(): 71 | #reformat the message and store it if we need to repeat it. 72 | lastFormattedMsg[obj.client_id] = protocol.relay_message(client_id, obj.data) 73 | print('Sending message to ' + obj.client_id) # ' with data ' + lastFormattedMsg[obj.client_id] 74 | # sending it to the next partner 75 | atomic_socket.send(clients[obj.client_id], lastFormattedMsg[obj.client_id]) 76 | else: 77 | print(obj.client_id + ' is not connected. Sending error message back.') 78 | atomic_socket.send(clientsocket, protocol.not_connected(obj.client_id)) 79 | else: 80 | print('Unexpected object ' + obj.__class__.__name__ + ' as a reply to server') 81 | 82 | except Exception as inst: 83 | print('Exception') 84 | print(inst) 85 | atomic_socket.send(clientsocket, protocol.repeat(client_id)) 86 | 87 | #wait for next message 88 | msgRaw = atomic_socket.recv(clientsocket) 89 | 90 | clientsocket.close() 91 | 92 | if __name__ == '__main__': 93 | server_program() -------------------------------------------------------------------------------- /src/split_with_socket/mnistHarvard.py: -------------------------------------------------------------------------------- 1 | import utils.client as client 2 | import data_package as dataPkg 3 | 4 | import os 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import matplotlib.pyplot as plt 9 | from time import time 10 | from torchvision import datasets, transforms 11 | from torch import nn, optim 12 | from torch.autograd import Variable 13 | 14 | #relay_server_host = 'nebula.media.mit.edu' # as both code is running on same pc 15 | #relay_server_host = 'ec2-3-14-72-103.us-east-2.compute.amazonaws.com' 16 | relay_server_host = '127.0.0.1' 17 | relay_server_port = 5004 # socket server port number 18 | 19 | client_id = 'HARVARD' 20 | client_send_to = 'MIT' 21 | 22 | #define transformations 23 | transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),]) 24 | 25 | #download dataset 26 | trainset = datasets.MNIST('/tmp/splitlearning_relayserver/datasets/train', download=True, train=True, transform=transform) 27 | valset = datasets.MNIST('/tmp/splitlearning_relayserver/datasets/val', download=True, train=False, transform=transform) 28 | 29 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) 30 | valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True) 31 | 32 | #define models 33 | input_size = 784 34 | hidden_sizes = [128, 64] 35 | 36 | epochs = 2 37 | 38 | def reset_model(): 39 | model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), 40 | nn.ReLU(), 41 | nn.Linear(hidden_sizes[0], hidden_sizes[1])) 42 | 43 | client.sendData(client_send_to, dataPkg.ResetModelPackage()) 44 | package = client.receiveData() 45 | 46 | if (type(package) is dataPkg.NewModelReadyPackage): 47 | return model 48 | else: 49 | print('Could not reset model') 50 | 51 | def train(client, model): 52 | optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9) 53 | for e in range(epochs): 54 | running_loss = 0 55 | img = 0 56 | for images, labels in trainloader: 57 | # Flatten MNIST images into a 784 long vector 58 | images = images.view(images.shape[0], -1) 59 | 60 | # Cleaning gradients 61 | optimizer.zero_grad() 62 | 63 | # evaluate 64 | output = model(images) 65 | 66 | if output.data.size() != (64,64): 67 | continue 68 | 69 | # prepare data for MIT 70 | y2 = Variable(output.data, requires_grad=True) 71 | 72 | # send to MIT to contine the process. 73 | client.sendData(client_send_to, dataPkg.ForwardPropPackage(y2,labels)) 74 | 75 | # wait for MIT to calculate 76 | bwd_package = client.receiveData() 77 | 78 | # backprop 79 | output.backward(bwd_package.grad) 80 | 81 | # optimize the weights 82 | optimizer.step() 83 | 84 | running_loss += bwd_package.loss 85 | 86 | img = img+1 87 | 88 | print("Epoch {} {} - Training loss: {}".format(e, img, running_loss/len(trainloader))) 89 | else: 90 | print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader))) 91 | 92 | print("Training finished") 93 | 94 | return model 95 | 96 | def test(client, model): 97 | correct_count, all_count = 0, 0 98 | image_idx = 0 99 | for images,labels in valloader: 100 | for i in range(len(labels)): 101 | img = images[i].view(1, 784) 102 | 103 | with torch.no_grad(): 104 | output = model(img) 105 | # Prepare data to send to MIT 106 | y2 = Variable(output.data, requires_grad=True) 107 | # Send to MIT to contine the process. 108 | client.sendData(client_send_to, dataPkg.EvaluatePackage(y2)) 109 | # Wait for MIT to calculate and return the logPs 110 | logps = client.receiveData().logps 111 | 112 | ps = torch.exp(logps) 113 | probab = list(ps.detach().numpy()[0]) 114 | 115 | pred_label = probab.index(max(probab)) 116 | true_label = labels.numpy()[i] 117 | 118 | if (true_label == pred_label): 119 | correct_count += 1 120 | 121 | all_count += 1 122 | 123 | print("Eval {} Label {} - Evaluation: {}".format(image_idx, i, true_label == pred_label)) 124 | 125 | image_idx += 1 126 | 127 | print("Number Of Images Tested =", all_count) 128 | print("\nModel Accuracy =", (correct_count/all_count)) 129 | 130 | def harvard_program(): 131 | client.connect(relay_server_host, relay_server_port, client_id) 132 | 133 | model = reset_model() 134 | train(client, model) 135 | test(client, model) 136 | 137 | client.disconnect() 138 | 139 | if __name__ == '__main__': 140 | harvard_program() --------------------------------------------------------------------------------