├── fedlab ├── board │ ├── builtin │ │ ├── __init__.py │ │ ├── renderer.py │ │ └── charts.py │ ├── front │ │ ├── __init__.py │ │ └── assets │ │ │ ├── favicon.png │ │ │ └── fedboard.png │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── roles.py │ │ ├── color.py │ │ ├── data.py │ │ └── io.py │ ├── requirements.txt │ └── delegate.py ├── contrib │ ├── algorithm │ │ ├── cfl.py │ │ ├── bypass_bn.py │ │ ├── __init__.py │ │ ├── fedmgda+.py │ │ ├── fedavgm.py │ │ ├── fedasam.py │ │ ├── qfedavg.py │ │ ├── powerofchoice.py │ │ ├── fednova.py │ │ ├── fedopt.py │ │ ├── fedavg.py │ │ ├── feddyn.py │ │ ├── fedsam.py │ │ ├── ifca.py │ │ ├── minimizers.py │ │ ├── ditto.py │ │ ├── fedGamma.py │ │ ├── fedSMOO_woReg.py │ │ ├── scaffold.py │ │ ├── fedprox.py │ │ ├── mofedsam.py │ │ └── fedgf.py │ ├── client_sampler │ │ ├── divfl.py │ │ ├── mabs.py │ │ ├── vrb.py │ │ ├── __init__.py │ │ ├── power_of_choice.py │ │ ├── base_sampler.py │ │ ├── uniform_sampler.py │ │ └── importance_sampler.py │ ├── __init__.py │ ├── compressor │ │ ├── __init__.py │ │ ├── compressor.py │ │ ├── topk.py │ │ └── quantization.py │ └── dataset │ │ ├── __init__.py │ │ ├── femnist.py │ │ ├── celeba.py │ │ ├── synthetic_dataset.py │ │ ├── shakespeare.py │ │ ├── sent140.py │ │ ├── rotated_cifar10.py │ │ ├── rotated_mnist.py │ │ ├── pathological_mnist.py │ │ ├── adult.py │ │ ├── basic_dataset.py │ │ └── fcube.py ├── __init__.py ├── core │ ├── server │ │ ├── __init__.py │ │ ├── hierarchical │ │ │ ├── __init__.py │ │ │ └── scheduler.py │ │ └── handler.py │ ├── __init__.py │ ├── client │ │ ├── __init__.py │ │ └── trainer.py │ ├── communicator │ │ ├── __init__.py │ │ └── processor.py │ ├── standalone.py │ ├── network_manager.py │ ├── coordinator.py │ └── model_maintainer.py ├── models │ ├── __init__.py │ ├── mlp.py │ ├── FedSAMcnn.py │ ├── resnet_cifar100_del_batch.py │ ├── rnn.py │ └── cnn.py └── utils │ ├── __init__.py │ ├── dataset │ └── __init__.py │ ├── message_code.py │ ├── logger.py │ └── aggregator.py ├── tools ├── Lib │ ├── models.py │ ├── arg_parser.py │ └── datasets.py └── main.py ├── README.md ├── environment.yaml └── .gitignore /fedlab/board/builtin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/board/front/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/cfl.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/divfl.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/mabs.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/vrb.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/power_of_choice.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedlab/board/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['delegate', 'fedboard'] 2 | -------------------------------------------------------------------------------- /fedlab/board/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['color', 'data', 'io'] 2 | -------------------------------------------------------------------------------- /fedlab/board/front/assets/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwan-sig/Official-FedGF/HEAD/fedlab/board/front/assets/favicon.png -------------------------------------------------------------------------------- /fedlab/board/front/assets/fedboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwan-sig/Official-FedGF/HEAD/fedlab/board/front/assets/fedboard.png -------------------------------------------------------------------------------- /fedlab/board/requirements.txt: -------------------------------------------------------------------------------- 1 | dash 2 | dash[diskcache] 3 | dash-mantine-components 4 | dash-core-components 5 | dash-cytoscape 6 | dash-html-components 7 | dash_iconify -------------------------------------------------------------------------------- /fedlab/board/utils/roles.py: -------------------------------------------------------------------------------- 1 | CLIENT_HOLDER = 1 2 | SERVER = 1 << 1 3 | BOARD_SHOWER = 1 << 2 4 | 5 | ALL = CLIENT_HOLDER | SERVER | BOARD_SHOWER 6 | SERVER_SHOWER = SERVER | BOARD_SHOWER 7 | 8 | 9 | def is_client_holder(role): 10 | return bool(role & CLIENT_HOLDER) 11 | 12 | 13 | def is_server(role): 14 | return bool(role & SERVER) 15 | 16 | 17 | def is_board_shower(role): 18 | return bool(role & BOARD_SHOWER) 19 | -------------------------------------------------------------------------------- /fedlab/board/utils/color.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def randomcolor(): 5 | colorArr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'] 6 | color = "" 7 | for i in range(6): 8 | color += colorArr[random.randint(0, 14)] 9 | return "#" + color 10 | 11 | 12 | color_box = [randomcolor() for c in range(20000)] 13 | 14 | 15 | def random_color(index): 16 | return color_box[index] 17 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/base_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | class FedSampler: 4 | __metaclass__ = ABCMeta 5 | 6 | @abstractmethod 7 | def __init__(self, n): 8 | self.n = n 9 | 10 | @abstractmethod 11 | def candidate(self, size): 12 | pass 13 | 14 | @abstractmethod 15 | def sample(self, size): 16 | pass 17 | 18 | @abstractmethod 19 | def update(self, val): 20 | pass 21 | -------------------------------------------------------------------------------- /fedlab/board/utils/data.py: -------------------------------------------------------------------------------- 1 | def encode_int_array(arr: list[int]): 2 | arr = sorted(arr) 3 | bits = [] 4 | for idx in arr: 5 | while idx > len(bits): 6 | bits.append(0) 7 | bits.append(1) 8 | while len(bits) % 4 != 0: 9 | bits.append(0) 10 | grouped_list = [bits[i:i + 4] for i in range(0, len(bits), 4)] 11 | hex_list = [hex(int(''.join(map(str, group)), 2))[2:] for group in grouped_list] 12 | hex_string = ''.join(hex_list) 13 | return hex_string 14 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/bypass_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | 5 | 6 | def disable_running_stats(model): 7 | def _disable(module): 8 | if isinstance(module, _BatchNorm): 9 | module.backup_momentum = module.momentum 10 | module.momentum = 0 11 | 12 | model.apply(_disable) 13 | 14 | def enable_running_stats(model): 15 | def _enable(module): 16 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): 17 | module.momentum = module.backup_momentum 18 | 19 | model.apply(_enable) -------------------------------------------------------------------------------- /fedlab/board/delegate.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any 3 | 4 | 5 | class FedBoardDelegate: 6 | def __init__(self): 7 | pass 8 | 9 | @abstractmethod 10 | def read_client_label(self, client_id: str, client_rank: str, type: str) -> list[Any]: 11 | """ 12 | 13 | Args: 14 | client_id: which client 15 | type: usually 'train', 'test' and 'val' 16 | 17 | Returns: 18 | list of client labels 19 | 20 | """ 21 | pass 22 | 23 | @abstractmethod 24 | def sample_client_data(self, client_id: str, client_rank: str, type: str, amount: int) -> list[tuple[Any, Any]]: 25 | pass 26 | -------------------------------------------------------------------------------- /fedlab/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fedlab/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "1.3.0" 16 | -------------------------------------------------------------------------------- /fedlab/contrib/compressor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .quantization import QSGDCompressor 16 | from .topk import TopkCompressor -------------------------------------------------------------------------------- /fedlab/core/server/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .manager import SynchronousServerManager, AsynchronousServerManager 17 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/uniform_sampler.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import FedSampler 2 | import numpy as np 3 | 4 | class RandomSampler(FedSampler): 5 | def __init__(self, n, probs=None): 6 | self.name = "random_sampling" 7 | self.n = n 8 | self.p = probs if probs is not None else np.ones(n) / float(n) 9 | 10 | assert len(self.p) == self.n 11 | 12 | def sample(self, k, replace=False): 13 | if k == self.n: 14 | self.last_sampled = np.arange(self.n), self.p 15 | return np.arange(self.n) 16 | else: 17 | sampled = np.random.choice(self.n, k, p=self.p, replace=replace) 18 | self.last_sampled = sampled, self.p[sampled] 19 | return np.sort(sampled) 20 | 21 | def update(self, probs): 22 | self.p = probs 23 | -------------------------------------------------------------------------------- /fedlab/core/server/hierarchical/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .connector import ClientConnector, ServerConnector 17 | from .scheduler import Scheduler 18 | -------------------------------------------------------------------------------- /fedlab/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .cnn import CNN_CIFAR10, CNN_FEMNIST, CNN_MNIST 16 | from .rnn import RNN_Shakespeare 17 | from .mlp import MLP, MLP_CelebA -------------------------------------------------------------------------------- /fedlab/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import client, server, communicator 16 | from .network import DistNetwork 17 | from .network_manager import NetworkManager 18 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 4 | from .basic_server import SyncServerHandler, AsyncServerHandler 5 | 6 | from .ditto import DittoSerialClientTrainer, DittoServerHandler 7 | from .fedavg import FedAvgSerialClientTrainer, FedAvgServerHandler 8 | from .feddyn import FedDynSerialClientTrainer, FedDynServerHandler 9 | from .fednova import FedNovaSerialClientTrainer, FedNovaServerHandler 10 | from .fedprox import FedProxSerialClientTrainer, FedProxClientTrainer, FedProxServerHandler 11 | from .ifca import IFCASerialClientTrainer, IFCAServerHander 12 | from .powerofchoice import PowerofchoiceSerialClientTrainer, PowerofchoicePipeline, Powerofchoice 13 | from .qfedavg import qFedAvgClientTrainer, qFedAvgServerHandler 14 | from .scaffold import ScaffoldSerialClientTrainer, ScaffoldServerHandler -------------------------------------------------------------------------------- /fedlab/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .aggregator import Aggregators 17 | from .logger import Logger 18 | from .message_code import MessageCode 19 | from .serialization import SerializationTool -------------------------------------------------------------------------------- /fedlab/core/client/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ORDINARY_TRAINER = 0 16 | SERIAL_TRAINER = 1 17 | 18 | from .manager import ClientManager, ActiveClientManager, PassiveClientManager 19 | from .trainer import ClientTrainer -------------------------------------------------------------------------------- /fedlab/contrib/compressor/compressor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC 16 | 17 | 18 | class Compressor(ABC): 19 | def __init__(self) -> None: 20 | super().__init__() 21 | 22 | def compress(self, *args, **kwargs): 23 | raise NotImplementedError() 24 | 25 | def decompress(self, *args, **kwargs): 26 | raise NotImplementedError() 27 | -------------------------------------------------------------------------------- /fedlab/utils/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .partition import DataPartitioner, BasicPartitioner, VisionPartitioner 16 | from .partition import CIFAR10Partitioner, CIFAR100Partitioner, FMNISTPartitioner, MNISTPartitioner, \ 17 | SVHNPartitioner 18 | from .partition import FCUBEPartitioner 19 | from .partition import AdultPartitioner, RCV1Partitioner, CovtypePartitioner 20 | -------------------------------------------------------------------------------- /tools/Lib/models.py: -------------------------------------------------------------------------------- 1 | from fedlab.models.mlp import MLP 2 | from fedlab.models.cnn import * 3 | from fedlab.models.FedSAMcnn import * 4 | from fedlab.models.resnet_cifar100_del_batch import ResNet18 as resnet18_DelBatch 5 | 6 | def get_model(args): 7 | if args.model == 'MLP': 8 | if args.dataset == 'mnist': 9 | model = MLP(784, 10) 10 | elif args.model == 'cnn': 11 | if args.dataset == 'mnist': 12 | model = CNN_MNIST() 13 | elif args.dataset == 'femnist': 14 | model = CNN_FEMNIST() 15 | elif args.dataset == 'cifar10': 16 | model = CNN_CIFAR10() 17 | elif args.dataset == 'cifar100': 18 | model = CNN_CIFAR100() 19 | elif args.model == 'FedSAMcnn': 20 | if args.dataset == 'cifar10': 21 | model = FedsamCNN_CIFAR10() 22 | elif args.dataset == 'cifar100': 23 | model = FedsamCNN_CIFAR100() 24 | elif args.model == 'resnet18_nonorm': 25 | if args.dataset == 'cifar100': 26 | model = resnet18_DelBatch() 27 | return model 28 | -------------------------------------------------------------------------------- /fedlab/utils/message_code.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | 17 | 18 | class MessageCode(Enum): 19 | """Different types of messages between client and server that we support go here.""" 20 | # Server and Client communication agreements 21 | ParameterRequest = 0 22 | GradientUpdate = 1 23 | ParameterUpdate = 2 24 | EvaluateParams = 3 25 | Exit = 4 26 | SetUp = 5 27 | Activation = 6 28 | 29 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .basic_dataset import FedDataset, BaseDataset, Subset 16 | from .fcube import FCUBE 17 | from .covtype import Covtype 18 | from .rcv1 import RCV1 19 | 20 | from .pathological_mnist import PathologicalMNIST 21 | from .rotated_mnist import RotatedMNIST 22 | from .rotated_cifar10 import RotatedCIFAR10 23 | from .partitioned_mnist import PartitionedMNIST 24 | from .partitioned_cifar10 import PartitionedCIFAR10 25 | from .synthetic_dataset import SyntheticDataset 26 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/femnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class FemnistDataset(Dataset): 7 | def __init__(self, client_id: int, client_str: str, data: list, 8 | targets: list): 9 | """get `Dataset` for femnist dataset 10 | Args: 11 | client_id (int): client id 12 | client_str (str): client name string 13 | data (list): image data list 14 | targets (list): image class target list 15 | """ 16 | self.client_id = client_id 17 | self.client_str = client_str 18 | self.data = data 19 | self.targets = targets 20 | self._process_data_target() 21 | 22 | def _process_data_target(self): 23 | """process client's data and target 24 | """ 25 | self.data = torch.tensor(self.data, 26 | dtype=torch.float32).reshape(-1, 1, 28, 28) 27 | self.targets = torch.tensor(self.targets, dtype=torch.long) 28 | 29 | def __len__(self): 30 | return len(self.targets) 31 | 32 | def __getitem__(self, index): 33 | return self.data[index], self.targets[index] -------------------------------------------------------------------------------- /fedlab/models/mlp.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | 5 | class MLP_CelebA(nn.Module): 6 | """Used for celeba experiment""" 7 | 8 | def __init__(self): 9 | super(MLP_CelebA, self).__init__() 10 | self.fc1 = nn.Linear(12288, 2048) # image_size=64, 64*64*3 11 | self.relu1 = nn.ReLU() 12 | self.fc2 = nn.Linear(2048, 500) 13 | self.relu2 = nn.ReLU() 14 | self.fc3 = nn.Linear(500, 100) 15 | self.relu3 = nn.ReLU() 16 | self.fc4 = nn.Linear(100, 2) 17 | 18 | def forward(self, x): 19 | x = x.view(x.shape[0], -1) 20 | x = self.relu1(self.fc1(x)) 21 | x = self.relu2(self.fc2(x)) 22 | x = self.relu3(self.fc3(x)) 23 | x = self.fc4(x) 24 | return x 25 | 26 | 27 | class MLP(nn.Module): 28 | def __init__(self, input_size, output_size): 29 | super(MLP, self).__init__() 30 | self.fc1 = nn.Linear(input_size, 200) 31 | self.fc2 = nn.Linear(200, 200) 32 | self.fc3 = nn.Linear(200, output_size) 33 | self.relu = nn.ReLU() 34 | 35 | def forward(self, x): 36 | x = x.view(x.shape[0], -1) 37 | x = self.relu(self.fc1(x)) 38 | x = self.relu(self.fc2(x)) 39 | x = self.fc3(x) 40 | return x -------------------------------------------------------------------------------- /fedlab/core/communicator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """FedLab communication API""" 16 | HEADER_SENDER_RANK_IDX = 0 17 | HEADER_RECEIVER_RANK_IDX = 1 18 | HEADER_SLICE_SIZE_IDX = 2 19 | HEADER_MESSAGE_CODE_IDX = 3 20 | HEADER_DATA_TYPE_IDX = 4 21 | 22 | DEFAULT_RECEIVER_RANK = -1 23 | DEFAULT_SLICE_SIZE = 0 24 | DEFAULT_MESSAGE_CODE_VALUE = 0 25 | 26 | HEADER_SIZE = 5 27 | 28 | # DATA TYPE CONSTANT 29 | INT8 = 0 30 | INT16 = 1 31 | INT32 = 2 32 | INT64 = 3 33 | 34 | FLOAT16 = 4 35 | FLOAT32 = 5 36 | FLOAT64 = 6 37 | 38 | 39 | def dtype_torch2flab(torch_type): 40 | return supported_torch_dtypes.index(torch_type) 41 | 42 | def dtype_flab2torch(fedlab_type): 43 | return supported_torch_dtypes[fedlab_type] 44 | 45 | from .package import Package, supported_torch_dtypes 46 | from .processor import PackageProcessor 47 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedmgda+.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .utils_algorithms import MinNormSolver 4 | 5 | from .basic_server import SyncServerHandler 6 | from ...utils.aggregator import Aggregators 7 | from ...utils.serialization import SerializationTool 8 | 9 | 10 | 11 | class FedMGDAServerHandler(SyncServerHandler): 12 | def setup_optim(self, sampler, lr): 13 | self.n = self.num_clients 14 | self.num_to_sample = int(self.sample_ratio * self.n) 15 | self.round_clients = int(self.sample_ratio * self.n) 16 | self.sampler = sampler 17 | 18 | self.lr = lr 19 | self.solver = MinNormSolver 20 | 21 | @property 22 | def num_clients_per_round(self): 23 | return self.round_clients 24 | 25 | def sample_clients(self, num_to_sample=None): 26 | clients = self.sampler.sample(self.num_to_sample) 27 | self.round_clients = len(clients) 28 | assert self.num_clients_per_round == len(clients) 29 | return clients 30 | 31 | def global_update(self, buffer): 32 | gradient_list = [ 33 | torch.sub(self.model_parameters, ele[0]) for ele in buffer 34 | ] 35 | 36 | # MGDA+ 37 | norms = np.array( 38 | [torch.norm(grad, p=2, dim=0).item() for grad in gradient_list]) 39 | normlized_gradients = [ 40 | grad / n for grad, n in zip(gradient_list, norms) 41 | ] 42 | sol, val = self.solver.find_min_norm_element_FW(normlized_gradients) 43 | print("GDA {}".format(val)) 44 | assert val > 1e-5 45 | estimates = Aggregators.fedavg_aggregate(normlized_gradients, sol) 46 | 47 | serialized_parameters = self.model_parameters - self.lr * estimates 48 | SerializationTool.deserialize_model(self._model, serialized_parameters) 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official-FedGF 2 | 3 | This repository contains the official implementation of 4 | 5 | > Taehwan Lee and Sung Whan Yoon, [Rethinking the Flat Minima Searching in Federated Learning](https://openreview.net/pdf?id=6TM62kpI5c), International Conference on Machine Learning (ICML) 2024. 6 | 7 | We refer to the [FedLab](https://github.com/SMILELab-FL/FedLab) for creating our project. 8 | 9 | ## Setup 10 | 11 | ### Environment 12 | - install conda environment (preferred): `conda env create -f environment.yaml` 13 | 14 | ### Weights and Biases 15 | - The code runs with WANDB. For setting up your profile, we refer you to the [quickstart documentation](https://docs.wandb.ai/quickstart). 16 | - WANDB MODE is set to "online" by default. 17 | - If you set `args.wandb_project_name` as `debug`, WANDB will be 'disabled'. 18 | - You also can switch to "offline" [Here](https://github.com/hwan-sig/Official-FedGF/blob/main/tools/main.py#L32). 19 | 20 | ## Datasets 21 | - Overview: Image Dataset based on [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) and [Federated Vision Datasets](https://github.com/google-research/google-research/tree/master/federated_vision_datasets) 22 | - We store the json file that distribute the images with Dirichlet's distribution in `tools/json_data`. 23 | - 100 users have 500 images each. Different $\alpha$ value is possible in Dirichlet's distribution. 24 | 25 | ## Running experiments 26 | An example command can be found in `tools/experiments` 27 | ```shell 28 | cd tools/experiments 29 | chmod +x cifar.sh 30 | ./cifar.sh 31 | ``` 32 | 33 | ## Bibtex 34 | ``` 35 | @inproceedings{ 36 | leerethinking, 37 | title={Rethinking the Flat Minima Searching in Federated Learning}, 38 | author={Lee, Taehwan and Yoon, Sung Whan}, 39 | booktitle={Forty-first International Conference on Machine Learning} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from pathlib import Path 4 | from torch.utils.data import Dataset 5 | 6 | from .basic_dataset import FedDataset 7 | 8 | class CelebADataset(Dataset): 9 | def __init__(self, 10 | client_id: int, 11 | client_str: str, 12 | data: list, 13 | targets: list, 14 | image_root: str, 15 | transform=None): 16 | """get `Dataset` for CelebA dataset 17 | Args: 18 | client_id (int): client id 19 | client_str (str): client name string 20 | data (list): input image name list data 21 | targets (list): output label list 22 | """ 23 | self.client_id = client_id 24 | self.client_str = client_str 25 | self.image_root = Path(__file__).parent.resolve() / image_root 26 | self.transform = transform 27 | self.data = data 28 | self.targets = targets 29 | self._process_data_target() 30 | 31 | def _process_data_target(self): 32 | """process client's data and target 33 | """ 34 | data = [] 35 | targets = [] 36 | for idx in range(len(self.data)): 37 | image_path = self.image_root / self.data[idx] 38 | image = Image.open(image_path).convert('RGB') 39 | data.append(image) 40 | targets.append(torch.tensor(self.targets[idx], dtype=torch.long)) 41 | self.data = data 42 | self.targets = targets 43 | 44 | def __len__(self): 45 | return len(self.targets) 46 | 47 | def __getitem__(self, index): 48 | data = self.data[index] 49 | if self.transform: 50 | data = self.transform(data) 51 | target = self.targets[index] 52 | return data, target -------------------------------------------------------------------------------- /fedlab/core/standalone.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .client.trainer import SerialClientTrainer 16 | from .server.handler import ServerHandler 17 | 18 | class StandalonePipeline(object): 19 | def __init__(self, handler: ServerHandler, trainer: SerialClientTrainer): 20 | """Perform standalone simulation process. 21 | 22 | Args: 23 | handler (ServerHandler): _description_ 24 | trainer (SerialClientTrainer): _description_ 25 | """ 26 | self.handler = handler 27 | self.trainer = trainer 28 | 29 | # initialization 30 | self.handler.num_clients = self.trainer.num_clients 31 | 32 | def main(self): 33 | while self.handler.if_stop is False: 34 | # server side 35 | sampled_clients = self.handler.sample_clients() 36 | broadcast = self.handler.downlink_package 37 | 38 | # client side 39 | self.trainer.local_process(broadcast, sampled_clients) 40 | uploads = self.trainer.uplink_package 41 | 42 | # server side 43 | for pack in uploads: 44 | self.handler.load(pack) 45 | 46 | # evaluate 47 | self.evaluate() 48 | # self.handler.evaluate() 49 | 50 | def evaluate(self): 51 | print("This is a example implementation. Please read the source code at fedlab.core.standalone.") 52 | -------------------------------------------------------------------------------- /fedlab/models/FedSAMcnn.py: -------------------------------------------------------------------------------- 1 | """CNN model in pytorch 2 | References: 3 | [1] Reddi S, Charles Z, Zaheer M, et al. 4 | Adaptive Federated Optimization. ICML 2020. 5 | https://arxiv.org/pdf/2003.00295.pdf 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class FedsamCNN_CIFAR10(nn.Module): 13 | """from torch tutorial 14 | https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 15 | """ 16 | def __init__(self): 17 | super(FedsamCNN_CIFAR10,self).__init__() 18 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 19 | self.pool = nn.MaxPool2d(2, 2) 20 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 21 | self.fc1 = nn.Linear(64*5*5, 384) 22 | self.fc2 = nn.Linear(384, 192) 23 | self.fc3 = nn.Linear(192, 10) 24 | 25 | def forward(self, x): 26 | x = self.pool(F.relu(self.conv1(x))) 27 | x = self.pool(F.relu(self.conv2(x))) 28 | x = torch.flatten(x, 1) # flatten all dimensions except batch 29 | x = F.relu(self.fc1(x)) 30 | x = F.relu(self.fc2(x)) 31 | x = self.fc3(x) 32 | return x 33 | 34 | class FedsamCNN_CIFAR100(nn.Module): 35 | """from torch tutorial 36 | https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 37 | """ 38 | def __init__(self): 39 | super(FedsamCNN_CIFAR100,self).__init__() 40 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 41 | self.pool = nn.MaxPool2d(2, 2) 42 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 43 | self.fc1 = nn.Linear(64 * 5 * 5, 384) 44 | self.fc2 = nn.Linear(384, 192) 45 | self.fc3 = nn.Linear(192, 100) 46 | 47 | def forward(self, x): 48 | x = self.pool(F.relu(self.conv1(x))) 49 | x = self.pool(F.relu(self.conv2(x))) 50 | x = torch.flatten(x, 1) # flatten all dimensions except batch 51 | x = F.relu(self.fc1(x)) 52 | x = F.relu(self.fc2(x)) 53 | x = self.fc3(x) 54 | return x 55 | -------------------------------------------------------------------------------- /fedlab/core/network_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch.multiprocessing import Process 16 | 17 | from .network import DistNetwork 18 | 19 | class NetworkManager(Process): 20 | """Abstract class. 21 | 22 | Args: 23 | network (DistNetwork): object to manage torch.distributed network communication. 24 | """ 25 | def __init__(self, network: DistNetwork): 26 | super(NetworkManager, self).__init__() 27 | self._network = network 28 | 29 | def run(self): 30 | """ 31 | Main Process: 32 | 33 | 1. Initialization stage. 34 | 2. FL communication stage. 35 | 3. Shutdown stage. Close network connection. 36 | """ 37 | self.setup() 38 | self.main_loop() 39 | self.shutdown() 40 | 41 | def setup(self): 42 | """Initialize network connection and necessary setups. 43 | 44 | At first, ``self._network.init_network_connection()`` is required to be called. 45 | 46 | Overwrite this method to implement system setup message communication procedure. 47 | """ 48 | self._network.init_network_connection() 49 | 50 | def main_loop(self): 51 | """Define the actions of communication stage.""" 52 | raise NotImplementedError() 53 | 54 | def shutdown(self): 55 | """Shutdown stage. 56 | 57 | Close the network connection in the end. 58 | """ 59 | self._network.close_network_connection() 60 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedavgm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # from utils_algorithms import MinNormSolver 4 | from fedlab.utils.aggregator import Aggregators 5 | from fedlab.utils.serialization import SerializationTool 6 | from fedlab.contrib.algorithm.basic_server import SyncServerHandler 7 | 8 | 9 | class FedAvgMServerHandler(SyncServerHandler): 10 | """ 11 | Hsu, Tzu-Ming Harry, Hang Qi, and Matthew Brown. "Measuring the effects of non-identical data distribution for federated visual classification." arXiv preprint arXiv:1909.06335 (2019). 12 | """ 13 | 14 | def setup_optim(self, beta): 15 | self.momentum = torch.zeros_like(self.model_parameters) 16 | self.beta = beta 17 | # def setup_optim(self, epochs, batch_size, lr, weight_decay, momentum, beta): 18 | # super().setup_optim(epochs, batch_size, lr, weight_decay, momentum) 19 | # # self.n = self.num_clients 20 | # self.num_to_sample = int(self.sample_ratio * self.n) 21 | # self.round_clients = int(self.sample_ratio * self.n) 22 | # self.sampler = sampler 23 | 24 | # self.args = args 25 | # self.lr = args.glr 26 | #self.k = args.k 27 | 28 | 29 | # @property 30 | # def num_clients_per_round(self): 31 | # return self.round_clients 32 | 33 | # def sample_clients(self, num_to_sample=None): 34 | # clients = self.sampler.sample(self.num_to_sample) 35 | # self.round_clients = len(clients) 36 | # assert self.num_clients_per_round == len(clients) 37 | # return clients 38 | 39 | def global_update(self, buffer): 40 | gradient_list = [ 41 | torch.sub(self.model_parameters, ele[0]) for ele in buffer 42 | ] 43 | weights = [ele[1] for ele in buffer] 44 | 45 | # indices, _ = self.sampler.last_sampled 46 | estimates = Aggregators.fedavg_aggregate(gradient_list, 47 | weights) 48 | self.momentum = self.beta * self.momentum + estimates 49 | 50 | serialized_parameters = self.model_parameters - self.momentum 51 | SerializationTool.deserialize_model(self._model, serialized_parameters) 52 | -------------------------------------------------------------------------------- /fedlab/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import sys 17 | 18 | logging.basicConfig( 19 | stream=sys.stdout, 20 | level=logging.INFO, 21 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 22 | datefmt='%H:%M:%S') 23 | 24 | logging.getLogger().setLevel(logging.INFO) 25 | 26 | 27 | class Logger(object): 28 | """record cmd info to file and print it to cmd at the same time 29 | 30 | Args: 31 | log_name (str): log name for output. 32 | log_file (str): a file path of log file. 33 | """ 34 | def __init__(self, log_name=None, log_file=None): 35 | if log_name is not None: 36 | self.logger = logging.getLogger(log_name) 37 | self.name = log_name 38 | else: 39 | logging.getLogger().setLevel(logging.INFO) 40 | self.logger = logging 41 | self.name = "root" 42 | 43 | if log_file is not None: 44 | handler = logging.FileHandler(log_file, mode='w') 45 | handler.setLevel(level=logging.INFO) 46 | formatter = logging.Formatter( 47 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 48 | handler.setFormatter(formatter) 49 | self.logger.addHandler(handler) 50 | 51 | def info(self, log_str): 52 | """Print information to logger""" 53 | self.logger.info(log_str) 54 | 55 | def warning(self, warning_str): 56 | """Print warning to logger""" 57 | self.logger.warning(warning_str) 58 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedasam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basic_server import SyncServerHandler 4 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 5 | from .fedavg import FedAvgServerHandler 6 | from .minimizers import ASAM 7 | from ...utils import Aggregators 8 | 9 | ################## 10 | # 11 | # Server 12 | # 13 | ################## 14 | 15 | 16 | class FedASamServerHandler(FedAvgServerHandler): 17 | pass 18 | 19 | ################## 20 | # 21 | # Client 22 | # 23 | ################## 24 | 25 | 26 | class FedASamSerialClientTrainer(SGDSerialClientTrainer): 27 | def __init__(self, model, num_clients, rho, eta, cuda=True, device=None, logger=None, personal=False) -> None: 28 | super().__init__(model, num_clients, cuda, device, logger, personal) 29 | self.rho = rho 30 | self.eta = eta 31 | 32 | def local_process(self, payload, id_list): 33 | model_parameters = payload[0] 34 | for id in id_list: 35 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 36 | # optimizer = torch.optim.SGD(model_parameters, lr=self.lr) 37 | minimizer = ASAM(self.optimizer, self.model, self.rho, self.eta) 38 | pack = self.train(id, model_parameters, minimizer, data_loader) 39 | self.cache.append(pack) 40 | 41 | def train(self, id, model_parameters, minimizer, train_loader): 42 | self.set_model(model_parameters) 43 | 44 | data_size = 0 45 | for _ in range(self.epochs): 46 | for data, target in train_loader: 47 | if self.cuda: 48 | data = data.cuda(self.device) 49 | target = target.cuda(self.device) 50 | 51 | # Ascent Step 52 | output = self.model(data) 53 | loss = self.criterion(output, target) 54 | 55 | loss.backward() 56 | minimizer.ascent_step() 57 | 58 | # Descent Step 59 | self.criterion(self.model(data), target).backward() 60 | minimizer.descent_step() 61 | 62 | data_size += len(target) 63 | 64 | return [self.model_parameters, data_size] 65 | -------------------------------------------------------------------------------- /tools/Lib/arg_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | MODELS = ['FedSAMcnn', 'resnet18_nonorm'] 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser(description="Standalone training") 6 | parser.add_argument("--seed", type=int, default=0) 7 | 8 | parser.add_argument("--model", type=str, required=True, choices=MODELS) 9 | parser.add_argument("--pre_trained", action='store_true') 10 | parser.add_argument("--alg", type=str, required=True) 11 | parser.add_argument("--dataset", type=str, required=True) 12 | parser.add_argument("--eval_every", type=int, required=True) 13 | parser.add_argument("--avg_test", action='store_true') 14 | parser.add_argument("--save_model", action='store_true') 15 | 16 | # dataset distribution 17 | parser.add_argument("--balance", action='store_true') 18 | parser.add_argument("--partition", type=str) 19 | parser.add_argument("--dir_alpha", type=str) 20 | parser.add_argument("--transform", action='store_true') 21 | 22 | parser.add_argument("--wandb_project_name", type=str, required=True) 23 | 24 | parser.add_argument("--total_client", type=int, default=100) 25 | parser.add_argument("--com_round", type=int) 26 | 27 | parser.add_argument("--sample_ratio", type=float) 28 | parser.add_argument("--batch_size", type=int) 29 | parser.add_argument("--epochs", type=int) 30 | parser.add_argument("--lr", type=float) 31 | parser.add_argument("--weight_decay", type=float, default=0.0004) 32 | parser.add_argument("--momentum", type=float, default=0) 33 | 34 | # scaffold 35 | parser.add_argument("--g_lr", type=float) 36 | # feddyn 37 | parser.add_argument("--alpha", type=float) 38 | # fedprox 39 | parser.add_argument("--mu", type=float) 40 | # fedsam 41 | parser.add_argument("--rho", type=float) 42 | # fedasam 43 | parser.add_argument("--eta", type=float) 44 | # mofedsam, fedavgm 45 | parser.add_argument("--beta", type=float) 46 | parser.add_argument("--eta_g", type=float, default=1) 47 | # FedGF 48 | parser.add_argument("--T_D", type=float) 49 | parser.add_argument("--g_rho", type=float) 50 | parser.add_argument("--W", type=int) 51 | 52 | return parser.parse_args() 53 | -------------------------------------------------------------------------------- /fedlab/utils/aggregator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | class Aggregators(object): 19 | """Define the algorithm of parameters aggregation""" 20 | 21 | @staticmethod 22 | def fedavg_aggregate(serialized_params_list, weights=None): 23 | """FedAvg aggregator 24 | 25 | Paper: http://proceedings.mlr.press/v54/mcmahan17a.html 26 | 27 | Args: 28 | serialized_params_list (list[torch.Tensor])): Merge all tensors following FedAvg. 29 | weights (list, numpy.array or torch.Tensor, optional): Weights for each params, the length of weights need to be same as length of ``serialized_params_list`` 30 | 31 | Returns: 32 | torch.Tensor 33 | """ 34 | if weights is None: 35 | weights = torch.ones(len(serialized_params_list)).cuda() 36 | 37 | if not isinstance(weights, torch.Tensor): 38 | weights = torch.tensor(weights, device='cuda:0') 39 | 40 | weights = weights / torch.sum(weights) 41 | assert torch.all(weights >= 0), "weights should be non-negative values" 42 | serialized_parameters = torch.sum( 43 | torch.stack(serialized_params_list, dim=-1) * weights, dim=-1) 44 | 45 | return serialized_parameters 46 | 47 | @staticmethod 48 | def fedasync_aggregate(server_param, new_param, alpha): 49 | """FedAsync aggregator 50 | 51 | Paper: https://arxiv.org/abs/1903.03934 52 | """ 53 | serialized_parameters = torch.mul(1 - alpha, server_param) + \ 54 | torch.mul(alpha, new_param) 55 | return serialized_parameters 56 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import torchvision 6 | 7 | from .basic_dataset import FedDataset, BaseDataset 8 | 9 | 10 | 11 | class SyntheticDataset(FedDataset): 12 | def __init__(self, root, path, preprocess=False) -> None: 13 | self.root = root 14 | self.path = path 15 | if preprocess is True: 16 | self.preprocess(root, path) 17 | else: 18 | print("Warning: please make sure that you have preprocess the data once!") 19 | 20 | def preprocess(self, root, path, partition=0.2): 21 | """Preprocess the raw data to fedlab dataset format. 22 | 23 | Args: 24 | root (str): path to the raw data. 25 | path (str): path to save the preprocessed datasets. 26 | partition (float, optional): The propotion of testset. Defaults to 0.2. 27 | """ 28 | if os.path.exists(self.path) is not True: 29 | os.mkdir(self.path) 30 | os.mkdir(os.path.join(self.path, "train")) 31 | os.mkdir(os.path.join(self.path, "var")) 32 | os.mkdir(os.path.join(self.path, "test")) 33 | 34 | raw_data = torch.load(root) 35 | users, user_data = raw_data["users"], raw_data["user_data"] 36 | 37 | for id in users: 38 | data, label = user_data[id]['x'], user_data[id]['y'] 39 | train_size = int(len(label)*partition) 40 | 41 | trainset = BaseDataset(torch.Tensor(data[0:train_size]), label[0:train_size]) 42 | torch.save(trainset, os.path.join(path, "train","data{}.pkl".format(id))) 43 | 44 | testset = BaseDataset(torch.Tensor(data[train_size:]), label[train_size:]) 45 | torch.save(testset, os.path.join(path, "test","data{}.pkl".format(id))) 46 | 47 | def get_dataset(self, id, type="train"): 48 | dataset = torch.load( 49 | os.path.join(self.path, type, "data{}.pkl".format(id))) 50 | return dataset 51 | 52 | def get_dataloader(self, id, batch_size, type="train"): 53 | dataset = self.get_dataset(id, type) 54 | batch_size = len(dataset) if batch_size is None else batch_size 55 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 56 | return data_loader 57 | 58 | -------------------------------------------------------------------------------- /fedlab/board/builtin/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.manifold import TSNE 3 | 4 | from fedlab.board import fedboard 5 | 6 | 7 | def client_param_tsne(round: int, client_ids: list[str]): 8 | if len(client_ids) < 2: 9 | return None, None 10 | client_params = {} 11 | for role_id, params in fedboard.read_logged_obj(round, 'client_params').items(): 12 | for k, v in params.items(): 13 | client_params[k] = v 14 | raw_params = {str(id): param for id, param in client_params.items()} 15 | params_selected = [raw_params[id] for id in client_ids if id in raw_params.keys()] 16 | id_existed = [id for id in client_ids if id in raw_params.keys()] 17 | if len(params_selected) <= 1: 18 | return None, None 19 | params_selected = torch.stack(params_selected) 20 | params_tsne = TSNE(n_components=2, learning_rate=100, random_state=501, 21 | perplexity=min(30.0, len(params_selected) - 1)).fit_transform( 22 | params_selected) 23 | return params_tsne, id_existed 24 | 25 | 26 | def get_client_dataset_tsne(client_ids: list[str], type: str, size, client_ranks: list[str]): 27 | if len(client_ids) < 1: 28 | return None 29 | if not fedboard.get_delegate(): 30 | return None 31 | raw = [] 32 | client_range = {} 33 | for client_id, rank in zip(client_ids, client_ranks): 34 | data, label = fedboard.get_delegate().sample_client_data(client_id, rank, type, size) 35 | client_range[client_id] = (len(raw), len(raw) + len(data)) 36 | raw += data 37 | if len(raw) == 0: 38 | return None 39 | raw = torch.stack(raw).view(len(raw), -1) 40 | tsne = TSNE(n_components=3, learning_rate=100, random_state=501, 41 | perplexity=min(30.0, len(raw) - 1)).fit_transform(raw) 42 | tsne = {cid: tsne[s:e] for cid, (s, e) in client_range.items()} 43 | return tsne 44 | 45 | 46 | def get_client_data_report(clients_ids: list[str], type: str, client_ranks: list[str]): 47 | res = {} 48 | for client_id, rank in zip(clients_ids, client_ranks): 49 | def rd(): 50 | if fedboard.get_delegate(): 51 | return fedboard.get_delegate().read_client_label(client_id, rank, type=type) 52 | else: 53 | return {} 54 | 55 | obj = fedboard.read_obj_with_cache('data', 'partition', f'{client_id}', rd) 56 | res[client_id] = obj 57 | return res 58 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/qfedavg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .basic_server import SyncServerHandler 4 | from .basic_client import SGDClientTrainer 5 | 6 | 7 | ################## 8 | # 9 | # Server 10 | # 11 | ################## 12 | 13 | 14 | class qFedAvgServerHandler(SyncServerHandler): 15 | """qFedAvg server handler.""" 16 | def global_update(self, buffer): 17 | deltas = [ele[0] for ele in buffer] 18 | hks = [ele[1] for ele in buffer] 19 | 20 | hk = sum(hks) 21 | updates = sum([delta/hk for delta in deltas]) 22 | model_parameters = self.model_parameters - updates 23 | 24 | self.set_model(model_parameters) 25 | 26 | 27 | ################## 28 | # 29 | # Client 30 | # 31 | ################## 32 | 33 | 34 | class qFedAvgClientTrainer(SGDClientTrainer): 35 | """Federated client with modified upload package and local SGD solver.""" 36 | @property 37 | def uplink_package(self): 38 | return [self.delta, self.hk] 39 | 40 | def setup_optim(self, epochs, batch_size, lr, q): 41 | super().setup_optim(epochs, batch_size, lr) 42 | self.q = q 43 | 44 | def train(self, model_parameters, train_loader) -> None: 45 | """Client trains its local model on local dataset. 46 | Args: 47 | model_parameters (torch.Tensor): Serialized model parameters. 48 | """ 49 | self.set_model(model_parameters) 50 | self._LOGGER.info("Local train procedure is running") 51 | for ep in range(self.epochs): 52 | self._model.train() 53 | ret_loss = 0.0 54 | for data, target in train_loader: 55 | if self.cuda: 56 | data, target = data.cuda(self.gpu), target.cuda( 57 | self.gpu) 58 | 59 | outputs = self._model(data) 60 | loss = self.criterion(outputs, target) 61 | 62 | self.optimizer.zero_grad() 63 | loss.backward() 64 | self.optimizer.step() 65 | 66 | ret_loss += loss.detach().item() 67 | self._LOGGER.info("Local train procedure is finished") 68 | 69 | grad = (model_parameters - self.model_parameters) / self.lr 70 | self.delta = grad * np.float_power(ret_loss + 1e-10, self.q) 71 | self.hk = self.q * np.float_power( 72 | ret_loss + 1e-10, self.q - 1) * grad.norm( 73 | )**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q) 74 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/powerofchoice.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | from .basic_server import SyncServerHandler 6 | from .basic_client import SGDSerialClientTrainer 7 | from ...core.standalone import StandalonePipeline 8 | from ...utils import functional as F 9 | 10 | 11 | ##################### 12 | # # 13 | # Pipeline # 14 | # # 15 | ##################### 16 | 17 | class PowerofchoicePipeline(StandalonePipeline): 18 | def main(self): 19 | while self.handler.if_stop is False: 20 | candidates = self.handler.sample_candidates() 21 | losses = self.trainer.evaluate(candidates, 22 | self.handler.model_parameters) 23 | 24 | # server side 25 | sampled_clients = self.handler.sample_clients(candidates, losses) 26 | broadcast = self.handler.downlink_package 27 | 28 | # client side 29 | self.trainer.local_process(broadcast, sampled_clients) 30 | uploads = self.trainer.uplink_package 31 | 32 | # server side 33 | for pack in uploads: 34 | self.handler.load(pack) 35 | 36 | 37 | ##################### 38 | # # 39 | # Server # 40 | # # 41 | ##################### 42 | 43 | 44 | class Powerofchoice(SyncServerHandler): 45 | def setup_optim(self, d): 46 | self.d = d # the number of candidate 47 | 48 | def sample_candidates(self): 49 | selection = random.sample(range(self.num_clients), self.d) 50 | selection = sorted(selection) 51 | return selection 52 | 53 | def sample_clients(self, candidates, losses): 54 | sort = np.array(losses).argsort().tolist() 55 | sort.reverse() 56 | selected_clients = np.array(candidates)[sort][0:self.num_clients_per_round] 57 | return selected_clients.tolist() 58 | 59 | 60 | ##################### 61 | # # 62 | # Client # 63 | # # 64 | ##################### 65 | 66 | class PowerofchoiceSerialClientTrainer(SGDSerialClientTrainer): 67 | def evaluate(self, id_list, model_parameters): 68 | self.set_model(model_parameters) 69 | losses = [] 70 | for id in id_list: 71 | dataloader = self.dataset.get_dataloader(id) 72 | loss, acc = F.evaluate(self._model, torch.nn.CrossEntropyLoss(), 73 | dataloader) 74 | losses.append(loss) 75 | return losses -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fednova.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basic_server import SyncServerHandler 4 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 5 | from ...utils import Aggregators 6 | 7 | ################## 8 | # 9 | # Server 10 | # 11 | ################## 12 | 13 | 14 | class FedNovaServerHandler(SyncServerHandler): 15 | """FedAvg server handler.""" 16 | 17 | def setup_optim(self, option="weighted_scale"): 18 | self.option = option # weighted_scale, uniform, weighted_com 19 | 20 | def global_update(self, buffer): 21 | models = [elem[0] for elem in buffer] 22 | taus = [elem[1] for elem in buffer] 23 | 24 | deltas = [(model - self.model_parameters)/tau for model, tau in zip(models, taus)] 25 | 26 | # p is the FedAvg weight, we simply set it 1/m here. 27 | p = [ 28 | 1.0 / self.num_clients_per_round 29 | for _ in range(self.num_clients_per_round) 30 | ] 31 | 32 | if self.option == 'weighted_scale': 33 | K = len(deltas) 34 | N = self.num_clients 35 | tau_eff = sum([tauk * pk for tauk, pk in zip(taus, p)]) 36 | delta = sum([dk * pk 37 | for dk, pk in zip(deltas, p)]) * N / K 38 | 39 | elif self.option == 'uniform': 40 | tau_eff = 1.0 * sum(taus) / len(deltas) 41 | delta = Aggregators.fedavg_aggregate(deltas) 42 | 43 | elif self.option == 'weighted_com': 44 | tau_eff = sum([tauk * pk for tauk, pk in zip(taus, p)]) 45 | delta = sum([dk * pk for dk, pk in zip(deltas, p)]) 46 | 47 | else: 48 | sump = sum(p) 49 | p = [pk / sump for pk in p] 50 | tau_eff = sum([tauk * pk for tauk, pk in zip(taus, p)]) 51 | delta = sum([dk * pk for dk, pk in zip(deltas, p)]) 52 | 53 | self.set_model(self.model_parameters + tau_eff * delta) 54 | 55 | 56 | ################## 57 | # 58 | # Client 59 | # 60 | ################## 61 | 62 | 63 | class FedNovaSerialClientTrainer(SGDSerialClientTrainer): 64 | """Federated client with local SGD solver.""" 65 | 66 | def local_process(self, payload, id_list): 67 | model_parameters = payload[0] 68 | for id in id_list: 69 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 70 | pack = self.train(model_parameters, data_loader) 71 | tau = [torch.Tensor([len(data_loader) * self.epochs])] 72 | pack += tau 73 | self.cache.append(pack) -------------------------------------------------------------------------------- /fedlab/contrib/dataset/shakespeare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class ShakespeareDataset(Dataset): 7 | def __init__(self, client_id: int, client_str: str, data: list, 8 | targets: list): 9 | """get `Dataset` for shakespeare dataset 10 | Args: 11 | client_id (int): client id 12 | client_str (str): client name string 13 | data (list): sentence list data 14 | targets (list): next-character target list 15 | """ 16 | self.client_id = client_id 17 | self.client_str = client_str 18 | self.ALL_LETTERS, self.VOCAB_SIZE = self._build_vocab() 19 | self.data = data 20 | self.targets = targets 21 | self._process_data_target() 22 | 23 | def _build_vocab(self): 24 | """ according all letters to build vocab 25 | Vocabulary re-used from the Federated Learning for Text Generation tutorial. 26 | https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation 27 | Returns: 28 | all letters vocabulary list and length of vocab list 29 | """ 30 | ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}" 31 | VOCAB_SIZE = len(ALL_LETTERS) 32 | return ALL_LETTERS, VOCAB_SIZE 33 | 34 | def _process_data_target(self): 35 | """process client's data and target 36 | """ 37 | self.data = torch.tensor( 38 | [self.__sentence_to_indices(sentence) for sentence in self.data]) 39 | self.targets = torch.tensor( 40 | [self.__letter_to_index(letter) for letter in self.targets]) 41 | 42 | def __sentence_to_indices(self, sentence: str): 43 | """Returns list of integer for character indices in ALL_LETTERS 44 | Args: 45 | sentence (str): input sentence 46 | Returns: a integer list of character indices 47 | """ 48 | indices = [] 49 | for c in sentence: 50 | indices.append(self.ALL_LETTERS.find(c)) 51 | return indices 52 | 53 | def __letter_to_index(self, letter: str): 54 | """Returns index in ALL_LETTERS of given letter 55 | Args: 56 | letter (char/str[0]): input letter 57 | Returns: int index of input letter 58 | """ 59 | index = self.ALL_LETTERS.find(letter) 60 | return index 61 | 62 | def __len__(self): 63 | return len(self.targets) 64 | 65 | def __getitem__(self, index): 66 | return self.data[index], self.targets[index] 67 | -------------------------------------------------------------------------------- /fedlab/contrib/client_sampler/importance_sampler.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import FedSampler 2 | import numpy as np 3 | 4 | 5 | 6 | class MultiArmedBanditSampler(FedSampler): 7 | "Refer to [Stochastic Optimization with Bandit Sampling](https://arxiv.org/abs/1708.02544)." 8 | 9 | def __init__(self, n, T, L): 10 | super().__init__(n) 11 | self.name = "mabs" 12 | self.w = np.ones(n) 13 | self.p = np.ones(n) / float(n) 14 | 15 | self.eta = 0.4 16 | self.delta = np.sqrt( 17 | (self.eta**4) * np.log(self.n) / ((self.n**5) * T * (L**2))) 18 | self.last_sampled = None 19 | 20 | def sample(self, batch_size): 21 | sampled = np.random.choice(np.arange(self.n), 22 | size=batch_size, 23 | replace=True, 24 | p=self.p) 25 | p = self.p[sampled] 26 | self.last_sampled = (sampled, p) 27 | return np.sort(sampled) 28 | 29 | def update(self, loss): 30 | at = loss**2 / (self.n**2) 31 | indices, p = self.last_sampled 32 | self.w[indices] *= np.exp(self.delta * at / p**3) 33 | self.p = (1 - self.eta) * self.w / np.sum(self.w) + self.eta / self.n 34 | 35 | 36 | class OptimalSampler(FedSampler): 37 | "Refer to [Optimal Client Sampling for Federated Learning](arxiv.org/abs/2010.13723)." 38 | 39 | def __init__(self, n, k): 40 | super().__init__(n) 41 | self.name = "optimal" 42 | self.k = k 43 | self.p = None 44 | 45 | def sample(self, size=None): 46 | indices = np.arange( 47 | (self.n))[np.random.random_sample(self.n) <= self.p] 48 | self.last_sampled = indices, self.p[indices] 49 | return indices 50 | 51 | def update(self, loss): 52 | self.p = self.optim_solver(loss) 53 | 54 | def optim_solver(self, norms): 55 | norms = np.array(norms) 56 | idx = np.argsort(norms) 57 | probs = np.zeros(len(norms)) 58 | l = 0 59 | for l, id in enumerate(idx): 60 | l = l + 1 61 | if self.k + l - self.n > sum(norms[idx[0:l]]) / norms[id]: 62 | l -= 1 63 | break 64 | 65 | m = sum(norms[idx[0:l]]) 66 | for i in range(len(idx)): 67 | if i <= l: 68 | probs[idx[i]] = (self.k + l - self.n) * norms[idx[i]] / m 69 | else: 70 | probs[idx[i]] = 1 71 | 72 | return np.array(probs) 73 | 74 | # def estimate(self): 75 | # indices = np.arange( 76 | # (self.n))[np.random.random_sample(self.n) <= self.p] 77 | # return indices, self.p[indices] -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedopt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from tqdm import tqdm 4 | from torch.utils.data import DataLoader 5 | 6 | from ...utils.functional import evaluate, setup_seed, AverageMeter 7 | from ...utils.serialization import SerializationTool 8 | from ...utils.aggregator import Aggregators 9 | from ...contrib.algorithm.fedavg import FedAvgSerialClientTrainer, FedAvgServerHandler 10 | 11 | 12 | class FedOptServerHandler(FedAvgServerHandler): 13 | def setup_optim(self, sampler, args): 14 | self.n = self.num_clients 15 | self.num_to_sample = int(self.sample_ratio * self.n) 16 | self.round_clients = int(self.sample_ratio * self.n) 17 | self.sampler = sampler 18 | 19 | self.args = args 20 | self.lr = args.glr 21 | # momentum 22 | self.beta1 = self.args.beta1 23 | self.beta2 = self.args.beta2 24 | self.option = self.args.option 25 | self.tau = self.args.tau 26 | self.momentum = torch.zeros_like(self.model_parameters) 27 | self.vt = torch.zeros_like(self.model_parameters) 28 | assert self.option in ["adagrad", "yogi", "adam"] 29 | 30 | @property 31 | def num_clients_per_round(self): 32 | return self.round_clients 33 | 34 | def local_process(self, payload, id_list): 35 | model_parameters = payload[0] 36 | loss_ = AverageMeter() 37 | acc_ = AverageMeter() 38 | for id in tqdm(id_list): 39 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 40 | pack = self.train(model_parameters, data_loader, loss_, acc_) 41 | self.cache.append(pack) 42 | return loss_, acc_ 43 | 44 | def global_update(self, buffer): 45 | gradient_list = [ 46 | torch.sub(ele[0], self.model_parameters) for ele in buffer 47 | ] 48 | indices, _ = self.sampler.last_sampled 49 | delta = Aggregators.fedavg_aggregate(gradient_list, 50 | self.args.weights[indices]) 51 | self.momentum = self.beta1 * self.momentum + (1 - self.beta1) * delta 52 | 53 | delta_2 = torch.pow(delta, 2) 54 | if self.option == "adagrad": 55 | self.vt += delta_2 56 | elif self.option == "yogi": 57 | self.vt = self.vt - ( 58 | 1 - self.beta2) * delta_2 * torch.sign(self.vt - delta_2) 59 | else: 60 | # adam 61 | self.vt = self.beta2 * self.vt + (1 - self.beta2) * delta_2 62 | 63 | serialized_parameters = self.model_parameters + self.lr * self.momentum / ( 64 | torch.sqrt(self.vt) + self.tau) 65 | self.set_model(serialized_parameters) 66 | 67 | -------------------------------------------------------------------------------- /fedlab/contrib/compressor/topk.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | 18 | from .compressor import Compressor 19 | 20 | 21 | class TopkCompressor(Compressor): 22 | """ Compressor for federated communication 23 | Top-k gradient or weights selection 24 | Args: 25 | compress_ratio (float): compress ratio 26 | """ 27 | def __init__(self, compress_ratio): 28 | self.compress_ratio = compress_ratio if compress_ratio <= 1.0 else 1.0 / compress_ratio 29 | self.index_dtype = torch.int64 30 | self.value_dtype = torch.float32 31 | 32 | def compress(self, tensor): 33 | """compress tensor into (values, indices) 34 | Args: 35 | tensor (torch.Tensor): tensor 36 | Returns: 37 | tuple: (values, indices) 38 | """ 39 | if torch.is_tensor(tensor): 40 | tensor = tensor.detach() 41 | else: 42 | raise TypeError( 43 | "Invalid type error, expecting {}, but get {}".format( 44 | torch.Tensor, type(tensor))) 45 | 46 | numel = tensor.numel() 47 | top_k_samples = int(math.ceil(numel * self.compress_ratio)) 48 | 49 | tensor = tensor.view(-1) 50 | importance = tensor.abs() 51 | 52 | _, indices = torch.topk(importance, 53 | top_k_samples, 54 | 0, 55 | largest=True, 56 | sorted=False) 57 | values = tensor[indices] 58 | 59 | values = values.to(dtype=self.value_dtype) 60 | indices = indices.to(dtype=self.index_dtype) 61 | 62 | return values, indices 63 | 64 | def decompress(self, values, indices, shape): 65 | """decompress tensor""" 66 | de_tensor = torch.zeros(size=shape, dtype=self.value_dtype).view(-1) 67 | de_tensor = de_tensor.index_put_([indices], values, 68 | accumulate=True).view(shape) 69 | return de_tensor 70 | -------------------------------------------------------------------------------- /fedlab/core/server/hierarchical/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch.multiprocessing import Queue 17 | 18 | from .connector import ClientConnector, ServerConnector 19 | from ...network import DistNetwork 20 | from ....utils import Logger 21 | 22 | torch.multiprocessing.set_sharing_strategy("file_system") 23 | 24 | 25 | class Scheduler(): 26 | """Middle Topology for hierarchical communication pattern. 27 | 28 | Scheduler uses message queues to decouple connector modules. 29 | 30 | Args: 31 | net_upper (DistNetwork): Distributed network manager of server from upper level. 32 | net_lower (DistNetwork): Distributed network manager of clients from lower level. 33 | """ 34 | 35 | def __init__(self, net_upper: DistNetwork, net_lower: DistNetwork): 36 | super(Scheduler, self).__init__() 37 | self.__MQs = [Queue(), Queue()] 38 | self.net_upper = net_upper 39 | self.logger_upper = Logger( 40 | log_name="Scheduler{}-ServerConnector".format(self.net_upper.rank)) 41 | 42 | self.net_lower = net_lower 43 | self.logger_lower = Logger( 44 | log_name="Scheduler{}-ClientConnector".format(self.net_upper.rank)) 45 | 46 | def run(self): 47 | connect_server = ServerConnector(self.net_upper, 48 | write_queue=self.__MQs[1], 49 | read_queue=self.__MQs[0], 50 | logger=self.logger_upper) 51 | 52 | connect_client = ClientConnector(self.net_lower, 53 | write_queue=self.__MQs[0], 54 | read_queue=self.__MQs[1], 55 | logger=self.logger_lower) 56 | 57 | connect_server.start() 58 | connect_client.start() 59 | 60 | # This is a tiny bug. 61 | # The process with connect_client.join() is always blocked somehow in the shutdown stage. 62 | # You are will come to solve this. (I think it's some process synchronization problems) 63 | connect_server.join() 64 | # connect_client.join() 65 | connect_client.kill() 66 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedavg.py: -------------------------------------------------------------------------------- 1 | from .basic_server import SyncServerHandler 2 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 3 | from ...utils.aggregator import Aggregators 4 | from ...utils.serialization import SerializationTool 5 | import torch 6 | import copy 7 | 8 | ################## 9 | # 10 | # Server 11 | # 12 | ################## 13 | 14 | 15 | class FedAvgServerHandler(SyncServerHandler): 16 | """FedAvg server handler.""" 17 | def global_update(self, buffer, upload_res=False): 18 | parameters_list = [ele[0] for ele in buffer] 19 | weights = torch.tensor([ele[1] for ele in buffer]).to(self.device) 20 | serialized_parameters = Aggregators.fedavg_aggregate(parameters_list, weights) 21 | SerializationTool.deserialize_model(self._model, serialized_parameters) 22 | 23 | def setup_swa_model(self): 24 | self.swa_model = copy.deepcopy(self.model_parameters) 25 | 26 | def update_swa_model(self, alpha): 27 | self.swa_model *= (1.0 - alpha) 28 | self.swa_model += self.model_parameters * alpha 29 | # for param1, param2 in zip(self.swa_model, self.model_parameters): 30 | # param1.data *= (1.0 - alpha) 31 | # param1.data += param2.data * alpha 32 | 33 | def update_clients_lr(self, lr, clients=None): 34 | if clients is None: 35 | clients = self.round_clients 36 | for c in clients: 37 | c.update_lr(lr) 38 | 39 | 40 | ################## 41 | # 42 | # Client 43 | # 44 | ################## 45 | 46 | 47 | class FedAvgClientTrainer(SGDClientTrainer): 48 | """Federated client with local SGD solver.""" 49 | def global_update(self, buffer): 50 | parameters_list = [ele[0] for ele in buffer] 51 | weights = [ele[1] for ele in buffer] 52 | serialized_parameters = Aggregators.fedavg_aggregate( 53 | parameters_list, weights) 54 | SerializationTool.deserialize_model(self._model, serialized_parameters) 55 | 56 | 57 | class FedAvgSerialClientTrainer(SGDSerialClientTrainer): 58 | """Federated client with local SGD solver.""" 59 | def train(self, model_parameters, train_loader): 60 | self.set_model(model_parameters) 61 | self._model.train() 62 | 63 | data_size = 0 64 | for _ in range(self.epochs): 65 | for batch_idx, (data, target) in enumerate(train_loader): 66 | if self.cuda: 67 | data = data.cuda(self.device) 68 | target = target.cuda(self.device) 69 | 70 | output = self.model(data) 71 | loss = self.criterion(output, target) 72 | 73 | data_size += len(target) 74 | 75 | self.optimizer.zero_grad() 76 | loss.backward() 77 | self.optimizer.step() 78 | 79 | return [self.model_parameters, data_size] 80 | -------------------------------------------------------------------------------- /fedlab/core/server/handler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import abstractmethod 16 | from typing import List 17 | 18 | import torch 19 | 20 | from ..model_maintainer import ModelMaintainer 21 | 22 | 23 | class ServerHandler(ModelMaintainer): 24 | """An abstract class representing handler of parameter server. 25 | 26 | Please make sure that your self-defined server handler class subclasses this class 27 | 28 | Example: 29 | Read source code of :class:`SyncServerHandler` and :class:`AsyncServerHandler`. 30 | 31 | Args: 32 | model (torch.nn.Module): PyTorch model. 33 | cuda (bool): Use GPUs or not. 34 | device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to None. If device is None and cuda is True, FedLab will set the gpu with the largest memory as default. 35 | """ 36 | def __init__(self, 37 | model: torch.nn.Module, 38 | cuda: bool, 39 | device: str = None) -> None: 40 | super().__init__(model, cuda, device) 41 | 42 | @property 43 | @abstractmethod 44 | def downlink_package(self) -> List[torch.Tensor]: 45 | """Property for manager layer. Server manager will call this property when activates clients.""" 46 | raise NotImplementedError() 47 | 48 | # only sync handler need this property 49 | # @property 50 | # def num_clients_per_round(self): 51 | # return self.round_clients 52 | 53 | @property 54 | @abstractmethod 55 | def if_stop(self) -> bool: 56 | """:class:`NetworkManager` keeps monitoring this attribute, and it will stop all related processes and threads when ``True`` returned.""" 57 | return False 58 | 59 | @abstractmethod 60 | def setup_optim(self): 61 | """Override this function to load your optimization hyperparameters.""" 62 | raise NotImplementedError() 63 | 64 | @abstractmethod 65 | def global_update(self, buffer): 66 | raise NotImplementedError() 67 | 68 | @abstractmethod 69 | def load(self, payload): 70 | """Override this function to define how to update global model (aggregation or optimization).""" 71 | raise NotImplementedError() 72 | 73 | @abstractmethod 74 | def evaluate(self): 75 | """Override this function to define the evaluation of global model.""" 76 | raise NotImplementedError() 77 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/feddyn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basic_server import SyncServerHandler 4 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 5 | from ...utils import Aggregators 6 | 7 | 8 | ################## 9 | # 10 | # Server 11 | # 12 | ################## 13 | 14 | 15 | class FedDynServerHandler(SyncServerHandler): 16 | """FedAvg server handler.""" 17 | def setup_optim(self, alpha): 18 | self.alpha = alpha 19 | self.h = torch.zeros_like(self.model_parameters) 20 | 21 | def global_update(self, buffer): 22 | parameters_list = [ele[0] for ele in buffer] 23 | deltas = sum([parameters-self.model_parameters for parameters in parameters_list]) 24 | self.h = self.h - self.alpha * (1.0/self.num_clients) * deltas 25 | new_parameters = Aggregators.fedavg_aggregate(parameters_list) - 1.0 / self.alpha * self.h 26 | self.set_model(new_parameters) 27 | 28 | 29 | ################## 30 | # 31 | # Client 32 | # 33 | ################## 34 | 35 | 36 | class FedDynSerialClientTrainer(SGDSerialClientTrainer): 37 | def __init__(self, model, num_clients, cuda=True, device=None, logger=None, personal=False) -> None: 38 | super().__init__(model, num_clients, cuda, device, logger, personal) 39 | 40 | self.L = [None for _ in range(num_clients)] 41 | 42 | 43 | def setup_dataset(self, dataset): 44 | return super().setup_dataset(dataset) 45 | 46 | def setup_optim(self, epochs, batch_size, lr, weight_decay, momentum, alpha): 47 | self.alpha = alpha 48 | super().setup_optim(epochs, batch_size, lr, weight_decay, momentum) 49 | 50 | def local_process(self, payload, id_list): 51 | model_parameters = payload[0] 52 | for id in id_list: 53 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 54 | pack = self.train(id, model_parameters, data_loader) 55 | self.cache.append(pack) 56 | 57 | def train(self, id, model_parameters, train_loader): 58 | if self.L[id] is None: 59 | self.L[id] = torch.zeros_like(model_parameters) 60 | 61 | L_t = self.L[id] 62 | frz_parameters = model_parameters 63 | 64 | self.set_model(model_parameters) 65 | self._model.train() 66 | 67 | for _ in range(self.epochs): 68 | for data, target in train_loader: 69 | if self.cuda: 70 | data = data.cuda(self.device) 71 | target = target.cuda(self.device) 72 | 73 | output = self.model(data) 74 | l1 = self.criterion(output, target) 75 | l2 = torch.dot(L_t, self.model_parameters) 76 | l3 = torch.sum(torch.pow(self.model_parameters - frz_parameters,2)) 77 | 78 | loss = l1 - l2 + 0.5 * self.alpha * l3 79 | 80 | self.optimizer.zero_grad() 81 | loss.backward() 82 | self.optimizer.step() 83 | 84 | self.L[id] = L_t - self.alpha * (self.model_parameters-frz_parameters) 85 | 86 | return [self.model_parameters] 87 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedsam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basic_server import SyncServerHandler 4 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 5 | from .fedavg import FedAvgServerHandler 6 | from .minimizers import SAM 7 | from ...utils import Aggregators 8 | 9 | ################## 10 | # 11 | # Server 12 | # 13 | ################## 14 | 15 | 16 | # class FedSAMServerHandler(SyncServerHandler): 17 | class FedSamServerHandler(FedAvgServerHandler): 18 | pass 19 | 20 | # super().__init__() 21 | """FedAvg server handler.""" 22 | # @property 23 | # def downlink_package(self): 24 | # return [self.model_parameters] #, self.global_c] 25 | # 26 | # def setup_optim(self, lr): 27 | # self.lr = lr 28 | # self.global_c = torch.zeros_like(self.model_parameters) 29 | 30 | # def global_update(self, buffer): 31 | # # unpack 32 | # dys = [ele[0] for ele in buffer] 33 | # dcs = [ele[1] for ele in buffer] 34 | # 35 | # dx = Aggregators.fedavg_aggregate(dys) 36 | # dc = Aggregators.fedavg_aggregate(dcs) 37 | # 38 | # next_model = self.model_parameters + self.lr * dx 39 | # self.set_model(next_model) 40 | 41 | # self.global_c += 1.0 * len(dcs) / self.num_clients * dc 42 | 43 | 44 | ################## 45 | # 46 | # Client 47 | # 48 | ################## 49 | 50 | 51 | class FedSamSerialClientTrainer(SGDSerialClientTrainer): 52 | def __init__(self, model, num_clients, rho, cuda=True, device=None, logger=None, personal=False) -> None: 53 | super().__init__(model, num_clients, cuda, device, logger, personal) 54 | self.rho = rho 55 | 56 | # def setup_optim(self, epochs, batch_size, lr): 57 | # super().setup_optim(epochs, batch_size, lr) 58 | # optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) 59 | # self.SAM = SAM(self.optimizer, self.model, self.rho) 60 | 61 | def local_process(self, payload, id_list): 62 | model_parameters = payload[0] 63 | for id in id_list: 64 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 65 | # optimizer = torch.optim.SGD(model_parameters, lr=self.lr) 66 | minimizer = SAM(self.optimizer, self.model, self.rho) 67 | pack = self.train(id, model_parameters, minimizer, data_loader) 68 | self.cache.append(pack) 69 | 70 | def train(self, id, model_parameters, minimizer, train_loader): 71 | self.set_model(model_parameters) 72 | 73 | data_size = 0 74 | for _ in range(self.epochs): 75 | for data, target in train_loader: 76 | if self.cuda: 77 | data = data.cuda(self.device) 78 | target = target.cuda(self.device) 79 | 80 | # Ascent Step 81 | output = self.model(data) 82 | loss = self.criterion(output, target) 83 | 84 | loss.backward() 85 | minimizer.ascent_step() 86 | 87 | # Descent Step 88 | self.criterion(self.model(data), target).backward() 89 | minimizer.descent_step() 90 | 91 | data_size += len(target) 92 | 93 | return [self.model_parameters, data_size] 94 | -------------------------------------------------------------------------------- /fedlab/contrib/compressor/quantization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from .compressor import Compressor 17 | 18 | 19 | class QSGDCompressor(Compressor): 20 | """Quantization compressor. 21 | 22 | A implementation for paper https://proceedings.neurips.cc/paper/2017/file/6c340f25839e6acdc73414517203f5f0-Paper.pdf. 23 | 24 | Alistarh, Dan, et al. "QSGD: Communication-efficient SGD via gradient quantization and encoding." Advances in Neural Information Processing Systems 30 (2017): 1709-1720. 25 | Thanks to git repo: https://github.com/xinyandai/gradient-quantization 26 | 27 | Args: 28 | n_bit (int): the bits num for quantization. Bigger n_bit comes with better compress precision but more communication consumption. 29 | random (bool, optional): Carry bit with probability. Defaults to True. 30 | cuda (bool, optional): use GPU. Defaults to False. 31 | """ 32 | def __init__(self, n_bit, random=True, cuda=False): 33 | self.random = random 34 | self.bit = n_bit 35 | 36 | self.cuda = cuda 37 | self.s = 2**self.bit 38 | 39 | self.code_dtype = torch.int32 40 | 41 | def compress(self, tensor): 42 | """Compress a tensor with quantization 43 | Args: 44 | tensor ([type]): [description] 45 | Returns: 46 | norm (torch.Tensor): The normalization number. 47 | signs (torch.Tensor): Tensor that indicates the sign of coresponding number. 48 | quantized_intervals (torch.Tensor): Quantized tensor that each item in [0, 2**n_bit -1]. 49 | """ 50 | shape = tensor.shape 51 | vec = tensor.view(-1) 52 | # norm = torch.norm(vec, dim=1, keepdim=True) 53 | norm = torch.max(torch.abs(vec), dim=0, keepdim=True)[0] 54 | normalized_vec = vec / norm 55 | 56 | scaled_vec = torch.abs(normalized_vec) * self.s 57 | l = torch.clamp(scaled_vec, 0, self.s - 1).type(self.code_dtype) 58 | 59 | if self.random: 60 | # l[i] <- l[i] + 1 with probability |v_i| / ||v|| * s - l 61 | probabilities = scaled_vec - l.type(torch.float32) 62 | r = torch.rand(l.size()) 63 | if self.cuda: 64 | r = r.cuda() 65 | l[:] += (probabilities > r).type(self.code_dtype) 66 | 67 | signs = torch.sign(vec) > 0 68 | return [norm, signs.view(shape), l.view(shape)] 69 | 70 | def decompress(self, signature): 71 | """Decompress tensor 72 | Args: 73 | signature (list): [norm, signs, quantized_intervals], returned by :func:``compress``. 74 | Returns: 75 | torch.Tensor: Raw tensor represented by signature. 76 | """ 77 | [norm, signs, l] = signature 78 | assert l.shape == signs.shape 79 | shape = l.shape 80 | scaled_vec = l.type( 81 | torch.float32) * (2 * signs.type(torch.float32) - 1) 82 | compressed = (scaled_vec.view((-1))) * norm / self.s 83 | return compressed.view(shape) -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/ifca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import numpy as np 4 | 5 | from .basic_server import SyncServerHandler 6 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 7 | from ...utils import SerializationTool, Aggregators 8 | from ...utils.functional import evaluate 9 | 10 | 11 | ################## 12 | # 13 | # Server 14 | # 15 | ################## 16 | 17 | 18 | class IFCAServerHander(SyncServerHandler): 19 | def __init__(self, model: torch.nn.Module, global_round: int, sample_ratio: float, cuda: bool = False, device: str = None, logger = None): 20 | super().__init__(model, global_round, sample_ratio, cuda, device, logger) 21 | 22 | @property 23 | def downlink_package(self): 24 | return [self.shared_paramters] + self.global_models 25 | 26 | def setup_optim(self, share_size, k, init_parameters): 27 | """_summary_ 28 | 29 | Args: 30 | share_size (_type_): _description_ 31 | k (_type_): _description_ 32 | init_parameters (_type_): _description_ 33 | """ 34 | assert k == len(init_parameters) 35 | self.k = k 36 | self.share_size = share_size 37 | 38 | self.global_models = init_parameters 39 | self.shared_paramters = Aggregators.fedavg_aggregate(self.global_models)[0:self.share_size] 40 | 41 | def global_update(self, buffer): 42 | cluster_model = [[] for _ in range(self.k)] 43 | # weights = [[] for _ in range(self.k)] 44 | for i, (cid, id, paramters) in enumerate(buffer): 45 | cluster_model[cid].append(paramters) 46 | # weights[cid].append(self.client_trainer.weights[id]) 47 | 48 | parameters = Aggregators.fedavg_aggregate([ele for _, _, ele in buffer]) 49 | self.shared_paramters[0:self.share_size] = parameters[0:self.share_size] 50 | 51 | for i, ele in enumerate(cluster_model): 52 | if len(ele) > 0: 53 | self.global_models[i] = Aggregators.fedavg_aggregate(ele) 54 | 55 | 56 | ################## 57 | # 58 | # Client 59 | # 60 | ################## 61 | 62 | 63 | class IFCASerialClientTrainer(SGDSerialClientTrainer): 64 | def __init__(self, model, num_clients, cuda=False, device=None, logger=None, personal=False) -> None: 65 | super().__init__(model, num_clients, cuda, device, logger, personal) 66 | 67 | def setup_dataset(self, dataset): 68 | return super().setup_dataset(dataset) 69 | 70 | def setup_optim(self, epochs, batch_size, lr): 71 | return super().setup_optim(epochs, batch_size, lr) 72 | 73 | def local_process(self, payload, id_list): 74 | shared_model = payload[0] 75 | payload = payload[1:0] 76 | 77 | criterion = torch.nn.CrossEntropyLoss() 78 | results = [] 79 | for id in tqdm(id_list): 80 | data_loader = self._get_dataloader(id, self.args.batch_size) 81 | if len(payload) > 1: 82 | eval_loss = [] 83 | for i, model_parameters in enumerate(payload): 84 | 85 | model_parameters[0:shared_model.shape[0]] = shared_model[:] 86 | payload[i] = model_parameters 87 | 88 | SerializationTool.deserialize_model(self._model, model_parameters) 89 | loss, _ = evaluate(self._model, criterion, data_loader) 90 | eval_loss.append(loss) 91 | latent_cluster = np.argmin(eval_loss) 92 | else: 93 | latent_cluster = 0 94 | 95 | model_parameters = self.train(payload[latent_cluster], data_loader) 96 | results.append((latent_cluster, id, model_parameters)) 97 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/minimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | 4 | class ASAM: 5 | def __init__(self, optimizer, model, rho=0.5, eta=0.01): 6 | self.optimizer = optimizer 7 | self.model = model 8 | self.rho = rho 9 | self.eta = eta 10 | self.state = defaultdict(dict) 11 | 12 | @torch.no_grad() 13 | def ascent_step(self): 14 | wgrads = [] 15 | for n, p in self.model.named_parameters(): 16 | if p.grad is None: 17 | continue 18 | t_w = self.state[p].get("eps") 19 | if t_w is None: 20 | t_w = torch.clone(p).detach() 21 | self.state[p]["eps"] = t_w 22 | if 'weight' in n: 23 | t_w[...] = p[...] 24 | t_w.abs_().add_(self.eta) 25 | p.grad.mul_(t_w) 26 | wgrads.append(torch.norm(p.grad, p=2)) 27 | wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16 28 | for n, p in self.model.named_parameters(): 29 | if p.grad is None: 30 | continue 31 | t_w = self.state[p].get("eps") 32 | if 'weight' in n: 33 | p.grad.mul_(t_w) 34 | eps = t_w 35 | eps[...] = p.grad[...] 36 | eps.mul_(self.rho / wgrad_norm) 37 | p.add_(eps) 38 | self.optimizer.zero_grad() 39 | 40 | @torch.no_grad() 41 | def descent_step(self, init_model=None): 42 | if init_model is None: 43 | for n, p in self.model.named_parameters(): 44 | if p.grad is None: 45 | continue 46 | p.sub_(self.state[p]["eps"]) 47 | self.optimizer.step() 48 | self.optimizer.zero_grad() 49 | 50 | class SAM(ASAM): 51 | @torch.no_grad() 52 | def ascent_step(self): 53 | grads = [] 54 | for n, p in self.model.named_parameters(): 55 | if p.grad is None: 56 | continue 57 | grads.append(torch.norm(p.grad, p=2)) 58 | grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16 59 | for n, p in self.model.named_parameters(): 60 | if p.grad is None: 61 | continue 62 | eps = self.state[p].get("eps") 63 | if eps is None: 64 | eps = torch.clone(p).detach() 65 | self.state[p]["eps"] = eps 66 | eps[...] = p.grad[...] 67 | eps.mul_(self.rho / grad_norm) 68 | p.add_(eps) 69 | self.optimizer.zero_grad() 70 | 71 | 72 | class MoSAM(SAM): 73 | def __init__(self, optimizer, model, rho, beta, delta): 74 | super().__init__(optimizer, model, rho) 75 | self.beta = beta 76 | self.delta = delta 77 | # self.model_parameters_np = model_parameters_np 78 | 79 | @torch.no_grad() 80 | def descent_step(self): 81 | idx = 0 82 | for n, p in self.model.named_parameters(): 83 | layer_size = p.grad.numel() 84 | shape = p.grad.shape 85 | 86 | if p.grad is None: 87 | continue 88 | p.sub_(self.state[p]["eps"]) 89 | 90 | p.grad.mul_(self.beta) 91 | momentum_grad = self.delta[idx:idx + layer_size].view(shape)[:] 92 | momentum_grad = momentum_grad.mul_(1 - self.beta).cuda() 93 | 94 | p.grad.add_(momentum_grad) 95 | 96 | idx += layer_size 97 | self.optimizer.step() 98 | self.optimizer.zero_grad() 99 | class GF_ADMM(SAM): 100 | @torch.no_grad() 101 | def descent_step(self): 102 | self.optimizer.step() 103 | self.optimizer.zero_grad() -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/ditto.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from tqdm import * 4 | from copy import deepcopy 5 | 6 | from .basic_server import SyncServerHandler 7 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 8 | from ...utils.serialization import SerializationTool 9 | 10 | 11 | ################## 12 | # 13 | # Server 14 | # 15 | ################## 16 | 17 | 18 | class DittoServerHandler(SyncServerHandler): 19 | """Ditto server acts the same as fedavg server.""" 20 | None 21 | 22 | 23 | ################## 24 | # 25 | # Client 26 | # 27 | ################## 28 | 29 | 30 | class DittoSerialClientTrainer(SGDSerialClientTrainer): 31 | def __init__(self, model, num, cuda=False, device=None, logger=None, personal=True) -> None: 32 | super().__init__(model, num, cuda, device, logger, personal) 33 | self.ditto_gmodels = [] 34 | self.local_models = self.parameters 35 | 36 | def setup_dataset(self, dataset): 37 | return super().setup_dataset(dataset) 38 | 39 | def setup_optim(self, epochs, batch_size, lr): 40 | return super().setup_optim(epochs, batch_size, lr) 41 | 42 | def local_process(self, payload, id_list): 43 | global_model = payload[0] 44 | for id in tqdm(id_list): 45 | # self._LOGGER.info("Local process is running. Training client {}".format(id)) 46 | train_loader = self.dataset.get_dataloader(id, batch_size=self.batch_size) 47 | self.local_models[id], glb_model = self.train(global_model, self.local_models[id], train_loader) 48 | self.ditto_gmodels.append(deepcopy(glb_model)) 49 | 50 | @property 51 | def uplink_package(self): 52 | ditto_gmodels = deepcopy(self.ditto_gmodels) 53 | self.ditto_gmodels = [] 54 | return [[parameter] for parameter in ditto_gmodels] 55 | 56 | def train(self, global_model_parameters, local_model_parameters, train_loader): 57 | criterion = torch.nn.CrossEntropyLoss() 58 | SerializationTool.deserialize_model(self._model, global_model_parameters) 59 | self._model.train() 60 | for ep in range(self.epochs): 61 | for data, label in train_loader: 62 | if self.cuda: 63 | data, label = data.cuda(self.device), label.cuda(self.device) 64 | 65 | preds = self._model(data) 66 | loss = criterion(preds,label) 67 | self.optimizer.zero_grad() 68 | loss.backward() 69 | self.optimizer.step() 70 | 71 | updated_glb_models = deepcopy(self.model_parameters) 72 | 73 | frz_model = deepcopy(self._model) 74 | SerializationTool.deserialize_model(frz_model, global_model_parameters) 75 | 76 | SerializationTool.deserialize_model(self._model, local_model_parameters) 77 | criterion = torch.nn.CrossEntropyLoss() 78 | optimizer = torch.optim.SGD(self._model.parameters(), lr=self.lr) 79 | 80 | self._model.train() 81 | for ep in range(self.epochs): 82 | for data, label in train_loader: 83 | if self.cuda: 84 | data, label = data.cuda(self.device), label.cuda(self.device) 85 | 86 | preds = self._model(data) 87 | l1 = criterion(preds,label) 88 | l2 = 0.0 89 | for w0, w in zip(frz_model.parameters(), self._model.parameters()): 90 | l2 += torch.sum(torch.pow(w - w0, 2)) 91 | 92 | # loss = l1 + 0.5 * self.args.mu * l2 93 | loss = l1 + 0.5 * 0.1 * l2 # fedprox 的 mu 94 | optimizer.zero_grad() 95 | loss.backward() 96 | optimizer.step() 97 | return self.model_parameters, updated_glb_models 98 | -------------------------------------------------------------------------------- /fedlab/core/coordinator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class Coordinator(object): 17 | """Deal with the mapping relation between client id and process rank in FL system. 18 | 19 | Note 20 | Server Manager creates a Coordinator following: 21 | 1. init network connection. 22 | 2. client send local group info (the number of client simulating in local) to server. 23 | 4. server receive all info and init a server Coordinator. 24 | 25 | Args: 26 | setup_dict (dict): A dict like {rank:client_num ...}, representing the map relation between process rank and client id. 27 | mode (str, optional): “GLOBAL” and "LOCAL". Coordinator will map client id to (rank, global id) or (rank, local id) according to mode. For example, client id 51 is in a machine which has 1 manager and serial trainer simulating 10 clients. LOCAL id means the index of its 10 clients. Therefore, global id 51 will be mapped into local id 1 (depending on setting). 28 | """ 29 | def __init__(self, setup_dict: dict, mode: str = 'LOCAL') -> None: 30 | self.map = setup_dict 31 | self.mode = mode 32 | 33 | def map_id(self, id): 34 | """a map function from client id to (rank,local id) 35 | 36 | Args: 37 | id (int): client id 38 | 39 | Returns: 40 | rank, id : rank in distributed group and local id. 41 | """ 42 | m_id = id 43 | for rank, num in self.map.items(): 44 | if m_id >= num: 45 | m_id -= num 46 | else: 47 | local_id = m_id 48 | global_id = id 49 | ret_id = local_id if self.mode == 'LOCAL' else global_id 50 | return rank, ret_id 51 | 52 | def map_id_list(self, id_list: list): 53 | """a map function from id_list to dict{rank:local id} 54 | 55 | This can be very useful in Scale modules. 56 | 57 | Args: 58 | id_list (list(int)): a list of client id. 59 | 60 | Returns: 61 | map_dict (dict): contains process rank and its relative local client ids. 62 | """ 63 | map_dict = {} 64 | for id in id_list: 65 | rank, id = self.map_id(id) 66 | if rank in map_dict.keys(): 67 | map_dict[rank].append(id) 68 | else: 69 | map_dict[rank] = [id] 70 | return map_dict 71 | 72 | def switch(self): 73 | if self.mode == 'GLOBAL': 74 | self.mode = 'LOCAL' 75 | elif self.mode == 'LOCAL': 76 | self.mode = 'GLOBAL' 77 | else: 78 | raise ValueError("Invalid Map Mode {}".format(self.mode)) 79 | 80 | @property 81 | def total(self): 82 | return int(sum(self.map.values())) 83 | 84 | def __str__(self) -> str: 85 | return "Coordinator map information: {} \nMap mode: {} \nTotal: {}".format( 86 | self.map, self.mode, self.total) 87 | 88 | def __call__(self, info): 89 | if isinstance(info, int): 90 | return self.map_id(info) 91 | if isinstance(info, list): 92 | return self.map_id_list(info) -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedGamma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basic_server import SyncServerHandler 4 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 5 | from .fedavg import FedAvgServerHandler 6 | from .minimizers import SAM 7 | from ...utils import Aggregators 8 | 9 | ################## 10 | # 11 | # Server 12 | # 13 | ################## 14 | 15 | 16 | class FedGammaServerHandler(FedAvgServerHandler): 17 | pass 18 | # super().__init__() 19 | @property 20 | def downlink_package(self): 21 | return [self.model_parameters, self.c] 22 | 23 | def setup_optim(self): 24 | self.c = torch.zeros_like(self.model_parameters) 25 | 26 | def global_update(self, buffer): 27 | weights = [ele[0] for ele in buffer] 28 | delta_c = [ele[2] for ele in buffer] 29 | 30 | avg_model = Aggregators.fedavg_aggregate(weights) 31 | dc = Aggregators.fedavg_aggregate(delta_c) / self.num_clients 32 | self.c += dc 33 | self.set_model(avg_model) 34 | 35 | 36 | ################## 37 | # 38 | # Client 39 | # 40 | ################## 41 | 42 | 43 | class FedGammaSerialClientTrainer(SGDSerialClientTrainer): 44 | def __init__(self, model, num_clients, rho, cuda=True, device=None, logger=None, personal=False) -> None: 45 | super().__init__(model, num_clients, cuda, device, logger, personal) 46 | self.rho = rho 47 | self.c_i = [torch.zeros_like(self.model_parameters) for _ in range(num_clients)] 48 | 49 | def local_process(self, payload, id_list): 50 | model_parameters = payload[0] 51 | c = payload[1] 52 | for id in id_list: 53 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 54 | # optimizer = torch.optim.SGD(model_parameters, lr=self.lr) 55 | minimizer = SAM(self.optimizer, self.model, self.rho) 56 | pack = self.train(id, model_parameters, minimizer, data_loader, c) 57 | self.cache.append(pack) 58 | 59 | def train(self, id, model_parameters, minimizer, train_loader, c): 60 | self.set_model(model_parameters) 61 | c_i = self.c_i[id] 62 | 63 | data_size = 0 64 | K = 0 65 | for _ in range(self.epochs): 66 | for data, target in train_loader: 67 | origin_param = self.model_parameters 68 | K += 1 69 | 70 | if self.cuda: 71 | data = data.cuda(self.device) 72 | target = target.cuda(self.device) 73 | 74 | # Ascent Step 75 | output = self.model(data) 76 | loss = self.criterion(output, target) 77 | 78 | loss.backward() 79 | minimizer.ascent_step() 80 | 81 | # Descent Step 82 | self.criterion(self.model(data), target).backward() 83 | g_hat = self.model_gradients 84 | self.optimizer.zero_grad() 85 | grad = g_hat - c_i + c 86 | 87 | self.set_model(origin_param) 88 | 89 | current_index = 0 90 | for n,p in self.model.named_parameters(): 91 | numel = p.data.numel() 92 | size = p.data.size() 93 | p.grad.copy_( 94 | grad[current_index:current_index + numel].view(size)) 95 | current_index += numel 96 | 97 | self.optimizer.step() 98 | data_size += len(target) 99 | 100 | delta_c_i = (model_parameters - self.model_parameters) / (self.lr * K) - c 101 | c_i += delta_c_i 102 | self.c_i[id] = c_i 103 | return [self.model_parameters, data_size, delta_c_i] 104 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/sent140.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from pathlib import Path 4 | 5 | BASE_DIR = Path(__file__).resolve().parents[2] 6 | sys.path.append(str(BASE_DIR)) 7 | 8 | from torch.utils.data import Dataset 9 | from .utils import Tokenizer, Vocab 10 | 11 | 12 | class Sent140Dataset(Dataset): 13 | def __init__(self, 14 | client_id: int, 15 | client_str: str, 16 | data: list, 17 | targets: list, 18 | is_to_tokens: bool = True, 19 | tokenizer: Tokenizer = None): 20 | """get `Dataset` for sent140 dataset 21 | Args: 22 | client_id (int): client id 23 | client_str (str): client name string 24 | data (list): sentence list data 25 | targets (list): next-character target list 26 | is_to_tokens (bool, optional), if tokenize data by using tokenizer 27 | tokenizer (Tokenizer, optional), tokenizer 28 | """ 29 | self.client_id = client_id 30 | self.client_str = client_str 31 | self.data = data 32 | self.targets = targets 33 | self.data_token = [] 34 | self.data_tokens_tensor = [] 35 | self.targets_tensor = [] 36 | self.tokenizer = tokenizer if tokenizer else Tokenizer() 37 | 38 | self._process_data_target() 39 | if is_to_tokens: 40 | self._data2token() 41 | 42 | def _process_data_target(self): 43 | """process client's data and target 44 | """ 45 | self.data = [e[4] for e in self.data] 46 | self.targets = torch.tensor(self.targets, dtype=torch.long) 47 | 48 | def _data2token(self): 49 | assert self.data is not None 50 | for sen in self.data: 51 | self.data_token.append(self.tokenizer(sen)) 52 | 53 | def encode(self, vocab: 'Vocab', fix_len: int): 54 | """transform token data to indices sequence by `Vocab` 55 | Args: 56 | vocab (fedlab_benchmark.leaf.nlp_utils.util.vocab): vocab for data_token 57 | fix_len (int): max length of sentence 58 | Returns: 59 | list of integer list for data_token, and a list of tensor target 60 | """ 61 | if len(self.data_tokens_tensor) > 0: 62 | self.data_tokens_tensor.clear() 63 | self.targets_tensor.clear() 64 | pad_idx = vocab.get_index('') 65 | assert self.data_token is not None 66 | for tokens in self.data_token: 67 | self.data_tokens_tensor.append(self.__encode_tokens(tokens, vocab, pad_idx, fix_len)) 68 | for target in self.targets: 69 | self.targets_tensor.append(torch.tensor(target)) 70 | 71 | def __encode_tokens(self, tokens, vocab, pad_idx, fix_len) -> torch.Tensor: 72 | """encode `fix_len` length for token_data to get indices list in `self.vocab` 73 | if one sentence length is shorter than fix_len, it will use pad word for padding to fix_len 74 | if one sentence length is longer than fix_len, it will cut the first max_words words 75 | Args: 76 | tokens (list[str]): data after tokenizer 77 | vocab (fedlab_benchmark.leaf.nlp_utils.util.vocab): vocab for data_token 78 | pad_idx (int): '' index in vocab 79 | fix_len (int): max length of sentence 80 | Returns: 81 | integer list of indices with `fix_len` length for tokens input 82 | """ 83 | x = [pad_idx for _ in range(fix_len)] 84 | for idx, word in enumerate(tokens[:fix_len]): 85 | x[idx] = vocab.get_index(word) 86 | return torch.tensor(x) 87 | 88 | def __len__(self): 89 | return len(self.targets_tensor) 90 | 91 | def __getitem__(self, item): 92 | return self.data_tokens_tensor[item], self.targets_tensor[item] -------------------------------------------------------------------------------- /fedlab/contrib/dataset/rotated_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | import torch 19 | from torch.utils.data import DataLoader 20 | import torchvision 21 | from torchvision import transforms 22 | 23 | from .basic_dataset import FedDataset, BaseDataset 24 | from ...utils.dataset.functional import noniid_slicing, random_slicing 25 | 26 | class RotatedCIFAR10(FedDataset): 27 | """Rotate CIFAR10 and patrition them. 28 | 29 | Args: 30 | root (str): Path to download raw dataset. 31 | path (str): Path to save partitioned subdataset. 32 | num_clients (int): Number of clients. 33 | """ 34 | def __init__(self, root, save_dir, num_clients): 35 | self.root = os.path.expanduser(root) 36 | self.dir = save_dir 37 | self.num_clients = num_clients 38 | # "./datasets/rotated_mnist/" 39 | if os.path.exists(save_dir) is not True: 40 | os.mkdir(save_dir) 41 | os.mkdir(os.path.join(save_dir, "train")) 42 | os.mkdir(os.path.join(save_dir, "test")) 43 | 44 | self.transform = transforms.Compose( 45 | [transforms.ToTensor(), 46 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 47 | 48 | def preprocess(self, shards, thetas = [0, 180]): 49 | """_summary_ 50 | 51 | Args: 52 | shards (_type_): _description_ 53 | thetas (list, optional): _description_. Defaults to [0, 180]. 54 | """ 55 | cifar10 = torchvision.datasets.CIFAR10(self.root, train=True) 56 | id = 0 57 | for theta in thetas: 58 | rotated_data = [] 59 | partition = random_slicing(cifar10, shards) 60 | for x, _ in cifar10: 61 | x = self.transform(transforms.functional.rotate(x, theta)) 62 | rotated_data.append(x) 63 | for key, value in partition.items(): 64 | data = [rotated_data[i] for i in value] 65 | label = [cifar10.targets[i] for i in value] 66 | dataset = BaseDataset(data, label) 67 | torch.save(dataset, os.path.join(self.dir, "train", "data{}.pkl".format(id))) 68 | id += 1 69 | 70 | # test 71 | cifar10_test = torchvision.datasets.CIFAR10(self.root, train=False) 72 | labels = cifar10_test.targets 73 | for i, theta in enumerate(thetas): 74 | rotated_data = [] 75 | for x, y in cifar10_test: 76 | x = self.transform(transforms.functional.rotate(x, theta)) 77 | rotated_data.append(x) 78 | dataset = BaseDataset(rotated_data, labels) 79 | torch.save(dataset, os.path.join(self.dir,"test", "data{}.pkl".format(i))) 80 | 81 | def get_dataset(self, id, type="train"): 82 | dataset = torch.load(os.path.join(self.dir, type, "data{}.pkl".format(id))) 83 | return dataset 84 | 85 | def get_data_loader(self, id, batch_size=None, type="train"): 86 | dataset = self.get_dataset(id, type) 87 | batch_size = len(dataset) if batch_size is None else batch_size 88 | data_loader = DataLoader(dataset, batch_size=batch_size) 89 | return data_loader 90 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedSMOO_woReg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from collections import OrderedDict 4 | 5 | from .basic_server import SyncServerHandler 6 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 7 | from .fedavg import FedAvgServerHandler 8 | from .minimizers import SAM 9 | from ...utils import Aggregators 10 | from .bypass_bn import disable_running_stats, enable_running_stats 11 | from statistics import mean 12 | from collections import defaultdict 13 | 14 | ################## 15 | # 16 | # Server 17 | # 18 | ################## 19 | 20 | 21 | # class FedSAMServerHandler(SyncServerHandler): 22 | class FedSMOONoRegServerHandler(FedAvgServerHandler): 23 | 24 | @property 25 | def downlink_package(self): 26 | return [self.model_parameters, self.s] 27 | 28 | def setup_optim(self, rho): 29 | self.rho = rho 30 | self.s = torch.zeros_like(self.model_parameters) 31 | 32 | def global_update(self, buffer, upload_res=False): 33 | 34 | self.s = self.calc_s(buffer) 35 | super().global_update(buffer) 36 | 37 | def calc_s(self, buffer): 38 | parameters_list = [ele[2] for ele in buffer] 39 | weights = torch.ones(len(parameters_list)).cuda() 40 | weights = weights / torch.sum(weights) 41 | 42 | serialized_parameters = torch.sum(torch.stack(parameters_list, dim=-1) / weights, dim=-1) 43 | return self.rho * serialized_parameters / serialized_parameters.norm() 44 | 45 | ################## 46 | # 47 | # Client 48 | # 49 | ################## 50 | 51 | 52 | class FedSMOONoRegSerialClientTrainer(SGDSerialClientTrainer): 53 | def __init__(self, model, num_clients, rho, cuda=True, device=None, logger=None, personal=False) -> None: 54 | super().__init__(model, num_clients, cuda, device, logger, personal) 55 | self.mu_i = [torch.zeros_like(self.model_parameters) for _ in range(num_clients)] 56 | self.rho = rho 57 | 58 | def local_process(self, payload, id_list): 59 | model_parameters = payload[0] 60 | s = payload[1] 61 | 62 | for id in id_list: 63 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 64 | pack = self.train(id, model_parameters, data_loader, s) 65 | self.cache.append(pack) 66 | 67 | def train(self, id, model_parameters, train_loader, s): 68 | self.set_model(model_parameters) 69 | hat_s = None 70 | mu_i = copy.deepcopy(self.mu_i[id]) 71 | 72 | data_size = 0 73 | for _ in range(self.epochs): 74 | for data, target in train_loader: 75 | origin_param = self.model_parameters 76 | if self.cuda: 77 | data = data.cuda(self.device) 78 | target = target.cuda(self.device) 79 | 80 | output = self.model(data) 81 | loss = self.criterion(output, target) 82 | 83 | data_size += len(target) 84 | 85 | self.optimizer.zero_grad() 86 | loss.backward() 87 | 88 | tier = self.model_gradients - mu_i - s 89 | hat_s = self.calc_hats(tier) 90 | mu_i = mu_i + hat_s - s 91 | 92 | self.optimizer.zero_grad() 93 | 94 | self.set_model(self.model_parameters+hat_s) 95 | output = self.model(data) 96 | loss = self.criterion(output, target) 97 | loss.backward() 98 | 99 | self.set_model(origin_param) 100 | 101 | self.optimizer.step() 102 | 103 | data_size += len(target) 104 | tilde_si = mu_i - hat_s 105 | self.mu_i[id] = mu_i 106 | 107 | return [self.model_parameters, data_size, tilde_si] 108 | 109 | def calc_hats(self, tier): 110 | return self.rho * tier / tier.norm() 111 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/rotated_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | from torch.utils.data import DataLoader 19 | import torchvision 20 | from torchvision import transforms 21 | 22 | from .basic_dataset import FedDataset, BaseDataset 23 | from ...utils.dataset.functional import noniid_slicing, random_slicing 24 | 25 | 26 | class RotatedMNIST(FedDataset): 27 | """Rotate MNIST and partition them. 28 | 29 | Args: 30 | root (str): Path to download raw dataset. 31 | path (str): Path to save partitioned subdataset. 32 | num_clients (int): Number of clients. 33 | """ 34 | def __init__(self, root, path, num) -> None: 35 | self.root = os.path.expanduser(root) 36 | self.path = path 37 | self.num = num 38 | 39 | def preprocess(self, thetas=[0, 90, 180, 270], download=True): 40 | self.download = download 41 | # "./datasets/rotated_mnist/" 42 | if os.path.exists(self.path) is not True: 43 | os.mkdir(self.path) 44 | os.mkdir(os.path.join(self.path, "train")) 45 | os.mkdir(os.path.join(self.path, "var")) 46 | os.mkdir(os.path.join(self.path, "test")) 47 | 48 | # train 49 | mnist = torchvision.datasets.MNIST(self.root, 50 | train=True, 51 | download=self.download) 52 | id = 0 53 | to_tensor = transforms.ToTensor() 54 | for theta in thetas: 55 | rotated_data = [] 56 | labels = [] 57 | partition = random_slicing(mnist, int(self.num / len(thetas))) 58 | for x, y in mnist: 59 | x = to_tensor(transforms.functional.rotate(x, theta)) 60 | rotated_data.append(x) 61 | labels.append(y) 62 | for key, value in partition.items(): 63 | data = [rotated_data[i] for i in value] 64 | label = [labels[i] for i in value] 65 | dataset = BaseDataset(data, label) 66 | torch.save( 67 | dataset, 68 | os.path.join(self.dir, "train", "data{}.pkl".format(id))) 69 | id += 1 70 | 71 | # test 72 | mnist_test = torchvision.datasets.MNIST( 73 | self.root, 74 | train=False, 75 | download=self.download) 76 | labels = mnist_test.targets 77 | for i, theta in enumerate(thetas): 78 | rotated_data = [] 79 | for x, y in mnist_test: 80 | x = to_tensor(transforms.functional.rotate(x, theta)) 81 | rotated_data.append(x) 82 | dataset = BaseDataset(rotated_data, labels) 83 | torch.save(dataset, 84 | os.path.join(self.dir, "test", "data{}.pkl".format(i))) 85 | 86 | def get_dataset(self, id, type="train"): 87 | dataset = torch.load( 88 | os.path.join(self.dir, type, "data{}.pkl".format(id))) 89 | return dataset 90 | 91 | def get_data_loader(self, id, batch_size=None, type="train"): 92 | dataset = self.get_dataset(id, type) 93 | batch_size = len(dataset) if batch_size is None else batch_size 94 | data_loader = DataLoader(dataset, batch_size=batch_size) 95 | return data_loader 96 | -------------------------------------------------------------------------------- /fedlab/board/builtin/charts.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objects as go 2 | 3 | from fedlab.board import fedboard 4 | from fedlab.board.builtin.renderer import client_param_tsne, get_client_dataset_tsne, get_client_data_report 5 | 6 | 7 | def add_built_in_charts(): 8 | fedboard.add_section('dataset', 'normal') 9 | fedboard.add_section('parameters', 'slider') 10 | 11 | @fedboard.add_chart(section='parameters', figure_name='figure_tsne', span=1.0) 12 | def update_tsne_figure(value, selected_client, selected_colors, selected_ranks): 13 | tsne_data, id_existed = client_param_tsne(value, selected_client) 14 | if tsne_data is not None: 15 | data = [] 16 | for idx, cid in enumerate(id_existed): 17 | data.append(go.Scatter( 18 | x=[tsne_data[idx, 0]], y=[tsne_data[idx, 1]], mode='markers', 19 | marker=dict(color=selected_colors[idx], size=16), 20 | name=f'Client{cid}' 21 | )) 22 | tsne_figure = go.Figure(data=data, 23 | layout_title_text=f"Parameters t-SNE") 24 | else: 25 | tsne_figure = [] 26 | return tsne_figure 27 | 28 | @fedboard.add_chart(section='dataset', figure_name='figure_client_classes', span=0.5) 29 | def update_data_classes(selected_client, selected_colors, selected_ranks): 30 | client_targets = get_client_data_report(selected_client, 'train', selected_ranks) 31 | class_sizes: dict[str, dict[str, int]] = {} 32 | for cid, targets in client_targets.items(): 33 | for y in targets: 34 | class_sizes.setdefault(y, {id: 0 for id in selected_client}) 35 | class_sizes[y][cid] += 1 36 | client_classes = go.Figure( 37 | data=[ 38 | go.Bar(y=[f'Client{id}' for id in selected_client], 39 | x=[sizes[id] for id in selected_client], 40 | name=f'Class {clz}', orientation='h') 41 | # marker=dict(color=[viewModel.colors[id] for id in selected_client])), 42 | for clz, sizes in class_sizes.items() 43 | ], 44 | layout_title_text="Label Distribution" 45 | ) 46 | client_classes.update_layout(barmode='stack', margin=dict(l=48, r=48, b=64, t=86)) 47 | return client_classes 48 | 49 | @fedboard.add_chart(section='dataset', figure_name='figure_client_sizes', span=0.5) 50 | def update_data_sizes(selected_client, selected_colors, selected_ranks): 51 | client_targets = get_client_data_report(selected_client, 'train', selected_ranks) 52 | client_sizes = go.Figure( 53 | data=[go.Bar(x=[f'Client{n}' for n, _ in client_targets.items()], 54 | y=[len(ce) for _, ce in client_targets.items()], 55 | marker=dict(color=selected_colors))], 56 | layout_title_text="Dataset Sizes" 57 | ) 58 | client_sizes.update_layout(margin=dict(l=48, r=48, b=64, t=86)) 59 | return client_sizes 60 | 61 | @fedboard.add_chart(section='dataset', figure_name='figure_client_data_tsne', span=1.0) 62 | def update_data_tsne_value(selected_client, selected_colors, selected_ranks): 63 | tsne_data = get_client_dataset_tsne(selected_client, "train", 200, selected_ranks) 64 | if tsne_data is not None: 65 | data = [] 66 | for idx, cid in enumerate(selected_client): 67 | data.append(go.Scatter3d( 68 | x=tsne_data[cid][:, 0], y=tsne_data[cid][:, 1], z=tsne_data[cid][:, 2], mode='markers', 69 | marker=dict(color=selected_colors[idx], size=4, opacity=0.8), 70 | name=f'Client{cid}' 71 | )) 72 | else: 73 | data = [] 74 | tsne_figure = go.Figure(data=data, 75 | layout_title_text=f"Local Dataset t-SNE") 76 | tsne_figure.update_layout(margin=dict(l=48, r=48, b=64, t=64), dict1={"height": 600}) 77 | return tsne_figure 78 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/pathological_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | from torch.utils.data import DataLoader 19 | import torchvision 20 | from torchvision import transforms 21 | 22 | from .basic_dataset import FedDataset, BaseDataset 23 | from ...utils.dataset.functional import noniid_slicing, random_slicing 24 | from pathlib import Path 25 | 26 | class PathologicalMNIST(FedDataset): 27 | """The partition stratigy in FedAvg. See http://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com 28 | 29 | Args: 30 | root (str): Path to download raw dataset. 31 | path (str): Path to save partitioned subdataset. 32 | num_clients (int): Number of clients. 33 | shards (int, optional): Sort the dataset by the label, and uniformly partition them into shards. Then 34 | download (bool, optional): Download. Defaults to True. 35 | """ 36 | def __init__(self, data_root, num_clients=100, shards=200, download=True, preprocess=False, transform=None) -> None: 37 | self.data_root = data_root 38 | self.home = Path.home() 39 | self.num_clients = num_clients 40 | self.shards = shards 41 | self.transform = transform 42 | if preprocess: 43 | self.preprocess(download) 44 | 45 | def preprocess(self, download=True): 46 | # self.num_clients = num_clients≠– 47 | # self.shards = shards 48 | self.download = download 49 | 50 | # train 51 | mnist = torchvision.datasets.MNIST(self.data_root, train=True, download=self.download, 52 | transform=transforms.ToTensor()) 53 | data_indices = noniid_slicing(mnist, self.num_clients, self.shards) 54 | 55 | samples, labels = [], [] 56 | for x, y in mnist: 57 | samples.append(x) 58 | labels.append(y) 59 | for id, indices in data_indices.items(): 60 | data, label = [], [] 61 | for idx in indices: 62 | x, y = samples[idx], labels[idx] 63 | data.append(x) 64 | label.append(y) 65 | dataset = BaseDataset(data, label) 66 | torch.save(dataset, os.path.join(self.home, 'MNIST', "train", "data{}.pkl".format(id))) 67 | 68 | def get_dataset(self, id, type="train"): 69 | """Load subdataset for client with client ID ``cid`` from local file. 70 | 71 | Args: 72 | cid (int): client id 73 | type (str, optional): Dataset type, can be ``"train"``, ``"val"`` or ``"test"``. Default as ``"train"``. 74 | 75 | Returns: 76 | Dataset 77 | """ 78 | dataset = torch.load(os.path.join(self.home, type, "data{}.pkl".format(id))) 79 | return dataset 80 | 81 | def get_dataloader(self, id, batch_size=None, type="train"): 82 | """Return dataload for client with client ID ``cid``. 83 | 84 | Args: 85 | cid (int): client id 86 | batch_size (int, optional): batch size in DataLoader. 87 | type (str, optional): Dataset type, can be ``"train"``, ``"val"`` or ``"test"``. Default as ``"train"``. 88 | """ 89 | dataset = self.get_dataset(id, type) 90 | batch_size = len(dataset) if batch_size is None else batch_size 91 | data_loader = DataLoader(dataset, batch_size=batch_size) 92 | return data_loader 93 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: fedgf 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - anaconda 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - blas=1.0=mkl 12 | - bottleneck=1.3.5=py39h7deecbd_0 13 | - brotlipy=0.7.0=py39h27cfd23_1003 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2023.7.22=hbcca054_0 16 | - certifi=2023.7.22=pyhd8ed1ab_0 17 | - cffi=1.15.1=py39h5eee18b_3 18 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 19 | - cryptography=41.0.3=py39hdda0065_0 20 | - cuda-cudart=11.7.99=0 21 | - cuda-cupti=11.7.101=0 22 | - cuda-libraries=11.7.1=0 23 | - cuda-nvrtc=11.7.99=0 24 | - cuda-nvtx=11.7.91=0 25 | - cuda-runtime=11.7.1=0 26 | - ffmpeg=4.3=hf484d3e_0 27 | - freetype=2.12.1=h4a9f257_0 28 | - giflib=5.2.1=h5eee18b_3 29 | - gmp=6.2.1=h295c915_3 30 | - gnutls=3.6.15=he1e5248_0 31 | - idna=3.4=py39h06a4308_0 32 | - intel-openmp=2023.1.0=hdb19cb5_46305 33 | - joblib=1.1.1=py39h06a4308_0 34 | - jpeg=9e=h5eee18b_1 35 | - lame=3.100=h7b6447c_0 36 | - lcms2=2.12=h3be6417_0 37 | - ld_impl_linux-64=2.38=h1181459_1 38 | - lerc=3.0=h295c915_0 39 | - libcublas=11.10.3.66=0 40 | - libcufft=10.7.2.124=h4fbf590_0 41 | - libcufile=1.7.2.10=0 42 | - libcurand=10.3.3.141=0 43 | - libcusolver=11.4.0.1=0 44 | - libcusparse=11.7.4.91=0 45 | - libdeflate=1.17=h5eee18b_0 46 | - libffi=3.4.4=h6a678d5_0 47 | - libgcc-ng=11.2.0=h1234567_1 48 | - libgfortran-ng=11.2.0=h00389a5_1 49 | - libgfortran5=11.2.0=h1234567_1 50 | - libgomp=11.2.0=h1234567_1 51 | - libiconv=1.16=h7f8727e_2 52 | - libidn2=2.3.4=h5eee18b_0 53 | - libnpp=11.7.4.75=0 54 | - libnvjpeg=11.8.0.2=0 55 | - libpng=1.6.39=h5eee18b_0 56 | - libstdcxx-ng=11.2.0=h1234567_1 57 | - libtasn1=4.19.0=h5eee18b_0 58 | - libtiff=4.5.1=h6a678d5_0 59 | - libunistring=0.9.10=h27cfd23_0 60 | - libwebp=1.2.4=h11a3e52_1 61 | - libwebp-base=1.2.4=h5eee18b_1 62 | - lz4-c=1.9.4=h6a678d5_0 63 | - mkl=2023.1.0=h213fc3f_46343 64 | - mkl-service=2.4.0=py39h5eee18b_1 65 | - mkl_fft=1.3.8=py39h5eee18b_0 66 | - mkl_random=1.2.4=py39hdb19cb5_0 67 | - munch=2.5.0=pyhd3eb1b0_0 68 | - ncurses=6.4=h6a678d5_0 69 | - nettle=3.7.3=hbbd107a_1 70 | - numexpr=2.8.4=py39hc78ab66_1 71 | - numpy=1.25.2=py39h5f9d8c6_0 72 | - numpy-base=1.25.2=py39hb5e798b_0 73 | - openh264=2.1.1=h4ff587b_0 74 | - openssl=3.0.10=h7f8727e_2 75 | - pandas=1.5.2=py39h417a72b_0 76 | - pillow=9.4.0=py39h6a678d5_0 77 | - pip=23.2.1=py39h06a4308_0 78 | - pycparser=2.21=pyhd3eb1b0_0 79 | - pynvml=11.5.0=pyhd8ed1ab_0 80 | - pyopenssl=23.2.0=py39h06a4308_0 81 | - pysocks=1.7.1=py39h06a4308_0 82 | - python=3.9.18=h955ad1f_0 83 | - python-dateutil=2.8.2=pyhd3eb1b0_0 84 | - pytorch=1.13.0=py3.9_cuda11.7_cudnn8.5.0_0 85 | - pytorch-cuda=11.7=h778d358_5 86 | - pytorch-mutex=1.0=cuda 87 | - pytz=2022.7=py39h06a4308_0 88 | - readline=8.2=h5eee18b_0 89 | - requests=2.31.0=py39h06a4308_0 90 | - scikit-learn=1.2.0=py39h6a678d5_1 91 | - scipy=1.9.3=py39hf6e8229_2 92 | - setuptools=68.0.0=py39h06a4308_0 93 | - six=1.16.0=pyhd3eb1b0_1 94 | - sqlite=3.41.2=h5eee18b_0 95 | - tbb=2021.8.0=hdb19cb5_0 96 | - threadpoolctl=2.2.0=pyh0d69192_0 97 | - tk=8.6.12=h1ccaba5_0 98 | - torchaudio=0.13.0=py39_cu117 99 | - torchvision=0.14.0=py39_cu117 100 | - tqdm=4.64.1=py39h06a4308_0 101 | - typing_extensions=4.7.1=py39h06a4308_0 102 | - tzdata=2023c=h04d1e81_0 103 | - urllib3=1.26.16=py39h06a4308_0 104 | - wheel=0.38.4=py39h06a4308_0 105 | - xz=5.4.2=h5eee18b_0 106 | - zlib=1.2.13=h5eee18b_0 107 | - zstd=1.5.5=hc292b87_0 108 | - pip: 109 | - appdirs==1.4.4 110 | - click==8.1.7 111 | - docker-pycreds==0.4.0 112 | - gitdb==4.0.10 113 | - gitpython==3.1.37 114 | - lib==4.0.0 115 | - pathtools==0.1.2 116 | - protobuf==4.24.3 117 | - psutil==5.9.5 118 | - pyyaml==6.0.1 119 | - sentry-sdk==1.31.0 120 | - setproctitle==1.3.2 121 | - smmap==5.0.1 122 | - wandb==0.15.11 123 | #prefix: /home/user/.conda/envs/fedgf 124 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/scaffold.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basic_server import SyncServerHandler 4 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 5 | from ...utils import Aggregators 6 | 7 | 8 | ################## 9 | # 10 | # Server 11 | # 12 | ################## 13 | 14 | 15 | class ScaffoldServerHandler(SyncServerHandler): 16 | """FedAvg server handler.""" 17 | 18 | @property 19 | def downlink_package(self): 20 | return [self.model_parameters, self.global_c] 21 | 22 | def setup_optim(self, lr): 23 | self.lr = lr 24 | self.global_c = torch.zeros_like(self.model_parameters) 25 | 26 | def global_update(self, buffer): 27 | # unpack 28 | dys = [ele[0] for ele in buffer] 29 | dcs = [ele[1] for ele in buffer] 30 | 31 | dx = Aggregators.fedavg_aggregate(dys) 32 | dc = Aggregators.fedavg_aggregate(dcs) 33 | 34 | next_model = self.model_parameters + self.lr * dx 35 | self.set_model(next_model) 36 | 37 | self.global_c += 1.0 * len(dcs) / self.num_clients * dc 38 | 39 | 40 | ################## 41 | # 42 | # Client 43 | # 44 | ################## 45 | 46 | 47 | class ScaffoldSerialClientTrainer(SGDSerialClientTrainer): 48 | def setup_optim(self, epochs, batch_size, lr, weight_decay, momentum): 49 | super().setup_optim(epochs, batch_size, lr, weight_decay, momentum) 50 | self.cs = [None for _ in range(self.num_clients)] 51 | 52 | def local_process(self, payload, id_list): 53 | model_parameters = payload[0] 54 | global_c = payload[1] 55 | for id in id_list: 56 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 57 | pack = self.train(id, model_parameters, global_c, data_loader) 58 | self.cache.append(pack) 59 | 60 | def train(self, id, model_parameters, global_c, train_loader): 61 | self.set_model(model_parameters) 62 | frz_model = model_parameters 63 | 64 | if self.cs[id] is None: 65 | self.cs[id] = torch.zeros_like(model_parameters) 66 | 67 | for _ in range(self.epochs): 68 | for data, target in train_loader: 69 | if self.cuda: 70 | data = data.cuda(self.device) 71 | target = target.cuda(self.device) 72 | 73 | output = self.model(data) 74 | loss = self.criterion(output, target) 75 | 76 | self.optimizer.zero_grad() 77 | loss.backward() 78 | 79 | grad = self.model_gradients 80 | # grad = self.model_grads 81 | grad = grad - self.cs[id] + global_c 82 | idx = 0 83 | 84 | parameters = self._model.parameters() 85 | for p in self._model.state_dict().values(): 86 | if p.grad is None: # Batchnorm have no grad 87 | layer_size = p.numel() 88 | else: 89 | parameter = next(parameters) 90 | layer_size = parameter.data.numel() 91 | shape = parameter.grad.shape 92 | parameter.grad.data[:] = grad[idx:idx + layer_size].view(shape)[:] 93 | idx += layer_size 94 | 95 | # for parameter in self._model.parameters(): 96 | # layer_size = parameter.grad.numel() 97 | # shape = parameter.grad.shape 98 | # #parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape) 99 | # parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:] 100 | # idx += layer_size 101 | 102 | self.optimizer.step() 103 | 104 | dy = self.model_parameters - frz_model 105 | dc = -1.0 / (self.epochs * len(train_loader) * self.lr) * dy - global_c 106 | self.cs[id] += dc 107 | return [dy, dc] 108 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedprox.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | 4 | from .basic_server import SyncServerHandler 5 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 6 | 7 | 8 | ################## 9 | # 10 | # Server 11 | # 12 | ################## 13 | 14 | 15 | class FedProxServerHandler(SyncServerHandler): 16 | """FedProx server handler.""" 17 | None 18 | 19 | 20 | ################## 21 | # 22 | # Client 23 | # 24 | ################## 25 | 26 | class FedProxClientTrainer(SGDClientTrainer): 27 | """Federated client with local SGD with proximal term solver.""" 28 | def setup_optim(self, epochs, batch_size, lr, weight_decay, momentum, mu): 29 | super().setup_optim(epochs, batch_size, lr, weight_decay, momentum) 30 | self.mu = mu 31 | 32 | def local_process(self, payload, id): 33 | model_parameters = payload[0] 34 | train_loader = self.dataset.get_dataloader(id, self.batch_size) 35 | self.train(model_parameters, train_loader, self.mu) 36 | 37 | def train(self, model_parameters, train_loader, mu) -> None: 38 | """Client trains its local model on local dataset. 39 | 40 | Args: 41 | model_parameters (torch.Tensor): Serialized model parameters. 42 | """ 43 | self.set_model(model_parameters) 44 | frz_model = deepcopy(self._model) 45 | for ep in range(self.epochs): 46 | self._model.train() 47 | for data, target in train_loader: 48 | if self.cuda: 49 | data, target = data.cuda(self.device), target.cuda( 50 | self.device) 51 | 52 | preds = self._model(data) 53 | l1 = self.criterion(preds, target) 54 | l2 = 0.0 55 | for w0, w in zip(frz_model.parameters(), self._model.parameters()): 56 | l2 += torch.sum(torch.pow(w - w0, 2)) 57 | 58 | loss = l1 + 0.5 * mu * l2 59 | 60 | self.optimizer.zero_grad() 61 | loss.backward() 62 | self.optimizer.step() 63 | return [self.model_parameters] 64 | 65 | class FedProxSerialClientTrainer(SGDSerialClientTrainer): 66 | def setup_optim(self, epochs, batch_size, lr, weight_decay, momentum, mu): 67 | super().setup_optim(epochs, batch_size, lr, weight_decay, momentum) 68 | self.mu = mu 69 | 70 | def local_process(self, payload, id_list): 71 | model_parameters = payload[0] 72 | for id in id_list: 73 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 74 | pack = self.train(model_parameters, data_loader, self.mu) 75 | self.cache.append(pack) 76 | 77 | def train(self, model_parameters, train_loader, mu) -> None: 78 | """Client trains its local model on local dataset. 79 | 80 | Args: 81 | model_parameters (torch.Tensor): serialized model parameters. 82 | train_loader (torch.utils.data.DataLoader): :class:`torch.utils.data.DataLoader` for this client. 83 | mu (float): parameter of FedProx. 84 | 85 | """ 86 | self.set_model(model_parameters) 87 | frz_model = deepcopy(self._model) 88 | frz_model.eval() 89 | 90 | for ep in range(self.epochs): 91 | self._model.train() 92 | for data, target in train_loader: 93 | if self.cuda: 94 | data, target = data.cuda(self.device), target.cuda( 95 | self.device) 96 | 97 | preds = self._model(data) 98 | l1 = self.criterion(preds, target) 99 | l2 = 0.0 100 | for w0, w in zip(frz_model.parameters(), self._model.parameters()): 101 | l2 += torch.sum(torch.pow(w - w0, 2)) 102 | 103 | loss = l1 + 0.5 * mu * l2 104 | 105 | self.optimizer.zero_grad() 106 | loss.backward() 107 | self.optimizer.step() 108 | 109 | return [self.model_parameters] 110 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # System file 2 | *.DS_Store 3 | 4 | # Pycharm 5 | .idea 6 | 7 | # VS Code 8 | .vscode 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | #*.mo 64 | #*.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | # docs/_build/ 81 | docs/build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | 149 | # dataset 150 | fedlab_benchmarks/algorithm/fedavg/scale/shakespeare-rnn/pkl_dataset/* 151 | 152 | fedlab_benchmarks/datasets/data/mnist/* 153 | !fedlab_benchmarks/datasets/data/mnist/*.py 154 | 155 | 156 | fedlab_benchmarks/datasets/data/cifar10/* 157 | !fedlab_benchmarks/datasets/data/cifar10/*.py 158 | 159 | 160 | fedlab_benchmarks/datasets/data/*/data 161 | fedlab_benchmarks/datasets/data/*/meta 162 | tests/data/mnist/* 163 | actions-runner 164 | 165 | # tests 166 | fedlab_benchmarks/algorithm/fedavg/standalone/*.txt 167 | fedlab_benchmarks/algorithm/fedavg/standalone/*.ipynb 168 | fedlab_benchmarks/algorithm/fedavg/standalone/*.sh 169 | tests/data/mnist 170 | codecov 171 | codecov_upload.sh 172 | pypi_release.sh 173 | 174 | tmp_test 175 | 176 | unit_test.ipynb 177 | /exp_logs 178 | # release cache 179 | 180 | /datasets/mnist/*mnist 181 | /datasets/mnist/fedmnist* 182 | /datasets/femnist/data 183 | /datasets/femnist/meta 184 | /datasets/celeba/data 185 | /datasets/celeba/meta 186 | /datasets/sent140/data 187 | /datasets/shakespeare/data 188 | /datasets/shakespeare/meta 189 | 190 | 191 | test_client.py 192 | test_server.py 193 | unit_test* 194 | 195 | /datasets/*/data 196 | /datasets/*/meta 197 | /datasets/*/data/* 198 | /datasets/*/meta/* 199 | /datasets/*/train/* 200 | /datasets/*/test/* 201 | -------------------------------------------------------------------------------- /tools/main.py: -------------------------------------------------------------------------------- 1 | from json import load 2 | import os 3 | import random 4 | from copy import deepcopy 5 | from torch import nn 6 | import sys 7 | import torch 8 | import numpy as np 9 | import pandas as pd 10 | 11 | sys.path.append("../") # To use fedlab library 12 | DATA_ROOT = '/data/' 13 | JSON_PATH = os.path.join(os.path.dirname(__file__), 'json_data') 14 | torch.manual_seed(0) 15 | 16 | from fedlab.utils.aggregator import Aggregators 17 | from fedlab.utils.serialization import SerializationTool 18 | from fedlab.utils.functional import evaluate, get_best_gpu, val_eval, schedule_cycling_lr 19 | 20 | from fedlab.core.standalone import StandalonePipeline 21 | 22 | from Lib.algorithms import load_algorithms, average 23 | from Lib.datasets import load_datasets 24 | from Lib.arg_parser import get_parser 25 | from Lib.models import get_model 26 | import wandb 27 | 28 | # configuration 29 | args = get_parser() 30 | c_values = [] 31 | 32 | if 'debug' in args.wandb_project_name: 33 | mode = 'disabled' 34 | else: 35 | mode = 'online' 36 | 37 | wandb.init(project=args.wandb_project_name, mode=mode) 38 | wandb.config.update(args) 39 | model = get_model(args) 40 | 41 | handler, trainer = load_algorithms(args, model) 42 | train_dataset, test_loader = load_datasets(args, DATA_ROOT, JSON_PATH, trainer) 43 | 44 | num_rounds = args.com_round 45 | eval_every = args.eval_every 46 | 47 | accuracy = [] 48 | val_accuracy = [] 49 | handler.num_clients = trainer.num_clients 50 | 51 | wandb_dict = {} 52 | 53 | while handler.if_stop is False: 54 | wandb_dict.clear() 55 | 56 | sampled_clients = handler.sample_clients() 57 | if 'debug' in args.wandb_project_name: 58 | print(f"round:{handler.round},clients:{sampled_clients}") 59 | broadcast = handler.downlink_package 60 | 61 | trainer.local_process(broadcast, sampled_clients) 62 | uploads = trainer.uplink_package 63 | 64 | # server side 65 | for pack in uploads: 66 | handler.load(pack) 67 | 68 | FLAG = False 69 | if (args.avg_test and ((handler.round >= args.com_round - 100) or ((handler.round // eval_every) and (handler.round % eval_every) < 100))) or\ 70 | (handler.round >= args.com_round - 100) or ((handler.round // eval_every) and not handler.round % eval_every): 71 | 72 | FLAG = True 73 | val_loss, val_acc = val_eval(handler._model, nn.CrossEntropyLoss(), sampled_clients, train_dataset) 74 | loss, acc = evaluate(handler._model, nn.CrossEntropyLoss(), test_loader) 75 | 76 | wandb_dict.update({"test loss": loss, "test acc.": acc * 100, "val loss": val_loss, "val acc": val_acc * 100}) 77 | 78 | accuracy.append(acc * 100) 79 | val_accuracy.append(val_acc * 100) 80 | 81 | print("round {}, Test Accuracy: {:.4f}, Max Acc: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( 82 | handler.round, acc * 100, max(accuracy), val_loss, val_acc)) 83 | 84 | elif handler.round == 0: 85 | FLAG = True 86 | val_loss, val_acc = val_eval(handler._model, nn.CrossEntropyLoss(), sampled_clients, train_dataset) 87 | loss, acc = evaluate(handler._model, nn.CrossEntropyLoss(), test_loader) 88 | 89 | wandb_dict.update({"test loss": loss, "test acc.": acc * 100, "val loss": val_loss, "val acc": val_acc * 100}) 90 | 91 | 92 | if FLAG: 93 | print(wandb_dict) 94 | wandb.log(wandb_dict, step=handler.round) 95 | 96 | handler.round += 1 97 | 98 | if args.save_model: 99 | save_info = {'model_state_dict': handler.model.state_dict(), 100 | 'round': handler.round, 101 | 'args': args} 102 | torch.save(save_info, os.path.join(wandb.run.dir, f"Algo{args.alg}_Data{args.dataset}_Alp{args.dir_alpha}.ckpt")) 103 | 104 | wandb.log({ 105 | "Eval:Avg Acc.": round(average(accuracy[-100:]), 2), 106 | "Eval:Max Acc.": round(max(accuracy), 2), 107 | "Eval:std": round(np.array(accuracy[-100:]).std(), 2), 108 | "Val:Avg Acc.": round(average(val_accuracy[-100:]), 2), 109 | "Val:Max Acc.": round(max(val_accuracy), 2), 110 | "Val:std": round(np.array(val_accuracy[-100:]).std(), 2), 111 | }) 112 | -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/mofedsam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import copy 4 | from .basic_server import SyncServerHandler 5 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 6 | from .fedavg import FedAvgServerHandler 7 | from .minimizers import MoSAM 8 | from ...utils import Aggregators 9 | from ...utils.serialization import SerializationTool 10 | # batchnorm 에 대한 term 은 momentum으로 안바꾸도록. 11 | ################## 12 | # 13 | # Server 14 | # 15 | ################## 16 | 17 | 18 | # class FedSAMServerHandler(SyncServerHandler): 19 | class MoFedSamServerHandler(FedAvgServerHandler): 20 | @property 21 | def downlink_package(self): 22 | return [self.model_parameters, self.delta] 23 | 24 | def setup_optim(self, eta_l, eta_g=1): 25 | self.delta = torch.zeros_like(self.model_parameters) 26 | self.eta_l = eta_l 27 | self.eta_g = eta_g 28 | # self.K = K 29 | 30 | def global_update(self, buffer): 31 | self.delta = self.calc_momentum(buffer) # grad = theta_prev - theta_current 32 | # super().global_update(buffer) 33 | serialized_parameters = self.model_parameters - self.delta*self.eta_g 34 | self.set_model(serialized_parameters) 35 | # self.set_momentum() 36 | # parameters_list = [ele[0] for ele in buffer] 37 | # weights = torch.tensor([ele[1] for ele in buffer]).to(self.device) 38 | # serialized_parameters = Aggregators.fedavg_aggregate(parameters_list, weights) 39 | # SerializationTool.deserialize_model(self._model, serialized_parameters) 40 | def set_momentum(self): 41 | SerializationTool.deserialize_model(self.delta, self.delta_parameters) 42 | 43 | def calc_momentum(self, buffer): 44 | # parameters_list = [ele[0] for ele in buffer] 45 | # weights = torch.tensor(weights) 46 | # weights = weights / torch.sum(weights) 47 | # S = len(buffer) 48 | 49 | eta_l = self.eta_l 50 | weights = [ele[1] for ele in buffer] 51 | K = np.array(weights).mean() 52 | 53 | # K = self.K 54 | # K = buffer[0][2] # number of epoch 55 | gradient_list = [ 56 | torch.sub(ele[0], self.model_parameters) for idx, ele in enumerate(buffer) 57 | ] 58 | 59 | delta = torch.mean(torch.stack(gradient_list, dim=0), dim=0) 60 | delta.div_(-1*eta_l*K) 61 | return delta 62 | 63 | ################## 64 | # 65 | # Client 66 | # 67 | ################## 68 | 69 | 70 | class MoFedSamSerialClientTrainer(SGDSerialClientTrainer): 71 | def __init__(self, model, num_clients, rho, beta, cuda=True, device=None, logger=None, personal=False) -> None: 72 | super().__init__(model, num_clients, cuda, device, logger, personal) 73 | self.rho = rho 74 | self.beta = beta 75 | 76 | def local_process(self, payload, id_list): 77 | model_parameters = payload[0] 78 | delta = payload[1] 79 | # model_parameters_np = payload[2] 80 | 81 | for id in id_list: 82 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 83 | minimizer = MoSAM(self.optimizer, self.model, self.rho, self.beta, delta) 84 | pack = self.train(id, model_parameters, minimizer, data_loader) 85 | self.cache.append(pack) 86 | 87 | def train(self, id, model_parameters, minimizer, train_loader): 88 | self.set_model(model_parameters) 89 | 90 | # data_size = 0 91 | num_update = 0 92 | for _ in range(self.epochs): 93 | for data, target in train_loader: 94 | if self.cuda: 95 | data = data.cuda(self.device) 96 | target = target.cuda(self.device) 97 | 98 | # Ascent Step 99 | output = self.model(data) 100 | loss = self.criterion(output, target) 101 | 102 | loss.backward() 103 | minimizer.ascent_step() 104 | 105 | # Descent Step 106 | self.criterion(self.model(data), target).backward() 107 | minimizer.descent_step() 108 | 109 | # data_size += len(target) 110 | num_update += 1 111 | 112 | return [self.model_parameters, num_update] 113 | -------------------------------------------------------------------------------- /tools/Lib/datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os.path 3 | 4 | from fedlab.contrib.dataset.pathological_mnist import PathologicalMNIST 5 | from fedlab.contrib.dataset.partitioned_cifar10 import PartitionedCIFAR10 6 | from fedlab.contrib.dataset.partitioned_cifar100 import PartitionedCIFAR100 7 | 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import DataLoader 10 | import torchvision 11 | 12 | IMAGE_SIZE = 32 13 | 14 | def load_datasets(args, data_root, json_path, trainer): 15 | json_path = os.path.join(json_path, args.dataset) 16 | 17 | if args.dataset == 'mnist': 18 | 19 | transform = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.1307,), (0.3081,)) 22 | ]) 23 | 24 | dataset = PathologicalMNIST(data_root=data_root, num_clients=args.total_client, transform=transform) 25 | 26 | dataset.preprocess() 27 | 28 | trainer.setup_dataset(dataset) 29 | test_data = torchvision.datasets.MNIST(root=data_root, 30 | train=False, 31 | transform=transform) 32 | test_loader = DataLoader(test_data, batch_size=1024) 33 | elif args.dataset == 'cifar10': 34 | 35 | if args.transform: 36 | train_transform = transforms.Compose([ 37 | transforms.RandomCrop(IMAGE_SIZE, padding=4), 38 | transforms.RandomHorizontalFlip(), 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))]) 41 | else: 42 | train_transform = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))]) 45 | 46 | test_transform = transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))]) 49 | 50 | dataset = PartitionedCIFAR10(data_root=data_root, num_clients=args.total_client, transform=train_transform, json_path=json_path, dir_alpha=args.dir_alpha) 51 | dataset.preprocess(balance=args.balance, 52 | partition=args.partition, 53 | dir_alpha=args.dir_alpha, 54 | batch_size=args.batch_size, 55 | ) 56 | 57 | trainer.setup_dataset(dataset) 58 | test_data = torchvision.datasets.CIFAR10(root=data_root, 59 | train=False, 60 | transform=test_transform) 61 | test_loader = DataLoader(test_data, batch_size=64, shuffle=False) 62 | elif args.dataset == 'cifar100': 63 | if args.transform: 64 | train_transform = transforms.Compose([ 65 | transforms.RandomCrop(IMAGE_SIZE, padding=4), 66 | transforms.RandomHorizontalFlip(), 67 | transforms.ToTensor(), 68 | transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675, 0.2565, 0.2761))]) 69 | else: 70 | train_transform = transforms.Compose([ 71 | transforms.ToTensor(), 72 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]) 73 | 74 | test_transform = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 77 | ]) 78 | 79 | dataset = PartitionedCIFAR100(data_root=data_root, num_clients=args.total_client, transform=train_transform, json_path=json_path, dir_alpha=args.dir_alpha) 80 | dataset.preprocess(balance=args.balance, 81 | partition=args.partition, 82 | dir_alpha=args.dir_alpha, 83 | batch_size=args.batch_size, 84 | ) 85 | 86 | trainer.setup_dataset(dataset) 87 | test_data = torchvision.datasets.CIFAR100(root=data_root, 88 | train=False, 89 | transform=test_transform) 90 | test_loader = DataLoader(test_data, batch_size=64, shuffle=False) 91 | else: 92 | raise ValueError(f"check args.dataset") 93 | 94 | return dataset, test_loader 95 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/adult.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from sklearn.datasets import load_svmlight_file 17 | import random 18 | import os 19 | from urllib.request import urlretrieve 20 | 21 | from torch.utils.data import Dataset 22 | 23 | 24 | class Adult(Dataset): 25 | """`Adult `_ dataset from `LIBSVM Data `_. 26 | 27 | Args: 28 | root (str): Root directory of raw dataset to download if ``download`` is set to ``True``. 29 | train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. 30 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default as ``None``. 31 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. Default as ``None``. 32 | download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. 33 | 34 | """ 35 | url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/" 36 | train_file_name = "a9a" 37 | test_file_name = "a9a.t" 38 | num_classes = 2 39 | num_features = 123 40 | 41 | def __init__(self, root, train=True, 42 | transform=None, 43 | target_transform=None, 44 | download=False): 45 | self.root = root 46 | self.train = train 47 | self.transform = transform 48 | self.target_transform = target_transform 49 | 50 | if not os.path.exists(root): 51 | os.mkdir(root) 52 | 53 | if self.train: 54 | self.full_file_path = os.path.join(self.root, self.train_file_name) 55 | else: 56 | self.full_file_path = os.path.join(self.root, self.test_file_name) 57 | 58 | if download: 59 | self.download() 60 | 61 | if not self._local_file_existence(): 62 | raise RuntimeError( 63 | f"Adult-a9a source data file not found. You can use download=True to " 64 | f"download it.") 65 | 66 | # now load from source file 67 | X, y = load_svmlight_file(self.full_file_path) 68 | X = X.todense() # transform 69 | 70 | if not self.train: 71 | X = np.c_[X, np.zeros((len(y), 1))] # append a zero vector at the end of X_test 72 | 73 | X = np.array(X, dtype=np.float32) 74 | y = np.array((y + 1) / 2, dtype=np.int32) # map elements of y from {-1, 1} to {0, 1} 75 | print(f"Local file {self.full_file_path} loaded.") 76 | self.data, self.targets = X, y 77 | 78 | def download(self): 79 | if self._local_file_existence(): 80 | print(f"Source file already downloaded.") 81 | return 82 | 83 | if self.train: 84 | download_url = self.url + self.train_file_name 85 | else: 86 | download_url = self.url + self.test_file_name 87 | 88 | urlretrieve(download_url, self.full_file_path) 89 | 90 | def _local_file_existence(self): 91 | return os.path.exists(self.full_file_path) 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | 98 | Returns: 99 | tuple: (features, target) where target is index of the target class. 100 | """ 101 | data, label = self.data[index], self.targets[index] 102 | 103 | if self.transform is not None: 104 | data = self.transform(data) 105 | 106 | if self.target_transform is not None: 107 | label = self.target_transform(label) 108 | 109 | return data, label 110 | 111 | def __len__(self): 112 | return len(self.targets) 113 | 114 | def extra_repr(self) -> str: 115 | return "Split: {}".format("Train" if self.train is True else "Test") 116 | -------------------------------------------------------------------------------- /fedlab/models/resnet_cifar100_del_batch.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | # self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | # self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | # nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | # out = F.relu(self.bn1(self.conv1(x))) 36 | out = F.relu(self.conv1(x)) 37 | # out = self.bn2(self.conv2(out)) 38 | out = self.conv2(out) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(Bottleneck, self).__init__() 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 50 | # self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 52 | stride=stride, padding=1, bias=False) 53 | # self.bn2 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion * 55 | planes, kernel_size=1, bias=False) 56 | # self.bn3 = nn.BatchNorm2d(self.expansion*planes) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or in_planes != self.expansion*planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion*planes, 62 | kernel_size=1, stride=stride, bias=False), 63 | # nn.BatchNorm2d(self.expansion*planes) 64 | ) 65 | 66 | def forward(self, x): 67 | # out = F.relu(self.bn1(self.conv1(x))) 68 | out = F.relu(self.conv1(x)) 69 | # out = F.relu(self.bn2(self.conv2(out))) 70 | out = F.relu(self.conv2(out)) 71 | # out = self.bn3(self.conv3(out)) 72 | out = self.conv3(out) 73 | out += self.shortcut(x) 74 | out = F.relu(out) 75 | return out 76 | 77 | 78 | class ResNet(nn.Module): 79 | def __init__(self, block, num_blocks, num_classes=100): 80 | super(ResNet, self).__init__() 81 | self.in_planes = 64 82 | 83 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 84 | stride=1, padding=1, bias=False) 85 | # self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512*block.expansion, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes * block.expansion 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | # out = F.relu(self.bn1(self.conv1(x))) 102 | out = F.relu(self.conv1(x)) 103 | out = self.layer1(out) 104 | out = self.layer2(out) 105 | out = self.layer3(out) 106 | out = self.layer4(out) 107 | out = F.avg_pool2d(out, 4) 108 | out = out.view(out.size(0), -1) 109 | out = self.linear(out) 110 | return out 111 | 112 | 113 | def ResNet18(): 114 | return ResNet(BasicBlock, [2, 2, 2, 2]) 115 | 116 | 117 | def ResNet34(): 118 | return ResNet(BasicBlock, [3, 4, 6, 3]) 119 | 120 | 121 | def ResNet50(): 122 | return ResNet(Bottleneck, [3, 4, 6, 3]) 123 | 124 | 125 | def ResNet101(): 126 | return ResNet(Bottleneck, [3, 4, 23, 3]) 127 | 128 | 129 | def ResNet152(): 130 | return ResNet(Bottleneck, [3, 8, 36, 3]) 131 | 132 | 133 | def test(): 134 | net = ResNet18() 135 | y = net(torch.randn(1, 3, 32, 32)) 136 | print(y.size()) 137 | 138 | # test() -------------------------------------------------------------------------------- /fedlab/models/rnn.py: -------------------------------------------------------------------------------- 1 | """RNN model in pytorch 2 | References: 3 | [1] H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Agueray Arcas. 4 | Communication-Efficient Learning of Deep Networks from Decentralized Data. AISTATS 2017. 5 | https://arxiv.org/abs/1602.05629 6 | [2] Reddi S, Charles Z, Zaheer M, et al. 7 | Adaptive Federated Optimization. ICML 2020. 8 | https://arxiv.org/pdf/2003.00295.pdf 9 | """ 10 | import torch.nn as nn 11 | import torch 12 | 13 | 14 | class RNN_Shakespeare(nn.Module): 15 | def __init__(self, vocab_size=80, embedding_dim=8, hidden_size=256): 16 | """Creates a RNN model using LSTM layers for Shakespeare language models (next character prediction task). 17 | Args: 18 | vocab_size (int, optional): the size of the vocabulary, used as a dimension in the input embedding, 19 | Defaults to 80. 20 | embedding_dim (int, optional): the size of embedding vector size, used as a dimension in the output embedding, 21 | Defaults to 8. 22 | hidden_size (int, optional): the size of hidden layer. Defaults to 256. 23 | Returns: 24 | A `torch.nn.Module`. 25 | Examples: 26 | RNN_Shakespeare( 27 | (embeddings): Embedding(80, 8, padding_idx=0) 28 | (lstm): LSTM(8, 256, num_layers=2, batch_first=True) 29 | (fc): Linear(in_features=256, out_features=90, bias=True) 30 | ), total 819920 parameters 31 | """ 32 | super(RNN_Shakespeare, self).__init__() 33 | self.embeddings = nn.Embedding(num_embeddings=vocab_size, 34 | embedding_dim=embedding_dim, 35 | padding_idx=0) 36 | self.lstm = nn.LSTM(input_size=embedding_dim, 37 | hidden_size=hidden_size, 38 | num_layers=2, 39 | batch_first=True) 40 | self.fc = nn.Linear(hidden_size, vocab_size) 41 | 42 | def forward(self, input_seq): 43 | embeds = self.embeddings(input_seq) # (batch, seq_len, embedding_dim) 44 | lstm_out, _ = self.lstm(embeds) 45 | final_hidden_state = lstm_out[:, -1] 46 | output = self.fc(final_hidden_state) 47 | return output 48 | 49 | 50 | class LSTMModel(nn.Module): 51 | def __init__(self, 52 | vocab_size, embedding_dim, hidden_size, num_layers, output_dim, pad_idx=0, 53 | using_pretrained=False, embedding_weights=None, bid=False): 54 | """Creates a RNN model using LSTM layers providing embedding_weights to pretrain 55 | Args: 56 | vocab_size (int): the size of the vocabulary, used as a dimension in the input embedding 57 | embedding_dim (int): the size of embedding vector size, used as a dimension in the output embedding 58 | hidden_size (int): the size of hidden layer, e.g. `256` 59 | num_layers (int): the number of recurrent layers, e.g. `2` 60 | output_dim (int): the dimension of output, e.g. `10` 61 | pad_idx (int): the index of pad_token 62 | using_pretrained (bool, optional): if use embedding vector to pretrain model, set `True`, defaults to `False` 63 | embedding_weights (torch.Tensor, optional): vectors to pretrain model, defaults to `None` 64 | bid (bool, optional): if use bidirectional LSTM model, set `True`, defaults to `False` 65 | Returns: 66 | A `torch.nn.Module`. 67 | """ 68 | super(LSTMModel, self).__init__() 69 | self.embeddings = nn.Embedding(num_embeddings=vocab_size, 70 | embedding_dim=embedding_dim, 71 | padding_idx=pad_idx) 72 | if using_pretrained: 73 | assert embedding_weights.shape[0] == vocab_size 74 | assert embedding_weights.shape[1] == embedding_dim 75 | self.embeddings.from_pretrained(embedding_weights) 76 | # self.embedding.weight.data.copy_(embedding_weights) 77 | 78 | self.dropout = nn.Dropout(0.5) 79 | self.encoder = nn.LSTM( 80 | input_size=embedding_dim, 81 | hidden_size=hidden_size, 82 | num_layers=num_layers, 83 | bidirectional=bid, 84 | dropout=0.3, 85 | batch_first=True 86 | ) 87 | 88 | # using bidrectional, *2 89 | if bid: 90 | hidden_size *= 2 91 | self.fc = nn.Sequential( 92 | nn.Dropout(0.5), 93 | nn.Linear(hidden_size, output_dim), 94 | ) 95 | 96 | def forward(self, input_seq: torch.Tensor): 97 | embeds = self.embeddings(input_seq) # (batch, seq_len, embedding_dim) 98 | embeds = self.dropout(embeds) 99 | lstm_out, _ = self.encoder(embeds) 100 | # outputs [seq_len, batch, hidden*2] *2 means using bidrectional 101 | final_hidden_state = lstm_out[:, -1] 102 | output = self.fc(final_hidden_state) 103 | return output -------------------------------------------------------------------------------- /fedlab/contrib/algorithm/fedgf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from collections import OrderedDict 4 | 5 | from .basic_server import SyncServerHandler 6 | from .basic_client import SGDClientTrainer, SGDSerialClientTrainer 7 | from .fedavg import FedAvgServerHandler 8 | from .minimizers import SAM 9 | from ...utils import Aggregators 10 | from .bypass_bn import disable_running_stats, enable_running_stats 11 | from statistics import mean 12 | from collections import defaultdict 13 | 14 | ################## 15 | # 16 | # Server 17 | # 18 | ################## 19 | 20 | 21 | class FedGfServerHandler(FedAvgServerHandler): 22 | """FedAvg server handler.""" 23 | @property 24 | def downlink_package(self): 25 | return [self.model_parameters, self.perturbed_model_parameters, self.c] 26 | 27 | def setup_optim(self, g_rho, T_D, W): 28 | self.perturbed_model_parameters = None 29 | self.c = 0 30 | self.g_rho = g_rho 31 | 32 | self.window = [] 33 | self.T_D = T_D 34 | self.W = W 35 | self.pseudo_gradient = None 36 | 37 | def global_update(self, buffer, upload_res=False): 38 | self.calc_c(buffer) 39 | pseudo_gradient = self.model_parameters 40 | super().global_update(buffer) 41 | # Updated average model 42 | pseudo_gradient.sub_(self.model_parameters) 43 | self.pseudo_gradient = pseudo_gradient 44 | self.calc_perturbation(pseudo_gradient) 45 | 46 | def calc_c(self, buffer): 47 | Divergence_metric = torch.tensor([torch.norm(torch.sub(self.model_parameters, ele[0])).item() for ele in buffer]) 48 | tot_norm = torch.div(torch.sum(Divergence_metric), len(Divergence_metric)).item() 49 | self.append_grad_norm(tot_norm) 50 | self.norm_grad = tot_norm 51 | self.c = mean(self.window) 52 | 53 | # check return value (int, float) 54 | def calc_avg_norm_grad(self, parameters_list): 55 | total_norm = 0 56 | for param in parameters_list: 57 | total_norm += param.norm(2) 58 | return total_norm.item() / len(parameters_list) 59 | 60 | def append_grad_norm(self, grad_norm): 61 | x = 1 if grad_norm > self.T_D else 0 62 | self.window.append(x) 63 | if len(self.window) > self.W: 64 | del (self.window[0]) 65 | 66 | def calc_perturbation(self, grad): 67 | # Calculate the perturbation using parameters (always) 68 | self.perturbed_model_parameters = copy.deepcopy(self.model_parameters) 69 | grad.div_(grad.norm(2)).mul_(self.g_rho) 70 | self.perturbed_model_parameters.add_(grad) 71 | 72 | 73 | ################## 74 | # 75 | # Client 76 | # 77 | ################## 78 | 79 | 80 | class FedGfSerialClientTrainer(SGDSerialClientTrainer): 81 | def __init__(self, model, num_clients, rho, cuda=True, device=None, logger=None, personal=False) -> None: 82 | super().__init__(model, num_clients, cuda, device, logger, personal) 83 | self.rho = rho 84 | 85 | def local_process(self, payload, id_list): 86 | model_parameters = payload[0] 87 | perturb_parameters = payload[1] 88 | c = payload[2] 89 | 90 | for id in id_list: 91 | data_loader = self.dataset.get_dataloader(id, self.batch_size) 92 | minimizer = SAM(self.optimizer, self.model, self.rho) 93 | pack = self.train(id, model_parameters, minimizer, data_loader, perturb_parameters, c) 94 | self.cache.append(pack) 95 | 96 | def train(self, id, model_parameters, minimizer, train_loader, perturb_parameters, c): 97 | self.set_model(model_parameters) 98 | 99 | data_size = 0 100 | for _ in range(self.epochs): 101 | for data, target in train_loader: 102 | if self.cuda: 103 | data = data.cuda(self.device) 104 | target = target.cuda(self.device) 105 | 106 | # Ascent Step 107 | init_model = None 108 | 109 | output = self.model(data) 110 | if perturb_parameters is not None: 111 | init_model = self.model_parameters 112 | loss = self.criterion(output, target) 113 | 114 | loss.backward() 115 | minimizer.ascent_step() 116 | if perturb_parameters is not None: 117 | self.weighted_sum(perturb_parameters, c) 118 | 119 | # Descent Step 120 | self.criterion(self.model(data), target).backward() 121 | if perturb_parameters is not None: 122 | self.set_model(init_model) 123 | 124 | minimizer.descent_step(init_model) 125 | 126 | data_size += len(target) 127 | return [self.model_parameters, data_size] 128 | 129 | def weighted_sum(self, perturb_model, c): 130 | """ 131 | perturb_model에 c를 곱해 + multiply c된 global model을 1-c만큼 multiply한 local model에 더해 shift함. 132 | """ 133 | 134 | model_parameters = perturb_model*c + self.model_parameters * (1-c) 135 | self.set_model(model_parameters) 136 | -------------------------------------------------------------------------------- /fedlab/contrib/dataset/basic_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch.utils.data import Dataset 16 | import os 17 | 18 | from PIL import Image 19 | import numpy as np 20 | 21 | 22 | 23 | class BaseDataset(Dataset): 24 | """Base dataset iterator""" 25 | 26 | def __init__(self, x, y, train_transform): 27 | self.x = x 28 | self.y = y 29 | self.train_transform = train_transform 30 | def __len__(self): 31 | return len(self.y) 32 | 33 | def __getitem__(self, index): 34 | return self.train_transform(self.x[index]), self.y[index] 35 | 36 | 37 | class Subset(Dataset): 38 | """For data subset with different augmentation for different client. 39 | 40 | Args: 41 | dataset (Dataset): The whole Dataset 42 | indices (List[int]): Indices of sub-dataset to achieve from ``dataset``. 43 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. 44 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 45 | """ 46 | 47 | def __init__(self, dataset, indices, transform=None, target_transform=None): 48 | self.data = [] 49 | for idx in indices: 50 | self.data.append(dataset.data[idx]) 51 | 52 | if not isinstance(dataset.targets, np.ndarray): 53 | dataset.targets = np.array(dataset.targets) 54 | 55 | self.targets = dataset.targets[indices].tolist() 56 | 57 | self.transform = transform 58 | self.target_transform = target_transform 59 | 60 | def __getitem__(self, index): 61 | """Get item 62 | 63 | Args: 64 | index (int): index 65 | 66 | Returns: 67 | (image, target) where target is index of the target class. 68 | """ 69 | img, label = self.data[index], self.targets[index] 70 | 71 | if self.transform is not None: 72 | img = self.transform(img) 73 | if self.target_transform is not None: 74 | label = self.target_transform(label) 75 | 76 | return img, label 77 | 78 | def __len__(self): 79 | return len(self.targets) 80 | 81 | class CIFARSubset(Subset): 82 | """For data subset with different augmentation for different client. 83 | 84 | Args: 85 | dataset (Dataset): The whole Dataset 86 | indices (List[int]): Indices of sub-dataset to achieve from ``dataset``. 87 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. 88 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 89 | """ 90 | def __init__(self, 91 | dataset, 92 | indices, 93 | transform=None, 94 | target_transform=None, 95 | to_image=True): 96 | self.data = [] 97 | for idx in indices: 98 | if to_image: 99 | self.data.append(Image.fromarray(dataset.data[idx])) 100 | if not isinstance(dataset.targets, np.ndarray): 101 | dataset.targets = np.array(dataset.targets) 102 | self.targets = dataset.targets[indices].tolist() 103 | self.transform = transform 104 | self.target_transform = target_transform 105 | 106 | 107 | class FedDataset(object): 108 | def __init__(self) -> None: 109 | self.num = None # the number of dataset indexed from 0 to num-1. 110 | self.root = None # the raw dataset. 111 | self.path = None # path to save the partitioned datasets. 112 | 113 | def preprocess(self): 114 | """Define the dataset partition process""" 115 | if os.path.exists(self.path) is not True: 116 | os.mkdir(self.path) 117 | os.mkdir(os.path.join(self.path, "train")) 118 | os.mkdir(os.path.join(self.path, "var")) 119 | os.mkdir(os.path.join(self.path, "test")) 120 | 121 | def get_dataset(self, id, type="train"): 122 | """Get dataset class 123 | 124 | Args: 125 | id (int): Client ID for the partial dataset to achieve. 126 | type (str, optional): Type of dataset, can be chosen from ``["train", "val", "test"]``. Defaults as ``"train"``. 127 | 128 | Raises: 129 | NotImplementedError 130 | """ 131 | raise NotImplementedError() 132 | 133 | def get_dataloader(self, id, batch_size, type="train"): 134 | """Get data loader""" 135 | raise NotImplementedError() 136 | 137 | def __len__(self): 138 | return self.num 139 | -------------------------------------------------------------------------------- /fedlab/core/communicator/processor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | import torch 18 | import torch.distributed as dist 19 | 20 | from .package import Package 21 | from . import HEADER_DATA_TYPE_IDX, HEADER_SIZE, HEADER_RECEIVER_RANK_IDX, HEADER_SLICE_SIZE_IDX, dtype_flab2torch, dtype_torch2flab 22 | 23 | 24 | class PackageProcessor(object): 25 | """Provide more flexible distributed tensor communication functions based on 26 | :func:`torch.distributed.send` and :func:`torch.distributed.recv`. 27 | 28 | :class:`PackageProcessor` defines the details of point-to-point package communication. 29 | 30 | EVERYTHING is :class:`torch.Tensor` in FedLab. 31 | """ 32 | 33 | @staticmethod 34 | def send_package(package, dst): 35 | """Three-segment tensor communication pattern based on ``torch.distributed`` 36 | 37 | Pattern is shown as follows: 38 | 1.1 sender: send a header tensor containing ``slice_size`` to receiver 39 | 40 | 1.2 receiver: receive the header, and get the value of ``slice_size`` and create a buffer for incoming slices of content 41 | 42 | 2.1 sender: send a list of slices indicating the size of every content size. 43 | 44 | 2.2 receiver: receive the slices list. 45 | 46 | 3.1 sender: send a content tensor composed of a list of tensors. 47 | 48 | 3.2 receiver: receive the content tensor, and parse it to obtain slices list using parser function 49 | """ 50 | 51 | def send_header(header, dst): 52 | header[HEADER_RECEIVER_RANK_IDX] = dst 53 | dist.send(header, dst=dst) 54 | 55 | def send_slices(slices, dst): 56 | np_slices = np.array(slices, dtype=np.int32) 57 | tensor_slices = torch.from_numpy(np_slices) 58 | dist.send(tensor_slices, dst=dst) 59 | 60 | def send_content(content, dst): 61 | dist.send(content, dst=dst) 62 | 63 | # body 64 | if package.dtype is not None: 65 | package.header[HEADER_DATA_TYPE_IDX] = dtype_torch2flab( 66 | package.dtype) 67 | 68 | # sender header firstly 69 | send_header(header=package.header, dst=dst) 70 | 71 | # if package got content, then send remain parts 72 | if package.header[HEADER_SLICE_SIZE_IDX] > 0: 73 | send_slices(slices=package.slices, dst=dst) 74 | send_content(content=package.content, dst=dst) 75 | 76 | @staticmethod 77 | def recv_package(src=None): 78 | """Three-segment tensor communication pattern based on ``torch.distributed`` 79 | 80 | Pattern is shown as follows: 81 | 1.1 sender: send a header tensor containing ``slice_size`` to receiver 82 | 83 | 1.2 receiver: receive the header, and get the value of ``slice_size`` and create a buffer for incoming slices of content 84 | 85 | 2.1 sender: send a list of slices indicating the size of every content size. 86 | 87 | 2.2 receiver: receive the slices list. 88 | 89 | 3.1 sender: send a content tensor composed of a list of tensors. 90 | 91 | 3.2 receiver: receive the content tensor, and parse it to obtain slices list using parser function 92 | """ 93 | 94 | def recv_header(src=src, parse=True): 95 | buffer = torch.zeros(size=(HEADER_SIZE, ), dtype=torch.int32) 96 | dist.recv(buffer, src=src) 97 | if parse is True: 98 | return Package.parse_header(buffer) 99 | else: 100 | return buffer 101 | 102 | def recv_slices(slices_size, src): 103 | buffer_slices = torch.zeros(size=(slices_size, ), 104 | dtype=torch.int32) 105 | dist.recv(buffer_slices, src=src) 106 | slices = [slc.item() for slc in buffer_slices] 107 | return slices 108 | 109 | def recv_content(slices, data_type, src): 110 | content_size = sum(slices) 111 | dtype = dtype_flab2torch(data_type) 112 | buffer = torch.zeros(size=(content_size, ), dtype=dtype) 113 | 114 | dist.recv(buffer, src=src) 115 | return Package.parse_content(slices, buffer) 116 | 117 | # body 118 | sender_rank, _, slices_size, message_code, data_type = recv_header( 119 | src=src) 120 | 121 | if slices_size > 0: 122 | slices = recv_slices(slices_size=slices_size, src=sender_rank) 123 | content = recv_content(slices, data_type, src=sender_rank) 124 | else: 125 | content = None 126 | 127 | return sender_rank, message_code, content 128 | -------------------------------------------------------------------------------- /fedlab/core/model_maintainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List 16 | import torch 17 | from copy import deepcopy 18 | from ..utils.serialization import SerializationTool 19 | from ..utils.functional import get_best_gpu 20 | 21 | 22 | class ModelMaintainer(object): 23 | """Maintain PyTorch model. 24 | 25 | Provide necessary attributes and operation methods. More features with local or global model 26 | will be implemented here. 27 | 28 | Args: 29 | model (torch.nn.Module): PyTorch model. 30 | cuda (bool): Use GPUs or not. 31 | device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to None. If device is None and cuda is True, FedLab will set the gpu with the largest memory as default. 32 | """ 33 | def __init__(self, 34 | model: torch.nn.Module, 35 | cuda: bool, 36 | device: str = None) -> None: 37 | self.cuda = cuda 38 | 39 | if cuda: 40 | # dynamic gpu acquire. 41 | if device is None: 42 | self.device = get_best_gpu() 43 | else: 44 | self.device = device 45 | self._model = deepcopy(model).cuda(self.device) 46 | else: 47 | self._model = deepcopy(model).cpu() 48 | 49 | def set_model(self, parameters: torch.Tensor): 50 | """Assign parameters to self._model.""" 51 | SerializationTool.deserialize_model(self._model, parameters) 52 | 53 | def set_model_np(self, parameters: torch.Tensor): 54 | """Assign parameters to self._model.""" 55 | SerializationTool.deserialize_trainable_model(self._model, parameters) 56 | 57 | @property 58 | def model(self) -> torch.nn.Module: 59 | """Return :class:`torch.nn.module`.""" 60 | return self._model 61 | 62 | @property 63 | def model_parameters(self) -> torch.Tensor: 64 | """Return serialized model parameters.""" 65 | return SerializationTool.serialize_model(self._model) 66 | 67 | @property 68 | def model_parameters_np(self) -> torch.Tensor: 69 | """Return serialized model parameters.""" 70 | return SerializationTool.serialize_trainable_model(self._model) 71 | 72 | @property 73 | def model_gradients(self) -> torch.Tensor: 74 | """Return serialized model gradients.""" 75 | return SerializationTool.serialize_model_gradients(self._model) 76 | 77 | @property 78 | def shape_list(self) -> List[torch.Tensor]: 79 | """Return shape of model parameters. 80 | 81 | Currently, this attributes used in tensor compression. 82 | """ 83 | shape_list = [param.shape for param in self._model.parameters()] 84 | return shape_list 85 | 86 | 87 | class SerialModelMaintainer(ModelMaintainer): 88 | """"Maintain PyTorch model. 89 | 90 | Provide necessary attributes and operation methods. More features with local or global model 91 | will be implemented here. 92 | 93 | Args: 94 | model (torch.nn.Module): PyTorch model. 95 | num_clients (int): The number of independent models. 96 | cuda (bool): Use GPUs or not. 97 | device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to None. If device is None and cuda is True, FedLab will set the gpu with the largest idle memory as default. 98 | personal (bool, optional): If Ture is passed, SerialModelMaintainer will generate the copy of local parameters list and maintain them respectively. These paremeters are indexed by [0, num-1]. Defaults to False. 99 | """ 100 | def __init__(self, 101 | model: torch.nn.Module, 102 | num_clients: int, 103 | cuda: bool, 104 | device: str = None, 105 | personal: bool = False) -> None: 106 | super().__init__(model, cuda, device) 107 | if personal: 108 | self.parameters = [ 109 | self.model_parameters for _ in range(num_clients) 110 | ] # A list of Tensor 111 | else: 112 | self.parameters = None 113 | 114 | def set_model(self, parameters: torch.Tensor = None, id: int = None): 115 | """Assign parameters to self._model. 116 | 117 | Note: 118 | parameters and id can not be None at the same time. 119 | If id is None, this function load the given parameters. 120 | If id is not None, this function load the parameters of given id first and the parameters attribute will be ignored. 121 | 122 | Args: 123 | parameters (torch.Tensor, optional): Model parameters. Defaults to None. 124 | id (int, optional): Load the model parameters of client id. Defaults to None. 125 | """ 126 | if id is None: 127 | super().set_model(parameters) 128 | else: 129 | super().set_model(self.parameters[id]) -------------------------------------------------------------------------------- /fedlab/models/cnn.py: -------------------------------------------------------------------------------- 1 | """CNN model in pytorch 2 | References: 3 | [1] Reddi S, Charles Z, Zaheer M, et al. 4 | Adaptive Federated Optimization. ICML 2020. 5 | https://arxiv.org/pdf/2003.00295.pdf 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class CNN_FEMNIST(nn.Module): 14 | """Used for EMNIST experiments in references[1] 15 | Args: 16 | only_digits (bool, optional): If True, uses a final layer with 10 outputs, for use with the 17 | digits only MNIST dataset (http://yann.lecun.com/exdb/mnist/). 18 | If selfalse, uses 62 outputs for selfederated Extended MNIST (selfEMNIST) 19 | EMNIST: Extending MNIST to handwritten letters: https://arxiv.org/abs/1702.05373 20 | Defaluts to `True` 21 | Returns: 22 | A `torch.nn.Module`. 23 | """ 24 | def __init__(self, only_digits=False): 25 | super(CNN_FEMNIST, self).__init__() 26 | self.conv2d_1 = nn.Conv2d(1, 32, kernel_size=3) 27 | self.max_pooling = nn.MaxPool2d(2, stride=2) 28 | self.conv2d_2 = nn.Conv2d(32, 64, kernel_size=3) 29 | self.dropout_1 = nn.Dropout(0.25) 30 | self.flatten = nn.Flatten() 31 | self.linear_1 = nn.Linear(9216, 128) 32 | self.dropout_2 = nn.Dropout(0.5) 33 | self.linear_2 = nn.Linear(128, 10 if only_digits else 62) 34 | self.relu = nn.ReLU() 35 | # self.softmax = nn.Softmax(dim=1) 36 | 37 | def forward(self, x): 38 | x = self.conv2d_1(x) 39 | x = self.relu(x) 40 | x = self.conv2d_2(x) 41 | x = self.relu(x) 42 | x = self.max_pooling(x) 43 | x = self.dropout_1(x) 44 | x = self.flatten(x) 45 | x = self.linear_1(x) 46 | x = self.relu(x) 47 | x = self.dropout_2(x) 48 | x = self.linear_2(x) 49 | # x = self.softmax(x) 50 | return x 51 | 52 | 53 | class CNN_MNIST(nn.Module): 54 | def __init__(self): 55 | super(CNN_MNIST, self).__init__() 56 | self.conv1 = nn.Conv2d(1, 32, kernel_size=(5, 5)) 57 | self.conv2 = nn.Conv2d(32, 64, kernel_size=(5, 5)) 58 | self.pool = nn.MaxPool2d(kernel_size=(2, 2)) 59 | self.fc1 = nn.Linear(in_features=1024, out_features=512) 60 | self.relu = nn.ReLU() 61 | self.fc2 = nn.Linear(512, 10) 62 | 63 | def forward(self, x): 64 | x = self.pool(self.conv1(x)) 65 | x = self.pool(self.conv2(x)) 66 | x = x.view(x.shape[0], -1) 67 | x = self.relu(self.fc1(x)) 68 | x = self.fc2(x) 69 | return x 70 | 71 | 72 | class CNN_CIFAR10(nn.Module): 73 | """from torch tutorial 74 | https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 75 | """ 76 | def __init__(self): 77 | super(CNN_CIFAR10,self).__init__() 78 | self.conv1 = nn.Conv2d(3, 6, 5) 79 | self.pool = nn.MaxPool2d(2, 2) 80 | self.conv2 = nn.Conv2d(6, 16, 5) 81 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 82 | self.fc2 = nn.Linear(120, 84) 83 | self.fc3 = nn.Linear(84, 10) 84 | 85 | def forward(self, x): 86 | x = self.pool(F.relu(self.conv1(x))) 87 | x = self.pool(F.relu(self.conv2(x))) 88 | x = torch.flatten(x, 1) # flatten all dimensions except batch 89 | x = F.relu(self.fc1(x)) 90 | x = F.relu(self.fc2(x)) 91 | x = self.fc3(x) 92 | return x 93 | 94 | class CNN_CIFAR100(nn.Module): 95 | """from torch tutorial 96 | https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 97 | """ 98 | def __init__(self): 99 | super(CNN_CIFAR100,self).__init__() 100 | self.conv1 = nn.Conv2d(3, 6, 5) 101 | self.pool = nn.MaxPool2d(2, 2) 102 | self.conv2 = nn.Conv2d(6, 16, 5) 103 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 104 | self.fc2 = nn.Linear(120, 84) 105 | self.fc3 = nn.Linear(84, 100) 106 | 107 | def forward(self, x): 108 | x = self.pool(F.relu(self.conv1(x))) 109 | x = self.pool(F.relu(self.conv2(x))) 110 | x = torch.flatten(x, 1) # flatten all dimensions except batch 111 | x = F.relu(self.fc1(x)) 112 | x = F.relu(self.fc2(x)) 113 | x = self.fc3(x) 114 | return x 115 | 116 | class AlexNet_CIFAR10(nn.Module): 117 | def __init__(self, num_classes=10): 118 | super(AlexNet_CIFAR10, self).__init__() 119 | self.features = nn.Sequential( 120 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 121 | nn.ReLU(inplace=True), 122 | nn.MaxPool2d(kernel_size=2), 123 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.MaxPool2d(kernel_size=2), 126 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 127 | nn.ReLU(inplace=True), 128 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 129 | nn.ReLU(inplace=True), 130 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 131 | nn.ReLU(inplace=True), 132 | nn.MaxPool2d(kernel_size=2), 133 | ) 134 | self.classifier = nn.Sequential( 135 | #nn.Dropout(), 136 | nn.Linear(256 * 2 * 2, 4096), 137 | nn.ReLU(inplace=True), 138 | #nn.Dropout(), 139 | nn.Linear(4096, 4096), 140 | nn.ReLU(inplace=True), 141 | nn.Linear(4096, num_classes), 142 | ) 143 | 144 | def forward(self, x): 145 | x = self.features(x) 146 | x = x.view(x.size(0), 256 * 2 * 2) 147 | x = self.classifier(x) 148 | return x -------------------------------------------------------------------------------- /fedlab/contrib/dataset/fcube.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import random 17 | import os 18 | 19 | from torch.utils.data import Dataset 20 | 21 | 22 | class FCUBE(Dataset): 23 | """FCUBE data set. 24 | 25 | From paper `Federated Learning on Non-IID Data Silos: An Experimental Study `_. 26 | 27 | Args: 28 | root (str): Root for data file. 29 | train (bool, optional): Training set or test set. Default as ``True``. 30 | generate (bool, optional): Whether to generate synthetic dataset. If ``True``, then generate new synthetic FCUBE data even existed. Default as ``True``. 31 | transform (callable, optional): A function/transform that takes in an ``numpy.ndarray`` and returns a transformed version. 32 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 33 | num_samples (int, optional): Total number of samples to generate. We suggest to use 4000 for training set, and 1000 for test set. Default is ``4000`` for trainset. 34 | """ 35 | train_files = {'data': "fcube_train_X", 'targets': "fcube_train_y"} 36 | test_files = {'data': "fcube_test_X", 'targets': "fcube_test_y"} 37 | num_clients = 4 # only for 4 clients 38 | 39 | def __init__(self, root, train=True, generate=True, transform=None, target_transform=None, 40 | num_samples=4000): 41 | self.data = None 42 | self.targets = None 43 | self.train = train 44 | self.generate = generate 45 | self.num_samples = num_samples 46 | self.transform = transform 47 | self.target_transform = target_transform 48 | 49 | if not os.path.exists(root): 50 | os.makedirs(root) 51 | 52 | files = self.train_files if train else self.test_files 53 | files = {key: files[key] + f"_{num_samples}.npy" for key in files} 54 | 55 | full_file_paths = {key: os.path.join(root, files[key]) for key in files} 56 | self.full_file_paths = full_file_paths 57 | 58 | if generate is False: 59 | if os.path.exists(full_file_paths['data']) and os.path.exists( 60 | full_file_paths['targets']): 61 | print( 62 | f"FCUBE data already generated. Load from {full_file_paths['data']} " 63 | f"and {full_file_paths['targets']}...") 64 | self.data = np.load(full_file_paths['data']) 65 | self.targets = np.load(full_file_paths['targets']) 66 | print(f"FCUBE data loaded from local file.") 67 | else: 68 | raise RuntimeError( 69 | f"FCUBE data file not found. You can use generate=True to generate it.") 70 | else: 71 | # Generate file by force 72 | print("Generate FCUBE data now...") 73 | if train: 74 | self._generate_train() 75 | else: 76 | self._generate_test() 77 | 78 | self._save_data() # save to local npy files 79 | 80 | def _generate_train(self): 81 | X_train, y_train = [], [] 82 | for loc in range(4): 83 | for i in range(int(self.num_samples / 4)): 84 | p1 = random.random() 85 | p2 = random.random() 86 | p3 = random.random() 87 | if loc > 1: 88 | p2 = -p2 89 | if loc % 2 == 1: 90 | p3 = -p3 91 | if i % 2 == 0: 92 | X_train.append([p1, p2, p3]) 93 | y_train.append(0) 94 | else: 95 | X_train.append([-p1, -p2, -p3]) 96 | y_train.append(1) 97 | 98 | self.data = np.array(X_train, dtype=np.float32) 99 | self.targets = np.array(y_train, dtype=np.int32) 100 | 101 | def _generate_test(self): 102 | X_test, y_test = [], [] 103 | for i in range(self.num_samples): 104 | p1 = random.random() * 2 - 1 105 | p2 = random.random() * 2 - 1 106 | p3 = random.random() * 2 - 1 107 | X_test.append([p1, p2, p3]) 108 | if p1 > 0: 109 | y_test.append(0) 110 | else: 111 | y_test.append(1) 112 | 113 | self.data = np.array(X_test, dtype=np.float32) 114 | self.targets = np.array(y_test, dtype=np.int64) 115 | 116 | def _save_data(self): 117 | np.save(self.full_file_paths['data'], self.data) 118 | print(f"{self.full_file_paths['data']} generated.") 119 | np.save(self.full_file_paths['targets'], self.targets) 120 | print(f"{self.full_file_paths['targets']} generated.") 121 | 122 | def __len__(self): 123 | return self.num_samples 124 | 125 | def __getitem__(self, index): 126 | """ 127 | Args: 128 | index (int): Index 129 | 130 | Returns: 131 | tuple: (features, target) where target is index of the target class. 132 | """ 133 | X = self.data[index] 134 | y = self.targets[index] 135 | 136 | if self.transform is not None: 137 | X = self.transform(X) 138 | 139 | if self.target_transform is not None: 140 | y = self.target_transform(y) 141 | 142 | return X, y 143 | -------------------------------------------------------------------------------- /fedlab/board/utils/io.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import json 3 | import os 4 | import pickle 5 | import shutil 6 | from os import path 7 | from typing import Any 8 | 9 | from fedlab.board.utils.roles import is_client_holder, is_server 10 | 11 | 12 | def _update_meta_file(file_root: str, section: str, dct: dict): 13 | config_file = path.join(file_root, 'experiment.ini') 14 | os.makedirs(file_root, exist_ok=True) 15 | if not os.path.exists(config_file): 16 | with open(config_file, 'w') as file: 17 | file.write('') 18 | config = configparser.ConfigParser() 19 | config.read(config_file) 20 | if not config.has_section(section): 21 | config.add_section(section) 22 | for key, value in dct.items(): 23 | config.set(section, key, str(value)) 24 | with open(config_file, 'w') as configfile: 25 | config.write(configfile) 26 | 27 | 28 | def register_client(dir: str, role_id: str, client_ids: list[str]): 29 | register_role(dir, role_id) 30 | fn = path.join(dir, f'roles/{role_id}/clients') 31 | json.dump(client_ids, open(fn, 'w+')) 32 | 33 | 34 | def register_role(dir: str, role_id: str): 35 | os.makedirs(path.join(dir, f'roles/{role_id}'), exist_ok=True) 36 | 37 | 38 | def get_client_ids(dir: str) -> dict[str:list]: 39 | pt = path.join(dir, f'roles/') 40 | dict = {} 41 | for role_id in os.listdir(pt): 42 | roles = int(role_id.split('-')[0]) 43 | if not is_client_holder(roles): 44 | continue 45 | fn = path.join(pt, role_id, 'clients') 46 | dict[role_id] = json.load(open(fn)) 47 | return dict 48 | 49 | 50 | def get_roles_tree(dir: str) -> dict[str:list]: 51 | pt = path.join(dir, f'roles/') 52 | dict: dict[str, list] = {} 53 | for role_id in os.listdir(pt): 54 | roles = int(role_id.split('-')[0]) 55 | dict.setdefault(role_id, []) 56 | if is_server(roles): 57 | dict[role_id].append({'role': 'server'}) 58 | fn = path.join(pt, role_id, 'clients') 59 | if os.path.exists(fn): 60 | client_ids = json.load(open(fn)) 61 | dict[role_id].append({'role': 'client_holder', 'client_ids': client_ids}) 62 | return dict 63 | 64 | 65 | def get_server_role_ids(dir: str) -> list[str]: 66 | pt = path.join(dir, f'roles/') 67 | res = [] 68 | for role_id in os.listdir(pt): 69 | roles = int(role_id.split('-')[0]) 70 | if is_server(roles): 71 | res.append(role_id) 72 | return res 73 | 74 | 75 | def get_role_ids(dir: str) -> list[str]: 76 | pt = path.join(dir, f'roles/') 77 | res = [] 78 | for role_id in os.listdir(pt): 79 | res.append(role_id) 80 | return res 81 | 82 | 83 | def clear_log(dir): 84 | shutil.rmtree(path.join(dir, 'log/'), ignore_errors=True) 85 | shutil.rmtree(path.join(dir, 'cache/'), ignore_errors=True) 86 | 87 | 88 | def clear_roles(dir): 89 | shutil.rmtree(path.join(dir, 'roles/'), ignore_errors=True) 90 | 91 | 92 | def _read_meta_file(file_root: str, section: str, keys): 93 | config_file = path.join(file_root, 'experiment.ini') 94 | if not os.path.isfile(config_file): 95 | return None 96 | config = configparser.ConfigParser() 97 | config.read(config_file) 98 | if not config.has_section(section): 99 | return None 100 | res = {key: config.get(section, key, fallback=None) for key in keys} 101 | return res 102 | 103 | 104 | def _log_to_fs(file_root: str, role_id: str, type: str, name: str, obj: Any, sub_type: str = None): 105 | pt = path.join(file_root, f'roles/{role_id}/log/{type}/') 106 | if sub_type: 107 | pt = path.join(pt, f'{sub_type}/') 108 | os.makedirs(pt, exist_ok=True) 109 | pickle.dump(obj, open(path.join(pt, f'{name}.pkl'), 'wb+')) 110 | 111 | 112 | def _log_to_role_fs_append(file_root: str, role_id: str, section: str, name: str, round: int, obj: Any): 113 | pt = path.join(file_root, f'roles/{role_id}/log/{section}') 114 | os.makedirs(pt, exist_ok=True) 115 | with open(path.join(pt, f'{name}.log'), 'a+') as f: 116 | f.write(f'{round}==' + json.dumps(obj) + '\n') 117 | 118 | 119 | def _read_log_from_role_fs_appended(file_root: str, role_id: str, section: str, name: str): 120 | target = path.join(file_root, f'roles/{role_id}/log/{section}/{name}.log') 121 | if not os.path.exists(target): 122 | return [] 123 | return open(target).readlines() 124 | 125 | 126 | # def _log_to_fs_append(file_root: str, type: str, name: str, obj: Any, sub_type: str = None): 127 | # pt = path.join(file_root, f'log/{type}/') 128 | # if sub_type: 129 | # pt = path.join(pt, f'{sub_type}/') 130 | # os.makedirs(pt, exist_ok=True) 131 | # with open(path.join(pt, f'{name}.log'), 'a+') as f: 132 | # f.write(json.dumps(obj) + '\n') 133 | 134 | 135 | def _read_log_from_fs(file_root: str, role_id, type: str, name: str, sub_type: str = None): 136 | target = path.join(file_root, f'roles/{role_id}/log/{type}/') 137 | if sub_type: 138 | target = path.join(target, f'{sub_type}/') 139 | target = path.join(target, f'{name}.pkl') 140 | try: 141 | return pickle.load(open(target, 'rb')) 142 | except: 143 | return None 144 | 145 | 146 | def _read_log_from_fs_appended(file_root: str, type: str, name: str, sub_type: str = None): 147 | target = path.join(file_root, f'log/{type}/') 148 | if sub_type: 149 | target = path.join(target, f'{sub_type}/') 150 | target = path.join(target, f'{name}.log') 151 | if not os.path.exists(target): 152 | return [] 153 | return open(target).readlines() 154 | 155 | 156 | def _read_cached_from_fs(file_root: str, type: str, sub_type: str, name: str): 157 | target = path.join(file_root, f'cache/{type}/{sub_type}/{name}.pkl') 158 | try: 159 | return pickle.load(open(target, 'rb')) 160 | except: 161 | return None 162 | 163 | 164 | def _cache_to_fs(obj, file_root: str, type: str, sub_type: str, name: str): 165 | os.makedirs(path.join(file_root, f'cache/{type}/{sub_type}/'), exist_ok=True) 166 | target = path.join(file_root, f'cache/{type}/{sub_type}/{name}.pkl') 167 | pickle.dump(obj, open(target, 'wb+')) 168 | -------------------------------------------------------------------------------- /fedlab/core/client/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Peng Cheng Laboratory (http://www.szpclab.com/) and FedLab Authors (smilelab.group) 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import abstractclassmethod, abstractproperty, abstractmethod 16 | from random import randint 17 | from typing import List 18 | 19 | import torch 20 | 21 | from fedlab.contrib.dataset.basic_dataset import FedDataset 22 | 23 | from ..client import ORDINARY_TRAINER, SERIAL_TRAINER 24 | from ..model_maintainer import ModelMaintainer, SerialModelMaintainer 25 | from ...utils import Logger, SerializationTool 26 | 27 | 28 | class ClientTrainer(ModelMaintainer): 29 | """An abstract class representing a client trainer. 30 | 31 | In FedLab, we define the backend of client trainer show manage its local model. 32 | It should have a function to update its model called :meth:`local_process`. 33 | 34 | If you use our framework to define the activities of client, please make sure that your self-defined class 35 | should subclass it. All subclasses should overwrite :meth:`local_process` and property ``uplink_package``. 36 | 37 | Args: 38 | model (torch.nn.Module): PyTorch model. 39 | cuda (bool): Use GPUs or not. 40 | device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to ``None``. 41 | """ 42 | 43 | def __init__(self, 44 | model: torch.nn.Module, 45 | cuda: bool, 46 | device: str = None) -> None: 47 | super().__init__(model, cuda, device) 48 | 49 | self.num_clients = 1 # default is 1. 50 | self.dataset = FedDataset() # or Dataset 51 | self.type = ORDINARY_TRAINER 52 | 53 | def setup_dataset(self): 54 | """Set up local dataset ``self.dataset`` for clients.""" 55 | raise NotImplementedError() 56 | 57 | def setup_optim(self): 58 | """Set up variables for optimization algorithms.""" 59 | raise NotImplementedError() 60 | 61 | @property 62 | @abstractmethod 63 | def uplink_package(self) -> List[torch.Tensor]: 64 | """Return a tensor list for uploading to server. 65 | 66 | This attribute will be called by client manager. 67 | Customize it for new algorithms. 68 | """ 69 | raise NotImplementedError() 70 | 71 | @abstractclassmethod 72 | def local_process(self, payload: List[torch.Tensor]): 73 | """Manager of the upper layer will call this function with accepted payload 74 | 75 | In synchronous mode, return True to end current FL round. 76 | """ 77 | raise NotImplementedError() 78 | 79 | def train(self): 80 | """Override this method to define the training procedure. This function should manipulate :attr:`self._model`.""" 81 | raise NotImplementedError() 82 | 83 | def validate(self): 84 | """Validate quality of local model.""" 85 | raise NotImplementedError() 86 | 87 | def evaluate(self): 88 | """Evaluate quality of local model.""" 89 | raise NotImplementedError() 90 | 91 | 92 | class SerialClientTrainer(SerialModelMaintainer): 93 | """Base class. Simulate multiple clients in sequence in a single process. 94 | 95 | Args: 96 | model (torch.nn.Module): Model used in this federation. 97 | num_clients (int): Number of clients in current trainer. 98 | cuda (bool): Use GPUs or not. Default: ``False``. 99 | device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to None. 100 | personal (bool, optional): If Ture is passed, SerialModelMaintainer will generate the copy of local parameters list and maintain them respectively. These paremeters are indexed by [0, num-1]. Defaults to False. 101 | """ 102 | 103 | def __init__(self, 104 | model: torch.nn.Module, 105 | num_clients: int, 106 | cuda: bool, 107 | device: str = None, 108 | personal: bool = False) -> None: 109 | super().__init__(model, num_clients, cuda, device, personal) 110 | 111 | self.num_clients = num_clients 112 | self.dataset = FedDataset() 113 | self.type = SERIAL_TRAINER # represent serial trainer 114 | 115 | def setup_dataset(self): 116 | """Override this function to set up local dataset for clients""" 117 | raise NotImplementedError() 118 | 119 | def setup_optim(self): 120 | """""" 121 | raise NotImplementedError() 122 | 123 | @property 124 | @abstractmethod 125 | def uplink_package(self) -> List[List[torch.Tensor]]: 126 | """Return a tensor list for uploading to server. 127 | 128 | This attribute will be called by client manager. 129 | Customize it for new algorithms. 130 | """ 131 | raise NotImplementedError() 132 | 133 | @abstractclassmethod 134 | def local_process(self, id_list: list, payload: List[torch.Tensor]): 135 | """Define the local main process.""" 136 | # Args: 137 | # id_list (list): The list consists of client ids. 138 | # payload (List[torch.Tensor]): The information that server broadcasts to clients. 139 | raise NotImplementedError() 140 | 141 | def train(self): 142 | """Override this method to define the algorithm of training your model. This function should manipulate :attr:`self._model`""" 143 | raise NotImplementedError() 144 | 145 | def evaluate(self): 146 | """Evaluate quality of local model.""" 147 | raise NotImplementedError() 148 | 149 | def validate(self): 150 | """Validate quality of local model.""" 151 | raise NotImplementedError() 152 | --------------------------------------------------------------------------------