├── .gitignore ├── README.md ├── configurations ├── femnist_conv_fedavg.json ├── femnist_conv_fedprox.json ├── femnist_conv_fedprox_E20.json ├── femnist_conv_virtual_natural.json ├── femnist_fedprox.json ├── femnist_virtual.json ├── femnist_virtual_natural_sparse_delta.json ├── har_fedprox.json ├── har_virtual.json ├── mnist_fedprox.json ├── mnist_virtual.json ├── pmnist_fedprox.json ├── pmnist_virtual.json ├── shakespeare_fedprox.json ├── shakespeare_virtual_natural.json ├── vsn_fedprox.json └── vsn_virtual.json ├── environment.yml ├── main.py ├── requirements.txt ├── source ├── __init__.py ├── centered_layers.py ├── constants.py ├── data_utils.py ├── experiment_utils.py ├── fed_process.py ├── fed_prox.py ├── federated_devices.py ├── gate_layer.py ├── learning_rate_multipliers_opt.py ├── natural_raparametrization_layer.py ├── normal_natural.py ├── tfp_utils.py ├── utils.py └── virtual_process.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | logs/ 3 | configurations/ 4 | data/ 5 | __pycache__/ 6 | .DS_Store 7 | .python-version 8 | results/ 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VIRTUAL 2 | 3 | The VIRTUAL package implements a model to perform federated multi-task learning with with variational NNs. 4 | 5 | 6 | ## Getting Started 7 | 8 | We recommend to setup Miniconda to create a python environment from the enviroment file environment.yml 9 | 10 | ``` 11 | conda env create -f environment.yml 12 | source activate virtual 13 | ``` 14 | 15 | Additionally install pip packages in the same environment: 16 | 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## General usage 22 | 23 | 24 | To reproduce the experiments of the paper use the main.py file, giving a configuration file from the configuration folder as 25 | 26 | ``` 27 | python main.py configurations/femnist_virtual.json 28 | ``` 29 | 30 | Then track the experiment using tensorboard as 31 | 32 | ``` 33 | tensorboard --logdir logs/femnist_virtual_* 34 | ``` 35 | 36 | Hyperparameters and relative metrics are tracked using the HPARAM API of tensorboard (see https://www.tensorflow.org/tensorboard/hyperparameter_tuning_with_hparams). 37 | 38 | Note that if you can not use any GPU, you have to specify "session": { "num_gpus": 0} in the config file. -------------------------------------------------------------------------------- /configurations/femnist_conv_fedavg.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1, 4 | "verbose": 0 5 | }, 6 | "data_set_conf":{ 7 | "name": "femnist", 8 | "num_clients": 100, 9 | "shape": [1, 28, 28] 10 | }, 11 | "training_conf": {"method": "fedprox", 12 | "tot_epochs_per_client": 10000, 13 | "optimizer": "sgd", 14 | "tensorboard_updates": 1 15 | }, 16 | "model_conf": {"layers": [{"name": "Conv2DCentered", 17 | "input_shape": [1, 28, 28], 18 | "filters": 32, 19 | "kernel_size": 5, 20 | "padding": "SAME", 21 | "activation": "relu" 22 | }, 23 | { "name": "MaxPooling2D", 24 | "pool_size": [2, 2], 25 | "strides": [2, 2], 26 | "padding": "SAME" 27 | }, 28 | {"name": "Conv2DCentered", 29 | "input_shape": [1, 28, 28], 30 | "filters": 64, 31 | "kernel_size": 5, 32 | "padding": "SAME", 33 | "activation": "relu" 34 | }, 35 | { "name": "MaxPooling2D", 36 | "pool_size": [2, 2], 37 | "strides": [2, 2], 38 | "padding": "SAME" 39 | }, 40 | { "name": "Flatten" 41 | }, 42 | { "name": "DenseCentered", 43 | "units": 100, 44 | "activation": "relu" 45 | }, 46 | { "name": "DenseCentered", 47 | "units": 100, 48 | "activation": "relu" 49 | }, 50 | { "name": "DenseCentered", 51 | "units": 10, 52 | "activation": "softmax" 53 | } 54 | ] 55 | }, 56 | "hp": {"epochs_per_round": [100], 57 | "clients_per_round": [10], 58 | "learning_rate": [0.001], 59 | "batch_size": [20], 60 | "l2_reg": [0.0], 61 | "server_learning_rate": [1], 62 | "damping_factor": [1] 63 | } 64 | } -------------------------------------------------------------------------------- /configurations/femnist_conv_fedprox.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1, 4 | "verbose": 0 5 | }, 6 | "data_set_conf":{ 7 | "name": "femnist", 8 | "num_clients": 100, 9 | "shape": [1, 28, 28] 10 | }, 11 | "training_conf": {"method": "fedprox", 12 | "tot_epochs_per_client": 500, 13 | "optimizer": "sgd", 14 | "tensorboard_updates": 1 15 | }, 16 | "model_conf": {"layers": [{"name": "Conv2DCentered", 17 | "input_shape": [1, 28, 28], 18 | "filters": 32, 19 | "kernel_size": 5, 20 | "padding": "SAME", 21 | "activation": "relu" 22 | }, 23 | { "name": "MaxPooling2D", 24 | "pool_size": [2, 2], 25 | "strides": [2, 2], 26 | "padding": "SAME" 27 | }, 28 | {"name": "Conv2DCentered", 29 | "input_shape": [1, 28, 28], 30 | "filters": 64, 31 | "kernel_size": 5, 32 | "padding": "SAME", 33 | "activation": "relu" 34 | }, 35 | { "name": "MaxPooling2D", 36 | "pool_size": [2, 2], 37 | "strides": [2, 2], 38 | "padding": "SAME" 39 | }, 40 | { "name": "Flatten" 41 | }, 42 | { "name": "DenseCentered", 43 | "units": 100, 44 | "activation": "relu" 45 | }, 46 | { "name": "DenseCentered", 47 | "units": 100, 48 | "activation": "relu" 49 | }, 50 | { "name": "DenseCentered", 51 | "units": 10, 52 | "activation": "softmax" 53 | } 54 | ] 55 | }, 56 | "hp": {"epochs_per_round": [20, 100], 57 | "clients_per_round": [10], 58 | "learning_rate": [5e-4, 1e-3, 2e-3, 5e-3, 1e-2], 59 | "batch_size": [20], 60 | "l2_reg": [0.0, 1e-5, 1e-4, 1e-3, 1e-2], 61 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1] 62 | } 63 | } -------------------------------------------------------------------------------- /configurations/femnist_conv_fedprox_E20.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1, 4 | "verbose": 0 5 | }, 6 | "data_set_conf":{ 7 | "name": "femnist", 8 | "num_clients": 100, 9 | "shape": [1, 28, 28] 10 | }, 11 | "training_conf": {"method": "fedprox", 12 | "tot_epochs_per_client": 10000, 13 | "optimizer": "sgd", 14 | "tensorboard_updates": 1 15 | }, 16 | "model_conf": {"layers": [{"name": "Conv2DCentered", 17 | "input_shape": [1, 28, 28], 18 | "filters": 32, 19 | "kernel_size": 5, 20 | "padding": "SAME", 21 | "activation": "relu" 22 | }, 23 | { "name": "MaxPooling2D", 24 | "pool_size": [2, 2], 25 | "strides": [2, 2], 26 | "padding": "SAME" 27 | }, 28 | {"name": "Conv2DCentered", 29 | "input_shape": [1, 28, 28], 30 | "filters": 64, 31 | "kernel_size": 5, 32 | "padding": "SAME", 33 | "activation": "relu" 34 | }, 35 | { "name": "MaxPooling2D", 36 | "pool_size": [2, 2], 37 | "strides": [2, 2], 38 | "padding": "SAME" 39 | }, 40 | { "name": "Flatten" 41 | }, 42 | { "name": "DenseCentered", 43 | "units": 100, 44 | "activation": "relu" 45 | }, 46 | { "name": "DenseCentered", 47 | "units": 100, 48 | "activation": "relu" 49 | }, 50 | { "name": "DenseCentered", 51 | "units": 10, 52 | "activation": "softmax" 53 | } 54 | ] 55 | }, 56 | "hp": {"epochs_per_round": [20], 57 | "clients_per_round": [10], 58 | "learning_rate": [5e-4, 1e-3, 2e-3, 5e-3, 1e-2], 59 | "batch_size": [20], 60 | "l2_reg": [0.0, 1e-5, 1e-4, 1e-3, 1e-2], 61 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1] 62 | } 63 | } -------------------------------------------------------------------------------- /configurations/femnist_conv_virtual_natural.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1, 4 | "verbose": 0 5 | }, 6 | "data_set_conf":{ 7 | "name": "femnist", 8 | "num_clients": 100, 9 | "shape": [1, 28, 28] 10 | }, 11 | "training_conf": {"method": "virtual", 12 | "tot_epochs_per_client": 500, 13 | "batch_size": 20, 14 | "optimizer": "sgd", 15 | "fed_avg_init": false, 16 | "tensorboard_updates": 1 17 | }, 18 | "model_conf": {"layers": [{"name": "Conv2DVirtualNatural", 19 | "input_shape": [1, 28, 28], 20 | "filters": 32, 21 | "kernel_size": 5, 22 | "padding": "SAME", 23 | "activation": "relu", 24 | "bias_posterior_fn": null 25 | }, 26 | { "name": "MaxPooling2D", 27 | "pool_size": [2, 2], 28 | "strides": [2, 2], 29 | "padding": "SAME" 30 | }, 31 | {"name": "Conv2DVirtualNatural", 32 | "input_shape": [1, 28, 28], 33 | "filters": 64, 34 | "kernel_size": 5, 35 | "padding": "SAME", 36 | "activation": "relu", 37 | "bias_posterior_fn": null 38 | }, 39 | { "name": "MaxPooling2D", 40 | "pool_size": [2, 2], 41 | "strides": [2, 2], 42 | "padding": "SAME" 43 | }, 44 | { "name": "Flatten" 45 | }, 46 | { "name": "DenseReparametrizationNaturalShared", 47 | "units": 100, 48 | "activation": "relu", 49 | "bias_posterior_fn": null 50 | }, 51 | { "name": "DenseReparametrizationNaturalShared", 52 | "units": 100, 53 | "activation": "relu", 54 | "bias_posterior_fn": null 55 | }, 56 | { "name": "DenseReparametrizationNaturalShared", 57 | "units": 10, 58 | "activation": "softmax", 59 | "bias_posterior_fn": null 60 | } 61 | ], 62 | "prior_scale": 1.0 63 | }, 64 | "hp": { 65 | "learning_rate": [0.0001], 66 | "natural_lr": [1e8, 2e8, 5e8, 1e9, 2e9], 67 | "kl_weight": [0.0, 1e-5, 1e-4, 1e-3, 1e-2], 68 | "hierarchical": [false], 69 | "scale_init": [[-4.45, 0.85]], 70 | "loc_init": [[0, 0.65]], 71 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0], 72 | "clients_per_round": [10], 73 | "epochs_per_round": [20, 100] 74 | } 75 | } -------------------------------------------------------------------------------- /configurations/femnist_fedprox.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1, 4 | "verbose": 0 5 | }, 6 | "data_set_conf":{ 7 | "name": "femnist", 8 | "num_clients": 100 9 | }, 10 | "training_conf": {"method": "fedprox", 11 | "tot_epochs_per_client": 500, 12 | "optimizer": "sgd", 13 | "tensorboard_updates": 1 14 | }, 15 | "model_conf": { "layers": [{ 16 | "input_shape": [784], 17 | "name": "DenseCentered", 18 | "units": 100, 19 | "activation": "relu", 20 | "use_bias": false 21 | }, 22 | { 23 | "name": "DenseCentered", 24 | "units": 100, 25 | "activation": "relu", 26 | "use_bias": false 27 | }, 28 | { 29 | "name": "DenseCentered", 30 | "units": 10, 31 | "activation": "softmax", 32 | "use_bias": false 33 | }] 34 | }, 35 | "hp": {"epochs_per_round": [20, 100], 36 | "learning_rate": [1e-3, 2e-3, 5e-3, 1e-2, 2e-2], 37 | "batch_size": [20], 38 | "l2_reg": [0, 1e-5, 1e-4, 1e-3, 1e-2], 39 | "clients_per_round": [10], 40 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 41 | } 42 | } -------------------------------------------------------------------------------- /configurations/femnist_virtual.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1, 4 | "verbose": 0 5 | }, 6 | "data_set_conf":{ 7 | "name": "femnist", 8 | "num_clients": 100 9 | }, 10 | "training_conf": {"method": "virtual", 11 | "tot_epochs_per_client": 500, 12 | "fed_avg_init": false, 13 | "tensorboard_updates": 1 14 | }, 15 | "model_conf": { 16 | "layers": [{ 17 | "input_shape": [784], 18 | "name": "DenseReparametrizationNaturalShared", 19 | "units": 100, 20 | "activation": "relu", 21 | "bias_posterior_fn": null 22 | }, 23 | { 24 | "name": "DenseReparametrizationNaturalShared", 25 | "units": 100, 26 | "activation": "relu", 27 | "bias_posterior_fn": null 28 | }, 29 | { 30 | "name": "DenseReparametrizationNaturalShared", 31 | "units": 10, 32 | "activation": "softmax", 33 | "bias_posterior_fn": null 34 | }], 35 | "prior_scale": 1.0 36 | }, 37 | "hp": {"epochs_per_round": [20, 100], 38 | "natural_lr": [1e8, 2e8, 5e8, 1e9, 2e9], 39 | "kl_weight": [0, 1e-5, 1e-4, 1e-3, 1e-2], 40 | "batch_size": [20], 41 | "hierarchical": [false], 42 | "clients_per_round": [10], 43 | "learning_rate": [0.001], 44 | "optimizer": ["sgd"], 45 | "scale_init": [[-4.85, 0.45]], 46 | "loc_init": [[0,0.5]], 47 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 48 | } 49 | } -------------------------------------------------------------------------------- /configurations/femnist_virtual_natural_sparse_delta.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1, 4 | "check_numerics": false 5 | }, 6 | "data_set_conf":{ 7 | "name": "femnist", 8 | "num_clients": 100 9 | }, 10 | "training_conf": {"method": "virtual", 11 | "tot_epochs_per_client": 500, 12 | "tensorboard_updates": 1, 13 | "verbose": 0 14 | }, 15 | "model_conf": { 16 | "layers": [{ 17 | "input_shape": [784], 18 | "name": "DenseLocalReparametrizationNaturalShared", 19 | "units": 100, 20 | "activation": "relu", 21 | "bias_posterior_fn": null 22 | }, 23 | { 24 | "name": "DenseLocalReparametrizationNaturalShared", 25 | "units": 100, 26 | "activation": "relu", 27 | "bias_posterior_fn": null 28 | }, 29 | { 30 | "name": "DenseLocalReparametrizationNaturalShared", 31 | "units": 10, 32 | "activation": "softmax", 33 | "bias_posterior_fn": null 34 | }], 35 | "prior_scale": 1.0 36 | }, 37 | "hp": {"epochs_per_round": [20], 38 | "kl_weight": [0.000001], 39 | "batch_size": [20], 40 | "hierarchical": [false], 41 | "clients_per_round": [10], 42 | "learning_rate": [0.01], 43 | "natural_lr": [2e6, 5e6, 1e7, 2e7, 5e7], 44 | "server_learning_rate": [0.2], 45 | "optimizer": ["sgd"], 46 | "scale_init": [[-4.85, 0.45]], 47 | "loc_init": [[0.0, 0.5]], 48 | "fed_avg_init": [0, 1], 49 | "delta_percentile": [0, 50, 75, 90, 95] 50 | } 51 | } -------------------------------------------------------------------------------- /configurations/har_fedprox.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "human_activity", 7 | "num_clients": -1 8 | }, 9 | "training_conf": {"method": "fedprox", 10 | "tot_epochs_per_client": 200, 11 | "optimizer": "sgd", 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": {"layers": [{ 15 | "input_shape": [561], 16 | "name": "DenseCentered", 17 | "units": 100, 18 | "activation": "relu", 19 | "use_bias": false 20 | }, 21 | { 22 | "name": "DenseCentered", 23 | "units": 100, 24 | "activation": "relu", 25 | "use_bias": false 26 | }, 27 | { 28 | "name": "DenseCentered", 29 | "units": 12, 30 | "activation": "softmax", 31 | "use_bias": false 32 | }] 33 | }, 34 | "hp": {"epochs_per_round": [20], 35 | "learning_rate": [1e-3, 2e-3, 5e-3, 1e-2, 2e-2], 36 | "batch_size": [20], 37 | "l2_reg": [0, 1e-5, 1e-4, 1e-3, 1e-2], 38 | "clients_per_round": [10], 39 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 40 | } 41 | } -------------------------------------------------------------------------------- /configurations/har_virtual.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "human_activity", 7 | "num_clients": -1 8 | }, 9 | "training_conf": {"method": "virtual", 10 | "tot_epochs_per_client": 200, 11 | "fed_avg_init": false, 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": {"layers": [{ 15 | "input_shape": [561], 16 | "name": "DenseReparametrizationNaturalShared", 17 | "units": 100, 18 | "activation": "relu", 19 | "bias_posterior_fn": null 20 | }, 21 | { 22 | "name": "DenseReparametrizationNaturalShared", 23 | "units": 100, 24 | "activation": "relu", 25 | "bias_posterior_fn": null 26 | }, 27 | { 28 | "name": "DenseReparametrizationNaturalShared", 29 | "units": 12, 30 | "activation": "softmax", 31 | "bias_posterior_fn": null 32 | }], 33 | "hierarchical": false, 34 | "prior_scale": 1.0 35 | }, 36 | "hp": {"epochs_per_round": [20], 37 | "natural_lr": [1e8, 2e8, 5e8, 1e9, 2e9], 38 | "kl_weight": [0, 1e-5, 1e-4, 1e-3, 1e-2], 39 | "batch_size": [20], 40 | "hierarchical": [false], 41 | "clients_per_round": [10], 42 | "learning_rate": [0.001], 43 | "optimizer": ["sgd"], 44 | "scale_init": [[-4.8, 0.45]], 45 | "loc_init": [[0.0, 0.3]], 46 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 47 | } 48 | } -------------------------------------------------------------------------------- /configurations/mnist_fedprox.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "mnist", 7 | "num_clients": 100 8 | }, 9 | "training_conf": {"method": "fedprox", 10 | "tot_epochs_per_client": 500, 11 | "optimizer": "sgd", 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": { 15 | "layers": [ 16 | { 17 | "input_shape": [ 18 | 784 19 | ], 20 | "name": "DenseCentered", 21 | "units": 100, 22 | "activation": "relu", 23 | "use_bias": false 24 | }, 25 | { 26 | "name": "DenseCentered", 27 | "units": 100, 28 | "activation": "relu", 29 | "use_bias": false 30 | }, 31 | { 32 | "name": "DenseCentered", 33 | "units": 10, 34 | "activation": "softmax", 35 | "use_bias": false 36 | } 37 | ] 38 | }, 39 | "hp": {"epochs_per_round": [20], 40 | "learning_rate": [1e-3, 2e-3, 5e-3, 1e-2, 2e-2], 41 | "batch_size": [20], 42 | "l2_reg": [0, 1e-5, 1e-4, 1e-3, 1e-2], 43 | "clients_per_round": [10], 44 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 45 | } 46 | } -------------------------------------------------------------------------------- /configurations/mnist_virtual.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "mnist", 7 | "num_clients": 100 8 | }, 9 | "training_conf": {"method": "virtual", 10 | "tot_epochs_per_client": 500, 11 | "fed_avg_init": false, 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": {"layers": [{ 15 | "input_shape": [784], 16 | "name": "DenseReparametrizationNaturalShared", 17 | "units": 100, 18 | "activation": "relu", 19 | "bias_posterior_fn": null 20 | }, 21 | { 22 | "name": "DenseReparametrizationNaturalShared", 23 | "units": 100, 24 | "activation": "relu", 25 | "bias_posterior_fn": null 26 | }, 27 | { 28 | "name": "DenseReparametrizationNaturalShared", 29 | "units": 10, 30 | "activation": "softmax", 31 | "bias_posterior_fn": null 32 | }], 33 | "hierarchical": false, 34 | "prior_scale": 1.0 35 | }, 36 | "hp": {"epochs_per_round": [20], 37 | "natural_lr": [1e8, 2e8, 5e8, 1e9, 2e9], 38 | "kl_weight": [0, 1e-5, 1e-4, 1e-3, 1e-2], 39 | "batch_size": [20], 40 | "hierarchical": [false], 41 | "clients_per_round": [10], 42 | "learning_rate": [0.001], 43 | "optimizer": ["sgd"], 44 | "scale_init": [[-4.85, 0.45]], 45 | "loc_init": [[0,0.5]], 46 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 47 | } 48 | } -------------------------------------------------------------------------------- /configurations/pmnist_fedprox.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "pmnist", 7 | "num_clients": 100 8 | }, 9 | "training_conf": {"method": "fedprox", 10 | "tot_epochs_per_client": 500, 11 | "optimizer": "sgd", 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": { 15 | "layers": [ 16 | { 17 | "input_shape": [ 18 | 784 19 | ], 20 | "name": "DenseCentered", 21 | "units": 100, 22 | "activation": "relu", 23 | "use_bias": false 24 | }, 25 | { 26 | "name": "DenseCentered", 27 | "units": 100, 28 | "activation": "relu", 29 | "use_bias": false 30 | }, 31 | { 32 | "name": "DenseCentered", 33 | "units": 10, 34 | "activation": "softmax", 35 | "use_bias": false 36 | } 37 | ] 38 | }, 39 | "hp": {"epochs_per_round": [20, 100], 40 | "learning_rate": [1e-3, 2e-3, 5e-3, 1e-2, 2e-2], 41 | "batch_size": [20], 42 | "l2_reg": [0, 1e-5, 1e-4, 1e-3, 1e-2], 43 | "clients_per_round": [10], 44 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 45 | } 46 | } -------------------------------------------------------------------------------- /configurations/pmnist_virtual.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "pmnist", 7 | "num_clients": 100 8 | }, 9 | "training_conf": {"method": "virtual", 10 | "tot_epochs_per_client": 500, 11 | "fed_avg_init": false, 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": {"layers": [{ 15 | "input_shape": [784], 16 | "name": "DenseReparametrizationNaturalShared", 17 | "units": 100, 18 | "activation": "relu", 19 | "bias_posterior_fn": null 20 | }, 21 | { 22 | "name": "DenseReparametrizationNaturalShared", 23 | "units": 100, 24 | "activation": "relu", 25 | "bias_posterior_fn": null 26 | }, 27 | { 28 | "name": "DenseReparametrizationNaturalShared", 29 | "units": 10, 30 | "activation": "softmax", 31 | "bias_posterior_fn": null 32 | }], 33 | "hierarchical": false, 34 | "prior_scale": 1.0 35 | }, 36 | "hp": {"epochs_per_round": [20], 37 | "natural_lr": [1e8, 2e8, 5e8, 1e9, 2e9], 38 | "kl_weight": [0, 1e-5, 1e-4, 1e-3, 1e-2], 39 | "batch_size": [20], 40 | "hierarchical": [false], 41 | "clients_per_round": [10], 42 | "learning_rate": [0.001], 43 | "optimizer": ["sgd"], 44 | "scale_init": [[-4.85, 0.45]], 45 | "loc_init": [[0,0.5]], 46 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 47 | } 48 | } -------------------------------------------------------------------------------- /configurations/shakespeare_fedprox.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "shakespeare", 7 | "num_clients": 1, 8 | "seq_length": 80, 9 | "vocab_size": 86 10 | }, 11 | "training_conf": {"method": "fedprox", 12 | "tot_epochs_per_client": 300, 13 | "optimizer": "sgd", 14 | "fed_avg_init": false, 15 | "tensorboard_updates": 1, 16 | "verbose": 0 17 | }, 18 | "model_conf": {"layers": [{"name": "EmbeddingCentered", 19 | "input_dim": 87, 20 | "output_dim": 8 21 | }, 22 | {"name": "LSTMCellCentered", 23 | "units": 256, 24 | "use_bias": false 25 | }, 26 | {"name": "LSTMCellCentered", 27 | "units": 256, 28 | "use_bias": false 29 | }, 30 | {"name": "DenseCentered", 31 | "units": 87, 32 | "activation": "softmax", 33 | "use_bias": false 34 | } 35 | ], 36 | "architecture": "rnn" 37 | }, 38 | "hp": {"epochs_per_round": [20], 39 | "learning_rate": [1e-1, 2e-1, 5e-1, 1, 2], 40 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0], 41 | "batch_size": [10], 42 | "l2_reg": [0.0, 1e-5, 1e-4, 1e-3, 1e-2], 43 | "clients_per_round": [10] 44 | } 45 | } -------------------------------------------------------------------------------- /configurations/shakespeare_virtual_natural.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "shakespeare", 7 | "num_clients": 1, 8 | "seq_length": 80, 9 | "vocab_size": 86 10 | }, 11 | "training_conf": {"method": "virtual", 12 | "tot_epochs_per_client": 300, 13 | "optimizer": "sgd", 14 | "fed_avg_init": false, 15 | "tensorboard_updates": 1, 16 | "verbose": 0 17 | 18 | }, 19 | "model_conf": {"layers": [{"name": "EmbeddingCentered", 20 | "input_dim": 87, 21 | "output_dim": 8 22 | }, 23 | {"name": "LSTMCellVariationalNatural", 24 | "units": 256, 25 | "use_bias": false 26 | }, 27 | {"name": "LSTMCellVariationalNatural", 28 | "units": 256, 29 | "use_bias": false 30 | }, 31 | {"name": "DenseLocalReparametrizationNaturalShared", 32 | "units": 87, 33 | "activation": "softmax", 34 | "bias_posterior_fn": null 35 | } 36 | ], 37 | "prior_scale": 1.0, 38 | "architecture": "rnn" 39 | }, 40 | "hp": {"epochs_per_round": [20], 41 | "natural_lr": [1e9, 2e9, 5e9, 1e10, 2e10], 42 | "kl_weight": [0, 1e-5, 1e-4, 1e-3, 1e-2], 43 | "batch_size": [20], 44 | "hierarchical": [false], 45 | "clients_per_round": [10], 46 | "learning_rate": [10.0], 47 | "optimizer": ["sgd"], 48 | "scale_init": [[-4.85, 0.45]], 49 | "loc_init": [[0,0.5]], 50 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 51 | } 52 | } -------------------------------------------------------------------------------- /configurations/vsn_fedprox.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "vehicle_sensor", 7 | "num_clients": -1 8 | }, 9 | "training_conf": {"method": "fedprox", 10 | "tot_epochs_per_client": 200, 11 | "optimizer": "sgd", 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": { 15 | "layers": [ 16 | { 17 | "input_shape": [ 18 | 100 19 | ], 20 | "name": "DenseCentered", 21 | "units": 100, 22 | "activation": "relu", 23 | "use_bias": false 24 | }, 25 | { 26 | "name": "DenseCentered", 27 | "units": 100, 28 | "activation": "relu", 29 | "use_bias": false 30 | }, 31 | { 32 | "name": "DenseCentered", 33 | "units": 2, 34 | "activation": "softmax", 35 | "use_bias": false 36 | } 37 | ] 38 | }, 39 | "hp": {"epochs_per_round": [20, 100], 40 | "learning_rate": [1e-3, 2e-3, 5e-3, 1e-2, 2e-2], 41 | "batch_size": [20], 42 | "l2_reg": [0, 1e-5, 1e-4, 1e-3, 1e-2], 43 | "clients_per_round": [10], 44 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 45 | } 46 | } -------------------------------------------------------------------------------- /configurations/vsn_virtual.json: -------------------------------------------------------------------------------- 1 | { 2 | "session":{ 3 | "num_gpus": 1 4 | }, 5 | "data_set_conf":{ 6 | "name": "vehicle_sensor", 7 | "num_clients": -1 8 | }, 9 | "training_conf": {"method": "virtual", 10 | "tot_epochs_per_client": 200, 11 | "fed_avg_init": false, 12 | "tensorboard_updates": 1 13 | }, 14 | "model_conf": {"layers": [{ 15 | "input_shape": [100], 16 | "name": "DenseReparametrizationNaturalShared", 17 | "units": 100, 18 | "activation": "relu", 19 | "bias_posterior_fn": null 20 | }, 21 | { 22 | "name": "DenseReparametrizationNaturalShared", 23 | "units": 100, 24 | "activation": "relu", 25 | "bias_posterior_fn": null 26 | }, 27 | { 28 | "name": "DenseReparametrizationNaturalShared", 29 | "units": 2, 30 | "activation": "softmax", 31 | "bias_posterior_fn": null 32 | }], 33 | "hierarchical": false, 34 | "prior_scale": 1.0 35 | }, 36 | "hp": {"epochs_per_round": [20], 37 | "natural_lr": [1e8, 2e8, 5e8, 1e9, 2e9], 38 | "kl_weight": [0, 1e-5, 1e-4, 1e-3, 1e-2], 39 | "batch_size": [20], 40 | "hierarchical": [false], 41 | "clients_per_round": [10], 42 | "learning_rate": [0.001], 43 | "optimizer": ["sgd"], 44 | "scale_init": [[-4.8, 0.45]], 45 | "loc_init": [[0.0, 0.3]], 46 | "server_learning_rate": [0.2, 0.4, 0.6, 0.8, 1.0] 47 | } 48 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: virtual 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=0_gnu 8 | - ca-certificates=2020.4.5.1=hecc5488_0 9 | - certifi=2020.4.5.1=py37hc8dfbb8_0 10 | - ld_impl_linux-64=2.33.1=h53a641e_7 11 | - libffi=3.2.1=he1b5a44_1006 12 | - libgcc-ng=9.2.0=h24d8f2e_2 13 | - libgomp=9.2.0=h24d8f2e_2 14 | - libstdcxx-ng=9.2.0=hdf63c60_2 15 | - ncurses=6.1=hf484d3e_1002 16 | - openssl=1.1.1f=h516909a_0 17 | - pip=19.3.1=py37_0 18 | - python=3.7.6=h357f687_2 19 | - python_abi=3.7=1_cp37m 20 | - readline=8.0=hf8c457e_0 21 | - sqlite=3.30.1=hcee41ef_0 22 | - tk=8.6.10=hed695b0_0 23 | - tqdm=4.45.0=pyh9f0ad1d_0 24 | - xz=5.2.4=h14c3975_1001 25 | - zlib=1.2.11=h516909a_1006 26 | - pip: 27 | - absl-py==0.9.0 28 | - astor==0.8.1 29 | - attrs==18.2.0 30 | - cachetools==4.0.0 31 | - chardet==3.0.4 32 | - cloudpickle==1.2.2 33 | - decorator==4.4.1 34 | - dm-tree==0.1.1 35 | - enum34==1.1.6 36 | - gast==0.2.2 37 | - gitdb==4.0.5 38 | - gitpython==3.1.2 39 | - google-auth==1.10.1 40 | - google-auth-oauthlib==0.4.1 41 | - google-pasta==0.1.8 42 | - gpustat==0.6.0 43 | - gputil==1.4.0 44 | - grpcio==1.26.0 45 | - gviz-api==1.9.0 46 | - h5py==2.10.0 47 | - idna==2.8 48 | - joblib==0.14.1 49 | - keras==2.1.0 50 | - keras-applications==1.0.8 51 | - keras-preprocessing==1.1.0 52 | - markdown==3.1.1 53 | - mpmath==1.1.0 54 | - numpy==1.18.1 55 | - oauthlib==3.1.0 56 | - opt-einsum==3.1.0 57 | - pandas==1.0.0 58 | - portpicker==1.3.1 59 | - protobuf==3.11.2 60 | - psutil==5.6.7 61 | - pyasn1==0.4.8 62 | - pyasn1-modules==0.2.8 63 | - python-dateutil==2.8.1 64 | - pytz==2019.3 65 | - pyyaml==5.3.1 66 | - requests==2.22.0 67 | - requests-oauthlib==1.3.0 68 | - retrying==1.3.3 69 | - rsa==4.0 70 | - scikit-learn==0.22.1 71 | - scipy==1.4.1 72 | - setuptools==45.0.0 73 | - six==1.14.0 74 | - smmap==3.0.4 75 | - tensorboard==2.3.0 76 | - tensorboard-plugin-profile==2.3.0 77 | - tensorboard-plugin-wit==1.7.0 78 | - tensorflow==2.1.0 79 | - tensorflow-addons==0.6.0 80 | - tensorflow-estimator==2.1.0 81 | - tensorflow-federated==0.11.0 82 | - tensorflow-gpu==2.0.0 83 | - tensorflow-model-optimization==0.1.3 84 | - tensorflow-privacy==0.2.2 85 | - tensorflow-probability==0.9.0 86 | - termcolor==1.1.0 87 | - urllib3==1.25.7 88 | - werkzeug==0.16.0 89 | - wheel==0.33.6 90 | - wrapt==1.11.2 91 | 92 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from os import environ as os_environ 2 | import sys 3 | from git import Repo 4 | import subprocess 5 | from shutil import copytree 6 | import datetime 7 | from itertools import product 8 | from pathlib import Path 9 | import argparse 10 | import logging 11 | import tensorflow as tf 12 | from tensorboard.plugins.hparams import api as hp 13 | import json 14 | import gc 15 | 16 | from source.data_utils import federated_dataset, batch_dataset 17 | from source.utils import gpu_session 18 | from source.experiment_utils import (run_simulation, 19 | get_compiled_model_fn_from_dict) 20 | from source.constants import ROOT_LOGGER_STR, LOGGER_RESULT_FILE 21 | 22 | 23 | logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__) 24 | 25 | 26 | def _setup_logger(results_path, create_stdlog): 27 | """Setup a general logger which saves all logs in the experiment folder""" 28 | 29 | f_format = logging.Formatter( 30 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 31 | f_handler = logging.FileHandler(str(results_path)) 32 | f_handler.setLevel(logging.DEBUG) 33 | f_handler.setFormatter(f_format) 34 | 35 | root_logger = logging.getLogger(ROOT_LOGGER_STR) 36 | root_logger.handlers = [] 37 | root_logger.setLevel(logging.DEBUG) 38 | root_logger.addHandler(f_handler) 39 | 40 | if create_stdlog: 41 | handler = logging.StreamHandler(sys.stdout) 42 | handler.setLevel(logging.DEBUG) 43 | root_logger.addHandler(handler) 44 | 45 | 46 | def create_hparams(hp_conf, data_set_conf, training_conf, 47 | model_conf, logdir): 48 | exp_wide_keys = ["learning_rate", "l2_reg", "kl_weight", "batch_size", 49 | "epochs_per_round", "hierarchical", "prior_scale", "natural_lr", 50 | "server_learning_rate", "damping_factor"] 51 | HP_DICT = {} 52 | for key_0 in list(hp_conf.keys()) + exp_wide_keys: 53 | if (key_0 == 'learning_rate' 54 | or key_0 == 'kl_weight' 55 | or key_0 == 'l2_reg' 56 | or key_0 == 'server_learning_rate' 57 | or key_0 == 'natural_lr' 58 | or key_0 == 'damping_factor' 59 | or key_0 == 'delta_percentile'): 60 | HP_DICT[key_0] = hp.HParam(key_0, hp.RealInterval(0.0, 1e20)) 61 | elif key_0 == 'batch_size': 62 | HP_DICT[key_0] = hp.HParam(key_0, hp.Discrete([1, 5, 10, 20, 40, 50, 63 | 64, 128, 256, 512])) 64 | elif key_0 == 'epochs_per_round': 65 | HP_DICT[key_0] = hp.HParam(key_0, hp.Discrete([1, 2, 5, 10, 15, 20, 66 | 25, 30, 35, 40, 67 | 45, 50, 55, 60, 68 | 65, 70, 75, 80, 69 | 85, 90, 95, 100, 70 | 110, 120, 130, 71 | 140, 150])) 72 | elif key_0 == 'clients_per_round': 73 | HP_DICT[key_0] = hp.HParam(key_0, hp.Discrete([1, 2, 3, 4, 5, 74 | 10, 15, 20, 50])) 75 | elif key_0 == 'method': 76 | HP_DICT[key_0] = hp.HParam(key_0, hp.Discrete(['virtual', 77 | 'fedprox'])) 78 | elif key_0 == 'hierarchical': 79 | HP_DICT[key_0] = hp.HParam(key_0, hp.Discrete([True, False])) 80 | else: 81 | HP_DICT[key_0] = hp.HParam(key_0) 82 | for key, _ in data_set_conf.items(): 83 | if key == 'name': 84 | HP_DICT[f'data_{key}'] = hp.HParam(f'data_{key}', 85 | hp.Discrete(['mnist', 86 | 'pmnist', 87 | 'femnist', 88 | 'shakespeare', 89 | 'human_activity', 90 | 'vehicle_sensor'])) 91 | else: 92 | HP_DICT[f'data_{key}'] = hp.HParam(f'data_{key}') 93 | for key, _ in training_conf.items(): 94 | HP_DICT[f'training_{key}'] = hp.HParam(f'training_{key}') 95 | for key, _ in model_conf.items(): 96 | HP_DICT[f'model_{key}'] = hp.HParam(f'model_{key}') 97 | HP_DICT['run'] = hp.HParam('run') 98 | HP_DICT['config_name'] = hp.HParam('config_name') 99 | HP_DICT['training_num_rounds'] = hp.HParam('num_rounds', 100 | hp.RealInterval(0.0, 1e10)) 101 | 102 | metrics = [hp.Metric('train/sparse_categorical_accuracy', 103 | display_name='train_accuracy'), 104 | hp.Metric('train/max_sparse_categorical_accuracy', 105 | display_name='train_max_accuracy'), 106 | hp.Metric('client_all/sparse_categorical_accuracy', 107 | display_name='client_all_accuracy'), 108 | hp.Metric('client_all/max_sparse_categorical_accuracy', 109 | display_name='max_client_all_accuracy'), 110 | hp.Metric('server/sparse_categorical_accuracy', 111 | display_name='server_accuracy'), 112 | hp.Metric('server/max_sparse_categorical_accuracy', 113 | display_name='max_server_accuracy'), 114 | hp.Metric('client_selected/sparse_categorical_accuracy', 115 | display_name='client_selected_accuracy'), 116 | hp.Metric('client_selected/max_sparse_categorical_accuracy', 117 | display_name='max_client_selected_max_accuracy') 118 | ] 119 | 120 | print('create_hp_fun ' + str(logdir)) 121 | 122 | with tf.summary.create_file_writer(str(logdir)).as_default(): 123 | hp.hparams_config(hparams=HP_DICT.values(), 124 | metrics=metrics) 125 | return HP_DICT 126 | 127 | 128 | def write_hparams(hp_dict, session_num, exp_conf, data_set_conf, 129 | training_conf, model_conf, logdir_run, config_name): 130 | 131 | hparams = {'run': int(session_num), 'config_name': config_name} 132 | for key_0, value_0 in exp_conf.items(): 133 | if isinstance(value_0, list): 134 | hparams[hp_dict[key_0]] = str(value_0) 135 | else: 136 | hparams[hp_dict[key_0]] = value_0 137 | for key_1, value_1 in data_set_conf.items(): 138 | hparams[hp_dict[f'data_{key_1}']] = str(value_1) 139 | for key_2, value_2 in training_conf.items(): 140 | if key_2 == 'num_rounds': 141 | hparams[hp_dict[f'training_{key_2}']] = value_2 142 | else: 143 | hparams[hp_dict[f'training_{key_2}']] = str(value_2) 144 | for key_3, value_3 in model_conf.items(): 145 | if key_3 == 'layers': 146 | continue 147 | hparams[hp_dict[f'model_{key_3}']] = str(value_3) 148 | 149 | # Only concatenation of the name of the layers 150 | layers = '' 151 | for layer in model_conf['layers']: 152 | layers = layers + layer['name'] + '_' 153 | hparams[hp_dict['model_layers']] = layers[:-1] 154 | 155 | print('write_hp_fun ' + str(logdir_run)) 156 | 157 | with tf.summary.create_file_writer(str(logdir_run)).as_default(): 158 | hp.hparams(hparams) 159 | 160 | 161 | def _gridsearch(hp_conf): 162 | keys, values = zip(*hp_conf.items()) 163 | experiments = [dict(zip(keys, v)) for v in product(*values)] 164 | return experiments 165 | 166 | 167 | def submit_jobs(configs, root_path, data_dir, hour=12, mem=8000, 168 | use_scratch=False, reps=1): 169 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 170 | config_dir = root_path / f'temp_configs_{current_time}' 171 | config_dir.mkdir(parents=True, exist_ok=True) 172 | 173 | lsf_out_dir = root_path/ 'outs' 174 | lsf_out_dir.mkdir(parents=True, exist_ok=True) 175 | 176 | hp_conf = configs['hp'] 177 | experiments = _gridsearch(hp_conf) 178 | new_config = configs.copy() 179 | for session_num, exp_conf in enumerate(experiments): 180 | for rep in range(reps): 181 | for key, value in exp_conf.items(): 182 | new_config['hp'][key] = [value] 183 | 184 | name = f"{configs['config_name']}_" \ 185 | f"g{current_time}_" \ 186 | f"{session_num}_" \ 187 | f"{rep}" 188 | 189 | # Save the new config file 190 | config_path = config_dir / f"{name}.json" 191 | with config_path.open(mode='w') as config_file: 192 | json.dump(new_config, config_file) 193 | 194 | # Run training with the new config file 195 | command = (f"bsub -n 2 -W {hour}:00 " 196 | f"-R rusage[mem={mem},scratch=80000," 197 | f"ngpus_excl_p=1] " 198 | f"-o {lsf_out_dir / name} " 199 | f"python main.py --result_dir {root_path} " 200 | f"--data_dir {data_dir} " 201 | f"{'--scratch ' if use_scratch else ''}" 202 | f"{config_path}") 203 | subprocess.check_output(command.split()) 204 | 205 | 206 | def run_experiments(configs, root_path, data_dir=None, use_scratch=False): 207 | print('enter experiment function') 208 | if use_scratch: 209 | dir_name = data_dir.name 210 | temp_dir = Path(os_environ['TMPDIR']) / dir_name 211 | logger.info(f"Copying datafiles to the scratch folder {temp_dir}") 212 | copytree(str(data_dir), str(temp_dir)) 213 | data_dir = temp_dir 214 | 215 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 216 | print(current_time) 217 | print(configs) 218 | # Configs 219 | data_set_conf = configs['data_set_conf'] 220 | training_conf = configs['training_conf'] 221 | model_conf = configs['model_conf'] 222 | hp_conf = configs['hp'] 223 | session_conf = configs['session'] 224 | if 'input_shape' in model_conf: 225 | model_conf['input_shape'] = tuple(model_conf['input_shape']) 226 | 227 | logdir = root_path / 'logs' / f'{configs["config_name"]}_' \ 228 | f'e{current_time}' 229 | logdir.mkdir(parents=True) 230 | logfile = logdir / LOGGER_RESULT_FILE 231 | _setup_logger(logfile, create_stdlog=True) 232 | commit_id = Repo(Path().absolute()).head.commit 233 | logger.debug(f"Running code on git commit {commit_id}") 234 | 235 | logger.debug(f"Loading dataset") 236 | fede_train_data, fed_test_data, train_size, test_size = federated_dataset( 237 | data_set_conf, data_dir) 238 | logger.debug(f"Dataset loaded") 239 | num_clients = len(fede_train_data) 240 | model_conf['num_clients'] = num_clients 241 | 242 | logger.debug(f"Making grid of hyperparameters") 243 | experiments = _gridsearch(hp_conf) 244 | logger.debug(f"Grid done") 245 | for session_num, exp_conf in enumerate(experiments): 246 | all_params = {**data_set_conf, 247 | **training_conf, 248 | **model_conf, 249 | **exp_conf, 250 | **session_conf} 251 | training_conf['num_rounds'] = \ 252 | int(all_params['tot_epochs_per_client'] * all_params['num_clients'] 253 | / (all_params['clients_per_round'] 254 | * all_params['epochs_per_round'])) 255 | all_params['num_rounds'] = training_conf['num_rounds'] 256 | if all_params.pop('check_numerics', None): 257 | print('enabled check numerics') 258 | tf.debugging.enable_check_numerics() 259 | else: 260 | print('no debugging') 261 | 262 | if 'damping_factor' not in all_params and all_params['method'] == 'virtual': 263 | all_params['damping_factor'] = 1-all_params['server_learning_rate'] 264 | hp_conf['damping_factor'] = all_params['damping_factor'] 265 | exp_conf['damping_factor'] = all_params['damping_factor'] 266 | 267 | print(all_params) 268 | # Log configurations 269 | logdir_run = logdir / f'{session_num}_{current_time}' 270 | logger.info(f"saving results in {logdir_run}") 271 | HP_DICT = create_hparams(hp_conf, data_set_conf, training_conf, 272 | model_conf, logdir_run) 273 | 274 | write_hparams(HP_DICT, session_num, exp_conf, data_set_conf, 275 | training_conf, model_conf, logdir_run, configs[ 276 | 'config_name']) 277 | 278 | with open(logdir_run / 'config.json', 'w') as config_file: 279 | json.dump(configs, config_file, indent=4) 280 | 281 | # Prepare dataset 282 | logger.debug(f'batching datasets') 283 | seq_length = data_set_conf.get('seq_length', None) 284 | federated_train_data_batched = [ 285 | batch_dataset(data, all_params['batch_size'], 286 | padding=data_set_conf['name'] == 'shakespeare', 287 | seq_length=seq_length) 288 | for data in fede_train_data] 289 | federated_test_data_batched = [ 290 | batch_dataset(data, all_params['batch_size'], 291 | padding=data_set_conf['name'] == 'shakespeare', 292 | seq_length=seq_length) 293 | for data in fed_test_data] 294 | 295 | sample_batch = tf.nest.map_structure( 296 | lambda x: x.numpy(), iter(federated_train_data_batched[0]).next()) 297 | 298 | # Run the experiment 299 | logger.info(f'Starting run {session_num} ' 300 | f'with parameters {all_params}...') 301 | model_fn = get_compiled_model_fn_from_dict(all_params, sample_batch) 302 | run_simulation(model_fn, federated_train_data_batched, 303 | federated_test_data_batched, train_size, test_size, 304 | all_params, logdir_run) 305 | tf.keras.backend.clear_session() 306 | gc.collect() 307 | 308 | logger.info("Finished experiment successfully") 309 | 310 | 311 | def main(): 312 | # Parse arguments 313 | 314 | parser = argparse.ArgumentParser() 315 | parser.add_argument("config_path", 316 | type=Path, 317 | help="Path to the main json config. " 318 | "Ex: 'configurations/femnist_virtual.json'") 319 | parser.add_argument("--result_dir", 320 | type=Path, 321 | help="Path in which results of training are/will be " 322 | "located") 323 | parser.add_argument("--data_dir", 324 | type=Path, 325 | default=Path('data'), 326 | help="Path in which data is located. This is " 327 | "required if run on Leonhard") 328 | parser.add_argument("--submit_leonhard", action='store_true', 329 | help="Whether to submit jobs to leonhard for " 330 | "grid search") 331 | 332 | parser.add_argument("-s", "--scratch", action='store_true', 333 | help="Whether to first copy the dataset to the " 334 | "scratch storage of Leonhard. Do not use on " 335 | "other systems than Leonhard.") 336 | parser.add_argument("-m", "--memory", 337 | type=int, 338 | default=8500, 339 | help="Memory allocated for each leonhard job. This " 340 | "will be ignored of Leonhard is not selected.") 341 | parser.add_argument("-t", "--time", 342 | type=int, 343 | default=24, 344 | help="Number of hours requested for the job on " 345 | "Leonhard. For virtual models usually it " 346 | "requires more time than this default value.") 347 | parser.add_argument("-r", "--repetitions", 348 | type=int, 349 | default=1, 350 | help="Number of repetitions to run the same " 351 | "experiment") 352 | 353 | args = parser.parse_args() 354 | # Read config files 355 | with args.config_path.absolute().open(mode='r') as config_file: 356 | configs = json.load(config_file) 357 | configs['config_name'] = args.config_path.name.\ 358 | replace(args.config_path.suffix, "") 359 | 360 | if not args.result_dir: 361 | args.result_dir = Path(__file__).parent.absolute() 362 | 363 | if args.scratch and args.data_dir == Path('data'): 364 | logger.warning("WARNING: You can not use scratch while not on " 365 | "Leonhard. Make sure you understand what you are " 366 | "doing.") 367 | 368 | if args.submit_leonhard: 369 | submit_jobs(configs, args.result_dir, args.data_dir, 370 | hour=args.time, mem=args.memory, use_scratch=args.scratch, 371 | reps=args.repetitions) 372 | else: 373 | print('before entering run experiment function') 374 | gpu_session(configs['session']['num_gpus']) 375 | run_experiments(configs, args.result_dir, args.data_dir, args.scratch) 376 | 377 | 378 | if __name__ == "__main__": 379 | main() 380 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | GitPython==3.1.9 2 | pandas==1.1.2 3 | scikit-learn==0.23.2 4 | tensorflow==2.5.3 5 | tensorboard==2.2.2 6 | tensorflow-federated==0.12.0 7 | tensorflow-probability==0.9.0 8 | keras==2.1 9 | GPUtil==1.4.0 -------------------------------------------------------------------------------- /source/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /source/centered_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras.engine.input_spec import InputSpec 3 | from tensorflow.python.framework import tensor_shape 4 | from tensorflow.python.framework import dtypes 5 | from tensorflow.python.keras import backend as K 6 | from tensorflow.python.keras import initializers 7 | from tensorflow.python.keras.layers.recurrent import _caching_device 8 | from tensorflow.python.keras.utils import tf_utils 9 | from tensorflow.python.eager import context 10 | from tensorflow.python.framework import ops 11 | from tensorflow.python.ops import nn_ops 12 | from tensorflow.python.keras.utils import conv_utils 13 | 14 | 15 | RECURRENT_DROPOUT_WARNING_MSG = ( 16 | 'RNN `implementation=2` is not supported when `recurrent_dropout` is set. ' 17 | 'Using `implementation=1`.') 18 | 19 | 20 | @tf.keras.utils.register_keras_serializable(package='Custom') 21 | class CenteredL2Regularizer(tf.keras.regularizers.Regularizer): 22 | def __init__(self, l2=0.): 23 | self.l2 = l2 24 | self.center = None 25 | 26 | def __call__(self, x): 27 | return self.l2 * tf.math.reduce_sum(tf.math.square(x - self.center)) 28 | 29 | def get_config(self): 30 | return {'l2': float(self.l2), 'center': float(self.center.numpy())} 31 | 32 | 33 | class LayerCentered: 34 | 35 | def compute_delta(self): 36 | delta_dict = {} 37 | for key in self.client_variable_dict.keys(): 38 | delta_dict[key] = ( 39 | self.delta_function(self.client_variable_dict[key], 40 | self.client_center_variable_dict[key])) 41 | return delta_dict 42 | 43 | def renew_center(self, center_to_updated=True): 44 | if 'natural' in self.name or center_to_updated: 45 | for key in self.client_center_variable_dict.keys(): 46 | self.client_center_variable_dict[key].assign( 47 | self.client_variable_dict[key]) 48 | 49 | def apply_delta(self, delta): 50 | for key in self.server_variable_dict.keys(): 51 | add = self.apply_delta_function( 52 | self.server_variable_dict[key], delta[key]) 53 | self.server_variable_dict[key].assign(add) 54 | self.client_variable_dict[key].assign(add) 55 | 56 | def receive_and_save_weights(self, layer_server): 57 | for key in self.server_variable_dict.keys(): 58 | self.server_variable_dict[key].assign( 59 | layer_server.server_variable_dict[key]) 60 | 61 | 62 | class DenseCentered(tf.keras.layers.Dense, LayerCentered): 63 | 64 | def __init__(self, 65 | units, 66 | activation=None, 67 | use_bias=True, 68 | kernel_initializer='glorot_uniform', 69 | bias_initializer='zeros', 70 | kernel_regularizer=None, 71 | bias_regularizer=None, 72 | activity_regularizer=None, 73 | kernel_constraint=None, 74 | bias_constraint=None, 75 | **kwargs): 76 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 77 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 78 | 79 | super(DenseCentered, self).__init__( 80 | units, 81 | activation=activation, 82 | use_bias=use_bias, 83 | kernel_initializer=kernel_initializer, 84 | bias_initializer=bias_initializer, 85 | kernel_regularizer=None, 86 | bias_regularizer=None, 87 | activity_regularizer=activity_regularizer, 88 | kernel_constraint=kernel_constraint, 89 | bias_constraint=bias_constraint, 90 | **kwargs) 91 | 92 | self.kernel_regularizer = kernel_regularizer() 93 | self.bias_regularizer = bias_regularizer() 94 | self.delta_function = tf.subtract 95 | self.apply_delta_function = tf.add 96 | self.client_variable_dict = {} 97 | self.server_variable_dict = {} 98 | self.client_center_variable_dict = {} 99 | 100 | def build(self, input_shape): 101 | dtype = dtypes.as_dtype(self.dtype or K.floatx()) 102 | if not (dtype.is_floating or dtype.is_complex): 103 | raise TypeError( 104 | 'Unable to build `Dense` layer with non-floating point ' 105 | 'dtype %s' % (dtype,)) 106 | input_shape = tensor_shape.TensorShape(input_shape) 107 | if tensor_shape.dimension_value(input_shape[-1]) is None: 108 | raise ValueError('The last dimension of the inputs to `Dense` ' 109 | 'should be defined. Found `None`.') 110 | last_dim = tensor_shape.dimension_value(input_shape[-1]) 111 | self.input_spec = InputSpec(min_ndim=2, 112 | axes={-1: last_dim}) 113 | self.kernel_regularizer.center = self.add_weight( 114 | 'kernel_center', 115 | shape=[last_dim, self.units], 116 | initializer=tf.keras.initializers.constant(0.), 117 | dtype=self.dtype, 118 | trainable=False) 119 | self.kernel = self.add_weight('kernel', 120 | shape=[last_dim, self.units], 121 | initializer=self.kernel_initializer, 122 | regularizer=self.kernel_regularizer, 123 | constraint=self.kernel_constraint, 124 | dtype=self.dtype, 125 | trainable=True) 126 | if self.use_bias: 127 | self.bias_regularizer.center = self.add_weight( 128 | 'bias_center', 129 | shape=[self.units, ], 130 | initializer=tf.keras.initializers.constant(0.), 131 | dtype=self.dtype, 132 | trainable=False) 133 | 134 | self.bias = self.add_weight('bias', 135 | shape=[self.units, ], 136 | initializer=self.bias_initializer, 137 | regularizer=self.bias_regularizer, 138 | constraint=self.bias_constraint, 139 | dtype=self.dtype, 140 | trainable=True) 141 | else: 142 | self.bias = None 143 | 144 | self.client_variable_dict['kernel'] = self.kernel 145 | self.server_variable_dict['kernel'] = self.kernel 146 | self.client_center_variable_dict['kernel'] = \ 147 | self.kernel_regularizer.center 148 | 149 | if self.use_bias: 150 | self.client_variable_dict['bias'] = self.bias 151 | self.server_variable_dict['bias'] = self.bias 152 | self.client_center_variable_dict['bias'] = \ 153 | self.bias_regularizer.center 154 | 155 | self.built = True 156 | 157 | 158 | class LSTMCellCentered(tf.keras.layers.LSTMCell, LayerCentered): 159 | 160 | def __init__(self, 161 | units, 162 | activation='tanh', 163 | recurrent_activation='hard_sigmoid', 164 | use_bias=True, 165 | kernel_initializer='glorot_uniform', 166 | recurrent_initializer='orthogonal', 167 | bias_initializer='zeros', 168 | unit_forget_bias=True, 169 | kernel_regularizer=None, 170 | recurrent_regularizer=None, 171 | bias_regularizer=None, 172 | kernel_constraint=None, 173 | recurrent_constraint=None, 174 | bias_constraint=None, 175 | dropout=0., 176 | recurrent_dropout=0., 177 | implementation=1, 178 | **kwargs): 179 | super(LSTMCellCentered, self).__init__( 180 | units, 181 | activation=activation, 182 | recurrent_activation=recurrent_activation, 183 | use_bias=use_bias, 184 | kernel_initializer=kernel_initializer, 185 | recurrent_initializer=recurrent_initializer, 186 | bias_initializer=bias_initializer, 187 | unit_forget_bias=unit_forget_bias, 188 | kernel_regularizer=None, 189 | recurrent_regularizer=None, 190 | bias_regularizer=None, 191 | kernel_constraint=kernel_constraint, 192 | recurrent_constraint=recurrent_constraint, 193 | bias_constraint=bias_constraint, 194 | dropout=dropout, 195 | recurrent_dropout=recurrent_dropout, 196 | implementation=implementation, 197 | **kwargs) 198 | 199 | self.kernel_regularizer = kernel_regularizer() 200 | self.recurrent_regularizer = recurrent_regularizer() 201 | self.bias_regularizer = bias_regularizer() 202 | self.delta_function = tf.subtract 203 | self.apply_delta_function = tf.add 204 | self.client_variable_dict = {} 205 | self.server_variable_dict = {} 206 | self.client_center_variable_dict = {} 207 | 208 | @tf_utils.shape_type_conversion 209 | def build(self, input_shape): 210 | default_caching_device = _caching_device(self) 211 | input_dim = input_shape[-1] 212 | self.kernel_regularizer.center = self.add_weight( 213 | 'kernel_center', 214 | shape=(input_dim, self.units * 4), 215 | initializer=tf.keras.initializers.constant(0.), 216 | dtype=self.dtype, 217 | trainable=False) 218 | 219 | self.recurrent_regularizer.center = self.add_weight( 220 | 'recurrent_kernel_center', 221 | shape=(self.units, self.units * 4), 222 | initializer=tf.keras.initializers.constant(0.), 223 | dtype=self.dtype, 224 | trainable=False) 225 | 226 | self.kernel = self.add_weight( 227 | shape=(input_dim, self.units * 4), 228 | name='kernel', 229 | initializer=self.kernel_initializer, 230 | regularizer=self.kernel_regularizer, 231 | constraint=self.kernel_constraint, 232 | caching_device=default_caching_device) 233 | self.recurrent_kernel = self.add_weight( 234 | shape=(self.units, self.units * 4), 235 | name='recurrent_kernel', 236 | initializer=self.recurrent_initializer, 237 | regularizer=self.recurrent_regularizer, 238 | constraint=self.recurrent_constraint, 239 | caching_device=default_caching_device) 240 | 241 | if self.use_bias: 242 | if self.unit_forget_bias: 243 | 244 | def bias_initializer(_, *args, **kwargs): 245 | return K.concatenate([ 246 | self.bias_initializer((self.units,), *args, **kwargs), 247 | initializers.Ones()((self.units,), *args, **kwargs), 248 | self.bias_initializer( 249 | (self.units * 2,), *args, **kwargs), 250 | ]) 251 | else: 252 | bias_initializer = self.bias_initializer 253 | self.bias_regularizer.center = self.add_weight( 254 | 'bias_center', 255 | shape=(self.units * 4,), 256 | initializer=tf.keras.initializers.constant(0.), 257 | dtype=self.dtype, 258 | trainable=False) 259 | 260 | self.bias = self.add_weight( 261 | shape=(self.units * 4,), 262 | name='bias', 263 | initializer=bias_initializer, 264 | regularizer=self.bias_regularizer, 265 | constraint=self.bias_constraint, 266 | caching_device=default_caching_device) 267 | else: 268 | self.bias = None 269 | 270 | self.client_variable_dict['kernel'] = self.kernel 271 | self.server_variable_dict['kernel'] = self.kernel 272 | self.client_center_variable_dict['kernel'] = \ 273 | self.kernel_regularizer.center 274 | 275 | self.client_variable_dict['recurrent_kernel'] = self.recurrent_kernel 276 | self.server_variable_dict['recurrent_kernel'] = self.recurrent_kernel 277 | self.client_center_variable_dict['recurrent_kernel'] = \ 278 | self.recurrent_regularizer.center 279 | 280 | if self.use_bias: 281 | self.client_variable_dict['bias'] = self.bias 282 | self.server_variable_dict['bias'] = self.bias 283 | self.client_center_variable_dict['bias'] = \ 284 | self.bias_regularizer.center 285 | 286 | self.built = True 287 | 288 | 289 | class RNNCentered(tf.keras.layers.RNN): 290 | 291 | def compute_delta(self): 292 | return self.cell.compute_delta() 293 | 294 | def renew_center(self, center_to_update=True): 295 | self.cell.renew_center(center_to_update) 296 | 297 | def apply_delta(self, delta): 298 | self.cell.apply_delta(delta) 299 | 300 | def receive_and_save_weights(self, layer_server): 301 | self.cell.receive_and_save_weights(layer_server.cell) 302 | 303 | 304 | class EmbeddingCentered(tf.keras.layers.Embedding, LayerCentered): 305 | 306 | def __init__(self, 307 | input_dim, 308 | output_dim, 309 | embeddings_initializer='uniform', 310 | embeddings_regularizer=None, 311 | activity_regularizer=None, 312 | embeddings_constraint=None, 313 | mask_zero=False, 314 | input_length=None, 315 | **kwargs): 316 | super(EmbeddingCentered, self).__init__( 317 | input_dim, output_dim, 318 | embeddings_initializer=embeddings_initializer, 319 | embeddings_regularizer=None, 320 | activity_regularizer=activity_regularizer, 321 | embeddings_constraint=embeddings_constraint, 322 | mask_zero=mask_zero, 323 | input_length=input_length, 324 | **kwargs) 325 | 326 | self.embeddings_regularizer = embeddings_regularizer() 327 | self.delta_function = tf.subtract 328 | self.apply_delta_function = tf.add 329 | self.client_variable_dict = {} 330 | self.server_variable_dict = {} 331 | self.client_center_variable_dict = {} 332 | 333 | @tf_utils.shape_type_conversion 334 | def build(self, input_shape): 335 | def create_weights(): 336 | self.embeddings_regularizer.center = self.add_weight( 337 | shape=(self.input_dim, self.output_dim), 338 | name='embeddings_center', 339 | initializer=tf.keras.initializers.constant(0.), 340 | dtype=self.dtype, 341 | trainable=False) 342 | 343 | self.embeddings = self.add_weight( 344 | shape=(self.input_dim, self.output_dim), 345 | initializer=self.embeddings_initializer, 346 | name='embeddings', 347 | regularizer=self.embeddings_regularizer, 348 | constraint=self.embeddings_constraint) 349 | if context.executing_eagerly() and context.context().num_gpus(): 350 | with ops.device('cpu:0'): 351 | create_weights() 352 | else: 353 | create_weights() 354 | 355 | self.client_variable_dict['embeddings'] = self.embeddings 356 | self.server_variable_dict['embeddings'] = self.embeddings 357 | self.client_center_variable_dict['embeddings'] = \ 358 | self.embeddings_regularizer.center 359 | self.built = True 360 | 361 | 362 | class Conv2DCentered(tf.keras.layers.Conv2D, LayerCentered): 363 | def __init__(self, 364 | filters, 365 | kernel_size, 366 | strides=(1, 1), 367 | padding='valid', 368 | data_format=None, 369 | dilation_rate=(1, 1), 370 | activation=None, 371 | use_bias=True, 372 | kernel_initializer='glorot_uniform', 373 | bias_initializer='zeros', 374 | kernel_regularizer=None, 375 | bias_regularizer=None, 376 | activity_regularizer=None, 377 | kernel_constraint=None, 378 | bias_constraint=None, 379 | **kwargs): 380 | super(Conv2DCentered, self).__init__( 381 | filters=filters, 382 | kernel_size=kernel_size, 383 | strides=strides, 384 | padding=padding, 385 | data_format=data_format, 386 | dilation_rate=dilation_rate, 387 | activation=activation, 388 | use_bias=use_bias, 389 | kernel_initializer=kernel_initializer, 390 | bias_initializer=bias_initializer, 391 | kernel_regularizer=kernel_regularizer, 392 | bias_regularizer=bias_regularizer, 393 | activity_regularizer=activity_regularizer, 394 | kernel_constraint=kernel_constraint, 395 | bias_constraint=bias_constraint, 396 | **kwargs) 397 | 398 | self.kernel_regularizer = kernel_regularizer() 399 | self.bias_regularizer = bias_regularizer() 400 | self.delta_function = tf.subtract 401 | self.apply_delta_function = tf.add 402 | self.client_variable_dict = {} 403 | self.server_variable_dict = {} 404 | self.client_center_variable_dict = {} 405 | 406 | def build(self, input_shape): 407 | input_shape = tensor_shape.TensorShape(input_shape) 408 | input_channel = self._get_input_channel(input_shape) 409 | kernel_shape = self.kernel_size + (input_channel, self.filters) 410 | 411 | self.kernel_regularizer.center = \ 412 | self.add_weight('kernel_center', 413 | shape=kernel_shape, 414 | initializer=tf.keras.initializers.constant(0.), 415 | dtype=self.dtype, 416 | trainable=False) 417 | self.kernel = self.add_weight( 418 | name='kernel', 419 | shape=kernel_shape, 420 | initializer=self.kernel_initializer, 421 | regularizer=self.kernel_regularizer, 422 | constraint=self.kernel_constraint, 423 | trainable=True, 424 | dtype=self.dtype) 425 | if self.use_bias: 426 | self.bias_regularizer.center = \ 427 | self.add_weight('bias_center', 428 | shape=(self.filters,), 429 | initializer=tf.keras.initializers.constant(0.), 430 | dtype=self.dtype, 431 | trainable=False) 432 | self.bias = self.add_weight( 433 | name='bias', 434 | shape=(self.filters,), 435 | initializer=self.bias_initializer, 436 | regularizer=self.bias_regularizer, 437 | constraint=self.bias_constraint, 438 | trainable=True, 439 | dtype=self.dtype) 440 | else: 441 | self.bias = None 442 | channel_axis = self._get_channel_axis() 443 | self.input_spec = InputSpec(ndim=self.rank + 2, 444 | axes={channel_axis: input_channel}) 445 | 446 | self._build_conv_op_input_shape = input_shape 447 | self._build_input_channel = input_channel 448 | self._padding_op = self._get_padding_op() 449 | self._conv_op_data_format = conv_utils.convert_data_format( 450 | self.data_format, self.rank + 2) 451 | self._convolution_op = nn_ops.Convolution( 452 | input_shape, 453 | filter_shape=self.kernel.shape, 454 | dilation_rate=self.dilation_rate, 455 | strides=self.strides, 456 | padding=self._padding_op, 457 | data_format=self._conv_op_data_format) 458 | 459 | self.client_variable_dict['kernel'] = self.kernel 460 | self.server_variable_dict['kernel'] = self.kernel 461 | self.client_center_variable_dict['kernel'] = \ 462 | self.kernel_regularizer.center 463 | 464 | if self.use_bias: 465 | self.client_variable_dict['bias'] = self.bias 466 | self.server_variable_dict['bias'] = self.bias 467 | self.client_center_variable_dict['bias'] = \ 468 | self.bias_regularizer.center 469 | 470 | self.built = True -------------------------------------------------------------------------------- /source/constants.py: -------------------------------------------------------------------------------- 1 | ROOT_LOGGER_STR = 'TFFVirtualLogger' 2 | LOGGER_RESULT_FILE = 'logs.txt' 3 | -------------------------------------------------------------------------------- /source/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | import zipfile 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | import tensorflow as tf 9 | import tensorflow_federated as tff 10 | from tensorflow_federated.python.simulation import hdf5_client_data 11 | from source.constants import ROOT_LOGGER_STR 12 | from source.utils import softmax 13 | 14 | logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__) 15 | 16 | 17 | SHUFFLE_BUFFER = 500 18 | BUFFER_SIZE = 10000 19 | 20 | 21 | def post_process_datasets(federated_data, epochs=1): 22 | return [data.repeat(epochs).shuffle(SHUFFLE_BUFFER).prefetch(BUFFER_SIZE) 23 | for data in federated_data] 24 | 25 | 26 | def federated_dataset(dataset_conf, data_dir=Path('data')): 27 | name = dataset_conf['name'] 28 | num_clients = dataset_conf['num_clients'] 29 | if name == 'mnist': 30 | x_train, y_train, x_test, y_test = mnist_preprocess(data_dir) 31 | x_train = np.split(x_train, 100) 32 | y_train = np.split(y_train, 100) 33 | x_test = np.split(x_test, 100) 34 | y_test = np.split(y_test, 100) 35 | 36 | federated_train_data = post_process_datasets( 37 | [tf.data.Dataset.from_tensor_slices(data) 38 | for data in zip(x_train, y_train)]) 39 | 40 | federated_test_data = post_process_datasets( 41 | [tf.data.Dataset.from_tensor_slices(data) 42 | for data in zip(x_test, y_test)]) 43 | 44 | train_size = [x.shape[0] for x in x_train] 45 | test_size = [x.shape[0] for x in x_test] 46 | 47 | federated_train_data = federated_train_data[0:num_clients] 48 | federated_test_data = federated_test_data[0:num_clients] 49 | train_size = train_size[0:num_clients] 50 | test_size = test_size[0:num_clients] 51 | 52 | if name == 'femnist': 53 | if (data_dir 54 | and (data_dir / 'datasets' / 'fed_emnist_digitsonly_train.h5').is_file() 55 | and (data_dir / 'datasets' / 'fed_emnist_digitsonly_test.h5').is_file()): 56 | train_file = data_dir / 'datasets' / 'fed_emnist_digitsonly_train.h5' 57 | test_file = data_dir / 'datasets' / 'fed_emnist_digitsonly_test.h5' 58 | 59 | logger.debug(f"Data already exists, loading from {data_dir}") 60 | emnist_train = hdf5_client_data.HDF5ClientData(str(train_file)) 61 | emnist_test = hdf5_client_data.HDF5ClientData(str(test_file)) 62 | else: 63 | emnist_train, emnist_test = tff.simulation.datasets.emnist.\ 64 | load_data(cache_dir=data_dir) 65 | post_shape = [-1] 66 | if 'shape' in dataset_conf: 67 | post_shape = dataset_conf['shape'] 68 | 69 | def preprocess(dataset): 70 | def element_fn(element): 71 | return (tf.reshape(element['pixels'], post_shape), 72 | (tf.reshape(element['label'], [1]))) 73 | 74 | return dataset.map(element_fn) 75 | 76 | def make_federated_data(client_data, client_ids): 77 | return [preprocess(client_data.create_tf_dataset_for_client(x)) 78 | for x in client_ids] 79 | 80 | sample_clients = emnist_train.client_ids[0:num_clients] 81 | federated_train_data = make_federated_data(emnist_train, sample_clients) 82 | federated_test_data = make_federated_data(emnist_test, sample_clients) 83 | 84 | train_size = [tf.data.experimental.cardinality(data).numpy() 85 | for data in federated_train_data] 86 | test_size = [tf.data.experimental.cardinality(data).numpy() 87 | for data in federated_test_data] 88 | federated_train_data = post_process_datasets(federated_train_data) 89 | federated_test_data = post_process_datasets(federated_test_data) 90 | 91 | if name == 'shakespeare': 92 | federated_train_data, federated_test_data, train_size, test_size = \ 93 | shakspeare(num_clients, dataset_conf['seq_length'], data_dir) 94 | 95 | if name == 'pmnist': 96 | federated_train_data, federated_test_data = permuted_mnist( 97 | num_clients=100, data_dir=data_dir) 98 | train_size = [data[0].shape[0] for data in federated_train_data] 99 | test_size = [data[0].shape[0] for data in federated_test_data] 100 | 101 | if 'shape' in dataset_conf: 102 | def preprocess(dataset): 103 | return dataset.reshape(dataset_conf['shape']) 104 | federated_train_data = [(np.array( 105 | list(map(preprocess, client[0]))), client[1]) 106 | for client in federated_train_data] 107 | federated_test_data = [(np.array( 108 | list(map(preprocess, client[0]))), client[1]) 109 | for client in federated_test_data] 110 | 111 | federated_train_data = post_process_datasets( 112 | [tf.data.Dataset.from_tensor_slices(data) 113 | for data in federated_train_data]) 114 | federated_test_data = post_process_datasets( 115 | [tf.data.Dataset.from_tensor_slices(data) 116 | for data in federated_test_data]) 117 | 118 | federated_train_data = federated_train_data[0:num_clients] 119 | federated_test_data = federated_test_data[0:num_clients] 120 | train_size = train_size[0:num_clients] 121 | test_size = test_size[0:num_clients] 122 | 123 | if name == 'human_activity': 124 | x, y = human_activity_preprocess(data_dir) 125 | x, y, x_t, y_t = data_split(x, y) 126 | train_size = [xs.shape[0] for xs in x] 127 | test_size = [xs.shape[0] for xs in x_t] 128 | 129 | federated_train_data = post_process_datasets( 130 | [tf.data.Dataset.from_tensor_slices(data) 131 | for data in zip(x, y)]) 132 | federated_test_data = post_process_datasets( 133 | [tf.data.Dataset.from_tensor_slices(data) 134 | for data in zip(x_t, y_t)]) 135 | 136 | federated_train_data = federated_train_data[0:num_clients] 137 | federated_test_data = federated_test_data[0:num_clients] 138 | train_size = train_size[0:num_clients] 139 | test_size = test_size[0:num_clients] 140 | 141 | if name == 'vehicle_sensor': 142 | x, y = vehicle_sensor_preprocess(data_dir) 143 | x, y, x_t, y_t = data_split(x, y) 144 | train_size = [xs.shape[0] for xs in x] 145 | test_size = [xs.shape[0] for xs in x_t] 146 | 147 | federated_train_data = post_process_datasets( 148 | [tf.data.Dataset.from_tensor_slices(data) 149 | for data in zip(x, y)]) 150 | federated_test_data = post_process_datasets( 151 | [tf.data.Dataset.from_tensor_slices(data) 152 | for data in zip(x_t, y_t)]) 153 | 154 | federated_train_data = federated_train_data[0:num_clients] 155 | federated_test_data = federated_test_data[0:num_clients] 156 | train_size = train_size[0:num_clients] 157 | test_size = test_size[0:num_clients] 158 | 159 | if name == 'synthetic': 160 | x, y = synthetic(num_clients=num_clients, num_class=10, dimension=60, 161 | alpha=dataset_conf['synth_alpha'], 162 | beta=dataset_conf['synth_beta'], 163 | iid=dataset_conf['iid']) 164 | x, y, x_t, y_t = data_split(x, y) 165 | train_size = [xs.shape[0] for xs in x] 166 | test_size = [xs.shape[0] for xs in x_t] 167 | federated_train_data = post_process_datasets( 168 | [tf.data.Dataset.from_tensor_slices(data) 169 | for data in zip(x, y)]) 170 | federated_test_data = post_process_datasets( 171 | [tf.data.Dataset.from_tensor_slices(data) 172 | for data in zip(x_t, y_t)]) 173 | 174 | return federated_train_data, federated_test_data, train_size, test_size 175 | 176 | 177 | def data_split(x, y, test_size=0.25): 178 | x, x_t, y, y_t = zip(*[train_test_split(x_i, y_i, test_size=test_size) 179 | for x_i, y_i in zip(x, y)]) 180 | return x, y, x_t, y_t 181 | 182 | 183 | def mnist_preprocess(data_dir=None): 184 | print(data_dir) 185 | if data_dir and (data_dir / 'datasets' / 'mnist.npz').is_file(): 186 | file_path = data_dir / 'datasets' / 'mnist.npz' 187 | 188 | logger.debug(f"Data already exists, loading from {data_dir}") 189 | with np.load(file_path, allow_pickle=True) as f: 190 | x_train, y_train = f['x_train'], f['y_train'] 191 | x_test, y_test = f['x_test'], f['y_test'] 192 | else: 193 | (x_train, y_train), (x_test, y_test) = \ 194 | tf.keras.datasets.mnist.load_data() 195 | x_train = x_train.reshape((-1, 784)) 196 | x_test = x_test.reshape((-1, 784)) 197 | x_train = x_train.astype('float32') 198 | x_test = x_test.astype('float32') 199 | x_train /= 255 200 | x_test /= 255 201 | 202 | return x_train, y_train, x_test, y_test 203 | 204 | 205 | def permute(x): 206 | 207 | def shuffle(a, i): 208 | for j, _ in enumerate(a): 209 | a[j] = (a[j].flatten()[i]) 210 | return a 211 | 212 | if isinstance(x, list): 213 | indx = np.random.permutation(x[0].shape[-1]) 214 | permuted = [] 215 | for el in x: 216 | permuted.append(shuffle(el, indx)) 217 | else: 218 | indx = np.random.permutation(x.shape[-1]) 219 | permuted = shuffle(x, indx) 220 | 221 | return permuted 222 | 223 | 224 | def permuted_mnist(num_clients=100, data_dir=None): 225 | x_train, y_train, x_test, y_test = mnist_preprocess(data_dir=data_dir) 226 | x_train = np.split(x_train, num_clients) 227 | y_train = np.split(y_train, num_clients) 228 | x_test = np.split(x_test, num_clients) 229 | y_test = np.split(y_test, num_clients) 230 | 231 | federated_train = [] 232 | federated_test = [] 233 | for x, xt, y, yt in zip(x_train, x_test, y_train, y_test): 234 | x, xt = permute([x, xt]) 235 | federated_train.append((x, y)) 236 | federated_test.append((xt, yt)) 237 | return federated_train, federated_test 238 | 239 | 240 | def download_file(url, filename): 241 | import requests 242 | from tqdm import tqdm 243 | 244 | r = requests.get(url, stream=True) 245 | total_size = int(r.headers.get('content-length', 0)) 246 | block_size = 1024 # 1 Kibibyte 247 | t = tqdm(total=total_size, unit='iB', unit_scale=True) 248 | with open(filename, 'wb') as f: 249 | for data in r.iter_content(block_size): 250 | t.update(len(data)) 251 | f.write(data) 252 | t.close() 253 | 254 | 255 | # It's important that the following link does not remove the zip file. 256 | # Otherwise the enxt time data will be downloaded again. 257 | def human_activity_preprocess(data_dir=None): 258 | 259 | if not data_dir: 260 | data_dir = Path(__file__).parent.absolute().parent 261 | data_dir = data_dir / 'data' / 'human_activity' 262 | 263 | if not data_dir.exists(): 264 | data_dir.mkdir(parents=True) 265 | 266 | subdirs = [f for f in data_dir.iterdir() if f.is_file()] 267 | if not subdirs: 268 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00240/UCI%20HAR%20Dataset.zip' 269 | zip_file = data_dir / 'original_data.zip' 270 | download_file(url, zip_file) 271 | 272 | with zipfile.ZipFile(zip_file, 'r') as zip_ref: 273 | zip_ref.extractall(data_dir) 274 | 275 | data_dir = data_dir / 'UCI HAR Dataset' 276 | data_dir_train = data_dir / 'train' 277 | data_dir_test = data_dir / 'test' 278 | 279 | x_train = pd.read_csv(data_dir_train / 'X_train.txt', 280 | delim_whitespace=True, header=None).values 281 | y_train = pd.read_csv(data_dir_train / 'y_train.txt', 282 | delim_whitespace=True, header=None).values 283 | task_index_train = pd.read_csv(data_dir_train / 'subject_train.txt', 284 | delim_whitespace=True, header=None).values 285 | x_test = pd.read_csv(data_dir_test / 'X_test.txt', 286 | delim_whitespace=True, header=None).values 287 | y_test = pd.read_csv(data_dir_test / 'y_test.txt', 288 | delim_whitespace=True, header=None).values 289 | task_index_test = pd.read_csv(data_dir_test / 'subject_test.txt', 290 | delim_whitespace=True, header=None).values 291 | 292 | x = np.concatenate((x_train, x_test)) 293 | y = np.concatenate((y_train, y_test)).squeeze() 294 | task_index = np.concatenate((task_index_train, task_index_test)).squeeze() 295 | argsort = np.argsort(task_index) 296 | x = x[argsort] 297 | y = np.array(y[argsort]) 298 | y = y-1 299 | task_index = task_index[argsort] 300 | split_index = np.where(np.roll(task_index, 1) != task_index)[0][1:] 301 | x = np.split(x, split_index) 302 | y = np.split(y, split_index) 303 | 304 | return x, y 305 | 306 | 307 | # It's important that the following link does not remove the zip file. 308 | # Otherwise the enxt time data will be downloaded again. 309 | def vehicle_sensor_preprocess(data_dir=None): 310 | if not data_dir or 'vehicle_sensor' not in str(data_dir): 311 | data_dir = Path(__file__).parent.absolute().parent 312 | data_dir = data_dir / 'data' / 'vehicle_sensor' 313 | 314 | if not data_dir.exists(): 315 | data_dir.mkdir(parents=True) 316 | subdirs = [f for f in data_dir.iterdir() if f.is_file()] 317 | if not subdirs: 318 | url = 'http://www.ecs.umass.edu/~mduarte/images/event.zip' 319 | zip_file = data_dir / 'original_data.zip' 320 | download_file(url, zip_file) 321 | 322 | with zipfile.ZipFile(zip_file, 'r') as zip_ref: 323 | zip_ref.extractall(data_dir) 324 | data_dir = data_dir / 'events' / 'runs' 325 | 326 | x = [] 327 | y = [] 328 | task_index = [] 329 | for root, dir, file_names in os.walk(data_dir): 330 | if 'acoustic' not in root and 'seismic' not in root: 331 | x_tmp = [] 332 | for file_name in file_names: 333 | if 'feat' in file_name: 334 | dt_tmp = pd.read_csv( 335 | os.path.join(root, file_name), sep=' ', 336 | skipinitialspace=True, header=None).values[:, :50] 337 | x_tmp.append(dt_tmp) 338 | if len(x_tmp) == 2: 339 | x_tmp = np.concatenate(x_tmp, axis=1) 340 | x.append(x_tmp) 341 | task_index.append( 342 | int(os.path.basename(root)[1:])*np.ones(x_tmp.shape[0])) 343 | y.append( 344 | int('aav' in os.path.basename( 345 | os.path.dirname(root)))*np.ones(x_tmp.shape[0])) 346 | 347 | x = np.concatenate(x) 348 | y = np.concatenate(y) 349 | task_index = np.concatenate(task_index) 350 | argsort = np.argsort(task_index) 351 | x = x[argsort] 352 | y = y[argsort] 353 | task_index = task_index[argsort] 354 | split_index = np.where(np.roll(task_index, 1) != task_index)[0][1:] 355 | x = np.split(x, split_index) 356 | y = np.split(y, split_index) 357 | return x, y 358 | 359 | 360 | def synthetic(num_clients=30, num_class=10, dimension=60, alpha=0., beta=0., 361 | iid=False): 362 | np.random.seed(0) 363 | samples_per_user = np.random.lognormal( 364 | 4, 2, (num_clients)).astype(int) + 50 365 | print(samples_per_user) 366 | num_samples = np.sum(samples_per_user) 367 | 368 | X_split = [[] for _ in range(num_clients)] 369 | y_split = [[] for _ in range(num_clients)] 370 | 371 | #### define some eprior #### 372 | mean_W = np.random.normal(0, alpha, num_clients) 373 | mean_b = mean_W 374 | B = np.random.normal(0, beta, num_clients) 375 | mean_x = np.zeros((num_clients, dimension)) 376 | 377 | diagonal = np.zeros(dimension) 378 | for j in range(dimension): 379 | diagonal[j] = np.power((j + 1), -1.2) 380 | cov_x = np.diag(diagonal) 381 | 382 | for i in range(num_clients): 383 | if iid == 1: 384 | mean_x[i] = np.ones(dimension) * B[i] # all zeros 385 | else: 386 | mean_x[i] = np.random.normal(B[i], 1, dimension) 387 | print(mean_x[i]) 388 | 389 | if iid == 1: 390 | W_global = np.random.normal(0, 1, (dimension, num_class)) 391 | b_global = np.random.normal(0, 1, num_class) 392 | 393 | for i in range(num_clients): 394 | W = np.random.normal(mean_W[i], 1, (dimension, num_class)) 395 | b = np.random.normal(mean_b[i], 1, num_class) 396 | if iid == 1: 397 | W = W_global 398 | b = b_global 399 | 400 | xx = np.random.multivariate_normal( 401 | mean_x[i], cov_x, samples_per_user[i]) 402 | yy = np.zeros(samples_per_user[i]) 403 | 404 | for j in range(samples_per_user[i]): 405 | tmp = np.dot(xx[j], W) + b 406 | yy[j] = np.argmax(softmax(tmp)) 407 | 408 | X_split[i] = xx 409 | y_split[i] = yy 410 | print("{}-th users has {} exampls".format(i, len(y_split[i]))) 411 | 412 | return X_split, y_split 413 | 414 | 415 | def shakspeare(num_clients=-1, seq_lenght=80, data_dir=None): 416 | vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r') 417 | table = tf.lookup.StaticHashTable( 418 | tf.lookup.KeyValueTensorInitializer( 419 | keys=vocab, values=tf.constant(list(range(1, len(vocab) + 1)), 420 | dtype=tf.int64)), 421 | default_value=tf.cast(0, tf.int64)) 422 | 423 | def to_ids(x): 424 | s = tf.reshape(x['snippets'], shape=[1]) 425 | chars = tf.strings.bytes_split(s).values 426 | ids = table.lookup(chars) 427 | return ids 428 | 429 | def preprocess(dataset): 430 | return ( 431 | # Map ASCII chars to int64 indexes using the vocab 432 | dataset.map(to_ids) 433 | # Split into individual chars 434 | .unbatch()) 435 | # Form example sequences of SEQ_LENGTH +1 436 | 437 | def postprocess(dataset): 438 | return (dataset.batch(seq_lenght + 1, drop_remainder=False) 439 | .shuffle(BUFFER_SIZE)) 440 | 441 | def data(client, source): 442 | return postprocess( 443 | preprocess(source.create_tf_dataset_for_client(client))) 444 | 445 | if data_dir: 446 | train_file = data_dir / 'datasets' / 'shakespeare_train.h5' 447 | test_file = data_dir / 'datasets' / 'shakespeare_test.h5' 448 | if data_dir and train_file.is_file() and test_file.is_file(): 449 | logger.debug(f"Data already exists, loading from {data_dir}") 450 | train_data = hdf5_client_data.HDF5ClientData(str(train_file)) 451 | test_data = hdf5_client_data.HDF5ClientData(str(test_file)) 452 | else: 453 | train_data, test_data = tff.simulation.datasets.shakespeare.load_data( 454 | cache_dir=data_dir) 455 | indx = [8, 11, 12, 17, 26, 32, 34, 43, 45, 66, 68, 72, 73, 456 | 85, 92, 93, 98, 105, 106, 108, 110, 130, 132, 143, 150, 153, 457 | 156, 158, 165, 169, 185, 187, 191, 199, 207, 212, 219, 227, 235, 458 | 236, 238, 257, 264, 269, 278, 281, 283, 285, 288, 297, 301, 305, 459 | 310, 324, 331, 340, 351, 362, 370, 373, 374, 375, 376, 383, 388, 460 | 418, 428, 429, 432, 433, 458, 471, 474, 476, 485, 491, 492, 494, 461 | 497, 500, 501, 507, 512, 519, 529, 543, 556, 564, 570, 573, 574, 462 | 579, 580, 581, 593, 600, 601, 603, 604, 613, 622, 626, 627, 632, 463 | 644, 645, 646, 648, 657, 658, 660, 663, 669, 671, 672, 676, 678, 464 | 681, 684, 695] 465 | 466 | clients = [train_data.client_ids[i] for i in indx] 467 | clients = clients[0:num_clients] 468 | 469 | train_size = [len(list(preprocess( 470 | train_data.create_tf_dataset_for_client(client)))) 471 | for client in clients] 472 | test_size = [len(list(preprocess( 473 | test_data.create_tf_dataset_for_client(client)))) 474 | for client in clients] 475 | 476 | federated_train_data = [data(client, train_data) for client in clients] 477 | federated_test_data = [data(client, test_data) for client in clients] 478 | 479 | return federated_train_data, federated_test_data, train_size, test_size 480 | 481 | 482 | def batch_dataset(dataset, batch_size, padding=None, seq_length=None): 483 | if not padding: 484 | return dataset.batch(batch_size) 485 | else: 486 | def split_input_target(chunk): 487 | input_text = tf.map_fn(lambda x: x[:-1], chunk) 488 | target_text = tf.map_fn(lambda x: x[1:], chunk) 489 | return (input_text, target_text) 490 | 491 | return dataset.padded_batch( 492 | batch_size, 493 | padded_shapes=[seq_length + 1], 494 | drop_remainder=True, 495 | padding_values=tf.cast(0, tf.int64)).map(split_input_target) 496 | -------------------------------------------------------------------------------- /source/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | 5 | import tensorflow as tf 6 | import tensorflow_federated as tff 7 | import tensorflow_probability as tfp 8 | from tensorflow_probability.python.distributions import kullback_leibler as kl_lib 9 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, RNN, Dense, Embedding, LSTMCell 10 | import gc 11 | 12 | from source.virtual_process import VirtualFedProcess 13 | from source.fed_prox import FedProx 14 | from source.gate_layer import Gate 15 | from source.utils import FlattenedCategoricalAccuracy 16 | from source.federated_devices import _Server 17 | from source.centered_layers import (DenseCentered, CenteredL2Regularizer, 18 | EmbeddingCentered, LSTMCellCentered, 19 | RNNCentered, Conv2DCentered) 20 | from source.natural_raparametrization_layer import RNNVarReparametrized 21 | from source.natural_raparametrization_layer import Conv1DVirtualNatural 22 | from source.natural_raparametrization_layer import Conv2DVirtualNatural 23 | from source.natural_raparametrization_layer import DenseReparametrizationNaturalShared, \ 24 | DenseLocalReparametrizationNaturalShared,\ 25 | DenseSharedNatural, \ 26 | natural_mean_field_normal_fn, \ 27 | natural_tensor_multivariate_normal_fn, \ 28 | natural_initializer_fn, \ 29 | NaturalGaussianEmbedding, LSTMCellVariationalNatural 30 | from source.tfp_utils import precision_from_untransformed_scale 31 | from source.constants import ROOT_LOGGER_STR 32 | from tensorflow_probability.python.layers import DenseReparameterization 33 | from source.learning_rate_multipliers_opt import LR_SGD 34 | from source.federated_devices import _Server 35 | 36 | logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__) 37 | 38 | 39 | dir_path = os.path.dirname(os.path.realpath(__file__)) 40 | 41 | 42 | def get_compiled_model_fn_from_dict(dict_conf, sample_batch): 43 | def create_seq_model(model_class=tf.keras.Sequential, train_size=None, 44 | client_weight=None): 45 | # Make sure layer parameters are a list 46 | if not isinstance(dict_conf['layers'], list): 47 | dict_conf['layers'] = [dict_conf['layers']] 48 | 49 | layers = [] 50 | for layer_params in dict_conf['layers']: 51 | layer_params = dict(layer_params) 52 | layer_class = globals()[layer_params['name']] 53 | layer_params.pop('name') 54 | layer_params.pop('scale_init', None) 55 | 56 | def kernel_reg_fn(): 57 | return CenteredL2Regularizer(dict_conf['l2_reg']) 58 | 59 | if not train_size: 60 | train_size = 1. 61 | 62 | k_w = float(train_size) 63 | if issubclass(model_class, _Server): 64 | k_w = 1 65 | 66 | kernel_divergence_fn = (lambda q, p, ignore: 67 | dict_conf['kl_weight'] 68 | * kl_lib.kl_divergence(q, p) / k_w) 69 | reccurrent_divergence_fn = (lambda q, p, ignore: 70 | dict_conf['kl_weight'] 71 | * kl_lib.kl_divergence(q, p) / k_w) 72 | 73 | if ('scale_init' in dict_conf 74 | and (issubclass(layer_class, DenseSharedNatural) 75 | or layer_class == Conv2DVirtualNatural 76 | or layer_class == NaturalGaussianEmbedding 77 | or layer_class == LSTMCellVariationalNatural)): 78 | scale_init = dict_conf['scale_init'] 79 | untransformed_scale = scale_init[0] 80 | if scale_init[0] == 'auto': 81 | untransformed_scale = \ 82 | precision_from_untransformed_scale.inverse( 83 | tf.constant(train_size, dtype=tf.float32)) 84 | layer_params['untransformed_scale_initializer'] = \ 85 | tf.random_normal_initializer(mean=untransformed_scale, 86 | stddev=scale_init[1]) 87 | 88 | if ('loc_init' in dict_conf 89 | and (issubclass(layer_class, DenseSharedNatural) 90 | or layer_class == Conv2DVirtualNatural 91 | or layer_class == NaturalGaussianEmbedding)): 92 | loc_init = dict_conf['loc_init'] 93 | layer_params['loc_initializer'] = \ 94 | tf.random_normal_initializer(mean=loc_init[0], 95 | stddev=loc_init[1]) 96 | 97 | if layer_class == DenseReparameterization: 98 | layer_params['kernel_divergence_fn'] = kernel_divergence_fn 99 | if issubclass(layer_class, DenseSharedNatural): 100 | layer_params['kernel_divergence_fn'] = kernel_divergence_fn 101 | layer_params['client_weight'] = client_weight 102 | layer_params['delta_percentile'] = dict_conf.get('delta_percentile', None) 103 | if layer_class == Conv2DVirtualNatural: 104 | layer_params['kernel_divergence_fn'] = kernel_divergence_fn 105 | layer_params['client_weight'] = client_weight 106 | if layer_class == DenseCentered: 107 | layer_params['kernel_regularizer'] = kernel_reg_fn 108 | layer_params['bias_regularizer'] = kernel_reg_fn 109 | if layer_class == EmbeddingCentered: 110 | layer_params['embeddings_regularizer'] = kernel_reg_fn 111 | layer_params['batch_input_shape'] = [dict_conf['batch_size'], 112 | dict_conf['seq_length']] 113 | layer_params['mask_zero'] = True 114 | if layer_class == Conv2DCentered: 115 | layer_params['kernel_regularizer'] = \ 116 | lambda: CenteredL2Regularizer(dict_conf['l2_reg']) 117 | layer_params['bias_regularizer'] = \ 118 | lambda: CenteredL2Regularizer(dict_conf['l2_reg']) 119 | if layer_class == NaturalGaussianEmbedding: 120 | layer_params['embedding_divergence_fn'] = kernel_divergence_fn 121 | layer_params['batch_input_shape'] = [dict_conf['batch_size'], 122 | dict_conf['seq_length']] 123 | layer_params['mask_zero'] = True 124 | if layer_class == NaturalGaussianEmbedding: 125 | layer_params['client_weight'] = client_weight 126 | 127 | if layer_class == LSTMCellCentered: 128 | cell_params = dict(layer_params) 129 | cell_params['kernel_regularizer'] = kernel_reg_fn 130 | cell_params['recurrent_regularizer'] = kernel_reg_fn 131 | cell_params['bias_regularizer'] = kernel_reg_fn 132 | cell = layer_class(**cell_params) 133 | layer_params = {'cell': cell, 134 | 'return_sequences': True, 135 | 'stateful': True} 136 | layer_class = RNNCentered 137 | if layer_class == LSTMCellVariationalNatural: 138 | cell_params = dict(layer_params) 139 | if layer_class == LSTMCellVariationalNatural: 140 | cell_params['client_weight'] = client_weight 141 | cell_params['kernel_divergence_fn'] = kernel_divergence_fn 142 | cell_params['recurrent_kernel_divergence_fn'] = \ 143 | reccurrent_divergence_fn 144 | cell = layer_class(**cell_params) 145 | 146 | layer_params = {'cell': cell, 147 | 'return_sequences': True, 148 | 'stateful': True} 149 | layer_class = RNNVarReparametrized 150 | 151 | layer_params.pop('name', None) 152 | layers.append(layer_class(**layer_params)) 153 | return model_class(layers) 154 | 155 | def create_model_hierarchical(model_class=tf.keras.Model, train_size=None, 156 | client_weight=None): 157 | if 'architecture' in dict_conf and dict_conf['architecture'] == 'rnn': 158 | b_shape = (dict_conf['batch_size'], dict_conf['seq_length']) 159 | in_layer = tf.keras.layers.Input(batch_input_shape=b_shape) 160 | else: 161 | in_key = ('input_dim' if 'input_dim' in dict_conf['layers'][0] 162 | else 'input_shape') 163 | input_dim = dict_conf['layers'][0][in_key] 164 | in_layer = tf.keras.layers.Input(shape=input_dim) 165 | 166 | client_path = in_layer 167 | server_path = in_layer 168 | 169 | for layer_params in dict_conf['layers']: 170 | layer_params = dict(layer_params) 171 | layer_class = globals()[layer_params['name']] 172 | layer_params.pop('name') 173 | 174 | k_w = float(train_size) 175 | if issubclass(model_class, _Server): 176 | k_w = 1 177 | 178 | server_divergence_fn = (lambda q, p, ignore: 179 | dict_conf['kl_weight'] 180 | * kl_lib.kl_divergence(q, p) / k_w) 181 | client_divergence_fn = (lambda q, p, ignore: 182 | dict_conf['kl_weight'] 183 | * kl_lib.kl_divergence(q, p) / k_w) 184 | 185 | client_posterior_fn = natural_mean_field_normal_fn 186 | client_prior_fn = natural_tensor_multivariate_normal_fn 187 | 188 | client_reccurrent_divergence_fn = (lambda q, p, ignore: 189 | dict_conf['kl_weight'] 190 | * kl_lib.kl_divergence(q, p) 191 | / k_w) 192 | server_reccurrent_divergence_fn = (lambda q, p, ignore: 193 | dict_conf['kl_weight'] 194 | * kl_lib.kl_divergence(q, p) 195 | / k_w) 196 | 197 | if ('scale_init' in dict_conf 198 | and (issubclass(layer_class, DenseSharedNatural) 199 | or layer_class == Conv2DVirtualNatural) 200 | or layer_class == NaturalGaussianEmbedding): 201 | scale_init = dict_conf['scale_init'] 202 | untransformed_scale = scale_init[0] 203 | if scale_init[0] == 'auto': 204 | untransformed_scale = \ 205 | precision_from_untransformed_scale.inverse( 206 | tf.constant(train_size, dtype=tf.float32)) 207 | layer_params['untransformed_scale_initializer'] = \ 208 | tf.random_normal_initializer(mean=untransformed_scale, 209 | stddev=scale_init[1]) 210 | if ('loc_init' in dict_conf 211 | and (issubclass(layer_class, DenseSharedNatural) 212 | or layer_class == Conv2DVirtualNatural 213 | or layer_class == NaturalGaussianEmbedding)): 214 | loc_init = dict_conf['loc_init'] 215 | layer_params['loc_initializer'] = \ 216 | tf.random_normal_initializer(mean=loc_init[0], 217 | stddev=loc_init[1]) 218 | 219 | if issubclass(layer_class, DenseSharedNatural): 220 | server_params = dict(layer_params) 221 | server_params['kernel_divergence_fn'] = server_divergence_fn 222 | if issubclass(layer_class, DenseSharedNatural): 223 | server_params['client_weight'] = client_weight 224 | client_params = dict(layer_params) 225 | client_params['kernel_divergence_fn'] = client_divergence_fn 226 | server_params['activation'] = 'linear' 227 | if issubclass(layer_class, DenseSharedNatural): 228 | natural_initializer = natural_initializer_fn( 229 | untransformed_scale_initializer=tf.random_normal_initializer(mean=-5, stddev=0.1)) 230 | client_params['kernel_posterior_fn'] = client_posterior_fn(natural_initializer) 231 | client_params['kernel_prior_fn'] = client_prior_fn() 232 | 233 | client_params.pop('untransformed_scale_initializer', None) 234 | client_params.pop('loc_initializer', None) 235 | print('client par:', client_params) 236 | client_path = tfp.layers.DenseReparameterization( 237 | **client_params)(client_path) 238 | print('server par:', server_params, 'layer_class', layer_class) 239 | server_path = layer_class(**server_params)(server_path) 240 | gate_initializer = tf.keras.initializers.RandomUniform(minval=0, maxval=0.1) 241 | if issubclass(model_class, _Server): 242 | print('use zero initializaer') 243 | gate_initializer = tf.keras.initializers.Constant(0.) 244 | server_path = tf.keras.layers.Activation( 245 | activation=layer_params['activation'])( 246 | tf.keras.layers.Add()([server_path, Gate(gate_initializer)(client_path)])) 247 | 248 | elif issubclass(layer_class, Conv2DVirtualNatural): 249 | client_params = dict(layer_params) 250 | server_params = dict(layer_params) 251 | 252 | if issubclass(layer_class, Conv2DVirtualNatural): 253 | natural_initializer = natural_initializer_fn( 254 | untransformed_scale_initializer= 255 | layer_params['untransformed_scale_initializer']) 256 | client_params['kernel_posterior_fn'] = client_posterior_fn(natural_initializer) 257 | client_params['kernel_prior_fn'] = client_prior_fn() 258 | server_params['client_weight'] = client_weight 259 | 260 | client_params['kernel_divergence_fn'] = client_divergence_fn 261 | client_params['activation'] = 'linear' 262 | client_params.pop('untransformed_scale_initializer', None) 263 | client_params.pop('loc_initializer', None) 264 | client_path = tfp.layers.Convolution2DReparameterization( 265 | **client_params)(client_path) 266 | 267 | server_params['kernel_divergence_fn'] = server_divergence_fn 268 | server_path = layer_class(**server_params)(server_path) 269 | gate_initializer = tf.keras.initializers.RandomUniform( 270 | minval=0, maxval=0.1) 271 | if issubclass(model_class, _Server): 272 | print('use zero initializaer') 273 | gate_initializer = tf.keras.initializers.Constant(0.) 274 | server_path = tf.keras.layers.Activation( 275 | activation=layer_params['activation'])( 276 | tf.keras.layers.Add()( 277 | [server_path, Gate(gate_initializer)(client_path)])) 278 | else: 279 | client_path = layer_class(**layer_params)(client_path) 280 | server_path = layer_class(**layer_params)(server_path) 281 | 282 | return model_class(inputs=in_layer, outputs=server_path) 283 | 284 | def compile_model(model, client_weight=None): 285 | if not client_weight: 286 | client_weight = 1. 287 | 288 | def loss_fn(y_true, y_pred): 289 | return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) + sum(model.losses) 290 | 291 | metric = tf.keras.metrics.SparseCategoricalAccuracy() 292 | if 'architecture' in dict_conf: 293 | if dict_conf['architecture'] == 'rnn': 294 | metric = FlattenedCategoricalAccuracy(vocab_size=dict_conf['vocab_size']) 295 | 296 | if "decay_rate" in dict_conf: 297 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 298 | dict_conf['learning_rate'], 299 | decay_steps=dict_conf['decay_steps'], 300 | decay_rate=dict_conf['decay_rate'], 301 | staircase=True) 302 | else: 303 | lr_schedule = dict_conf['learning_rate'] 304 | 305 | if "momentum" in dict_conf: # Case of SGD 306 | optimizer = tf.optimizers.get( 307 | {'class_name': dict_conf['optimizer'], 308 | 'config': {'learning_rate': lr_schedule, 309 | 'momentum': dict_conf['momentum'], 310 | 'nesterov': dict_conf.get('nesterov', False)}}) 311 | elif "beta" in dict_conf: # Case of Adam 312 | optimizer = tf.optimizers.get( 313 | {'class_name': dict_conf['optimizer'], 314 | 'config': {'learning_rate': lr_schedule, 315 | 'beta_1': dict_conf['beta'][0], 316 | 'beta_2': dict_conf['beta'][1], 317 | 'amsgrad': dict_conf.get('amsgrad', False)}}) 318 | else: 319 | optimizer = tf.optimizers.get( 320 | {'class_name': dict_conf['optimizer'], 321 | 'config': {'learning_rate': lr_schedule}}) 322 | 323 | if dict_conf['optimizer'] == 'sgd': 324 | LR_mult_dict = {} 325 | for layer in model.layers: 326 | layer_to_check = layer 327 | if hasattr(layer, 'cell'): 328 | layer_to_check = layer.cell 329 | if 'natural' in layer_to_check.name: 330 | LR_mult_dict[layer.name] = 1 / (lr_schedule * client_weight) * dict_conf['natural_lr'] 331 | elif 'dense_reparameterization' in layer.name: 332 | LR_mult_dict[layer.name] = 1 / lr_schedule * dict_conf['natural_lr'] 333 | 334 | optimizer = LR_SGD(lr=lr_schedule, multipliers=LR_mult_dict) 335 | 336 | model.compile(optimizer=optimizer, 337 | loss=loss_fn, 338 | metrics=[metric], 339 | experimental_run_tf_function=False) 340 | return model 341 | 342 | def model_fn(model_class=tf.keras.Sequential, train_size=None, client_weight=None): 343 | create = create_seq_model 344 | if 'hierarchical' in dict_conf and dict_conf['hierarchical']: 345 | create = create_model_hierarchical 346 | 347 | model = compile_model(create(model_class, train_size, client_weight), client_weight) 348 | if dict_conf['method'] == 'fedavg': 349 | return tff.learning.from_compiled_keras_model(model, sample_batch) 350 | return model 351 | 352 | return model_fn 353 | 354 | 355 | def run_simulation(model_fn, federated_train_data, federated_test_data, 356 | train_size, test_size, cfgs, logdir): 357 | if cfgs['method'] == 'virtual': 358 | virtual_process = VirtualFedProcess(model_fn, cfgs['num_clients'], 359 | damping_factor=cfgs['damping_factor'], 360 | fed_avg_init=cfgs['fed_avg_init']) 361 | virtual_process.fit(federated_train_data, cfgs['num_rounds'], 362 | cfgs['clients_per_round'], 363 | cfgs['epochs_per_round'], 364 | train_size=train_size, test_size=test_size, 365 | federated_test_data=federated_test_data, 366 | tensorboard_updates=cfgs['tensorboard_updates'], 367 | logdir=logdir, hierarchical=cfgs['hierarchical'], 368 | verbose=cfgs['verbose'], 369 | server_learning_rate=cfgs['server_learning_rate'], 370 | MTL=True) 371 | tf.keras.backend.clear_session() 372 | del virtual_process 373 | gc.collect() 374 | elif cfgs['method'] == 'fedavg': 375 | train_log_dir = logdir / 'train' 376 | train_summary_writer = tf.summary.create_file_writer(str(train_log_dir)) 377 | test_summary_writer = tf.summary.create_file_writer(str(logdir)) 378 | 379 | tff.framework.set_default_executor(tff.framework.create_local_executor()) 380 | iterative_process = tff.learning.build_federated_averaging_process(model_fn) 381 | evaluation = tff.learning.build_federated_evaluation(model_fn) 382 | state = iterative_process.initialize() 383 | 384 | for round_num in range(cfgs['num_rounds']): 385 | state, metrics = iterative_process.next( 386 | state, 387 | [federated_train_data[indx] 388 | for indx in random.sample(range(cfgs['num_clients']), 389 | cfgs['clients_per_round'])]) 390 | test_metrics = evaluation(state.model, federated_test_data) 391 | logger.info(f'round {round_num:2d}, ' 392 | f'metrics_train={metrics}, ' 393 | f'metrics_test={test_metrics}') 394 | if round_num % cfgs['tensorboard_updates'] == 0: 395 | with train_summary_writer.as_default(): 396 | for name, value in metrics._asdict().items(): 397 | tf.summary.scalar(name, value, step=round_num) 398 | with test_summary_writer.as_default(): 399 | for name, value in test_metrics._asdict().items(): 400 | tf.summary.scalar(name, value, step=round_num) 401 | 402 | elif cfgs['method'] == 'fedprox': 403 | fed_prox_process = FedProx(model_fn, cfgs['num_clients']) 404 | fed_prox_process.fit(federated_train_data, cfgs['num_rounds'], 405 | cfgs['clients_per_round'], 406 | cfgs['epochs_per_round'], 407 | train_size=train_size, test_size=test_size, 408 | federated_test_data=federated_test_data, 409 | tensorboard_updates=cfgs['tensorboard_updates'], 410 | logdir=logdir, 411 | verbose=cfgs['verbose'], 412 | server_learning_rate=cfgs['server_learning_rate'], 413 | MTL=False) 414 | tf.keras.backend.clear_session() 415 | del fed_prox_process 416 | gc.collect() 417 | -------------------------------------------------------------------------------- /source/fed_process.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from source.tfp_utils import loc_prod_from_locprec 3 | eps = 1/tf.float32.max 4 | import random 5 | import logging 6 | from pathlib import Path 7 | import tensorflow as tf 8 | import numpy as np 9 | from source.utils import avg_dict, avg_dict_eval 10 | from source.constants import ROOT_LOGGER_STR 11 | from operator import itemgetter 12 | from source.utils import CustomTensorboard 13 | 14 | logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__) 15 | 16 | 17 | class FedProcess: 18 | 19 | def __init__(self, model_fn, num_clients): 20 | self.model_fn = model_fn 21 | self.num_clients = num_clients 22 | self.clients_indx = range(self.num_clients) 23 | self.clients = [] 24 | self.server = None 25 | 26 | self.train_summary_writer = None 27 | self.test_summary_writer = None 28 | self.valid_summary_writer = None 29 | 30 | def build(self, *args, **kwargs): 31 | pass 32 | 33 | def aggregate_deltas_multi_layer(self, deltas, client_weight=None): 34 | aggregated_deltas = [] 35 | deltas = list(map(list, zip(*deltas))) 36 | for delta_layer in deltas: 37 | aggregated_deltas.append( 38 | self.aggregate_deltas_single_layer(delta_layer, client_weight)) 39 | return aggregated_deltas 40 | 41 | def aggregate_deltas_single_layer(self, deltas, client_weight=None): 42 | for i, delta_client in enumerate(deltas): 43 | for key, el in delta_client.items(): 44 | if isinstance(el, tuple): 45 | (loc, prec) = el 46 | if client_weight: 47 | prec = prec*client_weight[i]*self.num_clients 48 | loc = tf.math.multiply(loc, prec) 49 | delta_client[key] = (loc, prec) 50 | else: 51 | if client_weight: 52 | delta_client[key] = (el*client_weight[i], ) 53 | else: 54 | delta_client[key] = (el/self.num_clients, ) 55 | 56 | deltas = {key: [dic[key] for dic in deltas] for key in deltas[0]} 57 | for key, lst in deltas.items(): 58 | lst = zip(*lst) 59 | sum_el = [] 60 | for i, el in enumerate(lst): 61 | add = tf.math.add_n(el) 62 | sum_el.append(add) 63 | 64 | if len(sum_el) == 2: 65 | loc = loc_prod_from_locprec(*sum_el) 66 | deltas[key] = (loc, sum_el[1]) 67 | else: 68 | deltas[key] = sum_el[0] 69 | return deltas 70 | 71 | def fit(self, 72 | federated_train_data, 73 | num_rounds, 74 | clients_per_round, 75 | epochs_per_round, 76 | federated_test_data=None, 77 | tensorboard_updates=1, 78 | logdir=Path(), 79 | callbacks=None, 80 | train_size=None, 81 | test_size=None, 82 | hierarchical=False, 83 | server_learning_rate=1., 84 | verbose=0, 85 | MTL=False): 86 | 87 | print('fed_process ' + str(logdir)) 88 | self.summary_writer = tf.summary.create_file_writer(str(logdir)) 89 | if MTL: 90 | self.build(train_size, hierarchical) 91 | deltas = [client.compute_delta() for client in self.clients] 92 | aggregated_deltas = self.aggregate_deltas_multi_layer( 93 | deltas, [size / sum(train_size) for size in train_size]) 94 | self.server.apply_delta(aggregated_deltas) 95 | else: 96 | self.build() 97 | 98 | history_test = [None] * len(self.clients) 99 | max_train_accuracy = -1.0 100 | max_train_acc_round = None 101 | max_server_accuracy = -1.0 102 | max_server_acc_round = None 103 | max_client_all_accuracy = -1.0 104 | max_client_all_round = None 105 | max_client_selected_accuracy = -1.0 106 | max_client_selected_acc_round = None 107 | server_test_accs = np.zeros(num_rounds) 108 | all_client_test_accs = np.zeros(num_rounds) 109 | selected_client_test_accs = np.zeros(num_rounds) 110 | training_accs = np.zeros(num_rounds) 111 | server_test_losses = np.zeros(num_rounds) 112 | all_client_test_losses = np.zeros(num_rounds) 113 | selected_client_test_losses = np.zeros(num_rounds) 114 | training_losses = np.zeros(num_rounds) 115 | overall_tensorboard = CustomTensorboard(log_dir=str(logdir)+'/selected_client', 116 | histogram_freq=max(0, verbose - 2), 117 | profile_batch=max(0, verbose - 2)) 118 | if verbose >= 2: 119 | if callbacks: 120 | callbacks.append(overall_tensorboard) 121 | else: 122 | callbacks = [overall_tensorboard] 123 | 124 | for round_i in range(num_rounds): 125 | clients_sampled = random.sample(self.clients_indx, 126 | clients_per_round) 127 | deltas = [] 128 | history_train = [] 129 | for indx in clients_sampled: 130 | self.clients[indx].receive_and_save_weights(self.server) 131 | self.clients[indx].renew_center(round_i > 0) 132 | 133 | if MTL: 134 | if self.fed_avg_init == 2 or ( 135 | self.fed_avg_init 136 | and round_i > 0): 137 | print('initialize posterior with server') 138 | self.clients[indx].initialize_kernel_posterior() 139 | 140 | history_single = self.clients[indx].fit( 141 | federated_train_data[indx], 142 | verbose=0, 143 | validation_data=federated_test_data[indx], 144 | epochs=epochs_per_round, 145 | callbacks=callbacks 146 | ) 147 | 148 | if MTL: 149 | self.clients[indx].apply_damping(self.damping_factor) 150 | 151 | delta = self.clients[indx].compute_delta() 152 | deltas.append(delta) 153 | 154 | if verbose >= 1: 155 | with self.summary_writer.as_default(): 156 | for layer in self.clients[indx].layers: 157 | layer_to_check = layer 158 | if hasattr(layer, 'cell'): 159 | layer_to_check = layer.cell 160 | for weight in layer_to_check.trainable_weights: 161 | if 'natural' in weight.name + layer.name: 162 | tf.summary.histogram(layer.name + '/' + weight.name + '_gamma', 163 | weight[..., 0], step=round_i) 164 | tf.summary.histogram(layer.name + '/' + weight.name + '_prec', 165 | weight[..., 1], step=round_i) 166 | else: 167 | tf.summary.histogram(layer.name + '/' + weight.name, weight, step=round_i) 168 | if hasattr(layer_to_check, 'kernel_posterior'): 169 | tf.summary.histogram( 170 | layer.name + '/kernel_posterior' + '_gamma_reparametrized', 171 | layer_to_check.kernel_posterior.distribution.gamma, 172 | step=round_i) 173 | tf.summary.histogram( 174 | layer.name + '/kernel_posterior' + '_prec_reparametrized', 175 | layer_to_check.kernel_posterior.distribution.prec, 176 | step=round_i) 177 | if hasattr(layer_to_check, 'recurrent_kernel_posterior'): 178 | tf.summary.histogram( 179 | layer.name + '/recurrent_kernel_posterior' + '_gamma_reparametrized', 180 | layer_to_check.recurrent_kernel_posterior.distribution.gamma, 181 | step=round_i) 182 | tf.summary.histogram( 183 | layer.name + '/recurrent_kernel_posterior' + '_prec_reparametrized', 184 | layer_to_check.recurrent_kernel_posterior.distribution.prec, 185 | step=round_i) 186 | for layer in self.server.layers: 187 | layer_to_check = layer 188 | if hasattr(layer, 'cell'): 189 | layer_to_check = layer.cell 190 | if hasattr(layer_to_check, 'server_variable_dict'): 191 | for key, value in layer_to_check.server_variable_dict.items(): 192 | if 'natural' in layer_to_check.name + value.name: 193 | tf.summary.histogram( 194 | layer.name + '/server_gamma', 195 | value[..., 0], step=round_i) 196 | tf.summary.histogram( 197 | layer.name + '/server_prec', 198 | value[..., 1], step=round_i) 199 | else: 200 | tf.summary.histogram(layer.name, value, step=round_i) 201 | 202 | history_train.append({key: history_single.history[key] 203 | for key in history_single.history.keys() 204 | if 'val' not in key}) 205 | history_test[indx] = \ 206 | {key.replace('val_', ''): history_single.history[key] 207 | for key in history_single.history.keys() 208 | if 'val' in key} 209 | 210 | train_size_sampled = itemgetter(*clients_sampled)(train_size) 211 | if clients_per_round == 1: 212 | train_size_sampled = [train_size_sampled] 213 | 214 | if MTL: 215 | client_weights = [server_learning_rate * train_size[client] / sum(train_size) 216 | for client in clients_sampled] 217 | else: 218 | client_weights = [server_learning_rate * train_size[client] / sum(train_size_sampled) 219 | for client in 220 | clients_sampled] 221 | 222 | aggregated_deltas = self.aggregate_deltas_multi_layer(deltas, client_weights) 223 | self.server.apply_delta(aggregated_deltas) 224 | 225 | server_test = [self.server.evaluate(test_data, verbose=0) 226 | for test_data in federated_test_data] 227 | 228 | all_client_test = [self.clients[indx].evaluate(test_data, verbose=0) 229 | for indx, test_data in enumerate(federated_test_data)] 230 | all_client_avg_test = avg_dict_eval( 231 | all_client_test, [size / sum(test_size) for size in test_size]) 232 | all_client_test_accs[round_i] = all_client_avg_test[1] 233 | all_client_test_losses[round_i] = all_client_avg_test[0] 234 | 235 | avg_train = avg_dict(history_train, 236 | [train_size[client] 237 | for client in clients_sampled]) 238 | selected_client_test = avg_dict(history_test, test_size) 239 | server_avg_test = avg_dict_eval( 240 | server_test, [size / sum(test_size) for size in test_size]) 241 | 242 | if server_avg_test[1] > max_server_accuracy: 243 | max_server_accuracy = server_avg_test[1] 244 | max_server_acc_round = round_i 245 | if avg_train['sparse_categorical_accuracy'] > max_train_accuracy: 246 | max_train_accuracy = avg_train['sparse_categorical_accuracy'] 247 | max_train_acc_round = round_i 248 | if selected_client_test['sparse_categorical_accuracy'] > max_client_selected_accuracy: 249 | max_client_selected_accuracy = selected_client_test['sparse_categorical_accuracy'] 250 | max_client_selected_acc_round = round_i 251 | if all_client_avg_test[1] > max_client_all_accuracy: 252 | max_client_all_accuracy = all_client_avg_test[1] 253 | max_client_all_round = round_i 254 | 255 | server_test_accs[round_i] = server_avg_test[1] 256 | training_accs[round_i] = avg_train['sparse_categorical_accuracy'] 257 | selected_client_test_accs[round_i] = selected_client_test['sparse_categorical_accuracy'] 258 | 259 | server_test_losses[round_i] = server_avg_test[0] 260 | selected_client_test_losses[round_i] = selected_client_test['loss'] 261 | training_losses[round_i] = avg_train['loss'] 262 | 263 | debug_string = (f"round: {round_i}, " 264 | f"avg_train: {avg_train}, " 265 | f"selected_client_test: {selected_client_test}, " 266 | f"server_avg_test on whole test data: {server_avg_test} " 267 | f"server max accuracy so far: {max_server_acc_round} reached at " 268 | f"round {max_server_acc_round} " 269 | f"all clients max accuracy so far: {max_client_all_accuracy} reached at " 270 | f"round {max_client_all_round} " 271 | f"all clients avg test: {all_client_avg_test}") 272 | logger.debug(debug_string) 273 | 274 | if round_i % tensorboard_updates == 0: 275 | for i, key in enumerate(avg_train.keys()): 276 | with self.summary_writer.as_default(): 277 | tf.summary.scalar('train/' + key, avg_train[key], step=round_i) 278 | tf.summary.scalar('server/' + key, server_avg_test[i], step=round_i) 279 | tf.summary.scalar('client_selected/' + key, selected_client_test[key], step=round_i) 280 | tf.summary.scalar('client_all/' + key, all_client_avg_test[i], step=round_i) 281 | if key == 'sparse_categorical_accuracy': 282 | tf.summary.scalar('train/max_' + key, max_train_accuracy, step=round_i) 283 | tf.summary.scalar('server/max_' + key, max_server_accuracy, step=round_i) 284 | tf.summary.scalar('client_selected/max_' + key, max_client_selected_accuracy, step=round_i) 285 | tf.summary.scalar('client_all/max_' + key, max_client_all_accuracy, step=round_i) 286 | 287 | 288 | # Do this at every round to make sure to keep the data even if 289 | # the training is interrupted 290 | np.save(Path(logdir).parent / 'server_accs.npy', server_test_accs) 291 | np.save(Path(logdir).parent / 'training_accs.npy', training_accs) 292 | np.save(Path(logdir).parent / 'selected_client_accs.npy', selected_client_test_accs) 293 | np.save(Path(logdir).parent / 'server_losses.npy', server_test_losses) 294 | np.save(Path(logdir).parent / 'training_losses.npy', training_losses) 295 | np.save(Path(logdir).parent / 'selected_client_losses.npy', selected_client_test_losses) 296 | np.save(Path(logdir).parent / 'all_client_accs.npy', all_client_test_accs) 297 | np.save(Path(logdir).parent / 'all_client_losses.npy', all_client_test_losses) 298 | 299 | for i, client in enumerate(self.clients): 300 | client.save_weights(str(Path(logdir) / f'weights_{i}.h5')) 301 | 302 | -------------------------------------------------------------------------------- /source/fed_prox.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from source.federated_devices import ClientSequential, ServerSequential 3 | from source.fed_process import FedProcess 4 | from source.constants import ROOT_LOGGER_STR 5 | logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__) 6 | 7 | 8 | class FedProx(FedProcess): 9 | 10 | def __init__(self, model_fn, num_clients): 11 | super(FedProx, self).__init__(model_fn, num_clients) 12 | self.clients = None 13 | 14 | def build(self, *args, **kwargs): 15 | self.clients = [self.model_fn(ClientSequential, 1) 16 | for _ in range(self.num_clients)] 17 | self.server = self.model_fn(ServerSequential, 1) 18 | -------------------------------------------------------------------------------- /source/federated_devices.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class _Client: 5 | 6 | def compute_delta(self): 7 | delta = [] 8 | for layer in self.layers: 9 | if hasattr(layer, 'compute_delta'): 10 | delta.append(layer.compute_delta()) 11 | return delta 12 | 13 | def renew_center(self, center_to_update=True): 14 | for layer in self.layers: 15 | if hasattr(layer, 'renew_center'): 16 | layer.renew_center(center_to_update) 17 | 18 | def receive_and_save_weights(self, server): 19 | for l_c, l_s in zip(self.layers, server.layers): 20 | if hasattr(l_c, 'receive_and_save_weights'): 21 | l_c.receive_and_save_weights(l_s) 22 | 23 | 24 | class _Server: 25 | 26 | def apply_delta(self, delta): 27 | for i, layer in enumerate(x for x in self.layers if hasattr(x, 'apply_delta')): 28 | if hasattr(layer, 'apply_delta'): 29 | layer.apply_delta(delta[i]) 30 | 31 | 32 | class _ClientVirtual(_Client): 33 | 34 | def apply_damping(self, damping_factor): 35 | for layer in self.layers: 36 | if hasattr(layer, 'apply_damping'): 37 | layer.apply_damping(damping_factor) 38 | 39 | def initialize_kernel_posterior(self): 40 | for layer in self.layers: 41 | if hasattr(layer, 'initialize_kernel_posterior'): 42 | layer.initialize_kernel_posterior() 43 | 44 | def call(self, inputs, training=None, mask=None): 45 | if self.num_samples > 1: 46 | sampling = MultiSampleEstimator(self, self.num_samples) 47 | else: 48 | sampling = super(_ClientVirtual, self).call 49 | output = sampling(inputs, training, mask) 50 | return output 51 | 52 | 53 | class ClientSequential(tf.keras.Sequential, _Client): 54 | 55 | def __init__(self, layers=None, name=None, num_samples=1): 56 | super(ClientSequential, self).__init__(layers=layers, name=name) 57 | self.num_samples = num_samples 58 | 59 | 60 | class ClientModel(tf.keras.Model, _Client): 61 | 62 | def __init__(self, *args, **kwargs): 63 | self.num_samples = kwargs.pop('num_samples', 1) 64 | super(ClientModel, self).__init__(*args, **kwargs) 65 | 66 | 67 | class ServerSequential(tf.keras.Sequential, _Server): 68 | 69 | def __init__(self, layers=None, name=None, num_samples=1): 70 | super(ServerSequential, self).__init__(layers=layers, name=name) 71 | self.num_samples = num_samples 72 | 73 | 74 | class ServerModel(tf.keras.Model, _Server): 75 | 76 | def __init__(self, *args, **kwargs): 77 | self.num_samples = kwargs.pop('num_samples', 1) 78 | super(ServerModel, self).__init__(*args, **kwargs) 79 | 80 | 81 | class ClientVirtualSequential(ClientSequential, _ClientVirtual): 82 | pass 83 | 84 | 85 | class ClientVirtualModel(ClientModel, _ClientVirtual): 86 | pass 87 | 88 | 89 | class MultiSampleEstimator(tf.keras.layers.Layer): 90 | 91 | def __init__(self, model, num_samples): 92 | super(MultiSampleEstimator, self).__init__() 93 | self.model = model 94 | self.num_samples = num_samples 95 | 96 | def call(self, inputs, training=None, mask=None): 97 | output = [] 98 | for _ in range(self.num_samples): 99 | output.append(super(_ClientVirtual, self.model).call(inputs, training, mask)) 100 | output = tf.stack(output) 101 | output = tf.math.reduce_mean(output, axis=0) 102 | return output 103 | -------------------------------------------------------------------------------- /source/gate_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras import initializers 3 | from tensorflow.python.framework import tensor_shape 4 | 5 | 6 | class Gate(tf.keras.layers.Layer): 7 | 8 | def __init__(self, 9 | initializer=tf.keras.initializers.RandomUniform( 10 | minval=0, maxval=0.1), 11 | **kwargs): 12 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 13 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 14 | 15 | super(Gate, self).__init__(**kwargs) 16 | self.initializer = initializers.get(initializer) 17 | 18 | def build(self, input_shape): 19 | input_shape = tensor_shape.TensorShape(input_shape) 20 | self.gate = self.add_weight( 21 | 'gate', 22 | shape=input_shape[1:], 23 | initializer=self.initializer, 24 | dtype=self.dtype, 25 | trainable=True, 26 | constraint=tf.keras.constraints.NonNeg()) 27 | self.built = True 28 | 29 | def call(self, inputs): 30 | outputs = tf.math.multiply(inputs, self.gate) 31 | return outputs 32 | 33 | def compute_output_shape(self, input_shape): 34 | return input_shape 35 | 36 | def get_config(self): 37 | config = { 38 | 'initializer': initializers.serialize(self.initializer), 39 | } 40 | base_config = super(Gate, self).get_config() 41 | return dict(list(base_config.items()) + list(config.items())) 42 | 43 | -------------------------------------------------------------------------------- /source/learning_rate_multipliers_opt.py: -------------------------------------------------------------------------------- 1 | from keras.legacy import interfaces 2 | import tensorflow.keras.backend as K 3 | from tensorflow.keras.optimizers import Optimizer 4 | import tensorflow as tf 5 | 6 | 7 | class LR_SGD(Optimizer): 8 | """Stochastic gradient descent optimizer. 9 | 10 | Includes support for momentum, 11 | learning rate decay, and Nesterov momentum. 12 | 13 | # Arguments 14 | lr: float >= 0. Learning rate. 15 | momentum: float >= 0. Parameter updates momentum. 16 | decay: float >= 0. Learning rate decay over each update. 17 | nesterov: boolean. Whether to apply Nesterov momentum. 18 | """ 19 | 20 | def __init__(self, lr=0.01, momentum=0., decay=0., 21 | nesterov=False, multipliers=None, **kwargs): 22 | super(LR_SGD, self).__init__(**kwargs, name=self.__class__.__name__) 23 | with K.name_scope(self.__class__.__name__): 24 | self.iterations = tf.Variable(0, dtype='int64', name='iterations') 25 | self.learning_rate = tf.Variable(lr, name='learning_rate', dtype=tf.float32) 26 | self.momentum = tf.Variable(momentum, name='momentum') 27 | self.decay = tf.Variable(decay, name='decay') 28 | 29 | self.initial_decay = decay 30 | self.nesterov = nesterov 31 | self.lr_multipliers = multipliers 32 | 33 | @interfaces.legacy_get_updates_support 34 | def get_updates(self, loss, params): 35 | grads = self.get_gradients(loss, params) 36 | self.updates = [K.update_add(self.iterations, 1)] 37 | 38 | lr = self.learning_rate 39 | if self.initial_decay > 0: 40 | lr *= (1. / (1. + self.decay * K.cast(self.iterations, 41 | K.dtype(self.decay)))) 42 | # momentum 43 | shapes = [K.int_shape(p) for p in params] 44 | moments = [K.zeros(shape) for shape in shapes] 45 | self.weights = [self.iterations] + moments 46 | for p, g, m in zip(params, grads, moments): 47 | 48 | matched_layer = [x for x in self.lr_multipliers.keys() if x in p.name] 49 | if matched_layer: 50 | new_lr = lr * self.lr_multipliers[matched_layer[0]] 51 | else: 52 | new_lr = lr 53 | 54 | v = self.momentum * m - new_lr * g # velocity 55 | self.updates.append(K.update(m, v)) 56 | 57 | if self.nesterov: 58 | new_p = p + self.momentum * v - new_lr * g 59 | else: 60 | new_p = p + v 61 | 62 | # Apply constraints. 63 | if getattr(p, 'constraint', None) is not None: 64 | new_p = p.constraint(new_p) 65 | 66 | self.updates.append(K.update(p, new_p)) 67 | return self.updates 68 | 69 | def get_config(self): 70 | config = {'lr': float(K.get_value(self.learning_rate)), 71 | 'momentum': float(K.get_value(self.momentum)), 72 | 'decay': float(K.get_value(self.decay)), 73 | 'nesterov': self.nesterov} 74 | base_config = super(LR_SGD, self).get_config() 75 | return dict(list(base_config.items()) + list(config.items())) 76 | 77 | def weights(self, value): 78 | self._weights = value 79 | 80 | 81 | #TODO: also adam with lr multipliers 82 | -------------------------------------------------------------------------------- /source/natural_raparametrization_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | from tensorflow_probability.python import distributions as tfd 4 | from tensorflow_probability.python.layers import util as tfp_layers_util 5 | from tensorflow.python.layers import utils as tf_layers_util 6 | from source.centered_layers import LayerCentered 7 | from source.tfp_utils import precision_from_untransformed_scale, sparse_delta_function 8 | from source.normal_natural import NormalNatural, eps 9 | from tensorflow.python.keras.constraints import Constraint 10 | from tensorflow.python.keras import backend as K 11 | from tensorflow.python.ops import math_ops 12 | from tensorflow.python.ops import nn_ops 13 | from tensorflow.python.eager import context 14 | from tensorflow.python.framework import ops 15 | from tensorflow.python.keras.utils import tf_utils 16 | from tensorflow.python.keras.layers.recurrent import _caching_device 17 | 18 | 19 | class NonNegPrec(Constraint): 20 | 21 | def __call__(self, w): 22 | prec = w[..., -1] 23 | prec = prec * math_ops.cast( 24 | math_ops.greater_equal(prec, eps), K.floatx()) 25 | return tf.stack([w[..., 0], prec], axis=-1) 26 | 27 | 28 | class NaturalRegularizer(tf.keras.regularizers.Regularizer): 29 | 30 | def __init__(self, regularizer=None): 31 | self.regularizer = regularizer 32 | 33 | def __call__(self, w): 34 | return self.regularizer.call(w[..., 0]) 35 | 36 | 37 | class NaturalConstraint(tf.keras.constraints.Constraint): 38 | 39 | def __init__(self, constraint): 40 | self.constraint = constraint 41 | 42 | def __call__(self, w): 43 | gamma = self.constraint(w[..., 0]) 44 | return tf.stack([gamma, w[..., 1]], axis=-1) 45 | 46 | 47 | def tensor_natural_par_fn(is_singular=False, 48 | natural_initializer=tf.constant_initializer(0.), 49 | natural_regularizer=None, natural_constraint=None, 50 | **kwargs): 51 | def _fn(dtype, shape, name, trainable, add_variable_fn): 52 | """Creates 'natural' parameters.""" 53 | natural = add_variable_fn( 54 | name=name + '_natural', 55 | shape=list(shape) + [2], 56 | initializer=natural_initializer, 57 | regularizer=natural_regularizer, 58 | constraint=natural_constraint, 59 | dtype=dtype, 60 | trainable=trainable, 61 | **kwargs) 62 | return natural 63 | 64 | return _fn 65 | 66 | 67 | class VariationalReparametrizedNatural(LayerCentered): 68 | 69 | def build_posterior_fn_natural(self, shape, dtype, name, posterior_fn, 70 | prior_fn): 71 | natural_par_shape = list(shape) + [2] 72 | server_par = self.add_variable(name=name+'_server_par', 73 | shape=natural_par_shape, 74 | dtype=dtype, trainable=False, 75 | initializer=tf.keras.initializers.zeros) 76 | client_par = self.add_variable(name=name+'_client_par', 77 | shape=natural_par_shape, 78 | dtype=dtype, trainable=False, 79 | initializer=tf.keras.initializers.zeros) 80 | 81 | ratio_par = tfp.util.DeferredTensor( 82 | server_par, lambda x: x - self.client_weight * client_par) 83 | 84 | posterior_fn = posterior_fn(ratio_par) 85 | prior_fn = prior_fn(ratio_par) 86 | 87 | self.server_variable_dict[name] = server_par 88 | self.client_center_variable_dict[name] = client_par 89 | return posterior_fn, prior_fn 90 | 91 | def initialize_kernel_posterior(self): 92 | for key in self.client_variable_dict.keys(): 93 | self.client_variable_dict[key].assign( 94 | self.server_variable_dict[key]) 95 | 96 | def apply_damping(self, damping_factor): 97 | for key in self.server_variable_dict.keys(): 98 | damped = self.apply_delta_function( 99 | self.client_variable_dict[key] * damping_factor, 100 | self.client_center_variable_dict[key] * (1 - damping_factor)) 101 | self.client_variable_dict[key].assign(damped) 102 | 103 | def renormalize_natural_mean_field_normal_fn(self, ratio_par): 104 | 105 | def _fn(dtype, shape, name, trainable, add_variable_fn, 106 | natural_initializer=None, 107 | natural_regularizer=None, natural_constraint=NonNegPrec(), 108 | **kwargs): 109 | natural_par_fn = tensor_natural_par_fn( 110 | natural_initializer=natural_initializer, 111 | natural_regularizer=natural_regularizer, 112 | natural_constraint=natural_constraint, 113 | **kwargs) 114 | natural = natural_par_fn( 115 | dtype, shape, name, trainable, add_variable_fn) 116 | self.client_variable_dict['_'.join(name.split('_')[0:-1])] = natural 117 | natural_reparametrized = tfp.util.DeferredTensor( 118 | natural, lambda x: x * self.client_weight + ratio_par) 119 | gamma = tfp.util.DeferredTensor( 120 | natural_reparametrized, lambda x: x[..., 0], shape=shape) 121 | prec = tfp.util.DeferredTensor( 122 | natural_reparametrized, lambda x: x[..., 1], shape=shape) 123 | 124 | dist = NormalNatural(gamma=gamma, prec=prec) 125 | batch_ndims = tf.size(dist.batch_shape_tensor()) 126 | return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) 127 | 128 | return _fn 129 | 130 | def natural_tensor_multivariate_normal_fn(self, ratio_par): 131 | def _fn(dtype, shape, name, trainable, add_variable_fn, 132 | initializer=natural_prior_initializer_fn(), 133 | regularizer=None, constraint=None, **kwargs): 134 | del trainable 135 | natural_par_fn = tensor_natural_par_fn( 136 | natural_initializer=initializer, 137 | natural_regularizer=regularizer, 138 | natural_constraint=constraint, 139 | **kwargs) 140 | natural = natural_par_fn(dtype, shape, name, False, add_variable_fn) 141 | natural_reparametrized = tfp.util.DeferredTensor( 142 | natural, lambda x: x * self.client_weight + ratio_par) 143 | gamma = tfp.util.DeferredTensor( 144 | natural_reparametrized, lambda x: x[..., 0], shape=shape) 145 | prec = tfp.util.DeferredTensor( 146 | natural_reparametrized, lambda x: x[..., 1], shape=shape) 147 | 148 | dist = NormalNatural(gamma=gamma, prec=prec) 149 | batch_ndims = tf.size(input=dist.batch_shape_tensor()) 150 | return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) 151 | 152 | return _fn 153 | 154 | 155 | class DenseSharedNatural(VariationalReparametrizedNatural): 156 | 157 | def __init__( 158 | self, units, 159 | activation=None, 160 | activity_regularizer=None, 161 | client_weight=1., 162 | trainable=True, 163 | kernel_posterior_fn=None, 164 | kernel_posterior_tensor_fn=(lambda d: d.sample()), 165 | kernel_prior_fn=None, 166 | kernel_divergence_fn=( 167 | lambda q, p, ignore: tfd.kl_divergence(q, p)), 168 | bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( 169 | is_singular=True), 170 | bias_posterior_tensor_fn=(lambda d: d.sample()), 171 | bias_prior_fn=None, 172 | bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), 173 | **kwargs): 174 | 175 | self.untransformed_scale_initializer = None 176 | if 'untransformed_scale_initializer' in kwargs: 177 | self.untransformed_scale_initializer = \ 178 | kwargs.pop('untransformed_scale_initializer') 179 | self.loc_initializer = None 180 | if 'loc_initializer' in kwargs: 181 | self.loc_initializer = \ 182 | kwargs.pop('loc_initializer') 183 | 184 | self.delta_percentile = kwargs.pop('delta_percentile', None) 185 | 186 | if kernel_posterior_fn is None: 187 | kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn 188 | if kernel_prior_fn is None: 189 | kernel_prior_fn = self.natural_tensor_multivariate_normal_fn 190 | 191 | super(DenseSharedNatural, self).\ 192 | __init__(units, 193 | activation=activation, 194 | activity_regularizer=activity_regularizer, 195 | trainable=trainable, 196 | kernel_posterior_fn=kernel_posterior_fn, 197 | kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, 198 | kernel_prior_fn=kernel_prior_fn, 199 | kernel_divergence_fn=kernel_divergence_fn, 200 | bias_posterior_fn=bias_posterior_fn, 201 | bias_posterior_tensor_fn=bias_posterior_tensor_fn, 202 | bias_prior_fn=bias_prior_fn, 203 | bias_divergence_fn=bias_divergence_fn, 204 | **kwargs) 205 | 206 | self.client_weight = client_weight 207 | self.delta_function = tf.subtract 208 | if self.delta_percentile and not activation == 'softmax': 209 | self.delta_function = sparse_delta_function(self.delta_percentile) 210 | print(self, activation, 'using delta sparisfication') 211 | self.apply_delta_function = tf.add 212 | self.client_variable_dict = {} 213 | self.client_center_variable_dict = {} 214 | self.server_variable_dict = {} 215 | 216 | def build(self, input_shape): 217 | input_shape = tf.TensorShape(input_shape) 218 | in_size = tf.compat.dimension_value( 219 | input_shape.with_rank_at_least(2)[-1]) 220 | if in_size is None: 221 | raise ValueError('The last dimension of the inputs to `Dense` ' 222 | 'should be defined. Found `None`.') 223 | self._input_spec = tf.keras.layers.InputSpec( 224 | min_ndim=2, axes={-1: in_size}) 225 | 226 | # If self.dtype is None, build weights using the default dtype. 227 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 228 | shape = [in_size, self.units] 229 | name = 'kernel' 230 | self.kernel_posterior_fn, self.kernel_prior_fn = \ 231 | self.build_posterior_fn_natural(shape, dtype, name, 232 | self.kernel_posterior_fn, 233 | self.kernel_prior_fn) 234 | natural_initializer = natural_initializer_fn( 235 | loc_stdev=0.1, u_scale_init_avg=-5, 236 | u_scale_init_stdev=0.1, 237 | untransformed_scale_initializer=self.untransformed_scale_initializer, 238 | loc_initializer=self.loc_initializer) 239 | 240 | self.kernel_posterior = self.kernel_posterior_fn( 241 | dtype, [in_size, self.units], 'kernel_posterior', 242 | self.trainable, self.add_variable, 243 | natural_initializer=natural_initializer) 244 | 245 | if self.kernel_prior_fn is None: 246 | self.kernel_prior = None 247 | else: 248 | self.kernel_prior = self.kernel_prior_fn( 249 | dtype, [in_size, self.units], 'kernel_prior', 250 | self.trainable, self.add_variable) 251 | 252 | if self.bias_posterior_fn is None: 253 | self.bias_posterior = None 254 | else: 255 | self.bias_posterior = self.bias_posterior_fn( 256 | dtype, [self.units], 'bias_posterior', 257 | self.trainable, self.add_variable) 258 | 259 | if self.bias_prior_fn is None: 260 | self.bias_prior = None 261 | else: 262 | self.bias_prior = self.bias_prior_fn( 263 | dtype, [self.units], 'bias_prior', 264 | self.trainable, self.add_variable) 265 | 266 | if self.bias_posterior: 267 | self.bias_center = self.add_weight( 268 | 'bias_center', 269 | shape=[self.units, ], 270 | initializer=tf.keras.initializers.constant(0.), 271 | dtype=self.dtype, 272 | trainable=False) 273 | self.client_variable_dict['bias'] = self.bias_posterior.distribution.loc 274 | self.server_variable_dict['bias'] = self.bias_posterior.distribution.loc 275 | self.client_center_variable_dict['bias'] = self.bias_center 276 | self.built = True 277 | 278 | 279 | class DenseReparametrizationNaturalShared( 280 | DenseSharedNatural, tfp.layers.DenseReparameterization): 281 | pass 282 | 283 | 284 | class DenseLocalReparametrizationNaturalShared( 285 | DenseSharedNatural, tfp.layers.DenseLocalReparameterization): 286 | def _apply_variational_kernel(self, inputs): 287 | self.kernel_posterior_affine = tfd.Normal( 288 | loc=tf.matmul(inputs, self.kernel_posterior.distribution.loc), 289 | scale=tf.sqrt(tf.matmul(tf.math.square(inputs), tf.math.square( 290 | self.kernel_posterior.distribution.scale)))) 291 | self.kernel_posterior_affine_tensor = ( 292 | self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) 293 | self.kernel_posterior_tensor = None 294 | return self.kernel_posterior_affine_tensor 295 | 296 | 297 | def natural_mean_field_normal_fn(natural_initializer=None): 298 | 299 | def _fn(dtype, shape, name, trainable, add_variable_fn, 300 | natural_initializer=natural_initializer, 301 | natural_regularizer=None, natural_constraint=NonNegPrec(), 302 | **kwargs): 303 | natural_par_fn = tensor_natural_par_fn( 304 | natural_initializer=natural_initializer, 305 | natural_regularizer=natural_regularizer, 306 | natural_constraint=natural_constraint, 307 | **kwargs) 308 | natural = natural_par_fn(dtype, shape, name, trainable, add_variable_fn) 309 | gamma = tfp.util.DeferredTensor( 310 | natural, lambda x: x[..., 0], shape=shape) 311 | prec = tfp.util.DeferredTensor( 312 | natural, lambda x: x[..., 1], shape=shape) 313 | 314 | dist = NormalNatural(gamma=gamma, prec=prec) 315 | batch_ndims = tf.size(dist.batch_shape_tensor()) 316 | return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) 317 | 318 | return _fn 319 | 320 | 321 | def natural_tensor_multivariate_normal_fn(): 322 | def _fn(dtype, shape, name, trainable, add_variable_fn, 323 | initializer=natural_prior_initializer_fn(), 324 | regularizer=None, constraint=None, **kwargs): 325 | del trainable 326 | natural_par_fn = tensor_natural_par_fn(natural_initializer=initializer, 327 | natural_regularizer=regularizer, 328 | natural_constraint=constraint, 329 | **kwargs) 330 | natural = natural_par_fn(dtype, shape, name, False, add_variable_fn) 331 | gamma = tfp.util.DeferredTensor( 332 | natural, lambda x: x[..., 0], shape=shape) 333 | prec = tfp.util.DeferredTensor( 334 | natural, lambda x: x[..., 1], shape=shape) 335 | 336 | dist = NormalNatural(gamma=gamma, prec=prec) 337 | batch_ndims = tf.size(input=dist.batch_shape_tensor()) 338 | return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) 339 | 340 | return _fn 341 | 342 | 343 | def natural_initializer_fn(loc_stdev=0.1, u_scale_init_avg=-5, 344 | u_scale_init_stdev=0.1, 345 | untransformed_scale_initializer=None, 346 | loc_initializer=None): 347 | if loc_initializer: 348 | loc_init = loc_initializer 349 | else: 350 | loc_init = tf.random_normal_initializer(stddev=loc_stdev) 351 | if untransformed_scale_initializer is None: 352 | untransformed_scale_initializer = tf.random_normal_initializer( 353 | mean=u_scale_init_avg, stddev=u_scale_init_stdev) 354 | 355 | def natural_initializer(shape, dtype=tf.float32): 356 | prec = precision_from_untransformed_scale( 357 | untransformed_scale_initializer(shape[:-1], dtype)) 358 | gamma = loc_init(shape[:-1], dtype) * prec 359 | natural = tf.stack([gamma, prec], axis=-1) 360 | tf.debugging.check_numerics(natural, 'initializer') 361 | return natural 362 | 363 | return natural_initializer 364 | 365 | 366 | def natural_prior_initializer_fn(): 367 | gamma_init = tf.constant_initializer(0.) 368 | precision_init = tf.constant_initializer(1.) 369 | 370 | def natural_initializer(shape, dtype): 371 | prec = precision_init(shape[:-1], dtype) 372 | gamma = gamma_init(shape[:-1], dtype) 373 | natural = tf.stack([gamma, prec], axis=-1) 374 | return natural 375 | 376 | return natural_initializer 377 | 378 | 379 | class Conv2DVirtualNatural(VariationalReparametrizedNatural, 380 | tfp.layers.Convolution2DReparameterization): 381 | 382 | def __init__( 383 | self, 384 | filters, 385 | kernel_size, 386 | strides=1, 387 | padding='valid', 388 | data_format='channels_last', 389 | dilation_rate=1, 390 | activation=None, 391 | client_weight=1., 392 | activity_regularizer=None, 393 | kernel_posterior_fn=None, 394 | kernel_posterior_tensor_fn=(lambda d: d.sample()), 395 | kernel_prior_fn=None, 396 | kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), 397 | bias_posterior_fn= 398 | tfp_layers_util.default_mean_field_normal_fn(is_singular=True), 399 | bias_posterior_tensor_fn=lambda d: d.sample(), 400 | bias_prior_fn=None, 401 | bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), 402 | **kwargs): 403 | 404 | self.untransformed_scale_initializer = None 405 | if 'untransformed_scale_initializer' in kwargs: 406 | self.untransformed_scale_initializer = \ 407 | kwargs.pop('untransformed_scale_initializer') 408 | 409 | self.loc_initializer = None 410 | if 'loc_initializer' in kwargs: 411 | self.loc_initializer = \ 412 | kwargs.pop('loc_initializer') 413 | 414 | if kernel_posterior_fn is None: 415 | kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn 416 | if kernel_prior_fn is None: 417 | kernel_prior_fn = self.natural_tensor_multivariate_normal_fn 418 | 419 | super(Conv2DVirtualNatural, self).__init__( 420 | filters=filters, 421 | kernel_size=kernel_size, 422 | strides=strides, 423 | padding=padding, 424 | data_format=data_format, 425 | dilation_rate=dilation_rate, 426 | activation=tf.keras.activations.get(activation), 427 | activity_regularizer=activity_regularizer, 428 | kernel_posterior_fn=kernel_posterior_fn, 429 | kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, 430 | kernel_prior_fn=kernel_prior_fn, 431 | kernel_divergence_fn=kernel_divergence_fn, 432 | bias_posterior_fn=bias_posterior_fn, 433 | bias_posterior_tensor_fn=bias_posterior_tensor_fn, 434 | bias_prior_fn=bias_prior_fn, 435 | bias_divergence_fn=bias_divergence_fn, 436 | **kwargs) 437 | 438 | self.client_weight= client_weight 439 | self.delta_function = tf.subtract 440 | self.apply_delta_function = tf.add 441 | self.client_variable_dict = {} 442 | self.client_center_variable_dict = {} 443 | self.server_variable_dict = {} 444 | 445 | def build(self, input_shape): 446 | input_shape = tf.TensorShape(input_shape) 447 | if self.data_format == 'channels_first': 448 | channel_axis = 1 449 | else: 450 | channel_axis = -1 451 | input_dim = tf.compat.dimension_value(input_shape[channel_axis]) 452 | if input_dim is None: 453 | raise ValueError('The channel dimension of the inputs ' 454 | 'should be defined. Found `None`.') 455 | kernel_shape = self.kernel_size + (input_dim, self.filters) 456 | 457 | # If self.dtype is None, build weights using the default dtype. 458 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 459 | name = 'kernel' 460 | 461 | self.kernel_posterior_fn, self.kernel_prior_fn = \ 462 | self.build_posterior_fn_natural(kernel_shape, dtype, name, 463 | self.kernel_posterior_fn, 464 | self.kernel_prior_fn) 465 | 466 | natural_initializer = natural_initializer_fn( 467 | loc_stdev=0.1, 468 | u_scale_init_avg=-5, 469 | u_scale_init_stdev=0.1, 470 | untransformed_scale_initializer=self.untransformed_scale_initializer) 471 | 472 | self.kernel_posterior = self.kernel_posterior_fn( 473 | dtype, kernel_shape, 'kernel_posterior', 474 | self.trainable, self.add_variable, 475 | natural_initializer=natural_initializer) 476 | 477 | if self.kernel_prior_fn is None: 478 | self.kernel_prior = None 479 | else: 480 | self.kernel_prior = self.kernel_prior_fn( 481 | dtype, kernel_shape, 'kernel_prior', 482 | self.trainable, self.add_variable) 483 | self._built_kernel_divergence = False 484 | 485 | if self.bias_posterior_fn is None: 486 | self.bias_posterior = None 487 | else: 488 | self.bias_posterior = self.bias_posterior_fn( 489 | dtype, (self.filters,), 'bias_posterior', 490 | self.trainable, self.add_variable) 491 | 492 | if self.bias_prior_fn is None: 493 | self.bias_prior = None 494 | else: 495 | self.bias_prior = self.bias_prior_fn( 496 | dtype, (self.filters,), 'bias_prior', 497 | self.trainable, self.add_variable) 498 | self._built_bias_divergence = False 499 | 500 | self.input_spec = tf.keras.layers.InputSpec( 501 | ndim=self.rank + 2, axes={channel_axis: input_dim}) 502 | self._convolution_op = nn_ops.Convolution( 503 | input_shape, 504 | filter_shape=tf.TensorShape(kernel_shape), 505 | dilation_rate=self.dilation_rate, 506 | strides=self.strides, 507 | padding=self.padding.upper(), 508 | data_format=tf_layers_util.convert_data_format( 509 | self.data_format, self.rank + 2)) 510 | 511 | if self.bias_posterior: 512 | self.bias_center = self.add_weight( 513 | 'bias_center', 514 | shape=[self.units, ], 515 | initializer=tf.keras.initializers.constant(0.), 516 | dtype=self.dtype, 517 | trainable=False) 518 | self.client_variable_dict['bias'] = self.bias_posterior.distribution.loc 519 | self.server_variable_dict['bias'] = self.bias_posterior.distribution.loc 520 | self.client_center_variable_dict['bias'] = self.bias_center 521 | 522 | self.built = True 523 | 524 | 525 | class Conv1DVirtualNatural(tfp.layers.Convolution1DReparameterization, 526 | VariationalReparametrizedNatural): 527 | 528 | def __init__( 529 | self, 530 | filters, 531 | kernel_size, 532 | strides=1, 533 | padding='valid', 534 | client_weight=1., 535 | data_format='channels_last', 536 | dilation_rate=1, 537 | activation=None, 538 | activity_regularizer=None, 539 | kernel_posterior_fn=None, 540 | kernel_posterior_tensor_fn=(lambda d: d.sample()), 541 | kernel_prior_fn=None, 542 | kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), 543 | bias_posterior_fn= 544 | tfp_layers_util.default_mean_field_normal_fn(is_singular=True), 545 | bias_posterior_tensor_fn=lambda d: d.sample(), 546 | bias_prior_fn=None, 547 | bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), 548 | **kwargs): 549 | 550 | self.untransformed_scale_initializer = None 551 | if 'untransformed_scale_initializer' in kwargs: 552 | self.untransformed_scale_initializer = \ 553 | kwargs.pop('untransformed_scale_initializer') 554 | 555 | if kernel_posterior_fn is None: 556 | kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn 557 | if kernel_prior_fn is None: 558 | kernel_prior_fn = self.natural_tensor_multivariate_normal_fn 559 | 560 | super(Conv1DVirtualNatural, self).__init__( 561 | filters=filters, 562 | kernel_size=kernel_size, 563 | strides=strides, 564 | padding=padding, 565 | data_format=data_format, 566 | dilation_rate=dilation_rate, 567 | activation=tf.keras.activations.get(activation), 568 | activity_regularizer=activity_regularizer, 569 | kernel_posterior_fn=kernel_posterior_fn, 570 | kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, 571 | kernel_prior_fn=kernel_prior_fn, 572 | kernel_divergence_fn=kernel_divergence_fn, 573 | bias_posterior_fn=bias_posterior_fn, 574 | bias_posterior_tensor_fn=bias_posterior_tensor_fn, 575 | bias_prior_fn=bias_prior_fn, 576 | bias_divergence_fn=bias_divergence_fn, 577 | **kwargs) 578 | 579 | self.client_weight = client_weight 580 | self.delta_function = tf.subtract 581 | self.apply_delta_function = tf.add 582 | self.client_variable_dict = {} 583 | self.client_center_variable_dict = {} 584 | self.server_variable_dict = {} 585 | 586 | def build(self, input_shape): 587 | input_shape = tf.TensorShape(input_shape) 588 | if self.data_format == 'channels_first': 589 | channel_axis = 1 590 | else: 591 | channel_axis = -1 592 | input_dim = tf.compat.dimension_value(input_shape[channel_axis]) 593 | if input_dim is None: 594 | raise ValueError('The channel dimension of the inputs ' 595 | 'should be defined. Found `None`.') 596 | kernel_shape = self.kernel_size + (input_dim, self.filters) 597 | 598 | # If self.dtype is None, build weights using the default dtype. 599 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 600 | name = 'kernel' 601 | 602 | self.kernel_posterior_fn, self.kernel_prior_fn = \ 603 | self.build_posterior_fn_natural(kernel_shape, dtype, name, 604 | self.kernel_posterior_fn, 605 | self.kernel_prior_fn) 606 | 607 | natural_initializer = natural_initializer_fn( 608 | loc_stdev=0.1, 609 | u_scale_init_avg=-5, 610 | u_scale_init_stdev=0.1, 611 | untransformed_scale_initializer=self.untransformed_scale_initializer) 612 | 613 | # Must have a posterior kernel. 614 | self.kernel_posterior = self.kernel_posterior_fn( 615 | dtype, kernel_shape, 'kernel_posterior', 616 | self.trainable, self.add_variable, 617 | natural_initializer=natural_initializer) 618 | 619 | if self.kernel_prior_fn is None: 620 | self.kernel_prior = None 621 | else: 622 | self.kernel_prior = self.kernel_prior_fn( 623 | dtype, kernel_shape, 'kernel_prior', 624 | self.trainable, self.add_variable) 625 | self._built_kernel_divergence = False 626 | 627 | if self.bias_posterior_fn is None: 628 | self.bias_posterior = None 629 | else: 630 | self.bias_posterior = self.bias_posterior_fn( 631 | dtype, (self.filters,), 'bias_posterior', 632 | self.trainable, self.add_variable) 633 | 634 | if self.bias_prior_fn is None: 635 | self.bias_prior = None 636 | else: 637 | self.bias_prior = self.bias_prior_fn( 638 | dtype, (self.filters,), 'bias_prior', 639 | self.trainable, self.add_variable) 640 | self._built_bias_divergence = False 641 | 642 | self.input_spec = tf.keras.layers.InputSpec( 643 | ndim=self.rank + 2, axes={channel_axis: input_dim}) 644 | self._convolution_op = nn_ops.Convolution( 645 | input_shape, 646 | filter_shape=tf.TensorShape(kernel_shape), 647 | dilation_rate=self.dilation_rate, 648 | strides=self.strides, 649 | padding=self.padding.upper(), 650 | data_format=tf_layers_util.convert_data_format( 651 | self.data_format, self.rank + 2)) 652 | 653 | if self.bias_posterior: 654 | self.bias_center = self.add_weight( 655 | 'bias_center', 656 | shape=[self.units, ], 657 | initializer=tf.keras.initializers.constant(0.), 658 | dtype=self.dtype, 659 | trainable=False) 660 | self.client_variable_dict['bias'] = self.bias_posterior.distribution.loc 661 | self.server_variable_dict['bias'] = self.bias_posterior.distribution.loc 662 | self.client_center_variable_dict['bias'] = self.bias_center 663 | 664 | self.built = True 665 | 666 | 667 | class NaturalGaussianEmbedding( 668 | tf.keras.layers.Embedding, VariationalReparametrizedNatural): 669 | 670 | def __init__(self, 671 | input_dim, 672 | output_dim, 673 | mask_zero=False, 674 | input_length=None, 675 | client_weight=1., 676 | trainable=True, 677 | embeddings_initializer=tf.keras.initializers.RandomUniform( 678 | -0.01, 0.01), 679 | embedding_posterior_fn=None, 680 | embedding_posterior_tensor_fn=(lambda d: d.sample()), 681 | embedding_prior_fn=None, 682 | embedding_divergence_fn=( 683 | lambda q, p, ignore: tfd.kl_divergence(q, p)), 684 | **kwargs 685 | ): 686 | 687 | self.untransformed_scale_initializer = None 688 | if 'untransformed_scale_initializer' in kwargs: 689 | self.untransformed_scale_initializer = \ 690 | kwargs.pop('untransformed_scale_initializer') 691 | 692 | if embedding_posterior_fn is None: 693 | embedding_posterior_fn = self.renormalize_natural_mean_field_normal_fn 694 | if embedding_prior_fn is None: 695 | embedding_prior_fn = self.natural_tensor_multivariate_normal_fn 696 | 697 | super(NaturalGaussianEmbedding, self).__init__(input_dim, 698 | output_dim, 699 | mask_zero=mask_zero, 700 | input_length=input_length, 701 | trainable=trainable, 702 | embeddings_initializer=embeddings_initializer, 703 | **kwargs) 704 | 705 | self.client_weight = client_weight 706 | self.delta_function = tf.subtract 707 | self.apply_delta_function = tf.add 708 | self.embedding_posterior_fn = embedding_posterior_fn 709 | self.embedding_prior_fn = embedding_prior_fn 710 | self.embedding_posterior_tensor_fn = embedding_posterior_tensor_fn 711 | self.embedding_divergence_fn = embedding_divergence_fn 712 | self.client_variable_dict = {} 713 | self.client_center_variable_dict = {} 714 | self.server_variable_dict = {} 715 | 716 | def build(self, input_shape): 717 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 718 | shape = (self.input_dim, self.output_dim) 719 | if context.executing_eagerly() and context.context().num_gpus(): 720 | with ops.device('cpu:0'): 721 | self.embedding_posterior_fn, self.embedding_prior_fn = \ 722 | self.build_posterior_fn_natural(shape, dtype, 'embedding', 723 | self.embedding_posterior_fn, 724 | self.embedding_prior_fn) 725 | else: 726 | self.embedding_posterior_fn, self.embedding_prior_fn = \ 727 | self.build_posterior_fn_natural(shape, dtype, 'embedding', 728 | self.embedding_posterior_fn, 729 | self.embedding_prior_fn) 730 | 731 | natural_initializer = natural_initializer_fn( 732 | untransformed_scale_initializer=self.untransformed_scale_initializer, 733 | loc_initializer=self.embeddings_initializer) 734 | 735 | self.embedding_posterior = self.embedding_posterior_fn( 736 | dtype, shape, 'embedding_posterior', 737 | self.trainable, self.add_variable, 738 | natural_initializer=natural_initializer) 739 | 740 | self.embedding_prior = self.embedding_prior_fn( 741 | dtype, shape, 'embedding_prior', 742 | self.trainable, self.add_variable) 743 | 744 | self.built = True 745 | 746 | def _apply_divergence(self, divergence_fn, posterior, prior, 747 | posterior_tensor, name): 748 | if (divergence_fn is None or 749 | posterior is None or 750 | prior is None): 751 | divergence = None 752 | return 753 | divergence = tf.identity( 754 | divergence_fn( 755 | posterior, prior, posterior_tensor), 756 | name=name) 757 | self.add_loss(divergence) 758 | 759 | def call(self, inputs): 760 | self.embeddings = self.embedding_posterior_tensor_fn(self.embedding_posterior) 761 | self._apply_divergence(self.embedding_divergence_fn, 762 | self.embedding_posterior, 763 | self.embedding_prior, 764 | self.embeddings, 765 | name='divergence_embeddings') 766 | return super(NaturalGaussianEmbedding, self).call(inputs) 767 | 768 | 769 | class LSTMCellVariationalNatural(tf.keras.layers.LSTMCell, VariationalReparametrizedNatural): 770 | 771 | def __init__(self, 772 | units, 773 | activation='tanh', 774 | recurrent_activation='hard_sigmoid', 775 | use_bias=True, 776 | kernel_initializer=tf.keras.initializers.VarianceScaling(scale=30.0, 777 | mode='fan_avg', 778 | distribution='uniform',), 779 | recurrent_initializer=tf.keras.initializers.Orthogonal(gain=7.0), 780 | bias_initializer='zeros', 781 | unit_forget_bias=True, 782 | kernel_constraint=None, 783 | recurrent_constraint=None, 784 | bias_constraint=None, 785 | dropout=0., 786 | recurrent_dropout=0., 787 | implementation=1, 788 | kernel_posterior_fn=None, 789 | kernel_posterior_tensor_fn=(lambda d: d.sample()), 790 | recurrent_kernel_posterior_fn=None, 791 | recurrent_kernel_posterior_tensor_fn=(lambda d: d.sample()), 792 | kernel_prior_fn=None, 793 | recurrent_kernel_prior_fn=None, 794 | kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), 795 | recurrent_kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), 796 | bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( 797 | is_singular=True), 798 | bias_posterior_tensor_fn=(lambda d: d.sample()), 799 | bias_prior_fn=None, 800 | bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), 801 | client_weight=1., 802 | **kwargs): 803 | 804 | self.untransformed_scale_initializer = kwargs.pop('untransformed_scale_initializer', None) 805 | 806 | if kernel_posterior_fn is None: 807 | kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn 808 | if kernel_prior_fn is None: 809 | kernel_prior_fn = self.natural_tensor_multivariate_normal_fn 810 | if recurrent_kernel_posterior_fn is None: 811 | recurrent_kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn 812 | if recurrent_kernel_prior_fn is None: 813 | recurrent_kernel_prior_fn = self.natural_tensor_multivariate_normal_fn 814 | 815 | super(LSTMCellVariationalNatural, self).__init__( 816 | units, 817 | activation=activation, 818 | recurrent_activation=recurrent_activation, 819 | use_bias=use_bias, 820 | kernel_initializer=kernel_initializer, 821 | recurrent_initializer=recurrent_initializer, 822 | bias_initializer=bias_initializer, 823 | unit_forget_bias=unit_forget_bias, 824 | kernel_regularizer=None, 825 | recurrent_regularizer=None, 826 | bias_regularizer=None, 827 | kernel_constraint=kernel_constraint, 828 | recurrent_constraint=recurrent_constraint, 829 | bias_constraint=bias_constraint, 830 | dropout=dropout, 831 | recurrent_dropout=recurrent_dropout, 832 | implementation=implementation, 833 | **kwargs) 834 | 835 | self.kernel_posterior_fn = kernel_posterior_fn 836 | self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn 837 | self.recurrent_kernel_posterior_fn = recurrent_kernel_posterior_fn 838 | self.recurrent_kernel_posterior_tensor_fn = recurrent_kernel_posterior_tensor_fn 839 | self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn 840 | self.kernel_prior_fn = kernel_prior_fn 841 | self.recurrent_kernel_prior_fn = recurrent_kernel_prior_fn 842 | self.kernel_divergence_fn = kernel_divergence_fn 843 | self.recurrent_kernel_divergence_fn = recurrent_kernel_divergence_fn 844 | self.bias_posterior_fn = bias_posterior_fn 845 | self.bias_posterior_tensor_fn = bias_posterior_tensor_fn 846 | self.bias_prior_fn = bias_prior_fn 847 | self.bias_divergence_fn = bias_divergence_fn 848 | self.client_weight = client_weight 849 | self.delta_function = tf.subtract 850 | self.apply_delta_function = tf.add 851 | self.client_variable_dict = {} 852 | self.client_center_variable_dict = {} 853 | self.server_variable_dict = {} 854 | 855 | @tf_utils.shape_type_conversion 856 | def build(self, input_shape): 857 | default_caching_device = _caching_device(self) 858 | input_dim = input_shape[-1] 859 | 860 | shape_kernel = (input_dim, self.units * 4) 861 | shape_recurrent = (self.units, self.units * 4) 862 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 863 | self.kernel_posterior_fn, self.kernel_prior_fn = \ 864 | self.build_posterior_fn_natural(shape_kernel, dtype, 'kernel', 865 | self.kernel_posterior_fn, 866 | self.kernel_prior_fn) 867 | 868 | self.recurrent_kernel_posterior_fn, self.recurrent_kernel_prior_fn = \ 869 | self.build_posterior_fn_natural(shape_recurrent, dtype, 870 | 'recurrent_kernel', 871 | self.recurrent_kernel_posterior_fn, 872 | self.recurrent_kernel_prior_fn) 873 | 874 | kernel_initializer = natural_initializer_fn( 875 | loc_stdev=0.1, 876 | u_scale_init_avg=-5, 877 | u_scale_init_stdev=0.1, 878 | untransformed_scale_initializer=self.untransformed_scale_initializer, 879 | loc_initializer=self.kernel_initializer) 880 | 881 | if self.kernel_regularizer: self.kernel_regularizer = NaturalRegularizer(self.kernel_regularizer) 882 | if self.kernel_constraint: self.kernel_constraint = NaturalConstraint(self.kernel_constraint) 883 | 884 | self.kernel_posterior = self.kernel_posterior_fn( 885 | dtype, shape_kernel, 'kernel_posterior', self.trainable, 886 | self.add_variable, 887 | natural_initializer=kernel_initializer, 888 | natural_regularizer=self.kernel_regularizer, 889 | natural_constraint=self.kernel_constraint, 890 | caching_device=default_caching_device) 891 | 892 | if self.kernel_prior_fn is None: 893 | self.kernel_prior = None 894 | else: 895 | self.kernel_prior = self.kernel_prior_fn( 896 | dtype, shape_kernel, 'kernel_prior', 897 | self.trainable, self.add_variable) 898 | 899 | recurrent_initializer = natural_initializer_fn( 900 | loc_stdev=0.1, 901 | u_scale_init_avg=-5, 902 | u_scale_init_stdev=0.1, 903 | untransformed_scale_initializer= 904 | self.untransformed_scale_initializer, 905 | loc_initializer=self.recurrent_initializer) 906 | 907 | if self.recurrent_regularizer: 908 | self.recurrent_regularizer = NaturalRegularizer( 909 | self.recurrent_regularizer) 910 | if self.recurrent_constraint: 911 | self.recurrent_constraint = NaturalConstraint( 912 | self.recurrent_constraint) 913 | 914 | self.recurrent_kernel_posterior = self.recurrent_kernel_posterior_fn( 915 | dtype, shape_recurrent, 'recurrent_kernel_posterior', 916 | self.trainable, 917 | self.add_variable, 918 | natural_initializer=recurrent_initializer, 919 | natural_regularizer=self.recurrent_regularizer, 920 | natural_constraint=self.recurrent_constraint, 921 | caching_device=default_caching_device) 922 | 923 | if self.recurrent_kernel_prior_fn is None: 924 | self.recurrent_kernel_prior = None 925 | else: 926 | self.recurrent_kernel_prior = self.recurrent_kernel_prior_fn( 927 | dtype, shape_recurrent, 'recurrent_kernel_prior', 928 | self.trainable, self.add_variable) 929 | 930 | if self.use_bias: 931 | if self.unit_forget_bias: 932 | 933 | def bias_initializer(_, *args, **kwargs): 934 | return K.concatenate([ 935 | self.bias_initializer((self.units,), *args, **kwargs), 936 | tf.keras.initializers.Ones()((self.units,), *args, **kwargs), 937 | self.bias_initializer((self.units * 2,), *args, **kwargs), 938 | ]) 939 | else: 940 | bias_initializer = self.bias_initializer 941 | 942 | self.bias = self.add_weight( 943 | shape=(self.units * 4,), 944 | name='bias', 945 | initializer=bias_initializer, 946 | regularizer=self.bias_regularizer, 947 | constraint=self.bias_constraint, 948 | caching_device=default_caching_device) 949 | else: 950 | self.bias = None 951 | 952 | 953 | self._apply_divergence( 954 | self.kernel_divergence_fn, 955 | self.kernel_posterior, 956 | self.kernel_prior, 957 | name='divergence_kernel') 958 | self._apply_divergence( 959 | self.recurrent_kernel_divergence_fn, 960 | self.recurrent_kernel_posterior, 961 | self.recurrent_kernel_prior, 962 | name='divergence_recurrent_kernel') 963 | 964 | self.built = True 965 | 966 | def _apply_divergence(self, divergence_fn, posterior, prior, name, 967 | posterior_tensor=None): 968 | divergence = tf.identity( 969 | divergence_fn( 970 | posterior, prior, posterior_tensor), 971 | name=name) 972 | self.add_loss(divergence) 973 | 974 | def sample_weights(self): 975 | self.kernel = self.kernel_posterior_tensor_fn(self.kernel_posterior) 976 | self.recurrent_kernel = self.recurrent_kernel_posterior_tensor_fn( 977 | self.recurrent_kernel_posterior) 978 | 979 | 980 | class LSTMCellReparametrizationNatural(tf.keras.layers.LSTMCell): 981 | 982 | def __init__(self, 983 | units, 984 | activation='tanh', 985 | recurrent_activation='hard_sigmoid', 986 | use_bias=True, 987 | kernel_initializer='glorot_uniform', 988 | recurrent_initializer='orthogonal', 989 | bias_initializer='zeros', 990 | unit_forget_bias=True, 991 | kernel_constraint=None, 992 | recurrent_constraint=None, 993 | bias_constraint=None, 994 | dropout=0., 995 | recurrent_dropout=0., 996 | implementation=1, 997 | kernel_posterior_fn=None, 998 | kernel_posterior_tensor_fn=(lambda d: d.sample()), 999 | recurrent_kernel_posterior_fn=None, 1000 | recurrent_kernel_posterior_tensor_fn=(lambda d: d.sample()), 1001 | kernel_prior_fn=None, 1002 | recurrent_kernel_prior_fn=None, 1003 | kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), 1004 | recurrent_kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), 1005 | bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( 1006 | is_singular=True), 1007 | bias_posterior_tensor_fn=(lambda d: d.sample()), 1008 | bias_prior_fn=None, 1009 | bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), 1010 | client_weight=1., 1011 | **kwargs): 1012 | 1013 | self.untransformed_scale_initializer = None 1014 | if 'untransformed_scale_initializer' in kwargs: 1015 | self.untransformed_scale_initializer = \ 1016 | kwargs.pop('untransformed_scale_initializer') 1017 | 1018 | if kernel_posterior_fn is None: 1019 | kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn 1020 | if kernel_prior_fn is None: 1021 | kernel_prior_fn = self.natural_tensor_multivariate_normal_fn 1022 | if recurrent_kernel_posterior_fn is None: 1023 | recurrent_kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn 1024 | if recurrent_kernel_prior_fn is None: 1025 | recurrent_kernel_prior_fn = self.natural_tensor_multivariate_normal_fn 1026 | 1027 | super(LSTMCellReparametrizationNatural, self).__init__( 1028 | units, 1029 | activation=activation, 1030 | recurrent_activation=recurrent_activation, 1031 | use_bias=use_bias, 1032 | kernel_initializer=kernel_initializer, 1033 | recurrent_initializer=recurrent_initializer, 1034 | bias_initializer=bias_initializer, 1035 | unit_forget_bias=unit_forget_bias, 1036 | kernel_regularizer=None, 1037 | recurrent_regularizer=None, 1038 | bias_regularizer=None, 1039 | kernel_constraint=kernel_constraint, 1040 | recurrent_constraint=recurrent_constraint, 1041 | bias_constraint=bias_constraint, 1042 | dropout=dropout, 1043 | recurrent_dropout=recurrent_dropout, 1044 | implementation=implementation, 1045 | **kwargs) 1046 | 1047 | self.kernel_posterior_fn = kernel_posterior_fn 1048 | self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn 1049 | self.recurrent_kernel_posterior_fn = recurrent_kernel_posterior_fn 1050 | self.recurrent_kernel_posterior_tensor_fn = recurrent_kernel_posterior_tensor_fn 1051 | self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn 1052 | self.kernel_prior_fn = kernel_prior_fn 1053 | self.recurrent_kernel_prior_fn = recurrent_kernel_prior_fn 1054 | self.kernel_divergence_fn = kernel_divergence_fn 1055 | self.recurrent_kernel_divergence_fn = recurrent_kernel_divergence_fn 1056 | self.bias_posterior_fn = bias_posterior_fn 1057 | self.bias_posterior_tensor_fn = bias_posterior_tensor_fn 1058 | self.bias_prior_fn = bias_prior_fn 1059 | self.bias_divergence_fn = bias_divergence_fn 1060 | self.client_weight = client_weight 1061 | 1062 | @tf_utils.shape_type_conversion 1063 | def build(self, input_shape): 1064 | default_caching_device = _caching_device(self) 1065 | input_dim = input_shape[-1] 1066 | 1067 | shape_kernel = (input_dim, self.units * 4) 1068 | shape_recurrent = (self.units, self.units * 4) 1069 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 1070 | 1071 | kernel_initializer = natural_initializer_fn( 1072 | loc_stdev=0.1, 1073 | u_scale_init_avg=-5, 1074 | u_scale_init_stdev=0.1, 1075 | untransformed_scale_initializer=self.untransformed_scale_initializer, 1076 | loc_initializer=self.kernel_initializer) 1077 | 1078 | self.kernel_posterior = self.kernel_posterior_fn(dtype, shape_kernel, 1079 | 'kernel_posterior', 1080 | self.trainable, 1081 | self.add_variable, 1082 | natural_initializer=kernel_initializer) 1083 | 1084 | if self.kernel_prior_fn is None: 1085 | self.kernel_prior = None 1086 | else: 1087 | self.kernel_prior = self.kernel_prior_fn( 1088 | dtype, shape_kernel, 'kernel_prior', 1089 | self.trainable, self.add_variable) 1090 | 1091 | recurrent_initializer = natural_initializer_fn( 1092 | loc_stdev=0.1, 1093 | u_scale_init_avg=-5, 1094 | u_scale_init_stdev=0.1, 1095 | untransformed_scale_initializer=self.untransformed_scale_initializer, 1096 | loc_initializer=self.recurrent_initializer) 1097 | 1098 | self.recurrent_kernel_posterior = \ 1099 | self.recurrent_kernel_posterior_fn(dtype, shape_recurrent, 1100 | 'recurrent_kernel_posterior', 1101 | self.trainable, 1102 | self.add_variable, 1103 | natural_initializer=recurrent_initializer) 1104 | 1105 | if self.recurrent_kernel_prior_fn is None: 1106 | self.recurrent_kernel_prior = None 1107 | else: 1108 | self.recurrent_kernel_prior = self.recurrent_kernel_prior_fn( 1109 | dtype, shape_recurrent, 'recurrent_kernel_prior', 1110 | self.trainable, self.add_variable) 1111 | 1112 | if self.use_bias: 1113 | if self.unit_forget_bias: 1114 | 1115 | def bias_initializer(_, *args, **kwargs): 1116 | return K.concatenate([ 1117 | self.bias_initializer((self.units,), *args, **kwargs), 1118 | tf.keras.initializers.Ones()((self.units,), *args, **kwargs), 1119 | self.bias_initializer((self.units * 2,), *args, **kwargs), 1120 | ]) 1121 | else: 1122 | bias_initializer = self.bias_initializer 1123 | 1124 | self.bias = self.add_weight( 1125 | shape=(self.units * 4,), 1126 | name='bias', 1127 | initializer=bias_initializer, 1128 | regularizer=self.bias_regularizer, 1129 | constraint=self.bias_constraint, 1130 | caching_device=default_caching_device) 1131 | else: 1132 | self.bias = None 1133 | 1134 | 1135 | self._apply_divergence( 1136 | self.kernel_divergence_fn, 1137 | self.kernel_posterior, 1138 | self.kernel_prior, 1139 | name='divergence_kernel') 1140 | self._apply_divergence( 1141 | self.recurrent_kernel_divergence_fn, 1142 | self.recurrent_kernel_posterior, 1143 | self.recurrent_kernel_prior, 1144 | name='divergence_recurrent_kernel') 1145 | 1146 | self.built = True 1147 | 1148 | def _apply_divergence(self, divergence_fn, posterior, prior, name, 1149 | posterior_tensor=None): 1150 | divergence = tf.identity( 1151 | divergence_fn( 1152 | posterior, prior, posterior_tensor), 1153 | name=name) 1154 | self.add_loss(divergence) 1155 | 1156 | def sample_weights(self): 1157 | self.kernel = self.kernel_posterior_tensor_fn(self.kernel_posterior) 1158 | self.recurrent_kernel = self.recurrent_kernel_posterior_tensor_fn( 1159 | self.recurrent_kernel_posterior) 1160 | 1161 | 1162 | class RNNVarReparametrized(tf.keras.layers.RNN): 1163 | 1164 | def compute_delta(self): 1165 | return self.cell.compute_delta() 1166 | 1167 | def renew_center(self, center_to_update=True): 1168 | self.cell.renew_center(center_to_update) 1169 | 1170 | def apply_delta(self, delta): 1171 | self.cell.apply_delta(delta) 1172 | 1173 | def receive_and_save_weights(self, layer_server): 1174 | self.cell.receive_and_save_weights(layer_server.cell) 1175 | 1176 | def initialize_kernel_posterior(self): 1177 | self.cell.initialize_kernel_posterior() 1178 | 1179 | def apply_damping(self, damping_factor): 1180 | self.cell.apply_damping(damping_factor) 1181 | 1182 | def call(self, 1183 | inputs, 1184 | mask=None, 1185 | training=None, 1186 | initial_state=None, 1187 | constants=None): 1188 | self.cell.sample_weights() 1189 | return super(RNNVarReparametrized, self).call( 1190 | inputs, 1191 | mask=mask, 1192 | training=training, 1193 | initial_state=initial_state, 1194 | constants=constants) 1195 | -------------------------------------------------------------------------------- /source/normal_natural.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | from tensorflow_probability.python import distributions as tfd 4 | import numpy as np 5 | from tensorflow_probability.python.bijectors import identity as identity_bijector 6 | from tensorflow_probability.python.internal import assert_util 7 | from tensorflow_probability.python.internal import dtype_util 8 | from tensorflow_probability.python.internal import prefer_static 9 | from tensorflow_probability.python.internal import reparameterization 10 | from tensorflow_probability.python.internal import special_math 11 | from tensorflow_probability.python.internal import tensor_util 12 | 13 | eps = 1e-6 14 | 15 | 16 | class NormalNatural(tfd.Distribution): 17 | 18 | def __init__(self, 19 | gamma, 20 | prec, 21 | validate_args=False, 22 | allow_nan_stats=True, 23 | name='NormalNatural'): 24 | """Construct Normal distributions with natural parameters `gamma` and `prec`. 25 | 26 | The parameters `gamma` and `prec` must be shaped in a way that supports 27 | broadcasting (e.g. `gamma + prec` is a valid operation). 28 | 29 | Args: 30 | gamma: Floating point tensor; the signal to noise ratio of the distribution(s). 31 | prec: Floating point tensor; the precision of the distribution(s). 32 | Must contain only positive values. 33 | validate_args: Python `bool`, default `False`. When `True` distribution 34 | parameters are checked for validity despite possibly degrading runtime 35 | performance. When `False` invalid inputs may silently render incorrect 36 | outputs. 37 | allow_nan_stats: Python `bool`, default `True`. When `True`, 38 | statistics (e.g., mean, mode, variance) use the value "`NaN`" to 39 | indicate the result is undefined. When `False`, an exception is raised 40 | if one or more of the statistic's batch members are undefined. 41 | name: Python `str` name prefixed to Ops created by this class. 42 | 43 | Raises: 44 | TypeError: if `gamma` and `prec` have different `dtype`. 45 | """ 46 | parameters = dict(locals()) 47 | with tf.name_scope(name) as name: 48 | dtype = dtype_util.common_dtype([gamma, prec], dtype_hint=tf.float32) 49 | self._gamma = tensor_util.convert_nonref_to_tensor( 50 | gamma, dtype=dtype, name='gamma') 51 | self._prec = tensor_util.convert_nonref_to_tensor( 52 | prec, dtype=dtype, name='prec') 53 | super(NormalNatural, self).__init__(dtype=dtype, 54 | reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, 55 | validate_args=validate_args, 56 | allow_nan_stats=allow_nan_stats, 57 | parameters=parameters, 58 | name=name) 59 | 60 | @staticmethod 61 | def _param_shapes(sample_shape): 62 | return dict( 63 | zip(('gamma', 'prec'), 64 | ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2))) 65 | 66 | @classmethod 67 | def _params_event_ndims(cls): 68 | return dict(gamma=0, prec=0) 69 | 70 | @property 71 | def gamma(self): 72 | """Distribution parameter for the gamma.""" 73 | return self._gamma 74 | 75 | @property 76 | def prec(self): 77 | """Distribution parameter for standard deviation.""" 78 | return self._prec 79 | 80 | def _batch_shape_tensor(self, gamma=None, prec=None): 81 | return prefer_static.broadcast_shape( 82 | prefer_static.shape(self.gamma if gamma is None else gamma), 83 | prefer_static.shape(self.prec if prec is None else prec)) 84 | 85 | def _batch_shape(self): 86 | return tf.broadcast_static_shape(self.gamma.shape, self.prec.shape) 87 | 88 | def _event_shape_tensor(self): 89 | return tf.constant([], dtype=tf.int32) 90 | 91 | def _event_shape(self): 92 | return tf.TensorShape([]) 93 | 94 | @property 95 | def loc(self): 96 | return tf.math.multiply_no_nan(self.gamma, tf.math.reciprocal_no_nan(self.prec)) 97 | 98 | @property 99 | def scale(self): 100 | return tf.math.sqrt(tf.math.reciprocal_no_nan(self.prec)) 101 | 102 | def _sample_n(self, n, seed=None): 103 | gamma = tf.convert_to_tensor(self.gamma) 104 | prec = tf.convert_to_tensor(self.prec) 105 | shape = tf.concat([[n], self._batch_shape_tensor(gamma=gamma, prec=prec)], 106 | axis=0) 107 | sampled = tf.random.normal( 108 | shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) 109 | 110 | return tf.math.multiply(sampled, self.scale) + self.loc 111 | 112 | def _log_prob(self, x): 113 | prec = tf.convert_to_tensor(self.prec) 114 | log_unnormalized = - prec * tf.math.squared_difference(x, self.gamma) 115 | log_normalization = tf.constant( 116 | 0.5 * np.log(2. * np.pi), dtype=self.dtype) - 0.5 * tf.math.log(prec) 117 | return log_unnormalized - log_normalization 118 | 119 | def _log_cdf(self, x): 120 | return special_math.log_ndtr(self._z(x)) 121 | 122 | def _cdf(self, x): 123 | return special_math.ndtr(self._z(x)) 124 | 125 | def _log_survival_function(self, x): 126 | return special_math.log_ndtr(-self._z(x)) 127 | 128 | def _survival_function(self, x): 129 | return special_math.ndtr(-self._z(x)) 130 | 131 | def _entropy(self): 132 | log_normalization = tf.constant( 133 | 0.5 * np.log(2. * np.pi), dtype=self.dtype) - 0.5 * tf.math.log(self.prec) 134 | entropy = 0.5 + log_normalization 135 | return entropy * tf.ones_like(self.gamma) 136 | 137 | def _mean(self): 138 | return self.gamma / self.prec * tf.ones_like(self.prec) 139 | 140 | def _quantile(self, p): 141 | return special_math.ndtri(p) * self._stddev() + self._mean() 142 | 143 | def _stddev(self): 144 | return tf.math.sqrt(1. / self.prec) * tf.ones_like(self.gamma) 145 | 146 | _mode = _mean 147 | 148 | def _z(self, x, prec=None): 149 | """Standardize input `x` to a unit normal.""" 150 | with tf.name_scope('standardize'): 151 | return (self.prec * x - self.gamma) / (tf.math.sqrt(self.prec) if prec is None else tf.math.sqrt(prec)) 152 | 153 | def _default_event_space_bijector(self): 154 | return identity_bijector.Identity(validate_args=self.validate_args) 155 | 156 | def _parameter_control_dependencies(self, is_init): 157 | assertions = [] 158 | 159 | if is_init: 160 | try: 161 | self._batch_shape() 162 | except ValueError: 163 | raise ValueError( 164 | 'Arguments `loc` and `scale` must have compatible shapes; ' 165 | 'loc.shape={}, scale.shape={}.'.format( 166 | self.gamma.shape, self.prec.shape)) 167 | # We don't bother checking the shapes in the dynamic case because 168 | # all member functions access both arguments anyway. 169 | 170 | if not self.validate_args: 171 | assert not assertions # Should never happen. 172 | return [] 173 | 174 | if is_init != tensor_util.is_ref(self.scale): 175 | assertions.append(assert_util.assert_positive( 176 | self.scale, message='Argument `scale` must be positive.')) 177 | 178 | return assertions 179 | 180 | 181 | @tfp.distributions.kullback_leibler.RegisterKL(NormalNatural, NormalNatural) 182 | def _kl_normal_natural(a, b, name=None): 183 | """Calculate the batched KL divergence KL(a || b) with a and b Normal. 184 | 185 | Args: 186 | a: instance of a NormalNatural distribution object. 187 | b: instance of a NormalNatural distribution object. 188 | name: Name to use for created operations. 189 | Default value: `None` (i.e., `'kl_normal_natural'`). 190 | 191 | Returns: 192 | kl_div: Batchwise KL(a || b) 193 | """ 194 | with tf.name_scope(name or 'kl_normal_natural'): 195 | a_prec = tf.convert_to_tensor(a.prec) 196 | b_prec = tf.convert_to_tensor(b.prec) # We'll read it thrice. 197 | diff_log_prec = tf.math.log(a.prec + eps) - tf.math.log(b_prec + eps) 198 | inverse_a_prec = tf.math.reciprocal_no_nan(a_prec) 199 | inverse_b_prec = tf.math.reciprocal_no_nan(a_prec) 200 | return ( 201 | 0.5 * tf.multiply(b_prec, tf.math.squared_difference(tf.math.multiply(a.gamma, inverse_a_prec), 202 | tf.math.multiply(b.gamma, inverse_b_prec))) + 203 | 0.5 * tf.math.expm1(- diff_log_prec) + 204 | 0.5 * diff_log_prec) 205 | -------------------------------------------------------------------------------- /source/tfp_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | import tensorflow_probability as tfp 5 | from tensorflow_probability.python import distributions as tfd 6 | import tensorflow.compat.v2 as tf 7 | from tensorflow_probability.python.internal import dtype_util 8 | from tensorflow_probability.python.internal import tensorshape_util 9 | 10 | 11 | softplus = tfp.bijectors.Softplus() 12 | precision_from_scale = tfp.bijectors.Chain([tfp.bijectors.Reciprocal(), tfp.bijectors.Square()]) 13 | precision_from_untransformed_scale = tfp.bijectors.Chain([precision_from_scale, softplus]) 14 | CLIP_VALUE = 1e15 15 | 16 | 17 | class SoftClip(tfp.bijectors.Bijector): 18 | 19 | def __init__(self, low=None, high=None): 20 | self.low = low 21 | if self.low is None: 22 | self.low = -1e7 23 | self.high = high 24 | if self.high is None: 25 | self.high = 1e7 26 | 27 | def forward(self, x, name='forward', **kwargs): 28 | x_type = x.dtype 29 | x = tf.cast(x, tf.float64) 30 | self.low = tf.cast(self.low, x.dtype) 31 | self.high = tf.cast(self.high, x.dtype) 32 | return tf.cast(-softplus.forward(self.high - self.low - softplus.forward(x - self.low)) * \ 33 | (self.high - self.low) / (softplus.forward(self.high - self.low)) + self.high, x_type) 34 | 35 | def inverse(self, y, name='inverse', **kwargs): 36 | y_type = y.dtype 37 | y = tf.cast(y, tf.float64) 38 | return tf.cast(+softplus.inverse(self.high - self.low - softplus.inverse( 39 | (self.high - y) / (self.high - self.low) * softplus.forward(self.high - self.low))), y_type) 40 | 41 | 42 | def loc_prod_from_locprec(loc_times_prec, sum_prec): 43 | rec = tf.math.xdivy(1., sum_prec) 44 | rec = tf.clip_by_value(rec, -CLIP_VALUE, CLIP_VALUE) 45 | loc = tf.multiply(loc_times_prec, rec) 46 | 47 | return loc 48 | 49 | 50 | def loc_prod_from_precision(loc1, p1, loc2, p2): 51 | prec_prod = p1 + p2 52 | loc1p1 = tf.math.multiply(loc1, p1) 53 | loc2p2 = tf.math.multiply(loc2, p2) 54 | return loc_prod_from_locprec(loc1p1 + loc2p2, prec_prod) 55 | 56 | 57 | def compute_gaussian_prod(loc1, p1, loc2, p2): 58 | loc_prod = loc_prod_from_precision(loc1, p1, loc2, p2) 59 | return loc_prod, p1 + p2 60 | 61 | 62 | def loc_ratio_from_precision(loc1, p1, loc2, p2): 63 | return loc_prod_from_precision(loc1, p1, loc2, -p2) 64 | 65 | 66 | def compute_gaussian_ratio(loc1, p1, loc2, p2): 67 | return compute_gaussian_prod(loc1, p1, loc2, -p2) 68 | 69 | 70 | def renormalize_mean_field_normal_fn(loc_ratio, prec_ratio): 71 | 72 | def _fn(dtype, shape, name, trainable, add_variable_fn, 73 | initializer=tf.random_normal_initializer(stddev=0.1), 74 | regularizer=None, constraint=None, **kwargs): 75 | loc_scale_fn = tensor_loc_scale_fn(loc_initializer=initializer, 76 | loc_regularizer=regularizer, 77 | loc_constraint=constraint, **kwargs) 78 | 79 | loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn) 80 | prec = tfp.util.DeferredTensor(scale, precision_from_scale, name='precision') 81 | if scale is None: 82 | dist = tfd.Deterministic(loc=loc) 83 | else: 84 | loc_reparametrized, scale_reparametrized = \ 85 | reparametrize_loc_scale(loc, prec, loc_ratio, prec_ratio) 86 | dist = tfd.Normal(loc=loc_reparametrized, scale=scale_reparametrized) 87 | batch_ndims = tf.size(dist.batch_shape_tensor()) 88 | return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) 89 | return _fn 90 | 91 | 92 | def default_tensor_multivariate_normal_fn(loc_ratio, prec_ratio, num_clients, prior_scale=1.): 93 | def _fn(dtype, shape, name, trainable, add_variable_fn, initializer=tf.keras.initializers.constant(0.), 94 | regularizer=None, constraint=None, **kwargs): 95 | del trainable 96 | loc_scale_fn = tensor_loc_scale_fn(loc_initializer=initializer, 97 | loc_regularizer=regularizer, 98 | loc_constraint=constraint, 99 | untransformed_scale_initializer=tf.keras.initializers.constant( 100 | tfp.bijectors.Softplus().inverse(prior_scale*math.sqrt(num_clients)).numpy()), 101 | **kwargs) 102 | loc, scale = loc_scale_fn(dtype, shape, name, False, add_variable_fn) 103 | prec = tfp.util.DeferredTensor(scale, precision_from_scale) 104 | loc_reparametrized, scale_reparametrized = reparametrize_loc_scale(loc, prec, loc_ratio, prec_ratio) 105 | dist = tfd.Normal(loc=loc_reparametrized, scale=scale_reparametrized) 106 | batch_ndims = tf.size(input=dist.batch_shape_tensor()) 107 | return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims) 108 | return _fn 109 | 110 | 111 | def tensor_loc_scale_fn(is_singular=False, 112 | loc_initializer 113 | =tf.random_normal_initializer(stddev=0.1), 114 | untransformed_scale_initializer 115 | =tf.random_normal_initializer(mean=-3., stddev=0.1), 116 | loc_regularizer=None, 117 | untransformed_scale_regularizer=None, 118 | loc_constraint=None, 119 | untransformed_scale_constraint=None, 120 | **kwargs): 121 | def _fn(dtype, shape, name, trainable, add_variable_fn): 122 | """Creates `loc`, `scale` parameters.""" 123 | loc = add_variable_fn( 124 | name=name + '_loc', 125 | shape=shape, 126 | initializer=loc_initializer, 127 | regularizer=loc_regularizer, 128 | constraint=loc_constraint, 129 | dtype=dtype, 130 | trainable=trainable, 131 | **kwargs) 132 | if is_singular: 133 | return loc, None 134 | untransformed_scale = add_variable_fn( 135 | name=name + '_untransformed_scale', 136 | shape=shape, 137 | initializer=untransformed_scale_initializer, 138 | regularizer=untransformed_scale_regularizer, 139 | constraint=untransformed_scale_constraint, 140 | dtype=dtype, 141 | trainable=trainable, 142 | **kwargs) 143 | scale = tfp.util.DeferredTensor(untransformed_scale, tfp.bijectors.Softplus(), name=name + '_scale') 144 | return loc, scale 145 | return _fn 146 | 147 | 148 | def reparametrize_loc_scale(loc, prec, loc_ratio, prec_ratio): 149 | precision_reparametrized = tfp.util.DeferredTensor(prec, lambda x: x + prec_ratio) 150 | 151 | def loc_reparametrization_fn(x): 152 | return loc_prod_from_precision(x, prec, loc_ratio, prec_ratio) 153 | 154 | loc_reparametrized = tfp.util.DeferredTensor(loc, loc_reparametrization_fn) 155 | scale_reparametrized = tfp.util.DeferredTensor(precision_reparametrized, precision_from_scale.inverse) 156 | return loc_reparametrized, scale_reparametrized 157 | 158 | 159 | class LocPrecTuple(tuple): 160 | 161 | def assign(self, loc_prec_tuple): 162 | self[0].assign(loc_prec_tuple[0]) 163 | self[1].variables[0].assign(precision_from_untransformed_scale.inverse(loc_prec_tuple[1])) 164 | 165 | 166 | def sparse_delta_function(percentile): 167 | 168 | def sparse_subtract(n1, n2): 169 | #abs_snr = tf.abs(n1[..., 0] / tf.sqrt(n1[..., 1])) 170 | abs_snr = tf.abs(n1[..., 0] - n2[..., 0]) / tf.sqrt(tf.abs(n1[..., 1])) 171 | condition = abs_snr >= tfp.stats.percentile(abs_snr, percentile) 172 | sparse_delta = tf.where(tf.expand_dims(condition, -1), n1-n2, 0.) 173 | return sparse_delta 174 | 175 | return sparse_subtract 176 | -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import GPUtil 5 | from source.constants import ROOT_LOGGER_STR 6 | import tensorflow as tf 7 | from tensorflow.python.eager import context 8 | from tensorflow.python.ops import summary_ops_v2 9 | 10 | logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__) 11 | 12 | 13 | def gpu_session(num_gpus=None, gpus=None): 14 | print(gpus, tf.config.experimental.list_physical_devices('GPU')) 15 | if gpus: 16 | logger.info(f"{gpus}, " 17 | f"{tf.config.experimental.list_physical_devices('GPU')}") 18 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 19 | elif num_gpus: 20 | if num_gpus > 0: 21 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 22 | os.environ["CUDA_VISIBLE_DEVICES"] = set_free_gpus(num_gpus) 23 | print('visible devices:', os.environ["CUDA_VISIBLE_DEVICES"]) 24 | else: 25 | os.environ["CUDA_VISIBLE_DEVICES"] = '' 26 | num_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"]) 27 | gpus = os.environ["CUDA_VISIBLE_DEVICES"] 28 | logger.info(f'Cuda devices: {gpus}') if gpus else \ 29 | logger.info('No Cuda devices') 30 | 31 | tf_gpus = tf.config.experimental.list_physical_devices('GPU') 32 | if (gpus or num_gpus > 0) and len(tf_gpus) > 0: 33 | logger.info(f"{gpus}, " 34 | f"{tf.config.experimental.list_physical_devices('GPU')}") 35 | gpus = [tf_gpus[int(gpu)] for gpu in gpus] 36 | tf.config.experimental.set_visible_devices(gpus, 'GPU') 37 | tf.config.set_soft_device_placement(True) 38 | [tf.config.experimental.set_memory_growth(gpu, enable=True) for gpu in gpus] 39 | 40 | 41 | def set_free_gpus(num): 42 | # num: integer; number of GPUs that shall be allocated 43 | # returns: string; listing a total of 'num' available GPUs. 44 | 45 | list_gpu = GPUtil.getAvailable(limit=num, maxMemory=0.01) 46 | print(list_gpu) 47 | return str(list_gpu)[1:-1] 48 | 49 | 50 | def avg_dict(history_list, cards): 51 | avg_dict = {} 52 | for el in history_list: 53 | if hasattr(el, 'keys'): 54 | keys = el.keys() 55 | continue 56 | for key in keys: 57 | lists = list(zip(*[(history[key][-1]*card, card) 58 | for history, card in zip(history_list, cards) 59 | if history])) 60 | avg_dict[key] = sum(lists[0])/sum(lists[1]) 61 | return avg_dict 62 | 63 | 64 | def avg_dict_eval(eval_fed, cards): 65 | eval = np.array([np.array(eval)*card for eval, card in zip(eval_fed, cards)]) 66 | return eval.sum(axis=0) 67 | 68 | 69 | class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy): 70 | 71 | def __init__(self, name='sparse_categorical_accuracy', dtype=None, vocab_size=0): 72 | super().__init__(name, dtype=dtype) 73 | self.vocab_size = vocab_size 74 | 75 | def update_state(self, y_true, y_pred, sample_weight=None): 76 | y_true = tf.reshape(y_true, [-1, 1]) 77 | y_pred = tf.reshape(y_pred, [-1, self.vocab_size+1, 1]) 78 | if sample_weight is not None: 79 | sample_weight = tf.reshape(sample_weight, [-1, 1]) 80 | return super().update_state( 81 | y_true, y_pred, sample_weight) 82 | 83 | 84 | class CustomTensorboard(tf.keras.callbacks.TensorBoard): 85 | 86 | def __init__(self, *args, **kwargs): 87 | super(CustomTensorboard, self).__init__(*args, **kwargs) 88 | self.epoch = 0 89 | 90 | def _log_distr(self, epoch): 91 | """Logs the weights of the gaussian distributions to TensorBoard.""" 92 | writer = self._get_writer(self._train_run_name) 93 | with context.eager_mode(), \ 94 | writer.as_default(), \ 95 | summary_ops_v2.always_record_summaries(): 96 | for layer in self.model.layers: 97 | layer_to_check = layer 98 | if hasattr(layer, 'cell'): 99 | layer_to_check = layer.cell 100 | for weight in layer_to_check.trainable_weights: 101 | if 'natural' in weight.name + layer.name: 102 | tf.summary.histogram(layer.name + '/' + weight.name + '_gamma', 103 | weight[..., 0], step=epoch) 104 | tf.summary.histogram(layer.name + '/' + weight.name + '_prec', 105 | weight[..., 1], step=epoch) 106 | else: 107 | tf.summary.histogram(layer.name + '/' + weight.name, weight, step=epoch) 108 | if hasattr(layer_to_check, 'recurrent_kernel_posterior'): 109 | tf.summary.histogram( 110 | layer.name + '/recurrent_kernel_posterior' + '_gamma_reparametrized', 111 | layer_to_check.recurrent_kernel_posterior.distribution.gamma, 112 | step=epoch) 113 | tf.summary.histogram( 114 | layer.name + '/recurrent_kernel_posterior' + '_prec_reparametrized', 115 | layer_to_check.recurrent_kernel_posterior.distribution.prec, 116 | step=epoch) 117 | if hasattr(layer_to_check, 'kernel_posterior'): 118 | tf.summary.histogram( 119 | layer.name + '/kernel_posterior' + '_gamma_reparametrized', 120 | layer_to_check.kernel_posterior.distribution.gamma, 121 | step=epoch) 122 | tf.summary.histogram( 123 | layer.name + '/kernel_posterior' + '_prec_reparametrized', 124 | layer_to_check.kernel_posterior.distribution.prec, 125 | step=epoch) 126 | writer.flush() 127 | 128 | def on_epoch_end(self, epoch, logs=None): 129 | self.epoch = self.epoch + 1 130 | epoch = self.epoch 131 | """Runs metrics and histogram summaries at epoch end.""" 132 | self._log_metrics(logs, prefix='', step=epoch) 133 | 134 | if self.histogram_freq and epoch % self.histogram_freq == 0: 135 | self._log_weights(epoch) 136 | self._log_distr(epoch) 137 | 138 | if self.embeddings_freq and epoch % self.embeddings_freq == 0: 139 | self._log_embeddings(epoch) 140 | 141 | 142 | def softmax(x): 143 | ex = np.exp(x) 144 | sum_ex = np.sum(np.exp(x)) 145 | return ex / sum_ex -------------------------------------------------------------------------------- /source/virtual_process.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from source.federated_devices import (ClientVirtualSequential, 3 | ClientVirtualModel, 4 | ServerSequential, 5 | ServerModel) 6 | from source.fed_process import FedProcess 7 | from source.constants import ROOT_LOGGER_STR 8 | 9 | 10 | logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__) 11 | 12 | 13 | class VirtualFedProcess(FedProcess): 14 | 15 | def __init__(self, model_fn, num_clients, damping_factor=1, 16 | fed_avg_init=False): 17 | super(VirtualFedProcess, self).__init__(model_fn, num_clients) 18 | self.damping_factor = damping_factor 19 | self.fed_avg_init = fed_avg_init 20 | 21 | def build(self, cards_train, hierarchical): 22 | if hierarchical: 23 | client_model_class = ClientVirtualModel 24 | server_model_class = ServerModel 25 | else: 26 | client_model_class = ClientVirtualSequential 27 | server_model_class = ServerSequential 28 | for indx in self.clients_indx: 29 | model = self.model_fn(client_model_class, cards_train[indx], cards_train[indx]/sum(cards_train)) 30 | self.clients.append(model) 31 | self.server = self.model_fn(server_model_class, sum(cards_train), 1.) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucori/virtual/c3c4abb43ff2001c06849b02bc4dfc5ef1ea50e5/test.py --------------------------------------------------------------------------------