├── .gitignore ├── README.md ├── __init__.py ├── algorithm ├── __init__.py ├── block_algorithm.py └── shapley_value_algorithm.py ├── analysis ├── __init__.py ├── analyze_log.py └── analyze_round.py ├── conf ├── __init__.py ├── fed_avg │ ├── cifar10.yaml │ ├── cifar100.yaml │ ├── imagenet.yaml │ ├── imdb.yaml │ └── mnist.yaml ├── fed_avg_qat │ ├── cifar10.yaml │ └── mnist.yaml ├── fed_dropout_avg │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml ├── fed_obd │ ├── cifar10.yaml │ ├── cifar100.yaml │ ├── cifar100_sq.yaml │ └── imdb.yaml ├── fed_obd_sq │ └── cifar100.yaml ├── fed_paq │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml ├── global.yaml ├── gtg_sv │ ├── cifar10.yaml │ ├── cifar100.yaml │ ├── imdb.yaml │ └── mnist.yaml ├── multiround_sv │ ├── cifar10.yaml │ └── cifar100.yaml ├── qsgd │ └── mnist.yaml ├── sign_sgd │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml └── smafd │ ├── cifar10.yaml │ ├── cifar100.yaml │ └── imdb.yaml ├── exp_analyzer.py ├── exp_analyzer.sh ├── fed_aas.sh ├── fed_obd_train.sh ├── gtg_shapley_train.sh ├── method ├── __init__.py ├── fed_avg_qat │ ├── __init__.py │ └── worker.py ├── fed_dropout_avg │ ├── __init__.py │ ├── algorithm.py │ └── worker.py ├── fed_obd │ ├── __init__.py │ ├── obd_algorithm.py │ ├── phase.py │ ├── server.py │ └── worker.py ├── fed_paq │ └── __init__.py ├── qsgd │ └── __init__.py ├── shapley_value │ ├── GTG_shapley_value_algorithm.py │ ├── GTG_shapley_value_server.py │ ├── __init__.py │ ├── multiround_shapley_value_algorithm.py │ ├── multiround_shapley_value_server.py │ └── shapley_value_server.py └── sign_sgd │ ├── __init__.py │ ├── algorithm.py │ └── worker.py ├── other_method_test.sh ├── pyproject.toml ├── simulator.py └── test.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | outputs 131 | session 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # distributed_learning_simulator 2 | 3 | This is a simulator for distributed Machine Learning and Federated Learning on a single host. It implements common algorithms as well as our works. 4 | 5 | ## Installation 6 | 7 | This is a Python project. The third party dependencies are listed in **pyproject.toml **. 8 | 9 | Use PIP to setup: 10 | 11 | ``` 12 | python3 -m pip install . --upgrade --force-reinstall --user 13 | ``` 14 | 15 | ## Our Works 16 | 17 | ### GTG-Shapley 18 | 19 | To run the experiments of [GTG-Shapley: Efficient and Accurate Participant Contribution Evaluation in Federated Learning](https://dl.acm.org/doi/pdf/10.1145/3501811), use this command 20 | 21 | ``` 22 | bash gtg_shapley_train.sh 23 | ``` 24 | 25 | #### Reference 26 | 27 | If you find our work useful, feel free to cite it: 28 | 29 | ``` 30 | @article{10.1145/3501811, 31 | author = {Liu, Zelei and Chen, Yuanyuan and Yu, Han and Liu, Yang and Cui, Lizhen}, 32 | title = {GTG-Shapley: Efficient and Accurate Participant Contribution Evaluation in Federated Learning}, 33 | year = {2022}, 34 | issue_date = {August 2022}, 35 | publisher = {Association for Computing Machinery}, 36 | address = {New York, NY, USA}, 37 | volume = {13}, 38 | number = {4}, 39 | issn = {2157-6904}, 40 | url = {https://doi.org/10.1145/3501811}, 41 | doi = {10.1145/3501811}, 42 | journal = {ACM Trans. Intell. Syst. Technol.}, 43 | month = {may}, 44 | articleno = {60}, 45 | numpages = {21}, 46 | keywords = {Federated learning, contribution assessment, Shapley value} 47 | } 48 | ``` 49 | 50 | ### FedOBD 51 | 52 | To run the experiments of [FedOBD: Opportunistic Block Dropout for Efficiently Training Large-scale Neural Networks through Federated Learning](https://arxiv.org/abs/2208.05174), use this command 53 | 54 | ``` 55 | bash fed_obd_train.sh 56 | ``` 57 | 58 | #### Reference 59 | 60 | If you find our work useful, feel free to cite it: 61 | 62 | ``` 63 | @inproceedings{ijcai2023p394, 64 | title = {FedOBD: Opportunistic Block Dropout for Efficiently Training Large-scale Neural Networks through Federated Learning}, 65 | author = {Chen, Yuanyuan and Chen, Zichen and Wu, Pengcheng and Yu, Han}, 66 | booktitle = {Proceedings of the Thirty-Second International Joint Conference on 67 | Artificial Intelligence, {IJCAI-23}}, 68 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 69 | editor = {Edith Elkind}, 70 | pages = {3541--3549}, 71 | year = {2023}, 72 | month = {8}, 73 | note = {Main Track}, 74 | doi = {10.24963/ijcai.2023/394}, 75 | url = {https://doi.org/10.24963/ijcai.2023/394}, 76 | } 77 | ``` 78 | 79 | ### Historical Embedding-Guided Efficient Large-Scale Federated Graph Learning 80 | 81 | The implementation has been move to other (GitHub repository)[https://github.com/cyyever/distributed_graph_learning_simulator] 82 | 83 | 84 | #### Reference 85 | 86 | If you find this work useful, feel free to cite it: 87 | 88 | ``` 89 | @article{li2024historical, 90 | title={Historical Embedding-Guided Efficient Large-Scale Federated Graph Learning}, 91 | author={Li, Anran and Chen, Yuanyuan and Zhang, Jian and Cheng, Mingfei and Huang, Yihao and Wu, Yueming and Luu, Anh Tuan and Yu, Han}, 92 | journal={Proceedings of the ACM on Management of Data}, 93 | volume={2}, 94 | number={3}, 95 | pages={1--24}, 96 | year={2024}, 97 | publisher={ACM New York, NY, USA} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__))) 5 | -------------------------------------------------------------------------------- /algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_algorithm import BlockAlgorithmMixin 2 | from .shapley_value_algorithm import ShapleyValueAlgorithm 3 | 4 | __all__ = ["BlockAlgorithmMixin", "ShapleyValueAlgorithm"] 5 | -------------------------------------------------------------------------------- /algorithm/block_algorithm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from cyy_naive_lib.log import log_info 4 | from cyy_torch_toolbox import BlockType, ModelUtil 5 | from cyy_torch_toolbox.tensor import cat_tensors_to_vector 6 | from distributed_learning_simulation.worker.protocol import AggregationWorkerProtocol 7 | 8 | 9 | class BlockAlgorithmMixin(AggregationWorkerProtocol): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.__blocks: list[BlockType] | None = None 13 | self._block_types = { 14 | ("AlbertTransformer",), 15 | ("AlbertEmbeddings",), 16 | ("Bottleneck",), 17 | ("TransformerEncoderLayer",), 18 | (torch.nn.BatchNorm2d, torch.nn.ReLU, torch.nn.Conv2d), 19 | (torch.nn.BatchNorm2d, torch.nn.Conv2d), 20 | (torch.nn.Conv2d, torch.nn.BatchNorm2d), 21 | } 22 | 23 | @property 24 | def blocks(self) -> list[BlockType]: 25 | if self.__blocks is None: 26 | self._find_blocks() 27 | assert self.__blocks is not None 28 | return self.__blocks 29 | 30 | def _get_model_util(self) -> ModelUtil: 31 | return self.trainer.model_util 32 | 33 | def _find_blocks(self) -> None: 34 | model_util = self._get_model_util() 35 | blocks = model_util.get_module_blocks(block_types=self._block_types) 36 | self.__blocks = [] 37 | modules = list(model_util.get_modules()) 38 | while modules: 39 | submodule_name, submodule = modules[0] 40 | del modules[0] 41 | if not submodule_name: 42 | continue 43 | if len(list(submodule.parameters())) == 0: 44 | continue 45 | part_of_block = False 46 | in_block = False 47 | tmp_blocks = [] 48 | 49 | if blocks: 50 | tmp_blocks.append(blocks[0]) 51 | if self.__blocks: 52 | tmp_blocks.append(self.__blocks[-1]) 53 | for block in tmp_blocks: 54 | for block_submodule_name, _ in block: 55 | if block_submodule_name == submodule_name: 56 | part_of_block = True 57 | self.__blocks.append(block) 58 | for _ in range(len(block) - 1): 59 | del modules[0] 60 | del blocks[0] 61 | break 62 | if submodule_name.startswith( 63 | f"{block_submodule_name}." 64 | ) or block_submodule_name.startswith(f"{submodule_name}."): 65 | in_block = True 66 | break 67 | if part_of_block or in_block: 68 | break 69 | if part_of_block or in_block: 70 | continue 71 | self.__blocks.append([(submodule_name, submodule)]) 72 | if self.hold_log_lock: 73 | log_info("identify a submodule:%s", submodule_name) 74 | 75 | if self.hold_log_lock: 76 | log_info("identify these blocks in model:") 77 | for block in self.__blocks: 78 | log_info( 79 | "%s", 80 | [f"{name}" for name, _ in block], 81 | ) 82 | 83 | # check the parameter numbers are the same 84 | tmp_parameter_list = [] 85 | tmp_parameter_name = set() 86 | for block in self.__blocks: 87 | for submodule_name, submodule in block: 88 | for p_name, p in submodule.named_parameters(): 89 | tmp_parameter_list.append(p) 90 | if submodule_name: 91 | tmp_parameter_name.add(submodule_name + "." + p_name) 92 | else: 93 | tmp_parameter_name.add(p_name) 94 | parameter_dict = model_util.get_parameters() 95 | if tmp_parameter_name != set(parameter_dict.keys()): 96 | for a in tmp_parameter_name: 97 | if a not in parameter_dict: 98 | raise RuntimeError(a + " not in model") 99 | for a in parameter_dict: 100 | if a not in tmp_parameter_name: 101 | raise RuntimeError(a + " not in block") 102 | parameter_list = model_util.get_parameter_list() 103 | assert cat_tensors_to_vector(tmp_parameter_list).shape == parameter_list.shape 104 | -------------------------------------------------------------------------------- /algorithm/shapley_value_algorithm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any 4 | 5 | from cyy_naive_lib.concurrency import batch_process 6 | from cyy_naive_lib.log import log_warning 7 | from cyy_torch_algorithm.shapely_value.shapley_value import RoundBasedShapleyValue 8 | from cyy_torch_toolbox import TorchProcessTaskQueue 9 | from distributed_learning_simulation import ( 10 | AggregationServer, 11 | FedAVGAlgorithm, 12 | ParameterMessage, 13 | ) 14 | 15 | 16 | class ShapleyValueAlgorithm(FedAVGAlgorithm): 17 | def __init__( 18 | self, sv_algorithm_cls: type, server: AggregationServer, **kwargs: Any 19 | ) -> None: 20 | super().__init__(**kwargs) 21 | self._server: AggregationServer = server 22 | self.accumulate = False 23 | self.metric_type: str = "accuracy" 24 | self.__sv_algorithm: None | RoundBasedShapleyValue = None 25 | self.sv_algorithm_cls = sv_algorithm_cls 26 | 27 | @property 28 | def server(self) -> AggregationServer: 29 | return self._server 30 | 31 | @property 32 | def sv_algorithm(self) -> RoundBasedShapleyValue: 33 | if self.__sv_algorithm is None: 34 | assert self._all_worker_data 35 | assert self.server.round_index == 1 36 | self.__sv_algorithm = self.sv_algorithm_cls( 37 | players=sorted(self._all_worker_data.keys()), 38 | initial_metric=self.server.performance_stat[ 39 | self.server.round_index - 1 40 | ][f"test_{self.metric_type}"], 41 | algorithm_kwargs=self.config.algorithm_kwargs, 42 | ) 43 | assert isinstance(self.__sv_algorithm, RoundBasedShapleyValue) 44 | if ( 45 | self.config.algorithm_kwargs.get("round_trunc_threshold", None) 46 | is not None 47 | ): 48 | self.__sv_algorithm.set_round_truncation_threshold( 49 | self.config.algorithm_kwargs["round_trunc_threshold"] 50 | ) 51 | self.sv_algorithm.set_batch_metric_function(self._get_batch_metric) 52 | # For client selection in each round 53 | self.__sv_algorithm.set_players( 54 | sorted({k for k, v in self._all_worker_data.items() if v is not None}) 55 | ) 56 | return self.__sv_algorithm 57 | 58 | @property 59 | def choose_best_subset(self) -> bool: 60 | return self.config.algorithm_kwargs.get("choose_best_subset", False) 61 | 62 | def aggregate_worker_data(self) -> ParameterMessage: 63 | self.sv_algorithm.compute(round_index=self.server.round_index) 64 | if self.choose_best_subset: 65 | assert hasattr(self.sv_algorithm, "shapley_values_S") 66 | best_players = self.sv_algorithm.get_best_players( 67 | round_index=self.server.round_index 68 | ) 69 | assert best_players is not None 70 | log_warning("use players %s", best_players) 71 | self._all_worker_data = {k: self._all_worker_data[k] for k in best_players} 72 | return super().aggregate_worker_data() 73 | 74 | def _batch_metric_worker(self, task, **kwargs) -> dict: 75 | return {task: self._get_subset_metric(subset=task)} 76 | 77 | def _get_batch_metric(self, subsets) -> dict: 78 | if len(subsets) == 1: 79 | return {list(subsets)[0]: self._get_subset_metric(list(subsets)[0])} 80 | queue = TorchProcessTaskQueue( 81 | worker_num=self.config.algorithm_kwargs.get("sv_worker_number", None) 82 | ) 83 | queue.disable_logger() 84 | queue.start(worker_fun=self._batch_metric_worker) 85 | res = batch_process(queue, subsets) 86 | queue.stop() 87 | return res 88 | 89 | def _get_subset_metric(self, subset) -> float: 90 | assert subset 91 | aggregated_parameter = super()._aggregate_parameter( 92 | chosen_worker_ids=set(self.sv_algorithm.get_players(subset)) 93 | ) 94 | assert aggregated_parameter 95 | return self.server.get_metric( 96 | aggregated_parameter, log_performance_metric=False 97 | )[self.metric_type] 98 | 99 | def exit(self) -> None: 100 | assert self.sv_algorithm is not None 101 | self.sv_algorithm.exit() 102 | with open( 103 | os.path.join(self.config.save_dir, "shapley_values.json"), 104 | "w", 105 | encoding="utf8", 106 | ) as f: 107 | json.dump(self.sv_algorithm.get_result(), f) 108 | -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | currentdir = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | sys.path.insert(0, os.path.join(currentdir, "..")) 7 | -------------------------------------------------------------------------------- /analysis/analyze_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | 5 | currentdir = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | sys.path.insert(0, os.path.join(currentdir, "..")) 8 | 9 | import torch 10 | from config import global_config as config 11 | from config import load_config 12 | 13 | 14 | def compute_acc(paths: list) -> None: 15 | final_test_acc = [] 16 | worker_acc: dict = {} 17 | for path in paths: 18 | assert os.path.isfile(path) 19 | lines = None 20 | with open(path, encoding="utf8") as f: 21 | lines = f.readlines() 22 | for line in reversed(lines): 23 | if config.distributed_algorithm == "sign_SGD": 24 | if "test loss" in line: 25 | res = re.findall("[0-9.]+%", line) 26 | assert len(res) == 1 27 | acc = float(res[0].replace("%", "")) 28 | final_test_acc.append(acc) 29 | break 30 | elif config.distributed_algorithm in ( 31 | "fed_obd_first_stage", 32 | "fed_obd_layer", 33 | ): 34 | if ( 35 | "test in" in line 36 | and "accuracy" in line 37 | and f"round: {config.round}" in line 38 | ): 39 | print("line is", line) 40 | res = re.findall("[0-9.]+%", line) 41 | assert len(res) == 1 42 | acc = float(res[0].replace("%", "")) 43 | final_test_acc.append(acc) 44 | break 45 | else: 46 | if "test in" in line and "accuracy" in line: 47 | res = re.findall("[0-9.]+%", line) 48 | assert len(res) == 1 49 | acc = float(res[0].replace("%", "")) 50 | print(line) 51 | final_test_acc.append(acc) 52 | break 53 | for worker_id in range(config.worker_number): 54 | for line in reversed(lines): 55 | res = re.findall(f"worker {worker_id}.*train.*accuracy", line) 56 | if res: 57 | res = re.findall("[0-9.]+%", line) 58 | assert len(res) == 1 59 | acc = float(res[0].replace("%", "")) 60 | if worker_id not in worker_acc: 61 | worker_acc[worker_id] = [] 62 | worker_acc[worker_id].append(acc) 63 | break 64 | assert len(final_test_acc) == len(paths) 65 | std, mean = torch.std_mean(torch.tensor(final_test_acc)) 66 | print("test acc", round(mean.item(), 2), round(std.item(), 2)) 67 | 68 | 69 | def compute_data_amount(paths: list) -> dict: 70 | trainer = config.create_trainer() 71 | parameter_list = trainer.model_util.get_parameter_list() 72 | distributed_algorithm = config.distributed_algorithm.lower() 73 | print("worker_number is", config.worker_number) 74 | print("model is", config.model_config.model_name) 75 | uploaded_msg_num = config.round * config.algorithm_kwargs.get( 76 | "random_client_number", config.worker_number 77 | ) 78 | uploaded_parameter_num = uploaded_msg_num * parameter_list.nelement() 79 | downloaded_msg_num = uploaded_msg_num 80 | downloaded_parameter_num = uploaded_parameter_num 81 | distributed_msg_num = config.worker_number 82 | distributed_parameter_num = distributed_msg_num * parameter_list.nelement() 83 | msg_num = uploaded_msg_num + downloaded_msg_num + distributed_msg_num 84 | data_amount: float | tuple = 0 85 | match distributed_algorithm: 86 | case "fed_avg": 87 | data_amount = ( 88 | parameter_list.nelement() 89 | * parameter_list.element_size() 90 | * msg_num 91 | / (1024 * 1024) 92 | ) 93 | case "fed_obd": 94 | msg_num += ( 95 | config.algorithm_kwargs["second_phase_epoch"] * config.worker_number * 2 96 | ) 97 | 98 | data_amounts = [] 99 | for path in paths: 100 | remain_msg = msg_num 101 | lines = None 102 | compressed_part = 0 103 | rnd_cnt = 0 104 | with open(path, encoding="utf8") as f: 105 | lines = f.readlines() 106 | stage_one = True 107 | for line in lines: 108 | if "broadcast NNABQ compression ratio" in line: 109 | res = re.findall("[0-9.]+$", line) 110 | assert len(res) == 1 111 | broadcast_ratio = float( 112 | res[0].replace("(", "").replace(",", "") 113 | ) 114 | # print("broadcast_ratio", broadcast_ratio) 115 | rnd_cnt += 1 116 | if rnd_cnt <= config.round: 117 | compressed_part += ( 118 | broadcast_ratio 119 | * config.algorithm_kwargs["random_client_number"] 120 | ) 121 | remain_msg -= config.algorithm_kwargs[ 122 | "random_client_number" 123 | ] 124 | else: 125 | stage_one = False 126 | if remain_msg > config.worker_number: 127 | compressed_part += ( 128 | broadcast_ratio * config.worker_number 129 | ) 130 | remain_msg -= config.worker_number 131 | if "worker NNABQ compression ratio" in line: 132 | res = re.findall("[0-9.]+$", line) 133 | assert len(res) == 1 134 | worker_ratio = float(res[0].replace("(", "").replace(",", "")) 135 | # print("worker_ratio is ", worker_ratio) 136 | if stage_one: 137 | worker_ratio *= 1 - config.algorithm_kwargs["dropout_rate"] 138 | compressed_part += worker_ratio 139 | remain_msg -= 1 140 | # assert remain_msg == 0 141 | print(remain_msg) 142 | assert remain_msg == config.worker_number 143 | compressed_part += remain_msg 144 | data_amounts.append( 145 | parameter_list.nelement() 146 | * parameter_list.element_size() 147 | * compressed_part 148 | / (1024 * 1024) 149 | ) 150 | assert len(data_amounts) == len(paths) 151 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 152 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 153 | case "fed_obd_sq": 154 | msg_num += ( 155 | config.algorithm_kwargs["second_phase_epoch"] * config.worker_number * 2 156 | ) 157 | 158 | data_amount = ( 159 | uploaded_parameter_num * (1 - config.algorithm_kwargs["dropout_rate"]) 160 | + downloaded_parameter_num 161 | + config.algorithm_kwargs["second_phase_epoch"] 162 | * config.worker_number 163 | * 2 164 | * parameter_list.nelement() 165 | + distributed_parameter_num * parameter_list.element_size() 166 | ) / (1024 * 1024) 167 | case "fed_dropout_avg": 168 | data_amounts = [] 169 | for path in paths: 170 | lines = None 171 | with open(path, encoding="utf8") as f: 172 | lines = f.readlines() 173 | uploaded_parameter_num = 0 174 | for line in lines: 175 | if "send_num" in line: 176 | res = re.findall("[0-9.]+$", line) 177 | assert len(res) == 1 178 | uploaded_parameter_num += float(res[0]) 179 | assert uploaded_parameter_num > 0 180 | assert downloaded_parameter_num > 0 181 | data_amounts.append( 182 | ( 183 | uploaded_parameter_num 184 | + (downloaded_parameter_num + distributed_parameter_num) 185 | ) 186 | * parameter_list.element_size() 187 | / (1024 * 1024) 188 | ) 189 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 190 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 191 | case "single_model_afd": 192 | data_amounts = [] 193 | for path in paths: 194 | transfer_number = 0 195 | lines = None 196 | with open(path, encoding="utf8") as f: 197 | lines = f.readlines() 198 | for line in lines: 199 | if "send_num" in line: 200 | res = re.findall("[0-9.]+$", line) 201 | assert len(res) == 1 202 | transfer_number += float(res[0]) 203 | data_amounts.append( 204 | (transfer_number + distributed_parameter_num) 205 | * parameter_list.element_size() 206 | / (1024 * 1024) 207 | ) 208 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 209 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 210 | 211 | case "fed_obd_first_stage": 212 | data_amounts = [] 213 | for path in paths: 214 | remain_msg = msg_num 215 | lines = None 216 | compressed_part = 0 217 | rnd_cnt = 0 218 | with open(path, encoding="utf8") as f: 219 | lines = f.readlines() 220 | stage_one = True 221 | for line in lines: 222 | if "broadcast NNABQ compression ratio" in line: 223 | res = re.findall("[0-9.]+$", line) 224 | assert len(res) == 1 225 | broadcast_ratio = float( 226 | res[0].replace("(", "").replace(",", "") 227 | ) 228 | # print("broadcast_ratio", broadcast_ratio) 229 | rnd_cnt += 1 230 | if rnd_cnt <= config.round: 231 | compressed_part += ( 232 | broadcast_ratio 233 | * config.algorithm_kwargs["random_client_number"] 234 | ) 235 | remain_msg -= config.algorithm_kwargs[ 236 | "random_client_number" 237 | ] 238 | else: 239 | break 240 | if "worker NNABQ compression ratio" in line: 241 | res = re.findall("[0-9.]+$", line) 242 | assert len(res) == 1 243 | worker_ratio = float(res[0].replace("(", "").replace(",", "")) 244 | # print("worker_ratio is ", worker_ratio) 245 | if stage_one: 246 | worker_ratio *= 1 - config.algorithm_kwargs["dropout_rate"] 247 | compressed_part += worker_ratio 248 | remain_msg -= 1 249 | # assert remain_msg == 0 250 | print(remain_msg) 251 | # assert remain_msg == config.worker_number 252 | compressed_part += config.worker_number 253 | data_amounts.append( 254 | parameter_list.nelement() 255 | * parameter_list.element_size() 256 | * compressed_part 257 | / (1024 * 1024) 258 | ) 259 | assert len(data_amounts) == len(paths) 260 | std, mean = torch.std_mean(torch.tensor(data_amounts)) 261 | data_amount = {"mean": round(mean.item(), 2), "std": round(std.item(), 2)} 262 | 263 | case "fed_paq": 264 | msg_num = ( 265 | config.round * config.algorithm_kwargs["random_client_number"] * 2 266 | + config.worker_number 267 | ) 268 | data_amount = ( 269 | uploaded_parameter_num * 1 270 | + downloaded_parameter_num * parameter_list.element_size() 271 | + distributed_parameter_num * parameter_list.element_size() 272 | ) / (1024 * 1024) 273 | match data_amount: 274 | case float(): 275 | data_amount = round(data_amount, 2) 276 | case dict(): 277 | data_amount = {k: round(v, 2) for k, v in data_amount.items()} 278 | 279 | return {"msg_num": msg_num, "data_amount": data_amount} 280 | 281 | 282 | if __name__ == "__main__": 283 | load_config() 284 | # config.distributed_algorithm = "fed_obd_first_stage" 285 | paths = os.getenv("logfiles").strip().split(" ") 286 | assert paths 287 | compute_acc(paths) 288 | res = compute_data_amount(paths) 289 | print("msg_num is", res["msg_num"]) 290 | print("data_amount is", res["data_amount"]) 291 | -------------------------------------------------------------------------------- /analysis/analyze_round.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import seaborn as sns 7 | from cyy_naive_lib.fs.path import find_directories 8 | from simulation_lib.config import load_config_from_file 9 | 10 | 11 | def extract_data( 12 | session_path: str, algorithm: str, aggregated_performance_metric: dict 13 | ) -> dict: 14 | for server_session in find_directories(session_path, "visualizer"): 15 | if "server" not in server_session: 16 | continue 17 | for round_no in range(1, config.round + 1): 18 | round_str = f"round_{round_no}" 19 | 20 | metric_file = os.path.join( 21 | server_session, round_str, "test", "performance_metric.json" 22 | ) 23 | if not os.path.isfile(metric_file): 24 | continue 25 | with open( 26 | metric_file, 27 | encoding="utf8", 28 | ) as f: 29 | performance_metric = json.load(f) 30 | for k, v in performance_metric.items(): 31 | if k not in aggregated_performance_metric: 32 | aggregated_performance_metric[k] = pd.DataFrame( 33 | columns=["round", k] 34 | ) 35 | aggregated_performance_metric[k] = pd.concat( 36 | [ 37 | aggregated_performance_metric[k], 38 | pd.DataFrame( 39 | [[round_no, list(v.values())[0], algorithm]], 40 | columns=["round", k, "algorithm"], 41 | ), 42 | ] 43 | ) 44 | return aggregated_performance_metric 45 | 46 | 47 | if __name__ == "__main__": 48 | aggregated_performance_metric: dict = {} 49 | config_files = os.getenv("config_files") 50 | assert config_files is not None 51 | for config_file in config_files.split(): 52 | config = load_config_from_file(config_file=config_file) 53 | session_path = ( 54 | f"session/{config.distributed_algorithm}/{config.dc_config.dataset_name}/" 55 | ) 56 | extract_data( 57 | session_path, config.distributed_algorithm, aggregated_performance_metric 58 | ) 59 | for metric, metric_df in aggregated_performance_metric.items(): 60 | print(f"deal with {metric}") 61 | ax = sns.lineplot( 62 | data=metric_df, x="round", y=metric, hue="algorithm", errorbar="sd" 63 | ) 64 | plt.tight_layout() 65 | plt.savefig(f"{metric}.png") 66 | plt.clf() 67 | -------------------------------------------------------------------------------- /conf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cyyever/distributed_learning_simulator/8193d3f817450ebd89a0896784d37191743860c8/conf/__init__.py -------------------------------------------------------------------------------- /conf/fed_avg/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg/imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: ImageNet 3 | model_name: Resnet50 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 128 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.01 12 | dataset_kwargs: 13 | input_max_len: 300 14 | dataset_type: text 15 | tokenizer: 16 | type: spacy 17 | model_kwargs: 18 | word_vector_name: glove.6B.100d 19 | num_encoder_layer: 2 20 | d_model: 100 21 | nhead: 5 22 | frozen_modules: 23 | names: [embedding] 24 | -------------------------------------------------------------------------------- /conf/fed_avg/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: LeNet5 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 20 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 2 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg_qat/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_avg_qat 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_avg_qat/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: LeNet5 4 | distributed_algorithm: fed_avg_qat 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 20 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 2 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/fed_dropout_avg/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_dropout_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | dropout_rate: 0.3 14 | random_client_number: 5 15 | -------------------------------------------------------------------------------- /conf/fed_dropout_avg/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_dropout_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | dropout_rate: 0.3 14 | random_client_number: 5 15 | ... 16 | -------------------------------------------------------------------------------- /conf/fed_dropout_avg/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_dropout_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.01 12 | algorithm_kwargs: 13 | dropout_rate: 0.3 14 | random_client_number: 5 15 | dataset_kwargs: 16 | input_max_len: 300 17 | dataset_type: text 18 | tokenizer: 19 | type: spacy 20 | model_kwargs: 21 | word_vector_name: glove.6B.100d 22 | num_encoder_layer: 2 23 | d_model: 100 24 | nhead: 5 25 | frozen_modules: 26 | names: [embedding] 27 | -------------------------------------------------------------------------------- /conf/fed_obd/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd 5 | optimizer_name: SGD 6 | worker_number: 100 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | learning_rate: 0.1 11 | batch_size: 64 12 | endpoint_kwargs: 13 | server: 14 | weight: 0.001 15 | worker: 16 | weight: 0.001 17 | algorithm_kwargs: 18 | second_phase_epoch: 10 19 | dropout_rate: 0.3 20 | random_client_number: 50 21 | ... 22 | -------------------------------------------------------------------------------- /conf/fed_obd/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd 5 | optimizer_name: SGD 6 | worker_number: 50 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | endpoint_kwargs: 13 | server: 14 | weight: 0.001 15 | worker: 16 | weight: 0.001 17 | algorithm_kwargs: 18 | second_phase_epoch: 10 19 | dropout_rate: 0.3 20 | random_client_number: 25 21 | ... 22 | -------------------------------------------------------------------------------- /conf/fed_obd/cifar100_sq.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd_sq 5 | optimizer_name: SGD 6 | worker_number: 50 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | second_phase_epoch: 10 14 | dropout_rate: 0.3 15 | random_client_number: 25 16 | ... 17 | -------------------------------------------------------------------------------- /conf/fed_obd/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_obd 5 | optimizer_name: SGD 6 | worker_number: 100 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.01 12 | endpoint_kwargs: 13 | server: 14 | weight: 0.0001 15 | worker: 16 | weight: 0.0001 17 | algorithm_kwargs: 18 | second_phase_epoch: 10 19 | dropout_rate: 0.3 20 | random_client_number: 50 21 | dataset_kwargs: 22 | input_max_len: 300 23 | dataset_type: text 24 | tokenizer: 25 | type: spacy 26 | model_kwargs: 27 | word_vector_name: glove.6B.100d 28 | num_encoder_layer: 2 29 | d_model: 100 30 | nhead: 5 31 | frozen_modules: 32 | names: [embedding] 33 | -------------------------------------------------------------------------------- /conf/fed_obd_sq/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_obd_sq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | second_phase_epoch: 10 14 | dropout_rate: 0.9 15 | random_client_number: 5 16 | ... 17 | -------------------------------------------------------------------------------- /conf/fed_paq/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: fed_paq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | batch_size: 64 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | ... 15 | -------------------------------------------------------------------------------- /conf/fed_paq/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: fed_paq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | ... 15 | -------------------------------------------------------------------------------- /conf/fed_paq/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_paq 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.01 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dataset_kwargs: 15 | input_max_len: 300 16 | dataset_type: text 17 | tokenizer: 18 | type: spacy 19 | model_kwargs: 20 | word_vector_name: glove.6B.100d 21 | num_encoder_layer: 2 22 | d_model: 100 23 | nhead: 5 24 | frozen_modules: 25 | names: [embedding] 26 | -------------------------------------------------------------------------------- /conf/global.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | cache_transforms: cpu 3 | log_level: INFO 4 | save_performance_metric: false 5 | use_amp: false 6 | use_slow_performance_metrics: false 7 | -------------------------------------------------------------------------------- /conf/gtg_sv/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/gtg_sv/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/gtg_sv/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: fed_avg 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.01 12 | dataset_kwargs: 13 | input_max_len: 300 14 | dataset_type: text 15 | tokenizer: 16 | type: spacy 17 | model_kwargs: 18 | word_vector_name: glove.6B.100d 19 | num_encoder_layer: 2 20 | d_model: 100 21 | nhead: 5 22 | frozen_modules: 23 | names: [embedding] 24 | -------------------------------------------------------------------------------- /conf/gtg_sv/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: LeNet5 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 20 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 2 11 | learning_rate: 0.01 12 | ... 13 | -------------------------------------------------------------------------------- /conf/multiround_sv/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: multiround_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/multiround_sv/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: GTG_shapley_value 5 | optimizer_name: SGD 6 | worker_number: 10 7 | batch_size: 64 8 | round: 100 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | ... 13 | -------------------------------------------------------------------------------- /conf/qsgd/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: MNIST 3 | model_name: lenet5 4 | distributed_algorithm: QSGD 5 | worker_number: 2 6 | round: 1 7 | learning_rate_scheduler_name: CosineAnnealingLR 8 | epoch: 10 9 | batch_size: 64 10 | learning_rate: 0.001 11 | algorithm_kwargs: 12 | distribute_init_parameters: false 13 | -------------------------------------------------------------------------------- /conf/sign_sgd/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: sign_SGD 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 1 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 100 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | distribute_init_parameters: false 14 | -------------------------------------------------------------------------------- /conf/sign_sgd/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: sign_SGD 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 1 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 100 10 | learning_rate: 0.1 11 | algorithm_kwargs: 12 | distribute_init_parameters: false 13 | -------------------------------------------------------------------------------- /conf/sign_sgd/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: sign_SGD 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 1 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 100 10 | batch_size: 64 11 | learning_rate: 0.01 12 | dataset_kwargs: 13 | input_max_len: 300 14 | dataset_type: text 15 | tokenizer: 16 | type: spacy 17 | model_kwargs: 18 | word_vector_name: glove.6B.100d 19 | num_encoder_layer: 2 20 | d_model: 100 21 | nhead: 5 22 | frozen_modules: 23 | names: [embedding] 24 | algorithm_kwargs: 25 | distribute_init_parameters: false 26 | -------------------------------------------------------------------------------- /conf/smafd/cifar10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR10 3 | model_name: densenet40 4 | distributed_algorithm: single_model_afd 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | batch_size: 64 9 | learning_rate_scheduler_name: CosineAnnealingLR 10 | epoch: 5 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dropout_rate: 0.3 15 | ... 16 | -------------------------------------------------------------------------------- /conf/smafd/cifar100.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: CIFAR100 3 | model_name: densenet40 4 | distributed_algorithm: single_model_afd 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.1 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dropout_rate: 0.3 15 | ... 16 | -------------------------------------------------------------------------------- /conf/smafd/imdb.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_name: imdb 3 | model_name: TransformerClassificationModel 4 | distributed_algorithm: single_model_afd 5 | optimizer_name: SGD 6 | worker_number: 10 7 | round: 100 8 | learning_rate_scheduler_name: CosineAnnealingLR 9 | epoch: 5 10 | batch_size: 64 11 | learning_rate: 0.01 12 | algorithm_kwargs: 13 | random_client_number: 5 14 | dropout_rate: 0.3 15 | dataset_kwargs: 16 | input_max_len: 300 17 | dataset_type: text 18 | tokenizer: 19 | type: spacy 20 | model_kwargs: 21 | word_vector_name: glove.6B.100d 22 | num_encoder_layer: 2 23 | d_model: 100 24 | nhead: 5 25 | frozen_modules: 26 | names: [embedding] 27 | -------------------------------------------------------------------------------- /exp_analyzer.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation.analysis.document import dump_analysis 2 | 3 | if __name__ == "__main__": 4 | dump_analysis() 5 | -------------------------------------------------------------------------------- /exp_analyzer.sh: -------------------------------------------------------------------------------- 1 | for session_path in $(find session -name server); do 2 | base_dir=$(dirname $session_path) 3 | if ! test -d ".real_${base_dir}"; then 4 | mkdir -p ".real_${base_dir}" 5 | cp -r ${base_dir} $(dirname ".real_${base_dir}") 6 | fi 7 | env session_path=".real_${session_path}" python3 exp_analyzer.py 8 | done 9 | -------------------------------------------------------------------------------- /fed_aas.sh: -------------------------------------------------------------------------------- 1 | # fedaas 2 | for configname in cs.yaml PubMed.yaml reddit.yaml yelp.yaml; do 3 | python3 ./simulator.py --config-name fed_aas/${configname} ++fed_aas.worker_number=50 ++fed_aas.round=50 ++fed_aas.model_name=TwoGATCN ++fed_aas.epoch=1 ++fed_aas.dataloader_kwargs.batch_number=10 ++fed_aas.learning_rate=0.001 ++fed_aas.algorithm_kwargs.edge_drop_rate=0.99 ++fed_aas.weight_decay=0.001 ++fed_aas.exp_name="fed_aas" ++fed_aas.algorithm_kwargs.min_sharing_interval=2 4 | done 5 | -------------------------------------------------------------------------------- /fed_obd_train.sh: -------------------------------------------------------------------------------- 1 | python3 ./simulator.py --config-name fed_obd/cifar10.yaml 2 | python3 ./simulator.py --config-name fed_obd/cifar100.yaml 3 | python3 ./simulator.py --config-name fed_obd/imdb.yaml 4 | -------------------------------------------------------------------------------- /gtg_shapley_train.sh: -------------------------------------------------------------------------------- 1 | python3 ./simulator.py --config-name gtg_sv/mnist.yaml 2 | -------------------------------------------------------------------------------- /method/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | for entry in os.scandir(os.path.dirname(os.path.abspath(__file__))): 5 | if not entry.is_dir(): 6 | continue 7 | if entry.name == "__pycache__": 8 | continue 9 | if entry.name.startswith("."): 10 | continue 11 | try: 12 | importlib.import_module(f".{entry.name}", "method") 13 | except ModuleNotFoundError: 14 | importlib.import_module( 15 | f".{entry.name}", "distributed_learning_simulator.method" 16 | ) 17 | -------------------------------------------------------------------------------- /method/fed_avg_qat/__init__.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import ( 2 | AggregationServer, 3 | AlgorithmRepository, 4 | FedAVGAlgorithm, 5 | ) 6 | 7 | from .worker import QATWorker 8 | 9 | AlgorithmRepository.register_algorithm( 10 | algorithm_name="fed_avg_qat", 11 | client_cls=QATWorker, 12 | server_cls=AggregationServer, 13 | algorithm_cls=FedAVGAlgorithm, 14 | ) 15 | -------------------------------------------------------------------------------- /method/fed_avg_qat/worker.py: -------------------------------------------------------------------------------- 1 | import torch.ao.quantization 2 | from cyy_naive_lib.log import log_error 3 | from cyy_torch_algorithm.quantization.qat import QuantizationAwareTraining 4 | from cyy_torch_toolbox import TensorDict 5 | from distributed_learning_simulation import ( 6 | AggregationWorker, 7 | ) 8 | from torch.ao.nn.quantized.modules import Linear 9 | 10 | 11 | class QATWorker(AggregationWorker): 12 | parameter_name_set: set = set() 13 | 14 | def _get_parameters(self) -> TensorDict: 15 | assert isinstance(self.trainer.model, torch.ao.quantization.QuantWrapper) 16 | self.trainer.model.eval() 17 | old_model = self.trainer.model 18 | model_int8 = torch.ao.quantization.convert(old_model.module.cpu()) 19 | self.trainer.replace_model(lambda *args: model_int8) 20 | new_state_dict = {} 21 | for name, p in self.trainer.model_util.get_modules(): 22 | if isinstance(p, Linear): 23 | weight, bias = p._packed_params._weight_bias() 24 | weight_name = name + ".weight" 25 | assert weight_name in self.parameter_name_set 26 | new_state_dict[weight_name] = weight.detach().dequantize() 27 | bias_name = name + ".bias" 28 | assert bias_name in self.parameter_name_set 29 | new_state_dict[bias_name] = bias.detach().dequantize() 30 | 31 | for name, p in self.trainer.model.state_dict().items(): 32 | log_error("%s %s", name, p) 33 | if name.endswith(".zero_point") or name.endswith(".scale"): 34 | continue 35 | if name.startswith("module."): 36 | name = name[len("module.") :] 37 | if isinstance(p, torch.Tensor | torch.nn.Parameter): 38 | new_state_dict[name] = p.detach().dequantize() 39 | log_error("%s %s", new_state_dict.keys(), self.parameter_name_set) 40 | assert sorted(new_state_dict.keys()) == sorted(self.parameter_name_set) 41 | if self._model_loading_fun is None: 42 | self._model_loading_fun = self.load_model 43 | return new_state_dict 44 | 45 | def load_model(self, state_dict) -> None: 46 | self.trainer.remove_model() 47 | self.trainer.model.load_state_dict(state_dict) 48 | 49 | def _before_training(self) -> None: 50 | super()._before_training() 51 | self.parameter_name_set = set(self.trainer.model_util.get_parameters().keys()) 52 | self.trainer.append_hook(QuantizationAwareTraining(), "QAT") 53 | -------------------------------------------------------------------------------- /method/fed_dropout_avg/__init__.py: -------------------------------------------------------------------------------- 1 | """FedDropoutAvg: Generalizable federated learning for histopathology image classification (https://arxiv.org/pdf/2111.13230.pdf)""" 2 | 3 | from distributed_learning_simulation import ( 4 | AggregationServer, 5 | AlgorithmRepository, 6 | ) 7 | 8 | from .algorithm import FedDropoutAvgAlgorithm 9 | from .worker import FedDropoutAvgWorker 10 | 11 | AlgorithmRepository.register_algorithm( 12 | algorithm_name="fed_dropout_avg", 13 | client_cls=FedDropoutAvgWorker, 14 | server_cls=AggregationServer, 15 | algorithm_cls=FedDropoutAvgAlgorithm, 16 | ) 17 | -------------------------------------------------------------------------------- /method/fed_dropout_avg/algorithm.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from distributed_learning_simulation import FedAVGAlgorithm 5 | 6 | 7 | class FedDropoutAvgAlgorithm(FedAVGAlgorithm): 8 | def _get_weight(self, worker_data, name: str, parameter: torch.Tensor) -> Any: 9 | weight = super()._get_weight( 10 | worker_data=worker_data, name=name, parameter=parameter 11 | ) 12 | return (parameter != 0).float() * weight 13 | 14 | def _apply_total_weight( 15 | self, name: str, parameter: torch.Tensor, total_weight: Any 16 | ) -> torch.Tensor: 17 | # avoid dividing by zero, in such case we set weight to 1 18 | total_weight[total_weight == 0] = 1 19 | return super()._apply_total_weight( 20 | name=name, parameter=parameter, total_weight=total_weight 21 | ) 22 | -------------------------------------------------------------------------------- /method/fed_dropout_avg/worker.py: -------------------------------------------------------------------------------- 1 | """FedDropoutAvg: Generalizable federated learning for histopathology image classification (https://arxiv.org/pdf/2111.13230.pdf)""" 2 | 3 | import torch 4 | from cyy_naive_lib.log import log_info 5 | from distributed_learning_simulation import AggregationWorker, ParameterMessage 6 | 7 | 8 | class FedDropoutAvgWorker(AggregationWorker): 9 | def _get_sent_data(self) -> ParameterMessage: 10 | dropout_rate: float = self.config.algorithm_kwargs["dropout_rate"] 11 | if self.hold_log_lock: 12 | log_info("use dropout_rate %s", dropout_rate) 13 | self._send_parameter_diff = False 14 | sent_data = super()._get_sent_data() 15 | assert isinstance(sent_data, ParameterMessage) 16 | parameter = sent_data.parameter 17 | total_num: float = 0 18 | send_num: float = 0 19 | for k, v in parameter.items(): 20 | weight = torch.bernoulli(torch.full_like(v, 1 - dropout_rate)) 21 | parameter[k] = v * weight 22 | total_num += parameter[k].numel() 23 | send_num += torch.count_nonzero(parameter[k]).item() 24 | log_info("send_num %s", send_num) 25 | log_info("total_num %s", total_num) 26 | return sent_data 27 | -------------------------------------------------------------------------------- /method/fed_obd/__init__.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import ( 2 | AlgorithmRepository, 3 | FedAVGAlgorithm, 4 | NNADQClientEndpoint, 5 | NNADQServerEndpoint, 6 | StochasticQuantClientEndpoint, 7 | StochasticQuantServerEndpoint, 8 | ) 9 | 10 | from .server import FedOBDServer 11 | from .worker import FedOBDWorker 12 | 13 | AlgorithmRepository.register_algorithm( 14 | algorithm_name="fed_obd", 15 | client_cls=FedOBDWorker, 16 | server_cls=FedOBDServer, 17 | client_endpoint_cls=NNADQClientEndpoint, 18 | server_endpoint_cls=NNADQServerEndpoint, 19 | algorithm_cls=FedAVGAlgorithm, 20 | ) 21 | 22 | AlgorithmRepository.register_algorithm( 23 | algorithm_name="fed_obd_sq", 24 | client_cls=FedOBDWorker, 25 | server_cls=FedOBDServer, 26 | client_endpoint_cls=StochasticQuantClientEndpoint, 27 | server_endpoint_cls=StochasticQuantServerEndpoint, 28 | algorithm_cls=FedAVGAlgorithm, 29 | ) 30 | -------------------------------------------------------------------------------- /method/fed_obd/obd_algorithm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from algorithm.block_algorithm import BlockAlgorithmMixin 3 | from cyy_naive_lib.log import log_debug, log_info 4 | from cyy_torch_toolbox import ModelParameter 5 | from cyy_torch_toolbox.tensor import cat_tensors_to_vector 6 | 7 | 8 | class OpportunisticBlockDropoutAlgorithmMixin(BlockAlgorithmMixin): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self.__dropout_rate = self.config.algorithm_kwargs["dropout_rate"] 12 | log_debug("use dropout rate %s", self.__dropout_rate) 13 | self.__parameter_num: int = 0 14 | 15 | def get_block_parameter(self, parameter: ModelParameter) -> ModelParameter: 16 | if self.__parameter_num == 0: 17 | parameter_list = self.trainer.model_util.get_parameter_list() 18 | self.__parameter_num = len(parameter_list) 19 | threshold = (1 - self.__dropout_rate) * self.__parameter_num 20 | partial_parameter_num = 0 21 | new_parameter: dict = {} 22 | 23 | block_delta: dict = {} 24 | for block in self.blocks: 25 | block_dict, delta, block_size = self.__analyze_block(parameter, block) 26 | mean_delta = delta / block_size 27 | if mean_delta not in block_delta: 28 | block_delta[mean_delta] = [] 29 | block_delta[mean_delta].append((block_dict, block_size)) 30 | log_debug("block_delta is %s", sorted(block_delta.keys(), reverse=True)) 31 | 32 | for mean_delta in sorted(block_delta.keys(), reverse=True): 33 | if partial_parameter_num > threshold: 34 | break 35 | for block_dict, block_size in block_delta[mean_delta]: 36 | if partial_parameter_num + block_size > threshold: 37 | continue 38 | partial_parameter_num += block_size 39 | new_parameter |= block_dict 40 | log_debug("choose blocks %s", new_parameter.keys()) 41 | log_info( 42 | "partial_parameter_num %s threshold %s parameter_num %s", 43 | partial_parameter_num, 44 | threshold, 45 | self.__parameter_num, 46 | ) 47 | 48 | return new_parameter 49 | 50 | def __analyze_block(self, parameter: ModelParameter, block: list) -> tuple: 51 | cur_block_parameters = [] 52 | prev_block_parameters = [] 53 | block_dict = {} 54 | for submodule_name, submodule in block: 55 | for p_name, _ in submodule.named_parameters(): 56 | parameter_name = submodule_name + "." + p_name 57 | cur_block_parameters.append(parameter[parameter_name]) 58 | prev_block_parameters.append(self.model_cache.parameter[parameter_name]) 59 | block_dict[parameter_name] = parameter[parameter_name] 60 | 61 | cur_block_parameter = cat_tensors_to_vector(cur_block_parameters) 62 | prev_block_parameter = cat_tensors_to_vector(prev_block_parameters) 63 | delta = torch.linalg.vector_norm( 64 | cur_block_parameter.cpu() - prev_block_parameter.cpu() 65 | ).item() 66 | return (block_dict, delta, cur_block_parameter.nelement()) 67 | -------------------------------------------------------------------------------- /method/fed_obd/phase.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum, auto 2 | 3 | 4 | class Phase(StrEnum): 5 | STAGE_ONE = auto() 6 | STAGE_TWO = auto() 7 | END = auto() 8 | -------------------------------------------------------------------------------- /method/fed_obd/server.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from cyy_naive_lib.log import log_warning 4 | from distributed_learning_simulation import ( 5 | AggregationServer, 6 | ParameterMessage, 7 | ParameterMessageBase, 8 | QuantServerEndpoint, 9 | ) 10 | 11 | from .phase import Phase 12 | 13 | 14 | class FedOBDServer(AggregationServer): 15 | def __init__(self, **kwargs: Any) -> None: 16 | super().__init__(**kwargs) 17 | self.__phase: Phase = Phase.STAGE_ONE 18 | assert isinstance(self._endpoint, QuantServerEndpoint) 19 | self._endpoint.use_quant() 20 | self._compute_stat = True 21 | 22 | def select_workers(self) -> set: 23 | if self.__phase != Phase.STAGE_ONE: 24 | return set(range(self.worker_number)) 25 | return super().select_workers() 26 | 27 | def _get_stat_key(self, message: ParameterMessage): 28 | if self.__phase == Phase.STAGE_TWO: 29 | return max(self.performance_stat.keys()) + 1 30 | return super()._get_stat_key(message) 31 | 32 | def _aggregate_worker_data(self) -> ParameterMessageBase: 33 | result: ParameterMessageBase = super()._aggregate_worker_data() 34 | assert result 35 | match self.__phase: 36 | case Phase.STAGE_ONE: 37 | if self.round_index >= self.config.round or ( 38 | self.early_stop and not self.__has_improvement() 39 | ): 40 | log_warning("switch to phase 2") 41 | self.__phase = Phase.STAGE_TWO 42 | result.other_data["phase_two"] = True 43 | case Phase.STAGE_TWO: 44 | if self.early_stop and not self.__has_improvement(): 45 | log_warning("stop aggregation") 46 | result.end_training = True 47 | case _: 48 | raise NotImplementedError(f"unknown phase {self.__phase}") 49 | if result.end_training: 50 | self.__phase = Phase.END 51 | return result 52 | 53 | def _stopped(self) -> bool: 54 | return self.__phase == Phase.END 55 | 56 | def __has_improvement(self) -> bool: 57 | if self.__phase == Phase.STAGE_TWO: 58 | return True 59 | return not self.convergent() 60 | -------------------------------------------------------------------------------- /method/fed_obd/worker.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from cyy_naive_lib.log import log_debug, log_warning 4 | from cyy_torch_toolbox import ExecutorHookPoint, ModelUtil 5 | from distributed_learning_simulation import ( 6 | AggregationWorker, 7 | Message, 8 | ParameterMessage, 9 | QuantClientEndpoint, 10 | ) 11 | from distributed_learning_simulation.context import ClientEndpointInCoroutine 12 | 13 | from .obd_algorithm import OpportunisticBlockDropoutAlgorithmMixin 14 | from .phase import Phase 15 | 16 | 17 | class FedOBDWorker(AggregationWorker, OpportunisticBlockDropoutAlgorithmMixin): 18 | __phase = Phase.STAGE_ONE 19 | 20 | def __init__(self, *args, **kwargs): 21 | AggregationWorker.__init__(self, *args, **kwargs) 22 | OpportunisticBlockDropoutAlgorithmMixin.__init__(self) 23 | assert isinstance( 24 | self._endpoint, QuantClientEndpoint | ClientEndpointInCoroutine 25 | ) 26 | self._endpoint.dequant_server_data() 27 | self._send_parameter_diff = False 28 | self._keep_model_cache = True 29 | 30 | def _load_result_from_server(self, result: Message) -> None: 31 | if "phase_two" in result.other_data: 32 | assert isinstance(result, ParameterMessage) 33 | # result.other_data.pop("phase_two") 34 | self.__phase = Phase.STAGE_TWO 35 | log_warning("switch to phase 2") 36 | self.set_reuse_learning_rate(True) 37 | self._send_parameter_diff = True 38 | self.disable_choosing_model_by_validation() 39 | self.trainer.hyper_parameter.epoch = self.config.algorithm_kwargs[ 40 | "second_phase_epoch" 41 | ] 42 | self.config.round = self._round_index + 1 43 | self._aggregation_time = ExecutorHookPoint.AFTER_EPOCH 44 | self._register_aggregation() 45 | 46 | super()._load_result_from_server(result=result) 47 | 48 | def _get_model_util(self) -> ModelUtil: 49 | return self.trainer.model_util 50 | 51 | def _aggregation(self, sent_data: Message, **kwargs: Any) -> None: 52 | if self.__phase == Phase.STAGE_TWO: 53 | executor = kwargs["executor"] 54 | if kwargs["epoch"] == executor.hyper_parameter.epoch: 55 | sent_data.end_training = True 56 | self._force_stop = True 57 | log_debug("end training") 58 | super()._aggregation(sent_data=sent_data, **kwargs) 59 | 60 | def _stopped(self) -> bool: 61 | return self._force_stop 62 | 63 | def _get_sent_data(self): 64 | assert self._model_cache is not None 65 | data = super()._get_sent_data() 66 | if self.__phase == Phase.STAGE_ONE: 67 | assert isinstance(data, ParameterMessage) 68 | block_parameter = self.get_block_parameter( 69 | parameter=data.parameter, 70 | ) 71 | data.parameter = self._model_cache.get_parameter_diff(block_parameter) 72 | return data 73 | 74 | data.in_round = True 75 | log_warning("phase 2 keys %s", data.other_data.keys()) 76 | return data 77 | -------------------------------------------------------------------------------- /method/fed_paq/__init__.py: -------------------------------------------------------------------------------- 1 | # FedPAQ: A Communication-Efficient Federated Learning Method with Periodic Averaging and Quantization (https://arxiv.org/abs/1909.13014) 2 | from distributed_learning_simulation import ( 3 | AggregationServer, 4 | AggregationWorker, 5 | AlgorithmRepository, 6 | FedAVGAlgorithm, 7 | StochasticQuantClientEndpoint, 8 | StochasticQuantServerEndpoint, 9 | ) 10 | 11 | AlgorithmRepository.register_algorithm( 12 | algorithm_name="fed_paq", 13 | client_cls=AggregationWorker, 14 | server_cls=AggregationServer, 15 | client_endpoint_cls=StochasticQuantClientEndpoint, 16 | server_endpoint_cls=StochasticQuantServerEndpoint, 17 | algorithm_cls=FedAVGAlgorithm, 18 | ) 19 | -------------------------------------------------------------------------------- /method/qsgd/__init__.py: -------------------------------------------------------------------------------- 1 | """QSGD: Communication-Efficient SGD via Gradient Quantization and Encoding https://arxiv.org/abs/1610.02132""" 2 | 3 | from distributed_learning_simulation import ( 4 | AggregationServer, 5 | AlgorithmRepository, 6 | FedAVGAlgorithm, 7 | GradientWorker, 8 | StochasticQuantClientEndpoint, 9 | StochasticQuantServerEndpoint, 10 | ) 11 | 12 | AlgorithmRepository.register_algorithm( 13 | algorithm_name="QSGD", 14 | client_cls=GradientWorker, 15 | server_cls=AggregationServer, 16 | algorithm_cls=FedAVGAlgorithm, 17 | client_endpoint_cls=StochasticQuantClientEndpoint, 18 | server_endpoint_cls=StochasticQuantServerEndpoint, 19 | ) 20 | -------------------------------------------------------------------------------- /method/shapley_value/GTG_shapley_value_algorithm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from cyy_torch_algorithm.shapely_value.gtg_shapley_value import GTGShapleyValue 5 | 6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) 7 | from algorithm.shapley_value_algorithm import ShapleyValueAlgorithm 8 | 9 | 10 | class GTGShapleyValueAlgorithm(ShapleyValueAlgorithm): 11 | def __init__(self, *args, **kwargs) -> None: 12 | super().__init__(GTGShapleyValue, *args, **kwargs) 13 | -------------------------------------------------------------------------------- /method/shapley_value/GTG_shapley_value_server.py: -------------------------------------------------------------------------------- 1 | from .GTG_shapley_value_algorithm import GTGShapleyValueAlgorithm 2 | from .shapley_value_server import ShapleyValueServer 3 | 4 | 5 | class GTGShapleyValueServer(ShapleyValueServer): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(**kwargs, algorithm=GTGShapleyValueAlgorithm(server=self)) 8 | -------------------------------------------------------------------------------- /method/shapley_value/__init__.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import ( 2 | AggregationWorker, 3 | AlgorithmRepository, 4 | ) 5 | 6 | from .GTG_shapley_value_server import GTGShapleyValueServer 7 | from .multiround_shapley_value_server import MultiRoundShapleyValueServer 8 | 9 | AlgorithmRepository.register_algorithm( 10 | algorithm_name="multiround_shapley_value", 11 | client_cls=AggregationWorker, 12 | server_cls=MultiRoundShapleyValueServer, 13 | ) 14 | AlgorithmRepository.register_algorithm( 15 | algorithm_name="GTG_shapley_value", 16 | client_cls=AggregationWorker, 17 | server_cls=GTGShapleyValueServer, 18 | ) 19 | -------------------------------------------------------------------------------- /method/shapley_value/multiround_shapley_value_algorithm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from cyy_torch_algorithm.shapely_value.multiround_shapley_value import ( 5 | MultiRoundShapleyValue, 6 | ) 7 | 8 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) 9 | from algorithm.shapley_value_algorithm import ShapleyValueAlgorithm 10 | 11 | 12 | class MultiRoundShapleyValueAlgorithm(ShapleyValueAlgorithm): 13 | def __init__(self, *args, **kwargs) -> None: 14 | super().__init__(MultiRoundShapleyValue, *args, **kwargs) 15 | -------------------------------------------------------------------------------- /method/shapley_value/multiround_shapley_value_server.py: -------------------------------------------------------------------------------- 1 | from .multiround_shapley_value_algorithm import MultiRoundShapleyValueAlgorithm 2 | from .shapley_value_server import ShapleyValueServer 3 | 4 | 5 | class MultiRoundShapleyValueServer(ShapleyValueServer): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | **kwargs, algorithm=MultiRoundShapleyValueAlgorithm(server=self) 9 | ) 10 | -------------------------------------------------------------------------------- /method/shapley_value/shapley_value_server.py: -------------------------------------------------------------------------------- 1 | from distributed_learning_simulation import AggregationServer 2 | 3 | 4 | class ShapleyValueServer(AggregationServer): 5 | def __init__(self, *args, **kwargs) -> None: 6 | super().__init__(*args, **kwargs) 7 | self._need_init_performance = True 8 | -------------------------------------------------------------------------------- /method/sign_sgd/__init__.py: -------------------------------------------------------------------------------- 1 | """signSGD: Compressed Optimisation for Non-Convex Problems https://arxiv.org/abs/1802.04434""" 2 | 3 | from distributed_learning_simulation import ( 4 | AggregationServer, 5 | AlgorithmRepository, 6 | ) 7 | 8 | from .algorithm import SignSGDAlgorithm 9 | from .worker import SignSGDWorker 10 | 11 | AlgorithmRepository.register_algorithm( 12 | algorithm_name="sign_SGD", 13 | client_cls=SignSGDWorker, 14 | server_cls=AggregationServer, 15 | algorithm_cls=SignSGDAlgorithm, 16 | ) 17 | -------------------------------------------------------------------------------- /method/sign_sgd/algorithm.py: -------------------------------------------------------------------------------- 1 | """signSGD: Compressed Optimisation for Non-Convex Problems https://arxiv.org/abs/1802.04434""" 2 | 3 | from distributed_learning_simulation import FedAVGAlgorithm, ParameterMessage 4 | 5 | 6 | class SignSGDAlgorithm(FedAVGAlgorithm): 7 | def aggregate_worker_data(self) -> ParameterMessage: 8 | message = super().aggregate_worker_data() 9 | assert isinstance(message, ParameterMessage) 10 | message.parameter = {k: v.sign() for k, v in message.parameter.items()} 11 | return message 12 | -------------------------------------------------------------------------------- /method/sign_sgd/worker.py: -------------------------------------------------------------------------------- 1 | """signSGD: Compressed Optimisation for Non-Convex Problems https://arxiv.org/abs/1802.04434""" 2 | 3 | from cyy_torch_toolbox import ModelGradient, TensorDict 4 | from distributed_learning_simulation import GradientWorker, ParameterMessage 5 | 6 | 7 | class SignSGDWorker(GradientWorker): 8 | def _process_gradient(self, gradient_dict: ModelGradient) -> TensorDict: 9 | self._send_data_to_server( 10 | ParameterMessage( 11 | parameter={k: v.sign() for k, v in gradient_dict.items()}, 12 | in_round=True, 13 | aggregation_weight=self.trainer.dataset_size, 14 | ) 15 | ) 16 | result = self._get_data_from_server() 17 | assert isinstance(result, ParameterMessage) 18 | return result.parameter 19 | -------------------------------------------------------------------------------- /other_method_test.sh: -------------------------------------------------------------------------------- 1 | # Fed dropout avg 2 | python3 ./simulator.py --config-name fed_dropout_avg/cifar100.yaml ++fed_dropout_avg.round=1 ++fed_dropout_avg.epoch=1 ++fed_dropout_avg.worker_number=2 ++fed_dropout_avg.debug=True ++fed_dropout_avg.algorithm_kwargs.random_client_number=2 3 | 4 | python3 ./simulator.py --config-name fed_paq/cifar100.yaml ++fed_paq.round=1 ++fed_paq.epoch=1 ++fed_paq.worker_number=2 ++fed_paq.debug=True ++fed_paq.algorithm_kwargs.random_client_number=2 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 63.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "distributed_learning_simulator" 7 | version = "1.0" 8 | requires-python = ">=3.11" 9 | readme = {file = "README.md", content-type = "text/markdown"} 10 | authors = [ 11 | {name = "cyy", email = "cyyever@outlook.com"}, 12 | ] 13 | license = {text = "BSD License"} 14 | classifiers = [ 15 | "Programming Language :: Python" 16 | ] 17 | 18 | dependencies = [ 19 | "distributed_learning_simulation@git+https://github.com/cyyever/distributed_learning_simulation_lib.git", 20 | "cyy_torch_text@git+https://github.com/cyyever/torch_text.git", 21 | "cyy_torch_vision@git+https://github.com/cyyever/torch_vision.git" 22 | ] 23 | 24 | 25 | 26 | [tool.setuptools.package-dir] 27 | "distributed_learning_simulator.conf"= "./conf" 28 | "distributed_learning_simulator.method"= "./method" 29 | "distributed_learning_simulator.algorithm"= "./algorithm" 30 | 31 | [project.urls] 32 | Repository = "https://github.com/cyyever/distributed_learning_simulator" 33 | 34 | 35 | [tool.ruff] 36 | target-version = "py312" 37 | src = ["method", "algorithm"] 38 | 39 | [tool.ruff.lint] 40 | select = [ 41 | # pycodestyle 42 | "E", 43 | # Pyflakes 44 | "F", 45 | # pyupgrade 46 | "UP", 47 | # flake8-bugbear 48 | "B", 49 | # flake8-simplify 50 | "SIM", 51 | # isort 52 | "I", 53 | ] 54 | ignore = ["F401","E501","F403"] 55 | -------------------------------------------------------------------------------- /simulator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from cyy_naive_lib.log import redirect_stdout_to_logger 5 | from distributed_learning_simulation import import_dependencies, load_config, train 6 | 7 | sys.path.insert(0, os.path.abspath(".")) 8 | import method # noqa: F401 9 | 10 | import_dependencies() 11 | if __name__ == "__main__": 12 | with redirect_stdout_to_logger(): 13 | config_path = os.path.join(os.path.dirname(__file__), "conf") 14 | config = load_config( 15 | config_path=config_path, 16 | global_conf_path=os.path.join(config_path, "global.yaml"), 17 | ) 18 | train(config=config, single_task=True) 19 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # CV 2 | python3 ./simulator.py --config-name fed_avg/mnist.yaml ++fed_avg.round=2 ++fed_avg.epoch=1 ++fed_avg.worker_number=10 ++fed_avg.debug=True 3 | # NLP 4 | python3 ./simulator.py --config-name fed_avg/imdb.yaml ++fed_avg.round=1 ++fed_avg.epoch=1 ++fed_avg.worker_number=2 ++fed_avg.debug=True 5 | # GTG 6 | python3 ./simulator.py --config-name gtg_sv/mnist.yaml ++gtg_sv.round=1 ++gtg_sv.epoch=1 ++gtg_sv.worker_number=2 ++gtg_sv.debug=False 7 | # OBD 8 | python3 ./simulator.py --config-name fed_obd/cifar10.yaml ++fed_obd.round=2 ++fed_obd.epoch=1 ++fed_obd.worker_number=10 ++fed_obd.algorithm_kwargs.random_client_number=10 ++fed_obd.algorithm_kwargs.second_phase_epoch=1 ++fed_obd.debug=False 9 | --------------------------------------------------------------------------------