├── README.md ├── codes └── FedDF-code │ ├── .gitignore │ ├── README.md │ ├── auto_extract.py │ ├── generate_data.py │ ├── hostfile │ ├── main.py │ ├── parameters.py │ ├── pcode │ ├── __init__.py │ ├── aggregation │ │ ├── __init__.py │ │ ├── adv_knowledge_transfer.py │ │ ├── fedavg.py │ │ ├── knowledge_transfer.py │ │ ├── learn2agg.py │ │ ├── noise_knowledge_transfer.py │ │ ├── optimal_transport.py │ │ ├── server_adaptive.py │ │ ├── server_momentum.py │ │ ├── swa_knowledge_transfer.py │ │ ├── swag_utils │ │ │ ├── __init__.py │ │ │ └── swag.py │ │ ├── unlabeled_training.py │ │ └── utils.py │ ├── create_aggregator.py │ ├── create_coordinator.py │ ├── create_dataset.py │ ├── create_metrics.py │ ├── create_model.py │ ├── create_optimizer.py │ ├── create_scheduler.py │ ├── datasets │ │ ├── __init__.py │ │ ├── loader │ │ │ ├── __init__.py │ │ │ ├── epsilon_or_rcv1_folder.py │ │ │ ├── femnist.py │ │ │ ├── imagenet_folder.py │ │ │ ├── preprocess_toolkit.py │ │ │ ├── pseudo_imagenet_folder.py │ │ │ ├── serialize.py │ │ │ ├── svhn_folder.py │ │ │ └── utils.py │ │ ├── mixup_data.py │ │ ├── partition_data.py │ │ └── prepare_data.py │ ├── local_training │ │ ├── __init__.py │ │ ├── compressor.py │ │ └── random_reinit.py │ ├── master.py │ ├── master_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── densenet.py │ │ ├── efficientnet.py │ │ ├── lenet.py │ │ ├── mlp.py │ │ ├── mobilenetv2.py │ │ ├── moderate_cnns.py │ │ ├── regnet.py │ │ ├── resnet.py │ │ ├── resnet_evonorm.py │ │ ├── shufflenetv2.py │ │ ├── simple_cnns.py │ │ ├── vgg.py │ │ └── wideresnet.py │ ├── tools │ │ ├── __init__.py │ │ ├── build_downsampled_imagenet.py │ │ ├── db.py │ │ ├── plot.py │ │ ├── plot_utils.py │ │ └── show_results.py │ ├── utils │ │ ├── __init__.py │ │ ├── auxiliary.py │ │ ├── checkpoint.py │ │ ├── communication.py │ │ ├── cross_entropy.py │ │ ├── early_stopping.py │ │ ├── error_handler.py │ │ ├── logging.py │ │ ├── mathdict.py │ │ ├── misc.py │ │ ├── module_state.py │ │ ├── op_files.py │ │ ├── op_paths.py │ │ ├── param_parser.py │ │ ├── sparsification.py │ │ ├── stat_tracker.py │ │ ├── tensor_buffer.py │ │ ├── timer.py │ │ └── topology.py │ └── worker.py │ └── run.py └── environments ├── base ├── .screenrc ├── .tmux.conf ├── Dockerfile ├── entrypoint.sh └── fix-permissions ├── docker-compose.yml └── pytorch-mpi ├── .condarc └── Dockerfile /README.md: -------------------------------------------------------------------------------- 1 | This repository maintains a codebase for Federated Learning research. It supports: 2 | * PyTorch with MPI backend for a Master-Worker computation/communication topology. 3 | * Local training can be efficiently executed in a parallel-fashion over GPUs for randomly sampled clients. 4 | * Different FL algorithms, e.g., FedAvg, FedProx, FedAvg with Server Momentum, and FedDF, are implemented as the baselines. 5 | 6 | # Code Usage 7 | ## Requirements 8 | We rely on `Docker` for our experimental environments. Please refer to the folder `environments` for more details. 9 | 10 | ## Usage 11 | The current repository includes 12 | * the methods evaluated in the paper `FedDF: Ensemble Distillation for Robust Model Fusion in Federated Learning`. For the detailed instructions and more examples, please refer to the file `codes/FedDF-code/README.md`. 13 | 14 | # Reference 15 | If you use the code in this repository, please consider to cite the following papers: 16 | ``` 17 | @inproceedings{lin2020ensemble, 18 | title={Ensemble Distillation for Robust Model Fusion in Federated Learning}, 19 | author={Lin, Tao and Kong, Lingjing and Stich, Sebastian U and Jaggi, Martin}, 20 | booktitle = {Advances in Neural Information Processing Systems}, 21 | year = {2020} 22 | } 23 | ``` 24 | -------------------------------------------------------------------------------- /codes/FedDF-code/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # data 126 | .vscode 127 | *data 128 | Makefile 129 | -------------------------------------------------------------------------------- /codes/FedDF-code/auto_extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | 5 | import pcode.utils.op_files as op_files 6 | from pcode.tools.show_results import load_raw_info_from_experiments 7 | 8 | """parse and define arguments for different tasks.""" 9 | 10 | 11 | def get_args(): 12 | # feed them to the parser. 13 | parser = argparse.ArgumentParser(description="Extract results.") 14 | 15 | # add arguments. 16 | parser.add_argument("--in_dir", type=str) 17 | parser.add_argument("--out_name", type=str, default="summary.pickle") 18 | 19 | # parse aˇˇrgs. 20 | args = parser.parse_args() 21 | 22 | # an argument safety check. 23 | check_args(args) 24 | return args 25 | 26 | 27 | def check_args(args): 28 | assert args.in_dir is not None 29 | 30 | # define out path. 31 | args.out_path = os.path.join(args.in_dir, args.out_name) 32 | 33 | 34 | """write the results to path.""" 35 | 36 | 37 | def main(args): 38 | # save the parsed results to path. 39 | op_files.write_pickle(load_raw_info_from_experiments(args.in_dir), args.out_path) 40 | 41 | 42 | if __name__ == "__main__": 43 | args = get_args() 44 | 45 | main(args) 46 | -------------------------------------------------------------------------------- /codes/FedDF-code/hostfile: -------------------------------------------------------------------------------- 1 | localhost slots=32 -------------------------------------------------------------------------------- /codes/FedDF-code/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from parameters import get_args 8 | from pcode.master import Master 9 | from pcode.worker import Worker 10 | import pcode.utils.topology as topology 11 | import pcode.utils.checkpoint as checkpoint 12 | import pcode.utils.logging as logging 13 | import pcode.utils.param_parser as param_parser 14 | 15 | 16 | def main(conf): 17 | # init the distributed world. 18 | try: 19 | dist.init_process_group("mpi") 20 | except AttributeError as e: 21 | print(f"failed to init the distributed world: {e}.") 22 | conf.distributed = False 23 | 24 | # init the config. 25 | init_config(conf) 26 | 27 | # start federated learning. 28 | process = Master(conf) if conf.graph.rank == 0 else Worker(conf) 29 | process.run() 30 | 31 | 32 | def init_config(conf): 33 | # define the graph for the computation. 34 | conf.graph = topology.define_graph_topology( 35 | world=conf.world, 36 | world_conf=conf.world_conf, 37 | n_participated=conf.n_participated, 38 | on_cuda=conf.on_cuda, 39 | ) 40 | conf.graph.rank = dist.get_rank() 41 | 42 | # init related to randomness on cpu. 43 | if not conf.same_seed_process: 44 | conf.manual_seed = 1000 * conf.manual_seed + conf.graph.rank 45 | conf.random_state = np.random.RandomState(conf.manual_seed) 46 | torch.manual_seed(conf.manual_seed) 47 | 48 | # configure cuda related. 49 | if conf.graph.on_cuda: 50 | assert torch.cuda.is_available() 51 | torch.cuda.manual_seed(conf.manual_seed) 52 | torch.cuda.set_device(conf.graph.primary_device) 53 | torch.backends.cudnn.enabled = True 54 | torch.backends.cudnn.benchmark = True 55 | torch.backends.cudnn.deterministic = True if conf.train_fast else False 56 | 57 | # init the model arch info. 58 | conf.arch_info = ( 59 | param_parser.dict_parser(conf.complex_arch) 60 | if conf.complex_arch is not None 61 | else {"master": conf.arch, "worker": conf.arch} 62 | ) 63 | conf.arch_info["worker"] = conf.arch_info["worker"].split(":") 64 | 65 | # parse the fl_aggregate scheme. 66 | conf._fl_aggregate = conf.fl_aggregate 67 | conf.fl_aggregate = ( 68 | param_parser.dict_parser(conf.fl_aggregate) 69 | if conf.fl_aggregate is not None 70 | else conf.fl_aggregate 71 | ) 72 | [setattr(conf, f"fl_aggregate_{k}", v) for k, v in conf.fl_aggregate.items()] 73 | 74 | # define checkpoint for logging (for federated learning server). 75 | checkpoint.init_checkpoint(conf, rank=str(conf.graph.rank)) 76 | 77 | # configure logger. 78 | conf.logger = logging.Logger(conf.checkpoint_dir) 79 | 80 | # display the arguments' info. 81 | if conf.graph.rank == 0: 82 | logging.display_args(conf) 83 | 84 | # sync the processes. 85 | dist.barrier() 86 | 87 | 88 | if __name__ == "__main__": 89 | conf = get_args() 90 | main(conf) 91 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/aggregation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/aggregation/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/aggregation/fedavg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | 4 | import torch 5 | 6 | from pcode.utils.module_state import ModuleState 7 | import pcode.master_utils as master_utils 8 | import pcode.aggregation.utils as agg_utils 9 | 10 | 11 | def _fedavg(clientid2arch, n_selected_clients, flatten_local_models, client_models): 12 | weights = [ 13 | torch.FloatTensor([1.0 / n_selected_clients]) for _ in range(n_selected_clients) 14 | ] 15 | 16 | # NOTE: the arch for different local models needs to be the same as the master model. 17 | # retrieve the local models. 18 | local_models = {} 19 | for client_idx, flatten_local_model in flatten_local_models.items(): 20 | _arch = clientid2arch[client_idx] 21 | _model = copy.deepcopy(client_models[_arch]) 22 | _model_state_dict = client_models[_arch].state_dict() 23 | flatten_local_model.unpack(_model_state_dict.values()) 24 | _model.load_state_dict(_model_state_dict) 25 | local_models[client_idx] = _model 26 | 27 | # uniformly average the local models. 28 | # assume we use the runtime stat from the last model. 29 | _model = copy.deepcopy(_model) 30 | local_states = [ 31 | ModuleState(copy.deepcopy(local_model.state_dict())) 32 | for _, local_model in local_models.items() 33 | ] 34 | model_state = local_states[0] * weights[0] 35 | for idx in range(1, len(local_states)): 36 | model_state += local_states[idx] * weights[idx] 37 | model_state.copy_to_module(_model) 38 | return _model 39 | 40 | 41 | def fedavg( 42 | conf, 43 | clientid2arch, 44 | n_selected_clients, 45 | flatten_local_models, 46 | client_models, 47 | criterion, 48 | metrics, 49 | val_data_loader, 50 | ): 51 | if ( 52 | "server_teaching_scheme" not in conf.fl_aggregate 53 | or "drop" not in conf.fl_aggregate["server_teaching_scheme"] 54 | ): 55 | # directly averaging. 56 | conf.logger.log(f"No indices to be removed.") 57 | return _fedavg( 58 | clientid2arch, n_selected_clients, flatten_local_models, client_models 59 | ) 60 | else: 61 | # we will first perform the evaluation. 62 | # recover the models on the computation device. 63 | _, local_models = agg_utils.recover_models( 64 | conf, client_models, flatten_local_models 65 | ) 66 | 67 | # get the weights from the validation performance. 68 | weights = [] 69 | relationship = {} 70 | indices_to_remove = [] 71 | random_guess_perf = agg_utils.get_random_guess_perf(conf) 72 | for idx, (client_id, local_model) in enumerate(local_models.items()): 73 | relationship[idx] = client_id 74 | validated_perfs = validate( 75 | conf, 76 | model=local_model, 77 | criterion=criterion, 78 | metrics=metrics, 79 | data_loader=val_data_loader, 80 | ) 81 | perf = validated_perfs["top1"] 82 | weights.append(perf) 83 | 84 | # check the perf. 85 | if perf < random_guess_perf: 86 | indices_to_remove.append(idx) 87 | 88 | # update client_teacher. 89 | conf.logger.log( 90 | f"Indices to be removed for FedAvg: {indices_to_remove}; the original performance is: {weights}." 91 | ) 92 | for index in indices_to_remove[::-1]: 93 | flatten_local_models.pop(relationship[index]) 94 | return _fedavg( 95 | clientid2arch, 96 | n_selected_clients - len(indices_to_remove), 97 | flatten_local_models, 98 | client_models, 99 | ) 100 | 101 | 102 | def validate(conf, model, data_loader, criterion, metrics): 103 | val_perf = master_utils.validate( 104 | conf=conf, 105 | coordinator=None, 106 | model=model, 107 | criterion=criterion, 108 | metrics=metrics, 109 | data_loader=data_loader, 110 | label=None, 111 | display=False, 112 | ) 113 | del model 114 | return val_perf 115 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/aggregation/learn2agg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | from copy import deepcopy 4 | 5 | import torch 6 | 7 | from pcode.utils.module_state import ModuleState 8 | import pcode.aggregation.utils as agg_utils 9 | 10 | 11 | def _get_init_agg_weights(conf, model, num_models): 12 | def _get_agg_weight_template(): 13 | return torch.tensor( 14 | [0.5] * num_models, 15 | requires_grad=True, 16 | device="cuda" if conf.graph.on_cuda else "cpu", 17 | ) 18 | 19 | # build a list of agg_weight params 20 | is_layerwise = False 21 | if ( 22 | "layerwise" not in conf.fl_aggregate 23 | or conf.fl_aggregate["layerwise"] is False 24 | or conf.fl_aggregate["layerwise"] == "False" 25 | ): 26 | agg_weights = _get_agg_weight_template() 27 | elif ( 28 | conf.fl_aggregate["layerwise"] is True 29 | or conf.fl_aggregate["layerwise"] == "True" 30 | ): 31 | is_layerwise = True 32 | agg_weights = dict() 33 | module_state = ModuleState(deepcopy(model.state_dict())) 34 | 35 | for name, _module in model.named_modules(): 36 | for key in _module._parameters: 37 | param_name = f"{name}.{key}" 38 | if param_name in module_state.keys: 39 | agg_weights[param_name] = _get_agg_weight_template() 40 | else: 41 | raise NotImplementedError("not supported scheme for learning to aggregate.") 42 | 43 | optimizer = torch.optim.Adam( 44 | [agg_weights] if not is_layerwise else list(agg_weights.values()), 45 | lr=conf.fl_aggregate["optim_lr"], 46 | betas=(conf.adam_beta_1, conf.adam_beta_2), 47 | eps=conf.adam_eps, 48 | ) 49 | return agg_weights, optimizer, is_layerwise 50 | 51 | 52 | def learning2aggregate( 53 | conf, fedavg_model, client_models, flatten_local_models, criterion, data_loader 54 | ): 55 | # init the local models. 56 | num_models, local_models = agg_utils.recover_models( 57 | conf, client_models, flatten_local_models 58 | ) 59 | 60 | # init the agg_weights 61 | fedavg_model = fedavg_model.cuda() if conf.graph.on_cuda else fedavg_model 62 | agg_weights, optimizer, is_layerwise = _get_init_agg_weights( 63 | conf, fedavg_model, num_models 64 | ) 65 | 66 | # learning the aggregation weights. 67 | for _ in range(int(conf.fl_aggregate["epochs"])): 68 | for _ind, (_input, _target) in enumerate(data_loader): 69 | # place model and data. 70 | if conf.graph.on_cuda: 71 | _input, _target = _input.cuda(), _target.cuda() 72 | 73 | # get mixed model. 74 | mixed_model = get_mixed_model( 75 | conf=conf, 76 | model=fedavg_model, 77 | local_models=local_models, 78 | agg_weights=agg_weights, 79 | is_layerwise=is_layerwise, 80 | ) 81 | 82 | # inference and update alpha 83 | mixed_model.train() 84 | optimizer.zero_grad() 85 | loss = criterion(mixed_model(_input), _target) 86 | loss.backward() 87 | optimizer.step() 88 | 89 | # extract the final agg_weights. 90 | weighted_avg_model = get_mixed_model( 91 | conf=conf, 92 | model=fedavg_model, 93 | local_models=local_models, 94 | agg_weights=agg_weights, 95 | is_layerwise=is_layerwise, 96 | display_agg_weights=True, 97 | ) 98 | del local_models 99 | return weighted_avg_model.cpu() 100 | 101 | 102 | def get_mixed_model( 103 | conf, model, local_models, agg_weights, is_layerwise, display_agg_weights=False 104 | ): 105 | _model = deepcopy(model) 106 | local_states = [ 107 | ModuleState(deepcopy(local_model.state_dict())) 108 | for _, local_model in local_models.items() 109 | ] 110 | 111 | # normalize the aggregation weights and then return an aggregated model. 112 | agg_weights_info = {} 113 | if not is_layerwise: 114 | # get agg_weights. 115 | agg_weights = torch.nn.functional.softmax(agg_weights, dim=0) 116 | if display_agg_weights: 117 | agg_weights_info["globalwise"] = agg_weights.detach().cpu().numpy().tolist() 118 | 119 | # aggregate local models by weights. 120 | model_state = local_states[0] * agg_weights[0] 121 | for idx in range(1, len(local_states)): 122 | model_state += local_states[idx] * agg_weights[idx] 123 | model_state.copy_to_module(_model) 124 | else: 125 | model_state = local_states[0] * 0.0 126 | for key, _agg_weights in agg_weights.items(): 127 | _agg_weights = torch.nn.functional.softmax(_agg_weights, dim=0) 128 | if display_agg_weights: 129 | agg_weights_info[key] = _agg_weights.detach().cpu().numpy().tolist() 130 | 131 | # aggregate local models by weights. 132 | for idx in range(0, len(local_states)): 133 | model_state += local_states[idx].mul_by_key( 134 | factor=_agg_weights[idx], by_key=key 135 | ) 136 | model_state.copy_to_module(_model) 137 | 138 | if display_agg_weights: 139 | conf.logger.log(f"The aggregation weights={json.dumps(agg_weights_info)}") 140 | return _model 141 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/aggregation/server_adaptive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | from pcode.utils.tensor_buffer import TensorBuffer 5 | 6 | 7 | def aggregate(conf, master_model, fedavg_model, client_models, flatten_local_models): 8 | # perform the server Adam. 9 | # Following the setup in the paper, we use momentum of 0.9, 10 | # numerical stability constant epsilon to be 0.01, 11 | # the beta_2 is set to 0.99. 12 | # The suggested server_lr in the original paper is 0.1 13 | fl_aggregate = conf.fl_aggregate 14 | 15 | assert "server_lr" in fl_aggregate 16 | beta_2 = fl_aggregate["beta_2"] if "beta_2" in fl_aggregate else 0.99 17 | 18 | # start the server momentum acceleration. 19 | current_model_tb = TensorBuffer(list(fedavg_model.parameters())) 20 | previous_model_tb = TensorBuffer(list(master_model.parameters())) 21 | 22 | # get the update direction. 23 | update = previous_model_tb.buffer - current_model_tb.buffer 24 | 25 | # using server momentum for the update. 26 | if not hasattr(conf, "second_server_momentum_buffer"): 27 | conf.second_server_momentum_buffer = torch.zeros_like(update) 28 | conf.second_server_momentum_buffer.mul_(beta_2).add_((1 - beta_2) * (update ** 2)) 29 | previous_model_tb.buffer.add_( 30 | -fl_aggregate["server_lr"] 31 | * update 32 | / (torch.sqrt(conf.second_server_momentum_buffer) + 0.01) 33 | ) 34 | 35 | # update the master_model (but will use the bn stats from the fedavg_model) 36 | master_model = fedavg_model 37 | _model_param = list(master_model.parameters()) 38 | previous_model_tb.unpack(_model_param) 39 | 40 | # free the memory. 41 | torch.cuda.empty_cache() 42 | 43 | # a temp hack (only for debug reason). 44 | client_models = dict( 45 | (used_client_arch, master_model.cpu()) 46 | for used_client_arch in conf.used_client_archs 47 | ) 48 | return client_models 49 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/aggregation/server_momentum.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | from pcode.utils.tensor_buffer import TensorBuffer 5 | 6 | 7 | def aggregate(conf, master_model, fedavg_model, client_models, flatten_local_models): 8 | # perform the server momentum (either heavy-ball momentum or nesterov momentum) 9 | fl_aggregate = conf.fl_aggregate 10 | 11 | assert "server_momentum_factor" in fl_aggregate 12 | 13 | # start the server momentum acceleration. 14 | current_model_tb = TensorBuffer(list(fedavg_model.parameters())) 15 | previous_model_tb = TensorBuffer(list(master_model.parameters())) 16 | 17 | # get the update direction. 18 | update = previous_model_tb.buffer - current_model_tb.buffer 19 | 20 | # using server momentum for the update. 21 | if not hasattr(conf, "server_momentum_buffer"): 22 | conf.server_momentum_buffer = torch.zeros_like(update) 23 | conf.server_momentum_buffer.mul_(fl_aggregate["server_momentum_factor"]).add_( 24 | update 25 | ) 26 | previous_model_tb.buffer.add_(-conf.server_momentum_buffer) 27 | 28 | # update the master_model (but will use the bn stats from the fedavg_model) 29 | master_model = fedavg_model 30 | _model_param = list(master_model.parameters()) 31 | previous_model_tb.unpack(_model_param) 32 | 33 | # free the memory. 34 | torch.cuda.empty_cache() 35 | 36 | # a temp hack (only for debug reason). 37 | client_models = dict( 38 | (used_client_arch, master_model.cpu()) 39 | for used_client_arch in conf.used_client_archs 40 | ) 41 | return client_models 42 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/aggregation/swag_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/aggregation/swag_utils/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/aggregation/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | from copy import deepcopy 4 | import collections 5 | 6 | import numpy as np 7 | 8 | 9 | def recover_models(conf, client_models, flatten_local_models, use_cuda=True): 10 | # init the local models. 11 | num_models = len(flatten_local_models) 12 | local_models = {} 13 | 14 | for client_idx, flatten_local_model in flatten_local_models.items(): 15 | arch = conf.clientid2arch[client_idx] 16 | _model = deepcopy(client_models[arch]) 17 | _model_state_dict = _model.state_dict() 18 | flatten_local_model.unpack(_model_state_dict.values()) 19 | _model.load_state_dict(_model_state_dict) 20 | local_models[client_idx] = _model.cuda() if conf.graph.on_cuda else _model 21 | 22 | # turn off the grad for local models. 23 | for param in local_models[client_idx].parameters(): 24 | param.requires_grad = False 25 | return num_models, local_models 26 | 27 | 28 | def modify_model_trainable_status(conf, model, trainable): 29 | _model = deepcopy(model) 30 | if conf.graph.on_cuda: 31 | _model = _model.cuda() 32 | 33 | for _, _param in _model.named_parameters(): 34 | _param.requires_grad = trainable 35 | return _model 36 | 37 | 38 | def check_trainable(conf, model): 39 | _model = deepcopy(model) 40 | if conf.graph.on_cuda: 41 | _model = _model.cuda() 42 | 43 | trainable_params = [] 44 | is_complete = True 45 | for _name, _param in _model.named_parameters(): 46 | if _param.requires_grad is True: 47 | trainable_params.append(_name) 48 | else: 49 | is_complete = False 50 | print(f"\tthe trainable model parameters is complete={is_complete}") 51 | return _model 52 | 53 | 54 | def include_previous_models(conf, local_models): 55 | if hasattr(conf, "previous_local_models"): 56 | local_models.update(collections.ChainMap(*conf.previous_local_models.values())) 57 | return local_models 58 | 59 | 60 | def update_previous_models(conf, client_models): 61 | if not hasattr(conf, "previous_local_models"): 62 | conf.previous_local_models = collections.defaultdict(dict) 63 | 64 | for arch, model in client_models.items(): 65 | conf.previous_local_models[arch][-conf.graph.comm_round] = model.cpu() 66 | # we use reverse order here. 67 | conf.previous_local_models[arch] = dict( 68 | list(sorted(conf.previous_local_models[arch].items(), key=lambda x: -x[0]))[ 69 | -int(conf.fl_aggregate["include_previous_models"]) : 70 | ] 71 | ) 72 | 73 | 74 | def filter_models_by_weights(normalized_weights, detect_fn_name=None): 75 | remained_indices_weights = detect_outlier_and_remain( 76 | normalized_weights, fn_name=detect_fn_name 77 | ) 78 | remained_weights = [weight for index, weight in remained_indices_weights] 79 | whole_indices = list(range(len(normalized_weights))) 80 | indices_to_remove = sorted( 81 | list(set(whole_indices) - set(index for index, _ in remained_indices_weights)) 82 | ) 83 | return indices_to_remove, remained_weights 84 | 85 | 86 | def detect_outlier_and_remain(values, fn_name=None): 87 | if fn_name is None: 88 | return detect_outlier_and_remain_v1(values) 89 | else: 90 | return eval(fn_name)(values) 91 | 92 | 93 | def detect_outlier_and_remain_v1(values): 94 | _values = copy.deepcopy(values) 95 | _values.remove(max(values)) 96 | 97 | # calculate summary statistics 98 | data_mean, data_std = np.mean(_values), np.std(_values) 99 | # identify outliers 100 | cut_off = data_std * 1.5 101 | lower, upper = data_mean - cut_off, data_mean + cut_off 102 | return [(idx, value) for idx, value in enumerate(values) if lower <= value] 103 | 104 | 105 | def detect_outlier_and_remain_v2(values): 106 | _values = copy.deepcopy(values) 107 | _values.remove(max(values)) 108 | q25, q75 = np.quantile(_values, 0.25), np.quantile(_values, 0.75) 109 | iqr = q75 - q25 110 | cut_off = iqr * 1.5 111 | lower, upper = q25 - cut_off, q75 + cut_off 112 | return [(idx, value) for idx, value in enumerate(values) if lower <= value] 113 | 114 | 115 | SCALING_FACTOR = 1.2 116 | 117 | 118 | def get_random_guess_perf(conf): 119 | if conf.data == "cifar10": 120 | return 1 / 10 * 100 * SCALING_FACTOR 121 | elif conf.data == "cifar100": 122 | return 1 / 100 * 100 * SCALING_FACTOR 123 | elif "imagenet" in conf.data: 124 | return 1 / 1000 * SCALING_FACTOR 125 | else: 126 | raise NotImplementedError 127 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/create_coordinator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pcode.utils.stat_tracker as stat_tracker 3 | 4 | 5 | class Coordinator(object): 6 | def __init__(self, conf, metrics): 7 | # init 8 | self.conf = conf 9 | self.metrics_names = metrics.metric_names 10 | self.build_best_trackers() 11 | 12 | def build_best_trackers(self): 13 | self.best_trackers = {} 14 | for name in ["loss"] + self.metrics_names: 15 | self.best_trackers[name] = stat_tracker.BestPerf( 16 | best_perf=None 17 | if not hasattr(self.conf, "best_perf") 18 | else self.conf.best_perf, 19 | larger_is_better=True if "loss" not in name else False, 20 | ) 21 | 22 | def update_perf(self, performance): 23 | for name, perf in performance.items(): 24 | self.best_trackers[name].update(perf, self.conf.graph.comm_round) 25 | 26 | def __call__(self): 27 | return dict( 28 | (name, (best_tracker.best_perf, best_tracker.get_best_perf_loc)) 29 | for name, best_tracker in self.best_trackers.items() 30 | ) 31 | 32 | @property 33 | def key_metric(self): 34 | return self.best_trackers[self.metrics_names[0]] 35 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/create_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | from pcode.datasets.partition_data import DataPartitioner 5 | from pcode.datasets.prepare_data import get_dataset 6 | import pcode.datasets.mixup_data as mixup 7 | 8 | 9 | """create dataset and load the data_batch.""" 10 | 11 | 12 | def load_data_batch(conf, _input, _target, is_training=True): 13 | """Load a mini-batch and record the loading time.""" 14 | if conf.graph.on_cuda: 15 | _input, _target = _input.cuda(), _target.cuda() 16 | 17 | # argument data. 18 | if conf.use_mixup and is_training: 19 | _input, _target_a, _target_b, mixup_lambda = mixup.mixup_data( 20 | _input, 21 | _target, 22 | alpha=conf.mixup_alpha, 23 | assist_non_iid=conf.mixup_noniid, 24 | use_cuda=conf.graph.on_cuda, 25 | ) 26 | _data_batch = { 27 | "input": _input, 28 | "target_a": _target_a, 29 | "target_b": _target_b, 30 | "mixup_lambda": mixup_lambda, 31 | } 32 | else: 33 | _data_batch = {"input": _input, "target": _target} 34 | return _data_batch 35 | 36 | 37 | def define_dataset(conf, data, display_log=True): 38 | # prepare general train/test. 39 | conf.partitioned_by_user = True if "femnist" == conf.data else False 40 | train_dataset = get_dataset(conf, data, conf.data_dir, split="train") 41 | test_dataset = get_dataset(conf, data, conf.data_dir, split="test") 42 | 43 | # create the validation from train. 44 | train_dataset, val_dataset, test_dataset = define_val_dataset( 45 | conf, train_dataset, test_dataset 46 | ) 47 | 48 | if display_log: 49 | conf.logger.log( 50 | "Data stat for original dataset: we have {} samples for train, {} samples for val, {} samples for test.".format( 51 | len(train_dataset), 52 | len(val_dataset) if val_dataset is not None else 0, 53 | len(test_dataset), 54 | ) 55 | ) 56 | return {"train": train_dataset, "val": val_dataset, "test": test_dataset} 57 | 58 | 59 | def define_val_dataset(conf, train_dataset, test_dataset): 60 | assert conf.val_data_ratio >= 0 61 | 62 | partition_sizes = [ 63 | (1 - conf.val_data_ratio) * conf.train_data_ratio, 64 | (1 - conf.val_data_ratio) * (1 - conf.train_data_ratio), 65 | conf.val_data_ratio, 66 | ] 67 | 68 | data_partitioner = DataPartitioner( 69 | conf, 70 | train_dataset, 71 | partition_sizes, 72 | partition_type="origin", 73 | consistent_indices=False, 74 | ) 75 | train_dataset = data_partitioner.use(0) 76 | 77 | # split for val data. 78 | if conf.val_data_ratio > 0: 79 | assert conf.partitioned_by_user is False 80 | 81 | val_dataset = data_partitioner.use(2) 82 | return train_dataset, val_dataset, test_dataset 83 | else: 84 | return train_dataset, None, test_dataset 85 | 86 | 87 | def define_data_loader( 88 | conf, dataset, localdata_id=None, is_train=True, shuffle=True, data_partitioner=None 89 | ): 90 | # determine the data to load, 91 | # either the whole dataset, or a subset specified by partition_type. 92 | if is_train: 93 | world_size = conf.n_clients 94 | partition_sizes = [1.0 / world_size for _ in range(world_size)] 95 | assert localdata_id is not None 96 | 97 | if conf.partitioned_by_user: # partitioned by "users". 98 | # in case our dataset is already partitioned by the client. 99 | # and here we need to load the dataset based on the client id. 100 | dataset.set_user(localdata_id) 101 | data_to_load = dataset 102 | else: # (general) partitioned by "labels". 103 | # in case we have a global dataset and want to manually partition them. 104 | if data_partitioner is None: 105 | # update the data_partitioner. 106 | data_partitioner = DataPartitioner( 107 | conf, dataset, partition_sizes, partition_type=conf.partition_data 108 | ) 109 | # note that the master node will not consume the training dataset. 110 | data_to_load = data_partitioner.use(localdata_id) 111 | conf.logger.log( 112 | f"Data partition for train (client_id={localdata_id + 1}): partitioned data and use subdata." 113 | ) 114 | else: 115 | if conf.partitioned_by_user: # partitioned by "users". 116 | # in case our dataset is already partitioned by the client. 117 | # and here we need to load the dataset based on the client id. 118 | dataset.set_user(localdata_id) 119 | data_to_load = dataset 120 | else: 121 | data_to_load = dataset 122 | conf.logger.log("Data partition for validation/test.") 123 | 124 | # use Dataloader. 125 | data_loader = torch.utils.data.DataLoader( 126 | data_to_load, 127 | batch_size=conf.batch_size, 128 | shuffle=shuffle, 129 | num_workers=conf.num_workers, 130 | pin_memory=conf.pin_memory, 131 | drop_last=False, 132 | ) 133 | 134 | # Some simple statistics. 135 | conf.logger.log( 136 | "\tData stat for {}: # of samples={} for {}. # of batches={}. The batch size={}".format( 137 | "train" if is_train else "validation/test", 138 | len(data_to_load), 139 | f"client_id={localdata_id + 1}" if localdata_id is not None else "Master", 140 | len(data_loader), 141 | conf.batch_size, 142 | ) 143 | ) 144 | conf.num_batches_per_device_per_epoch = len(data_loader) 145 | conf.num_whole_batches_per_worker = ( 146 | conf.num_batches_per_device_per_epoch * conf.local_n_epochs 147 | ) 148 | return data_loader, data_partitioner 149 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/create_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | 5 | class Metrics(object): 6 | """""" 7 | 8 | def __init__(self, model, task="classification"): 9 | self.model = model 10 | self.task = task 11 | self.metric_names = None 12 | self.metrics_fn = self._infer() 13 | 14 | def evaluate(self, loss, output, target, **kwargs): 15 | return self.metrics_fn(loss, output, target, **kwargs) 16 | 17 | def _infer(self): 18 | if self.task == "classification": 19 | self.topks = ( 20 | (1, 5) 21 | if getattr(self.model, "num_classes", None) is not None 22 | and self.model.num_classes >= 5 23 | else (1,) 24 | ) 25 | self.metric_names = ["top{}".format(topk) for topk in self.topks] 26 | return self._accuracy 27 | elif self.task == "language_modeling": 28 | self.metric_names = ["ppl"] 29 | return self._ppl 30 | elif self.task == "transformer_nmt": 31 | self.metric_names = ["ppl", "top1"] 32 | return self._transformer_nmt 33 | else: 34 | raise NotImplementedError 35 | 36 | # some safety check. 37 | assert self.metric_names is not None 38 | 39 | def _accuracy(self, loss, output, target): 40 | """Computes the precision@k for the specified values of k""" 41 | res = [] 42 | 43 | if len(self.topks) > 0: 44 | maxk = max(self.topks) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | for topk in self.topks: 52 | correct_k = correct[:topk].view(-1).float().sum(0, keepdim=True) 53 | res.append(correct_k.mul_(100.0 / batch_size).item()) 54 | else: 55 | res += [0] 56 | return res 57 | 58 | def _ppl(self, loss, output, target): 59 | return [math.exp(loss)] 60 | 61 | def _transformer_nmt(self, loss, output, target, **kwargs): 62 | pred = output.max(1)[1] 63 | n_correct = pred.eq(target) 64 | n_correct = n_correct.masked_select(kwargs["non_pad_mask"]).sum().item() 65 | return [math.exp(loss), n_correct / kwargs["n_samples"]] 66 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/create_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.distributed as dist 3 | 4 | import pcode.models as models 5 | 6 | 7 | def define_model( 8 | conf, 9 | show_stat=True, 10 | to_consistent_model=True, 11 | use_complex_arch=True, 12 | client_id=None, 13 | arch=None, 14 | ): 15 | arch, model = define_cv_classification_model( 16 | conf, client_id, use_complex_arch, arch 17 | ) 18 | 19 | # consistent the model. 20 | if to_consistent_model: 21 | consistent_model(conf, model) 22 | 23 | # get the model stat info. 24 | if show_stat: 25 | get_model_stat(conf, model, arch) 26 | return arch, model 27 | 28 | 29 | """define loaders for different models.""" 30 | 31 | 32 | def determine_arch(conf, client_id, use_complex_arch): 33 | # the client_id starts from 1. 34 | _id = client_id if client_id is not None else 0 35 | if use_complex_arch: 36 | if _id == 0: 37 | arch = conf.arch_info["master"] 38 | else: 39 | archs = conf.arch_info["worker"] 40 | if len(conf.arch_info["worker"]) == 1: 41 | arch = archs[0] 42 | else: 43 | assert "num_clients_per_model" in conf.arch_info 44 | assert ( 45 | conf.arch_info["num_clients_per_model"] * len(archs) 46 | == conf.n_clients 47 | ) 48 | arch = archs[int((_id - 1) / conf.arch_info["num_clients_per_model"])] 49 | else: 50 | arch = conf.arch 51 | return arch 52 | 53 | 54 | def define_cv_classification_model(conf, client_id, use_complex_arch, arch): 55 | # determine the arch. 56 | arch = determine_arch(conf, client_id, use_complex_arch) if arch is None else arch 57 | # use the determined arch to init the model. 58 | if "wideresnet" in arch: 59 | model = models.__dict__["wideresnet"](conf) 60 | elif "resnet" in arch and "resnet_evonorm" not in arch: 61 | model = models.__dict__["resnet"](conf, arch=arch) 62 | elif "resnet_evonorm" in arch: 63 | model = models.__dict__["resnet_evonorm"](conf, arch=arch) 64 | elif "regnet" in arch.lower(): 65 | model = models.__dict__["regnet"](conf, arch=arch) 66 | elif "densenet" in arch: 67 | model = models.__dict__["densenet"](conf) 68 | elif "vgg" in arch: 69 | model = models.__dict__["vgg"](conf) 70 | elif "mobilenetv2" in arch: 71 | model = models.__dict__["mobilenetv2"](conf) 72 | elif "shufflenetv2" in arch: 73 | model = models.__dict__["shufflenetv2"](conf, arch=arch) 74 | elif "efficientnet" in arch: 75 | model = models.__dict__["efficientnet"](conf) 76 | elif "federated_averaging_cnn" in arch: 77 | model = models.__dict__["simple_cnn"](conf) 78 | elif "moderate_cnn" in arch: 79 | model = models.__dict__["moderate_cnn"](conf) 80 | else: 81 | model = models.__dict__[arch](conf) 82 | return arch, model 83 | 84 | 85 | """some utilities functions.""" 86 | 87 | 88 | def get_model_stat(conf, model, arch): 89 | conf.logger.log( 90 | "\t=> {} created model '{}. Total params: {}M".format( 91 | "Master" 92 | if conf.graph.rank == 0 93 | else f"Worker-{conf.graph.worker_id} (client-{conf.graph.client_id})", 94 | arch, 95 | sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 96 | ) 97 | ) 98 | 99 | 100 | def consistent_model(conf, model): 101 | """it might because of MPI, the model for each process is not the same. 102 | 103 | This function is proposed to fix this issue, 104 | i.e., use the model (rank=0) as the global model. 105 | """ 106 | conf.logger.log("\tconsistent model for process (rank {})".format(conf.graph.rank)) 107 | cur_rank = conf.graph.rank 108 | for param in model.parameters(): 109 | param.data = param.data if cur_rank == 0 else param.data - param.data 110 | dist.all_reduce(param.data, op=dist.ReduceOp.SUM) 111 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/create_optimizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | 5 | def define_optimizer(conf, model, optimizer_name, lr=None): 6 | # define the param to optimize. 7 | params = [ 8 | { 9 | "params": [value], 10 | "name": key, 11 | "weight_decay": conf.weight_decay if "bn" not in key else 0.0, 12 | "param_size": value.size(), 13 | "nelement": value.nelement(), 14 | } 15 | for key, value in model.named_parameters() 16 | ] 17 | 18 | # define the optimizer. 19 | if optimizer_name == "sgd": 20 | optimizer = torch.optim.SGD( 21 | params, 22 | lr=conf.lr if lr is None else lr, 23 | momentum=conf.momentum_factor, 24 | nesterov=conf.use_nesterov, 25 | ) 26 | elif optimizer_name == "adam": 27 | optimizer = torch.optim.Adam( 28 | params, 29 | lr=conf.lr if lr is None else lr, 30 | betas=(conf.adam_beta_1, conf.adam_beta_2), 31 | eps=conf.adam_eps, 32 | ) 33 | else: 34 | raise NotImplementedError 35 | return optimizer 36 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/datasets/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/datasets/loader/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/loader/epsilon_or_rcv1_folder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from pcode.datasets.loader.utils import LMDBPT 4 | 5 | 6 | def define_epsilon_or_rcv1_folder(root): 7 | print("load epsilon_or_rcv1 from lmdb: {}.".format(root)) 8 | return LMDBPT(root, is_image=False) 9 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/loader/imagenet_folder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import torch.utils.data as data 8 | import torchvision.datasets as datasets 9 | from torchvision.datasets.utils import check_integrity 10 | 11 | from pcode.datasets.loader.preprocess_toolkit import get_transform 12 | from pcode.datasets.loader.utils import LMDBPT 13 | 14 | 15 | def define_imagenet_folder( 16 | conf, name, root, flag, cuda=True, transform=None, is_image=True 17 | ): 18 | is_train = "train" in root 19 | # note that for the standard imagenet training, 20 | # we should correctly normalize the input. 21 | if transform is None: 22 | transform = get_transform(name, augment=is_train, color_process=False) 23 | 24 | if flag: 25 | print("load imagenet from lmdb: {}".format(root)) 26 | return LMDBPT(root, transform=transform, is_image=is_image) 27 | else: 28 | print("load imagenet using pytorch's default dataloader.") 29 | return datasets.ImageFolder( 30 | root=root, transform=transform, target_transform=None 31 | ) 32 | 33 | 34 | class ImageNetDS(data.Dataset): 35 | """`Downsampled ImageNet `_ Datasets. 36 | Args: 37 | root (string): Root directory of dataset where directory 38 | ``ImagenetXX_train`` exists. 39 | img_size (int): Dimensions of the images: 64,32,16,8 40 | train (bool, optional): If True, creates dataset from training set, otherwise 41 | creates from test set. 42 | transform (callable, optional): A function/transform that takes in an PIL image 43 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 44 | target_transform (callable, optional): A function/transform that takes in the 45 | target and transforms it. 46 | """ 47 | 48 | base_folder = "imagenet{}" 49 | train_list = [ 50 | ["train_data_batch_1", ""], 51 | ["train_data_batch_2", ""], 52 | ["train_data_batch_3", ""], 53 | ["train_data_batch_4", ""], 54 | ["train_data_batch_5", ""], 55 | ["train_data_batch_6", ""], 56 | ["train_data_batch_7", ""], 57 | ["train_data_batch_8", ""], 58 | ["train_data_batch_9", ""], 59 | ["train_data_batch_10", ""], 60 | ] 61 | 62 | test_list = [["val_data", ""]] 63 | 64 | def __init__( 65 | self, root, img_size, train=True, transform=None, target_transform=None 66 | ): 67 | self.root = os.path.expanduser(root) 68 | self.transform = transform 69 | self.target_transform = target_transform 70 | self.train = train # training set or test set 71 | self.img_size = img_size 72 | 73 | self.base_folder = self.base_folder.format(img_size) 74 | 75 | # now load the picked numpy arrays 76 | if self.train: 77 | self.data = [] 78 | self.targets = [] 79 | for fentry in self.train_list: 80 | f = fentry[0] 81 | file = os.path.join(self.root, self.base_folder, f) 82 | with open(file, "rb") as fo: 83 | entry = pickle.load(fo) 84 | self.data.append(entry["data"]) 85 | self.targets += [label - 1 for label in entry["labels"]] 86 | self.mean = entry["mean"] 87 | 88 | self.data = np.concatenate(self.data) 89 | else: 90 | f = self.test_list[0][0] 91 | file = os.path.join(self.root, self.base_folder, f) 92 | with open(file, "rb") as fo: 93 | entry = pickle.load(fo) 94 | self.data = entry["data"] 95 | self.targets = [label - 1 for label in entry["labels"]] 96 | 97 | self.data = self.data.reshape((self.data.shape[0], 3, 32, 32)) 98 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 99 | 100 | def __getitem__(self, index): 101 | """ 102 | Args: 103 | index (int): Index 104 | Returns: 105 | tuple: (image, target) where target is index of the target class. 106 | """ 107 | if self.train: 108 | img, target = self.data[index], self.targets[index] 109 | else: 110 | img, target = self.data[index], self.targets[index] 111 | 112 | # doing this so that it is consistent with all other datasets to return a PIL Image 113 | img = Image.fromarray(img) 114 | 115 | if self.transform is not None: 116 | img = self.transform(img) 117 | 118 | if self.target_transform is not None: 119 | target = self.target_transform(target) 120 | 121 | return img, target 122 | 123 | def __len__(self): 124 | return len(self.data) 125 | 126 | def _check_integrity(self): 127 | root = self.root 128 | for fentry in self.train_list + self.test_list: 129 | filename, md5 = fentry[0], fentry[1] 130 | fpath = os.path.join(root, self.base_folder, filename) 131 | if not check_integrity(fpath, md5): 132 | return False 133 | return True 134 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/loader/preprocess_toolkit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | 7 | 8 | __imagenet_stats = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} 9 | __imagenet_pca = { 10 | "eigval": torch.Tensor([0.2175, 0.0188, 0.0045]), 11 | "eigvec": torch.Tensor( 12 | [ 13 | [-0.5675, 0.7192, 0.4009], 14 | [-0.5808, -0.0045, -0.8140], 15 | [-0.5836, -0.6948, 0.4203], 16 | ] 17 | ), 18 | } 19 | 20 | 21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 22 | t_list = [transforms.CenterCrop(input_size), transforms.ToTensor()] 23 | if normalize is not None: 24 | t_list += [transforms.Normalize(**normalize)] 25 | if scale_size != input_size: 26 | t_list = [transforms.Resize(scale_size)] + t_list 27 | return transforms.Compose(t_list) 28 | 29 | 30 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 31 | t_list = [transforms.RandomCrop(input_size), transforms.ToTensor()] 32 | if normalize is not None: 33 | t_list += [transforms.Normalize(**normalize)] 34 | if scale_size != input_size: 35 | t_list = [transforms.Resize(scale_size)] + t_list 36 | return transforms.Compose(t_list) 37 | 38 | 39 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 40 | padding = int((scale_size - input_size) / 2) 41 | t_list = [ 42 | transforms.RandomCrop(input_size, padding=padding), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | ] 46 | if normalize is not None: 47 | t_list += [transforms.Normalize(**normalize)] 48 | return transforms.Compose(t_list) 49 | 50 | 51 | def inception_preproccess(input_size, normalize=__imagenet_stats): 52 | t_list = [ 53 | transforms.RandomResizedCrop(input_size), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.ToTensor(), 56 | ] 57 | if normalize is not None: 58 | t_list += [transforms.Normalize(**normalize)] 59 | return transforms.Compose(t_list) 60 | 61 | 62 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 63 | t_list = [ 64 | transforms.RandomResizedCrop(input_size), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.ToTensor(), 67 | ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 68 | Lighting(0.1, __imagenet_pca["eigval"], __imagenet_pca["eigvec"]), 69 | ] 70 | if normalize is not None: 71 | t_list += [transforms.Normalize(**normalize)] 72 | return transforms.Compose(t_list) 73 | 74 | 75 | def get_transform( 76 | name="imagenet", 77 | input_size=None, 78 | scale_size=None, 79 | normalize=None, 80 | augment=True, 81 | color_process=False, 82 | ): 83 | is_downsampled = "8" in name or "16" in name or "32" in name or "64" in name 84 | downsampled_img_scale = int(name[8:]) if is_downsampled else None 85 | if normalize is None and not is_downsampled: 86 | normalize = __imagenet_stats 87 | 88 | if "imagenet" in name: 89 | scale_size = scale_size or ( 90 | downsampled_img_scale if downsampled_img_scale is not None else 256 91 | ) 92 | input_size = input_size or ( 93 | downsampled_img_scale if downsampled_img_scale is not None else 224 94 | ) 95 | 96 | if augment: 97 | if color_process: 98 | preprocess_fn = inception_color_preproccess 99 | else: 100 | preprocess_fn = inception_preproccess 101 | return preprocess_fn(input_size, normalize=normalize) 102 | else: 103 | return scale_crop( 104 | input_size=input_size, scale_size=scale_size, normalize=normalize 105 | ) 106 | elif "cifar" in name: 107 | input_size = input_size or 32 108 | if augment: 109 | scale_size = scale_size or 40 110 | return pad_random_crop( 111 | input_size, scale_size=scale_size, normalize=normalize 112 | ) 113 | else: 114 | scale_size = scale_size or 32 115 | return scale_crop( 116 | input_size=input_size, scale_size=scale_size, normalize=normalize 117 | ) 118 | elif name == "mnist": 119 | normalize = {"mean": [0.5], "std": [0.5]} 120 | input_size = input_size or 28 121 | if augment: 122 | scale_size = scale_size or 32 123 | return pad_random_crop( 124 | input_size, scale_size=scale_size, normalize=normalize 125 | ) 126 | else: 127 | scale_size = scale_size or 32 128 | return scale_crop( 129 | input_size=input_size, scale_size=scale_size, normalize=normalize 130 | ) 131 | 132 | 133 | class Lighting(object): 134 | """Lighting noise(AlexNet - style PCA - based noise)""" 135 | 136 | def __init__(self, alphastd, eigval, eigvec): 137 | self.alphastd = alphastd 138 | self.eigval = eigval 139 | self.eigvec = eigvec 140 | 141 | def __call__(self, img): 142 | if self.alphastd == 0: 143 | return img 144 | 145 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 146 | rgb = ( 147 | self.eigvec.type_as(img) 148 | .clone() 149 | .mul(alpha.view(1, 3).expand(3, 3)) 150 | .mul(self.eigval.view(1, 3).expand(3, 3)) 151 | .sum(1) 152 | .squeeze() 153 | ) 154 | 155 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 156 | 157 | 158 | class Grayscale(object): 159 | def __call__(self, img): 160 | gs = img.clone() 161 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 162 | gs[1].copy_(gs[0]) 163 | gs[2].copy_(gs[0]) 164 | return gs 165 | 166 | 167 | class Saturation(object): 168 | def __init__(self, var): 169 | self.var = var 170 | 171 | def __call__(self, img): 172 | gs = Grayscale()(img) 173 | alpha = random.uniform(0, self.var) 174 | return img.lerp(gs, alpha) 175 | 176 | 177 | class Brightness(object): 178 | def __init__(self, var): 179 | self.var = var 180 | 181 | def __call__(self, img): 182 | gs = img.new().resize_as_(img).zero_() 183 | alpha = random.uniform(0, self.var) 184 | return img.lerp(gs, alpha) 185 | 186 | 187 | class Contrast(object): 188 | def __init__(self, var): 189 | self.var = var 190 | 191 | def __call__(self, img): 192 | gs = Grayscale()(img) 193 | gs.fill_(gs.mean()) 194 | alpha = random.uniform(0, self.var) 195 | return img.lerp(gs, alpha) 196 | 197 | 198 | class RandomOrder(object): 199 | """Composes several transforms together in random order.""" 200 | 201 | def __init__(self, transforms): 202 | self.transforms = transforms 203 | 204 | def __call__(self, img): 205 | if self.transforms is None: 206 | return img 207 | order = torch.randperm(len(self.transforms)) 208 | for i in order: 209 | img = self.transforms[i](img) 210 | return img 211 | 212 | 213 | class ColorJitter(RandomOrder): 214 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 215 | self.transforms = [] 216 | if brightness != 0: 217 | self.transforms.append(Brightness(brightness)) 218 | if contrast != 0: 219 | self.transforms.append(Contrast(contrast)) 220 | if saturation != 0: 221 | self.transforms.append(Saturation(saturation)) 222 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/loader/pseudo_imagenet_folder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | from PIL import Image 5 | import torch.utils.data as data 6 | 7 | 8 | class ImageNetDS(data.Dataset): 9 | """ 10 | Args: 11 | root (string): Root directory of dataset. 12 | img_size (int): Dimensions of the images: 128. 13 | train (bool, optional): If True, creates dataset from training set, otherwise 14 | creates from test set. 15 | transform (callable, optional): A function/transform that takes in an PIL image 16 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 17 | target_transform (callable, optional): A function/transform that takes in the 18 | target and transforms it. 19 | """ 20 | 21 | def __init__(self, root, train=True, transform=None, target_transform=None): 22 | self.root = os.path.expanduser(root) 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | self.train = train # training set or test set 26 | 27 | # get the filenames. 28 | self.class_paths = [ 29 | (_class, os.path.join(self.root, _class)) 30 | for _class in os.listdir(self.root) 31 | ] 32 | self.filenames = [] 33 | self.filename2target = {} 34 | for _class, class_path in self.class_paths: 35 | for file_path in os.listdir(class_path): 36 | abs_file_path = os.path.join(class_path, file_path) 37 | self.filenames.append(abs_file_path) 38 | self.filename2target[abs_file_path] = _class 39 | 40 | def __getitem__(self, index): 41 | """ 42 | Args: 43 | index (int): Index 44 | Returns: 45 | tuple: (image, target) where target is index of the target class. 46 | """ 47 | img = Image.open(self.filenames[index]) 48 | target = self.filename2target[self.filenames[index]] 49 | 50 | if self.transform is not None: 51 | img = self.transform(img) 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return img, target 56 | 57 | def __len__(self): 58 | return len(self.filenames) 59 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/loader/serialize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | __all__ = ["loads", "dumps"] 6 | 7 | 8 | def create_dummy_func(func, dependency): 9 | """ 10 | When a dependency of a function is not available, 11 | create a dummy function which throws ImportError when used. 12 | Args: 13 | func (str): name of the function. 14 | dependency (str or list[str]): name(s) of the dependency. 15 | Returns: 16 | function: a function object 17 | """ 18 | if isinstance(dependency, (list, tuple)): 19 | dependency = ",".join(dependency) 20 | 21 | def _dummy(*args, **kwargs): 22 | raise ImportError( 23 | "Cannot import '{}', therefore '{}' is not available".format( 24 | dependency, func 25 | ) 26 | ) 27 | 28 | return _dummy 29 | 30 | 31 | def dumps_msgpack(obj): 32 | """ 33 | Serialize an object. 34 | Returns: 35 | Implementation-dependent bytes-like object 36 | """ 37 | return msgpack.dumps(obj, use_bin_type=True) 38 | 39 | 40 | def loads_msgpack(buf): 41 | """ 42 | Args: 43 | buf: the output of `dumps`. 44 | """ 45 | return msgpack.loads(buf, raw=False) 46 | 47 | 48 | def dumps_pyarrow(obj): 49 | """ 50 | Serialize an object. 51 | 52 | Returns: 53 | Implementation-dependent bytes-like object 54 | """ 55 | return pa.serialize(obj).to_buffer() 56 | 57 | 58 | def loads_pyarrow(buf): 59 | """ 60 | Args: 61 | buf: the output of `dumps`. 62 | """ 63 | return pa.deserialize(buf) 64 | 65 | 66 | try: 67 | # fixed in pyarrow 0.9: https://github.com/apache/arrow/pull/1223#issuecomment-359895666 68 | import pyarrow as pa 69 | except ImportError: 70 | pa = None 71 | dumps_pyarrow = create_dummy_func("dumps_pyarrow", ["pyarrow"]) # noqa 72 | loads_pyarrow = create_dummy_func("loads_pyarrow", ["pyarrow"]) # noqa 73 | 74 | try: 75 | import msgpack 76 | import msgpack_numpy 77 | 78 | msgpack_numpy.patch() 79 | except ImportError: 80 | assert pa is not None, "pyarrow is a dependency of tensorpack!" 81 | loads_msgpack = create_dummy_func( # noqa 82 | "loads_msgpack", ["msgpack", "msgpack_numpy"] 83 | ) 84 | dumps_msgpack = create_dummy_func( # noqa 85 | "dumps_msgpack", ["msgpack", "msgpack_numpy"] 86 | ) 87 | 88 | if os.environ.get("TENSORPACK_SERIALIZE", "msgpack") == "msgpack": 89 | loads = loads_msgpack 90 | dumps = dumps_msgpack 91 | else: 92 | loads = loads_pyarrow 93 | dumps = dumps_pyarrow 94 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/loader/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | 5 | import lmdb 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | 10 | import torch.utils.data as data 11 | 12 | import pcode.datasets.loader.serialize as serialize 13 | 14 | 15 | if sys.version_info[0] == 2: 16 | import cPickle as pickle 17 | else: 18 | import pickle 19 | 20 | 21 | def be_ncwh_pt(x): 22 | return x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h) 23 | 24 | 25 | def uint8_to_float(x): 26 | x = x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h) 27 | return x.float() / 128.0 - 1.0 28 | 29 | 30 | class LMDBPT(data.Dataset): 31 | """A class to load the LMDB file for extreme large datasets. 32 | Args: 33 | root (string): Either root directory for the database files, 34 | or a absolute path pointing to the file. 35 | classes (string or list): One of {'train', 'val', 'test'} or a list of 36 | categories to load. e,g. ['bedroom_train', 'church_train']. 37 | transform (callable, optional): A function/transform that 38 | takes in an PIL image and returns a transformed version. 39 | E.g, ``transforms.RandomCrop`` 40 | target_transform (callable, optional): 41 | A function/transform that takes in the target and transforms it. 42 | """ 43 | 44 | def __init__(self, root, transform=None, target_transform=None, is_image=True): 45 | self.root = os.path.expanduser(root) 46 | self.transform = transform 47 | self.target_transform = target_transform 48 | self.lmdb_files = self._get_valid_lmdb_files() 49 | 50 | # for each class, create an LSUNClassDataset 51 | self.dbs = [] 52 | for lmdb_file in self.lmdb_files: 53 | self.dbs.append( 54 | LMDBPTClass( 55 | root=lmdb_file, 56 | transform=transform, 57 | target_transform=target_transform, 58 | is_image=is_image, 59 | ) 60 | ) 61 | 62 | # build up indices. 63 | self.indices = np.cumsum([len(db) for db in self.dbs]) 64 | self.length = self.indices[-1] 65 | self._get_index_zones = self._build_indices() 66 | 67 | def _get_valid_lmdb_files(self): 68 | """get valid lmdb based on given root.""" 69 | if not self.root.endswith(".lmdb"): 70 | for l in os.listdir(self.root): 71 | if "_" in l and "-lock" not in l: 72 | yield os.path.join(self.root, l) 73 | else: 74 | yield self.root 75 | 76 | def _build_indices(self): 77 | indices = self.indices 78 | from_to_indices = enumerate(zip(indices[:-1], indices[1:])) 79 | 80 | def f(x): 81 | if len(list(from_to_indices)) == 0: 82 | return 0, x 83 | 84 | for ind, (from_index, to_index) in from_to_indices: 85 | if from_index <= x and x < to_index: 86 | return ind, x - from_index 87 | 88 | return f 89 | 90 | def _get_matched_index(self, index): 91 | return self._get_index_zones(index) 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | Returns: 98 | tuple: Tuple (image, target) 99 | """ 100 | block_index, item_index = self._get_matched_index(index) 101 | image, target = self.dbs[block_index][item_index] 102 | return image, target 103 | 104 | def __len__(self): 105 | return self.length 106 | 107 | def __repr__(self): 108 | fmt_str = "Dataset " + self.__class__.__name__ + "\n" 109 | fmt_str += " Number of datapoints: {}\n".format(self.__len__()) 110 | fmt_str += " Root Location: {}\n".format(self.root) 111 | tmp = " Transforms (if any): " 112 | fmt_str += "{0}{1}\n".format( 113 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 114 | ) 115 | tmp = " Target Transforms (if any): " 116 | fmt_str += "{0}{1}".format( 117 | tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 118 | ) 119 | return fmt_str 120 | 121 | 122 | class LMDBPTClass(data.Dataset): 123 | def __init__(self, root, transform=None, target_transform=None, is_image=True): 124 | self.root = os.path.expanduser(root) 125 | self.transform = transform 126 | self.target_transform = target_transform 127 | self.is_image = is_image 128 | 129 | # open lmdb env. 130 | self.env = self._open_lmdb() 131 | 132 | # get file stats. 133 | self._get_length() 134 | 135 | # prepare cache_file 136 | self._prepare_cache() 137 | 138 | def _open_lmdb(self): 139 | return lmdb.open( 140 | self.root, 141 | subdir=os.path.isdir(self.root), 142 | readonly=True, 143 | lock=False, 144 | readahead=False, 145 | map_size=1099511627776 * 2, 146 | max_readers=1, 147 | meminit=False, 148 | ) 149 | 150 | def _get_length(self): 151 | with self.env.begin(write=False) as txn: 152 | self.length = txn.stat()["entries"] 153 | 154 | if txn.get(b"__keys__") is not None: 155 | self.length -= 1 156 | 157 | def _prepare_cache(self): 158 | cache_file = self.root + "_cache_" 159 | if os.path.isfile(cache_file): 160 | self.keys = pickle.load(open(cache_file, "rb")) 161 | else: 162 | with self.env.begin(write=False) as txn: 163 | self.keys = [key for key, _ in txn.cursor() if key != b"__keys__"] 164 | pickle.dump(self.keys, open(cache_file, "wb")) 165 | 166 | def _image_decode(self, x): 167 | image = cv2.imdecode(x, cv2.IMREAD_COLOR).astype("uint8") 168 | return Image.fromarray(image, "RGB") 169 | 170 | def __getitem__(self, index): 171 | env = self.env 172 | with env.begin(write=False) as txn: 173 | bin_file = txn.get(self.keys[index]) 174 | 175 | image, target = serialize.loads(bin_file) 176 | if self.is_image: 177 | image = cv2.imdecode(image, cv2.IMREAD_COLOR).astype("uint8") 178 | else: 179 | if "img_size" not in dir(self): 180 | self.img_size = int(np.sqrt(image.shape[0] / 3)) 181 | image = image.reshape(3, 32, 32).transpose((1, 2, 0)).astype("uint8") 182 | image = Image.fromarray(image, "RGB") 183 | 184 | if self.transform is not None: 185 | image = self.transform(image) 186 | if self.target_transform is not None: 187 | target = self.target_transform(target) 188 | return image, target 189 | 190 | def __len__(self): 191 | return self.length 192 | 193 | def __repr__(self): 194 | return self.__class__.__name__ + " (" + self.root + ")" 195 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/datasets/mixup_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """some utilities for mixup.""" 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def mixup_criterion(criterion, pred, y_a, y_b, _lambda): 8 | return _lambda * criterion(pred, y_a) + (1 - _lambda) * criterion(pred, y_b) 9 | 10 | 11 | def mixup_data(x, y, alpha=1.0, assist_non_iid=False, use_cuda=True): 12 | """Returns mixed inputs, pairs of targets, and lambda""" 13 | if alpha > 0: 14 | _lambda = np.random.beta(alpha, alpha) 15 | else: 16 | _lambda = 1 17 | 18 | batch_size = x.size()[0] 19 | if not assist_non_iid: 20 | if use_cuda: 21 | index = torch.randperm(batch_size).cuda() 22 | else: 23 | index = torch.randperm(batch_size) 24 | 25 | mixed_x = _lambda * x + (1 - _lambda) * x[index, :] 26 | y_a, y_b = y, y[index] 27 | else: 28 | # build the sampling probability for target. 29 | unique_y, counts = torch.unique(y, sorted=True, return_counts=True) 30 | unique_y, counts = unique_y.unsqueeze(1), counts.unsqueeze(1) 31 | replaced_counts = y.clone() 32 | for _unique_y, _count in zip(unique_y, counts): 33 | replaced_counts = torch.where( 34 | replaced_counts == _unique_y, _count, replaced_counts 35 | ) 36 | prob_y = 1.0 - 1.0 * replaced_counts / batch_size 37 | 38 | # get index. 39 | index = torch.multinomial( 40 | input=prob_y, num_samples=batch_size, replacement=True 41 | ) 42 | if use_cuda: 43 | index = index.cuda() 44 | 45 | # mixup. 46 | mixed_x = _lambda * x + (1 - _lambda) * x[index, :] 47 | y_a, y_b = y, y[index] 48 | 49 | return mixed_x, y_a, y_b, _lambda 50 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/local_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/local_training/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/local_training/random_reinit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def random_reinit_model(conf, model): 9 | if conf.random_reinit_local_model is None: 10 | return 11 | else: 12 | assert "resnet" in conf.arch or "federated_averaging_cnn" in conf.arch 13 | names = [ 14 | (_name, _module) 15 | for _name, _module in model.named_modules() 16 | if ( 17 | len(list(_module.children())) == 0 18 | and "bn" not in _name 19 | and ("conv" in _name or "classifier" in _name) 20 | ) 21 | ] 22 | 23 | if conf.random_reinit_local_model == "last": 24 | name_module = names[-1] 25 | weight_initialization(name_module[1]) 26 | elif "random" in conf.random_reinit_local_model: 27 | name_module = names[conf.random_state.choice(len(names))] 28 | weight_initialization(name_module[1]) 29 | else: 30 | raise NotImplementedError 31 | 32 | conf.logger.log( 33 | f"Worker-{conf.graph.worker_id} (client-{conf.graph.client_id}) received the model from Master and reinitialize layer={name_module[0]}." 34 | ) 35 | 36 | 37 | def weight_initialization(m): 38 | if isinstance(m, nn.Conv2d): 39 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 40 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | m.weight.data.fill_(1) 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.Linear): 45 | m.weight.data.normal_(mean=0, std=0.01) 46 | # torch.nn.init.xavier_uniform(m.weight.data) 47 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/master_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import functools 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import pcode.datasets.mixup_data as mixup 8 | import pcode.create_dataset as create_dataset 9 | import pcode.utils.checkpoint as checkpoint 10 | from pcode.utils.stat_tracker import RuntimeTracker 11 | from pcode.utils.logging import display_test_stat, dispaly_best_test_stat 12 | from pcode.utils.mathdict import MathDict 13 | 14 | 15 | def inference( 16 | conf, model, criterion, metrics, data_batch, tracker=None, is_training=True 17 | ): 18 | """Inference on the given model and get loss and accuracy.""" 19 | # do the forward pass and get the output. 20 | output = model(data_batch["input"]) 21 | 22 | # evaluate the output and get the loss, performance. 23 | if conf.use_mixup and is_training: 24 | loss = mixup.mixup_criterion( 25 | criterion, 26 | output, 27 | data_batch["target_a"], 28 | data_batch["target_b"], 29 | data_batch["mixup_lambda"], 30 | ) 31 | 32 | performance_a = metrics.evaluate(loss, output, data_batch["target_a"]) 33 | performance_b = metrics.evaluate(loss, output, data_batch["target_b"]) 34 | performance = [ 35 | data_batch["mixup_lambda"] * _a + (1 - data_batch["mixup_lambda"]) * _b 36 | for _a, _b in zip(performance_a, performance_b) 37 | ] 38 | else: 39 | loss = criterion(output, data_batch["target"]) 40 | performance = metrics.evaluate(loss, output, data_batch["target"]) 41 | 42 | # update tracker. 43 | if tracker is not None: 44 | tracker.update_metrics( 45 | [loss.item()] + performance, n_samples=data_batch["input"].size(0) 46 | ) 47 | return loss, output 48 | 49 | 50 | def do_validation( 51 | conf, 52 | coordinator, 53 | model, 54 | criterion, 55 | metrics, 56 | data_loaders, 57 | performance=None, 58 | label=None, 59 | ): 60 | """Evaluate the model on the test dataset and save to the checkpoint.""" 61 | # wait until the whole group enters this function, and then evaluate. 62 | conf.logger.log(f"Master enters the validation phase.") 63 | if performance is None: 64 | performance = get_avg_perf_on_dataloaders( 65 | conf, coordinator, model, criterion, metrics, data_loaders, label 66 | ) 67 | 68 | # remember best performance and display the val info. 69 | coordinator.update_perf(performance) 70 | dispaly_best_test_stat(conf, coordinator) 71 | 72 | # save to the checkpoint. 73 | conf.logger.log(f"Master finished the validation.") 74 | if not conf.train_fast: 75 | checkpoint.save_to_checkpoint( 76 | conf, 77 | { 78 | "arch": conf.arch, 79 | "current_comm_round": conf.graph.comm_round, 80 | "best_perf": coordinator.best_trackers["top1"].best_perf, 81 | "state_dict": model.state_dict(), 82 | }, 83 | coordinator.best_trackers["top1"].is_best, 84 | dirname=conf.checkpoint_root, 85 | filename="checkpoint.pth.tar", 86 | save_all=conf.save_all_models, 87 | ) 88 | conf.logger.log(f"Master saved to checkpoint.") 89 | 90 | 91 | def get_avg_perf_on_dataloaders( 92 | conf, coordinator, model, criterion, metrics, data_loaders, label 93 | ): 94 | print(f"\tGet averaged performance from {len(data_loaders)} data_loaders.") 95 | performance = [] 96 | 97 | for idx, data_loader in enumerate(data_loaders): 98 | _performance = validate( 99 | conf, 100 | coordinator, 101 | model, 102 | criterion, 103 | metrics, 104 | data_loader, 105 | label=f"{label}-{idx}" if label is not None else "test_loader", 106 | ) 107 | performance.append(MathDict(_performance)) 108 | performance = functools.reduce(lambda a, b: a + b, performance) / len(performance) 109 | return performance 110 | 111 | 112 | def validate( 113 | conf, 114 | coordinator, 115 | model, 116 | criterion, 117 | metrics, 118 | data_loader, 119 | label="test_loader", 120 | display=True, 121 | ): 122 | """A function for model evaluation.""" 123 | if data_loader is None: 124 | return None 125 | 126 | # switch to evaluation mode. 127 | model.eval() 128 | 129 | # place the model to the device. 130 | if conf.graph.on_cuda: 131 | model = model.cuda() 132 | 133 | # evaluate on test_loader. 134 | tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names) 135 | 136 | for _input, _target in data_loader: 137 | # load data and check performance. 138 | data_batch = create_dataset.load_data_batch( 139 | conf, _input, _target, is_training=False 140 | ) 141 | 142 | with torch.no_grad(): 143 | inference( 144 | conf, 145 | model, 146 | criterion, 147 | metrics, 148 | data_batch, 149 | tracker_te, 150 | is_training=False, 151 | ) 152 | 153 | # place back model to the cpu. 154 | if conf.graph.on_cuda: 155 | model = model.cpu() 156 | 157 | # display the test stat. 158 | perf = tracker_te() 159 | if label is not None: 160 | display_test_stat(conf, coordinator, tracker_te, label) 161 | if display: 162 | conf.logger.log(f"The validation performance = {perf}.") 163 | return perf 164 | 165 | 166 | def ensembled_validate( 167 | conf, 168 | coordinator, 169 | models, 170 | criterion, 171 | metrics, 172 | data_loader, 173 | label="test_loader", 174 | ensemble_scheme=None, 175 | ): 176 | """A function for model evaluation.""" 177 | if data_loader is None: 178 | return None 179 | 180 | # switch to evaluation mode. 181 | for model in models: 182 | model.eval() 183 | 184 | # place the model to the device. 185 | if conf.graph.on_cuda: 186 | model = model.cuda() 187 | 188 | # evaluate on test_loader. 189 | tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names) 190 | 191 | for _input, _target in data_loader: 192 | # load data and check performance. 193 | data_batch = create_dataset.load_data_batch( 194 | conf, _input, _target, is_training=False 195 | ) 196 | 197 | with torch.no_grad(): 198 | # ensemble. 199 | if ( 200 | ensemble_scheme is None 201 | or ensemble_scheme == "avg_losses" 202 | or ensemble_scheme == "avg_logits" 203 | ): 204 | outputs = [] 205 | for model in models: 206 | outputs.append(model(data_batch["input"])) 207 | output = sum(outputs) / len(outputs) 208 | elif ensemble_scheme == "avg_probs": 209 | outputs = [] 210 | for model in models: 211 | outputs.append(F.softmax(model(data_batch["input"]))) 212 | output = sum(outputs) / len(outputs) 213 | 214 | # eval the performance. 215 | loss = torch.FloatTensor([0]) 216 | performance = metrics.evaluate(loss, output, data_batch["target"]) 217 | 218 | # update the tracker. 219 | tracker_te.update_metrics( 220 | [loss.item()] + performance, n_samples=data_batch["input"].size(0) 221 | ) 222 | 223 | # place back model to the cpu. 224 | for model in models: 225 | if conf.graph.on_cuda: 226 | model = model.cpu() 227 | 228 | # display the test stat. 229 | if label is not None: 230 | display_test_stat(conf, coordinator, tracker_te, label) 231 | perf = tracker_te() 232 | conf.logger.log(f"The performance of the ensenmbled model: {perf}.") 233 | return perf 234 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .regnet import * 3 | from .densenet import * 4 | from .wideresnet import * 5 | from .mlp import * 6 | from .lenet import * 7 | from .vgg import * 8 | from .simple_cnns import * 9 | from .moderate_cnns import * 10 | from .efficientnet import * 11 | from .mobilenetv2 import * 12 | from .shufflenetv2 import * 13 | from .resnet_evonorm import * 14 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/efficientnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """EfficientNet in PyTorch. 4 | 5 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". 6 | 7 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | __all__ = ["efficientnet"] 15 | 16 | 17 | def swish(x): 18 | return x * x.sigmoid() 19 | 20 | 21 | class Block(nn.Module): 22 | """expansion + depthwise + pointwise + squeeze-excitation""" 23 | 24 | def __init__( 25 | self, 26 | in_planes, 27 | out_planes, 28 | kernel_size, 29 | stride, 30 | expand_ratio=1, 31 | se_ratio=0.0, 32 | drop_rate=0.0, 33 | ): 34 | super(Block, self).__init__() 35 | self.stride = stride 36 | self.drop_rate = drop_rate 37 | 38 | # Expansion 39 | planes = expand_ratio * in_planes 40 | self.conv1 = nn.Conv2d( 41 | in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False 42 | ) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | 45 | # Depthwise conv 46 | self.conv2 = nn.Conv2d( 47 | planes, 48 | planes, 49 | kernel_size=kernel_size, 50 | stride=stride, 51 | padding=(kernel_size - 1) // 2, 52 | groups=planes, 53 | bias=False, 54 | ) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | 57 | # SE layers 58 | se_planes = max(1, int(planes * se_ratio)) 59 | self.se1 = nn.Conv2d(planes, se_planes, kernel_size=1) 60 | self.se2 = nn.Conv2d(se_planes, planes, kernel_size=1) 61 | 62 | # Output 63 | self.conv3 = nn.Conv2d( 64 | planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False 65 | ) 66 | self.bn3 = nn.BatchNorm2d(out_planes) 67 | 68 | self.shortcut = nn.Sequential() 69 | if stride == 1 and in_planes != out_planes: 70 | self.shortcut = nn.Sequential( 71 | nn.Conv2d( 72 | in_planes, 73 | out_planes, 74 | kernel_size=1, 75 | stride=1, 76 | padding=0, 77 | bias=False, 78 | ), 79 | nn.BatchNorm2d(out_planes), 80 | ) 81 | 82 | def forward(self, x): 83 | out = swish(self.bn1(self.conv1(x))) 84 | out = swish(self.bn2(self.conv2(out))) 85 | # Squeeze-Excitation 86 | w = F.avg_pool2d(out, out.size(2)) 87 | w = swish(self.se1(w)) 88 | w = self.se2(w).sigmoid() 89 | out = out * w 90 | # Output 91 | out = self.bn3(self.conv3(out)) 92 | if self.drop_rate > 0: 93 | out = F.dropout2d(out, self.drop_rate) 94 | shortcut = self.shortcut(x) if self.stride == 1 else out 95 | out = out + shortcut 96 | return out 97 | 98 | 99 | class EfficientNet(nn.Module): 100 | def __init__(self, cfg, num_classes=10, save_activations=False): 101 | super(EfficientNet, self).__init__() 102 | self.cfg = cfg 103 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(32) 105 | self.layers = self._make_layers(in_planes=32) 106 | [setattr(self, f"layer{idx}", layer) for idx, layer in enumerate(self.layers)] 107 | self.linear = nn.Linear(cfg[-1][1], num_classes) 108 | 109 | # a placeholder for activations in the intermediate layers. 110 | self.save_activations = save_activations 111 | self.activations = None 112 | 113 | def _make_layers(self, in_planes): 114 | layers = [] 115 | for expansion, out_planes, num_blocks, kernel_size, stride in self.cfg: 116 | strides = [stride] + [1] * (num_blocks - 1) 117 | for stride in strides: 118 | layers.append( 119 | Block( 120 | in_planes, 121 | out_planes, 122 | kernel_size, 123 | stride, 124 | expansion, 125 | se_ratio=0.25, 126 | drop_rate=0, 127 | ) 128 | ) 129 | in_planes = out_planes 130 | return layers 131 | 132 | def forward(self, x): 133 | out = swish(self.bn1(self.conv1(x))) 134 | 135 | self.activations = [] 136 | for layer in self.layers: 137 | out = layer(out) 138 | if self.save_activations: 139 | self.activations.append(out) 140 | 141 | out = F.adaptive_avg_pool2d(out, 1) 142 | out = out.view(out.size(0), -1) 143 | out = self.linear(out) 144 | return out 145 | 146 | 147 | def EfficientNetB0(): 148 | # (expansion, out_planes, num_blocks, kernel_size, stride) 149 | cfg = [ 150 | (1, 16, 1, 3, 1), 151 | (6, 24, 2, 3, 2), 152 | (6, 40, 2, 5, 2), 153 | (6, 80, 3, 3, 2), 154 | (6, 112, 3, 5, 1), 155 | (6, 192, 4, 5, 2), 156 | (6, 320, 1, 3, 1), 157 | ] 158 | return EfficientNet(cfg) 159 | 160 | 161 | def efficientnet(conf): 162 | if "cifar" in conf.data: 163 | model = EfficientNetB0() 164 | else: 165 | raise NotImplementedError 166 | return model 167 | 168 | 169 | if __name__ == "__main__": 170 | net = EfficientNetB0() 171 | net.save_activations = True 172 | x = torch.randn(2, 3, 32, 32) 173 | y = net(x) 174 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/lenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | 9 | __all__ = ["lenet"] 10 | 11 | 12 | class LeNet5(torch.nn.Module): 13 | def __init__(self): 14 | super(LeNet5, self).__init__() 15 | # Convolution (In LeNet-5, 32x32 images are given as input. Hence padding of 2 is done below) 16 | self.conv1 = torch.nn.Conv2d( 17 | in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2, bias=True 18 | ) 19 | # Max-pooling 20 | self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=2) 21 | # Convolution 22 | self.conv2 = torch.nn.Conv2d( 23 | in_channels=6, 24 | out_channels=16, 25 | kernel_size=5, 26 | stride=1, 27 | padding=0, 28 | bias=True, 29 | ) 30 | # Max-pooling 31 | self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=2) 32 | # Fully connected layer 33 | self.fc1 = torch.nn.Linear( 34 | 16 * 5 * 5, 120 35 | ) # convert matrix with 16*5*5 (= 400) features to a matrix of 120 features (columns) 36 | self.fc2 = torch.nn.Linear( 37 | 120, 84 38 | ) # convert matrix with 120 features to a matrix of 84 features (columns) 39 | self.classifier = torch.nn.Linear( 40 | 84, 10 41 | ) # convert matrix with 84 features to a matrix of 10 features (columns) 42 | 43 | def forward(self, x): 44 | # convolve, then perform ReLU non-linearity 45 | x = torch.nn.functional.relu(self.conv1(x)) 46 | # max-pooling with 2x2 grid 47 | x = self.max_pool_1(x) 48 | # convolve, then perform ReLU non-linearity 49 | x = torch.nn.functional.relu(self.conv2(x)) 50 | # max-pooling with 2x2 grid 51 | x = self.max_pool_2(x) 52 | # first flatten 'max_pool_2_out' to contain 16*5*5 columns 53 | # read through https://stackoverflow.com/a/42482819/7551231 54 | x = x.view(-1, 16 * 5 * 5) 55 | # FC-1, then perform ReLU non-linearity 56 | x = torch.nn.functional.relu(self.fc1(x)) 57 | # FC-2, then perform ReLU non-linearity 58 | x = torch.nn.functional.relu(self.fc2(x)) 59 | # FC-3 60 | x = self.classifier(x) 61 | return x 62 | 63 | 64 | class LeNet(nn.Module): 65 | def __init__(self, n_classes=10): 66 | super(LeNet, self).__init__() 67 | self.conv1 = nn.Conv2d(3, 6, 5) 68 | self.conv2 = nn.Conv2d(6, 16, 5) 69 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 70 | self.fc2 = nn.Linear(120, 84) 71 | self.classifier = nn.Linear(84, n_classes) 72 | 73 | def forward(self, x): 74 | out = F.relu(self.conv1(x)) 75 | out = F.max_pool2d(out, 2) 76 | out = F.relu(self.conv2(out)) 77 | out = F.max_pool2d(out, 2) 78 | out = out.view(out.size(0), -1) 79 | out = F.relu(self.fc1(out)) 80 | out = F.relu(self.fc2(out)) 81 | out = self.classifier(out) 82 | return out 83 | 84 | 85 | def lenet(conf): 86 | """Constructs a LeNet-18 model.""" 87 | if "mnist" in conf.data: 88 | model = LeNet5() 89 | elif "cifar" in conf.data: 90 | model = LeNet(n_classes=10 if conf.data == "cifar10" else 100) 91 | else: 92 | raise NotImplementedError 93 | return model 94 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/mlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ["mlp"] 6 | 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, dataset, num_layers, hidden_size, drop_rate): 10 | super(MLP, self).__init__() 11 | self.dataset = dataset 12 | 13 | # init 14 | self.num_layers = num_layers 15 | self.num_classes = self._decide_num_classes() 16 | input_size = self._decide_input_feature_size() 17 | 18 | # define layers. 19 | for i in range(1, self.num_layers + 1): 20 | in_features = input_size if i == 1 else hidden_size 21 | out_features = hidden_size 22 | 23 | layer = nn.Sequential( 24 | nn.Linear(in_features, out_features), 25 | nn.BatchNorm1d(out_features), 26 | nn.ReLU(), 27 | nn.Dropout(p=drop_rate), 28 | ) 29 | setattr(self, "layer{}".format(i), layer) 30 | 31 | self.classifier = nn.Linear(hidden_size, self.num_classes, bias=False) 32 | 33 | def _decide_num_classes(self): 34 | if self.dataset == "cifar10": 35 | return 10 36 | elif self.dataset == "cifar100": 37 | return 100 38 | 39 | def _decide_input_feature_size(self): 40 | if "cifar" in self.dataset: 41 | return 32 * 32 * 3 42 | elif "mnist" in self.dataset: 43 | return 28 * 28 44 | else: 45 | raise NotImplementedError 46 | 47 | def forward(self, x): 48 | out = x.view(x.size(0), -1) 49 | 50 | for i in range(1, self.num_layers + 1): 51 | out = getattr(self, "layer{}".format(i))(out) 52 | out = self.classifier(out) 53 | return out 54 | 55 | 56 | def mlp(conf): 57 | return MLP( 58 | dataset=conf.data, 59 | num_layers=conf.mlp_num_layers, 60 | hidden_size=conf.mlp_hidden_size, 61 | drop_rate=conf.drop_rate, 62 | ) 63 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """MobileNetV2 in PyTorch. 4 | 5 | See the paper "Inverted Residuals and Linear Bottlenecks: 6 | Mobile Networks for Classification, Detection and Segmentation" for more details. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | __all__ = ["mobilenetv2"] 13 | 14 | 15 | class Block(nn.Module): 16 | """expand + depthwise + pointwise""" 17 | 18 | def __init__(self, in_planes, out_planes, expansion, stride): 19 | super(Block, self).__init__() 20 | self.stride = stride 21 | 22 | planes = expansion * in_planes 23 | self.conv1 = nn.Conv2d( 24 | in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False 25 | ) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d( 28 | planes, 29 | planes, 30 | kernel_size=3, 31 | stride=stride, 32 | padding=1, 33 | groups=planes, 34 | bias=False, 35 | ) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.conv3 = nn.Conv2d( 38 | planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False 39 | ) 40 | self.bn3 = nn.BatchNorm2d(out_planes) 41 | 42 | self.shortcut = nn.Sequential() 43 | if stride == 1 and in_planes != out_planes: 44 | self.shortcut = nn.Sequential( 45 | nn.Conv2d( 46 | in_planes, 47 | out_planes, 48 | kernel_size=1, 49 | stride=1, 50 | padding=0, 51 | bias=False, 52 | ), 53 | nn.BatchNorm2d(out_planes), 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out = out + self.shortcut(x) if self.stride == 1 else out 61 | return out 62 | 63 | 64 | class MobileNetV2(nn.Module): 65 | # (expansion, out_planes, num_blocks, stride) 66 | cfg = [ 67 | (1, 16, 1, 1), 68 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 69 | (6, 32, 3, 2), 70 | (6, 64, 4, 2), 71 | (6, 96, 3, 1), 72 | (6, 160, 3, 2), 73 | (6, 320, 1, 1), 74 | ] 75 | 76 | def __init__(self, dataset, save_activations=False): 77 | super(MobileNetV2, self).__init__() 78 | 79 | # init. 80 | self.dataset = dataset 81 | 82 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 83 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 84 | self.bn1 = nn.BatchNorm2d(32) 85 | self.layers = self._make_layers(in_planes=32) 86 | [setattr(self, f"layer{idx}", layer) for idx, layer in enumerate(self.layers)] 87 | self.conv2 = nn.Conv2d( 88 | 320, 1280, kernel_size=1, stride=1, padding=0, bias=False 89 | ) 90 | self.bn2 = nn.BatchNorm2d(1280) 91 | self.linear = nn.Linear(1280, self._decide_num_classes()) 92 | 93 | # a placeholder for activations in the intermediate layers. 94 | self.save_activations = save_activations 95 | self.activations = None 96 | 97 | def _decide_num_classes(self): 98 | if self.dataset == "cifar10" or self.dataset == "svhn": 99 | return 10 100 | elif self.dataset == "cifar100": 101 | return 100 102 | elif "imagenet" in self.dataset: 103 | return 1000 104 | 105 | def _make_layers(self, in_planes): 106 | layers = [] 107 | for expansion, out_planes, num_blocks, stride in self.cfg: 108 | strides = [stride] + [1] * (num_blocks - 1) 109 | for stride in strides: 110 | layers.append(Block(in_planes, out_planes, expansion, stride)) 111 | in_planes = out_planes 112 | return layers 113 | 114 | def forward(self, x): 115 | out = F.relu(self.bn1(self.conv1(x))) 116 | 117 | self.activations = [] 118 | for layer in self.layers: 119 | out = layer(out) 120 | if self.save_activations: 121 | self.activations.append(out) 122 | 123 | out = F.relu(self.bn2(self.conv2(out))) 124 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 125 | out = F.avg_pool2d(out, 4) 126 | out = out.view(out.size(0), -1) 127 | out = self.linear(out) 128 | return out 129 | 130 | 131 | def mobilenetv2(conf): 132 | if "cifar" in conf.data or "imagenet32" in conf.data: 133 | model = MobileNetV2(dataset=conf.data) 134 | else: 135 | raise NotImplementedError 136 | return model 137 | 138 | 139 | if __name__ == "__main__": 140 | net = MobileNetV2() 141 | net.save_activations = True 142 | x = torch.randn(2, 3, 32, 32) 143 | y = net(x) 144 | print(y.shape) 145 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/moderate_cnns.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ["moderate_cnn"] 6 | 7 | 8 | class ModerateCNN(nn.Module): 9 | def __init__(self, w_conv_bias=False, w_fc_bias=True, save_activations=True): 10 | super(ModerateCNN, self).__init__() 11 | 12 | # Conv Layer block 1 13 | self.conv_layer1 = nn.Sequential( 14 | nn.Conv2d( 15 | in_channels=3, 16 | out_channels=32, 17 | kernel_size=3, 18 | padding=1, 19 | bias=w_conv_bias, 20 | ), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d( 23 | in_channels=32, 24 | out_channels=64, 25 | kernel_size=3, 26 | padding=1, 27 | bias=w_conv_bias, 28 | ), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=2, stride=2), 31 | ) 32 | 33 | # Conv Layer block 2 34 | self.conv_layer2 = nn.Sequential( 35 | nn.Conv2d( 36 | in_channels=64, 37 | out_channels=128, 38 | kernel_size=3, 39 | padding=1, 40 | bias=w_conv_bias, 41 | ), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d( 44 | in_channels=128, 45 | out_channels=128, 46 | kernel_size=3, 47 | padding=1, 48 | bias=w_conv_bias, 49 | ), 50 | nn.ReLU(inplace=True), 51 | nn.MaxPool2d(kernel_size=2, stride=2), 52 | nn.Dropout2d(p=0.05), 53 | ) 54 | 55 | # Conv Layer block 3 56 | self.conv_layer3 = nn.Sequential( 57 | nn.Conv2d( 58 | in_channels=128, 59 | out_channels=256, 60 | kernel_size=3, 61 | padding=1, 62 | bias=w_conv_bias, 63 | ), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d( 66 | in_channels=256, 67 | out_channels=256, 68 | kernel_size=3, 69 | padding=1, 70 | bias=w_conv_bias, 71 | ), 72 | nn.ReLU(inplace=True), 73 | nn.MaxPool2d(kernel_size=2, stride=2), 74 | ) 75 | 76 | self.fc_layer = nn.Sequential( 77 | nn.Dropout(p=0.1), 78 | nn.Linear(4096, 512, bias=w_fc_bias), 79 | nn.ReLU(inplace=True), 80 | nn.Linear(512, 512, bias=w_fc_bias), 81 | nn.ReLU(inplace=True), 82 | nn.Dropout(p=0.1), 83 | nn.Linear(512, 10, bias=w_fc_bias), 84 | ) 85 | 86 | # a placeholder for activations in the intermediate layers. 87 | self.save_activations = save_activations 88 | self.activations = None 89 | 90 | def forward(self, x): 91 | x = self.conv_layer1(x) 92 | activation1 = x 93 | x = self.conv_layer2(x) 94 | activation2 = x 95 | x = self.conv_layer3(x) 96 | activation3 = x 97 | 98 | x = x.view(x.size(0), -1) 99 | x = self.fc_layer(x) 100 | 101 | if self.save_activations: 102 | self.activations = [activation1, activation2, activation3] 103 | return x 104 | 105 | 106 | def moderate_cnn(conf): 107 | dataset = conf.data 108 | 109 | if "cifar" in dataset or dataset == "svhn": 110 | return ModerateCNN(w_conv_bias=conf.w_conv_bias, w_fc_bias=conf.w_fc_bias) 111 | else: 112 | raise NotImplementedError(f"not supported yet.") 113 | 114 | 115 | if __name__ == "__main__": 116 | import torch 117 | 118 | print("cifar10") 119 | net = ModerateCNN() 120 | print(net) 121 | x = torch.randn(1, 3, 32, 32) 122 | y = net(x) 123 | print(y.shape) 124 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/regnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """RegNet in PyTorch. 3 | 4 | Paper: "Designing Network Design Spaces". 5 | 6 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | __all__ = ["regnet"] 14 | 15 | 16 | class SE(nn.Module): 17 | """Squeeze-and-Excitation block.""" 18 | 19 | def __init__(self, in_planes, se_planes): 20 | super(SE, self).__init__() 21 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True) 22 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True) 23 | 24 | def forward(self, x): 25 | out = F.adaptive_avg_pool2d(x, (1, 1)) 26 | out = F.relu(self.se1(out)) 27 | out = self.se2(out).sigmoid() 28 | out = x * out 29 | return out 30 | 31 | 32 | class Block(nn.Module): 33 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio): 34 | super(Block, self).__init__() 35 | # 1x1 36 | w_b = int(round(w_out * bottleneck_ratio)) 37 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False) 38 | self.bn1 = nn.BatchNorm2d(w_b) 39 | # 3x3 40 | num_groups = w_b // group_width 41 | self.conv2 = nn.Conv2d( 42 | w_b, 43 | w_b, 44 | kernel_size=3, 45 | stride=stride, 46 | padding=1, 47 | groups=num_groups, 48 | bias=False, 49 | ) 50 | self.bn2 = nn.BatchNorm2d(w_b) 51 | # se 52 | self.with_se = se_ratio > 0 53 | if self.with_se: 54 | w_se = int(round(w_in * se_ratio)) 55 | self.se = SE(w_b, w_se) 56 | # 1x1 57 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(w_out) 59 | 60 | self.shortcut = nn.Sequential() 61 | if stride != 1 or w_in != w_out: 62 | self.shortcut = nn.Sequential( 63 | nn.Conv2d(w_in, w_out, kernel_size=1, stride=stride, bias=False), 64 | nn.BatchNorm2d(w_out), 65 | ) 66 | 67 | def forward(self, x): 68 | out = F.relu(self.bn1(self.conv1(x))) 69 | out = F.relu(self.bn2(self.conv2(out))) 70 | if self.with_se: 71 | out = self.se(out) 72 | out = self.bn3(self.conv3(out)) 73 | out += self.shortcut(x) 74 | out = F.relu(out) 75 | return out 76 | 77 | 78 | class RegNet(nn.Module): 79 | def __init__(self, cfg, save_activations=False): 80 | super(RegNet, self).__init__() 81 | self.cfg = cfg 82 | self.dataset = cfg["dataset"] 83 | self.num_classes = self._decide_num_classes() 84 | self.in_planes = 64 85 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 86 | self.bn1 = nn.BatchNorm2d(64) 87 | self.layer1 = self._make_layer(0) 88 | self.layer2 = self._make_layer(1) 89 | self.layer3 = self._make_layer(2) 90 | self.layer4 = self._make_layer(3) 91 | self.linear = nn.Linear(self.cfg["widths"][-1], self.num_classes) 92 | 93 | # a placeholder for activations in the intermediate layers. 94 | self.save_activations = save_activations 95 | self.activations = None 96 | 97 | def _make_layer(self, idx): 98 | depth = self.cfg["depths"][idx] 99 | width = self.cfg["widths"][idx] 100 | stride = self.cfg["strides"][idx] 101 | group_width = self.cfg["group_width"] 102 | bottleneck_ratio = self.cfg["bottleneck_ratio"] 103 | se_ratio = self.cfg["se_ratio"] 104 | 105 | layers = [] 106 | for i in range(depth): 107 | s = stride if i == 0 else 1 108 | layers.append( 109 | Block(self.in_planes, width, s, group_width, bottleneck_ratio, se_ratio) 110 | ) 111 | self.in_planes = width 112 | return nn.Sequential(*layers) 113 | 114 | def _decide_num_classes(self): 115 | if self.dataset == "cifar10" or self.dataset == "svhn": 116 | return 10 117 | elif self.dataset == "cifar100": 118 | return 100 119 | elif "imagenet" in self.dataset: 120 | return 1000 121 | elif "femnist" == self.dataset: 122 | return 62 123 | 124 | def forward(self, x): 125 | out = F.relu(self.bn1(self.conv1(x))) 126 | out = self.layer1(out) 127 | activation1 = out 128 | 129 | out = self.layer2(out) 130 | activation2 = out 131 | 132 | out = self.layer3(out) 133 | activation3 = out 134 | 135 | out = self.layer4(out) 136 | activation4 = out 137 | 138 | out = F.adaptive_avg_pool2d(out, (1, 1)) 139 | out = out.view(out.size(0), -1) 140 | out = self.linear(out) 141 | 142 | if self.save_activations: 143 | self.activations = [activation1, activation2, activation3, activation4] 144 | return out 145 | 146 | 147 | def regnet_confs(net_name, dataset): 148 | cfgs = { 149 | "RegNetX_200MF": { 150 | "depths": [1, 1, 4, 7], 151 | "widths": [24, 56, 152, 368], 152 | "strides": [1, 1, 2, 2], 153 | "group_width": 8, 154 | "bottleneck_ratio": 1, 155 | "se_ratio": 0, 156 | "dataset": dataset, 157 | }, 158 | "RegNetX_400MF": { 159 | "depths": [1, 2, 7, 12], 160 | "widths": [32, 64, 160, 384], 161 | "strides": [1, 1, 2, 2], 162 | "group_width": 16, 163 | "bottleneck_ratio": 1, 164 | "se_ratio": 0, 165 | "dataset": dataset, 166 | }, 167 | "RegNetY_400MF": { 168 | "depths": [1, 2, 7, 12], 169 | "widths": [32, 64, 160, 384], 170 | "strides": [1, 1, 2, 2], 171 | "group_width": 16, 172 | "bottleneck_ratio": 1, 173 | "se_ratio": 0.25, 174 | "dataset": dataset, 175 | }, 176 | } 177 | return RegNet(cfgs[net_name]) 178 | 179 | 180 | def regnet(conf, arch=None): 181 | dataset = conf.data 182 | 183 | if "cifar" in conf.data or "svhn" in conf.data: 184 | model = regnet_confs(conf.arch, dataset) 185 | else: 186 | raise NotImplementedError 187 | return model 188 | 189 | 190 | if __name__ == "__main__": 191 | 192 | def get_n_model_params(model): 193 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 194 | 195 | net = regnets(net_name="RegNetX_200MF", dataset="cifar10") 196 | print(f"The net has {get_n_model_params(net)} M.") 197 | x = torch.randn(1, 3, 32, 32) 198 | y = net(x) 199 | print(y.shape) 200 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ShuffleNetV2 in PyTorch. 4 | 5 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = ["shufflenetv2"] 13 | 14 | 15 | class ShuffleBlock(nn.Module): 16 | def __init__(self, groups=2): 17 | super(ShuffleBlock, self).__init__() 18 | self.groups = groups 19 | 20 | def forward(self, x): 21 | """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" 22 | N, C, H, W = x.size() 23 | g = self.groups 24 | return x.view(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 25 | 26 | 27 | class SplitBlock(nn.Module): 28 | def __init__(self, ratio): 29 | super(SplitBlock, self).__init__() 30 | self.ratio = ratio 31 | 32 | def forward(self, x): 33 | c = int(x.size(1) * self.ratio) 34 | return x[:, :c, :, :], x[:, c:, :, :] 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | def __init__(self, in_channels, split_ratio=0.5): 39 | super(BasicBlock, self).__init__() 40 | self.split = SplitBlock(split_ratio) 41 | in_channels = int(in_channels * split_ratio) 42 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(in_channels) 44 | self.conv2 = nn.Conv2d( 45 | in_channels, 46 | in_channels, 47 | kernel_size=3, 48 | stride=1, 49 | padding=1, 50 | groups=in_channels, 51 | bias=False, 52 | ) 53 | self.bn2 = nn.BatchNorm2d(in_channels) 54 | self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(in_channels) 56 | self.shuffle = ShuffleBlock() 57 | 58 | def forward(self, x): 59 | x1, x2 = self.split(x) 60 | out = F.relu(self.bn1(self.conv1(x2))) 61 | out = self.bn2(self.conv2(out)) 62 | out = F.relu(self.bn3(self.conv3(out))) 63 | out = torch.cat([x1, out], 1) 64 | out = self.shuffle(out) 65 | return out 66 | 67 | 68 | class DownBlock(nn.Module): 69 | def __init__(self, in_channels, out_channels): 70 | super(DownBlock, self).__init__() 71 | mid_channels = out_channels // 2 72 | # left 73 | self.conv1 = nn.Conv2d( 74 | in_channels, 75 | in_channels, 76 | kernel_size=3, 77 | stride=2, 78 | padding=1, 79 | groups=in_channels, 80 | bias=False, 81 | ) 82 | self.bn1 = nn.BatchNorm2d(in_channels) 83 | self.conv2 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 84 | self.bn2 = nn.BatchNorm2d(mid_channels) 85 | # right 86 | self.conv3 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 87 | self.bn3 = nn.BatchNorm2d(mid_channels) 88 | self.conv4 = nn.Conv2d( 89 | mid_channels, 90 | mid_channels, 91 | kernel_size=3, 92 | stride=2, 93 | padding=1, 94 | groups=mid_channels, 95 | bias=False, 96 | ) 97 | self.bn4 = nn.BatchNorm2d(mid_channels) 98 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False) 99 | self.bn5 = nn.BatchNorm2d(mid_channels) 100 | 101 | self.shuffle = ShuffleBlock() 102 | 103 | def forward(self, x): 104 | # left 105 | out1 = self.bn1(self.conv1(x)) 106 | out1 = F.relu(self.bn2(self.conv2(out1))) 107 | # right 108 | out2 = F.relu(self.bn3(self.conv3(x))) 109 | out2 = self.bn4(self.conv4(out2)) 110 | out2 = F.relu(self.bn5(self.conv5(out2))) 111 | # concat 112 | out = torch.cat([out1, out2], 1) 113 | out = self.shuffle(out) 114 | return out 115 | 116 | 117 | class ShuffleNetV2(nn.Module): 118 | def __init__(self, net_size, dataset, save_activations=False): 119 | super(ShuffleNetV2, self).__init__() 120 | out_channels = configs[net_size]["out_channels"] 121 | num_blocks = configs[net_size]["num_blocks"] 122 | self.dataset = dataset 123 | 124 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, stride=1, padding=1, bias=False) 125 | self.bn1 = nn.BatchNorm2d(24) 126 | self.in_channels = 24 127 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 128 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 129 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 130 | self.conv2 = nn.Conv2d( 131 | out_channels[2], 132 | out_channels[3], 133 | kernel_size=1, 134 | stride=1, 135 | padding=0, 136 | bias=False, 137 | ) 138 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 139 | self.linear = nn.Linear(out_channels[3], self._decide_num_classes()) 140 | 141 | # a placeholder for activations in the intermediate layers. 142 | self.save_activations = save_activations 143 | self.activations = None 144 | 145 | def _decide_num_classes(self): 146 | if self.dataset == "cifar10" or self.dataset == "svhn": 147 | return 10 148 | elif self.dataset == "cifar100": 149 | return 100 150 | elif "imagenet" in self.dataset: 151 | return 1000 152 | 153 | def _make_layer(self, out_channels, num_blocks): 154 | layers = [DownBlock(self.in_channels, out_channels)] 155 | for i in range(num_blocks): 156 | layers.append(BasicBlock(out_channels)) 157 | self.in_channels = out_channels 158 | return nn.Sequential(*layers) 159 | 160 | def forward(self, x): 161 | out = F.relu(self.bn1(self.conv1(x))) 162 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 163 | out1 = self.layer1(out) 164 | out2 = self.layer2(out1) 165 | out3 = self.layer3(out2) 166 | out = F.relu(self.bn2(self.conv2(out3))) 167 | out = F.avg_pool2d(out, 4) 168 | out = out.view(out.size(0), -1) 169 | out = self.linear(out) 170 | 171 | if self.save_activations: 172 | self.activations = [out1, out2, out3] 173 | return out 174 | 175 | 176 | configs = { 177 | "0.5": {"out_channels": (48, 96, 192, 1024), "num_blocks": (3, 7, 3)}, 178 | "1": {"out_channels": (116, 232, 464, 1024), "num_blocks": (3, 7, 3)}, 179 | "1.5": {"out_channels": (176, 352, 704, 1024), "num_blocks": (3, 7, 3)}, 180 | "2": {"out_channels": (224, 488, 976, 2048), "num_blocks": (3, 7, 3)}, 181 | } 182 | 183 | 184 | def shufflenetv2(conf, arch=None): 185 | net_size = (arch if arch is not None else conf.arch).replace("shufflenetv2-", "") 186 | assert net_size in {"0.5", "1", "1.5", "2"} 187 | 188 | if "cifar" in conf.data or "imagenet32" in conf.data: 189 | model = ShuffleNetV2(net_size=net_size, dataset=conf.data) 190 | else: 191 | raise NotImplementedError 192 | return model 193 | 194 | 195 | if __name__ == "__main__": 196 | net = ShuffleNetV2(net_size="0.5") 197 | net.save_activations = True 198 | x = torch.randn(2, 3, 32, 32) 199 | y = net(x) 200 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/simple_cnns.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ["simple_cnn"] 6 | 7 | 8 | def _decide_num_classes(dataset): 9 | if dataset == "cifar10" or dataset == "svhn": 10 | return 10 11 | elif dataset == "cifar100": 12 | return 100 13 | elif "imagenet" in dataset: 14 | return 1000 15 | elif "mnist" == dataset: 16 | return 10 17 | elif "femnist" == dataset: 18 | return 62 19 | else: 20 | raise NotImplementedError(f"this dataset ({dataset}) is not supported yet.") 21 | 22 | 23 | class CNNMnist(nn.Module): 24 | def __init__(self, dataset, w_conv_bias=False, w_fc_bias=True): 25 | super(CNNMnist, self).__init__() 26 | 27 | # decide the num of classes. 28 | self.num_classes = _decide_num_classes(dataset) 29 | 30 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5, bias=w_conv_bias) 31 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=w_conv_bias) 32 | self.conv2_drop = nn.Dropout2d() 33 | self.fc1 = nn.Linear(320, 50, bias=w_fc_bias) 34 | self.classifier = nn.Linear(50, self.num_classes, bias=w_fc_bias) 35 | 36 | def forward(self, x): 37 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 38 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 39 | x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3]) 40 | x = F.relu(self.fc1(x)) 41 | x = F.dropout(x, training=self.training) 42 | x = self.classifier(x) 43 | return x 44 | 45 | 46 | class CNNfemnist(nn.Module): 47 | def __init__( 48 | self, dataset, w_conv_bias=True, w_fc_bias=True, save_activations=True 49 | ): 50 | super(CNNfemnist, self).__init__() 51 | 52 | # decide the num of classes. 53 | self.num_classes = _decide_num_classes(dataset) 54 | 55 | # define layers. 56 | self.conv1 = nn.Conv2d(1, 32, 5, bias=w_conv_bias) 57 | self.pool = nn.MaxPool2d(2, 2) 58 | self.conv2 = nn.Conv2d(32, 64, 5, bias=w_conv_bias) 59 | self.fc1 = nn.Linear(1024, 2048, bias=w_fc_bias) 60 | self.classifier = nn.Linear(2048, self.num_classes, bias=w_fc_bias) 61 | 62 | # a placeholder for activations in the intermediate layers. 63 | self.save_activations = save_activations 64 | self.activations = None 65 | 66 | def forward(self, x): 67 | activation1 = self.conv1(x) 68 | x = self.pool(F.relu(activation1)) 69 | 70 | activation2 = self.conv2(x) 71 | 72 | x = self.pool(F.relu(activation2)) 73 | x = x.view(-1, 1024) 74 | x = F.relu(self.fc1(x)) 75 | x = self.classifier(x) 76 | 77 | if self.save_activations: 78 | self.activations = [activation1, activation2] 79 | return x 80 | 81 | 82 | class CNNCifar(nn.Module): 83 | def __init__( 84 | self, dataset, w_conv_bias=True, w_fc_bias=True, save_activations=True 85 | ): 86 | super(CNNCifar, self).__init__() 87 | 88 | # decide the num of classes. 89 | self.num_classes = _decide_num_classes(dataset) 90 | 91 | # define layers. 92 | self.conv1 = nn.Conv2d(3, 6, 5, bias=w_conv_bias) 93 | self.pool = nn.MaxPool2d(2, 2) 94 | self.conv2 = nn.Conv2d(6, 16, 5, bias=w_conv_bias) 95 | self.fc1 = nn.Linear(16 * 5 * 5, 120, bias=w_fc_bias) 96 | self.fc2 = nn.Linear(120, 84, bias=w_fc_bias) 97 | self.classifier = nn.Linear(84, self.num_classes, bias=w_fc_bias) 98 | 99 | # a placeholder for activations in the intermediate layers. 100 | self.save_activations = save_activations 101 | self.activations = None 102 | 103 | def forward(self, x): 104 | activation1 = self.conv1(x) 105 | x = self.pool(F.relu(activation1)) 106 | 107 | activation2 = self.conv2(x) 108 | x = self.pool(F.relu(activation2)) 109 | x = x.view(-1, 16 * 5 * 5) 110 | x = F.relu(self.fc1(x)) 111 | x = F.relu(self.fc2(x)) 112 | x = self.classifier(x) 113 | 114 | if self.save_activations: 115 | self.activations = [activation1, activation2] 116 | return x 117 | 118 | 119 | def simple_cnn(conf): 120 | dataset = conf.data 121 | 122 | if "cifar" in dataset or dataset == "svhn": 123 | return CNNCifar(dataset, w_conv_bias=conf.w_conv_bias, w_fc_bias=conf.w_fc_bias) 124 | elif "mnist" == dataset: 125 | return CNNMnist(dataset, w_conv_bias=conf.w_conv_bias, w_fc_bias=conf.w_fc_bias) 126 | elif "femnist" == dataset: 127 | return CNNfemnist( 128 | dataset, w_conv_bias=conf.w_conv_bias, w_fc_bias=conf.w_fc_bias 129 | ) 130 | else: 131 | raise NotImplementedError(f"not supported yet.") 132 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/vgg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | 7 | __all__ = ["vgg"] 8 | 9 | 10 | ARCHITECTURES = { 11 | "O": [4, "M", 8, "M", 16, 16, "M", 32, 32, "M", 32, 32, "M"], 12 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 13 | "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 14 | "D": [ 15 | 64, 16 | 64, 17 | "M", 18 | 128, 19 | 128, 20 | "M", 21 | 256, 22 | 256, 23 | 256, 24 | "M", 25 | 512, 26 | 512, 27 | 512, 28 | "M", 29 | 512, 30 | 512, 31 | 512, 32 | "M", 33 | ], 34 | "E": [ 35 | 64, 36 | 64, 37 | "M", 38 | 128, 39 | 128, 40 | "M", 41 | 256, 42 | 256, 43 | 256, 44 | 256, 45 | "M", 46 | 512, 47 | 512, 48 | 512, 49 | 512, 50 | "M", 51 | 512, 52 | 512, 53 | 512, 54 | 512, 55 | "M", 56 | ], 57 | } 58 | 59 | 60 | class VGG(nn.Module): 61 | def __init__(self, nn_arch, dataset, use_bn=True): 62 | super(VGG, self).__init__() 63 | 64 | # init parameters. 65 | self.use_bn = use_bn 66 | self.nn_arch = nn_arch 67 | self.dataset = dataset 68 | self.num_classes = self._decide_num_classes() 69 | 70 | # init models. 71 | self.features = self._make_layers() 72 | self.intermediate_classifier = nn.Sequential( 73 | nn.Dropout(), 74 | nn.Linear(512, 512), 75 | nn.ReLU(True), 76 | nn.Dropout(), 77 | nn.Linear(512, 512), 78 | nn.ReLU(True), 79 | ) 80 | self.classifier = nn.Linear(512, self.num_classes) 81 | 82 | # weight initialization. 83 | self._weight_initialization() 84 | 85 | def _weight_initialization(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 89 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 90 | elif isinstance(m, nn.BatchNorm2d): 91 | m.weight.data.fill_(1) 92 | m.bias.data.zero_() 93 | 94 | def _decide_num_classes(self): 95 | if self.dataset == "cifar10" or self.dataset == "svhn": 96 | return 10 97 | elif self.dataset == "cifar100": 98 | return 100 99 | else: 100 | raise ValueError("not allowed dataset.") 101 | 102 | def _make_layers(self): 103 | layers = [] 104 | in_channels = 3 105 | for v in ARCHITECTURES[self.nn_arch]: 106 | if v == "M": 107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 108 | else: 109 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 110 | if self.use_bn: 111 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 112 | else: 113 | layers += [conv2d, nn.ReLU(inplace=True)] 114 | in_channels = v 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x): 118 | x = self.features(x) 119 | x = x.view(x.size(0), -1) 120 | x = self.intermediate_classifier(x) 121 | x = self.classifier(x) 122 | return x 123 | 124 | 125 | class VGG_S(nn.Module): 126 | def __init__(self, nn_arch, dataset, width=1, use_bn=True, save_activations=False): 127 | super(VGG_S, self).__init__() 128 | 129 | # init parameters. 130 | self.use_bn = use_bn 131 | self.nn_arch = nn_arch 132 | self.width = width 133 | self.dataset = dataset 134 | self.num_classes = self._decide_num_classes() 135 | 136 | # init models. 137 | self.features = self._make_layers() 138 | self.classifier = nn.Linear(int(32 * width), self.num_classes) 139 | 140 | # weight initialization. 141 | self._weight_initialization() 142 | 143 | # a placeholder for activations in the intermediate layers. 144 | self.save_activations = save_activations 145 | self.activations = None 146 | 147 | def _weight_initialization(self): 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv2d): 150 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 151 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 152 | elif isinstance(m, nn.BatchNorm2d): 153 | m.weight.data.fill_(1) 154 | m.bias.data.zero_() 155 | 156 | def _decide_num_classes(self): 157 | if self.dataset == "cifar10" or self.dataset == "svhn": 158 | return 10 159 | elif self.dataset == "cifar100": 160 | return 100 161 | else: 162 | raise ValueError("not allowed dataset.") 163 | 164 | def _make_layers(self): 165 | layers = [] 166 | in_channels = 3 167 | for v in ARCHITECTURES[self.nn_arch]: 168 | if v == "M": 169 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 170 | else: 171 | out_planes = int(v * self.width) 172 | conv2d = nn.Conv2d(in_channels, out_planes, kernel_size=3, padding=1) 173 | if self.use_bn: 174 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 175 | else: 176 | layers += [conv2d, nn.ReLU(inplace=True)] 177 | in_channels = out_planes 178 | return nn.Sequential(*layers) 179 | 180 | def forward(self, x): 181 | x = self.features(x) 182 | x = x.view(x.size(0), -1) 183 | x = self.classifier(x) 184 | return x 185 | 186 | 187 | def vgg(conf): 188 | use_bn = "bn" in conf.arch 189 | dataset = conf.data 190 | 191 | if conf.vgg_scaling is not None: 192 | return VGG_S( 193 | nn_arch="O", dataset=dataset, width=conf.vgg_scaling, use_bn=use_bn 194 | ) 195 | else: 196 | if "11" in conf.arch: 197 | return VGG(nn_arch="A", dataset=dataset, use_bn=use_bn) 198 | elif "13" in conf.arch: 199 | return VGG(nn_arch="B", dataset=dataset, use_bn=use_bn) 200 | elif "16" in conf.arch: 201 | return VGG(nn_arch="D", dataset=dataset, use_bn=use_bn) 202 | elif "19" in conf.arch: 203 | return VGG(nn_arch="E", dataset=dataset, use_bn=use_bn) 204 | else: 205 | raise NotImplementedError 206 | 207 | 208 | if __name__ == "__main__": 209 | 210 | def get_n_model_params(model): 211 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 212 | 213 | width = 8 214 | net = VGG_S(nn_arch="O", dataset="cifar10", width=width, use_bn=False) 215 | print(f"VGG with width={width} has n_params={get_n_model_params(net)}M.") 216 | 217 | # x = torch.randn(1, 3, 32, 32) 218 | # y = net(x) 219 | # print(y.shape) 220 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | __all__ = ["wideresnet"] 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0): 13 | super(BasicBlock, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.relu1 = nn.ReLU(inplace=True) 16 | self.conv1 = nn.Conv2d( 17 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 18 | ) 19 | self.bn2 = nn.BatchNorm2d(out_planes) 20 | self.relu2 = nn.ReLU(inplace=True) 21 | self.conv2 = nn.Conv2d( 22 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False 23 | ) 24 | self.droprate = drop_rate 25 | self.equal_in_out = in_planes == out_planes 26 | self.conv_shortcut = ( 27 | (not self.equal_in_out) 28 | and nn.Conv2d( 29 | in_planes, 30 | out_planes, 31 | kernel_size=1, 32 | stride=stride, 33 | padding=0, 34 | bias=False, 35 | ) 36 | or None 37 | ) 38 | 39 | def forward(self, x): 40 | if not self.equal_in_out: 41 | x = self.relu1(self.bn1(x)) 42 | else: 43 | out = self.relu1(self.bn1(x)) 44 | out = self.relu2(self.bn2(self.conv1(out if self.equal_in_out else x))) 45 | if self.droprate > 0: 46 | out = F.dropout(out, p=self.droprate, training=self.training) 47 | out = self.conv2(out) 48 | return torch.add(x if self.equal_in_out else self.conv_shortcut(x), out) 49 | 50 | 51 | class NetworkBlock(nn.Module): 52 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0): 53 | super(NetworkBlock, self).__init__() 54 | self.layer = self._make_layer( 55 | block, in_planes, out_planes, nb_layers, stride, drop_rate 56 | ) 57 | 58 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate): 59 | layers = [] 60 | for i in range(nb_layers): 61 | layers.append( 62 | block( 63 | i == 0 and in_planes or out_planes, 64 | out_planes, 65 | i == 0 and stride or 1, 66 | drop_rate, 67 | ) 68 | ) 69 | return nn.Sequential(*layers) 70 | 71 | def forward(self, x): 72 | return self.layer(x) 73 | 74 | 75 | class WideResNet(nn.Module): 76 | def __init__(self, dataset, net_depth, widen_factor, drop_rate): 77 | super(WideResNet, self).__init__() 78 | 79 | # define fundamental parameters. 80 | self.dataset = dataset 81 | 82 | assert (net_depth - 4) % 6 == 0 83 | num_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 84 | num_blocks = (net_depth - 4) // 6 85 | block = BasicBlock 86 | self.num_classes = self._decide_num_classes() 87 | 88 | # 1st conv before any network block 89 | self.conv1 = nn.Conv2d( 90 | 3, num_channels[0], kernel_size=3, stride=1, padding=1, bias=False 91 | ) 92 | # 1st block 93 | self.block1 = NetworkBlock( 94 | num_blocks, num_channels[0], num_channels[1], block, 1, drop_rate 95 | ) 96 | # 2nd block 97 | self.block2 = NetworkBlock( 98 | num_blocks, num_channels[1], num_channels[2], block, 2, drop_rate 99 | ) 100 | # 3rd block 101 | self.block3 = NetworkBlock( 102 | num_blocks, num_channels[2], num_channels[3], block, 2, drop_rate 103 | ) 104 | 105 | # global average pooling and classifier 106 | self.bn1 = nn.BatchNorm2d(num_channels[3]) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.num_channels = num_channels[3] 109 | self.classifier = nn.Linear(num_channels[3], self.num_classes) 110 | 111 | self._weight_initialization() 112 | 113 | def _decide_num_classes(self): 114 | if self.dataset == "cifar10" or self.dataset == "svhn": 115 | return 10 116 | elif self.dataset == "cifar100": 117 | return 100 118 | elif "imagenet" in self.dataset: 119 | return 1000 120 | 121 | def _weight_initialization(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 125 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.Linear): 130 | m.bias.data.zero_() 131 | 132 | def forward(self, x): 133 | out = self.conv1(x) 134 | out = self.block1(out) 135 | out = self.block2(out) 136 | out = self.block3(out) 137 | out = self.relu(self.bn1(out)) 138 | out = F.avg_pool2d(out, 8) 139 | out = out.view(-1, self.num_channels) 140 | return self.classifier(out) 141 | 142 | 143 | def wideresnet(conf): 144 | net_depth = int(conf.arch.replace("wideresnet", "")) 145 | dataset = conf.data 146 | 147 | if ( 148 | "cifar" in conf.data 149 | or "svhn" in conf.data 150 | or ("imagenet" in conf.data and len(conf.data) > 8) 151 | ): 152 | model = WideResNet( 153 | dataset=dataset, 154 | net_depth=net_depth, 155 | widen_factor=conf.wideresnet_widen_factor, 156 | drop_rate=conf.drop_rate, 157 | ) 158 | return model 159 | else: 160 | raise NotImplementedError 161 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/tools/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/tools/build_downsampled_imagenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import cv2 4 | import pickle 5 | import os 6 | 7 | import numpy as np 8 | from tensorpack.dataflow import PrefetchDataZMQ, LMDBSerializer 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser(description="aug data.") 13 | 14 | # define arguments. 15 | parser.add_argument("--data_dir", default=None) 16 | parser.add_argument("--data_type", default="train", type=str) 17 | parser.add_argument("--img_size", default=32, type=int) 18 | parser.add_argument("--force_delete", default=0, type=int) 19 | 20 | # parse args. 21 | args = parser.parse_args() 22 | 23 | # check args. 24 | assert args.data_dir is not None 25 | return args 26 | 27 | 28 | def unpickle(file): 29 | with open(file, "rb") as fo: 30 | dict = pickle.load(fo) 31 | return dict 32 | 33 | 34 | def sequential_downsampled_imagenet(args): 35 | data = DownsampledImageNet(args.data_dir, args.data_type, args.img_size) 36 | lmdb_file_path = os.path.join( 37 | args.data_dir, f"imagenet{args.img_size}_{args.data_type}.lmdb" 38 | ) 39 | 40 | # delete file if exists. 41 | if os.path.exists(lmdb_file_path) and args.force_delete == 1: 42 | os.remove(lmdb_file_path) 43 | 44 | # serialize to the target path. 45 | ds1 = PrefetchDataZMQ(data, num_proc=1) 46 | LMDBSerializer.save(ds1, lmdb_file_path) 47 | 48 | 49 | class DownsampledImageNet(object): 50 | def __init__(self, root_path, data_type, img_size=32): 51 | self.img_size = img_size 52 | self.img_size_square = self.img_size * self.img_size 53 | folder_path = os.path.join(root_path, f"imagenet{img_size}") 54 | 55 | # get dataset. 56 | list_of_data = [ 57 | unpickle(os.path.join(folder_path, file)) 58 | for file in os.listdir(folder_path) 59 | if ("train" if "train" in data_type else "val") in file 60 | ] 61 | mean_of_image = unpickle( 62 | [ 63 | os.path.join(folder_path, file) 64 | for file in os.listdir(folder_path) 65 | if "train" in file 66 | ][0] 67 | )["mean"] 68 | 69 | # extract features. 70 | self.features, self.labels = self._get_images_and_labels( 71 | list_of_data, mean_of_image=mean_of_image 72 | ) 73 | 74 | def _get_images_and_labels(self, list_of_data, mean_of_image): 75 | def _helper(_feature, _target, _mean): 76 | # process data. 77 | # _feature = _feature - _mean 78 | _target = [x - 1 for x in _target] 79 | return _feature, _target 80 | 81 | features, labels = [], [] 82 | for _data in list_of_data: 83 | # extract raw data. 84 | _feature = _data["data"] 85 | _target = _data["labels"] 86 | _mean = mean_of_image 87 | 88 | # get data. 89 | feature, target = _helper(_feature, _target, _mean) 90 | 91 | # store data. 92 | features.append(feature) 93 | labels.append(target) 94 | 95 | features = np.concatenate(features) 96 | labels = np.concatenate(labels) 97 | return features, labels 98 | 99 | def __len__(self): 100 | return self.features.shape[0] 101 | 102 | def __iter__(self): 103 | idxs = list(range(self.__len__())) 104 | for k in idxs: 105 | if self.features[k] is not None and self.labels[k] is not None: 106 | # feature = cv2.imencode(".jpeg", self.features[k])[1] 107 | yield [self.features[k], self.labels[k]] 108 | 109 | def get_data(self): 110 | return self.__iter__() 111 | 112 | def size(self): 113 | return self.__len__() 114 | 115 | def reset_state(self): 116 | pass 117 | 118 | 119 | def main(args): 120 | sequential_downsampled_imagenet(args) 121 | 122 | 123 | if __name__ == "__main__": 124 | args = get_args() 125 | main(args) 126 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/tools/db.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import socket 4 | import datetime 5 | from copy import deepcopy 6 | from bson.objectid import ObjectId 7 | from itertools import groupby 8 | 9 | import numpy as np 10 | 11 | log_metric_fn = None 12 | mongo = None 13 | 14 | 15 | """some operators for mongodb.""" 16 | 17 | 18 | def init_mongo(conf): 19 | from pymongo import MongoClient 20 | 21 | mongo_client = MongoClient( 22 | host=os.getenv("JOBMONITOR_METADATA_HOST"), 23 | port=int(os.getenv("JOBMONITOR_METADATA_PORT")), 24 | ) 25 | global mongo 26 | 27 | mongo = getattr(mongo_client, os.getenv("JOBMONITOR_METADATA_DB")) 28 | # init the job status at mongodb. 29 | job_content = { 30 | "user": conf.user, 31 | "project": conf.project, 32 | "experiment": conf.experiment, 33 | "config": get_clean_arguments(conf), 34 | "job_id": conf.timestamp, 35 | "rank_id": get_rank(conf), 36 | "host": socket.gethostname(), 37 | "status": "SCHEDULED", 38 | "schedule_time": datetime.datetime.utcnow(), 39 | "output_dir": get_checkpoint_dir(conf), 40 | "is_cuda": conf.on_cuda, 41 | } 42 | 43 | conf._mongo_job_id = mongo.job.insert_one(job_content) 44 | conf.mongo_job_id = {"_id": ObjectId(str(conf._mongo_job_id.inserted_id))} 45 | 46 | # set job to 'started' in MongoDB 47 | update_mongo_record( 48 | conf.mongo_job_id, 49 | {"$set": {"status": "RUNNING", "start_time": datetime.datetime.utcnow()}}, 50 | ) 51 | 52 | 53 | def announce_job_termination_to_mongo(conf): 54 | global mongo 55 | end_time = datetime.datetime.utcnow() 56 | 57 | # update. 58 | mongo.job.update( 59 | conf.mongo_job_id, 60 | { 61 | "$set": { 62 | "status": "FINISHED", 63 | "end_time": end_time, 64 | "lasted_time": ( 65 | end_time 66 | - find_record_from_mongo(conf.mongo_job_id)[0]["start_time"] 67 | ).seconds, 68 | } 69 | }, 70 | ) 71 | 72 | 73 | def update_mongo_record(mongo_job_id, content): 74 | global mongo 75 | mongo.job.update_one(mongo_job_id, content) 76 | 77 | 78 | def find_mongo_record(mongo_job_id): 79 | global mongo 80 | return mongo.job.find_one(mongo_job_id) 81 | 82 | 83 | def delete_mongo_record(mongo_job_id): 84 | global mongo 85 | return mongo.job.delete_one(mongo_job_id) 86 | 87 | 88 | def delete_mongo_collection(): 89 | global mongo 90 | mongo.job.drop() 91 | 92 | 93 | def find_record_from_mongo(condition, projection=None): 94 | # some examples. 95 | # db.find(projection={"pmc_id": True}) 96 | # db.find({"pmc_id": {"$ne": ""}}) 97 | global mongo 98 | return [s for s in mongo.job.find(condition, projection=projection)] 99 | 100 | 101 | def _get_non_duplicated_time(records): 102 | def _get_used_gpu_ids(record): 103 | if "conf" in record: 104 | conf = record["conf"] 105 | gpus = conf["world"].split(",")[: conf["blocks"]] 106 | return set(int(gpu) for gpu in gpus) 107 | else: 108 | return set([0]) 109 | 110 | def _organize_results_per_host(_records): 111 | # preprocessing. 112 | used_gpus = set() 113 | list_of_start_time = [] 114 | list_of_end_time = [] 115 | 116 | for record in _records: 117 | # get used gpus. 118 | record["gpus"] = _get_used_gpu_ids(record) 119 | used_gpus.update(record["gpus"]) 120 | 121 | # drop the microsecond 122 | record["start_time"] = record["start_time"].replace(microsecond=0) 123 | record["end_time"] = record["end_time"].replace(microsecond=0) 124 | list_of_start_time.append(record["start_time"]) 125 | list_of_end_time.append(record["end_time"]) 126 | 127 | # build time matrix. 128 | num_gpus = len(used_gpus) 129 | start_time = min(list_of_start_time) 130 | end_time = max(list_of_end_time) 131 | time_steps = int((end_time - start_time).total_seconds()) 132 | time_matrix = np.zeros((num_gpus, time_steps)) 133 | 134 | # fill in the time matrix. 135 | for record in _records: 136 | for gpu in record["gpus"]: 137 | start_time_idx = int( 138 | (record["start_time"] - start_time).total_seconds() 139 | ) 140 | end_time_idx = int((record["end_time"] - start_time).total_seconds()) 141 | time_matrix[gpu, list(range(start_time_idx, end_time_idx))] = 1 142 | 143 | # merge results 144 | return time_matrix.sum() 145 | 146 | # sort records. 147 | new_records = [] 148 | records = sorted(records, key=lambda x: x["host"]) 149 | 150 | for _, values in groupby(records, key=lambda x: x["host"]): 151 | new_records += [_organize_results_per_host(list(values))] 152 | return sum(new_records) 153 | 154 | 155 | def get_gpu_hours_from_mongo(year, month, day): 156 | # init client. 157 | from pymongo import MongoClient 158 | 159 | mongo_client = MongoClient( 160 | host=os.getenv("JOBMONITOR_METADATA_HOST"), 161 | port=int(os.getenv("JOBMONITOR_METADATA_PORT")), 162 | ) 163 | 164 | mongo = getattr(mongo_client, os.getenv("JOBMONITOR_METADATA_DB")) 165 | 166 | # define the time range. 167 | end_time = datetime.datetime(year, month, day, 23, 59, 59) 168 | start_time = end_time - datetime.timedelta(days=7) 169 | 170 | # get all GPU hours. 171 | matched_records = [ 172 | s 173 | for s in mongo.job.find( 174 | { 175 | "is_cuda": True, 176 | "status": "FINISHED", 177 | "start_time": {"$gt": start_time, "$lt": end_time}, 178 | } 179 | ) 180 | ] 181 | return 1.0 * _get_non_duplicated_time(matched_records) / 60 / 60 182 | 183 | 184 | def get_clean_arguments(conf): 185 | copy_conf = deepcopy(conf) 186 | 187 | if "graph" in conf: 188 | copy_conf._graph = conf.graph.__dict__ 189 | copy_conf.graph = None 190 | return copy_conf.__dict__ 191 | 192 | 193 | def get_rank(args): 194 | return args.graph.rank if "graph" in args else "root" 195 | 196 | 197 | def get_checkpoint_dir(args): 198 | return args.checkpoint_root if "checkpoint_root" in args else "" 199 | 200 | 201 | """some operators for telegraf.""" 202 | 203 | 204 | def init_telegraf(args): 205 | from telegraf.client import TelegrafClient 206 | 207 | telegraf_client = TelegrafClient( 208 | host=os.getenv("JOBMONITOR_TELEGRAF_HOST"), 209 | port=int(os.getenv("JOBMONITOR_TELEGRAF_PORT")), 210 | tags={ 211 | "host": socket.gethostname(), 212 | "user": args.user, 213 | "project": args.project, 214 | "experiment": args.experiment, 215 | "job_id": args.timestamp, 216 | "job_details": args.job_details, 217 | "job_info": args.job_info, 218 | }, 219 | ) 220 | 221 | global log_metric_fn 222 | log_metric_fn = telegraf_client.metric 223 | 224 | 225 | def log_metric(*args): 226 | return log_metric_fn(*args) 227 | 228 | 229 | """some operators for influxdb.""" 230 | 231 | 232 | def init_influxdb(db_name="jobmonitor"): 233 | from influxdb import InfluxDBClient 234 | 235 | influx_client = InfluxDBClient( 236 | host=os.getenv("JOBMONITOR_TIMESERIES_HOST"), database=db_name 237 | ) 238 | return influx_client 239 | 240 | 241 | def get_measurement(cli, measurement=None, tags={}): 242 | rs = cli.query("select * from {}".format(measurement)) 243 | return list(rs.get_points(measurement=measurement, tags=tags)) 244 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/tools/plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from operator import itemgetter 3 | import numpy as np 4 | 5 | from pcode.tools.show_results import reorder_records 6 | from pcode.tools.plot_utils import ( 7 | determine_color_and_lines, 8 | plot_one_case, 9 | smoothing_func, 10 | configure_figure, 11 | build_legend, 12 | groupby_indices, 13 | ) 14 | 15 | 16 | """plot the curve in terms of time.""" 17 | 18 | 19 | def plot_curve_wrt_time( 20 | ax, 21 | records, 22 | x_wrt_sth, 23 | y_wrt_sth, 24 | xlabel, 25 | ylabel, 26 | title=None, 27 | markevery_list=None, 28 | is_smooth=True, 29 | smooth_space=100, 30 | l_subset=0.0, 31 | r_subset=1.0, 32 | reorder_record_item=None, 33 | remove_duplicate=True, 34 | legend=None, 35 | legend_loc="lower right", 36 | legend_ncol=2, 37 | bbox_to_anchor=[0, 0], 38 | ylimit_bottom=None, 39 | ylimit_top=None, 40 | use_log=False, 41 | num_cols=3, 42 | ): 43 | """Each info consists of 44 | ['tr_loss', 'tr_top1', 'tr_time', 'te_top1', 'te_step', 'te_time']. 45 | """ 46 | # parse a list of records. 47 | num_records = len(records) 48 | distinct_conf_set = set() 49 | 50 | # re-order the records. 51 | if reorder_record_item is not None: 52 | records = reorder_records(records, based_on=reorder_record_item) 53 | 54 | count = 0 55 | for ind, (args, info) in enumerate(records): 56 | # build legend. 57 | _legend = build_legend(args, legend) 58 | if _legend in distinct_conf_set and remove_duplicate: 59 | continue 60 | else: 61 | distinct_conf_set.add(_legend) 62 | 63 | # split the y_wrt_sth if it can be splitted. 64 | if ";" in y_wrt_sth: 65 | has_multiple_y = True 66 | list_of_y_wrt_sth = y_wrt_sth.split(";") 67 | else: 68 | has_multiple_y = False 69 | list_of_y_wrt_sth = [y_wrt_sth] 70 | 71 | for _y_wrt_sth in list_of_y_wrt_sth: 72 | # determine the style of line, color and marker. 73 | line_style, color_style, mark_style = determine_color_and_lines( 74 | num_rows=num_records // num_cols, num_cols=num_cols, ind=count 75 | ) 76 | if markevery_list is not None: 77 | mark_every = markevery_list[count] 78 | else: 79 | mark_style = None 80 | mark_every = None 81 | 82 | # update the counter. 83 | count += 1 84 | 85 | # determine if we want to smooth the curve. 86 | if "tr_step" in x_wrt_sth or "tr_epoch" in x_wrt_sth: 87 | info["tr_step"] = list(range(1, 1 + len(info["tr_loss"]))) 88 | if "tr_epoch" == x_wrt_sth: 89 | x = info["tr_step"] 90 | x = [ 91 | 1.0 * _x / args["num_batches_train_per_device_per_epoch"] 92 | for _x in x 93 | ] 94 | else: 95 | x = info[x_wrt_sth] 96 | if "time" in x_wrt_sth: 97 | x = [(time - x[0]).seconds + 1 for time in x] 98 | y = info[_y_wrt_sth] 99 | 100 | if is_smooth: 101 | x, y = smoothing_func(x, y, smooth_space) 102 | 103 | # only plot subtset. 104 | _l_subset, _r_subset = int(len(x) * l_subset), int(len(x) * r_subset) 105 | _x = x[_l_subset:_r_subset] 106 | _y = y[_l_subset:_r_subset] 107 | 108 | # use log scale for y 109 | if use_log: 110 | _y = np.log(_y) 111 | 112 | # plot 113 | ax = plot_one_case( 114 | ax, 115 | x=_x, 116 | y=_y, 117 | label=_legend if not has_multiple_y else _legend + f", {_y_wrt_sth}", 118 | line_style=line_style, 119 | color_style=color_style, 120 | mark_style=mark_style, 121 | mark_every=mark_every, 122 | remove_duplicate=remove_duplicate, 123 | ) 124 | 125 | ax.set_ylim(bottom=ylimit_bottom, top=ylimit_top) 126 | ax = configure_figure( 127 | ax, 128 | xlabel=xlabel, 129 | ylabel=ylabel, 130 | title=title, 131 | has_legend=legend is not None, 132 | legend_loc=legend_loc, 133 | legend_ncol=legend_ncol, 134 | bbox_to_anchor=bbox_to_anchor, 135 | ) 136 | return ax 137 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/tools/plot_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from matplotlib.lines import Line2D 4 | from itertools import groupby 5 | 6 | import seaborn as sns 7 | 8 | """operate x and y.""" 9 | 10 | 11 | def smoothing_func(x, y, smooth_length=10): 12 | def smoothing(end_index): 13 | # print(end_index) 14 | if end_index - smooth_length < 0: 15 | start_index = 0 16 | else: 17 | start_index = end_index - smooth_length 18 | 19 | data = y[start_index:end_index] 20 | if len(data) == 0: 21 | return y[start_index] 22 | else: 23 | return 1.0 * sum(data) / len(data) 24 | 25 | if smooth_length == 0: 26 | _min_length = min(len(x), len(y)) 27 | return x[:_min_length], y[:_min_length] 28 | 29 | # smooth curve 30 | x_, y_ = [], [] 31 | 32 | for end_ind in range(0, len(x)): 33 | x_.append(x[end_ind]) 34 | y_.append(smoothing(end_ind)) 35 | return x_, y_ 36 | 37 | 38 | def reject_outliers(data, threshold=3): 39 | return data[abs(data - np.mean(data)) < threshold * np.std(data)] 40 | 41 | 42 | def groupby_indices(results, grouper): 43 | """group by indices and select the subset parameters""" 44 | out = [] 45 | for key, group in groupby(sorted(results, key=grouper), grouper): 46 | group_item = list(group) 47 | out += [(key, group_item)] 48 | return out 49 | 50 | 51 | def find_same_num_sync(num_update_steps_and_local_step): 52 | list_of_num_sync = [ 53 | num_update_steps // local_step 54 | for num_update_steps, local_step in num_update_steps_and_local_step 55 | ] 56 | return min(list_of_num_sync) 57 | 58 | 59 | def sample_from_records(x, y, local_step, max_same_num_sync): 60 | # cut the records. 61 | if max_same_num_sync is not None: 62 | x = x[: local_step * max_same_num_sync] 63 | y = y[: local_step * max_same_num_sync] 64 | return x[::local_step], y[::local_step] 65 | 66 | 67 | def drop_first_few(x, y, num_drop): 68 | return x[num_drop:], y[num_drop:] 69 | 70 | 71 | def rebuild_runtime_record(times): 72 | times = [(time - times[0]).seconds + 1 for time in times] 73 | return times 74 | 75 | 76 | def add_communication_delay(times, local_step, delay_factor): 77 | """add communication delay to original time.""" 78 | return [ 79 | time + delay_factor * ((ind + 1) // local_step) 80 | for ind, time in enumerate(times) 81 | ] 82 | 83 | 84 | """plot style related.""" 85 | 86 | 87 | def determine_color_and_lines(num_rows, num_cols, ind): 88 | line_styles = ["-", "--", "-.", ":"] 89 | color_styles = [ 90 | "#377eb8", 91 | "#ff7f00", 92 | "#4daf4a", 93 | "#f781bf", 94 | "#a65628", 95 | "#984ea3", 96 | "#999999", 97 | "#e41a1c", 98 | "#dede00", 99 | ] 100 | 101 | num_line_styles = len(line_styles) 102 | num_color_styles = len(color_styles) 103 | total_num_combs = num_line_styles * num_color_styles 104 | 105 | assert total_num_combs > num_rows * num_cols 106 | 107 | if max(num_rows, num_cols) > max(num_line_styles, num_color_styles): 108 | row = ind // num_line_styles 109 | col = ind % num_line_styles 110 | # print('plot {}. case 1, row: {}, col: {}'.format(ind, row, col)) 111 | return line_styles[row], color_styles[col], Line2D.filled_markers[ind] 112 | 113 | denominator = max(num_rows, num_cols) 114 | row = ind // denominator 115 | col = ind % denominator 116 | # print('plot {}. case 2, row: {}, col: {}'.format(ind, row, col)) 117 | return line_styles[row], color_styles[col], Line2D.filled_markers[ind] 118 | 119 | 120 | def configure_figure( 121 | ax, 122 | xlabel, 123 | ylabel, 124 | title=None, 125 | has_legend=True, 126 | legend_loc="lower right", 127 | legend_ncol=2, 128 | bbox_to_anchor=[0, 0], 129 | ): 130 | if has_legend: 131 | ax.legend( 132 | loc=legend_loc, 133 | bbox_to_anchor=bbox_to_anchor, 134 | ncol=legend_ncol, 135 | shadow=True, 136 | fancybox=True, 137 | fontsize=20, 138 | ) 139 | 140 | ax.set_xlabel(xlabel, fontsize=24, labelpad=18) 141 | ax.set_ylabel(ylabel, fontsize=24, labelpad=18) 142 | 143 | if title is not None: 144 | ax.set_title(title, fontsize=24) 145 | ax.xaxis.set_tick_params(labelsize=22) 146 | ax.yaxis.set_tick_params(labelsize=22) 147 | return ax 148 | 149 | 150 | def plot_one_case( 151 | ax, 152 | label, 153 | line_style, 154 | color_style, 155 | mark_style, 156 | line_width=2.0, 157 | mark_every=5000, 158 | x=None, 159 | y=None, 160 | sns_plot=None, 161 | remove_duplicate=False, 162 | ): 163 | if sns_plot is not None and not remove_duplicate: 164 | ax = sns.lineplot( 165 | x="x", 166 | y="y", 167 | data=sns_plot, 168 | label=label, 169 | linewidth=line_width, 170 | linestyle=line_style, 171 | color=color_style, 172 | marker=mark_style, 173 | markevery=mark_every, 174 | markersize=16, 175 | ax=ax, 176 | ) 177 | elif sns_plot is not None and remove_duplicate: 178 | ax = sns.lineplot( 179 | x="x", 180 | y="y", 181 | data=sns_plot, 182 | label=label, 183 | linewidth=line_width, 184 | linestyle=line_style, 185 | color=color_style, 186 | marker=mark_style, 187 | markevery=mark_every, 188 | markersize=16, 189 | ax=ax, 190 | estimator=None, 191 | ) 192 | else: 193 | ax.plot( 194 | x, 195 | y, 196 | label=label, 197 | linewidth=line_width, 198 | linestyle=line_style, 199 | color=color_style, 200 | marker=mark_style, 201 | markevery=mark_every, 202 | markersize=16, 203 | ) 204 | return ax 205 | 206 | 207 | def build_legend(args, legend): 208 | legend = legend.split(",") 209 | 210 | my_legend = [] 211 | for _legend in legend: 212 | _legend_content = args[_legend] if _legend in args else -1 213 | my_legend += [ 214 | "{}={}".format( 215 | _legend, 216 | list(_legend_content)[0] 217 | if "pandas" in str(type(_legend_content)) 218 | else _legend_content, 219 | ) 220 | ] 221 | return ", ".join(my_legend) 222 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/federated-learning-public-code/44dd5551071969eb354ee2ed091a9ba36c1d5b08/codes/FedDF-code/pcode/utils/__init__.py -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/auxiliary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def deepcopy_model(conf, model): 10 | # a dirty hack.... 11 | tmp_model = deepcopy(model) 12 | if conf.track_model_aggregation: 13 | for tmp_para, para in zip(tmp_model.parameters(), model.parameters()): 14 | tmp_para.grad = para.grad.clone() 15 | return tmp_model 16 | 17 | 18 | def get_model_difference(model1, model2): 19 | list_of_tensors = [] 20 | for weight1, weight2 in zip(model1.parameters(), model2.parameters()): 21 | tensor = get_diff_weights(weight1, weight2) 22 | list_of_tensors.append(tensor) 23 | return list_to_vec(list_of_tensors).norm().item() 24 | 25 | 26 | def get_diff_weights(weights1, weights2): 27 | """ Produce a direction from 'weights1' to 'weights2'.""" 28 | if isinstance(weights1, list) and isinstance(weights2, list): 29 | return [w2 - w1 for (w1, w2) in zip(weights1, weights2)] 30 | elif isinstance(weights1, torch.Tensor) and isinstance(weights2, torch.Tensor): 31 | return weights2 - weights1 32 | else: 33 | raise NotImplementedError 34 | 35 | 36 | def get_diff_states(states1, states2): 37 | """ Produce a direction from 'states1' to 'states2'.""" 38 | return [v2 - v1 for (k1, v1), (k2, v2) in zip(states1.items(), states2.items())] 39 | 40 | 41 | def list_to_vec(weights): 42 | """Concatnate a numpy list of weights of all layers into one torch vector.""" 43 | v = [] 44 | direction = [d * np.float64(1.0) for d in weights] 45 | for w in direction: 46 | if isinstance(w, np.ndarray): 47 | w = torch.tensor(w) 48 | else: 49 | w = w.clone().detach() 50 | if w.dim() > 1: 51 | v.append(w.view(w.numel())) 52 | elif w.dim() == 1: 53 | v.append(w) 54 | return torch.cat(v) 55 | 56 | 57 | def str2time(string, pattern): 58 | """convert the string to the datetime.""" 59 | return datetime.strptime(string, pattern) 60 | 61 | 62 | def get_fullname(o): 63 | """get the full name of the class.""" 64 | return "%s.%s" % (o.__module__, o.__class__.__name__) 65 | 66 | 67 | def is_float(value): 68 | try: 69 | float(value) 70 | return True 71 | except: 72 | return False 73 | 74 | 75 | class dict2obj(object): 76 | def __init__(self, d): 77 | for a, b in d.items(): 78 | if isinstance(b, (list, tuple)): 79 | setattr(self, a, [dict2obj(x) if isinstance(x, dict) else x for x in b]) 80 | else: 81 | setattr(self, a, dict2obj(b) if isinstance(b, dict) else b) 82 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import shutil 4 | import json 5 | from os.path import join 6 | 7 | import torch 8 | 9 | from pcode.utils.op_paths import build_dirs 10 | from pcode.utils.op_files import is_jsonable 11 | 12 | 13 | def get_checkpoint_folder_name(conf): 14 | # get optimizer info. 15 | optim_info = "{}".format(conf.optimizer) 16 | 17 | # get n_participated 18 | conf.n_participated = int(conf.n_clients * conf.participation_ratio + 0.5) 19 | 20 | # concat them together. 21 | return "_l2-{}_lr-{}_n_comm_rounds-{}_local_n_epochs-{}_batchsize-{}_n_clients_{}_n_participated-{}_optim-{}_agg_scheme-{}".format( 22 | conf.weight_decay, 23 | conf.lr, 24 | conf.n_comm_rounds, 25 | conf.local_n_epochs, 26 | conf.batch_size, 27 | conf.n_clients, 28 | conf.n_participated, 29 | optim_info, 30 | conf.fl_aggregate_scheme, 31 | ) 32 | 33 | 34 | def init_checkpoint(conf, rank=None): 35 | # init checkpoint_root for the main process. 36 | conf.checkpoint_root = join( 37 | conf.checkpoint, 38 | conf.data, 39 | conf.arch, 40 | conf.experiment, 41 | conf.timestamp + get_checkpoint_folder_name(conf), 42 | ) 43 | if conf.save_some_models is not None: 44 | conf.save_some_models = conf.save_some_models.split(",") 45 | 46 | if rank is None: 47 | # if the directory does not exists, create them. 48 | build_dirs(conf.checkpoint_root) 49 | else: 50 | conf.checkpoint_dir = join(conf.checkpoint_root, rank) 51 | build_dirs(conf.checkpoint_dir) 52 | 53 | 54 | def _save_to_checkpoint(state, dirname, filename): 55 | checkpoint_path = join(dirname, filename) 56 | torch.save(state, checkpoint_path) 57 | return checkpoint_path 58 | 59 | 60 | def save_arguments(conf): 61 | # save the configure file to the checkpoint. 62 | # write_pickle(conf, path=join(conf.checkpoint_root, "arguments.pickle")) 63 | with open(join(conf.checkpoint_root, "arguments.json"), "w") as fp: 64 | json.dump( 65 | dict( 66 | [ 67 | (k, v) 68 | for k, v in conf.__dict__.items() 69 | if is_jsonable(v) and type(v) is not torch.Tensor 70 | ] 71 | ), 72 | fp, 73 | indent=" ", 74 | ) 75 | 76 | 77 | def save_to_checkpoint(conf, state, is_best, dirname, filename, save_all=False): 78 | # save full state. 79 | checkpoint_path = _save_to_checkpoint(state, dirname, filename) 80 | best_model_path = join(dirname, "model_best.pth.tar") 81 | if is_best: 82 | shutil.copyfile(checkpoint_path, best_model_path) 83 | if save_all: 84 | shutil.copyfile( 85 | checkpoint_path, 86 | join( 87 | dirname, "checkpoint_c_round_%s.pth.tar" % state["current_comm_round"] 88 | ), 89 | ) 90 | elif conf.save_some_models is not None: 91 | if str(state["current_comm_round"]) in conf.save_some_models: 92 | shutil.copyfile( 93 | checkpoint_path, 94 | join( 95 | dirname, 96 | "checkpoint_c_round_%s.pth.tar" % state["current_comm_round"], 97 | ), 98 | ) 99 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | class EarlyStoppingTracker(object): 5 | def __init__(self, patience, delta=0, mode="max"): 6 | self.patience = patience 7 | self.delta = delta 8 | self.mode = mode 9 | self.best_value = None 10 | self.counter = 0 11 | 12 | def __call__(self, value): 13 | if self.patience is None or self.patience <= 0: 14 | return False 15 | 16 | if self.best_value is None: 17 | self.best_value = value 18 | self.counter = 0 19 | return False 20 | 21 | if self.mode == "max": 22 | if value > self.best_value + self.delta: 23 | return self._positive_update(value) 24 | else: 25 | return self._negative_update(value) 26 | elif self.mode == "min": 27 | if value < self.best_value - self.delta: 28 | return self._positive_update(value) 29 | else: 30 | return self._negative_update(value) 31 | else: 32 | raise ValueError(f"Illegal mode for early stopping: {self.mode}") 33 | 34 | def _positive_update(self, value): 35 | self.counter = 0 36 | self.best_value = value 37 | return False 38 | 39 | def _negative_update(self, value): 40 | self.counter += 1 41 | if self.counter > self.patience: 42 | return True 43 | else: 44 | return False 45 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/error_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | 5 | def global_except_hook(exctype, value, traceback): 6 | import sys 7 | 8 | try: 9 | import mpi4py.MPI 10 | 11 | sys.stderr.write("\n*****************************************************\n") 12 | sys.stderr.write( 13 | "Uncaught exception was detected on rank {}. \n".format( 14 | mpi4py.MPI.COMM_WORLD.Get_rank() 15 | ) 16 | ) 17 | from traceback import print_exception 18 | 19 | print_exception(exctype, value, traceback) 20 | sys.stderr.write("*****************************************************\n\n\n") 21 | sys.stderr.write("\n") 22 | sys.stderr.write("Calling MPI_Abort() to shut down MPI processes...\n") 23 | sys.stderr.flush() 24 | finally: 25 | try: 26 | import mpi4py.MPI 27 | 28 | mpi4py.MPI.COMM_WORLD.Abort(1) 29 | except Exception as e: 30 | sys.stderr.write("*****************************************************\n") 31 | sys.stderr.write("Sorry, we failed to stop MPI, this process will hang.\n") 32 | sys.stderr.write("*****************************************************\n") 33 | sys.stderr.flush() 34 | raise e 35 | 36 | 37 | def error_abort(): 38 | try: 39 | import mpi4py.MPI 40 | 41 | mpi4py.MPI.COMM_WORLD.Abort(1) 42 | except Exception as e: 43 | sys.stderr.write("*****************************************************\n") 44 | sys.stderr.write("Sorry, we failed to stop MPI, this process will hang.\n") 45 | sys.stderr.write("*****************************************************\n") 46 | sys.stderr.flush() 47 | raise e 48 | 49 | 50 | def abort(): 51 | error_abort() 52 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import time 5 | import platform 6 | 7 | from pcode.utils.op_files import write_txt 8 | 9 | 10 | class Logger: 11 | """ 12 | Very simple prototype logger that will store the values to a JSON file 13 | """ 14 | 15 | def __init__(self, file_folder): 16 | """ 17 | :param filename: ending with .json 18 | :param auto_save: save the JSON file after every addition 19 | """ 20 | self.file_folder = file_folder 21 | self.file_json = os.path.join(file_folder, "log-1.json") 22 | self.file_txt = os.path.join(file_folder, "log.txt") 23 | self.values = [] 24 | 25 | def log_metric(self, name, values, tags, display=False): 26 | """ 27 | Store a scalar metric 28 | 29 | :param name: measurement, like 'accuracy' 30 | :param values: dictionary, like { epoch: 3, value: 0.23 } 31 | :param tags: dictionary, like { split: train } 32 | """ 33 | self.values.append({"measurement": name, **values, **tags}) 34 | if display: 35 | print( 36 | "{name}: {values} ({tags})".format(name=name, values=values, tags=tags) 37 | ) 38 | 39 | def log(self, value, display=True): 40 | content = time.strftime("%Y-%m-%d %H:%M:%S") + "\t" + value 41 | print(content) 42 | self.save_txt(content) 43 | 44 | def save_json(self): 45 | """ 46 | Save the internal memory to a file 47 | """ 48 | with open(self.file_json, "w") as fp: 49 | json.dump(self.values, fp, indent=" ") 50 | 51 | # reset 'values' and redirect the json file to other name. 52 | if self.meet_cache_limit(): 53 | self.values = [] 54 | self.redirect_new_json() 55 | 56 | def save_txt(self, value): 57 | write_txt(value + "\n", self.file_txt, type="a") 58 | 59 | def redirect_new_json(self): 60 | """get the number of existing json files under the current folder.""" 61 | existing_json_files = [ 62 | file for file in os.listdir(self.file_folder) if "json" in file 63 | ] 64 | self.file_json = os.path.join( 65 | self.file_folder, "log-{}.json".format(len(existing_json_files) + 1) 66 | ) 67 | 68 | def meet_cache_limit(self): 69 | return True if len(self.values) > 1e4 else False 70 | 71 | 72 | def display_args(conf): 73 | print("\n\nparameters: ") 74 | for arg in vars(conf): 75 | print("\t" + str(arg) + "\t" + str(getattr(conf, arg))) 76 | 77 | print( 78 | "\n\nexperiment platform: on {} {}-{}".format( 79 | platform.node(), 80 | "GPU" if conf.graph.on_cuda else "CPU", 81 | conf.graph.primary_device, 82 | ) 83 | ) 84 | for name in ["n_participated", "world", "rank", "devices", "on_cuda"]: 85 | print("\t{}: {}".format(name, getattr(conf.graph, name))) 86 | print("\n\n") 87 | 88 | 89 | def display_training_stat(conf, scheduler, tracker): 90 | current_time = time.strftime("%Y-%m-%d %H:%M:%S") 91 | 92 | # display the runtime training information. 93 | conf.logger.log_metric( 94 | name="runtime", 95 | values={ 96 | "time": current_time, 97 | "worker_id": conf.graph.worker_id, 98 | "client_id": conf.graph.client_id, 99 | "comm_round": conf.graph.comm_round, 100 | "epoch": scheduler.epoch_, 101 | "local_index": scheduler.local_index, 102 | **tracker(), 103 | }, 104 | tags={"split": "train"}, 105 | display=True, 106 | ) 107 | 108 | 109 | def display_test_stat(conf, coordinator, tracker, label="local"): 110 | current_time = time.strftime("%Y-%m-%d %H:%M:%S") 111 | 112 | # display the runtime training information. 113 | conf.logger.log_metric( 114 | name="runtime", 115 | values={"time": current_time, "comm_round": conf.graph.comm_round, **tracker()}, 116 | tags={"split": "test", "type": label}, 117 | display=True, 118 | ) 119 | conf.logger.save_json() 120 | 121 | 122 | def dispaly_best_test_stat(conf, coordinator): 123 | current_time = time.strftime("%Y-%m-%d %H:%M:%S") 124 | 125 | conf.logger.log_metric( 126 | name="runtime", 127 | values={ 128 | "time": current_time, 129 | "comm_round": conf.graph.comm_round, 130 | "best_perfs": coordinator(), 131 | }, 132 | tags={"split": "test", "type": "aggregated_model"}, 133 | display=False, 134 | ) 135 | 136 | for name, best_tracker in coordinator.best_trackers.items(): 137 | conf.logger.log( 138 | "Best performance of {} \ 139 | (best comm_round {:.3f}, current comm_round {:.3f}): {}.".format( 140 | name, 141 | best_tracker.get_best_perf_loc, 142 | conf.graph.comm_round, 143 | best_tracker.best_perf, 144 | ) 145 | ) 146 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/mathdict.py: -------------------------------------------------------------------------------- 1 | class MathDict: 2 | def __init__(self, dictionary): 3 | self.dictionary = dictionary 4 | self.keys = set(dictionary.keys()) 5 | 6 | def __str__(self): 7 | return "MathDict({})".format(str(self.dictionary)) 8 | 9 | def __repr__(self): 10 | return "MathDict({})".format(repr(self.dictionary)) 11 | 12 | def map(self, mapfun): 13 | new_dict = {} 14 | for key in self.keys: 15 | new_dict[key] = mapfun(self.dictionary[key]) 16 | return MathDict(new_dict) 17 | 18 | def filter(self, condfun): 19 | new_dict = {} 20 | for key in self.keys: 21 | if condfun(key): 22 | new_dict[key] = self.dictionary[key] 23 | return MathDict(new_dict) 24 | 25 | def detach(self): 26 | for key in self.keys: 27 | self.dictionary[key] = self.dictionary[key].detach() 28 | 29 | def values(self): 30 | return self.dictionary.values() 31 | 32 | def items(self): 33 | return self.dictionary.items() 34 | 35 | 36 | def _mathdict_binary_op(operation): 37 | def op(self, other): 38 | new_dict = {} 39 | if isinstance(other, MathDict): 40 | assert other.keys == self.keys 41 | for key in self.keys: 42 | new_dict[key] = operation(self.dictionary[key], other.dictionary[key]) 43 | else: 44 | for key in self.keys: 45 | new_dict[key] = operation(self.dictionary[key], other) 46 | return MathDict(new_dict) 47 | 48 | return op 49 | 50 | 51 | def _mathdict_map_op(operation): 52 | def op(self, *args, **kwargs): 53 | new_dict = {} 54 | for key in self.keys: 55 | new_dict[key] = operation(self.dictionary[key], args, kwargs) 56 | return MathDict(new_dict) 57 | 58 | return op 59 | 60 | 61 | def _mathdict_binary_in_place_op(operation): 62 | def op(self, other): 63 | if isinstance(other, MathDict): 64 | assert other.keys == self.keys 65 | for key in self.keys: 66 | operation(self.dictionary, key, other.dictionary[key]) 67 | else: 68 | for key in self.keys: 69 | operation(self.dictionary, key, other) 70 | return self 71 | 72 | return op 73 | 74 | 75 | def _iadd(dict, key, b): 76 | dict[key] += b 77 | 78 | 79 | def _isub(dict, key, b): 80 | dict[key] -= b 81 | 82 | 83 | def _imul(dict, key, b): 84 | dict[key] *= b 85 | 86 | 87 | def _itruediv(dict, key, b): 88 | dict[key] /= b 89 | 90 | 91 | def _ifloordiv(dict, key, b): 92 | dict[key] //= b 93 | 94 | 95 | MathDict.__add__ = _mathdict_binary_op(lambda a, b: a + b) 96 | MathDict.__sub__ = _mathdict_binary_op(lambda a, b: a - b) 97 | MathDict.__rsub__ = _mathdict_binary_op(lambda a, b: b - a) 98 | MathDict.__mul__ = _mathdict_binary_op(lambda a, b: a * b) 99 | MathDict.__rmul__ = _mathdict_binary_op(lambda a, b: a * b) 100 | MathDict.__truediv__ = _mathdict_binary_op(lambda a, b: a / b) 101 | MathDict.__floordiv__ = _mathdict_binary_op(lambda a, b: a // b) 102 | MathDict.__getitem__ = _mathdict_map_op( 103 | lambda x, args, kwargs: x.__getitem__(*args, **kwargs) 104 | ) 105 | MathDict.__iadd__ = _mathdict_binary_in_place_op(_iadd) 106 | MathDict.__isub__ = _mathdict_binary_in_place_op(_isub) 107 | MathDict.__imul__ = _mathdict_binary_in_place_op(_imul) 108 | MathDict.__itruediv__ = _mathdict_binary_in_place_op(_itruediv) 109 | MathDict.__ifloordiv__ = _mathdict_binary_in_place_op(_ifloordiv) 110 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 8 | 9 | 10 | torch_dtypes = { 11 | "float": torch.float, 12 | "float32": torch.float32, 13 | "float64": torch.float64, 14 | "double": torch.double, 15 | "float16": torch.float16, 16 | "half": torch.half, 17 | "uint8": torch.uint8, 18 | "int8": torch.int8, 19 | "int16": torch.int16, 20 | "short": torch.short, 21 | "int32": torch.int32, 22 | "int": torch.int, 23 | "int64": torch.int64, 24 | "long": torch.long, 25 | } 26 | 27 | 28 | def onehot(indexes, N=None, ignore_index=None): 29 | """ 30 | Creates a one-representation of indexes with N possible entries 31 | if N is not specified, it will suit the maximum index appearing. 32 | indexes is a long-tensor of indexes 33 | ignore_index will be zero in onehot representation 34 | """ 35 | if N is None: 36 | N = indexes.max() + 1 37 | sz = list(indexes.size()) 38 | output = indexes.new().byte().resize_(*sz, N).zero_() 39 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 40 | if ignore_index is not None and ignore_index >= 0: 41 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 42 | return output 43 | 44 | 45 | def to_one_hot(y, n_dims=None): 46 | """ Take integer y (tensor or variable) with n dims and convert it to 1-hot representation with n+1 dims. """ 47 | y_tensor = y.data 48 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1) 49 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1 50 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1) 51 | y_one_hot = y_one_hot.view(*y.shape, -1) 52 | return y_one_hot 53 | 54 | 55 | def set_global_seeds(i): 56 | try: 57 | import torch 58 | except ImportError: 59 | pass 60 | else: 61 | torch.manual_seed(i) 62 | if torch.cuda.is_available(): 63 | torch.cuda.manual_seed_all(i) 64 | np.random.seed(i) 65 | random.seed(i) 66 | 67 | 68 | class CheckpointModule(nn.Module): 69 | def __init__(self, module, num_segments=1): 70 | super(CheckpointModule, self).__init__() 71 | assert num_segments == 1 or isinstance(module, nn.Sequential) 72 | self.module = module 73 | self.num_segments = num_segments 74 | 75 | def forward(self, x): 76 | if self.num_segments > 1: 77 | return checkpoint_sequential(self.module, self.num_segments, x) 78 | else: 79 | return checkpoint(self.module, x) 80 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/module_state.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | 5 | class ModuleState: 6 | def __init__(self, state_dict): 7 | self.state_dict = state_dict 8 | self.keys = set(state_dict.keys()) 9 | 10 | def __add__(self, other): 11 | assert other.keys == self.keys 12 | assert isinstance(other, ModuleState) 13 | new_dict = {} 14 | for key in self.keys: 15 | new_dict[key] = self.state_dict[key] + other.state_dict[key] 16 | return ModuleState(new_dict) 17 | 18 | def __iadd__(self, other): 19 | assert other.keys == self.keys 20 | assert isinstance(other, ModuleState) 21 | new_dict = {} 22 | for key in self.keys: 23 | self.state_dict[key] += other.state_dict[key] 24 | return self 25 | 26 | def __sub__(self, other): 27 | assert other.keys == self.keys 28 | assert isinstance(other, ModuleState) 29 | new_dict = {} 30 | for key in self.keys: 31 | new_dict[key] = self.state_dict[key] - other.state_dict[key] 32 | return ModuleState(new_dict) 33 | 34 | def __mul__(self, factor): 35 | assert isinstance(factor, float) or isinstance(factor, torch.Tensor) 36 | new_dict = {} 37 | for key in self.keys: 38 | data = self.state_dict[key] 39 | if data.dtype == torch.int64: 40 | # do nothing for integers 41 | new_dict[key] = self.state_dict[key] 42 | else: 43 | new_dict[key] = self.state_dict[key] * factor 44 | return ModuleState(new_dict) 45 | 46 | def mul_by_key(self, factor, by_key): 47 | assert isinstance(factor, float) or isinstance(factor, torch.Tensor) 48 | new_dict = {} 49 | for key in self.keys: 50 | data = self.state_dict[key] 51 | if data.dtype == torch.int64: 52 | # do nothing for integers 53 | new_dict[key] = self.state_dict[key] 54 | elif by_key is not None and by_key == key: 55 | new_dict[key] = self.state_dict[key] * factor 56 | else: 57 | new_dict[key] = self.state_dict[key] 58 | return ModuleState(new_dict) 59 | 60 | def __div__(self, factor): 61 | return self.__mul__(1.0 / factor) 62 | 63 | def copy_to_module(self, module): 64 | """ 65 | Use this to copy the state to a module object when you need to maintain the 66 | computation graph that led to this particular state. This does break the model 67 | for normal optimizers down the line. 68 | """ 69 | for name, module in module.named_modules(): 70 | params = module._parameters 71 | for key in params: 72 | param_name = f"{name}.{key}" 73 | if param_name in self.keys: 74 | params[key] = self.state_dict[param_name] 75 | 76 | __rmul__ = __mul__ 77 | __truediv__ = __div__ 78 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/op_files.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Auxiliary functions that support for system.""" 3 | import os 4 | import json 5 | import pickle 6 | from os.path import exists 7 | from six.moves import cPickle 8 | 9 | 10 | """operate files.""" 11 | 12 | 13 | def read_text_withoutsplit(path): 14 | """read text file from path.""" 15 | with open(path, "r") as f: 16 | return f.read() 17 | 18 | 19 | def read_txt(path): 20 | """read text file from path.""" 21 | with open(path, "r") as f: 22 | return f.read().splitlines() 23 | 24 | 25 | def read_json(path): 26 | """read json file from path.""" 27 | with open(path, "r") as f: 28 | return json.load(f) 29 | 30 | 31 | def write_txt(data, out_path, type="w"): 32 | """write the data to the txt file.""" 33 | with open(out_path, type) as f: 34 | f.write(data) 35 | 36 | 37 | def load_pickle(path): 38 | """load data by pickle.""" 39 | with open(path, "rb") as handle: 40 | return pickle.load(handle) 41 | 42 | 43 | def write_pickle(data, path): 44 | """dump file to dir.""" 45 | print("write --> data to path: {}\n".format(path)) 46 | with open(path, "wb") as handle: 47 | pickle.dump(data, handle) 48 | 49 | 50 | def load_cpickle(path): 51 | """load data by pickle.""" 52 | with open(path, "rb") as handle: 53 | return cPickle.load(handle) 54 | 55 | 56 | def write_cpickle(data, path): 57 | """dump file to dir.""" 58 | print("write --> data to path: {}\n".format(path)) 59 | with open(path, "wb") as handle: 60 | cPickle.dump(data, handle) 61 | 62 | 63 | def output_string(data, path_output, delimiter="\n"): 64 | """join the string in a list and output them to a file.""" 65 | os.remove(path_output) if exists(path_output) else None 66 | 67 | for d in data: 68 | try: 69 | write_txt(d + delimiter, path_output, "a") 70 | except: 71 | print(d) 72 | 73 | 74 | def is_jsonable(x): 75 | try: 76 | json.dumps(x) 77 | return True 78 | except: 79 | return False 80 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/op_paths.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import shutil 4 | 5 | 6 | def get_current_path(conf, rank): 7 | paths = conf.resume.split(",") 8 | splited_paths = map(lambda p: p.split("/")[-1].split("-")[:1], paths) 9 | splited_paths_dict = dict( 10 | [(path, paths[ind]) for ind, path in enumerate(splited_paths)] 11 | ) 12 | return splited_paths_dict[str(rank)] 13 | 14 | 15 | def build_dir(path, force): 16 | """build directory.""" 17 | if os.path.exists(path) and force: 18 | shutil.rmtree(path) 19 | os.mkdir(path) 20 | elif not os.path.exists(path): 21 | os.mkdir(path) 22 | return path 23 | 24 | 25 | def build_dirs(path): 26 | try: 27 | os.makedirs(path) 28 | except Exception as e: 29 | print(" encounter error: {}".format(e)) 30 | 31 | 32 | def remove_folder(path): 33 | try: 34 | shutil.rmtree(path) 35 | except Exception as e: 36 | print(" encounter error: {}".format(e)) 37 | 38 | 39 | def list_files(root_path): 40 | dirs = os.listdir(root_path) 41 | return [os.path.join(root_path, path) for path in dirs] 42 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/param_parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # some extra parameter parsers 3 | 4 | import argparse 5 | 6 | 7 | def str2bool(v): 8 | if v.lower() in ("yes", "true", "t", "y", "1"): 9 | return True 10 | elif v.lower() in ("no", "false", "f", "n", "0"): 11 | return False 12 | else: 13 | return v 14 | 15 | 16 | def dict_parser(values): 17 | local_dict = {} 18 | for kv in values.split(","): 19 | k, v = kv.split("=") 20 | try: 21 | local_dict[k] = float(v) 22 | except ValueError: 23 | local_dict[k] = str2bool(v) 24 | except ValueError: 25 | local_dict[k] = v 26 | return local_dict 27 | 28 | 29 | class DictParser(argparse.Action): 30 | def __init__(self, *args, **kwargs): 31 | 32 | super(DictParser, self).__init__(*args, **kwargs) 33 | self.local_dict = {} 34 | 35 | def __call__(self, parser, namespace, values, option_string=None): 36 | 37 | try: 38 | self.local_dict = dict_parser(values) 39 | setattr(namespace, self.dest, self.local_dict) 40 | except: 41 | raise ValueError("Failed when parsing %s as dict" % values) 42 | 43 | 44 | class ListParser(argparse.Action): 45 | def __init__(self, *args, **kwargs): 46 | 47 | super(ListParser, self).__init__(*args, **kwargs) 48 | self.local_list = [] 49 | 50 | def __call__(self, parser, namespace, values, option_string=None): 51 | 52 | try: 53 | self.local_list = values.split(",") 54 | setattr(namespace, self.dest, self.local_list) 55 | except: 56 | raise ValueError("Failed when parsing %s as str list" % values) 57 | 58 | 59 | class IntListParser(argparse.Action): 60 | def __init__(self, *args, **kwargs): 61 | 62 | super(IntListParser, self).__init__(*args, **kwargs) 63 | self.local_list = [] 64 | 65 | def __call__(self, parser, namespace, values, option_string=None): 66 | 67 | try: 68 | self.local_list = list(map(int, values.split(","))) 69 | setattr(namespace, self.dest, self.local_list) 70 | except: 71 | raise ValueError("Failed when parsing %s as int list" % values) 72 | 73 | 74 | class FloatListParser(argparse.Action): 75 | def __init__(self, *args, **kwargs): 76 | 77 | super(FloatListParser, self).__init__(*args, **kwargs) 78 | self.local_list = [] 79 | 80 | def __call__(self, parser, namespace, values, option_string=None): 81 | 82 | try: 83 | self.local_list = list(map(float, values.split(","))) 84 | setattr(namespace, self.dest, self.local_list) 85 | except: 86 | raise ValueError("Failed when parsing %s as float list" % values) 87 | 88 | 89 | class BooleanParser(argparse.Action): 90 | def __init__(self, *args, **kwargs): 91 | 92 | super(BooleanParser, self).__init__(*args, **kwargs) 93 | self.values = None 94 | 95 | def __call__(self, parser, namespace, values, option_string=None): 96 | try: 97 | self.values = False if int(values) == 0 else True 98 | setattr(namespace, self.dest, self.values) 99 | except: 100 | raise ValueError("Failed when parsing %s as boolean list" % values) 101 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/sparsification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import bit2byte 8 | 9 | 10 | def get_n_bits(tensor): 11 | return 8 * tensor.nelement() * tensor.element_size() 12 | 13 | 14 | """define some general compressors, e.g., top_k, random_k, sign""" 15 | 16 | 17 | class SparsificationCompressor(object): 18 | def get_top_k(self, x, ratio): 19 | """it will sample the top 1-ratio of the samples.""" 20 | x_data = x.view(-1) 21 | x_len = x_data.nelement() 22 | top_k = max(1, int(x_len * (1 - ratio))) 23 | 24 | # get indices and the corresponding values 25 | if top_k == 1: 26 | _, selected_indices = torch.max(x_data.abs(), dim=0, keepdim=True) 27 | else: 28 | _, selected_indices = torch.topk( 29 | x_data.abs(), top_k, largest=True, sorted=False 30 | ) 31 | return x_data[selected_indices], selected_indices 32 | 33 | def get_mask(self, flatten_arr, indices): 34 | mask = torch.zeros_like(flatten_arr) 35 | mask[indices] = 1 36 | 37 | mask = mask.byte() 38 | return mask.float(), (~mask).float() 39 | 40 | def get_random_k(self, x, ratio, is_biased=True): 41 | """it will randomly sample the 1-ratio of the samples.""" 42 | # get tensor size. 43 | x_data = x.view(-1) 44 | x_len = x_data.nelement() 45 | top_k = max(1, int(x_len * (1 - ratio))) 46 | 47 | # random sample the k indices. 48 | selected_indices = np.random.choice(x_len, top_k, replace=False) 49 | selected_indices = torch.LongTensor(selected_indices).to(x.device) 50 | 51 | if is_biased: 52 | return x_data[selected_indices], selected_indices 53 | else: 54 | return x_len / top_k * x_data[selected_indices], selected_indices 55 | 56 | def compress(self, arr, op, compress_ratio, is_biased): 57 | if "top_k" in op: 58 | values, indices = self.get_top_k(arr, compress_ratio) 59 | elif "random_k" in op: 60 | values, indices = self.get_random_k(arr, compress_ratio) 61 | else: 62 | raise NotImplementedError 63 | 64 | # n_bits = get_n_bits(values) + get_n_bits(indices) 65 | return values, indices 66 | 67 | def uncompress(self, values, indices, selected_shapes, original_shapes): 68 | # apply each param. 69 | sync_pointer = 0 70 | pointer = 0 71 | 72 | _q_values, _q_indices = [], [] 73 | for idx, n_sparse_value in enumerate(selected_shapes): 74 | # get value and indice for the current param. 75 | _q_value = values[sync_pointer : sync_pointer + n_sparse_value] 76 | _q_indice = pointer + indices[sync_pointer : sync_pointer + n_sparse_value] 77 | _q_values += [_q_value] 78 | _q_indices += [_q_indice] 79 | 80 | # update the pointers. 81 | sync_pointer += n_sparse_value 82 | pointer += original_shapes[idx][1] 83 | return torch.cat(_q_values), torch.cat(_q_indices).long() 84 | 85 | 86 | class QuantizationCompressor(object): 87 | def get_qsgd(self, x, s, is_biased=False): 88 | norm = x.norm(p=2) 89 | level_float = s * x.abs() / norm 90 | previous_level = torch.floor(level_float) 91 | is_next_level = (torch.rand_like(x) < (level_float - previous_level)).float() 92 | new_level = previous_level + is_next_level 93 | 94 | scale = 1 95 | if is_biased: 96 | d = x.nelement() 97 | scale = 1.0 / (min(d / (s ** 2), math.sqrt(d) / s) + 1.0) 98 | return scale * torch.sign(x) * norm * new_level / s 99 | 100 | def qsgd_quantize_numpy(self, x, s, is_biased=False): 101 | """quantize the tensor x in d level on the absolute value coef wise""" 102 | norm = np.sqrt(np.sum(np.square(x))) 103 | level_float = s * np.abs(x) / norm 104 | previous_level = np.floor(level_float) 105 | is_next_level = np.random.rand(*x.shape) < (level_float - previous_level) 106 | new_level = previous_level + is_next_level 107 | 108 | scale = 1 109 | if is_biased: 110 | d = len(x) 111 | scale = 1.0 / (np.minimum(d / s ** 2, np.sqrt(d) / s) + 1.0) 112 | return scale * np.sign(x) * norm * new_level / s 113 | 114 | def compress(self, arr, op, quantize_level, is_biased): 115 | s = 2 ** quantize_level - 1 116 | values = self.get_qsgd(arr, s, is_biased) 117 | 118 | # n_bits = get_n_bits(values) * quantize_level / 32 119 | return values 120 | 121 | def uncompress(self, arr): 122 | return arr 123 | 124 | 125 | class SignCompressor(object): 126 | """Taken from https://github.com/PermiJW/signSGD-with-Majority-Vote""" 127 | 128 | def packing(self, src_tensor): 129 | src_tensor = torch.sign(src_tensor) 130 | src_tensor_size = src_tensor.size() 131 | src_tensor = src_tensor.view(-1) 132 | src_len = len(src_tensor) 133 | add_elm = 32 - (src_len % 32) 134 | if src_len % 32 == 0: 135 | add_elm = 0 136 | new_tensor = torch.zeros( 137 | [add_elm], dtype=torch.float32, device=src_tensor.device 138 | ) 139 | src_tensor = torch.cat((src_tensor, new_tensor), 0) 140 | src_tensor = src_tensor.view(32, -1) 141 | src_tensor = src_tensor.to(dtype=torch.int32) 142 | dst_tensor = bit2byte.packing(src_tensor) 143 | dst_tensor = dst_tensor.to(dtype=torch.int32) 144 | return dst_tensor, src_tensor_size 145 | 146 | def unpacking(self, src_tensor, src_tensor_size): 147 | src_element_num = self.element_num(src_tensor_size) 148 | add_elm = 32 - (src_element_num % 32) 149 | if src_element_num % 32 == 0: 150 | add_elm = 0 151 | src_tensor = src_tensor.int() 152 | new_tensor = torch.ones( 153 | src_element_num + add_elm, device=src_tensor.device, dtype=torch.int32 154 | ) 155 | new_tensor = new_tensor.view(32, -1) 156 | new_tensor = bit2byte.unpacking(src_tensor, new_tensor) 157 | new_tensor = new_tensor.view(-1) 158 | new_tensor = new_tensor[:src_element_num] 159 | new_tensor = new_tensor.view(src_tensor_size) 160 | new_tensor = -new_tensor.add_(-1) 161 | new_tensor = new_tensor.float() 162 | return new_tensor 163 | 164 | def majority_vote(self, src_tensor_list): 165 | voter_num = len(src_tensor_list) 166 | src_tensor = torch.stack(src_tensor_list) 167 | src_tensor = src_tensor.view(-1) 168 | full_size = 32 * len(src_tensor) 169 | new_tensor = torch.ones(full_size, device=src_tensor.device, dtype=torch.int32) 170 | new_tensor = new_tensor.view(32, -1) 171 | new_tensor = bit2byte.unpacking(src_tensor, new_tensor) 172 | new_tensor = -new_tensor.add_(-1) 173 | # sum 174 | new_tensor = new_tensor.permute(1, 0).contiguous().view(voter_num, -1) 175 | new_tensor = torch.sum(new_tensor, 0) 176 | new_tensor = new_tensor.view(-1, 32).permute(1, 0) 177 | new_tensor = torch.sign(new_tensor) 178 | new_tensor = bit2byte.packing(new_tensor) 179 | new_tensor = new_tensor.to(dtype=torch.int32) 180 | return new_tensor 181 | 182 | def element_num(self, size): 183 | num = 1 184 | for i in range(len(size)): 185 | num *= size[i] 186 | return num 187 | 188 | def compress(self, src_tensor): 189 | return self.packing(src_tensor) 190 | 191 | def uncompress(self, src_tensor, src_tensor_size): 192 | dst_tensor = self.unpacking(src_tensor, src_tensor_size) 193 | return dst_tensor 194 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/stat_tracker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | 4 | from pcode.utils.communication import global_average 5 | 6 | 7 | class MaxMeter(object): 8 | """ 9 | Keeps track of the max of all the values that are 'add'ed 10 | """ 11 | 12 | def __init__(self): 13 | self.max = None 14 | 15 | def update(self, value): 16 | """ 17 | Add a value to the accumulator. 18 | :return: `true` if the provided value became the new max 19 | """ 20 | if self.max is None or value > self.max: 21 | self.max = deepcopy(value) 22 | return True 23 | else: 24 | return False 25 | 26 | def value(self): 27 | """Access the current running average""" 28 | return self.max 29 | 30 | 31 | class MinMeter(object): 32 | """ 33 | Keeps track of the max of all the values that are 'add'ed 34 | """ 35 | 36 | def __init__(self): 37 | self.min = None 38 | 39 | def update(self, value): 40 | """ 41 | Add a value to the accumulator. 42 | :return: `true` if the provided value became the new max 43 | """ 44 | if self.min is None or value < self.min: 45 | self.min = deepcopy(value) 46 | return True 47 | else: 48 | return False 49 | 50 | def value(self): 51 | """Access the current running average""" 52 | return self.min 53 | 54 | 55 | class AverageMeter(object): 56 | """Computes and stores the average and current value""" 57 | 58 | def __init__(self): 59 | self.reset() 60 | 61 | def reset(self): 62 | self.val = 0 63 | self.avg = 0 64 | self.sum = 0 65 | self.max = -float("inf") 66 | self.min = float("inf") 67 | self.count = 0 68 | 69 | def update(self, val, n=1): 70 | self.val = val 71 | self.sum += val * n 72 | self.count += n 73 | self.avg = self.sum / self.count 74 | self.max = val if val > self.max else self.max 75 | self.min = val if val < self.min else self.min 76 | 77 | 78 | class RuntimeTracker(object): 79 | """Tracking the runtime stat for local training.""" 80 | 81 | def __init__(self, metrics_to_track=["top1"], force_to_replace_metrics=False): 82 | self.metrics_to_track = metrics_to_track 83 | self.things_to_track = ( 84 | ["loss"] + metrics_to_track 85 | if not force_to_replace_metrics 86 | else metrics_to_track 87 | ) 88 | self.reset() 89 | 90 | def reset(self): 91 | self.stat = dict((name, AverageMeter()) for name in self.things_to_track) 92 | 93 | def evaluate_global_metric(self, metric): 94 | return global_average(self.stat[metric].sum, self.stat[metric].count).item() 95 | 96 | def evaluate_global_metrics(self): 97 | return [self.evaluate_global_metric(metric) for metric in self.metrics_to_track] 98 | 99 | def get_metrics_performance(self): 100 | return [self.stat[metric].avg for metric in self.metrics_to_track] 101 | 102 | def update_metrics(self, metric_stat, n_samples): 103 | for idx, thing in enumerate(self.things_to_track): 104 | self.stat[thing].update(metric_stat[idx], n_samples) 105 | 106 | def __call__(self): 107 | return dict((name, val.avg) for name, val in self.stat.items()) 108 | 109 | 110 | class BestPerf(object): 111 | def __init__(self, best_perf=None, larger_is_better=True): 112 | self.best_perf = best_perf 113 | self.cur_perf = None 114 | self.best_perf_locs = [] 115 | self.larger_is_better = larger_is_better 116 | 117 | # define meter 118 | self._define_meter() 119 | 120 | def _define_meter(self): 121 | self.meter = MaxMeter() if self.larger_is_better else MinMeter() 122 | 123 | def update(self, perf, perf_location): 124 | self.is_best = self.meter.update(perf) 125 | self.cur_perf = perf 126 | 127 | if self.is_best: 128 | self.best_perf = perf 129 | self.best_perf_locs += [perf_location] 130 | 131 | @property 132 | def get_best_perf_loc(self): 133 | return self.best_perf_locs[-1] if len(self.best_perf_locs) != 0 else 0 134 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/tensor_buffer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pcode.utils.communication import flatten 3 | 4 | 5 | class TensorBuffer: 6 | """ 7 | Packs multiple tensors into one flat buffer for efficient 8 | intra-worker communication. 9 | """ 10 | 11 | def __init__(self, tensors, use_cuda=True): 12 | indices = [0] 13 | for tensor in tensors: 14 | new_end = indices[-1] + tensor.nelement() 15 | indices.append(new_end) 16 | 17 | self._start_idx = indices[:-1] 18 | self._end_idx = indices[1:] 19 | self._tensors_len = len(tensors) 20 | self._tensors_sizes = [x.size() for x in tensors] 21 | 22 | self.buffer = flatten(tensors, use_cuda=use_cuda) # copies 23 | 24 | def __getitem__(self, index): 25 | return self.buffer[self._start_idx[index] : self._end_idx[index]].view( 26 | self._tensors_sizes[index] 27 | ) 28 | 29 | def __len__(self): 30 | return self._tensors_len 31 | 32 | def is_cuda(self): 33 | return self.buffer.is_cuda 34 | 35 | def nelement(self): 36 | return self.buffer.nelement() 37 | 38 | def unpack(self, tensors): 39 | for tensor, entry in zip(tensors, self): 40 | tensor.data = entry.clone() 41 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from contextlib import contextmanager 4 | from io import StringIO 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class Timer: 11 | """ 12 | Timer for PyTorch code 13 | Comes in the form of a contextmanager: 14 | 15 | Example: 16 | >>> timer = Timer() 17 | ... for i in range(10): 18 | ... with timer("expensive operation"): 19 | ... x = torch.randn(100) 20 | ... print(timer.summary()) 21 | """ 22 | 23 | def __init__(self, verbosity_level=1, log_fn=None, skip_first=True): 24 | self.verbosity_level = verbosity_level 25 | self.log_fn = log_fn if log_fn is not None else self._default_log_fn 26 | self.skip_first = skip_first 27 | self.cuda_available = torch.cuda.is_available() 28 | 29 | self.reset() 30 | 31 | def reset(self): 32 | """Reset the timer""" 33 | self.totals = {} # Total time per label 34 | self.first_time = {} # First occurrence of a label (start time) 35 | self.last_time = {} # Last occurence of a label (end time) 36 | self.call_counts = {} # Number of times a label occurred 37 | 38 | @contextmanager 39 | def __call__(self, label, epoch=-1.0, verbosity=1): 40 | # Don't measure this if the verbosity level is too high 41 | if verbosity > self.verbosity_level: 42 | yield 43 | return 44 | 45 | # Measure the time 46 | self._cuda_sync() 47 | start = time.time() 48 | yield 49 | self._cuda_sync() 50 | end = time.time() 51 | 52 | # Update first and last occurrence of this label 53 | if label not in self.first_time: 54 | self.first_time[label] = start 55 | self.last_time[label] = end 56 | 57 | # Update the totals and call counts 58 | if label not in self.totals and self.skip_first: 59 | self.totals[label] = 0.0 60 | del self.first_time[label] 61 | self.call_counts[label] = 0 62 | elif label not in self.totals and not self.skip_first: 63 | self.totals[label] = end - start 64 | self.call_counts[label] = 1 65 | else: 66 | self.totals[label] += end - start 67 | self.call_counts[label] += 1 68 | 69 | if self.call_counts[label] > 0: 70 | # We will reduce the probability of logging a timing 71 | # linearly with the number of time we have seen it. 72 | # It will always be recorded in the totals, though. 73 | if np.random.rand() < 1 / self.call_counts[label]: 74 | self.log_fn( 75 | "timer", {"epoch": epoch, "value": end - start}, {"event": label} 76 | ) 77 | 78 | def summary(self): 79 | """ 80 | Return a summary in string-form of all the timings recorded so far 81 | """ 82 | if len(self.totals) > 0: 83 | with StringIO() as buffer: 84 | total_avg_time = 0 85 | print("--- Timer summary ------------------------", file=buffer) 86 | print(" Event | Count | Average time | Frac.", file=buffer) 87 | for event_label in sorted(self.totals): 88 | total = self.totals[event_label] 89 | count = self.call_counts[event_label] 90 | if count == 0: 91 | continue 92 | avg_duration = total / count 93 | total_runtime = ( 94 | self.last_time[event_label] - self.first_time[event_label] 95 | ) 96 | runtime_percentage = 100 * total / total_runtime 97 | total_avg_time += avg_duration if "." not in event_label else 0 98 | print( 99 | f"- {event_label:30s} | {count:6d} | {avg_duration:11.5f}s | {runtime_percentage:5.1f}%", 100 | file=buffer, 101 | ) 102 | print("-------------------------------------------", file=buffer) 103 | event_label = "total_averaged_time" 104 | print( 105 | f"- {event_label:30s}| {count:6d} | {total_avg_time:11.5f}s |", 106 | file=buffer, 107 | ) 108 | print("-------------------------------------------", file=buffer) 109 | return buffer.getvalue() 110 | 111 | def _cuda_sync(self): 112 | """Finish all asynchronous GPU computations to get correct timings""" 113 | if self.cuda_available: 114 | torch.cuda.synchronize() 115 | 116 | def _default_log_fn(self, _, values, tags): 117 | label = tags["label"] 118 | epoch = values["epoch"] 119 | duration = values["value"] 120 | print(f"Timer: {label:30s} @ {epoch:4.1f} - {duration:8.5f}s") 121 | -------------------------------------------------------------------------------- /codes/FedDF-code/pcode/utils/topology.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import functools 3 | 4 | 5 | def configure_gpu(world_conf): 6 | # the logic of world_conf follows "a,b,c,d,e" where: 7 | # the block range from 'a' to 'b' with interval 'c' (and each integer will repeat for 'd' time); 8 | # the block will be repeated for 'e' times. 9 | start, stop, interval, local_repeat, block_repeat = [ 10 | int(x) for x in world_conf.split(",") 11 | ] 12 | _block = [ 13 | [x] * local_repeat for x in range(start, stop + 1, interval) 14 | ] * block_repeat 15 | world_list = functools.reduce(lambda a, b: a + b, _block) 16 | return world_list 17 | 18 | 19 | class PhysicalLayout(object): 20 | def __init__(self, n_participated, world, world_conf, on_cuda): 21 | self.n_participated = n_participated 22 | self._world = self.configure_world(world, world_conf) 23 | self._on_cuda = on_cuda 24 | self.rank = -1 25 | 26 | def configure_world(self, world, world_conf): 27 | if world is not None: 28 | world_list = world.split(",") 29 | assert self.n_participated <= len(world_list) 30 | elif world_conf is not None: 31 | # the logic of world_conf follows "a,b,c,d,e" where: 32 | # the block range from 'a' to 'b' with interval 'c' (and each integer will repeat for 'd' time); 33 | # the block will be repeated for 'e' times. 34 | return configure_gpu(world_conf) 35 | else: 36 | raise RuntimeError( 37 | "you should at least make sure world or world_conf is not None." 38 | ) 39 | return [int(l) for l in world_list] 40 | 41 | @property 42 | def primary_device(self): 43 | return self.devices[0] 44 | 45 | @property 46 | def devices(self): 47 | return self.world 48 | 49 | @property 50 | def on_cuda(self): 51 | return self._on_cuda 52 | 53 | @property 54 | def ranks(self): 55 | return list(range(1 + self.n_participated)) 56 | 57 | @property 58 | def world(self): 59 | return self._world 60 | 61 | def get_device(self, rank): 62 | return self.devices[rank] 63 | 64 | def change_n_participated(self, n_participated): 65 | self.n_participated = n_participated 66 | 67 | 68 | def define_graph_topology(world, world_conf, n_participated, on_cuda): 69 | return PhysicalLayout( 70 | n_participated=n_participated, 71 | world=world, 72 | world_conf=world_conf, 73 | on_cuda=on_cuda, 74 | ) 75 | -------------------------------------------------------------------------------- /codes/FedDF-code/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | import os 4 | import time 5 | 6 | import pcode.utils.op_files as op_files 7 | import parameters as para 8 | 9 | 10 | def read_hostfile(file_path): 11 | def _parse(line): 12 | matched_line = re.findall(r"^(.*?) slots=(.*?)$", line, re.DOTALL) 13 | matched_line = [x.strip() for x in matched_line[0]] 14 | return matched_line 15 | 16 | # read file 17 | lines = op_files.read_txt(file_path) 18 | 19 | # use regex to parse the file. 20 | ip2slots = dict(_parse(line) for line in lines) 21 | return ip2slots 22 | 23 | 24 | def map_slot(ip2slots): 25 | ip_slot = [] 26 | for ip, slots in ip2slots.items(): 27 | for _ in range(int(slots)): 28 | ip_slot += [ip] 29 | return ip_slot 30 | 31 | 32 | def run_cmd(conf, cmd): 33 | # run the cmd. 34 | print("\nRun the following cmd:\n" + cmd) 35 | os.system(cmd) 36 | 37 | 38 | def build_mpi_script(conf, replacement=None): 39 | # get the n_participated clients per communication round. 40 | conf.n_participated = int(conf.n_clients * conf.participation_ratio + 0.5) 41 | conf.timestamp = str(int(time.time())) 42 | assert conf.n_participated > 0 43 | 44 | # get prefix_cmd. 45 | if conf.n_participated >= 1: 46 | prefix_cmd = f"mpirun -n {conf.n_participated + 1} --hostfile {conf.hostfile} --mca orte_base_help_aggregate 0 --mca btl_tcp_if_exclude docker0,lo --prefix {conf.mpi_path} " 47 | prefix_cmd += ( 48 | f" -x {conf.mpi_env}" 49 | if conf.mpi_env is not None and len(conf.mpi_env) > 0 50 | else "" 51 | ) 52 | else: 53 | prefix_cmd = "" 54 | 55 | # build complete script. 56 | cmd = " {} main.py ".format(conf.python_path) 57 | for k, v in conf.__dict__.items(): 58 | if replacement is not None and k in replacement: 59 | cmd += " --{} {} ".format(k, replacement[k]) 60 | elif v is not None: 61 | cmd += " --{} {} ".format(k, v) 62 | return prefix_cmd + cmd 63 | 64 | 65 | def main_mpi(conf, ip2slot): 66 | cmd = build_mpi_script(conf) 67 | 68 | # run cmd. 69 | run_cmd(conf, cmd) 70 | 71 | 72 | if __name__ == "__main__": 73 | # parse the arguments. 74 | conf = para.get_args() 75 | 76 | # get ip and the corresponding # of slots. 77 | ip2slots = read_hostfile(conf.hostfile) 78 | ip2slot = map_slot(ip2slots) 79 | 80 | # run the main script. 81 | main_mpi(conf, ip2slot) 82 | -------------------------------------------------------------------------------- /environments/base/.screenrc: -------------------------------------------------------------------------------- 1 | # the following two lines give a two-line status, with the current window highlighted 2 | hardstatus alwayslastline 3 | hardstatus string '%{= kG}[%{G}%H%? %1`%?%{g}][%= %{= kw}%-w%{+b yk} %n*%t%?(%u)%? %{-}%+w %=%{g}][%{B}%m/%d %{W}%C%A%{g}]' 4 | 5 | # huge scrollback buffer 6 | defscrollback 5000 7 | 8 | # no welcome message 9 | startup_message off 10 | 11 | # 256 colors 12 | attrcolor b ".I" 13 | termcapinfo xterm 'Co#256:AB=\E[48;5;%dm:AF=\E[38;5;%dm' 14 | defbce on 15 | 16 | # mouse tracking allows to switch region focus by clicking 17 | mousetrack on 18 | 19 | # default windows 20 | screen -t Shell1 1 bash 21 | screen -t Shell2 2 bash 22 | screen -t Python 3 python 23 | screen -t Media 4 bash 24 | select 0 25 | bind c screen 1 # window numbering starts at 1 not 0 26 | bind 0 select 10 27 | 28 | # get rid of silly xoff stuff 29 | bind s split 30 | 31 | # layouts 32 | layout autosave on 33 | layout new one 34 | select 1 35 | layout new two 36 | select 1 37 | split 38 | resize -v +8 39 | focus down 40 | select 4 41 | focus up 42 | layout new three 43 | select 1 44 | split 45 | resize -v +7 46 | focus down 47 | select 3 48 | split -v 49 | resize -h +10 50 | focus right 51 | select 4 52 | focus up 53 | 54 | layout attach one 55 | layout select one 56 | 57 | # navigating regions with Ctrl-arrows 58 | bindkey "^[[1;5D" focus left 59 | bindkey "^[[1;5C" focus right 60 | bindkey "^[[1;5A" focus up 61 | bindkey "^[[1;5B" focus down 62 | 63 | # switch windows with F3 (prev) and F4 (next) 64 | bindkey "^[OR" prev 65 | bindkey "^[OS" next 66 | 67 | # switch layouts with Ctrl+F3 (prev layout) and Ctrl+F4 (next) 68 | bindkey "^[O1;5R" layout prev 69 | bindkey "^[O1;5S" layout next 70 | 71 | # F2 puts Screen into resize mode. Resize regions using hjkl keys. 72 | bindkey "^[OQ" eval "command -c rsz" # enter resize mode 73 | 74 | # use hjkl keys to resize regions 75 | bind -c rsz h eval "resize -h -5" "command -c rsz" 76 | bind -c rsz j eval "resize -v -5" "command -c rsz" 77 | bind -c rsz k eval "resize -v +5" "command -c rsz" 78 | bind -c rsz l eval "resize -h +5" "command -c rsz" 79 | 80 | # quickly switch between regions using tab and arrows 81 | bind -c rsz \t eval "focus" "command -c rsz" # Tab 82 | bind -c rsz -k kl eval "focus left" "command -c rsz" # Left 83 | bind -c rsz -k kr eval "focus right" "command -c rsz" # Right 84 | bind -c rsz -k ku eval "focus up" "command -c rsz" # Up 85 | bind -c rsz -k kd eval "focus down" "command -c rsz" # Down 86 | -------------------------------------------------------------------------------- /environments/base/.tmux.conf: -------------------------------------------------------------------------------- 1 | # 0 is too far from ` ;) 2 | set -g base-index 1 3 | 4 | # Automatically set window title 5 | set-window-option -g automatic-rename on 6 | set-option -g set-titles on 7 | set-option -g mouse on 8 | 9 | #set -g default-terminal screen-256color 10 | set -g status-keys vi 11 | set -g history-limit 10000 12 | 13 | setw -g mode-keys vi 14 | setw -g monitor-activity on 15 | 16 | bind-key v split-window -h 17 | bind-key s split-window -v 18 | 19 | bind-key J resize-pane -D 5 20 | bind-key K resize-pane -U 5 21 | bind-key H resize-pane -L 5 22 | bind-key L resize-pane -R 5 23 | 24 | bind-key M-j resize-pane -D 25 | bind-key M-k resize-pane -U 26 | bind-key M-h resize-pane -L 27 | bind-key M-l resize-pane -R 28 | 29 | # Vim style pane selection 30 | bind h select-pane -L 31 | bind j select-pane -D 32 | bind k select-pane -U 33 | bind l select-pane -R 34 | 35 | # Use Alt-vim keys without prefix key to switch panes 36 | bind -n M-h select-pane -L 37 | bind -n M-j select-pane -D 38 | bind -n M-k select-pane -U 39 | bind -n M-l select-pane -R 40 | 41 | # Use Alt-arrow keys without prefix key to switch panes 42 | bind -n M-Left select-pane -L 43 | bind -n M-Right select-pane -R 44 | bind -n M-Up select-pane -U 45 | bind -n M-Down select-pane -D 46 | 47 | # Shift arrow to switch windows 48 | bind -n S-Left previous-window 49 | bind -n S-Right next-window 50 | 51 | # No delay for escape key press 52 | set -sg escape-time 0 53 | 54 | # Reload tmux config 55 | bind r source-file ~/.tmux.conf 56 | 57 | # THEME 58 | set -g status-bg black 59 | set -g status-fg white 60 | set -g window-status-current-bg white 61 | set -g window-status-current-fg black 62 | set -g window-status-current-attr bold 63 | set -g status-interval 60 64 | set -g status-left-length 30 65 | set -g status-left '#[fg=green](#S) #(whoami)' 66 | set -g status-right '#[fg=yellow]#(cut -d " " -f 1-3 /proc/loadavg)#[default] #[fg=white]%H:%M#[default]' 67 | -------------------------------------------------------------------------------- /environments/base/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04 2 | 3 | # Some important environment variables in Dockerfile 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | ENV TZ=Asia/Shanghai LANG=C.UTF-8 LC_ALL=C.UTF-8 PIP_NO_CACHE_DIR=1 6 | 7 | # install some necessary tools. 8 | # RUN echo "deb http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list 9 | RUN sed -i "s/archive.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list && \ 10 | sed -i "s/security.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list && \ 11 | rm -f /etc/apt/sources.list.d/* && \ 12 | apt-get update \ 13 | && apt-get install -y --no-install-recommends \ 14 | build-essential \ 15 | ca-certificates \ 16 | pkg-config \ 17 | software-properties-common \ 18 | g++ \ 19 | sudo \ 20 | ccache \ 21 | cmake \ 22 | libjpeg-dev \ 23 | libpng-dev 24 | RUN apt-get install -y \ 25 | rsync \ 26 | swig \ 27 | curl \ 28 | git \ 29 | vim \ 30 | wget \ 31 | unzip \ 32 | zsh \ 33 | screen \ 34 | tmux \ 35 | openssh-server 36 | RUN apt-get update && \ 37 | apt-get install -y pciutils net-tools iputils-ping && \ 38 | apt-get install -y htop && \ 39 | rm -rf /var/lib/apt/lists/* 40 | 41 | RUN /usr/sbin/update-ccache-symlinks 42 | RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache 43 | 44 | # install good vim. 45 | RUN curl http://j.mp/spf13-vim3 -L -o - | sh 46 | 47 | # configure environments. 48 | RUN apt-get update && apt-get install -y locales && rm -rf /var/lib/apt/lists/* 49 | RUN echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && locale-gen 50 | 51 | # configure user. 52 | ENV SHELL=/bin/bash \ 53 | NB_USER=user \ 54 | NB_UID=1000 \ 55 | NB_GROUP=user \ 56 | NB_GID=1000 57 | ENV HOME=/home/$NB_USER 58 | 59 | ADD base/fix-permissions /usr/local/bin/fix-permissions 60 | RUN chmod +x /usr/local/bin/fix-permissions 61 | ADD base/entrypoint.sh /usr/local/bin/entrypoint.sh 62 | RUN chmod +x /usr/local/bin/entrypoint.sh 63 | RUN groupadd $NB_GROUP -g $NB_GID 64 | RUN useradd -m -s /bin/bash -N -u $NB_UID -g $NB_GID $NB_USER && \ 65 | echo "${NB_USER}:${NB_USER}" | chpasswd && \ 66 | usermod -aG sudo,adm,root ${NB_USER} && \ 67 | fix-permissions $HOME 68 | RUN echo 'user ALL=(ALL) NOPASSWD: ALL' | sudo EDITOR='tee -a' visudo 69 | 70 | # Default ssh config file that skips (yes/no) question when first login to the host 71 | RUN mkdir /var/run/sshd 72 | RUN sed -i "s/#PasswordAuthentication.*/PasswordAuthentication no/g" /etc/ssh/sshd_config \ 73 | && sed -i "s/#PermitRootLogin.*/PermitRootLogin yes/g" /etc/ssh/sshd_config \ 74 | && sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config \ 75 | && sed -i "s/#AuthorizedKeysFile/AuthorizedKeysFile/g" /etc/ssh/sshd_config 76 | RUN /usr/bin/ssh-keygen -A 77 | 78 | ENV SSHDIR $HOME/.ssh 79 | RUN mkdir -p $SSHDIR \ 80 | && chmod go-w $HOME/ \ 81 | && chmod 700 $SSHDIR \ 82 | && touch $SSHDIR/authorized_keys \ 83 | && chmod 600 $SSHDIR/authorized_keys \ 84 | && chown -R ${NB_USER}:${NB_GROUP} ${SSHDIR} \ 85 | && chown -R ${NB_USER}:${NB_GROUP} /etc/ssh/* 86 | 87 | ###### switch to user and compile test example. 88 | USER ${NB_USER} 89 | RUN ssh-keygen -b 2048 -t rsa -f $SSHDIR/id_rsa -q -N "" 90 | RUN cat ${SSHDIR}/*.pub >> ${SSHDIR}/authorized_keys 91 | RUN echo "StrictHostKeyChecking no" > ${SSHDIR}/config 92 | 93 | # configure screen and tmux 94 | ADD base/.tmux.conf $HOME/ 95 | ADD base/.screenrc $HOME/ 96 | 97 | # expose port for ssh and start ssh service. 98 | EXPOSE 22 99 | # expose port for notebook. 100 | EXPOSE 8888 101 | # expose port for tensorboard. 102 | EXPOSE 6666 103 | -------------------------------------------------------------------------------- /environments/base/entrypoint.sh: -------------------------------------------------------------------------------- 1 | sudo service ssh start 2 | exec "$@" 3 | -------------------------------------------------------------------------------- /environments/base/fix-permissions: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # set permissions on a directory 3 | # after any installation, if a directory needs to be (human) user-writable, 4 | # run this script on it. 5 | # It will make everything in the directory owned by the group $NB_GID 6 | # and writable by that group. 7 | # Deployments that want to set a specific user id can preserve permissions 8 | # by adding the `--group-add users` line to `docker run`. 9 | 10 | # uses find to avoid touching files that already have the right permissions, 11 | # which would cause massive image explosion 12 | 13 | # right permissions are: 14 | # group=$NB_GID 15 | # AND permissions include group rwX (directory-execute) 16 | # AND directories have setuid,setgid bits set 17 | 18 | set -e 19 | 20 | for d in $@; do 21 | find "$d" \ 22 | ! \( \ 23 | -group $NB_GID \ 24 | -a -perm -g+rwX \ 25 | \) \ 26 | -exec chgrp $NB_GID {} \; \ 27 | -exec chmod g+rwX {} \; 28 | # setuid,setgid *on directories only* 29 | find "$d" \ 30 | \( \ 31 | -type d \ 32 | -a ! -perm -6000 \ 33 | \) \ 34 | -exec chmod +6000 {} \; 35 | done 36 | -------------------------------------------------------------------------------- /environments/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | services: 3 | base: 4 | build: 5 | context: . 6 | dockerfile: base/Dockerfile 7 | image: user/base 8 | pytorch-mpi: 9 | build: 10 | context: . 11 | dockerfile: pytorch-mpi/Dockerfile 12 | image: user/pytorch-mpi 13 | depends_on: 14 | - base 15 | -------------------------------------------------------------------------------- /environments/pytorch-mpi/.condarc: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | show_channel_urls: true 4 | default_channels: 5 | - https://mirrors.bfsu.edu.cn/anaconda/pkgs/main 6 | - https://mirrors.bfsu.edu.cn/anaconda/pkgs/r 7 | - https://mirrors.bfsu.edu.cn/anaconda/pkgs/msys2 8 | custom_channels: 9 | conda-forge: https://mirrors.bfsu.edu.cn/anaconda/cloud 10 | msys2: https://mirrors.bfsu.edu.cn/anaconda/cloud 11 | bioconda: https://mirrors.bfsu.edu.cn/anaconda/cloud 12 | menpo: https://mirrors.bfsu.edu.cn/anaconda/cloud 13 | pytorch: https://mirrors.bfsu.edu.cn/anaconda/cloud 14 | pytorch-lts: https://mirrors.bfsu.edu.cn/anaconda/cloud 15 | simpleitk: https://mirrors.bfsu.edu.cn/anaconda/cloud -------------------------------------------------------------------------------- /environments/pytorch-mpi/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM user/base 2 | 3 | USER $NB_USER 4 | WORKDIR $HOME 5 | 6 | # install openMPI 7 | RUN mkdir $HOME/.openmpi/ 8 | RUN wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.4.tar.gz 9 | RUN gunzip -c openmpi-4.1.4.tar.gz | tar xf - \ 10 | && cd openmpi-4.1.4 \ 11 | && ./configure --prefix=$HOME/.openmpi/ --with-cuda \ 12 | && make all install 13 | 14 | ENV PATH $HOME/.openmpi/bin:$PATH 15 | ENV LD_LIBRARY_PATH $HOME/.openmpi/lib:$LD_LIBRARY_PATH 16 | 17 | # install conda 18 | ENV PYTHON_VERSION=3.8.12 19 | ENV PYTHONUNBUFFERED=1 PYTHONFAULTHANDLER=1 PYTHONHASHSEED=0 20 | RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 21 | bash ~/miniconda.sh -b -p $HOME/conda && \ 22 | rm ~/miniconda.sh 23 | COPY pytorch-mpi/.condarc $HOME/conda/.condarc 24 | RUN $HOME/conda/bin/conda update -n base conda 25 | RUN $HOME/conda/bin/conda create -y --name pytorch-py$PYTHON_VERSION python=$PYTHON_VERSION 26 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple 27 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -y conda-build cython typing typing_extensions numpy pyyaml scipy ipython mkl mkl-include 28 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -y astunparse ninja setuptools cmake cffi future six requests dataclasses 29 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -c pytorch magma-cuda112 30 | ENV PATH $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin:$PATH 31 | RUN $HOME/conda/bin/conda clean -ya 32 | 33 | # install pytorch, torchvision, torchtext. 34 | RUN git clone --recursive https://github.com/pytorch/pytorch 35 | RUN cd pytorch && \ 36 | git checkout tags/v1.11.0 && \ 37 | git submodule sync && \ 38 | git submodule update --init --recursive --jobs 0 && \ 39 | TORCH_CUDA_ARCH_LIST="3.7+PTX;5.0;6.0;6.1;7.0;7.5;8.0;8.6" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ 40 | CMAKE_PREFIX_PATH="$(dirname $(which $HOME/conda/bin/conda))/../" \ 41 | $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python setup.py install 42 | RUN git clone https://github.com/pytorch/vision.git && cd vision && git checkout tags/v0.12.0 && $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python setup.py install 43 | RUN git clone https://github.com/pytorch/text torchtext \ 44 | && cd torchtext \ 45 | && git checkout tags/v0.12.0 \ 46 | && git submodule update --init --recursive \ 47 | && $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python setup.py clean install 48 | 49 | # install other packages. 50 | RUN $HOME/conda/bin/conda install -y --name pytorch-py$PYTHON_VERSION -c conda-forge av ffmpeg tabulate python-blosc 51 | RUN $HOME/conda/bin/conda install -y --name pytorch-py$PYTHON_VERSION scikit-learn protobuf networkx 52 | RUN $HOME/conda/bin/conda install -y --name pytorch-py$PYTHON_VERSION -c anaconda pandas 53 | 54 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install spacy 55 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python -m spacy download en_core_web_sm 56 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python -m spacy download de_core_news_sm 57 | 58 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install nltk==3.4.5 transformers==2.2.2 59 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install pytelegraf pymongo influxdb kubernetes jinja2 60 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install lmdb tensorboard_logger pyarrow msgpack msgpack_numpy mpi4py 61 | 62 | RUN sudo apt-get update && sudo apt install -y libgl1-mesa-glx && sudo rm -rf /var/lib/apt/lists/* 63 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install POT opencv-python 64 | 65 | # clean. 66 | RUN $HOME/conda/bin/conda clean -ya 67 | --------------------------------------------------------------------------------