├── utils ├── __init__.py ├── __pycache__ │ ├── utils.cpython-39.pyc │ └── __init__.cpython-39.pyc └── utils.py ├── NBI_method.zip ├── model ├── __pycache__ │ ├── lenet.cpython-39.pyc │ ├── CNN_model.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── binarization.cpython-39.pyc │ └── quantization.cpython-39.pyc ├── quantization.py └── CNN_model.py ├── client ├── __pycache__ │ ├── Client.cpython-39.pyc │ └── Client_Class.cpython-39.pyc └── Client_Class.py ├── server ├── __pycache__ │ ├── Server.cpython-39.pyc │ └── Server_Class.cpython-39.pyc └── Server_Class.py ├── README.md ├── LICENSE ├── train_argument.py ├── Simulator.py ├── Split_Data.py └── train.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /NBI_method.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/NBI_method.zip -------------------------------------------------------------------------------- /model/__pycache__/lenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/model/__pycache__/lenet.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /client/__pycache__/Client.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/client/__pycache__/Client.cpython-39.pyc -------------------------------------------------------------------------------- /server/__pycache__/Server.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/server/__pycache__/Server.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/CNN_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/model/__pycache__/CNN_model.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/binarization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/model/__pycache__/binarization.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/quantization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/model/__pycache__/quantization.cpython-39.pyc -------------------------------------------------------------------------------- /client/__pycache__/Client_Class.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/client/__pycache__/Client_Class.cpython-39.pyc -------------------------------------------------------------------------------- /server/__pycache__/Server_Class.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/news-vt/Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design/HEAD/server/__pycache__/Server_Class.cpython-39.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Green-Quantized-FL-over-Wireless-Networks-An-Energy-Efficient-Design 2 | This is a repository for the implementation of the paper "Green, Quantized Federated Learning over Wireless Networks: An Energy-Efficient Design". 3 | 4 | "NBI_method.zip" contains code for the introduced NBI method to find the Pareto boundary 5 | 6 | Another python files (such as Train.py) are Pytorch implementation for the proof of concept of Quantized FL. You can set arbitrary precision levels (n and m), the number of local epochs, schedulingset size, or neural networks to run the presented Quantized FL algorithm. 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 NEWS@VT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /server/Server_Class.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import numpy as np 6 | from numpy import random 7 | 8 | class Server(): 9 | 10 | def __init__(self, args, model): 11 | self.clients_list = np.arange(args.num_clients) 12 | self.args = args 13 | self.global_model = copy.deepcopy(model) 14 | 15 | def sample_clients(self): 16 | """ 17 | Return: array of integers, which corresponds to the indices of sampled deviecs 18 | """ 19 | sampling_set = np.random.choice(self.args.num_clients, self.args.schedulingsize, replace = False) 20 | 21 | return sampling_set 22 | 23 | def broadcast(self, Clients_list, Clients_list_idx = None): 24 | """ 25 | Input: a list of Client class 26 | Flow: Set the current global model to sampled clients 27 | """ 28 | for client_idx in Clients_list_idx: 29 | with torch.no_grad(): 30 | Clients_list[client_idx].model.load_state_dict(copy.deepcopy(self.global_model.state_dict())) 31 | 32 | def aggregation(self, Clients_list, sampling_set): 33 | """ 34 | Input: sampling_set: array of integers, which corresponds to the indices of sampled devices and a list of Client class 35 | Flow: aggregate the updated threholds in the sampling set 36 | """ 37 | #You can change the weights of clients arbitrarily 38 | #For simplicy, we use 1/args.schedulingsize here 39 | 40 | 41 | weight_dict = OrderedDict() 42 | 43 | weight_difference_dict = OrderedDict() 44 | for i, client in enumerate(sampling_set): 45 | local_difference = Clients_list[client].model_difference 46 | if i == 0: 47 | for key in local_difference.keys(): 48 | weight_difference_dict[key] = local_difference[key] * 1/self.args.schedulingsize 49 | else: 50 | for key in local_difference.keys(): 51 | weight_difference_dict[key] += local_difference[key] *1/self.args.schedulingsize 52 | 53 | for key in weight_difference_dict.keys(): 54 | weight_dict[key] = self.global_model.state_dict()[key] + weight_difference_dict[key] 55 | self.global_model.load_state_dict(weight_dict) 56 | 57 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import sys 5 | 6 | path = os.getcwd() #current path 7 | sys.path.append(os.path.abspath(os.path.join(path, os.pardir))) #import the parent directory 8 | 9 | from model import quantization 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | # from binarization import MaskedMLP, MaskedConv2d 14 | 15 | 16 | def list2cuda(_list): 17 | array = np.array(_list) 18 | return numpy2cuda(array) 19 | 20 | def numpy2cuda(array): 21 | tensor = torch.from_numpy(array) 22 | 23 | return tensor2cuda(tensor) 24 | 25 | def tensor2cuda(tensor): 26 | if torch.cuda.is_available(): 27 | tensor = tensor.cuda() 28 | 29 | return tensor 30 | 31 | def one_hot(ids, n_class): 32 | assert len(ids.shape) == 1, 'the ids should be 1-D' 33 | 34 | out_tensor = torch.zeros(len(ids), n_class) 35 | 36 | out_tensor.scatter_(1, ids.cpu().unsqueeze(1), 1.) 37 | 38 | return out_tensor 39 | 40 | def evaluate(_input, _target, method='mean'): 41 | correct = (_input == _target).astype(np.float32) 42 | if method == 'mean': 43 | return correct.mean() 44 | else: 45 | return correct.sum() 46 | 47 | 48 | def create_logger(save_path='', file_type='', level='debug'): 49 | 50 | if level == 'debug': 51 | _level = logging.DEBUG 52 | elif level == 'info': 53 | _level = logging.INFO 54 | 55 | logger = logging.getLogger() #This creats a new logger 56 | logger.setLevel(_level) 57 | 58 | cs = logging.StreamHandler() #It is one of the handdlers in logging module. It sends logging output to streams such as sys.stderr or any file-like object 59 | cs.setLevel(_level) 60 | logger.addHandler(cs) 61 | 62 | if save_path != '': 63 | file_name = os.path.join(save_path, file_type + '_log.txt') 64 | fh = logging.FileHandler(file_name, mode='w') #makes your custom logger to log in to a different file 65 | fh.setLevel(_level) 66 | 67 | logger.addHandler(fh) 68 | 69 | return logger 70 | 71 | def makedirs(path): 72 | if not os.path.exists(path): 73 | os.makedirs(path) 74 | 75 | def load_model(model, file_name): 76 | model.load_state_dict( 77 | torch.load(file_name, map_location=lambda storage, loc: storage)) 78 | 79 | def save_model(model, file_name): 80 | torch.save(model.state_dict(), file_name) 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /train_argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parser(): 4 | #This creates the parser 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--model', choices=["Base_CNN"], default="Base_CNN", #This adds an argument to the parser 7 | help='Which model to use') #These arguments can be access with args.name, where args = parser.parse_args() 8 | parser.add_argument('--dataset', choices=['mnist'], default ='mnist') 9 | parser.add_argument('--data_root', default='data', 10 | help='the directory to save the dataset') 11 | parser.add_argument('--log_root', default='log', 12 | help='the directory to save the logs or other imformations (e.g. images)') 13 | parser.add_argument('--model_root', default='checkpoint', help='the directory to save the models') 14 | parser.add_argument('--load_checkpoint', default='./model/default/model.pth') 15 | parser.add_argument('--affix', default='natural_train', help='the affix for the save folder') 16 | 17 | 18 | ## Training realted 19 | parser.add_argument('--num_clients', '-N', type=int, default=50, help='number of clients') 20 | parser.add_argument('--n_bit', type = int, default = 16, help = 'quantization level for local training') 21 | parser.add_argument('--m_bit', type = int, default = 16, help = 'quantization level for transmission') 22 | parser.add_argument('--schedulingsize', type=int, default = 5, help = 'how many clients will be sampled') 23 | parser.add_argument('--batch_size', '-b', type=int, default=32, help='batch size') 24 | parser.add_argument('--comm_rounds', '-m_e', type=int, default=200, 25 | help='the maximum communication rounds') 26 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.001, help='learning rate') 27 | parser.add_argument('--momentum', type=float, default=0.9, help="SGD momentum(defalt: 0.9)") 28 | parser.add_argument('--gpu', '-g', default='0', help='which gpu to use') 29 | parser.add_argument('--seed', default=1, help='The random seed') 30 | parser.add_argument('--alpha', type=float, default=0.1, help="Dirichelet concentration parameter") 31 | parser.add_argument('--weight_decay', type=float, default=0., help="SGD weight decay(defalt: 0.)") 32 | parser.add_argument('--local_epoch', type=int, default = 1, help = "number of local iterations (default = 5)") 33 | 34 | return parser.parse_args() 35 | 36 | def print_args(args, logger=None): 37 | for k, v in vars(args).items(): 38 | if logger is not None: 39 | logger.info('{:<16} : {}'.format(k, v)) 40 | else: 41 | print('{:<16} : {}'.format(k, v)) 42 | -------------------------------------------------------------------------------- /client/Client_Class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import random 5 | import sys 6 | import os 7 | from collections import OrderedDict 8 | import copy 9 | 10 | path = os.getcwd() #current path 11 | sys.path.append(os.path.abspath(os.path.join(path, os.pardir))) #import the parent directory 12 | 13 | from model import quantization 14 | 15 | 16 | class Client(): 17 | def __init__(self, args, model, loss, client_id, tr_loader, te_loader, device, scheduler = None): 18 | self.args = args 19 | self.model = model 20 | self.loss = loss 21 | self.scheduler = scheduler 22 | self.client_id = client_id 23 | self.tr_loader = tr_loader 24 | self.te_loader = te_loader 25 | self.device = device 26 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr= self.args.learning_rate, 27 | momentum=self.args.momentum, weight_decay=self.args.weight_decay) 28 | self.model_difference = OrderedDict() 29 | 30 | def local_training(self, comm_rounds): 31 | initial = copy.deepcopy(self.model) 32 | for epoch in range(1, self.args.local_epoch+1): 33 | for data, label in self.tr_loader: 34 | data.to(self.device), label.to(self.device) 35 | self.model.train() 36 | output = self.model(data) 37 | loss_val = self.loss(output, label) 38 | 39 | self.optimizer.zero_grad() 40 | loss_val.backward() 41 | self.optimizer.step() 42 | 43 | if self.scheduler is not None: 44 | self.scheduler.step() 45 | for name in self.model.state_dict(): 46 | foo = self.model.state_dict()[name] - initial.state_dict()[name] 47 | quantized_foo = self.uniform_quantize(foo) 48 | self.model_difference[name] = quantized_foo 49 | 50 | def local_test(self): 51 | 52 | total_acc = 0.0 53 | num = 0 54 | self.model.eval() 55 | std_loss = 0. 56 | iteration = 0. 57 | with torch.no_grad(): 58 | for data, label in self.te_loader: 59 | data, label = data.to(self.device), label.to(self.device) 60 | output = self.model(data) 61 | pred = torch.max(output, dim=1)[1] 62 | te_acc = (pred.cpu().numpy()== label.cpu().numpy()).astype(np.float32).sum() 63 | 64 | total_acc += te_acc 65 | num += output.shape[0] 66 | 67 | std_loss += self.loss(output, label) 68 | iteration += 1 69 | std_acc = total_acc/num*100. 70 | std_loss /= iteration 71 | 72 | 73 | return std_acc, std_loss 74 | 75 | def uniform_quantize(self, x): 76 | if self.args.m_bit == 32: 77 | return x 78 | elif self.args.m_bit == 1: 79 | return torch.sign(x) 80 | else: 81 | m = float(2 ** (self.args.m_bit - 1)) 82 | out = torch.round(x * m) / m 83 | return out -------------------------------------------------------------------------------- /model/quantization.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Function for quantization 8 | reference: https://github.com/zzzxxxttt/pytorch_DoReFaNet/tree/master 9 | 10 | Here, I implemented deterministic quantization, which quantizes a given input to the nearest. 11 | I tired to implement the stochastic version, but it took too much time for training (I was not able to vertorize stochastic operations in python). 12 | According to "I.Hubara, (2016)," these two quantization schemes eventually work in the almost same way (the stochastic version generalizes a bit better). 13 | Therefore, for fast training, I adopted the deterministic version for this proof of concept implementation. 14 | """ 15 | def uniform_quantize(n_bit): 16 | class qfn(torch.autograd.Function): 17 | 18 | @staticmethod 19 | def forward(ctx, input): 20 | if n_bit == 32: 21 | out = input 22 | elif n_bit == 1: 23 | out = torch.sign(input) 24 | else: 25 | n = float(2 ** (n_bit - 1)) 26 | out = torch.round(input * n) / n 27 | return out 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | grad_input = grad_output.clone() 32 | return grad_input 33 | 34 | return qfn().apply 35 | 36 | 37 | class weight_quantize_fn(nn.Module): 38 | def __init__(self, n_bit): 39 | super(weight_quantize_fn, self).__init__() 40 | # assert w_bit <= 8 or w_bit == 32 41 | self.n_bit = n_bit 42 | self.uniform_q = uniform_quantize(n_bit = n_bit) 43 | 44 | def forward(self, x): 45 | if self.n_bit == 32: 46 | weight_q = x 47 | else: 48 | weight = torch.clamp(x, min =-1, max = 1) #It clips an input to [-1, 1] 49 | weight_q = self.uniform_q(weight) 50 | return weight_q 51 | 52 | 53 | class activation_quantize_fn(nn.Module): 54 | def __init__(self, n_bit): 55 | super(activation_quantize_fn, self).__init__() 56 | self.n_bit = n_bit 57 | self.uniform_q = uniform_quantize(n_bit=n_bit) 58 | 59 | def forward(self, x): 60 | if self.n_bit == 32: 61 | activation_q = x 62 | else: 63 | activation_q = self.uniform_q(F.leaky_relu(x)) 64 | return activation_q 65 | 66 | 67 | class Conv2d_Q(nn.Conv2d): 68 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 69 | padding=0, dilation=1, groups=1, bias=True): 70 | super(Conv2d_Q, self).__init__(in_channels, out_channels, kernel_size, stride, 71 | padding, dilation, groups, bias) 72 | self.n_bit = None 73 | self.quantize_fn = None 74 | 75 | def set_quantization_level(self, n_bit): 76 | self.n_bit = n_bit 77 | self.quantize_fn = weight_quantize_fn(n_bit=n_bit) 78 | 79 | def forward(self, input, order=None): 80 | weight_q = self.quantize_fn(self.weight) 81 | return F.conv2d(input, weight_q, self.bias, self.stride, 82 | self.padding, self.dilation, self.groups) 83 | 84 | 85 | class Linear_Q(nn.Linear): 86 | def __init__(self, in_features, out_features, bias=True): 87 | super(Linear_Q, self).__init__(in_features, out_features, bias) 88 | self.n_bit = None 89 | self.quantize_fn = None 90 | 91 | def set_quantization_level(self, n_bit): 92 | self.n_bit = n_bit 93 | self.quantize_fn = weight_quantize_fn(n_bit=n_bit) 94 | 95 | def forward(self, input): 96 | weight_q = self.quantize_fn(self.weight) 97 | # print(np.unique(weight_q.detach().numpy())) 98 | return F.linear(input, weight_q, self.bias) -------------------------------------------------------------------------------- /Simulator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import datasets, transforms 5 | from torch.utils.data import DataLoader 6 | from torchvision.transforms import ToTensor 7 | from time import time 8 | import numpy as np 9 | import copy 10 | from model import CNN_model 11 | from server import Server_Class 12 | from Split_Data import Non_iid_split 13 | from client import Client_Class 14 | from utils import* 15 | 16 | class Simulator(): 17 | def __init__(self, args, logger, local_tr_data_loaders, local_te_data_loaders, device): 18 | self.args = args 19 | self.logger = logger 20 | self.Clients_list = None 21 | self.Clients_list = None 22 | self.Server = None 23 | self.local_tr_data_loaders = local_tr_data_loaders 24 | self.local_te_data_loaders = local_te_data_loaders 25 | self.device = device 26 | 27 | 28 | def initialization(self, model): 29 | 30 | loss = nn.CrossEntropyLoss() 31 | 32 | self.Server = Server_Class.Server(self.args, model) 33 | 34 | self.Clients_list = [Client_Class.Client(self.args, copy.deepcopy(self.Server.global_model), loss, 35 | client_id, tr_loader, te_loader, self.device, scheduler=None) 36 | for (client_id, (tr_loader, te_loader)) in enumerate(zip(self.local_tr_data_loaders, self.local_te_data_loaders))] 37 | 38 | def FedAvg(self): 39 | 40 | best_acc = 0 41 | acc_history = [] 42 | 43 | for rounds in np.arange(self.args.comm_rounds): 44 | begin_time = time() 45 | avg_acc =[] 46 | avg_loss =[] 47 | self.logger.info("-"*30 + "Epoch start" + "-"*30) 48 | 49 | sampled_clients = self.Server.sample_clients() 50 | 51 | self.Server.broadcast(self.Clients_list, sampled_clients) 52 | for client_idx in sampled_clients: 53 | acc, loss = self.Clients_list[client_idx].local_test() 54 | avg_acc.append(acc), avg_loss.append(loss) 55 | 56 | for client_idx in sampled_clients: 57 | self.Clients_list[client_idx].local_training(rounds) 58 | 59 | 60 | self.Server.aggregation(self.Clients_list, sampled_clients) 61 | 62 | 63 | avg_acc_round = np.mean(avg_acc) 64 | 65 | acc_history.append(avg_acc_round) #save the current average accuracy to the history 66 | 67 | self.logger.info('round: %d, avg_acc: %.3f, spent: %.2f' %(rounds, avg_acc_round, 68 | time()-begin_time)) 69 | 70 | cur_acc = avg_acc_round 71 | if cur_acc > best_acc: 72 | best_acc =cur_acc 73 | 74 | #####Check final accuracy 75 | self.Server.broadcast(self.Clients_list, range(0, self.args.num_clients)) 76 | final_acc =[] 77 | for client_idx, client in enumerate(self.Clients_list): 78 | acc, loss = client.local_test() 79 | final_acc.append(acc) 80 | self.logger.info('client_id: %d , final acc: %.3f' %( 81 | client_idx, loss)) 82 | final_avg_acc = np.mean(final_acc) 83 | 84 | self.logger.info(">>>>> Training process finish") 85 | self.logger.info("Best test accuracy {:.4f}".format(best_acc)) 86 | self.logger.info("Final test accuracy {:.4f}".format(final_avg_acc)) 87 | self.logger.info(">>>>> Accuracy history during training") 88 | self.logger.info(acc_history) 89 | 90 | -------------------------------------------------------------------------------- /model/CNN_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .quantization import * 5 | 6 | class Base_CNN(nn.Module): 7 | def __init__(self, n_bit, in_features=1, num_classes=10): 8 | super().__init__() 9 | self.n_bit = n_bit 10 | self.activation = activation_quantize_fn(self.n_bit) 11 | 12 | self.conv1= Conv2d_Q(in_features, 128, 13 | kernel_size=3, 14 | padding=0, 15 | stride=1, 16 | bias=True) 17 | self.bn1 = nn.BatchNorm2d(128) 18 | self.conv2= Conv2d_Q(128, 19 | 64, 20 | kernel_size=3, 21 | padding=0, 22 | stride=1, 23 | bias=True) 24 | self.bn2 = nn.BatchNorm2d(64) 25 | self.conv3 = Conv2d_Q(64, 26 | 64, 27 | kernel_size=3, 28 | padding =0, 29 | stride =1, 30 | bias =True) 31 | self.bn3 = nn.BatchNorm2d(64) 32 | self.conv4 = Conv2d_Q(64, 32, kernel_size=3, padding = 0, stride = 1, bias = True) 33 | self.bn4 = nn.BatchNorm2d(32) 34 | self.conv5 = Conv2d_Q(32, 32, kernel_size=3, padding = 0, stride = 1, bias = True) 35 | self.bn5 = nn.BatchNorm2d(32) 36 | self.fc1 = Linear_Q(32, 2000) 37 | #self.bn6 = nn.BatchNorm1d(2000) 38 | self.fc2 = Linear_Q(2000, 100) 39 | #self.bn7 = nn.BatchNorm1d(100) 40 | self.fc3 = Linear_Q(100, 10) 41 | 42 | for name, layer in self.named_modules(): 43 | if isinstance(layer, Conv2d_Q) or isinstance(layer, Linear_Q): 44 | layer.set_quantization_level(self.n_bit) 45 | 46 | 47 | def forward(self, x): 48 | x = self.bn1(self.activation(self.conv1(x))) 49 | x = F.max_pool2d(x, (3, 3), 1) 50 | x = self.bn2(self.activation(self.conv2(x))) 51 | x = F.max_pool2d(x, (3, 3), 2) 52 | x = self.bn3(self.activation(self.conv3(x))) 53 | x = self.bn4(self.activation(self.conv4(x))) 54 | x = self.bn5(self.activation(self.conv5(x))) 55 | x = F.max_pool2d(x, (3, 3), 2) 56 | x = torch.flatten(x,1) 57 | x = self.activation(self.fc1(x)) 58 | x = self.activation(self.fc2(x)) 59 | x = self.fc3(x) 60 | 61 | return x 62 | 63 | class CNN_simple(nn.Module): 64 | def __init__(self, in_features=1, num_classes=10): 65 | super().__init__() 66 | self.conv1 = nn.Conv2d(in_features, 67 | 32, 68 | kernel_size=5, 69 | padding=0, 70 | stride=1, 71 | bias=True) 72 | self.conv2 = nn.Conv2d(32, 73 | 64, 74 | kernel_size=5, 75 | padding=0, 76 | stride=1, 77 | bias=True) 78 | self.fc1 = nn.Linear(1024, 512) 79 | self.fc2 = nn.Linear(512, num_classes) 80 | 81 | self.act = nn.ReLU(inplace=True) 82 | self.maxpool = nn.MaxPool2d(kernel_size=(2, 2)) 83 | 84 | def forward(self, x): 85 | x = self.act(self.conv1(x)) 86 | x = self.maxpool(x) 87 | x = self.act(self.conv2(x)) 88 | x = self.maxpool(x) 89 | x = torch.flatten(x, 1) 90 | x = self.act(self.fc1(x)) 91 | x = self.fc2(x) 92 | return x 93 | 94 | 95 | -------------------------------------------------------------------------------- /Split_Data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import torch 4 | import random 5 | 6 | class Non_iid(Dataset): 7 | def __init__(self, x, y): 8 | self.x_data = x.unsqueeze(1).to(torch.float32) 9 | # self.x_data = x.reshape(x.shape[0], 28, 28, 1) 10 | self.y_data = y.to(torch.int64) 11 | self.cuda_available = torch.cuda.is_available() 12 | 13 | #Return the number of data 14 | def __len__(self): 15 | return len(self.x_data) 16 | 17 | #Sampling 18 | def __getitem__(self, idx): 19 | x = self.x_data[idx] 20 | y = self.y_data[idx] 21 | 22 | if self.cuda_available: 23 | return x.cuda(), y.cuda() 24 | else: 25 | return x, y 26 | 27 | 28 | def data_stats(non_iid_datasets, num_classes, num_clients): 29 | 30 | client_data_counts = {client:{} for client in range(num_clients)} 31 | client_total_samples = [] 32 | for client, data in enumerate(non_iid_datasets): 33 | total_sample = 0 34 | for label in range(num_classes): 35 | idx_label = len(np.where(data.y_data == label)[0]) 36 | # client_data_counts[client].append(idx_label/data.__len__() * 100) 37 | label_sum = np.sum(idx_label) 38 | client_data_counts[client][label] = label_sum 39 | total_sample += label_sum 40 | client_total_samples.append(total_sample) 41 | 42 | return client_data_counts, client_total_samples 43 | 44 | def Non_iid_split(num_classes, num_clients, tr_datasets, te_datasets, alpha): 45 | """ 46 | Input: num_classes, num_clients, datasets, alpha 47 | Output: Dataset classes of the number of num_clients 48 | """ 49 | tr_idx_batch = [[] for _ in range(num_clients)] 50 | tr_data_index_map = {} 51 | te_idx_batch = [[] for _ in range(num_clients)] 52 | te_data_index_map = {} 53 | 54 | #for each calss in the training/test dataset 55 | for label in range(num_classes): 56 | proportions = np.random.dirichlet(np.repeat(alpha, num_clients)) #It generates dirichichlet random variable with alpha over num_clients 57 | 58 | tr_idx_label = np.where(tr_datasets.targets == label)[0] #np.where returns corresponding indices and datatype 59 | np.random.shuffle(tr_idx_label) 60 | tr_proportions = (np.cumsum(proportions) * len(tr_idx_label)).astype(int)[:-1] 61 | 62 | tr_idx_batch = [idx_j + idx.tolist() for idx_j, idx in 63 | zip(tr_idx_batch, np.split(tr_idx_label, tr_proportions))] 64 | 65 | te_idx_label = np.where(te_datasets.targets == label)[0] 66 | np.random.shuffle(te_idx_label) 67 | te_proportions = (np.cumsum(proportions) * len(te_idx_label)).astype(int)[:-1] 68 | 69 | te_idx_batch = [idx_j + idx.tolist() for idx_j, idx in 70 | zip(te_idx_batch, np.split(te_idx_label, te_proportions))] 71 | 72 | for client in range(num_clients): 73 | np.random.shuffle(tr_idx_batch[client]) 74 | tr_data_index_map[client] = tr_idx_batch[client] 75 | te_data_index_map[client] = te_idx_batch[client] 76 | 77 | Non_iid_tr_datasets = [] 78 | Non_iid_te_datasets = [] 79 | 80 | for client in range(num_clients): 81 | tr_x_data = tr_datasets.data[tr_data_index_map[client]] 82 | tr_y_data = tr_datasets.targets[tr_data_index_map[client]] 83 | Non_iid_tr_datasets.append(Non_iid(tr_x_data, tr_y_data)) 84 | 85 | te_x_data = te_datasets.data[te_data_index_map[client]] 86 | te_y_data = te_datasets.targets[te_data_index_map[client]] 87 | Non_iid_te_datasets.append(Non_iid(te_x_data, te_y_data)) 88 | 89 | return Non_iid_tr_datasets, Non_iid_te_datasets 90 | 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | import torchvision 10 | 11 | from train_argument import parser, print_args 12 | 13 | import random 14 | import copy 15 | 16 | from time import time 17 | from model import CNN_model 18 | from utils import * 19 | from Simulator import Simulator 20 | from Split_Data import Non_iid_split, data_stats 21 | 22 | def main(args): 23 | save_folder = args.affix 24 | 25 | log_folder = os.path.join(args.log_root, save_folder) #return a new path 26 | model_folder = os.path.join(args.model_root, save_folder) 27 | 28 | makedirs(log_folder) 29 | makedirs(model_folder) 30 | 31 | 32 | setattr(args, 'log_folder', log_folder) #setattr(obj, var, val) assign object attribute to its value, just like args.'log_folder' = log_folder 33 | setattr(args, 'model_folder', model_folder) 34 | 35 | logger = create_logger(log_folder, 'train', 'info') 36 | print_args(args, logger) #It prints arguments 37 | 38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | num_classes = 10 40 | 41 | if args.dataset =='mnist': 42 | tr_dataset = torchvision.datasets.MNIST(args.data_root, 43 | train=True, 44 | transform=torchvision.transforms.ToTensor(), 45 | download=True) 46 | 47 | # evaluation during training 48 | te_dataset = torchvision.datasets.MNIST(args.data_root, 49 | train=False, 50 | transform=torchvision.transforms.ToTensor(), 51 | download=True) 52 | 53 | 54 | Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split( 55 | num_classes, args.num_clients, tr_dataset, te_dataset, args.alpha) 56 | 57 | local_tr_data_loaders = [DataLoader(dataset, num_workers = 0, 58 | batch_size = args.batch_size, 59 | shuffle = True) 60 | for dataset in Non_iid_tr_datasets] 61 | local_te_data_loaders = [DataLoader(dataset, num_workers = 0, 62 | batch_size = args.batch_size, 63 | shuffle = True) 64 | for dataset in Non_iid_te_datasets] 65 | 66 | client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, num_classes, args.num_clients) 67 | client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, num_classes, args.num_clients) 68 | 69 | while 1 in np.remainder(client_total_samples, args.batch_size) or 1 in np.remainder(client_total_te_samples, args.batch_size): #There should be more than one sample in a batch 70 | Non_iid_tr_datasets, Non_iid_te_datasets = Non_iid_split( 71 | num_classes, args.num_clients, tr_dataset, te_dataset, args.alpha) 72 | client_data_counts, client_total_samples = data_stats(Non_iid_tr_datasets, num_classes, args.num_clients) 73 | client_te_data_counts, client_total_te_samples = data_stats(Non_iid_te_datasets, num_classes, args.num_clients) 74 | 75 | 76 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 77 | 78 | if args.model == "Base_CNN": 79 | model = CNN_model.Base_CNN(n_bit= args.n_bit).to(device) 80 | 81 | trainer = Simulator(args, logger, local_tr_data_loaders, local_te_data_loaders, device) 82 | trainer.initialization(copy.deepcopy(model)) 83 | trainer.FedAvg() 84 | 85 | if __name__ == '__main__': 86 | args = parser() 87 | print_args(args) 88 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 89 | random.seed(args.seed) 90 | main(args) 91 | --------------------------------------------------------------------------------