├── conf ├── __init__.py ├── global.yaml ├── fed_avg │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── cifar100.yaml │ ├── imagenet.yaml │ └── imdb.yaml ├── fed_avg_qat │ ├── mnist.yaml │ └── cifar10.yaml ├── gtg_sv │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml ├── multiround_sv │ ├── cifar100.yaml │ └── cifar10.yaml ├── qsgd │ └── mnist.yaml ├── sign_sgd │ ├── cifar100.yaml │ ├── cifar10.yaml │ └── imdb.yaml ├── fed_paq │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml ├── fed_dropout_avg │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml ├── smafd │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml ├── fed_obd │ ├── cifar100_sq.yaml │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml └── fed_obd_sq │ └── cifar100.yaml ├── gtg_shapley_train.sh ├── __init__.py ├── exp_analyzer.py ├── method ├── fed_obd │ ├── phase.py │ ├── __init__.py │ ├── server.py │ ├── worker.py │ └── obd_algorithm.py ├── shapley_value │ ├── shapley_value_server.py │ ├── GTG_shapley_value_server.py │ ├── GTG_shapley_value_algorithm.py │ ├── multiround_shapley_value_server.py │ ├── multiround_shapley_value_algorithm.py │ └── __init__.py ├── fed_avg_qat │ ├── __init__.py │ └── worker.py ├── sign_sgd │ ├── __init__.py │ ├── algorithm.py │ └── worker.py ├── __init__.py ├── fed_dropout_avg │ ├── __init__.py │ ├── algorithm.py │ └── worker.py ├── qsgd │ └── __init__.py └── fed_paq │ └── __init__.py ├── analysis ├── __init__.py ├── analyze_round.py └── analyze_log.py ├── fed_obd_train.sh ├── algorithm ├── __init__.py ├── block_algorithm.py └── shapley_value_algorithm.py ├── exp_analyzer.sh ├── fed_aas.sh ├── other_method_test.sh ├── test.sh ├── simulator.py ├── pyproject.toml ├── .gitignore └── README.md /conf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gtg_shapley_train.sh: -------------------------------------------------------------------------------- 1 | python3 ./simulator.py --config-name gtg_sv/mnist.yaml 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__))) 5 | -------------------------------------------------------------------------------- /exp_analyzer.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation.analysis.document import dump_analysis 2 | 3 | if __name__ == "__main__": 4 | dump_analysis() 5 | -------------------------------------------------------------------------------- /conf/global.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | cache_transforms: cpu 3 | log_level: INFO 4 | save_performance_metric: false 5 | use_amp: false 6 | use_slow_performance_metrics: false 7 | -------------------------------------------------------------------------------- /method/fed_obd/phase.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum, auto 2 | 3 | 4 | class Phase(StrEnum): 5 | STAGE_ONE = auto() 6 | STAGE_TWO = auto() 7 | END = auto() 8 | -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | currentdir = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | sys.path.insert(0, os.path.join(currentdir, "..")) 7 | -------------------------------------------------------------------------------- /fed_obd_train.sh: -------------------------------------------------------------------------------- 1 | python3 ./simulator.py --config-name fed_obd/cifar10.yaml 2 | python3 ./simulator.py --config-name fed_obd/cifar100.yaml 3 | python3 ./simulator.py --config-name fed_obd/imdb.yaml 4 | -------------------------------------------------------------------------------- /algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_algorithm import BlockAlgorithmMixin 2 | from .shapley_value_algorithm import ShapleyValueAlgorithm 3 | 4 | __all__ = ["BlockAlgorithmMixin", "ShapleyValueAlgorithm"] 5 | -------------------------------------------------------------------------------- /conf/fed_avg/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: LeNet5 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 20 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 2 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg/imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: ImageNet 3 | model_name: Resnet50 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 128 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg_qat/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: LeNet5 4 | distributed_algorithm: fed_avg_qat 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 20 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 2 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/gtg_sv/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: LeNet5 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 20 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 2 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg_qat/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_avg_qat 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /method/shapley_value/shapley_value_server.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import AggregationServer 2 | 3 | 4 | class ShapleyValueServer(AggregationServer): 5 | def __init__(self, *args, **kwargs) -> None: 6 | super().__init__(*args, **kwargs) 7 | self._need_init_performance = True 8 | -------------------------------------------------------------------------------- /conf/gtg_sv/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/gtg_sv/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/multiround_sv/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/qsgd/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: lenet5 4 | distributed_algorithm: QSGD 5 | worker_number: 2 6 | round: 1 7 | learning_rate_scheduler_name: CosineAnnealingLR 8 | epoch: 10 9 | batch_size: 64 10 | learning_rate: 0.001 11 | algorithm_kwargs: 12 | distribute_init_parameters: false 13 | -------------------------------------------------------------------------------- /conf/multiround_sv/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: multiround_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /exp_analyzer.sh: -------------------------------------------------------------------------------- 1 | for session_path in $(find session -name server); do 2 | base_dir=$(dirname $session_path) 3 | if ! test -d ".real_${base_dir}"; then 4 | mkdir -p ".real_${base_dir}" 5 | cp -r ${base_dir} $(dirname ".real_${base_dir}") 6 | fi 7 | env session_path=".real_${session_path}" python3 exp_analyzer.py 8 | done 9 | -------------------------------------------------------------------------------- /conf/sign_sgd/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: sign_SGD 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 1 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 100 10 | learning_rate: 0.1 11 | algorithm_kwargs: 12 | distribute_init_parameters: false 13 | -------------------------------------------------------------------------------- /conf/fed_paq/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_paq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | batch_size: 64 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | ... 15 | -------------------------------------------------------------------------------- /conf/fed_paq/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_paq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | ... 15 | -------------------------------------------------------------------------------- /conf/sign_sgd/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: sign_SGD 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 1 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 100 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | distribute_init_parameters: false 14 | -------------------------------------------------------------------------------- /method/shapley_value/GTG_shapley_value_server.py: -------------------------------------------------------------------------------- 1 | from .GTG_shapley_value_algorithm import GTGShapleyValueAlgorithm 2 | from .shapley_value_server import ShapleyValueServer 3 | 4 | 5 | class GTGShapleyValueServer(ShapleyValueServer): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(**kwargs, algorithm=GTGShapleyValueAlgorithm(server=self)) 8 | -------------------------------------------------------------------------------- /conf/fed_dropout_avg/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_dropout_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | dropout_rate: 0.3 14 | random_client_number: 5 15 | -------------------------------------------------------------------------------- /conf/smafd/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: single_model_afd 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | batch_size: 64 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dropout_rate: 0.3 15 | ... 16 | -------------------------------------------------------------------------------- /conf/smafd/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: single_model_afd 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dropout_rate: 0.3 15 | ... 16 | -------------------------------------------------------------------------------- /method/shapley_value/GTG_shapley_value_algorithm.py: -------------------------------------------------------------------------------- 1 | from cyy_torch_algorithm.shapely_value.gtg_shapley_value import GTGShapleyValue 2 | 3 | from ..algorithm.shapley_value_algorithm import ShapleyValueAlgorithm 4 | 5 | 6 | class GTGShapleyValueAlgorithm(ShapleyValueAlgorithm): 7 | def __init__(self, *args, **kwargs) -> None: 8 | super().__init__(GTGShapleyValue, *args, **kwargs) 9 | -------------------------------------------------------------------------------- /conf/fed_dropout_avg/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_dropout_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | dropout_rate: 0.3 14 | random_client_number: 5 15 | ... 16 | -------------------------------------------------------------------------------- /method/fed_avg_qat/__init__.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import ( 2 | AggregationServer, 3 | AlgorithmRepository, 4 | FedAVGAlgorithm, 5 | ) 6 | 7 | from .worker import QATWorker 8 | 9 | AlgorithmRepository.register_algorithm( 10 | algorithm_name="fed_avg_qat", 11 | client_cls=QATWorker, 12 | server_cls=AggregationServer, 13 | algorithm_cls=FedAVGAlgorithm, 14 | ) 15 | -------------------------------------------------------------------------------- /conf/fed_obd/cifar100_sq.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd_sq 5 | optimizer_name: SGD 6 | worker_number: 50 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | second_phase_epoch: 10 14 | dropout_rate: 0.3 15 | random_client_number: 25 16 | ... 17 | -------------------------------------------------------------------------------- /conf/fed_obd_sq/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd_sq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | second_phase_epoch: 10 14 | dropout_rate: 0.9 15 | random_client_number: 5 16 | ... 17 | -------------------------------------------------------------------------------- /method/shapley_value/multiround_shapley_value_server.py: -------------------------------------------------------------------------------- 1 | from .multiround_shapley_value_algorithm import MultiRoundShapleyValueAlgorithm 2 | from .shapley_value_server import ShapleyValueServer 3 | 4 | 5 | class MultiRoundShapleyValueServer(ShapleyValueServer): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | **kwargs, algorithm=MultiRoundShapleyValueAlgorithm(server=self) 9 | ) 10 | -------------------------------------------------------------------------------- /method/shapley_value/multiround_shapley_value_algorithm.py: -------------------------------------------------------------------------------- 1 | from cyy_torch_algorithm.shapely_value.multiround_shapley_value import ( 2 | MultiRoundShapleyValue, 3 | ) 4 | 5 | from ...algorithm.shapley_value_algorithm import ShapleyValueAlgorithm 6 | 7 | 8 | class MultiRoundShapleyValueAlgorithm(ShapleyValueAlgorithm): 9 | def __init__(self, *args, **kwargs) -> None: 10 | super().__init__(MultiRoundShapleyValue, *args, **kwargs) 11 | -------------------------------------------------------------------------------- /fed_aas.sh: -------------------------------------------------------------------------------- 1 | # fedaas 2 | for configname in cs.yaml PubMed.yaml reddit.yaml yelp.yaml; do 3 | python3 ./simulator.py --config-name fed_aas/${configname} ++fed_aas.worker_number=50 ++fed_aas.round=50 ++fed_aas.model_name=TwoGATCN ++fed_aas.epoch=1 ++fed_aas.dataloader_kwargs.batch_number=10 ++fed_aas.learning_rate=0.001 ++fed_aas.algorithm_kwargs.edge_drop_rate=0.99 ++fed_aas.weight_decay=0.001 ++fed_aas.exp_name="fed_aas" ++fed_aas.algorithm_kwargs.min_sharing_interval=2 4 | done 5 | -------------------------------------------------------------------------------- /other_method_test.sh: -------------------------------------------------------------------------------- 1 | # Fed dropout avg 2 | python3 ./simulator.py --config-name fed_dropout_avg/cifar100.yaml ++fed_dropout_avg.round=1 ++fed_dropout_avg.epoch=1 ++fed_dropout_avg.worker_number=2 ++fed_dropout_avg.debug=True ++fed_dropout_avg.algorithm_kwargs.random_client_number=2 3 | 4 | python3 ./simulator.py --config-name fed_paq/cifar100.yaml ++fed_paq.round=1 ++fed_paq.epoch=1 ++fed_paq.worker_number=2 ++fed_paq.debug=True ++fed_paq.algorithm_kwargs.random_client_number=2 5 | -------------------------------------------------------------------------------- /conf/fed_obd/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd 5 | optimizer_name: SGD 6 | worker_number: 100 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | learning_rate: 0.1 11 | batch_size: 64 12 | endpoint_kwargs: 13 | server: 14 | weight: 0.001 15 | worker: 16 | weight: 0.001 17 | algorithm_kwargs: 18 | second_phase_epoch: 10 19 | dropout_rate: 0.3 20 | random_client_number: 50 21 | ... 22 | -------------------------------------------------------------------------------- /conf/fed_obd/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd 5 | optimizer_name: SGD 6 | worker_number: 50 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | endpoint_kwargs: 13 | server: 14 | weight: 0.001 15 | worker: 16 | weight: 0.001 17 | algorithm_kwargs: 18 | second_phase_epoch: 10 19 | dropout_rate: 0.3 20 | random_client_number: 25 21 | ... 22 | -------------------------------------------------------------------------------- /method/sign_sgd/__init__.py: -------------------------------------------------------------------------------- 1 | """signSGD: Compressed Optimisation for Non-Convex Problems https://arxiv.org/abs/1802.04434""" 2 | 3 | from distributed_learning_simulation import ( 4 | AggregationServer, 5 | AlgorithmRepository, 6 | ) 7 | 8 | from .algorithm import SignSGDAlgorithm 9 | from .worker import SignSGDWorker 10 | 11 | AlgorithmRepository.register_algorithm( 12 | algorithm_name="sign_SGD", 13 | client_cls=SignSGDWorker, 14 | server_cls=AggregationServer, 15 | algorithm_cls=SignSGDAlgorithm, 16 | ) 17 | -------------------------------------------------------------------------------- /method/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | for entry in os.scandir(os.path.dirname(os.path.abspath(__file__))): 5 | if not entry.is_dir(): 6 | continue 7 | if entry.name == "__pycache__": 8 | continue 9 | if entry.name.startswith("."): 10 | continue 11 | try: 12 | importlib.import_module(f".{entry.name}", "method") 13 | except ModuleNotFoundError: 14 | importlib.import_module( 15 | f".{entry.name}", "distributed_learning_simulator.method" 16 | ) 17 | -------------------------------------------------------------------------------- /method/sign_sgd/algorithm.py: -------------------------------------------------------------------------------- 1 | """signSGD: Compressed Optimisation for Non-Convex Problems https://arxiv.org/abs/1802.04434""" 2 | 3 | from distributed_learning_simulation import FedAVGAlgorithm, ParameterMessage 4 | 5 | 6 | class SignSGDAlgorithm(FedAVGAlgorithm): 7 | def aggregate_worker_data(self) -> ParameterMessage: 8 | message = super().aggregate_worker_data() 9 | assert isinstance(message, ParameterMessage) 10 | message.parameter = {k: v.sign() for k, v in message.parameter.items()} 11 | return message 12 | -------------------------------------------------------------------------------- /conf/fed_avg/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.01 12 | dataset_kwargs: 13 | input_max_len: 300 14 | dataset_type: text 15 | tokenizer: 16 | type: spacy 17 | model_kwargs: 18 | word_vector_name: glove.6B.100d 19 | num_encoder_layer: 2 20 | d_model: 100 21 | nhead: 5 22 | frozen_modules: 23 | names: [embedding] 24 | -------------------------------------------------------------------------------- /conf/gtg_sv/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.01 12 | dataset_kwargs: 13 | input_max_len: 300 14 | dataset_type: text 15 | tokenizer: 16 | type: spacy 17 | model_kwargs: 18 | word_vector_name: glove.6B.100d 19 | num_encoder_layer: 2 20 | d_model: 100 21 | nhead: 5 22 | frozen_modules: 23 | names: [embedding] 24 | -------------------------------------------------------------------------------- /method/fed_dropout_avg/__init__.py: -------------------------------------------------------------------------------- 1 | """FedDropoutAvg: Generalizable federated learning for histopathology image classification (https://arxiv.org/pdf/2111.13230.pdf)""" 2 | 3 | from distributed_learning_simulation import ( 4 | AggregationServer, 5 | AlgorithmRepository, 6 | ) 7 | 8 | from .algorithm import FedDropoutAvgAlgorithm 9 | from .worker import FedDropoutAvgWorker 10 | 11 | AlgorithmRepository.register_algorithm( 12 | algorithm_name="fed_dropout_avg", 13 | client_cls=FedDropoutAvgWorker, 14 | server_cls=AggregationServer, 15 | algorithm_cls=FedDropoutAvgAlgorithm, 16 | ) 17 | -------------------------------------------------------------------------------- /conf/fed_paq/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_paq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.01 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dataset_kwargs: 15 | input_max_len: 300 16 | dataset_type: text 17 | tokenizer: 18 | type: spacy 19 | model_kwargs: 20 | word_vector_name: glove.6B.100d 21 | num_encoder_layer: 2 22 | d_model: 100 23 | nhead: 5 24 | frozen_modules: 25 | names: [embedding] 26 | -------------------------------------------------------------------------------- /conf/sign_sgd/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: sign_SGD 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 1 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 100 10 | batch_size: 64 11 | learning_rate: 0.01 12 | dataset_kwargs: 13 | input_max_len: 300 14 | dataset_type: text 15 | tokenizer: 16 | type: spacy 17 | model_kwargs: 18 | word_vector_name: glove.6B.100d 19 | num_encoder_layer: 2 20 | d_model: 100 21 | nhead: 5 22 | frozen_modules: 23 | names: [embedding] 24 | algorithm_kwargs: 25 | distribute_init_parameters: false 26 | -------------------------------------------------------------------------------- /method/shapley_value/__init__.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import ( 2 | AggregationWorker, 3 | AlgorithmRepository, 4 | ) 5 | 6 | from .GTG_shapley_value_server import GTGShapleyValueServer 7 | from .multiround_shapley_value_server import MultiRoundShapleyValueServer 8 | 9 | AlgorithmRepository.register_algorithm( 10 | algorithm_name="multiround_shapley_value", 11 | client_cls=AggregationWorker, 12 | server_cls=MultiRoundShapleyValueServer, 13 | ) 14 | AlgorithmRepository.register_algorithm( 15 | algorithm_name="GTG_shapley_value", 16 | client_cls=AggregationWorker, 17 | server_cls=GTGShapleyValueServer, 18 | ) 19 | -------------------------------------------------------------------------------- /conf/smafd/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: single_model_afd 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.01 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dropout_rate: 0.3 15 | dataset_kwargs: 16 | input_max_len: 300 17 | dataset_type: text 18 | tokenizer: 19 | type: spacy 20 | model_kwargs: 21 | word_vector_name: glove.6B.100d 22 | num_encoder_layer: 2 23 | d_model: 100 24 | nhead: 5 25 | frozen_modules: 26 | names: [embedding] 27 | -------------------------------------------------------------------------------- /conf/fed_dropout_avg/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_dropout_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.01 12 | algorithm_kwargs: 13 | dropout_rate: 0.3 14 | random_client_number: 5 15 | dataset_kwargs: 16 | input_max_len: 300 17 | dataset_type: text 18 | tokenizer: 19 | type: spacy 20 | model_kwargs: 21 | word_vector_name: glove.6B.100d 22 | num_encoder_layer: 2 23 | d_model: 100 24 | nhead: 5 25 | frozen_modules: 26 | names: [embedding] 27 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # CV 2 | python3 ./simulator.py --config-name fed_avg/mnist.yaml ++fed_avg.round=2 ++fed_avg.epoch=1 ++fed_avg.worker_number=10 ++fed_avg.debug=True 3 | # NLP 4 | python3 ./simulator.py --config-name fed_avg/imdb.yaml ++fed_avg.round=1 ++fed_avg.epoch=1 ++fed_avg.worker_number=2 ++fed_avg.debug=True 5 | # GTG 6 | python3 ./simulator.py --config-name gtg_sv/mnist.yaml ++gtg_sv.round=1 ++gtg_sv.epoch=1 ++gtg_sv.worker_number=2 ++gtg_sv.debug=False 7 | # OBD 8 | python3 ./simulator.py --config-name fed_obd/cifar10.yaml ++fed_obd.round=2 ++fed_obd.epoch=1 ++fed_obd.worker_number=10 ++fed_obd.algorithm_kwargs.random_client_number=10 ++fed_obd.algorithm_kwargs.second_phase_epoch=1 ++fed_obd.debug=False 9 | -------------------------------------------------------------------------------- /method/qsgd/__init__.py: -------------------------------------------------------------------------------- 1 | """QSGD: Communication-Efficient SGD via Gradient Quantization and Encoding https://arxiv.org/abs/1610.02132""" 2 | 3 | from distributed_learning_simulation import ( 4 | AggregationServer, 5 | AlgorithmRepository, 6 | FedAVGAlgorithm, 7 | GradientWorker, 8 | StochasticQuantClientEndpoint, 9 | StochasticQuantServerEndpoint, 10 | ) 11 | 12 | AlgorithmRepository.register_algorithm( 13 | algorithm_name="QSGD", 14 | client_cls=GradientWorker, 15 | server_cls=AggregationServer, 16 | algorithm_cls=FedAVGAlgorithm, 17 | client_endpoint_cls=StochasticQuantClientEndpoint, 18 | server_endpoint_cls=StochasticQuantServerEndpoint, 19 | ) 20 | -------------------------------------------------------------------------------- /method/fed_paq/__init__.py: -------------------------------------------------------------------------------- 1 | # FedPAQ: A Communication-Efficient Federated Learning Method with Periodic Averaging and Quantization (https://arxiv.org/abs/1909.13014) 2 | from distributed_learning_simulation import ( 3 | AggregationServer, 4 | AggregationWorker, 5 | AlgorithmRepository, 6 | FedAVGAlgorithm, 7 | StochasticQuantClientEndpoint, 8 | StochasticQuantServerEndpoint, 9 | ) 10 | 11 | AlgorithmRepository.register_algorithm( 12 | algorithm_name="fed_paq", 13 | client_cls=AggregationWorker, 14 | server_cls=AggregationServer, 15 | client_endpoint_cls=StochasticQuantClientEndpoint, 16 | server_endpoint_cls=StochasticQuantServerEndpoint, 17 | algorithm_cls=FedAVGAlgorithm, 18 | ) 19 | -------------------------------------------------------------------------------- /simulator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import cyy_huggingface_toolbox # noqa: F401 5 | import cyy_torch_text # noqa: F401 6 | import cyy_torch_vision # noqa: F401 7 | from cyy_naive_lib.log import redirect_stdout_to_logger 8 | from distributed_learning_simulation import load_config, train 9 | 10 | sys.path.insert(0, os.path.abspath(".")) 11 | import method # noqa: F401 12 | 13 | if __name__ == "__main__": 14 | with redirect_stdout_to_logger(): 15 | config_path = os.path.join(os.path.dirname(__file__), "conf") 16 | config = load_config( 17 | config_path=config_path, 18 | global_conf_path=os.path.join(config_path, "global.yaml"), 19 | ) 20 | train(config=config, single_task=True) 21 | -------------------------------------------------------------------------------- /conf/fed_obd/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_obd 5 | optimizer_name: SGD 6 | worker_number: 100 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.01 12 | endpoint_kwargs: 13 | server: 14 | weight: 0.0001 15 | worker: 16 | weight: 0.0001 17 | algorithm_kwargs: 18 | second_phase_epoch: 10 19 | dropout_rate: 0.3 20 | random_client_number: 50 21 | dataset_kwargs: 22 | input_max_len: 300 23 | dataset_type: text 24 | tokenizer: 25 | type: spacy 26 | model_kwargs: 27 | word_vector_name: glove.6B.100d 28 | num_encoder_layer: 2 29 | d_model: 100 30 | nhead: 5 31 | frozen_modules: 32 | names: [embedding] 33 | -------------------------------------------------------------------------------- /method/sign_sgd/worker.py: -------------------------------------------------------------------------------- 1 | """signSGD: Compressed Optimisation for Non-Convex Problems https://arxiv.org/abs/1802.04434""" 2 | 3 | from cyy_torch_toolbox import ModelGradient, TensorDict 4 | from distributed_learning_simulation import GradientWorker, ParameterMessage 5 | 6 | 7 | class SignSGDWorker(GradientWorker): 8 | def _process_gradient(self, gradient_dict: ModelGradient) -> TensorDict: 9 | self._send_data_to_server( 10 | ParameterMessage( 11 | parameter={k: v.sign() for k, v in gradient_dict.items()}, 12 | in_round=True, 13 | aggregation_weight=self.trainer.dataset_size, 14 | ) 15 | ) 16 | result = self._get_data_from_server() 17 | assert isinstance(result, ParameterMessage) 18 | return result.parameter 19 | -------------------------------------------------------------------------------- /method/fed_dropout_avg/algorithm.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from distributed_learning_simulation import FedAVGAlgorithm 5 | 6 | 7 | class FedDropoutAvgAlgorithm(FedAVGAlgorithm): 8 | def _get_weight(self, worker_data, name: str, parameter: torch.Tensor) -> Any: 9 | weight = super()._get_weight( 10 | worker_data=worker_data, name=name, parameter=parameter 11 | ) 12 | return (parameter != 0).float() * weight 13 | 14 | def _apply_total_weight( 15 | self, name: str, parameter: torch.Tensor, total_weight: Any 16 | ) -> torch.Tensor: 17 | # avoid dividing by zero, in such case we set weight to 1 18 | total_weight[total_weight == 0] = 1 19 | return super()._apply_total_weight( 20 | name=name, parameter=parameter, total_weight=total_weight 21 | ) 22 | -------------------------------------------------------------------------------- /method/fed_obd/__init__.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import ( 2 | AlgorithmRepository, 3 | FedAVGAlgorithm, 4 | NNADQClientEndpoint, 5 | NNADQServerEndpoint, 6 | StochasticQuantClientEndpoint, 7 | StochasticQuantServerEndpoint, 8 | ) 9 | 10 | from .server import FedOBDServer 11 | from .worker import FedOBDWorker 12 | 13 | AlgorithmRepository.register_algorithm( 14 | algorithm_name="fed_obd", 15 | client_cls=FedOBDWorker, 16 | server_cls=FedOBDServer, 17 | client_endpoint_cls=NNADQClientEndpoint, 18 | server_endpoint_cls=NNADQServerEndpoint, 19 | algorithm_cls=FedAVGAlgorithm, 20 | ) 21 | 22 | AlgorithmRepository.register_algorithm( 23 | algorithm_name="fed_obd_sq", 24 | client_cls=FedOBDWorker, 25 | server_cls=FedOBDServer, 26 | client_endpoint_cls=StochasticQuantClientEndpoint, 27 | server_endpoint_cls=StochasticQuantServerEndpoint, 28 | algorithm_cls=FedAVGAlgorithm, 29 | ) 30 | -------------------------------------------------------------------------------- /method/fed_dropout_avg/worker.py: -------------------------------------------------------------------------------- 1 | """FedDropoutAvg: Generalizable federated learning for histopathology image classification (https://arxiv.org/pdf/2111.13230.pdf)""" 2 | 3 | import torch 4 | from cyy_naive_lib.log import log_info 5 | from distributed_learning_simulation import AggregationWorker, ParameterMessage 6 | 7 | 8 | class FedDropoutAvgWorker(AggregationWorker): 9 | def _get_sent_data(self) -> ParameterMessage: 10 | dropout_rate: float = self.config.algorithm_kwargs["dropout_rate"] 11 | if self.hold_log_lock: 12 | log_info("use dropout_rate %s", dropout_rate) 13 | self._send_parameter_diff = False 14 | sent_data = super()._get_sent_data() 15 | assert isinstance(sent_data, ParameterMessage) 16 | parameter = sent_data.parameter 17 | total_num: float = 0 18 | send_num: float = 0 19 | for k, v in parameter.items(): 20 | weight = torch.bernoulli(torch.full_like(v, 1 - dropout_rate)) 21 | parameter[k] = v * weight 22 | total_num += parameter[k].numel() 23 | send_num += torch.count_nonzero(parameter[k]).item() 24 | log_info("send_num %s", send_num) 25 | log_info("total_num %s", total_num) 26 | return sent_data 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 63.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "distributed_learning_simulator" 7 | version = "1.0" 8 | requires-python = ">=3.11" 9 | readme = {file = "README.md", content-type = "text/markdown"} 10 | authors = [ 11 | {name = "cyy", email = "cyyever@outlook.com"}, 12 | ] 13 | license = {text = "BSD License"} 14 | classifiers = [ 15 | "Programming Language :: Python" 16 | ] 17 | 18 | dependencies = [ 19 | "distributed_learning_simulation@git+https://github.com/cyyever/distributed_learning_simulation_lib.git", 20 | "cyy_torch_text@git+https://github.com/cyyever/torch_text.git", 21 | "cyy_torch_vision@git+https://github.com/cyyever/torch_vision.git" 22 | ] 23 | 24 | 25 | 26 | [tool.setuptools.package-dir] 27 | "distributed_learning_simulator.conf"= "./conf" 28 | "distributed_learning_simulator.method"= "./method" 29 | "distributed_learning_simulator.algorithm"= "./algorithm" 30 | 31 | [project.urls] 32 | Repository = "https://github.com/cyyever/distributed_learning_simulator" 33 | 34 | 35 | [tool.ruff] 36 | target-version = "py312" 37 | src = ["method", "algorithm"] 38 | 39 | [tool.ruff.lint] 40 | select = [ 41 | # pycodestyle 42 | "E", 43 | # Pyflakes 44 | "F", 45 | # pyupgrade 46 | "UP", 47 | # flake8-bugbear 48 | "B", 49 | # flake8-simplify 50 | "SIM", 51 | # isort 52 | "I", 53 | ] 54 | ignore = ["F401","E501","F403"] 55 | -------------------------------------------------------------------------------- /method/fed_obd/server.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from cyy_naive_lib.log import log_warning 4 | from distributed_learning_simulation import ( 5 | AggregationServer, 6 | ParameterMessage, 7 | ParameterMessageBase, 8 | QuantServerEndpoint, 9 | ) 10 | 11 | from .phase import Phase 12 | 13 | 14 | class FedOBDServer(AggregationServer): 15 | def __init__(self, **kwargs: Any) -> None: 16 | super().__init__(**kwargs) 17 | self.__phase: Phase = Phase.STAGE_ONE 18 | assert isinstance(self._endpoint, QuantServerEndpoint) 19 | self._endpoint.use_quant() 20 | self._compute_stat = True 21 | 22 | def select_workers(self) -> set: 23 | if self.__phase != Phase.STAGE_ONE: 24 | return set(range(self.worker_number)) 25 | return super().select_workers() 26 | 27 | def _get_stat_key(self, message: ParameterMessage): 28 | if self.__phase == Phase.STAGE_TWO: 29 | return max(self.performance_stat.keys()) + 1 30 | return super()._get_stat_key(message) 31 | 32 | def _aggregate_worker_data(self) -> ParameterMessageBase: 33 | result: ParameterMessageBase = super()._aggregate_worker_data() 34 | assert result 35 | match self.__phase: 36 | case Phase.STAGE_ONE: 37 | if self.round_index >= self.config.round or ( 38 | self.early_stop and not self.__has_improvement() 39 | ): 40 | log_warning("switch to phase 2") 41 | self.__phase = Phase.STAGE_TWO 42 | result.other_data["phase_two"] = True 43 | case Phase.STAGE_TWO: 44 | if self.early_stop and not self.__has_improvement(): 45 | log_warning("stop aggregation") 46 | result.end_training = True 47 | case _: 48 | raise NotImplementedError(f"unknown phase {self.__phase}") 49 | if result.end_training: 50 | self.__phase = Phase.END 51 | return result 52 | 53 | def _stopped(self) -> bool: 54 | return self.__phase == Phase.END 55 | 56 | def __has_improvement(self) -> bool: 57 | if self.__phase == Phase.STAGE_TWO: 58 | return True 59 | return not self.convergent() 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | outputs 131 | session 132 | -------------------------------------------------------------------------------- /method/fed_avg_qat/worker.py: -------------------------------------------------------------------------------- 1 | import torch.ao.quantization 2 | from cyy_naive_lib.log import log_error 3 | from cyy_torch_algorithm.quantization.qat import QuantizationAwareTraining 4 | from cyy_torch_toolbox import TensorDict 5 | from distributed_learning_simulation import ( 6 | AggregationWorker, 7 | ) 8 | from torch.ao.nn.quantized.modules import Linear 9 | 10 | 11 | class QATWorker(AggregationWorker): 12 | parameter_name_set: set = set() 13 | 14 | def _get_parameters(self) -> TensorDict: 15 | assert isinstance(self.trainer.model, torch.ao.quantization.QuantWrapper) 16 | self.trainer.model.eval() 17 | old_model = self.trainer.model 18 | model_int8 = torch.ao.quantization.convert(old_model.module.cpu()) 19 | self.trainer.replace_model(lambda *args: model_int8) 20 | new_state_dict = {} 21 | for name, p in self.trainer.model_util.get_modules(): 22 | if isinstance(p, Linear): 23 | weight, bias = p._packed_params._weight_bias() 24 | weight_name = name + ".weight" 25 | assert weight_name in self.parameter_name_set 26 | new_state_dict[weight_name] = weight.detach().dequantize() 27 | bias_name = name + ".bias" 28 | assert bias_name in self.parameter_name_set 29 | new_state_dict[bias_name] = bias.detach().dequantize() 30 | 31 | for name, p in self.trainer.model.state_dict().items(): 32 | log_error("%s %s", name, p) 33 | if name.endswith(".zero_point") or name.endswith(".scale"): 34 | continue 35 | if name.startswith("module."): 36 | name = name[len("module.") :] 37 | if isinstance(p, torch.Tensor | torch.nn.Parameter): 38 | new_state_dict[name] = p.detach().dequantize() 39 | log_error("%s %s", new_state_dict.keys(), self.parameter_name_set) 40 | assert sorted(new_state_dict.keys()) == sorted(self.parameter_name_set) 41 | if self._model_loading_fun is None: 42 | self._model_loading_fun = self.load_model 43 | return new_state_dict 44 | 45 | def load_model(self, state_dict) -> None: 46 | self.trainer.remove_model() 47 | self.trainer.model.load_state_dict(state_dict) 48 | 49 | def _before_training(self) -> None: 50 | super()._before_training() 51 | self.parameter_name_set = set(self.trainer.model_util.get_parameters().keys()) 52 | self.trainer.append_hook(QuantizationAwareTraining(), "QAT") 53 | -------------------------------------------------------------------------------- /analysis/analyze_round.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import seaborn as sns 7 | from cyy_naive_lib.fs.path import find_directories 8 | from distributed_learning_simulation.config import load_config 9 | 10 | 11 | def extract_data( 12 | session_path: str, algorithm: str, aggregated_performance_metric: dict 13 | ) -> dict: 14 | for server_session in find_directories(session_path, "visualizer"): 15 | if "server" not in server_session: 16 | continue 17 | for round_no in range(1, config.round + 1): 18 | round_str = f"round_{round_no}" 19 | 20 | metric_file = os.path.join( 21 | server_session, round_str, "test", "performance_metric.json" 22 | ) 23 | if not os.path.isfile(metric_file): 24 | continue 25 | with open( 26 | metric_file, 27 | encoding="utf8", 28 | ) as f: 29 | performance_metric = json.load(f) 30 | for k, v in performance_metric.items(): 31 | if k not in aggregated_performance_metric: 32 | aggregated_performance_metric[k] = pd.DataFrame( 33 | columns=["round", k] 34 | ) 35 | aggregated_performance_metric[k] = pd.concat( 36 | [ 37 | aggregated_performance_metric[k], 38 | pd.DataFrame( 39 | [[round_no, list(v.values())[0], algorithm]], 40 | columns=["round", k, "algorithm"], 41 | ), 42 | ] 43 | ) 44 | return aggregated_performance_metric 45 | 46 | 47 | if __name__ == "__main__": 48 | aggregated_performance_metric: dict = {} 49 | config_files = os.getenv("config_files") 50 | assert config_files is not None 51 | for config_file in config_files.split(): 52 | config = load_config(config_file) 53 | session_path = ( 54 | f"session/{config.distributed_algorithm}/{config.dc_config.dataset_name}/" 55 | ) 56 | extract_data( 57 | session_path, config.distributed_algorithm, aggregated_performance_metric 58 | ) 59 | for metric, metric_df in aggregated_performance_metric.items(): 60 | print(f"deal with {metric}") 61 | ax = sns.lineplot( 62 | data=metric_df, x="round", y=metric, hue="algorithm", errorbar="sd" 63 | ) 64 | plt.tight_layout() 65 | plt.savefig(f"{metric}.png") 66 | plt.clf() 67 | -------------------------------------------------------------------------------- /method/fed_obd/worker.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from cyy_naive_lib.log import log_debug, log_warning 4 | from cyy_torch_toolbox import ExecutorHookPoint, ModelUtil 5 | from distributed_learning_simulation import ( 6 | AggregationWorker, 7 | Message, 8 | ParameterMessage, 9 | QuantClientEndpoint, 10 | ) 11 | from distributed_learning_simulation.context import ClientEndpointInCoroutine 12 | 13 | from .obd_algorithm import OpportunisticBlockDropoutAlgorithmMixin 14 | from .phase import Phase 15 | 16 | 17 | class FedOBDWorker(AggregationWorker, OpportunisticBlockDropoutAlgorithmMixin): 18 | __phase = Phase.STAGE_ONE 19 | 20 | def __init__(self, *args, **kwargs): 21 | AggregationWorker.__init__(self, *args, **kwargs) 22 | OpportunisticBlockDropoutAlgorithmMixin.__init__(self) 23 | assert isinstance( 24 | self._endpoint, QuantClientEndpoint | ClientEndpointInCoroutine 25 | ) 26 | self._endpoint.dequant_server_data() 27 | self._send_parameter_diff = False 28 | self._keep_model_cache = True 29 | 30 | def _load_result_from_server(self, result: Message) -> None: 31 | if "phase_two" in result.other_data: 32 | assert isinstance(result, ParameterMessage) 33 | # result.other_data.pop("phase_two") 34 | self.__phase = Phase.STAGE_TWO 35 | log_warning("switch to phase 2") 36 | self.set_reuse_learning_rate(True) 37 | self._send_parameter_diff = True 38 | self.disable_choosing_model_by_validation() 39 | self.trainer.hyper_parameter.epoch = self.config.algorithm_kwargs[ 40 | "second_phase_epoch" 41 | ] 42 | self.config.round = self._round_index + 1 43 | self._aggregation_time = ExecutorHookPoint.AFTER_EPOCH 44 | self._register_aggregation() 45 | 46 | super()._load_result_from_server(result=result) 47 | 48 | def _get_model_util(self) -> ModelUtil: 49 | return self.trainer.model_util 50 | 51 | def _aggregation(self, sent_data: Message, **kwargs: Any) -> None: 52 | if self.__phase == Phase.STAGE_TWO: 53 | executor = kwargs["executor"] 54 | if kwargs["epoch"] == executor.hyper_parameter.epoch: 55 | sent_data.end_training = True 56 | self._force_stop = True 57 | log_debug("end training") 58 | super()._aggregation(sent_data=sent_data, **kwargs) 59 | 60 | def _stopped(self) -> bool: 61 | return self._force_stop 62 | 63 | def _get_sent_data(self): 64 | assert self._model_cache is not None 65 | data = super()._get_sent_data() 66 | if self.__phase == Phase.STAGE_ONE: 67 | assert isinstance(data, ParameterMessage) 68 | block_parameter = self.get_block_parameter( 69 | parameter=data.parameter, 70 | ) 71 | data.parameter = self._model_cache.get_parameter_diff(block_parameter) 72 | return data 73 | 74 | data.in_round = True 75 | log_warning("phase 2 keys %s", data.other_data.keys()) 76 | return data 77 | -------------------------------------------------------------------------------- /method/fed_obd/obd_algorithm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cyy_naive_lib.log import log_debug, log_info 3 | from cyy_preprocessing_pipeline import cat_tensors_to_vector 4 | from cyy_torch_toolbox import ModelParameter 5 | 6 | from ..algorithm.block_algorithm import BlockAlgorithmMixin 7 | 8 | 9 | class OpportunisticBlockDropoutAlgorithmMixin(BlockAlgorithmMixin): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.__dropout_rate = self.config.algorithm_kwargs["dropout_rate"] 13 | log_debug("use dropout rate %s", self.__dropout_rate) 14 | self.__parameter_num: int = 0 15 | 16 | def get_block_parameter(self, parameter: ModelParameter) -> ModelParameter: 17 | if self.__parameter_num == 0: 18 | parameter_list = self.trainer.model_util.get_parameter_list() 19 | self.__parameter_num = len(parameter_list) 20 | threshold = (1 - self.__dropout_rate) * self.__parameter_num 21 | partial_parameter_num = 0 22 | new_parameter: dict = {} 23 | 24 | block_delta: dict = {} 25 | for block in self.blocks: 26 | block_dict, delta, block_size = self.__analyze_block(parameter, block) 27 | mean_delta = delta / block_size 28 | if mean_delta not in block_delta: 29 | block_delta[mean_delta] = [] 30 | block_delta[mean_delta].append((block_dict, block_size)) 31 | log_debug("block_delta is %s", sorted(block_delta.keys(), reverse=True)) 32 | 33 | for mean_delta in sorted(block_delta.keys(), reverse=True): 34 | if partial_parameter_num > threshold: 35 | break 36 | for block_dict, block_size in block_delta[mean_delta]: 37 | if partial_parameter_num + block_size > threshold: 38 | continue 39 | partial_parameter_num += block_size 40 | new_parameter |= block_dict 41 | log_debug("choose blocks %s", new_parameter.keys()) 42 | log_info( 43 | "partial_parameter_num %s threshold %s parameter_num %s", 44 | partial_parameter_num, 45 | threshold, 46 | self.__parameter_num, 47 | ) 48 | 49 | return new_parameter 50 | 51 | def __analyze_block(self, parameter: ModelParameter, block: list) -> tuple: 52 | cur_block_parameters = [] 53 | prev_block_parameters = [] 54 | block_dict = {} 55 | for submodule_name, submodule in block: 56 | for p_name, _ in submodule.named_parameters(): 57 | parameter_name = submodule_name + "." + p_name 58 | cur_block_parameters.append(parameter[parameter_name]) 59 | prev_block_parameters.append(self.model_cache.parameter[parameter_name]) 60 | block_dict[parameter_name] = parameter[parameter_name] 61 | 62 | cur_block_parameter = cat_tensors_to_vector(cur_block_parameters) 63 | prev_block_parameter = cat_tensors_to_vector(prev_block_parameters) 64 | delta = torch.linalg.vector_norm( 65 | cur_block_parameter.cpu() - prev_block_parameter.cpu() 66 | ).item() 67 | return (block_dict, delta, cur_block_parameter.nelement()) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # distributed_learning_simulator 2 | 3 | This is a simulator for distributed Machine Learning and Federated Learning on a single host. It implements common algorithms as well as our works. 4 | 5 | ## Installation 6 | 7 | This is a Python project. The third party dependencies are listed in **pyproject.toml **. 8 | 9 | Use PIP to setup: 10 | 11 | ``` 12 | python3 -m pip install . --upgrade --force-reinstall --user 13 | ``` 14 | 15 | ## Our Works 16 | 17 | ### GTG-Shapley 18 | 19 | To run the experiments of [GTG-Shapley: Efficient and Accurate Participant Contribution Evaluation in Federated Learning](https://dl.acm.org/doi/pdf/10.1145/3501811), use this command 20 | 21 | ``` 22 | bash gtg_shapley_train.sh 23 | ``` 24 | 25 | #### Reference 26 | 27 | If you find our work useful, feel free to cite it: 28 | 29 | ``` 30 | @article{10.1145/3501811, 31 | author = {Liu, Zelei and Chen, Yuanyuan and Yu, Han and Liu, Yang and Cui, Lizhen}, 32 | title = {GTG-Shapley: Efficient and Accurate Participant Contribution Evaluation in Federated Learning}, 33 | year = {2022}, 34 | issue_date = {August 2022}, 35 | publisher = {Association for Computing Machinery}, 36 | address = {New York, NY, USA}, 37 | volume = {13}, 38 | number = {4}, 39 | issn = {2157-6904}, 40 | url = {https://doi.org/10.1145/3501811}, 41 | doi = {10.1145/3501811}, 42 | journal = {ACM Trans. Intell. Syst. Technol.}, 43 | month = {may}, 44 | articleno = {60}, 45 | numpages = {21}, 46 | keywords = {Federated learning, contribution assessment, Shapley value} 47 | } 48 | ``` 49 | 50 | ### FedOBD 51 | 52 | To run the experiments of [FedOBD: Opportunistic Block Dropout for Efficiently Training Large-scale Neural Networks through Federated Learning](https://arxiv.org/abs/2208.05174), use this command 53 | 54 | ``` 55 | bash fed_obd_train.sh 56 | ``` 57 | 58 | #### Reference 59 | 60 | If you find our work useful, feel free to cite it: 61 | 62 | ``` 63 | @inproceedings{ijcai2023p394, 64 | title = {FedOBD: Opportunistic Block Dropout for Efficiently Training Large-scale Neural Networks through Federated Learning}, 65 | author = {Chen, Yuanyuan and Chen, Zichen and Wu, Pengcheng and Yu, Han}, 66 | booktitle = {Proceedings of the Thirty-Second International Joint Conference on 67 | Artificial Intelligence, {IJCAI-23}}, 68 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 69 | editor = {Edith Elkind}, 70 | pages = {3541--3549}, 71 | year = {2023}, 72 | month = {8}, 73 | note = {Main Track}, 74 | doi = {10.24963/ijcai.2023/394}, 75 | url = {https://doi.org/10.24963/ijcai.2023/394}, 76 | } 77 | ``` 78 | 79 | ### Historical Embedding-Guided Efficient Large-Scale Federated Graph Learning 80 | 81 | The implementation has been move to other (GitHub repository)[https://github.com/cyyever/distributed_graph_learning_simulator] 82 | 83 | 84 | #### Reference 85 | 86 | If you find this work useful, feel free to cite it: 87 | 88 | ``` 89 | @article{li2024historical, 90 | title={Historical Embedding-Guided Efficient Large-Scale Federated Graph Learning}, 91 | author={Li, Anran and Chen, Yuanyuan and Zhang, Jian and Cheng, Mingfei and Huang, Yihao and Wu, Yueming and Luu, Anh Tuan and Yu, Han}, 92 | journal={Proceedings of the ACM on Management of Data}, 93 | volume={2}, 94 | number={3}, 95 | pages={1--24}, 96 | year={2024}, 97 | publisher={ACM New York, NY, USA} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /algorithm/block_algorithm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from cyy_naive_lib.log import log_info 4 | from cyy_preprocessing_pipeline import cat_tensors_to_vector 5 | from cyy_torch_toolbox import BlockType, ModelUtil 6 | from distributed_learning_simulation.worker.protocol import AggregationWorkerProtocol 7 | 8 | 9 | class BlockAlgorithmMixin(AggregationWorkerProtocol): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.__blocks: list[BlockType] | None = None 13 | self._block_types = { 14 | ("AlbertTransformer",), 15 | ("AlbertEmbeddings",), 16 | ("Bottleneck",), 17 | ("TransformerEncoderLayer",), 18 | (torch.nn.BatchNorm2d, torch.nn.ReLU, torch.nn.Conv2d), 19 | (torch.nn.BatchNorm2d, torch.nn.Conv2d), 20 | (torch.nn.Conv2d, torch.nn.BatchNorm2d), 21 | } 22 | 23 | @property 24 | def blocks(self) -> list[BlockType]: 25 | if self.__blocks is None: 26 | self._find_blocks() 27 | assert self.__blocks is not None 28 | return self.__blocks 29 | 30 | def _get_model_util(self) -> ModelUtil: 31 | return self.trainer.model_util 32 | 33 | def _find_blocks(self) -> None: 34 | model_util = self._get_model_util() 35 | blocks = model_util.get_module_blocks(block_types=self._block_types) 36 | self.__blocks = [] 37 | modules = list(model_util.get_modules()) 38 | while modules: 39 | submodule_name, submodule = modules[0] 40 | del modules[0] 41 | if not submodule_name: 42 | continue 43 | if len(list(submodule.parameters())) == 0: 44 | continue 45 | part_of_block = False 46 | in_block = False 47 | tmp_blocks = [] 48 | 49 | if blocks: 50 | tmp_blocks.append(blocks[0]) 51 | if self.__blocks: 52 | tmp_blocks.append(self.__blocks[-1]) 53 | for block in tmp_blocks: 54 | for block_submodule_name, _ in block: 55 | if block_submodule_name == submodule_name: 56 | part_of_block = True 57 | self.__blocks.append(block) 58 | for _ in range(len(block) - 1): 59 | del modules[0] 60 | del blocks[0] 61 | break 62 | if submodule_name.startswith( 63 | f"{block_submodule_name}." 64 | ) or block_submodule_name.startswith(f"{submodule_name}."): 65 | in_block = True 66 | break 67 | if part_of_block or in_block: 68 | break 69 | if part_of_block or in_block: 70 | continue 71 | self.__blocks.append([(submodule_name, submodule)]) 72 | if self.hold_log_lock: 73 | log_info("identify a submodule:%s", submodule_name) 74 | 75 | if self.hold_log_lock: 76 | log_info("identify these blocks in model:") 77 | for block in self.__blocks: 78 | log_info( 79 | "%s", 80 | [f"{name}" for name, _ in block], 81 | ) 82 | 83 | # check the parameter numbers are the same 84 | tmp_parameter_list = [] 85 | tmp_parameter_name = set() 86 | for block in self.__blocks: 87 | for submodule_name, submodule in block: 88 | for p_name, p in submodule.named_parameters(): 89 | tmp_parameter_list.append(p) 90 | if submodule_name: 91 | tmp_parameter_name.add(submodule_name + "." + p_name) 92 | else: 93 | tmp_parameter_name.add(p_name) 94 | parameter_dict = model_util.get_parameters() 95 | if tmp_parameter_name != set(parameter_dict.keys()): 96 | for a in tmp_parameter_name: 97 | if a not in parameter_dict: 98 | raise RuntimeError(a + " not in model") 99 | for a in parameter_dict: 100 | if a not in tmp_parameter_name: 101 | raise RuntimeError(a + " not in block") 102 | parameter_list = model_util.get_parameter_list() 103 | assert cat_tensors_to_vector(tmp_parameter_list).shape == parameter_list.shape 104 | -------------------------------------------------------------------------------- /algorithm/shapley_value_algorithm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any 4 | 5 | from cyy_naive_lib.concurrency import batch_process 6 | from cyy_naive_lib.log import log_warning 7 | from cyy_torch_algorithm.shapely_value.shapley_value import RoundBasedShapleyValue 8 | from cyy_torch_toolbox import TorchProcessTaskQueue 9 | from distributed_learning_simulation import ( 10 | AggregationServer, 11 | FedAVGAlgorithm, 12 | ParameterMessage, 13 | ) 14 | 15 | 16 | class ShapleyValueAlgorithm(FedAVGAlgorithm): 17 | def __init__( 18 | self, sv_algorithm_cls: type, server: AggregationServer, **kwargs: Any 19 | ) -> None: 20 | super().__init__(**kwargs) 21 | self._server: AggregationServer = server 22 | self.accumulate = False 23 | self.metric_type: str = "accuracy" 24 | self.__sv_algorithm: None | RoundBasedShapleyValue = None 25 | self.sv_algorithm_cls = sv_algorithm_cls 26 | 27 | @property 28 | def server(self) -> AggregationServer: 29 | return self._server 30 | 31 | @property 32 | def sv_algorithm(self) -> RoundBasedShapleyValue: 33 | if self.__sv_algorithm is None: 34 | assert self._all_worker_data 35 | assert self.server.round_index == 1 36 | self.__sv_algorithm = self.sv_algorithm_cls( 37 | players=sorted(self._all_worker_data.keys()), 38 | initial_metric=self.server.performance_stat[ 39 | self.server.round_index - 1 40 | ][f"test_{self.metric_type}"], 41 | algorithm_kwargs=self.config.algorithm_kwargs, 42 | ) 43 | assert isinstance(self.__sv_algorithm, RoundBasedShapleyValue) 44 | if ( 45 | self.config.algorithm_kwargs.get("round_trunc_threshold", None) 46 | is not None 47 | ): 48 | self.__sv_algorithm.set_round_truncation_threshold( 49 | self.config.algorithm_kwargs["round_trunc_threshold"] 50 | ) 51 | self.sv_algorithm.set_batch_metric_function(self._get_batch_metric) 52 | # For client selection in each round 53 | self.__sv_algorithm.set_players( 54 | sorted({k for k, v in self._all_worker_data.items() if v is not None}) 55 | ) 56 | return self.__sv_algorithm 57 | 58 | @property 59 | def choose_best_subset(self) -> bool: 60 | return self.config.algorithm_kwargs.get("choose_best_subset", False) 61 | 62 | def aggregate_worker_data(self) -> ParameterMessage: 63 | self.sv_algorithm.compute(round_index=self.server.round_index) 64 | if self.choose_best_subset: 65 | assert hasattr(self.sv_algorithm, "shapley_values_S") 66 | best_players = self.sv_algorithm.get_best_players( 67 | round_index=self.server.round_index 68 | ) 69 | assert best_players is not None 70 | log_warning("use players %s", best_players) 71 | self._all_worker_data = {k: self._all_worker_data[k] for k in best_players} 72 | return super().aggregate_worker_data() 73 | 74 | def _batch_metric_worker(self, task, **kwargs) -> dict: 75 | return {task: self._get_subset_metric(subset=task)} 76 | 77 | def _get_batch_metric(self, subsets) -> dict: 78 | if len(subsets) == 1: 79 | return {list(subsets)[0]: self._get_subset_metric(list(subsets)[0])} 80 | queue = TorchProcessTaskQueue( 81 | worker_num=self.config.algorithm_kwargs.get("sv_worker_number", None) 82 | ) 83 | queue.disable_logger() 84 | queue.start(worker_fun=self._batch_metric_worker) 85 | res = batch_process(queue, subsets) 86 | queue.stop() 87 | return res 88 | 89 | def _get_subset_metric(self, subset) -> float: 90 | assert subset 91 | aggregated_parameter = super()._aggregate_parameter( 92 | chosen_worker_ids=set(self.sv_algorithm.get_players(subset)) 93 | ) 94 | assert aggregated_parameter 95 | return self.server.get_metric( 96 | aggregated_parameter, log_performance_metric=False 97 | )[self.metric_type] 98 | 99 | def exit(self) -> None: 100 | assert self.sv_algorithm is not None 101 | self.sv_algorithm.exit() 102 | with open( 103 | os.path.join(self.config.save_dir, "shapley_values.json"), 104 | "w", 105 | encoding="utf8", 106 | ) as f: 107 | json.dump(self.sv_algorithm.get_result(), f) 108 | -------------------------------------------------------------------------------- /analysis/analyze_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import torch 5 | from distributed_learning_simulation import load_config 6 | 7 | 8 | def compute_acc(paths: list) -> None: 9 | final_test_acc = [] 10 | worker_acc: dict = {} 11 | for path in paths: 12 | assert os.path.isfile(path) 13 | lines = None 14 | with open(path, encoding="utf8") as f: 15 | lines = f.readlines() 16 | for line in reversed(lines): 17 | if config.distributed_algorithm == "sign_SGD": 18 | if "test loss" in line: 19 | res = re.findall("[0-9.]+%", line) 20 | assert len(res) == 1 21 | acc = float(res[0].replace("%", "")) 22 | final_test_acc.append(acc) 23 | break 24 | elif config.distributed_algorithm in ( 25 | "fed_obd_first_stage", 26 | "fed_obd_layer", 27 | ): 28 | if ( 29 | "test in" in line 30 | and "accuracy" in line 31 | and f"round: {config.round}" in line 32 | ): 33 | print("line is", line) 34 | res = re.findall("[0-9.]+%", line) 35 | assert len(res) == 1 36 | acc = float(res[0].replace("%", "")) 37 | final_test_acc.append(acc) 38 | break 39 | else: 40 | if "test in" in line and "accuracy" in line: 41 | res = re.findall("[0-9.]+%", line) 42 | assert len(res) == 1 43 | acc = float(res[0].replace("%", "")) 44 | print(line) 45 | final_test_acc.append(acc) 46 | break 47 | for worker_id in range(config.worker_number): 48 | for line in reversed(lines): 49 | res = re.findall(f"worker {worker_id}.*train.*accuracy", line) 50 | if res: 51 | res = re.findall("[0-9.]+%", line) 52 | assert len(res) == 1 53 | acc = float(res[0].replace("%", "")) 54 | if worker_id not in worker_acc: 55 | worker_acc[worker_id] = [] 56 | worker_acc[worker_id].append(acc) 57 | break 58 | assert len(final_test_acc) == len(paths) 59 | std, mean = torch.std_mean(torch.tensor(final_test_acc)) 60 | print("test acc", round(mean.item(), 2), round(std.item(), 2)) 61 | 62 | 63 | def compute_data_amount(paths: list) -> dict: 64 | trainer = config.create_trainer() 65 | parameter_list = trainer.model_util.get_parameter_list() 66 | distributed_algorithm = config.distributed_algorithm.lower() 67 | print("worker_number is", config.worker_number) 68 | print("model is", config.model_config.model_name) 69 | uploaded_msg_num = config.round * config.algorithm_kwargs.get( 70 | "random_client_number", config.worker_number 71 | ) 72 | uploaded_parameter_num = uploaded_msg_num * parameter_list.nelement() 73 | downloaded_msg_num = uploaded_msg_num 74 | downloaded_parameter_num = uploaded_parameter_num 75 | distributed_msg_num = config.worker_number 76 | distributed_parameter_num = distributed_msg_num * parameter_list.nelement() 77 | msg_num = uploaded_msg_num + downloaded_msg_num + distributed_msg_num 78 | data_amount: float | tuple = 0 79 | match distributed_algorithm: 80 | case "fed_avg": 81 | data_amount = ( 82 | parameter_list.nelement() 83 | * parameter_list.element_size() 84 | * msg_num 85 | / (1024 * 1024) 86 | ) 87 | case "fed_obd": 88 | msg_num += ( 89 | config.algorithm_kwargs["second_phase_epoch"] * config.worker_number * 2 90 | ) 91 | 92 | data_amounts = [] 93 | for path in paths: 94 | remain_msg = msg_num 95 | lines = None 96 | compressed_part = 0 97 | rnd_cnt = 0 98 | with open(path, encoding="utf8") as f: 99 | lines = f.readlines() 100 | stage_one = True 101 | for line in lines: 102 | if "broadcast NNABQ compression ratio" in line: 103 | res = re.findall("[0-9.]+$", line) 104 | assert len(res) == 1 105 | broadcast_ratio = float( 106 | res[0].replace("(", "").replace(",", "") 107 | ) 108 | # print("broadcast_ratio", broadcast_ratio) 109 | rnd_cnt += 1 110 | if rnd_cnt <= config.round: 111 | compressed_part += ( 112 | broadcast_ratio 113 | * config.algorithm_kwargs["random_client_number"] 114 | ) 115 | remain_msg -= config.algorithm_kwargs[ 116 | "random_client_number" 117 | ] 118 | else: 119 | stage_one = False 120 | if remain_msg > config.worker_number: 121 | compressed_part += ( 122 | broadcast_ratio * config.worker_number 123 | ) 124 | remain_msg -= config.worker_number 125 | if "worker NNABQ compression ratio" in line: 126 | res = re.findall("[0-9.]+$", line) 127 | assert len(res) == 1 128 | worker_ratio = float(res[0].replace("(", "").replace(",", "")) 129 | # print("worker_ratio is ", worker_ratio) 130 | if stage_one: 131 | worker_ratio *= 1 - config.algorithm_kwargs["dropout_rate"] 132 | compressed_part += worker_ratio 133 | remain_msg -= 1 134 | # assert remain_msg == 0 135 | print(remain_msg) 136 | assert remain_msg == config.worker_number 137 | compressed_part += remain_msg 138 | data_amounts.append( 139 | parameter_list.nelement() 140 | * parameter_list.element_size() 141 | * compressed_part 142 | / (1024 * 1024) 143 | ) 144 | assert len(data_amounts) == len(paths) 145 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 146 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 147 | case "fed_obd_sq": 148 | msg_num += ( 149 | config.algorithm_kwargs["second_phase_epoch"] * config.worker_number * 2 150 | ) 151 | 152 | data_amount = ( 153 | uploaded_parameter_num * (1 - config.algorithm_kwargs["dropout_rate"]) 154 | + downloaded_parameter_num 155 | + config.algorithm_kwargs["second_phase_epoch"] 156 | * config.worker_number 157 | * 2 158 | * parameter_list.nelement() 159 | + distributed_parameter_num * parameter_list.element_size() 160 | ) / (1024 * 1024) 161 | case "fed_dropout_avg": 162 | data_amounts = [] 163 | for path in paths: 164 | lines = None 165 | with open(path, encoding="utf8") as f: 166 | lines = f.readlines() 167 | uploaded_parameter_num = 0 168 | for line in lines: 169 | if "send_num" in line: 170 | res = re.findall("[0-9.]+$", line) 171 | assert len(res) == 1 172 | uploaded_parameter_num += float(res[0]) 173 | assert uploaded_parameter_num > 0 174 | assert downloaded_parameter_num > 0 175 | data_amounts.append( 176 | ( 177 | uploaded_parameter_num 178 | + (downloaded_parameter_num + distributed_parameter_num) 179 | ) 180 | * parameter_list.element_size() 181 | / (1024 * 1024) 182 | ) 183 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 184 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 185 | case "single_model_afd": 186 | data_amounts = [] 187 | for path in paths: 188 | transfer_number = 0 189 | lines = None 190 | with open(path, encoding="utf8") as f: 191 | lines = f.readlines() 192 | for line in lines: 193 | if "send_num" in line: 194 | res = re.findall("[0-9.]+$", line) 195 | assert len(res) == 1 196 | transfer_number += float(res[0]) 197 | data_amounts.append( 198 | (transfer_number + distributed_parameter_num) 199 | * parameter_list.element_size() 200 | / (1024 * 1024) 201 | ) 202 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 203 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 204 | 205 | case "fed_obd_first_stage": 206 | data_amounts = [] 207 | for path in paths: 208 | remain_msg = msg_num 209 | lines = None 210 | compressed_part = 0 211 | rnd_cnt = 0 212 | with open(path, encoding="utf8") as f: 213 | lines = f.readlines() 214 | stage_one = True 215 | for line in lines: 216 | if "broadcast NNABQ compression ratio" in line: 217 | res = re.findall("[0-9.]+$", line) 218 | assert len(res) == 1 219 | broadcast_ratio = float( 220 | res[0].replace("(", "").replace(",", "") 221 | ) 222 | # print("broadcast_ratio", broadcast_ratio) 223 | rnd_cnt += 1 224 | if rnd_cnt <= config.round: 225 | compressed_part += ( 226 | broadcast_ratio 227 | * config.algorithm_kwargs["random_client_number"] 228 | ) 229 | remain_msg -= config.algorithm_kwargs[ 230 | "random_client_number" 231 | ] 232 | else: 233 | break 234 | if "worker NNABQ compression ratio" in line: 235 | res = re.findall("[0-9.]+$", line) 236 | assert len(res) == 1 237 | worker_ratio = float(res[0].replace("(", "").replace(",", "")) 238 | # print("worker_ratio is ", worker_ratio) 239 | if stage_one: 240 | worker_ratio *= 1 - config.algorithm_kwargs["dropout_rate"] 241 | compressed_part += worker_ratio 242 | remain_msg -= 1 243 | # assert remain_msg == 0 244 | print(remain_msg) 245 | # assert remain_msg == config.worker_number 246 | compressed_part += config.worker_number 247 | data_amounts.append( 248 | parameter_list.nelement() 249 | * parameter_list.element_size() 250 | * compressed_part 251 | / (1024 * 1024) 252 | ) 253 | assert len(data_amounts) == len(paths) 254 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 255 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 256 | 257 | case "fed_paq": 258 | msg_num = ( 259 | config.round * config.algorithm_kwargs["random_client_number"] * 2 260 | + config.worker_number 261 | ) 262 | data_amount = ( 263 | uploaded_parameter_num * 1 264 | + downloaded_parameter_num * parameter_list.element_size() 265 | + distributed_parameter_num * parameter_list.element_size() 266 | ) / (1024 * 1024) 267 | match data_amount: 268 | case float(): 269 | data_amount = round(data_amount, 2) 270 | case dict(): 271 | data_amount = {k: round(v, 2) for k, v in data_amount.items()} 272 | 273 | return {"msg_num": msg_num, "data_amount": data_amount} 274 | 275 | 276 | if __name__ == "__main__": 277 | # load_config() 278 | # config.distributed_algorithm = "fed_obd_first_stage" 279 | paths = os.getenv("logfiles").strip().split(" ") 280 | assert paths 281 | compute_acc(paths) 282 | res = compute_data_amount(paths) 283 | print("msg_num is", res["msg_num"]) 284 | print("data_amount is", res["data_amount"]) 285 | --------------------------------------------------------------------------------