├── test.py ├── .gitignore ├── models ├── __pycache__ │ └── CNN.cpython-36.pyc └── CNN.py ├── utils ├── __pycache__ │ └── Arguments.cpython-36.pyc └── Arguments.py ├── .github └── workflows │ └── semgrep.yml ├── README.md ├── LICENSE └── main_fed.py /test.py: -------------------------------------------------------------------------------- 1 | vivek 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/data 2 | -------------------------------------------------------------------------------- /models/__pycache__/CNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vivekkhimani/federated_learning_pysyft/HEAD/models/__pycache__/CNN.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Arguments.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vivekkhimani/federated_learning_pysyft/HEAD/utils/__pycache__/Arguments.cpython-36.pyc -------------------------------------------------------------------------------- /.github/workflows/semgrep.yml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_dispatch: {} 3 | pull_request: {} 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - .github/workflows/semgrep.yml 10 | schedule: 11 | # random HH:MM to avoid a load spike on GitHub Actions at 00:00 12 | - cron: 12 14 * * * 13 | name: Semgrep 14 | jobs: 15 | semgrep: 16 | name: semgrep/ci 17 | runs-on: ubuntu-20.04 18 | env: 19 | SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }} 20 | container: 21 | image: returntocorp/semgrep 22 | steps: 23 | - uses: actions/checkout@v3 24 | - run: semgrep ci 25 | -------------------------------------------------------------------------------- /utils/Arguments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class Arguments: 3 | def __init__(self): 4 | self.batch_size = 64 5 | self.test_batch_size = 128 6 | self.epochs = 50 7 | self.local_epochs = 5 8 | self.lr = 0.01 9 | self.momentum = 0.5 10 | self.no_cuda = False 11 | self.seed = 1 12 | self.log_interval = 10 13 | self.save_model = False 14 | 15 | 16 | args = Arguments() 17 | use_cuda = not args.no_cuda and torch.cuda.is_available() 18 | torch.manual_seed(args.seed) 19 | device = device = torch.device("cuda" if use_cuda else "cpu") 20 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementing Federated Learning using PySyft 2 | 3 | ### Basics: 4 | - Dataset - MNIST 5 | - Number of Workers - 32 6 | - Classification Model - CNN (see the details in models directory) 7 | - Tools Used - PySyft, PyTorch 8 | 9 | ### Instructions: 10 | - Prerequisite: python3, pip3, pysyft, pytorch 11 | - RUN: "main_fed.py" 12 | - To edit the basic characteristics of the model, check "/utils/Arguments.py". No CLI has been provided for now. 13 | - To edit the classification model, check "/models/CNN.py" 14 | 15 | 16 | ### Future Work: 17 | - Add a CLI to make the process of editing the arguments easier. 18 | - Facilitate training by selecting a subset of workers instead of using all the workers. 19 | 20 | -------------------------------------------------------------------------------- /models/CNN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | 5 | class CNN(nn.Module): 6 | def __init__(self): 7 | super(CNN, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 9 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 10 | self.conv2_drop = nn.Dropout2d() 11 | self.fc1 = nn.Linear(320, 50) 12 | self.fc2 = nn.Linear(50, 10) 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x, training=self.training) 20 | x = self.fc2(x) 21 | return F.log_softmax(x,dim=1) 22 | 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Vivek Khimani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main_fed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | import syft as sy 7 | from collections import defaultdict 8 | 9 | from utils import Arguments 10 | from models import CNN 11 | 12 | hook = sy.TorchHook(torch) #extra functionality to support FL 13 | 14 | def virtualWorkers(): 15 | one=sy.VirtualWorker(hook, id="one") 16 | two=sy.VirtualWorker(hook, id="two") 17 | three=sy.VirtualWorker(hook, id="three") 18 | four=sy.VirtualWorker(hook, id="four") 19 | five=sy.VirtualWorker(hook, id="five") 20 | six=sy.VirtualWorker(hook, id="six") 21 | seven=sy.VirtualWorker(hook, id="seven") 22 | eight=sy.VirtualWorker(hook, id="eight") 23 | nine=sy.VirtualWorker(hook, id="nine") 24 | ten=sy.VirtualWorker(hook, id="ten") 25 | eleven=sy.VirtualWorker(hook, id="eleven") 26 | twelve=sy.VirtualWorker(hook, id="twelve") 27 | thirteen=sy.VirtualWorker(hook, id="thirteen") 28 | fourteen=sy.VirtualWorker(hook, id="fourteen") 29 | fifteen=sy.VirtualWorker(hook, id="fifteen") 30 | sixteen=sy.VirtualWorker(hook, id="sixteen") 31 | seventeen=sy.VirtualWorker(hook, id="seventeen") 32 | eighteen=sy.VirtualWorker(hook, id="eighteen") 33 | nineteen=sy.VirtualWorker(hook, id="nineteen") 34 | twenty=sy.VirtualWorker(hook, id="twenty") 35 | twenty_one=sy.VirtualWorker(hook, id="twenty_one") 36 | twenty_two=sy.VirtualWorker(hook, id="twenty_two") 37 | twenty_three=sy.VirtualWorker(hook, id="twenty_three") 38 | twenty_four=sy.VirtualWorker(hook, id="twenty_four") 39 | twenty_five=sy.VirtualWorker(hook, id="twenty_five") 40 | twenty_six=sy.VirtualWorker(hook, id="twenty_six") 41 | twenty_seven=sy.VirtualWorker(hook, id="twenty_seven") 42 | twenty_eight=sy.VirtualWorker(hook, id="twenty_eight") 43 | twenty_nine=sy.VirtualWorker(hook, id="twenty_nine") 44 | thirty=sy.VirtualWorker(hook, id="thirty") 45 | thirty_one=sy.VirtualWorker(hook, id="thirty_one") 46 | thirty_two=sy.VirtualWorker(hook, id="thirty_two") 47 | 48 | return [one,two,three,four,five,six,seven,eight,nine,ten,eleven,twelve,thirteen,fourteen,fifteen,sixteen,seventeen,eighteen,nineteen,twenty,twenty_one,twenty_two,twenty_three,twenty_four,twenty_five,twenty_six,twenty_seven,twenty_eight,twenty_nine,thirty,thirty_one,thirty_two] 49 | 50 | vList = virtualWorkers() 51 | 52 | def loadMNISTData(): 53 | federated_train_loader = sy.FederatedDataLoader( 54 | datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])).federate((vList[0],vList[1],vList[2],vList[3],vList[4],vList[5],vList[6],vList[7],vList[8],vList[9],vList[10],vList[11],vList[12],vList[13],vList[14],vList[15],vList[16],vList[17],vList[18],vList[19],vList[20],vList[21],vList[22],vList[23],vList[24],vList[25],vList[26],vList[27],vList[28],vList[29],vList[30],vList[31])),batch_size=Arguments.args.batch_size, shuffle=True, **Arguments.kwargs) 55 | 56 | test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),batch_size=Arguments.args.test_batch_size, shuffle=True, **Arguments.kwargs) 57 | 58 | return federated_train_loader,test_loader 59 | 60 | 61 | def train(args, model, device, train_loader, optimizer, epoch): 62 | model.train() 63 | for batch_idx, (data, target) in enumerate(train_loader): # <-- now it is a distributed dataset 64 | model.send(data.location) # <-- NEW: send the model to the right location 65 | data, target = data.to(device), target.to(device) 66 | optimizer.zero_grad() 67 | output = model(data) 68 | loss = nn.CrossEntropyLoss(output, target) 69 | loss.backward() 70 | optimizer.step() 71 | model.get() # <-- NEW: get the model back 72 | if batch_idx % args.log_interval == 0: 73 | loss = loss.get() # <-- NEW: get the loss back 74 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, #batch_idx * len(data), len(train_loader.dataset), 75 | 100. * batch_idx / len(train_loader), loss.item())) 76 | 77 | 78 | def test(args, model, device, test_loader): 79 | model.eval() 80 | test_loss = 0 81 | correct = 0 82 | with torch.no_grad(): 83 | for data, target in test_loader: 84 | data, target = data.to(device), target.to(device) 85 | output = model(data) 86 | test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss 87 | pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 88 | correct += pred.eq(target.view_as(pred)).sum().item() 89 | 90 | test_loss /= len(test_loader.dataset) 91 | 92 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 93 | test_loss, correct, len(test_loader.dataset), 94 | 100. * correct / len(test_loader.dataset))) 95 | 96 | def launch(): 97 | model = CNN.CNN().to(Arguments.device) 98 | optimizer = optim.SGD(model.parameters(), lr=Arguments.args.lr) 99 | train_loader,test_loader = loadMNISTData() 100 | 101 | for epoch in range(1,Arguments.args.epochs+1): 102 | train(Arguments.args,model,Arguments.device,train_loader,optimizer,epoch) 103 | test(Arguments.args,model,Arguments.device,test_loader) 104 | 105 | if(Arguments.args.save_model): 106 | torch.save(model.state_dict(),"mnist_cnn.pt") 107 | 108 | ##DRIVER## 109 | if __name__ == '__main__': 110 | launch() 111 | --------------------------------------------------------------------------------