├── .gitignore ├── README.md ├── assets ├── init_models │ ├── README.md │ ├── resnet20_cifar10_random_1.h5 │ └── yamnet_spcm_random_1.h5 ├── results │ └── README.md └── saved_models │ └── README.md ├── configs.py ├── requirements.txt ├── run.py └── src ├── architectures ├── Resnets.py └── YAMNet.py ├── client.py ├── compress.py ├── data.py ├── data_utils.py ├── distiller.py ├── main.py ├── network.py ├── server.py ├── strategy.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Weight Clustering with Adaptive clustering 2 | 3 | Federated Learning (FL) is a promising technique for the collaborative training of deep neural networks across multiple devices while preserving data privacy. Despite its potential benefits, FL is hindered by excessive communication costs due to repeated server-client communication during training. To address this challenge, model compression techniques, such as sparsification and weight clustering are applied, which often require modifying the underlying model aggregation schemes or involve cumbersome hyperparameter tuning, with the latter not only adjusts the model's compression rate but also limits model's potential for continuous improvement over growing data. In this paper, we propose FedCompress, a novel approach that combines dynamic weight clustering and server-side knowledge distillation to reduce communication costs while learning highly generalizable models. Through a comprehensive evaluation on diverse public datasets, we demonstrate the efficacy of our approach compared to baselines in terms of communication costs and inference speed. 4 | 5 | A complete description of our work can be found in our [paper](https://ieeexplore.ieee.org/document/10447174) (and in our [arxiv](https://arxiv.org/pdf/2401.14211.pdf) version). 6 | 7 | ## Dependencies 8 | 9 | Create a new Python enviroment (virtualenvs, anacoda, etc.) and install all required packages via: 10 | 11 | ```console 12 | foo@bar:~$ pip install -r requirements.txt 13 | ``` 14 | 15 | ## Executing experiments 16 | 17 | From the `root` directory of this repo, run: 18 | 19 | ```console 20 | # Standard FedAvg 21 | foo@bar:~$ ./run.py --datasets cifar10 --method fedavg 22 | # FedAvg + Client-side compression via weight-clustering 23 | foo@bar:~$ ./run.py --datasets cifar10 --method fedavg 24 | # FedCompress (Ours) 25 | foo@bar:~$ ./run.py --datasets cifar10 --method fedavg 26 | ``` 27 | 28 | > **_NOTE:_** You can configure all federated parameters (i.e. number of federated rounds, etc.,) by adjusting them in the `configs.py` file. 29 | 30 | ## Reference 31 | 32 | If you use this repository, please consider citing: 33 | 34 |
@inproceedings{tsouvalas2024communicationefficient,
35 |   author={Tsouvalas, Vasileios and Saeed, Aaqib and Ozcelebi, Tanir and Meratnia, Nirvana},
36 |   booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 
37 |   title={Communication-Efficient Federated Learning Through Adaptive Weight Clustering And Server-Side Distillation}, 
38 |   year={2024},
39 |   pages={5805-5809},
40 |   doi={10.1109/ICASSP48485.2024.10447174}
41 | }
42 | 
43 | -------------------------------------------------------------------------------- /assets/init_models/README.md: -------------------------------------------------------------------------------- 1 | # Store Ramdon Initialization Weights Dir -------------------------------------------------------------------------------- /assets/init_models/resnet20_cifar10_random_1.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FederatedML/FedCompress/86102531097155cdc0bd94f2eac186bc4cdde197/assets/init_models/resnet20_cifar10_random_1.h5 -------------------------------------------------------------------------------- /assets/init_models/yamnet_spcm_random_1.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FederatedML/FedCompress/86102531097155cdc0bd94f2eac186bc4cdde197/assets/init_models/yamnet_spcm_random_1.h5 -------------------------------------------------------------------------------- /assets/results/README.md: -------------------------------------------------------------------------------- 1 | # Store Results Dir 2 | -------------------------------------------------------------------------------- /assets/saved_models/README.md: -------------------------------------------------------------------------------- 1 | # Store Models Dir 2 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from dataclasses import dataclass, field 4 | from typing import List 5 | 6 | @dataclass 7 | class Params: 8 | # Approach 9 | method: str = 'standard' 10 | # Datasets 11 | dataset: str = "cifar10" 12 | ood_dataset: str = "stylegan" 13 | # Model 14 | model: str = 'resnet20' 15 | init_model_fp: str = '' 16 | # Training params 17 | client_epochs: int = 5 18 | client_compress_epochs: int = 5 19 | client_batch_size: int = 128 20 | server_epochs: int = 10 21 | server_batch_size: int = 256 22 | client_normal_lr: float = 1e-3 23 | client_compress_lr: float = 5e-4 24 | server_compress_lr: float = 1e-4 25 | server_compress_temperature: float = 2.0 26 | # Cluster params 27 | cluster_space: List = field(default_factory=lambda: [8,30]) 28 | cluster_init_patience = 3 29 | cluster_patience = 4 30 | cluster_window = 3 31 | cluster_limit = 1 32 | cluster_step = 1 33 | # Federated params 34 | num_rounds: int = 30 35 | num_clients: int = 20 36 | participation: float = 1.0 37 | cpu_usage: float = 1.0 38 | gpu_usage: float = 0.08 39 | seed: int = 0 40 | # Logging params 41 | random_id: str = str(uuid.uuid4())[-10:] 42 | results_dir: str = os.path.abspath('./assets/results') 43 | model_dir: str = os.path.abspath('./assets/saved_models') 44 | 45 | def __post_init__(self): 46 | # Adjust batch size 47 | if self.dataset in ['spcm']: 48 | self.client_batch_size = 128 49 | self.server_batch_size = 256 50 | # Adjust ood data 51 | if self.dataset in ['spcm']: 52 | self.ood_dataset = "librispeech" 53 | self.model = 'yamnet' 54 | # Adjust local training epochs 55 | if self.method == 'standard': 56 | self.client_epochs += self.client_compress_epochs 57 | self.client_compress_epochs = 0 58 | if self.method in ['standard','client']: 59 | self.server_epochs = 0 60 | # Keep this last 61 | self.init_model_fp = os.path.join(os.path.abspath('./assets/init_models/'),"{}_{}_random_1.h5".format(self.model,self.dataset)) 62 | 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | aiohttp==3.8.3 3 | aiohttp-cors==0.7.0 4 | aiosignal==1.3.1 5 | asttokens==2.2.1 6 | astunparse==1.6.3 7 | async-timeout==4.0.2 8 | attrs==22.1.0 9 | backcall==0.2.0 10 | blessed==1.19.1 11 | cachetools==5.2.0 12 | certifi==2022.9.24 13 | charset-normalizer==2.1.1 14 | click==8.0.4 15 | colorful==0.5.5 16 | comm==0.1.1 17 | contourpy==1.0.6 18 | cycler==0.11.0 19 | debugpy==1.6.4 20 | decorator==5.1.1 21 | dill==0.3.6 22 | distlib==0.3.6 23 | dm-tree==0.1.7 24 | entrypoints==0.4 25 | etils==0.9.0 26 | executing==1.2.0 27 | filelock==3.8.2 28 | flatbuffers==22.12.6 29 | flwr==1.1.0 30 | fonttools==4.38.0 31 | frozenlist==1.3.3 32 | future==0.18.3 33 | gast==0.4.0 34 | google-api-core==2.11.0 35 | google-auth==2.15.0 36 | google-auth-oauthlib==0.4.6 37 | google-pasta==0.2.0 38 | googleapis-common-protos==1.57.0 39 | gpustat==1.0.0 40 | grpcio==1.43.0 41 | h5py==3.7.0 42 | idna==3.4 43 | importlib-resources==5.10.1 44 | ipykernel==6.19.0 45 | ipython==8.7.0 46 | iterators==0.0.2 47 | jedi==0.18.2 48 | joblib==1.2.0 49 | jsonschema==4.17.3 50 | jupyter_client==7.4.8 51 | jupyter_core==5.1.0 52 | keras==2.10.0 53 | Keras-Preprocessing==1.1.2 54 | kiwisolver==1.4.4 55 | libclang==14.0.6 56 | Markdown==3.4.1 57 | MarkupSafe==2.1.1 58 | matplotlib==3.6.2 59 | matplotlib-inline==0.1.6 60 | msgpack==1.0.4 61 | multidict==6.0.3 62 | nest-asyncio==1.5.6 63 | numpy==1.23.5 64 | nvidia-ml-py==11.495.46 65 | oauthlib==3.2.2 66 | opencensus==0.11.0 67 | opencensus-context==0.1.3 68 | opt-einsum==3.3.0 69 | packaging==21.3 70 | pandas==1.5.2 71 | parso==0.8.3 72 | pexpect==4.8.0 73 | pickleshare==0.7.5 74 | Pillow==9.3.0 75 | platformdirs==2.6.0 76 | prometheus-client==0.13.1 77 | promise==2.3 78 | prompt-toolkit==3.0.36 79 | protobuf==3.19.6 80 | psutil==5.9.4 81 | ptyprocess==0.7.0 82 | pure-eval==0.2.2 83 | py-spy==0.3.14 84 | pyasn1==0.4.8 85 | pyasn1-modules==0.2.8 86 | pydantic==1.10.2 87 | Pygments==2.13.0 88 | pyparsing==3.0.9 89 | pyrsistent==0.19.2 90 | python-dateutil==2.8.2 91 | pytz==2022.6 92 | PyYAML==6.0 93 | pyzmq==24.0.1 94 | ray==2.1.0 95 | requests==2.28.1 96 | requests-oauthlib==1.3.1 97 | rsa==4.9 98 | scikit-learn==1.2.0 99 | scipy==1.9.3 100 | seaborn==0.12.1 101 | six==1.16.0 102 | smart-open==6.2.0 103 | stack-data==0.6.2 104 | tensorboard==2.10.1 105 | tensorboard-data-server==0.6.1 106 | tensorboard-plugin-wit==1.8.1 107 | tensorflow-addons==0.19.0 108 | tensorflow-datasets==4.8.1 109 | tensorflow-estimator==2.10.0 110 | tensorflow-gpu==2.10.1 111 | tensorflow-io==0.29.0 112 | tensorflow-io-gcs-filesystem==0.29.0 113 | tensorflow-metadata==1.12.0 114 | tensorflow-model-optimization==0.7.3 115 | termcolor==2.1.1 116 | threadpoolctl==3.1.0 117 | toml==0.10.2 118 | tornado==6.2 119 | tqdm==4.64.1 120 | traitlets==5.6.0 121 | typeguard==2.13.3 122 | typing_extensions==4.4.0 123 | urllib3==1.26.13 124 | virtualenv==20.17.1 125 | wcwidth==0.2.5 126 | Werkzeug==2.2.2 127 | wrapt==1.14.1 128 | yarl==1.8.2 129 | zipp==3.11.0 130 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | from configs import Params 5 | os.environ['TF_CPP_MIN_LOG_LEVEL']="3" 6 | os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 7 | os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID" 8 | 9 | def create_parser(): 10 | # Create the parser 11 | parser = argparse.ArgumentParser(description='FedCompress arguments.') 12 | parser.add_argument('--datasets', nargs='+', default=['cifar10'], choices=['cifar10'], help="List of datasets to use. Default is ['cifar10'].") 13 | parser.add_argument('--methods', nargs=1, default=['both'], choices=['fedcompress', 'client', 'fedavg'], help="Method to use. One of 'fedcompress', 'client', 'standard'. Default is 'both'.") 14 | return parser 15 | 16 | def main(args): 17 | datasets = args.datasets 18 | methods = args.methods 19 | rnds = {'cifar10':100,'cifar100':500,'pathmnist':100,'spcm':50,'voxforge':40} 20 | 21 | for dataset in datasets: 22 | for method in methods: 23 | assert method in ['fedcompress','client','fedavg'], 'Parameter `method` must be one of [`fedcompress`, `client`, `fedavg`]. Provided {}.'.format(method) 24 | params = Params(method=method, dataset=dataset, num_rounds=rnds[dataset]) 25 | 26 | call_cmd = ["python3", "./src/main.py", 27 | "--random_id", params.random_id, 28 | "--rounds", str(params.num_rounds), 29 | "--clients",str(params.num_clients), 30 | "--dataset", dataset, 31 | "--model_name", params.model, 32 | "--participation", str(params.participation), 33 | "--cpu_percentage", str(params.cpu_usage), 34 | "--gpu_percentage", str(params.gpu_usage), 35 | "--seed", str(params.seed), 36 | "--server_compression" if method in ['fedcompress'] else None, 37 | "--client_compression" if method in ['fedcompress', 'client'] else None, 38 | # Training params 39 | "--epochs", str(params.client_epochs), 40 | "--learning_rate", str(params.client_normal_lr), 41 | "--client_compression_epochs", str(params.client_compress_epochs), 42 | "--client_compression_lr", str(params.client_compress_lr), 43 | "--batch_size", str(params.client_batch_size), 44 | "--server_compression_epochs", str(params.server_epochs), 45 | "--server_compression_lr", str(params.server_compress_lr), 46 | "--server_compression_temperature", str(params.server_compress_temperature), 47 | "--server_compression_batch", str(params.server_batch_size), 48 | # Logging 49 | "--results_dir", params.results_dir, 50 | "--model_dir", params.model_dir, 51 | "--init_model_fp", params.init_model_fp, 52 | # Compression parameters 53 | "--init_num_clusters", str(params.cluster_space[0]), 54 | "--max_num_clusters", str(params.cluster_space[1]), 55 | "--cluster_update_step", str(params.cluster_step), 56 | "--cluster_search_init_rounds", str(params.cluster_init_patience), 57 | "--cluster_search_window", str(params.cluster_window), 58 | "--cluster_search_patience", str(params.cluster_patience), 59 | "--cluster_search_metric_limit", str(params.cluster_limit), 60 | ] 61 | call_cmd = [c for c in call_cmd if c is not None] 62 | print('\nCalling command:') 63 | print(' '.join(call_cmd),'\n') 64 | subprocess.call(call_cmd) 65 | 66 | if __name__ == "__main__": 67 | args = create_parser().parse_args() 68 | main(args) -------------------------------------------------------------------------------- /src/architectures/Resnets.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def regularized_padded_conv(*args, **kwargs): 4 | return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer, 5 | kernel_initializer='he_normal', use_bias=False) 6 | 7 | def bn_relu(x): 8 | x = tf.keras.layers.BatchNormalization()(x) 9 | return tf.keras.layers.ReLU()(x) 10 | 11 | def shortcut(x, filters, stride, mode): 12 | if x.shape[-1] == filters: 13 | return x 14 | elif mode == 'B': 15 | return regularized_padded_conv(filters, 1, strides=stride)(x) 16 | elif mode == 'B_original': 17 | x = regularized_padded_conv(filters, 1, strides=stride)(x) 18 | return tf.keras.layers.BatchNormalization()(x) 19 | elif mode == 'A': 20 | return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x, 21 | paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])]) 22 | else: 23 | raise KeyError("Parameter shortcut_type not recognized!") 24 | 25 | def original_block(x, filters, stride=1, **kwargs): 26 | c1 = regularized_padded_conv(filters, 3, strides=stride)(x) 27 | c2 = regularized_padded_conv(filters, 3)(bn_relu(c1)) 28 | c2 = tf.keras.layers.BatchNormalization()(c2) 29 | mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type 30 | x = shortcut(x, filters, stride, mode=mode) 31 | x = tf.keras.layers.Add()([x, c2]) 32 | return tf.keras.layers.ReLU()(x) 33 | 34 | def preactivation_block(x, filters, stride=1, preact_block=False): 35 | flow = bn_relu(x) 36 | if preact_block: 37 | x = flow 38 | c1 = regularized_padded_conv(filters, 3, strides=stride)(flow) 39 | if _dropout: 40 | c1 = tf.keras.layers.Dropout(_dropout)(c1) 41 | c2 = regularized_padded_conv(filters, 3)(bn_relu(c1)) 42 | x = shortcut(x, filters, stride, mode=_shortcut_type) 43 | return x + c2 44 | 45 | def bootleneck_block(x, filters, stride=1, preact_block=False): 46 | flow = bn_relu(x) 47 | if preact_block: 48 | x = flow 49 | c1 = regularized_padded_conv(filters//_bootleneck_width, 1)(flow) 50 | c2 = regularized_padded_conv(filters//_bootleneck_width, 3, strides=stride)(bn_relu(c1)) 51 | c3 = regularized_padded_conv(filters, 1)(bn_relu(c2)) 52 | x = shortcut(x, filters, stride, mode=_shortcut_type) 53 | return x + c3 54 | 55 | def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0): 56 | global _preact_shortcuts 57 | preact_block = True if _preact_shortcuts or block_idx == 0 else False 58 | x = block_type(x, filters, stride, preact_block=preact_block) 59 | for i in range(num_blocks-1): 60 | x = block_type(x, filters) 61 | return x 62 | 63 | def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2), 64 | shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 65 | dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True, name='resnet'): 66 | 67 | global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts 68 | _bootleneck_width = bootleneck_width 69 | _regularizer = tf.keras.regularizers.l2(l2_reg) 70 | _shortcut_type = shortcut_type 71 | _cardinality = cardinality 72 | _dropout = dropout 73 | _preact_shortcuts = preact_shortcuts 74 | block_types = {'preactivated': preactivation_block, 'bootleneck': bootleneck_block, 'original': original_block} 75 | 76 | selected_block = block_types[block_type] 77 | inputs = tf.keras.layers.Input(shape=input_shape) 78 | flow = regularized_padded_conv(**first_conv)(inputs) 79 | 80 | if block_type == 'original': 81 | flow = bn_relu(flow) 82 | 83 | for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)): 84 | flow = group_of_blocks(flow, block_type=selected_block, 85 | num_blocks=group_size, block_idx=block_idx, 86 | filters=feature, stride=stride) 87 | if block_type != 'original': 88 | flow = bn_relu(flow) 89 | flow = tf.keras.layers.GlobalAveragePooling2D(name='GMP_layer')(flow) 90 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer)(flow) 91 | model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name) 92 | return model 93 | 94 | def resnet20(input_shape=(32,32,3), num_classes=10, weights_dir=None, 95 | block_type='original', shortcut_type='A', l2_reg=1e-4, name="resnet20"): 96 | model = Resnet(input_shape=input_shape, n_classes=num_classes, l2_reg=l2_reg, group_sizes=(3, 3, 3), features=(16, 32, 64), 97 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type, 98 | block_type=block_type, preact_shortcuts=False, name=name) 99 | if weights_dir is not None: model.load_weights(weights_dir).expect_partial() 100 | return model -------------------------------------------------------------------------------- /src/architectures/YAMNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from dataclasses import dataclass 3 | 4 | @dataclass(frozen=True) 5 | class Params: 6 | conv_padding: str = "same" 7 | batchnorm_center: bool = True 8 | batchnorm_scale: bool = False 9 | batchnorm_epsilon: float = 1e-4 10 | classifier_activation: str = "linear" 11 | 12 | def _batch_norm(name, params): 13 | def _bn_layer(layer_input): 14 | return tf.keras.layers.BatchNormalization(name=name, 15 | center=params.batchnorm_center, 16 | scale=params.batchnorm_scale, 17 | epsilon=params.batchnorm_epsilon)(layer_input) 18 | return _bn_layer 19 | 20 | def _conv(name, kernel, stride, filters, params): 21 | def _conv_layer(layer_input): 22 | output = tf.keras.layers.Conv2D( 23 | name="{}/conv".format(name), 24 | filters=filters, kernel_size=kernel, strides=stride, 25 | padding=params.conv_padding, use_bias=False, activation=None)(layer_input) 26 | output = _batch_norm("{}/conv/bn".format(name), params)(output) 27 | output = tf.keras.layers.ReLU(name="{}/relu".format(name))(output) 28 | return output 29 | return _conv_layer 30 | 31 | def _separable_conv(name, kernel, stride, filters, params): 32 | def _separable_conv_layer(layer_input): 33 | output = tf.keras.layers.DepthwiseConv2D( 34 | name="{}/depthwise_conv".format(name), 35 | kernel_size=kernel, strides=stride, depth_multiplier=1, 36 | padding=params.conv_padding, use_bias=False, activation=None)(layer_input) 37 | output = _batch_norm("{}/depthwise_conv/bn".format(name), params)(output) 38 | output = tf.keras.layers.ReLU(name="{}/depthwise_conv/relu".format(name))(output) 39 | output = tf.keras.layers.Conv2D( 40 | name="{}/pointwise_conv".format(name), 41 | filters=filters, kernel_size=(1, 1), strides=1, 42 | padding=params.conv_padding, use_bias=False, activation=None)(output) 43 | output = _batch_norm("{}/pointwise_conv/bn".format(name), params)(output) 44 | output = tf.keras.layers.ReLU(name="{}/pointwise_conv/relu".format(name))(output) 45 | return output 46 | return _separable_conv_layer 47 | 48 | _YAMNET_LAYER_DEFS = [ 49 | # (layer_function, kernel, stride, num_filters) 50 | (_conv, [3, 3], 2, 32), 51 | (_separable_conv, [3, 3], 1, 64), 52 | (_separable_conv, [3, 3], 2, 128), 53 | (_separable_conv, [3, 3], 1, 128), 54 | (_separable_conv, [3, 3], 2, 256), 55 | (_separable_conv, [3, 3], 1, 256), 56 | (_separable_conv, [3, 3], 2, 512), 57 | (_separable_conv, [3, 3], 1, 512), 58 | (_separable_conv, [3, 3], 1, 512), 59 | (_separable_conv, [3, 3], 1, 512), 60 | (_separable_conv, [3, 3], 1, 512), 61 | (_separable_conv, [3, 3], 1, 512), 62 | (_separable_conv, [3, 3], 2, 1024), 63 | (_separable_conv, [3, 3], 1, 1024) 64 | ] 65 | 66 | """Define the core YAMNet model in Keras.""" 67 | def YAMNet(inputs, num_classes, params): 68 | net = inputs 69 | for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS): 70 | net = layer_fun("layer{}".format(i + 1), kernel, stride, filters, params)(net) 71 | x = tf.keras.layers.GlobalAveragePooling2D(name='GMP_layer')(net) 72 | logits = tf.keras.layers.Dense(units=num_classes, use_bias=True, activation=params.classifier_activation)(x) 73 | return logits 74 | 75 | """Defines the YAMNet.""" 76 | def create_yamnet_model(input_shape=(None,0,1), num_classes=12, weights_dir=None): 77 | params = Params() 78 | inputs = tf.keras.layers.Input(input_shape, dtype=tf.float32) 79 | predictions = YAMNet(inputs=inputs, num_classes=num_classes, params=params) 80 | model = tf.keras.Model(inputs=inputs, outputs=predictions, name="YAMNet") 81 | if weights_dir is not None: model.load_weights(weights_dir).expect_partial() 82 | return model -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- 1 | import flwr as fl 2 | import tensorflow as tf 3 | import compress 4 | import utils 5 | 6 | class _Client(fl.client.NumPyClient): 7 | 8 | def __init__(self, cid, num_clients, model_loader, data_loader, batch_size, client_compression, seed): 9 | self.cid = cid 10 | self.data_loader = data_loader 11 | self.num_clients = num_clients 12 | self.batch_size = batch_size 13 | self.seed = seed 14 | (self.data,self.val_data), self.num_classes, (self.num_samples,self.num_val_samples) = data_loader(cid, num_clients, batch_size=batch_size, seed=seed) 15 | self.model_loader = model_loader 16 | self.input_shape = self.data.element_spec[0].shape 17 | self.client_compression = client_compression 18 | self.combined_data = False 19 | 20 | def set_parameters(self, parameters, config): 21 | """ Set model weights """ 22 | if not hasattr(self, 'model'): 23 | self.model = self.model_loader(input_shape=self.input_shape[1:], num_classes=self.num_classes) 24 | 25 | self.model.compile( 26 | optimizer=tf.keras.optimizers.Adam(learning_rate=config['lr']), 27 | loss=tf.keras.losses.SparseCategoricalCrossentropy(name='loss', from_logits=True), 28 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')] 29 | ) 30 | 31 | if parameters is not None: 32 | self.model.set_weights(parameters) 33 | 34 | def get_parameters(self, config={}): 35 | """ Get model weights """ 36 | return self.model.get_weights() 37 | 38 | def fit(self, parameters, config): 39 | results = {} 40 | 41 | # Set parameters 42 | self.set_parameters(parameters, config) 43 | 44 | # Perform standard local update step 45 | if config['epochs'] != 0: 46 | h = self.model.fit(self.data, epochs=config['epochs'], verbose=0) 47 | print(f"[Client {self.cid}] - Accurary: {float(h.history['accuracy'][-1]):0.4f}.") 48 | 49 | # Perform compressed local update step 50 | if self.client_compression: 51 | self.model, h = compress.self_compress_with_data(model=self.model, 52 | data=self.data, 53 | epochs=config['compression_epochs'], 54 | nb_clusters=config['num_clusters'], 55 | learning_rate=config['compression_lr'], 56 | ) 57 | print(f"[Client {self.cid}] - Compr. Accurary: {float(h.history['accuracy'][-1]):0.4f} (Clusters: {config['num_clusters']}).") 58 | # Store results 59 | results['model_size'] = utils.get_gzipped_model_size_from_model(self.model) 60 | results['train_loss'] = float(h.history['loss'][-1]) 61 | results['train_accuracy'] = float(h.history['accuracy'][-1]) 62 | 63 | # Measure validation accuracy for elbow method. 64 | metrics = self.model.evaluate(self.val_data,verbose=0) 65 | results['val_loss'] = metrics[0] 66 | results['val_accuracy'] = metrics[1] 67 | results['num_val_samples'] = self.num_val_samples 68 | 69 | # Measure model embeddings 70 | results['embeddings'] = utils.compute_embeddings(self.model_loader, self.val_data, self.num_classes, weights=self.get_parameters()) 71 | 72 | return self.get_parameters(), self.num_samples, results 73 | 74 | def evaluate(self, parameters, config): 75 | raise NotImplementedError('Client-side evaluation is not implemented!') -------------------------------------------------------------------------------- /src/compress.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import tensorflow as tf 3 | import tensorflow_model_optimization as tfmot 4 | import distiller 5 | 6 | def _apply_weight_clustering(layer, clustering_params): 7 | if isinstance(layer, tf.keras.layers.Conv2D): 8 | return tfmot.clustering.keras.cluster_weights(layer, **clustering_params) 9 | return layer 10 | 11 | def compress_image_model(model, nb_clusters=50, verbose=True): 12 | clustering_params = {"number_of_clusters": nb_clusters,"cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.KMEANS_PLUS_PLUS} 13 | if verbose: 14 | print(f"Weight clustering with {nb_clusters} clusters", flush=True) 15 | return tf.keras.models.clone_model(model, clone_function=functools.partial(_apply_weight_clustering, clustering_params=clustering_params)) 16 | 17 | def self_compress(model, num_classes, model_loader, data_loader, data_shape=(32,32,3), nb_clusters=30, epochs=5, batch_size=128, learning_rate=1e-4, temperature=2.0, seed=0, verbose=0): 18 | ood_data = data_loader(batch_size=batch_size, seed=seed, reshape_size=data_shape[1:-1]) 19 | model.trainable = False 20 | student = model_loader(input_shape=data_shape[1:], num_classes=num_classes) 21 | student.set_weights(model.get_weights()) 22 | student_model = compress_image_model(student, nb_clusters=nb_clusters, verbose=False) 23 | trainer = distiller.Distiller(student=student_model, teacher=model, has_labels=False) 24 | trainer.compile( 25 | optimizer=tf.keras.optimizers.Adam(learning_rate), distillation_loss_fn=tf.keras.losses.KLDivergence(), 26 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')], temperature=temperature) 27 | trainer.fit(ood_data, epochs=epochs, verbose=0) 28 | student_compressed = tfmot.clustering.keras.strip_clustering(trainer.student) 29 | student_compressed.compile(metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]) 30 | return student_compressed 31 | 32 | def self_compress_with_data(model, data, nb_clusters=50, epochs=5, learning_rate=1e-4): 33 | student_model = compress_image_model(model, nb_clusters=nb_clusters, verbose=False) 34 | student_model.compile( 35 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 36 | loss=tf.keras.losses.SparseCategoricalCrossentropy(name='loss',from_logits=True), 37 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')] 38 | ) 39 | h = student_model.fit(data, epochs=epochs, verbose=0) 40 | student_compressed = tfmot.clustering.keras.strip_clustering(student_model) 41 | student_compressed.compile(metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]) 42 | return student_compressed, h 43 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | import numpy as np 4 | import tensorflow as tf 5 | from data_utils import LoadDataset, SplitDataset, MaskDataset, AugmentDataset 6 | from sklearn.model_selection import train_test_split 7 | 8 | DATASETS = { 9 | 'cifar10': { 10 | 'get_data_fn': LoadDataset.get_cifar10, 11 | 'train_fn':AugmentDataset.image_train_prep, 12 | 'test_fn': AugmentDataset.image_valid_prep, 13 | 'augment_fn': AugmentDataset.cutout, 14 | }, 15 | 'stylegan': { 16 | 'get_data_fn': LoadDataset.get_stylegan, 17 | 'train_fn':AugmentDataset.image_ood_prep, 18 | 'test_fn': None, 19 | 'augment_fn': None, 20 | }, 21 | 'spcm': { 22 | 'get_data_fn': LoadDataset.get_speech_commands, 23 | 'train_fn':AugmentDataset.audio_train_prep, 24 | 'test_fn': AugmentDataset.audio_valid_prep, 25 | 'augment_fn': None, 26 | }, 27 | 'librispeech': { 28 | 'get_data_fn': LoadDataset.get_librispeech, 29 | 'train_fn':AugmentDataset.audio_ood_prep, 30 | 'test_fn': None, 31 | 'augment_fn': None, 32 | }, 33 | } 34 | 35 | def get_cifar10(id=0, num_clients=10, return_eval_ds=False, batch_size=128, valid_split=0.1, seed=0, name='cifar10'): 36 | 37 | # Fix seed for reproducability. 38 | np.random.seed(seed) 39 | 40 | # Prepare evaluation set. 41 | if return_eval_ds: 42 | ds, info = DATASETS[name]['get_data_fn'](split='test', with_info=True) 43 | ds = tf.data.Dataset.from_tensor_slices(ds)\ 44 | .map(DATASETS[name]['test_fn']).batch(batch_size*4).prefetch(-1) 45 | ds = (ds,None) 46 | num_samples, num_classes = info['num_examples'], info['num_classes'] 47 | 48 | else: 49 | # Load data 50 | ds, info = DATASETS[name]['get_data_fn'](split='train', with_info=True) 51 | num_samples, num_classes = info['num_examples'], info['num_classes'] 52 | # Get client data 53 | samples_to_ids = np.random.choice(a=np.arange(0,num_clients),size=num_samples).astype(int) 54 | mask = np.in1d(samples_to_ids, [int(id)]) 55 | train_images,train_labels = (ds[0][mask,::], ds[1][mask,::]) 56 | 57 | val_ds = None 58 | num_samples = (train_images.shape[0], None) 59 | 60 | # Validation set. 61 | if valid_split>0.0: 62 | train_images, valid_images, train_labels, valid_labels = train_test_split(train_images, train_labels, test_size=valid_split, random_state=seed) 63 | val_ds = tf.data.Dataset.from_tensor_slices((valid_images,valid_labels))\ 64 | .map(DATASETS[name]['test_fn']).batch(batch_size).prefetch(-1) 65 | num_samples = (train_images.shape[0], valid_images.shape[0]) 66 | 67 | # Train set. 68 | ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))\ 69 | .map(DATASETS[name]['train_fn'], num_parallel_calls=tf.data.AUTOTUNE).shuffle(10000,seed,True).batch(batch_size) 70 | if DATASETS[name]['augment_fn'] is not None: 71 | ds = ds.map(DATASETS[name]['augment_fn'], num_parallel_calls=tf.data.AUTOTUNE) 72 | ds = ds.prefetch(-1) 73 | ds = (ds,val_ds) 74 | 75 | return ds, num_classes, num_samples 76 | 77 | def get_stylegan(batch_size=128, size=50000, reshape_size=(32,32), seed=0, name='stylegan'): 78 | # Fix seed for reproducability. 79 | np.random.seed(seed) 80 | # Load data 81 | ds, info = DATASETS[name]['get_data_fn'](split=None, with_info=True, reshape_size=reshape_size, color_mode='rgb') 82 | ds = ds.take(size).shuffle(10000,seed,True).batch(batch_size, drop_remainder=True) 83 | # Take subset of dataset 84 | ds = ds.map(functools.partial(DATASETS[name]['train_fn'], color_mode='rgb', size=reshape_size), num_parallel_calls=tf.data.AUTOTUNE) 85 | # Add augmmentations 86 | if DATASETS[name]['augment_fn'] is not None: # and dataset_name != 'emnist': 87 | ds = ds.map(functools.partial(DATASETS[name]['augment_fn'], mask_size=tuple([x//2 for x in reshape_size])), num_parallel_calls=tf.data.AUTOTUNE) 88 | ds = ds.prefetch(-1) 89 | return ds 90 | 91 | def get_spcm(id=0, num_clients=10, return_eval_ds=False, batch_size=128, valid_split=0.1, seed=0, num_mels=40, name='spcm',): 92 | 93 | # Fix seed for reproducability. 94 | np.random.seed(seed) 95 | 96 | # Prepare evaluation set. 97 | if return_eval_ds: 98 | ds, info = DATASETS[name]['get_data_fn'](split='test', with_info=True) 99 | ds = ds.map(functools.partial(DATASETS[name]['test_fn'], bins=num_mels), num_parallel_calls=tf.data.AUTOTUNE).\ 100 | batch(1).prefetch(-1) 101 | ds = (ds,None) 102 | num_classes, num_samples = info['num_classes'], info['num_examples'] 103 | 104 | else: 105 | # Load data 106 | ds, info = DATASETS[name]['get_data_fn'](split='train', with_info=True) 107 | num_classes = info['num_classes'] 108 | # Load labels 109 | labels = np.load(os.path.join(os.environ['TFDS_DATA_DIR'],f'speech_commands/train_labels.npy')) 110 | indexes = np.arange(labels.shape[0]) 111 | # Get client partition 112 | (indexes,labels) = SplitDataset.create_dataset_partition(data=(indexes,labels), cid=int(id), num_clients=num_clients, 113 | data_skew=0.0, class_skew=0.0, seed=seed) 114 | # Validation set 115 | val_ds = None 116 | num_samples = (indexes.shape[0], None) 117 | 118 | if valid_split>0.0: 119 | indexes, val_indexes, _, _ = train_test_split(indexes, labels, test_size=valid_split, random_state=seed) 120 | val_ds = MaskDataset.filter_dataset(ds, indexes=np.sort(val_indexes)) 121 | val_ds = val_ds.map(functools.partial(DATASETS[name]['test_fn'], bins=num_mels),num_parallel_calls=tf.data.AUTOTUNE).\ 122 | batch(1).prefetch(-1) 123 | num_samples = (indexes.shape[0], val_indexes.shape[0]) 124 | # Train set. 125 | ds = MaskDataset.filter_dataset(ds, indexes=np.sort(indexes)) 126 | ds = ds.map(functools.partial(DATASETS[name]['train_fn'], bins=num_mels), num_parallel_calls=tf.data.AUTOTUNE)\ 127 | .shuffle(10000,seed,True).batch(batch_size).prefetch(-1) 128 | ds = (ds,val_ds) 129 | 130 | return ds, num_classes, num_samples 131 | 132 | def get_librispeech(batch_size=128, size=-1, seed=0, num_mels=40, reshape_size=None, name='librispeech'): 133 | # Fix seed for reproducability. 134 | np.random.seed(seed) 135 | # Load data 136 | ds, info = DATASETS[name]['get_data_fn'](split=None, with_info=True) 137 | ds = ds.take(size).shuffle(10000,seed,True) 138 | # Take subset of dataset 139 | ds = ds.map(functools.partial(DATASETS[name]['train_fn'], bins=num_mels),num_parallel_calls=tf.data.AUTOTUNE)\ 140 | .batch(batch_size, drop_remainder=True).prefetch(-1) 141 | return ds 142 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import contextlib 3 | import numpy as np 4 | import random 5 | import math 6 | import scipy 7 | import functools 8 | import tensorflow as tf 9 | import tensorflow_addons as tfa 10 | import tensorflow_datasets as tfds 11 | from typing import List 12 | from sklearn.utils.class_weight import compute_class_weight 13 | 14 | class SplitDataset: 15 | 16 | @staticmethod 17 | def get_class_stats(labels): 18 | classes, samples_per_class = np.unique(labels, return_counts=True) 19 | indexes_per_class = [np.flatnonzero(labels==c) for c in classes] 20 | class_weights = 1/compute_class_weight(class_weight="balanced",classes=classes,y=labels) 21 | return classes, samples_per_class, indexes_per_class, class_weights 22 | 23 | @staticmethod 24 | def get_class_distribution(num_classes, num_partitions=10, concentration=0.0, seed=0): 25 | 26 | def _scale(x, low_range=(5e-4,1e-1), high_range=(50,1e+5)): 27 | x=1.-x 28 | assert 0.0<=x<=1.0, 'Concentration must be a float in [0.0,1.0] range.' 29 | if x <= 0.05: 30 | # Scale values in the bottom 5% of the range exponentially to low_range 31 | return math.pow(x / 0.05, 2.0) * 0.05 * (low_range[1] - low_range[0]) + low_range[0] 32 | elif x >= 0.95: 33 | # Scale values in the top 5% of the range exponentially to high_range 34 | return (1 - math.pow((1 - x) / 0.05, 2.0)) * 0.95 * (high_range[1] - high_range[0]) + high_range[0] 35 | else: 36 | # Scale values in the middle 90% of the range linearly to (1e-4, 500) 37 | return (x - 0.05) / 0.9 * (high_range[0] - low_range[1]) + low_range[1] 38 | 39 | return np.random.default_rng(seed).dirichlet(alpha=np.repeat(_scale(concentration), num_classes), size=num_partitions) 40 | 41 | @staticmethod 42 | def get_label_distribution(num_samples, num_partitions=10, concentration=0.0, min_length=1, seed=0): 43 | 44 | def _scale(x, min_value=0.01, max_value=1000): 45 | # Reverse the value of x so that it increases from 1.0 to 0.0 46 | x = 1.0 - x 47 | # Scale the value of x to the range [0, pi] 48 | x = x * math.pi 49 | # Calculate the scaled value using the reversed cosine curve 50 | y = (max_value - min_value) * (1.0 - math.cos(x)) / 2.0 + min_value 51 | return y 52 | 53 | # Fix seed 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | # Ensure that min_chunk_length is valid 57 | assert min_length >= 0 and min_length * num_partitions <= num_samples, "Invalid value for min_chunk_length" 58 | assert concentration >= 0 and concentration <= 1.0, "Invalid value for concentration" 59 | 60 | if concentration == 0: 61 | # If concentration is zero, return equal sized chunks 62 | chunk_size = num_samples // num_partitions 63 | remainder = num_samples % num_partitions 64 | chunks = [chunk_size] * num_partitions 65 | for i in range(remainder): chunks[i] += 1 66 | return np.asarray(chunks) 67 | else: 68 | # Otherwise, sample chunk sizes from a dirichlet distribution 69 | chunk_sizes = scipy.stats.dirichlet.rvs([_scale(concentration)] * num_partitions, size=1)[0] 70 | # Round chunk sizes to the nearest integer and ensure that each chunk 71 | # has at least min_chunk_length samples 72 | chunks = [max(int(round(size * num_samples)), min_length) for size in chunk_sizes] 73 | # Ensure that the total number of samples is equal to num_samples 74 | while sum(chunks) != num_samples: 75 | # Calculate the difference between the total number of samples and num_samples 76 | diff = num_samples - sum(chunks) 77 | # Calculate the amount to increase or decrease each chunk by 78 | sign = lambda a: (a>0) - (a<0) 79 | chunk_size_diff = sign(diff)*(abs(diff) // num_partitions) 80 | # Increase or decrease the size of each chunk by chunk_size_diff 81 | chunks = [max(chunk + chunk_size_diff, min_length) for chunk in chunks] 82 | # If the total number of samples is still not equal to num_samples, 83 | # increase or decrease a random chunk by the remaining difference 84 | if sum(chunks) != num_samples: 85 | while sum(chunks) != num_samples: 86 | remainder = num_samples - sum(chunks) 87 | chunk_id = random.choice(range(num_partitions)) 88 | if chunks[chunk_id]+remainder >= min_length: chunks[chunk_id] += remainder 89 | assert sum(chunks) == num_samples, "Chunks do not sum up to the total number of samples." 90 | return np.asarray(chunks) 91 | 92 | @staticmethod 93 | def create_partitions(num_classes, num_partitions, class_distribution, samples_distribution, class_weights, class_samples=None, align=True, min_length=10, max_iter=10, seed=0): 94 | # Create empty partitions array 95 | parts = np.zeros((num_partitions,num_classes), dtype=np.int64) 96 | # Distribute class samples across partitions based on distributions. 97 | for i,part_class_distibution in enumerate(class_distribution): 98 | parts[i] = np.around(part_class_distibution * class_weights * samples_distribution[i]) 99 | # (Try to) Fix rounding errors. 100 | if align and (class_samples is not None): 101 | i=0 102 | while not __class__.check_alignment(parts, class_samples, samples_distribution, verbose=0): 103 | if i==max_iter: break; 104 | parts = __class__.align_samples_distribution(parts, class_distribution, class_weights, samples_distribution, min_length, seed+i) 105 | parts = __class__.align_class_distribution(parts, class_samples, samples_distribution, min_length, seed+i+1) 106 | i+=1 107 | 108 | assert __class__.valid_partitions(parts, class_samples), 'A valid partitioning could not be constructed.' 109 | 110 | return parts 111 | 112 | @staticmethod 113 | def align_class_distribution(partitions, class_samples, samples_distribution, min_length=10, seed=0): 114 | # Fix random seed 115 | np.random.seed(seed) 116 | # Local functions 117 | compute_diff = lambda x: (class_samples - x.sum(axis=0)) 118 | compute_replace_classes = lambda x: (True if abs(x)>partitions.shape[0] else False) 119 | # Available swappable partitions 120 | swappable_partitions = np.arange(partitions.shape[0]) 121 | num_samples_to_replace = compute_diff(partitions) 122 | 123 | for i,n in enumerate(num_samples_to_replace): 124 | # Create random partitions to modify classes 125 | indexes = np.random.choice(swappable_partitions, 126 | size=abs(n), 127 | replace=compute_replace_classes(compute_diff(partitions)[i]), 128 | p=scipy.special.softmax((samples_distribution - partitions.sum(axis=1)))) 129 | # Remove/Add sample for current partition 130 | for idx in indexes: 131 | if partitions[idx][i] + np.sign(n)>min_length: 132 | partitions[idx][i] = partitions[idx][i] + np.sign(n) 133 | return partitions 134 | 135 | @staticmethod 136 | def normalize_probabilities(prob): 137 | from sklearn.preprocessing import normalize 138 | if sum(prob) != 1.0: 139 | prob = normalize(prob[:,np.newaxis], axis=0, norm='l1',).ravel() 140 | return prob 141 | 142 | @staticmethod 143 | def align_samples_distribution(partitions, class_distribution, class_weights, samples_distribution, min_length=10, seed=0): 144 | # Fix random seed 145 | np.random.seed(seed) 146 | # Local functions 147 | compute_diff = lambda x: (samples_distribution - x.sum(axis=1)) 148 | compute_replace_partitions = lambda x: (True if abs(x)>partitions.shape[0]-1 else False) 149 | compute_replace_classes = lambda x: (True if abs(x)>partitions.shape[1] else False) 150 | # Available swappable classes 151 | swappable_classes = np.arange(partitions.shape[1]) 152 | 153 | for i in range(partitions.shape[0]): 154 | # Available swappable partitions 155 | swappable_partitions = np.setdiff1d(np.arange(partitions.shape[0]),[i]) 156 | # Local variables 157 | num_samples_to_replace = compute_diff(partitions) 158 | 159 | partition_replace = compute_replace_partitions(num_samples_to_replace[i]) 160 | class_replace = compute_replace_classes(num_samples_to_replace[i]) 161 | # Create random partitions swaps 162 | part_idxs = np.random.choice( 163 | swappable_partitions, 164 | size=abs(num_samples_to_replace[i]), 165 | replace=partition_replace, 166 | p=__class__.normalize_probabilities(scipy.special.softmax(np.delete(num_samples_to_replace,[i], axis=0))+1e-9), 167 | ) 168 | # Create random class swaps 169 | class_idxs = np.random.choice(swappable_classes, 170 | size=abs(num_samples_to_replace[i]), 171 | replace=class_replace, 172 | p=scipy.special.softmax(class_distribution[i]*class_weights), 173 | ) 174 | for j,c in zip(part_idxs,class_idxs): 175 | if (partitions[i][c] + np.sign(num_samples_to_replace[i])>min_length) and (partitions[j][c] - np.sign(num_samples_to_replace[i])>min_length): 176 | # Remove sample for current partition 177 | partitions[i][c] = partitions[i][c] + np.sign(num_samples_to_replace[i]) 178 | # Move it to another (random) partition 179 | partitions[j][c] = partitions[j][c] - np.sign(num_samples_to_replace[i]) 180 | return partitions 181 | 182 | @staticmethod 183 | def check_alignment(partitions, class_samples, samples_distribution, verbose=0): 184 | 185 | if verbose: 186 | __class__.print_diff(partitions, class_samples, samples_distribution) 187 | 188 | if ((samples_distribution - partitions.sum(1)).var()==.0) and ((class_samples - partitions.sum(0)).var()==.0): 189 | if ((samples_distribution - partitions.sum(1)).sum()==.0) and ((class_samples - partitions.sum(0)).sum()==.0): 190 | return True 191 | else: 192 | return False 193 | 194 | @staticmethod 195 | def valid_partitions(partitions, class_samples): 196 | return (class_samples - partitions.sum(0)).sum()==0.0 197 | 198 | @staticmethod 199 | def print_diff(partitions, class_samples, samples_distribution): 200 | x = (samples_distribution - partitions.sum(1)) 201 | y = (class_samples - partitions.sum(0)) 202 | 203 | print(f"Sum is [{x.sum()}, {y.sum()}]") 204 | print(f"Var is [{x.var():0.2f}, {y.var():0.2f}]") 205 | print(f"Samples error: {x}") 206 | print(f"Class error: {y}\n") 207 | 208 | @staticmethod 209 | def create_mask(partitions, class_indexes, seed=0): 210 | # Fix random seed 211 | np.random.seed(seed) 212 | # Number of classes and partitions 213 | num_classes = len(class_indexes) 214 | num_partitions = partitions.shape[0] 215 | # Create mask array 216 | partitions_class_indexes = np.empty_like(partitions, dtype=object) 217 | # Iterate over classes 218 | for c_idx in range(len(class_indexes)): 219 | # Create (un)availabe indexes set 220 | available_class_indexes = set(class_indexes[c_idx]) 221 | unavailable_class_indexes = set() 222 | # Iterate over partitions 223 | for i in range(num_partitions): 224 | # Create set of indexes 225 | available_class_indexes -= unavailable_class_indexes 226 | num_samples = partitions[i][c_idx] 227 | # Choose indexes for each partition 228 | indexes = np.random.choice(np.array(list(available_class_indexes)), size=num_samples, replace=False) 229 | partitions_class_indexes[i][c_idx] = indexes 230 | # Update unavailable indexes 231 | unavailable_class_indexes.update(set(indexes)) 232 | # Create one index array per partition. 233 | indexes = np.empty(num_partitions, dtype=object) 234 | for i in range(num_partitions): 235 | indexes[i] = np.concatenate(partitions_class_indexes[i]) 236 | return indexes, partitions_class_indexes 237 | 238 | @staticmethod 239 | def create_dataset_partition(data, cid=0, num_clients=10, data_skew=0.0, class_skew=0.0, min_samples=1, max_iter=1000, seed=0): 240 | # Read dataset statistics 241 | classes, class_samples, class_indexes, class_weights = __class__.get_class_stats(labels=data[1]) 242 | # Create samples distribution (# samples per partition) 243 | samples_distribution = __class__.get_label_distribution(num_samples=data[1].shape[0], num_partitions=num_clients, concentration=data_skew, min_length=min_samples, seed=seed) 244 | # Create class distribution (% classes per partition) 245 | class_distribution = __class__.get_class_distribution(num_classes=classes.size, num_partitions=num_clients, concentration=class_skew, seed=seed) 246 | # Create partition arrays (# samples per class for each partition) 247 | partitions = __class__.create_partitions(num_classes=classes.size, num_partitions=num_clients, 248 | class_distribution=class_distribution, samples_distribution=samples_distribution, 249 | class_weights=class_weights, class_samples=class_samples, 250 | align=True, min_length=min_samples, max_iter=max_iter, seed=seed) 251 | # Create indexes masks based on partitions 252 | indexes,_ = __class__.create_mask(partitions, class_indexes=class_indexes, seed=seed) 253 | # Create mask 254 | mask = np.zeros(data[1].shape[0], dtype=bool) 255 | mask[indexes[int(cid)]] = True 256 | # Return tuple of (indexes,labels) 257 | return (data[0][mask], data[1][mask]) 258 | 259 | @staticmethod 260 | def concatenate_client_data(ds, clients_ids): 261 | client_ds = ds.create_tf_dataset_for_client(clients_ids[0]) 262 | if len(clients_ids)>1: 263 | for i in clients_ids[1:]: 264 | client_ds = client_ds.concatenate(ds.create_tf_dataset_for_client(i)) 265 | return client_ds 266 | 267 | @staticmethod 268 | def concatenate_client_data_from_numpy(ds, clients_ids): 269 | images, labels = [],[] 270 | for id in clients_ids: 271 | (x,y) = ds[id][()].values() 272 | images.extend(x) 273 | labels.extend(y) 274 | return (np.array(images), np.array(labels)) 275 | 276 | @staticmethod 277 | def split_clients_ids_to_partitions(cid, num_clients, available_ids, seed=0): 278 | # Fix seed for reproducability. 279 | np.random.seed(seed) 280 | # Ensure split is possible. 281 | assert num_clients <= len(available_ids), 'Number of clients exceeds avaialable clients ids.' 282 | partitions = np.random.choice(a=num_clients, size=len(available_ids), replace=True) 283 | return [available_ids[i] for i in np.argwhere(np.isin(partitions, [int(cid)])).ravel()] 284 | 285 | class MaskDataset: 286 | 287 | @staticmethod 288 | def create_lookup_table(mask): 289 | keys = tf.constant(mask) 290 | values = tf.ones_like(keys) 291 | return tf.lookup.StaticHashTable(tf.lookup.KeyValueTensorInitializer(keys, values), default_value=0) 292 | 293 | @staticmethod 294 | def hash_table_filter(index, _, table): 295 | return tf.cast(table.lookup(index), tf.bool) 296 | 297 | @staticmethod 298 | def filter_dataset(ds, indexes): 299 | # Ensure that dataset is a tf.dataset object 300 | if isinstance(ds,list): ds = ds[0] 301 | # Construct lookup table from indexes 302 | table = __class__.create_lookup_table(np.asarray(indexes)) 303 | # Convert to enumerated dataset 304 | ds = ds.enumerate() 305 | # Filter dataset based on lookup table 306 | ds = ds.filter(functools.partial(__class__.hash_table_filter, table=table)) 307 | # Convert back to original dataset 308 | ds = ds.map(lambda _,x: x) 309 | return ds 310 | 311 | class AugmentDataset: 312 | 313 | IMAGENET_STD = tf.reshape((0.2023, 0.1994, 0.2010), shape=(1, 1, 3)) 314 | IMAGENET_MEAN = tf.reshape((0.4914, 0.4822, 0.4465), shape=(1, 1, 3)) 315 | 316 | @staticmethod 317 | def image_valid_prep(x,y): 318 | x = tf.cast(x, tf.float32) / 255. 319 | x = (x - __class__.IMAGENET_MEAN) / __class__.IMAGENET_STD 320 | return x, y 321 | 322 | @staticmethod 323 | def image_train_prep(x,y=None, size=(32,32)): 324 | x = tf.cast(x, tf.float32) / 255. 325 | x = tf.image.random_flip_left_right(x) 326 | x = tf.image.pad_to_bounding_box(x, 4, 4, size[0]+8, size[1]+8) 327 | x = tf.image.random_crop(x, (size[0], size[1], 3)) 328 | x = (x - __class__.IMAGENET_MEAN) / __class__.IMAGENET_STD 329 | return x if y is None else (x,y) 330 | 331 | # emnist support 332 | @staticmethod 333 | def image_valid_prep_v2(x,y): 334 | x = tf.cast(x, tf.float32) / 255. 335 | return x, y 336 | 337 | # emnist support 338 | @staticmethod 339 | def image_train_prep_v2(x,y=None, size=(32,32)): 340 | x = tf.cast(x, tf.float32) / 255. 341 | x = tf.image.random_flip_left_right(x) 342 | return x if y is None else (x,y) 343 | 344 | @staticmethod 345 | def image_ood_prep(x, color_mode='rgb', size=(32,32)): 346 | if color_mode=='grayscale': return tf.map_fn(__class__.image_train_prep_v2, x) 347 | return tf.map_fn(functools.partial(__class__.image_train_prep, size=size), x) 348 | 349 | @staticmethod 350 | def audio_valid_prep(x,y, bins=64, sr=16000, to_float=True): 351 | # Convert to float 352 | if to_float: x = tf.cast(x, tf.float32) / float(tf.int16.max) 353 | # Compute mel spectrograms 354 | x = __class__.logmelspectogram(x, bins=bins, sr=sr) 355 | return x, y 356 | 357 | @staticmethod 358 | def audio_train_prep(x, y=None, seconds=1, bins=64, sr=16000, to_float=True): 359 | # Convert to float 360 | if to_float: x = tf.cast(x, tf.float32) / float(tf.int16.max) 361 | # Pad to seconds length 362 | x = __class__.pad(x, sequence_length=int(sr*seconds)) 363 | # Random crop (if larger) 364 | x = tf.image.random_crop(x, [int(sr*seconds)]) 365 | # Compute mel spectrograms 366 | x = __class__.logmelspectogram(x, bins=bins, sr=sr) 367 | return x if y is None else (x,y) 368 | 369 | @staticmethod 370 | def audio_ood_prep(x, seconds=1, bins=64, sr=16000, to_float=True): 371 | return __class__.audio_train_prep(tf.squeeze(x), seconds=seconds, bins=bins, sr=sr, to_float=to_float) 372 | 373 | ###################### 374 | # Image Augmentations 375 | ###################### 376 | 377 | " Cutout" 378 | def cutout(images, labels=None, mask_size=(16,16)): 379 | _images = tfa.image.cutout(images, mask_size=mask_size) 380 | if labels is None: return _images 381 | return _images, labels 382 | 383 | " Mixup" 384 | def mixup(x, y=None): 385 | alpha = tf.random.uniform([], 0, 1) 386 | mixedup_x = alpha * x + (1 - alpha) * tf.reverse(x, axis=[0]) 387 | if y is None: return mixedup_x 388 | return mixedup_x, y 389 | 390 | ###################### 391 | # Audio Augmentations 392 | ###################### 393 | 394 | @staticmethod 395 | def read_audio(fp, label=None): 396 | waveform, _ = tf.audio.decode_wav(tf.io.read_file(fp)) 397 | return waveform[Ellipsis, 0], label 398 | 399 | @staticmethod 400 | def pad(waveform, sequence_length=16000): 401 | padding = tf.maximum(sequence_length - tf.shape(waveform)[0], 0) 402 | left_pad = padding // 2 403 | right_pad = padding - left_pad 404 | return tf.pad(waveform, paddings=[[left_pad, right_pad]]) 405 | 406 | @staticmethod 407 | def logmelspectogram(x, bins=64, sr=16000, fmin=60.0, fmax=7800.0, fft_length=1024): 408 | # Spectrogram extraction 409 | s = tf.signal.stft(x, frame_length=400, frame_step=160, fft_length=fft_length) 410 | x = tf.abs(s) 411 | w = tf.signal.linear_to_mel_weight_matrix(bins, s.shape[-1], sr, fmin, fmax) 412 | x = tf.tensordot(x, w, 1) 413 | x.set_shape(x.shape[:-1].concatenate(w.shape[-1:])) 414 | x = tf.math.log(x+1e-6) 415 | return x[Ellipsis, tf.newaxis] 416 | 417 | class LoadDataset: 418 | 419 | @staticmethod 420 | def get_cifar10(split, with_info=True): 421 | if isinstance(split,List): split = split[0] 422 | (train_images, train_labels), (test_images,test_labels) = tf.keras.datasets.cifar10.load_data() 423 | info = {'num_classes': len(np.unique(train_labels)), 'num_examples': test_images.shape[0] if split=='test' else train_images.shape[0]} 424 | ds = (test_images,test_labels) if split=='test' else (train_images, train_labels) 425 | return (ds, info) if with_info else ds 426 | 427 | @staticmethod 428 | def get_stylegan(split=None, with_info=True, reshape_size=(32,32), color_mode='rgb', name='stylegan_oriented'): 429 | with contextlib.redirect_stdout(None): # Read data from dir 430 | ds = tf.keras.preprocessing.image_dataset_from_directory( 431 | directory=os.path.join(os.environ['TFDS_DATA_DIR'],f'raw/{name}/'), 432 | label_mode=None, shuffle=False, batch_size=None, image_size=reshape_size, color_mode=color_mode, seed=0) 433 | return (ds, None) if with_info else ds 434 | 435 | @staticmethod 436 | def get_speech_commands(split, with_info=True): 437 | ds, info = tfds.load('speech_commands', split=split, with_info=True, as_supervised=True, shuffle_files=False) 438 | info = {'num_classes': info.features['label'].num_classes, 'num_examples': info.splits[split].num_examples} 439 | return (ds, info) if with_info else ds 440 | 441 | @staticmethod 442 | def get_librispeech(split=None, with_info=True, reshape_size=None, name='librispeech'): 443 | with contextlib.redirect_stdout(None): # Read data from dir 444 | ds = tf.keras.utils.audio_dataset_from_directory( 445 | directory=os.path.join(os.environ['TFDS_DATA_DIR'],f'raw/{name}/'), 446 | label_mode=None, shuffle=False, batch_size=None, seed=0) 447 | return (ds, None) if with_info else ds 448 | -------------------------------------------------------------------------------- /src/distiller.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Distiller(tf.keras.Model): 4 | 5 | def __init__(self, student, teacher, has_labels=False): 6 | super(Distiller, self).__init__() 7 | self.teacher = teacher 8 | self.student = student 9 | self.is_source_dataset_has_labels = has_labels 10 | 11 | def save(self, *args, **kwargs): 12 | return self.student.save(include_optimizer=False, *args, **kwargs) 13 | 14 | def save_weights(self, *args, **kwargs): 15 | return self.student.save_weights(*args, **kwargs) 16 | 17 | def load_weights(self, *args, **kwargs): 18 | return self.student.load_weights(*args, **kwargs) 19 | 20 | def compile(self, optimizer, metrics, distillation_loss_fn, temperature=8., energy_confidence=0.): 21 | super(Distiller, self).compile(optimizer=optimizer, metrics=metrics) 22 | self.distillation_loss_fn = distillation_loss_fn 23 | self.temperature = temperature 24 | self.energy_confidence = energy_confidence 25 | 26 | def train_step(self, data): 27 | 28 | if self.is_source_dataset_has_labels: # If knowledge distillation from source data 29 | x, _ = data 30 | else: # If knowledge distillation from single image source 31 | x = data 32 | 33 | # Forward pass of teacher 34 | teacher_predictions = self.teacher(x, training=False) 35 | 36 | if self.energy_confidence > 0.: 37 | pseudo_mask = tf.cast(-1. * tf.math.reduce_logsumexp(teacher_predictions, axis=-1) >= self.energy_confidence, dtype=tf.float32) 38 | 39 | with tf.GradientTape() as tape: 40 | # Forward pass of student 41 | student_predictions = self.student(x, training=True) 42 | 43 | # Compute losses 44 | distillation_loss = self.distillation_loss_fn( 45 | tf.nn.softmax(teacher_predictions / self.temperature, axis=1), 46 | tf.nn.softmax(student_predictions / self.temperature, axis=1), 47 | ) 48 | 49 | if self.energy_confidence > 0.: 50 | loss = tf.reduce_mean(tf.cast(distillation_loss, dtype=tf.float32) * pseudo_mask) 51 | else: 52 | loss = distillation_loss 53 | 54 | # Compute gradients 55 | trainable_vars = self.student.trainable_variables 56 | gradients = tape.gradient(loss, trainable_vars) 57 | # Update weights 58 | self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 59 | # Return a dict of performance 60 | results = {m.name: m.result() for m in self.metrics} 61 | results.update({"distillation_loss": distillation_loss}) 62 | return results 63 | 64 | def test_step(self, data): 65 | # Unpack the data 66 | x, y = data 67 | # Compute predictions 68 | y_prediction = self.student(x, training=False) 69 | # Update the metrics. 70 | self.compiled_metrics.update_state(y, y_prediction) 71 | # Return a dict of performance 72 | results = {m.name: m.result() for m in self.metrics} 73 | return results -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utils 3 | import argparse 4 | import shutil 5 | import flwr as fl 6 | from time import sleep 7 | import network 8 | 9 | parser = argparse.ArgumentParser(description="Federated Model Compression") 10 | parser.add_argument("--random_id", default='0a6809e9cb', type=str, help="unique id of experiment") 11 | parser.add_argument("--rounds", default=30, type=int, help="number of totat federated rounds to run") 12 | parser.add_argument("--timeout", default=None, type=utils.none_or_int, nargs='?', help="maximum seconds of round until timeout") 13 | parser.add_argument("--clients", default=10, type=int, help="number of clients") 14 | parser.add_argument("--cpu_percentage", default=0.5, type=float, help="percentage of cpu resources (cores) to be used by each client") 15 | parser.add_argument("--gpu_percentage", default=0.08, type=float, help="percentage of gpu resources to be used by each client") 16 | parser.add_argument("--participation", default=1.0, type=float, help="participation percentage of clients in each round") 17 | parser.add_argument("--epochs", default=1, type=int, help="number of local client epochs to run per federated round") 18 | parser.add_argument("--batch_size", default=32, type=int, help="batch size for local client training") 19 | parser.add_argument("--learning_rate", default=1e-1, type=float, help="learning rate") 20 | parser.add_argument("--dataset", default="cifar10", type=str, help="dataset name to use for trainig") 21 | parser.add_argument("--model_name", default="resnet20", type=str, help="dataset name to use for trainig") 22 | parser.add_argument("--init_model_fp", default=None, type=utils.none_or_str, nargs='?', help="initialization model path") 23 | parser.add_argument("--results_dir", default="../assets/results", type=str, help="store results directory") 24 | parser.add_argument("--model_dir", default="../assets/saved_models", type=str, help="store model directory") 25 | parser.add_argument("--seed", default=0, type=int, help="seed used to achieve reproducibility") 26 | # Compression Parameters 27 | parser.add_argument("--init_num_clusters", default=8, type=int, help="initial number of clusters for weight clustering") 28 | parser.add_argument("--max_num_clusters", default=20, type=int, help="maximum number of clusters for weight clustering") 29 | parser.add_argument("--cluster_update_step", default=1, type=int, help="search step for optimal number of clusters for weight clustering") 30 | parser.add_argument("--cluster_search_init_rounds", default=1, type=int, help="initial rounds before starting cluster search") 31 | parser.add_argument("--cluster_search_window", default=3, type=int, help="window of metric average for optimal number of clusters for weight clustering") 32 | parser.add_argument("--cluster_search_patience", default=3, type=int, help="patience step for optimal number of clusters for weight clustering") 33 | parser.add_argument("--cluster_search_metric_limit", default=1.0, type=float, help="metric minimum improvement limit for switch number of clusters") 34 | parser.add_argument("--client_compression", dest="client_compression", action="store_true", help="use client-side weight clustering") 35 | parser.add_argument("--client_compression_epochs", default=5, type=int, help="local train epochs with weight clustering") 36 | parser.add_argument("--client_compression_lr", default=1e-4, type=float, help="learning rate for local train epochs with weight clustering") 37 | parser.add_argument("--server_compression", dest="server_compression", action="store_true", help="use server-side ood weight clustering with KD") 38 | parser.add_argument("--server_compression_step", default=1, type=int, help="frequency step for applying server-side ood weight clustering with KD") 39 | parser.add_argument("--server_compression_epochs", default=5, type=int, help="train epochs with server-side ood weight clustering with KD") 40 | parser.add_argument("--server_compression_lr", default=1e-4, type=float, help="learning rate for server-side ood weight clustering with KD") 41 | parser.add_argument("--server_compression_batch", default=256, type=int, help="batch size for server-side ood weight clustering with KD") 42 | parser.add_argument("--server_compression_temperature", default=8.0, type=float, help="temperature for server-side ood weight clustering with KD") 43 | args = parser.parse_args() 44 | 45 | model_save_dir_fn = lambda x: os.path.abspath(os.path.join(args.model_dir,f"{args.model_name}_{args.dataset}_{args.server_compression}_{args.client_compression}_{x}_{args.random_id}.h5")) 46 | store_dir_fn = lambda x: os.path.abspath(os.path.join(args.results_dir, f"{args.model_name}_{args.dataset}_{args.server_compression}_{args.client_compression}_{x[0]}_{args.random_id}.{x[1]}")) 47 | 48 | def get_clients_config(): 49 | return { 50 | "lr": args.learning_rate, 51 | "epochs": args.epochs, 52 | "compression_lr": args.client_compression_lr, 53 | "compression_epochs": args.client_compression_epochs, 54 | } 55 | 56 | def get_server_compression_config(): 57 | return { 58 | 'compression_step': args.server_compression_step, 59 | 'epochs':args.server_compression_epochs, 60 | 'learning_rate':args.server_compression_lr, 61 | 'batch_size': args.server_compression_batch, 62 | 'temperature': args.server_compression_temperature, 63 | 'data_loader': utils.create_dataloader_fn(ood=True)[args.dataset], 64 | 'dataset_name': args.dataset, 65 | 'seed': args.seed, 66 | } 67 | 68 | def get_cluster_search_config(): 69 | return { 70 | "init_num_clusters":args.init_num_clusters, 71 | "max_num_clusters":args.max_num_clusters, 72 | "cluster_update_step":args.cluster_update_step, 73 | "init_rounds": args.cluster_search_init_rounds, 74 | "window":args.cluster_search_window, 75 | "patience":args.cluster_search_patience, 76 | "limit": args.cluster_search_metric_limit, 77 | } 78 | 79 | def create_client(cid): 80 | import utils 81 | # Sleep few seconds to allow for GPU to setup for each client 82 | sleep(int(cid)*0.75) 83 | # Assign free GPU 84 | os.environ['CUDA_VSIBLE_DEVICES']="0" 85 | # Start client 86 | from client import _Client 87 | return _Client(int(cid), 88 | num_clients=args.clients, 89 | model_loader=network.get_model(args.model_name), 90 | data_loader=utils.create_dataloader_fn()[args.dataset], 91 | batch_size = args.batch_size, 92 | client_compression=args.client_compression, 93 | seed=args.seed, 94 | ) 95 | 96 | def create_server(run_id=0): 97 | # Assign free GPU 98 | os.environ['CUDA_VSIBLE_DEVICES']="0" 99 | # Start server 100 | from server import _Server 101 | return _Server(run_id=run_id, 102 | num_rounds=args.rounds, 103 | num_clients=args.clients, 104 | participation=args.participation, 105 | model_loader=network.get_model(args.model_name), 106 | data_loader=utils.create_dataloader_fn()[args.dataset], 107 | init_model_fp=args.init_model_fp, 108 | model_save_dir=model_save_dir_fn, 109 | clients_config=get_clients_config(), 110 | # Compression parameters 111 | server_compression=args.server_compression, 112 | client_compression=args.client_compression, 113 | server_compression_config=get_server_compression_config(), 114 | cluster_search_config=get_cluster_search_config(), 115 | ) 116 | 117 | def main(run_id=0): 118 | 119 | server = create_server(run_id=run_id) 120 | history = fl.simulation.start_simulation( 121 | client_fn = create_client, 122 | server = server, 123 | num_clients=args.clients, 124 | ray_init_args= { 125 | "ignore_reinit_error": True, "include_dashboard": False, 126 | "dashboard_host": "127.0.0.1", "dashboard_port": 8265, 127 | # By setting `num_cpus` to match the number of available cores, we ensure that clients are terminated after been executed 128 | # This way gpu resources are released in every round. 129 | "num_cpus": min(args.clients,7) 130 | }, 131 | client_resources={ 132 | "num_cpus":args.cpu_percentage, 133 | "num_gpus": args.gpu_percentage 134 | }, 135 | config=fl.server.ServerConfig(num_rounds=args.rounds, round_timeout=args.timeout), 136 | ) 137 | return history 138 | 139 | if __name__ == "__main__": 140 | 141 | parsed_args = '\t' + '\t'.join(f'{k} = {v}\n' for k, v in vars(args).items()) 142 | print('Parameters:') 143 | print(parsed_args) 144 | 145 | history = main() 146 | df = utils.store_history(history, args, store_dir_fn=store_dir_fn) 147 | print(df) 148 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | 2 | def get_resnet20_network(input_shape=(32,32,3), num_classes=10, weights_dir=None): 3 | from architectures import Resnets 4 | return Resnets.resnet20(input_shape=input_shape, num_classes=num_classes, weights_dir=weights_dir) 5 | 6 | def get_yamnet_network(input_shape=(None,40,1), num_classes=10, weights_dir=None): 7 | from architectures import YAMNet 8 | return YAMNet.create_yamnet_model(input_shape=input_shape, num_classes=num_classes, weights_dir=weights_dir) 9 | 10 | def get_model(name='resnet20'): 11 | if name=='resnet20': 12 | return get_resnet20_network 13 | elif name=='yamnet': 14 | return get_yamnet_network 15 | else: 16 | raise Exception("Model `{}` is not available. Please provide one of [`resnet20`,`cnn`,`yamnet`].".format(name)) 17 | -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import flwr as fl 3 | import tensorflow as tf 4 | import compress 5 | import strategy 6 | import utils 7 | import pandas as pd 8 | 9 | class _Server(fl.server.Server): 10 | 11 | def __init__(self, model_loader, data_loader, num_rounds, run_id=0, 12 | num_clients=10, participation=1.0, 13 | init_model_fp=None, model_save_dir='./', 14 | server_compression=True, client_compression = False, 15 | server_compression_config={"compression_step":1, "learning_rate":1e-4, "epochs":10}, 16 | cluster_search_config = { 17 | "init_num_clusters":8, "max_num_clusters":20, "cluster_update_step":1, 18 | "init_rounds":1, "window":3, "patience":3, "limit": 1, 19 | }, 20 | clients_config={"learning_rate":1e-4, "epochs":1}, 21 | log_level=logging.INFO, 22 | ): 23 | self.run_id = run_id 24 | (self.data, self.valid_data), self.num_classes, self.num_samples = data_loader(return_eval_ds=True) 25 | self.model_loader = model_loader 26 | self.input_shape = self.data.element_spec[0].shape 27 | self.init_model_fp = init_model_fp 28 | self.clients_config = clients_config 29 | self.num_clients = num_clients 30 | self.participation = participation 31 | self.set_strategy(self) 32 | self._client_manager = fl.server.client_manager.SimpleClientManager() 33 | self.max_workers = None 34 | self.num_rounds = num_rounds 35 | self.model_save_dir = model_save_dir 36 | logging.getLogger("flower").setLevel(log_level) 37 | 38 | # Self-compress parameters 39 | self.server_compression = server_compression 40 | self.client_compression = client_compression 41 | self.server_compression_config = server_compression_config 42 | self.cluster_search_config = cluster_search_config 43 | self.compression_rnds = utils.set_compression_rnds(server_compression_config['compression_step'], num_rounds) 44 | # Local Variables 45 | self.num_clusters = cluster_search_config['init_num_clusters'] 46 | self.compression = False 47 | self.compression_metric = 'val_accuracy' 48 | self.compression_metric_dict = {self.compression_metric:[], 'nb_clusters':[], 'val_score':[]} 49 | 50 | def _is_final_round(self,rnd): 51 | return self.num_rounds==rnd 52 | 53 | " Get number of epochs to train on ood data." 54 | def _num_compression_epochs(self, rnd): 55 | # NOTE: To recover accuracy from large compression, we might consider to increase it gradually. 56 | return self.server_compression_config['epochs'] 57 | 58 | " Set number of clusters based on allowed performance drop." 59 | def set_num_clusters(self, metrics): 60 | 61 | if self.client_compression or self.server_compression: 62 | # Check if metric for cluster selection exists 63 | if self.compression_metric in metrics.keys(): 64 | metric = metrics[self.compression_metric] 65 | # Cluster reduction check 66 | flag = utils.compression_flag(df=pd.DataFrame(self.compression_metric_dict), 67 | init_num_clusters=self.cluster_search_config["init_num_clusters"], 68 | max_num_clusters=self.cluster_search_config["max_num_clusters"], 69 | init_rounds=self.cluster_search_config["init_rounds"], 70 | window=self.cluster_search_config["window"], 71 | patience=self.cluster_search_config["patience"], 72 | limit=self.cluster_search_config["limit"]/100,) 73 | 74 | if flag: self.num_clusters+=self.cluster_search_config["cluster_update_step"] 75 | print(f"[Server] - Best Score: {max(self.compression_metric_dict[self.compression_metric]):0.4f}, Current Score: {metric:0.4f}, Flag: {flag}.") 76 | return flag 77 | return False 78 | 79 | " Set best-seen performance across training." 80 | def set_performance(self, metrics): 81 | if self.compression_metric in metrics.keys(): 82 | self.compression_metric_dict[self.compression_metric].append(metrics[self.compression_metric]) 83 | self.compression_metric_dict['val_score'].append(metrics['val_score']) 84 | self.compression_metric_dict['nb_clusters'].append(self.num_clusters) 85 | 86 | " Set the max_workers used by ThreadPoolExecutor. " 87 | def set_max_workers(self, *args, **kwargs): 88 | return super(_Server, self).set_max_workers(*args, **kwargs) 89 | 90 | " Set server-side model aggregation strategy. " 91 | def set_strategy(self, *_): 92 | self.strategy = strategy.FedCustom( 93 | min_available_clients=self.num_clients, 94 | fraction_fit=self.participation, 95 | min_fit_clients=int(self.participation*self.num_clients), 96 | fraction_evaluate=0.0, 97 | min_evaluate_clients=0, 98 | evaluate_fn=self.get_evaluation_fn(), 99 | on_fit_config_fn=self.get_client_config_fn(), 100 | initial_parameters=self.get_initial_parameters(), 101 | fit_metrics_aggregation_fn=utils.weighted_average_train_metrics, 102 | ) 103 | 104 | " Return ClientManager. " 105 | def client_manager(self, *args, **kwargs): 106 | return super(_Server, self).client_manager(*args, **kwargs) 107 | 108 | " Get model parameters. " 109 | def get_parameters(self, config={}): 110 | return self.model.get_weights() 111 | 112 | " Set model parameters" 113 | def set_parameters(self, parameters, config={}): 114 | if not hasattr(self, 'model'): 115 | self.model = self.model_loader(input_shape=self.input_shape[1:],num_classes=self.num_classes) 116 | self.model.compile(metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]) 117 | if parameters is not None: 118 | self.model.set_weights(parameters) 119 | 120 | " Get initial model parameters to be used for training." 121 | def get_initial_parameters(self, *_): 122 | if self.init_model_fp is not None: 123 | self.init_weights = tf.keras.models.load_model(self.init_model_fp, compile=False).get_weights() 124 | else: 125 | self.init_weights = self.model_loader(input_shape=self.input_shape[1:], num_classes=self.num_classes).get_weights() 126 | return fl.common.ndarrays_to_parameters(self.init_weights) 127 | 128 | " Get evaluation function to perform server-side evalation." 129 | def get_evaluation_fn(self): 130 | def evaluation_fn(rnd, parameters, config): 131 | 132 | metrics=None 133 | # Add train model information 134 | results = utils.set_train_metrics(config) 135 | results['num_clusters'] = self.num_clusters 136 | # Update model parameters 137 | self.set_parameters(parameters, config) 138 | 139 | # Centralized evaluation 140 | metrics = self.model.evaluate(self.data, verbose=0) 141 | results['model_size'] = utils.get_gzipped_model_size_from_model(self.model) 142 | results["accuracy"] = metrics[1] 143 | if self._is_final_round(rnd): 144 | self.model.save(self.model_save_dir(False), include_optimizer=False) 145 | print(f"[Server] - Round {rnd}: For {self.num_clusters} clusters, accuracy {metrics[1]:0.4f} " + 146 | f"(Val. Accuracy: {results['val_accuracy'] if 'val_accuracy' in results.keys() else 0.0 :0.4f}).") 147 | 148 | # Update best-seen performance (BEFORE setting number of clusters!) 149 | self.set_performance(results) 150 | # NOTE: Update number of clusters based on round performance 151 | results['compression'] = self.set_num_clusters(results) 152 | 153 | # Self compression with OOD + Re-evaluation 154 | if self.server_compression: 155 | if rnd in self.compression_rnds and self._num_compression_epochs(rnd)>0: 156 | self.model = compress.self_compress( 157 | self.model, self.num_classes, 158 | model_loader=self.model_loader, 159 | data_loader=self.server_compression_config['data_loader'], 160 | data_shape=self.data.element_spec[0].shape, 161 | nb_clusters=self.num_clusters, 162 | epochs=self._num_compression_epochs(rnd), 163 | batch_size=self.server_compression_config['batch_size'], 164 | learning_rate=self.server_compression_config['learning_rate'], 165 | temperature=self.server_compression_config['temperature'], 166 | seed=self.server_compression_config['seed']) 167 | metrics = self.model.evaluate(self.data, verbose=0) 168 | results['compressed_model_size'] = utils.get_gzipped_model_size_from_model(self.model) 169 | results['compressed_accuracy'] = metrics[1] 170 | 171 | if self._is_final_round(rnd): 172 | self.model.save(self.model_save_dir(True), include_optimizer=False) 173 | 174 | print(f"[Server] - Round {rnd}: Next training round will be executed with {self.num_clusters} clusters (Compression {results['compression']}).") 175 | return (metrics[0], results), self.get_parameters() 176 | return evaluation_fn 177 | 178 | " Get clients fit configuration function." 179 | def get_client_config_fn(self): 180 | def get_on_fit_config_fn(rnd): 181 | self.clients_config["round"] = rnd 182 | self.clients_config["num_clusters"] = self.num_clusters 183 | return self.clients_config 184 | return get_on_fit_config_fn 185 | -------------------------------------------------------------------------------- /src/strategy.py: -------------------------------------------------------------------------------- 1 | import flwr as fl 2 | import logging 3 | 4 | class FedCustom(fl.server.strategy.fedavg.FedAvg): 5 | 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.train_metrics_aggregated = {} 9 | self.parameters = None 10 | 11 | """Configure the next round of training.""" 12 | def configure_fit(self, server_round, parameters, client_manager): 13 | config = {} 14 | if self.on_fit_config_fn is not None: # Custom fit config function provided 15 | config = self.on_fit_config_fn(server_round) 16 | if self.parameters is not None: 17 | parameters = self.parameters 18 | fit_ins = fl.common.FitIns(parameters, config) 19 | # Sample clients 20 | sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available()) 21 | clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients) 22 | # Return client/config pairs 23 | return [(client, fit_ins) for client in clients] 24 | 25 | """Aggregate fit results using weighted average.""" 26 | def aggregate_fit(self, server_round, results, failures): 27 | if not results: 28 | return None, {} 29 | # Do not aggregate if there are failures and failures are not accepted 30 | if not self.accept_failures and failures: 31 | return None, {} 32 | 33 | # Convert results 34 | weights_results = [(fl.common.parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results] 35 | parameters_aggregated = fl.common.ndarrays_to_parameters(fl.server.strategy.aggregate.aggregate(weights_results)) 36 | 37 | # Aggregate custom metrics if aggregation fn was provided 38 | metrics_aggregated = {} 39 | if self.fit_metrics_aggregation_fn: 40 | fit_metrics = [(res.num_examples, res.metrics) for _, res in results] 41 | metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) 42 | self.train_metrics_aggregated = metrics_aggregated 43 | elif server_round == 1: # Only log this warning once 44 | fl.common.logger.log(logging.WARNING, "No fit_metrics_aggregation_fn provided") 45 | return parameters_aggregated, metrics_aggregated 46 | 47 | """Evaluate model parameters using an evaluation function.""" 48 | def evaluate(self, server_round, parameters): 49 | if self.evaluate_fn is None: 50 | # No evaluation function provided 51 | return None 52 | parameters_ndarrays = fl.common.parameters_to_ndarrays(parameters) 53 | eval_res, new_parameters = self.evaluate_fn(server_round, parameters_ndarrays, config=self.train_metrics_aggregated) 54 | self.parameters = fl.common.ndarrays_to_parameters(new_parameters) 55 | if eval_res is None: 56 | return None 57 | loss, metrics = eval_res 58 | return loss, metrics -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import tempfile 4 | import errno 5 | import contextlib 6 | import numpy as np 7 | import pickle 8 | import pandas as pd 9 | from time import sleep 10 | 11 | " Compute embedding from GMP layer." 12 | def compute_embeddings(model_loader, data, num_classes, weights): 13 | import tensorflow as tf 14 | # Need to create a new model to avoid errors. 15 | model = model_loader(input_shape=data.element_spec[0].shape[1:], num_classes=num_classes) 16 | model.set_weights(weights) 17 | # Set new model up to 'GMP_layer'. 18 | new_model = tf.keras.Model(inputs=model.input, outputs=model.get_layer("GMP_layer").output) 19 | new_model.compile(metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]) 20 | # Compute embeddings 21 | embeddings = new_model.predict(data, verbose=0) 22 | return embeddings 23 | 24 | " Compute embedding score from SVD analysis. " 25 | def compute_score(embeddings, normalize=True): 26 | import tensorflow as tf 27 | if normalize: 28 | embeddings = tf.math.l2_normalize(embeddings, axis=1) 29 | s = tf.linalg.svd(embeddings, compute_uv=False) 30 | s = s/(tf.math.reduce_sum(tf.math.abs(s))) 31 | s = tf.math.exp(-(tf.math.reduce_sum(s*tf.math.log(s)))) 32 | return s.numpy() 33 | 34 | " Compute energy score on data." 35 | def compute_energy_score(model, data, temp=-1.0): 36 | import tensorflow as tf 37 | energy_score = [] 38 | for (x,_) in data: 39 | preds = model(x,training=False) 40 | score = tf.math.reduce_logsumexp(preds/temp, axis=-1) 41 | energy_score.extend(score.numpy()) 42 | return np.asarray(energy_score).mean() 43 | 44 | "Agreggate metrics from clients fit results." 45 | def weighted_average_train_metrics(metrics): 46 | results = {} 47 | # Store metrics names 48 | metrics_keys = [m.keys() for _, m in metrics] 49 | # Compute standard metrics 50 | train_examples = [num_examples for num_examples, _ in metrics] 51 | results['train_loss'] = sum([num_examples * m["train_loss"] for num_examples, m in metrics]) / sum(train_examples) 52 | results['train_accuracy'] = sum([num_examples * m["train_accuracy"] for num_examples, m in metrics]) / sum(train_examples) 53 | # Compute validation metrics 54 | if all([set(['val_loss', 'val_accuracy', 'num_val_samples']) <= set(list(keys)) for keys in metrics_keys]): 55 | val_examples = [m['num_val_samples'] for _, m in metrics] 56 | results['val_loss'] = sum([num_examples * m["val_loss"] for num_examples,(_,m) in zip(val_examples,metrics)]) / sum(val_examples) 57 | results['val_accuracy'] = sum([num_examples * m["val_accuracy"] for num_examples,(_,m) in zip(val_examples,metrics)]) / sum(val_examples) 58 | # Compute model size 59 | if all('model_size' in keys for keys in metrics_keys): 60 | results['models_sizes'] = [m["model_size"] for _, m in metrics] 61 | results['models_size'] = sum(results['models_sizes'])/len(results['models_sizes']) 62 | # Compute validation/ood embedding scores 63 | if all('embeddings' in keys for keys in metrics_keys): 64 | results['val_score'] = compute_score(np.concatenate([m['embeddings'] for _,m in metrics], axis=0)) 65 | if all('odd_embeddings' in keys for keys in metrics_keys): 66 | results['odd_score'] = compute_score(np.concatenate([m['odd_embeddings'] for _,m in metrics], axis=0)) 67 | return results 68 | 69 | " Extract train metrics for a given round." 70 | def set_train_metrics(config={}): 71 | results={} 72 | # Store metrics names 73 | config_keys = list(config.keys()) 74 | # Store train loss/accuracy 75 | if set(['train_loss','train_accuracy']) <= set(config_keys): 76 | results['train_loss'] = config['train_loss'] 77 | results['train_accuracy'] = config['train_accuracy'] 78 | # Store validation loss/accuracy 79 | if set(['val_loss','val_accuracy']) <= set(config_keys): 80 | results['val_loss'] = config['val_loss'] 81 | results['val_accuracy'] = config['val_accuracy'] 82 | # Store train model size(s) 83 | if set(['model_size','models_sizes']) <= set(config_keys): 84 | results['client_model_size'] = config['model_size'] 85 | results['clients_model_size'] = config['models_sizes'] 86 | # Score embedding scores(s) 87 | if set(['val_score']) <= set(config_keys): 88 | results['val_score'] = config['val_score'] 89 | if set(['odd_score']) <= set(config_keys): 90 | results['odd_score'] = config['odd_score'] 91 | return results 92 | 93 | " Set number of federated rounds to perform server-side compression." 94 | def set_compression_rnds(compression_step, num_rounds): 95 | compression_rnds = [compression_step*i for i in range(1,(num_rounds+1)//compression_step)] 96 | # Force initial compression 97 | if 0 not in compression_rnds: 98 | compression_rnds = [0,] + compression_rnds 99 | # Force final compression 100 | if num_rounds not in compression_rnds: 101 | compression_rnds = compression_rnds + [num_rounds,] 102 | return compression_rnds 103 | 104 | " Check mechanism for increasing number of clusters." 105 | def compression_flag(df, init_num_clusters=8, max_num_clusters=20, init_rounds=1, patience=3, window=3, base=2, limit=0.01): 106 | # Compute rolling average of metric 107 | df['rolling_avg_acc'] = df['val_accuracy'].rolling(window=window, min_periods=1).mean() 108 | df['rolling_avg_score'] = df['val_score'].rolling(window=window, min_periods=1).mean() 109 | df['mask'] = df['rolling_avg_acc'].notna() 110 | # Get current number of clusters. 111 | num_clusters = df['nb_clusters'].iloc[-1] 112 | # Compute metric, when enough samples exist and maximum number of clusters is not reached. 113 | if df['mask'].any() and (num_clusters < max_num_clusters): 114 | df['threshold_acc'] = (limit / (base*(df['nb_clusters']-(init_num_clusters))).replace(0,1)) 115 | df['threshold_score'] = ((10*limit) / (base*(df['nb_clusters']-(init_num_clusters))).replace(0,1)) 116 | df['metric_acc'] = df['rolling_avg_acc'].diff() 117 | df['metric_score'] = df['rolling_avg_score'].diff() 118 | df['result_acc'] = df['metric_acc'] <= df['threshold_acc'] 119 | df['result_score'] = df['metric_score'] <= df['threshold_score'] 120 | df['result'] = df['result_acc'] & df['result_score'] 121 | print(df) 122 | if ((df['nb_clusters']==num_clusters).sum()>=patience) and (len(df) >= max(init_rounds,2)): 123 | return df['result'][df['mask']].iloc[-1] 124 | return False 125 | 126 | " Get number of cluster between two boundaries based on current round." 127 | def get_num_clusters(rnd, init_num_clusters, compression_rnds, min_num_clusters=1): 128 | # Find closest (upper) index of rnd in compression_rnds 129 | upper_closest = lambda r: compression_rnds.index(min([i for i in compression_rnds if i>=r], key=lambda x:abs(x-r))) 130 | # Decrease step of clusters per new compression 131 | decrease_step = (init_num_clusters-min_num_clusters)/len(compression_rnds) 132 | # Number of cluster in rnd 133 | num_clusters = int(init_num_clusters-(upper_closest(rnd)+1)*decrease_step) 134 | return (max(num_clusters, min_num_clusters)) 135 | 136 | " None or Str datatype." 137 | def none_or_str(value): 138 | if value == 'None': 139 | return None 140 | return value 141 | 142 | " None or Int datatype." 143 | def none_or_int(value): 144 | if value == 'None': 145 | return None 146 | return value 147 | 148 | " Create if exist without warning." 149 | def silentcreate(filename): 150 | try: 151 | os.makedirs(filename) 152 | except FileExistsError: 153 | pass 154 | 155 | " Remove if exist without warning." 156 | def silentremove(filename): 157 | try: 158 | os.remove(filename) 159 | except OSError as e: 160 | if e.errno != errno.ENOENT: 161 | raise 162 | 163 | def setup(id='0', mem=12000): 164 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 165 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 166 | os.environ['CUDA_VISIBLE_DEVICES'] = id 167 | import tensorflow as tf 168 | gpus = tf.config.list_physical_devices("GPU") 169 | if gpus: 170 | try: 171 | tf.config.set_logical_device_configuration(gpus[0],[tf.config.LogicalDeviceConfiguration(memory_limit=mem)]) 172 | except RuntimeError as e: 173 | print(e) 174 | else: 175 | print('No gpu available') 176 | 177 | def get_gzipped_model_size(file): 178 | _, zipped_file = tempfile.mkstemp(".zip") 179 | with zipfile.ZipFile(zipped_file, "w", compression=zipfile.ZIP_DEFLATED) as f: 180 | f.write(file) 181 | return os.path.getsize(zipped_file)/1000 182 | 183 | def get_gzipped_model_size_from_model(model): 184 | with contextlib.redirect_stdout(None): 185 | _, file = tempfile.mkstemp(".h5") 186 | model.save(file, include_optimizer=False) 187 | _, zipped_file = tempfile.mkstemp(".zip") 188 | with zipfile.ZipFile(zipped_file, "w", compression=zipfile.ZIP_DEFLATED) as f: 189 | f.write(file) 190 | return os.path.getsize(zipped_file)/1000 191 | 192 | def store_history(history, args, store_dir_fn=None, score_files=None): 193 | 194 | exceptions = ['num_clusters','model_size','accuracy','compression','compressed_model_size','compressed_accuracy'] 195 | data = history.metrics_centralized 196 | losses = history.losses_centralized 197 | data['loss'] = list(zip(*losses))[1][1:] 198 | data['rnd'] = list(zip(*losses))[0][1:] 199 | for key in data.keys(): 200 | if key in ['rnd','loss']: break 201 | data[key] = list(zip(*data[key]))[1][(1 if key in exceptions else 0):] 202 | 203 | # Add scores 204 | if score_files: 205 | types = ['val_score','val_norm_score','ood_score','odd_norm_score'] 206 | names = ['embeddings','w_embeddings','ood_embeddings','w_ood_embeddings'] 207 | rnd = lambda x: int(x.split('_')[-1].split('.')[0]) 208 | # Initialize 209 | for t in types: data[t] = {r:[] for r in data['rnd']} 210 | # Load 211 | for f in score_files: 212 | temp = pickle.load(open(f,'rb')) 213 | for i,j in zip(names,types): 214 | if i in temp.keys(): data[j][rnd(f)].extend(temp[i].numpy() if not isinstance(temp[i],np.ndarray) else temp[i]) 215 | # Combine 216 | for i in types: 217 | data[i] = [compute_score(np.asarray(data[i][j])) for j in data['rnd']] 218 | 219 | df = pd.DataFrame(data) 220 | df = df.rename(columns={"accuracy": "test_accuracy", "loss": "test_loss"}) 221 | 222 | if (not args.server_compression) and (not args.client_compression): 223 | df = df.drop('num_clusters', axis=1) 224 | 225 | if store_dir_fn is not None: 226 | # Store results 227 | df.to_pickle(store_dir_fn(('metrics','pkl'))) 228 | pickle.dump(args, open(store_dir_fn(('args','pkl')), 'wb'), protocol=pickle.HIGHEST_PROTOCOL) 229 | 230 | return df 231 | 232 | def create_dataloader_fn(ood=False): 233 | import data 234 | if ood: 235 | return {"cifar10": data.get_stylegan, "spcm": data.get_librispeech,} 236 | return {"cifar10": data.get_cifar10, "spcm": data.get_spcm,} 237 | --------------------------------------------------------------------------------