├── src ├── dynafed.png ├── results.png ├── results_table.png └── trajectorymatch_dynafed_final1.pdf ├── experiments ├── cifar10 │ ├── cifar10_0.01_serverdistill.sh │ ├── cifar10_0.04_serverdistill.sh │ └── cifar10_0.16_serverdistill.sh ├── cifar100 │ ├── cifar100_0.01_serverdistill.sh │ ├── cifar100_0.04_serverdistill.sh │ └── cifar100_0.16_serverdistill.sh └── cinic10 │ ├── cinic_0.01_serverdistill.sh │ ├── cinic_0.04_serverdistill.sh │ └── cinic_0.16_serverdistill.sh ├── README.md ├── client.py ├── experiment_manager.py ├── server.py ├── reparam_module.py ├── run_end2end.py ├── data.py ├── image_synthesizer.py ├── models.py └── utils.py /src/dynafed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pipilurj/DynaFed/HEAD/src/dynafed.png -------------------------------------------------------------------------------- /src/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pipilurj/DynaFed/HEAD/src/results.png -------------------------------------------------------------------------------- /src/results_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pipilurj/DynaFed/HEAD/src/results_table.png -------------------------------------------------------------------------------- /src/trajectorymatch_dynafed_final1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pipilurj/DynaFed/HEAD/src/trajectorymatch_dynafed_final1.pdf -------------------------------------------------------------------------------- /experiments/cifar10/cifar10_0.01_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='1' 7 | hyperparameters01='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cifar10"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.01], 18 | "eta" : [0.4], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [20], 31 | "distill_lr": [1e-4], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters01" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed --runs_name hyperparameters01 51 | -------------------------------------------------------------------------------- /experiments/cifar10/cifar10_0.04_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='0' 7 | hyperparameters04='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cifar10"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.04], 18 | "eta" : [0.4], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [20], 31 | "distill_lr": [1e-4], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters04" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed --runs_name hyperparameters04 51 | -------------------------------------------------------------------------------- /experiments/cifar10/cifar10_0.16_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='0' 7 | hyperparameters016='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cifar10"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.04], 18 | "eta" : [0.4], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [15], 31 | "distill_lr": [1e-4], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters016" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed --runs_name hyperparameters016 51 | -------------------------------------------------------------------------------- /experiments/cifar100/cifar100_0.01_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='1' 7 | hyperparameters01='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cifar100"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.01], 18 | "eta" : [0.4], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [8], 31 | "distill_lr": [1e-4], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters01" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed_cifar100 --runs_name hyperparameters01 51 | -------------------------------------------------------------------------------- /experiments/cifar100/cifar100_0.04_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='0' 7 | hyperparameters04='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cifar100"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.04], 18 | "eta" : [0.4], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [8], 31 | "distill_lr": [1e-4], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters04" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed_cifar100 --runs_name hyperparameters04 51 | -------------------------------------------------------------------------------- /experiments/cinic10/cinic_0.01_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='1' 7 | hyperparameters01='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cinic10"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.01], 18 | "eta" : [0.8], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [20], 31 | "distill_lr": [0.00025], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters01" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed_cinic10 --runs_name hyperparameters01 51 | -------------------------------------------------------------------------------- /experiments/cinic10/cinic_0.04_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='1' 7 | hyperparameters04='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cinic10"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.01], 18 | "eta" : [0.8], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [20], 31 | "distill_lr": [0.00025], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters04" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed_cinic10 --runs_name hyperparameters04 51 | -------------------------------------------------------------------------------- /experiments/cinic10/cinic_0.16_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='1' 7 | hyperparameters16='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cinic10"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.01], 18 | "eta" : [0.8], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [20], 31 | "distill_lr": [0.00025], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters16" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed_cinic10 --runs_name hyperparameters16 51 | -------------------------------------------------------------------------------- /experiments/cifar100/cifar100_0.16_serverdistill.sh: -------------------------------------------------------------------------------- 1 | cmdargs=$1 2 | 3 | # `gpu=$1` 4 | # `echo "export CUDA_VISIBLE_DEVICES=${gpu}"` 5 | #export CUDA_VISIBLE_DEVICES='0,1' 6 | export CUDA_VISIBLE_DEVICES='0' 7 | hyperparameters016='[{ 8 | "random_seed" : [4], 9 | 10 | "dataset" : ["cifar100"], 11 | "models" : [{"ConvNet" : 80}], 12 | 13 | "attack_rate" : [0], 14 | "attack_method": ["-"], 15 | "participation_rate" : [0.4], 16 | 17 | "alpha" : [0.04], 18 | "eta" : [0.4], 19 | "client_mode": ["normal"], 20 | "minimum_trajectory_length": [[25]], 21 | "maximum_distill_round": [1], 22 | "distill_interval": [1], 23 | "start_round": [0], 24 | "communication_rounds" : [200], 25 | "local_epochs" : [1], 26 | "batch_size" : [32], 27 | "val_size" : [32], 28 | "val_batch_size": [32], 29 | "local_optimizer" : [ ["Adam", {"lr": 0.001}]], 30 | "distill_iter": [8], 31 | "distill_lr": [1e-4], 32 | 33 | "aggregation_mode" : ["datadistill"], 34 | 35 | "sample_size": [0], 36 | "save_scores" : [false], 37 | 38 | "pretrained" : [null], 39 | "save_model" : [null], 40 | "log_frequency" : [1], 41 | "log_path" : ["new_noniid/"]}] 42 | 43 | ' 44 | 45 | 46 | RESULTS_PATH="results/" 47 | DATA_PATH="../data/" 48 | CHECKPOINT_PATH="checkpoints/" 49 | 50 | python -u run_end2end.py --hp="$hyperparameters016" --RESULTS_PATH="$RESULTS_PATH" --DATA_PATH="$DATA_PATH" --CHECKPOINT_PATH="$CHECKPOINT_PATH" $cmdargs --dataset=cifar10 --ipc=15 --syn_steps=20 --expert_epochs=3 --max_start_epoch=50 --min_start_epoch=0 --lr_img=5e-2 --lr_lr=1e-05 --lr_teacher=0.01 --pix_init noise --img_optim adam --lr_teacher 0.04 --weight_averaging --least_ave_num 2 --start_learning_label 0 --label_init 0. --Iteration 3000 --project dynafed_cifar100 --runs_name hyperparameters016 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DYNAFED: Tackling Client Data Heterogeneity with Global Dynamics 2 | ![Main Figure](src/dynafed.png) 3 | 4 | This repository contains the source code for the paper DYNAFED: Tackling Client Data Heterogeneity with Global Dynamics. 5 | Our paper is accepted by CVPR2023 and is available on arXiv: [link](https://arxiv.org/abs/2211.10878). 6 | 7 | ## Table of Contents 8 | 9 | - [Installation](#installation) 10 | - [Reproducing Results](#reproducing-results) 11 | - [Citation](#citation) 12 | 13 | ## Installation 14 | To use this project, you will need to install the following packages: 15 | 16 | - PyTorch: `pip install torch` 17 | - wandb: `pip install wandb` 18 | - scikit-learn: `pip install scikit-learn` 19 | 20 | ## Reproducing Results 21 | 22 | To reproduce the results from our paper, follow these steps: 23 | 24 | 1. Download the datasets (fmnist, cifar, cinic10). 25 | 2. Train the model by running the following commands: 26 | 27 | ``` 28 | # cifar10 experiments 29 | bash experiments/cifar10/cifar10_0.01_serverdistill.sh 30 | # cifar100 experiments 31 | bash experiments/cifar100/cifar100_0.01_serverdistill.sh 32 | # cinic10 experiments 33 | bash experiments/cinic10/cinic10_0.01_serverdistill.sh 34 | 35 | ``` 36 | 37 | ### Example Results 38 | 39 | ![Results 1](src/results.png) 40 | ![Results 2](src/results_table.png) 41 | 42 | ### Credits 43 | 44 | We would like to give credit to the following repositories for their code and resources that we used in our project: 45 | 46 | - [Dataset Distillation by Matching Training Trajectories 47 | ](https://github.com/GeorgeCazenavette/mtt-distillation) - we were inspired from the source code for distilling data from expert trajectories. 48 | 49 | ## Citation 50 | 51 | If you use our code or data in your research, please cite our paper. You can use the following BibTeX entry: 52 | ```bibtex 53 | @article{pi2022dynafed, 54 | title={DYNAFED: Tackling Client Data Heterogeneity with Global Dynamics}, 55 | author={Pi, Renjie and Zhang, Weizhong and Xie, Yueqi and Gao, Jiahui and Wang, Xiaoyu and Kim, Sunghun and Chen, Qifeng}, 56 | journal={arXiv preprint arXiv:2211.10878}, 57 | year={2022} 58 | } 59 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | from functools import partial 4 | from collections import OrderedDict 5 | import torch 6 | import torch.optim as optim 7 | #from torchcontrib.optim import SWA 8 | import torch.nn as nn 9 | import numpy as np 10 | from utils import * 11 | import models as model_utils 12 | from sklearn.linear_model import LogisticRegression 13 | 14 | import os 15 | 16 | # from gmm_torch.gmm import GaussianMixture 17 | from math import sqrt 18 | 19 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 20 | 21 | class Device(object): 22 | def __init__(self, loader): 23 | 24 | self.loader = loader 25 | 26 | def evaluate(self, loader=None): 27 | return eval_op(self.model, self.loader if not loader else loader) 28 | 29 | def save_model(self, path=None, name=None, verbose=True): 30 | if name: 31 | torch.save(self.model.state_dict(), path+name) 32 | if verbose: print("Saved model to", path+name) 33 | 34 | def load_model(self, path=None, name=None, verbose=True): 35 | if name: 36 | self.model.load_state_dict(torch.load(path+name)) 37 | if verbose: print("Loaded model from", path+name) 38 | 39 | class Client(Device): 40 | def __init__(self, model_name, optimizer_fn, loader, idnum=0, num_classes=10, images_train=None, labels_train=None, eta=0.5, dataset = 'cifar10'): 41 | super().__init__(loader) 42 | self.id = idnum 43 | print(f"dataset client {dataset}") 44 | self.model_name = model_name 45 | self.model_fn = partial(model_utils.get_model(self.model_name)[0], num_classes=num_classes , dataset = dataset) 46 | self.model = self.model_fn().to(device) 47 | 48 | self.W = {key : value for key, value in self.model.named_parameters()} 49 | 50 | self.optimizer_fn = optimizer_fn 51 | self.optimizer = self.optimizer_fn(self.model.parameters()) 52 | self.images_train, self.labels_train = images_train, labels_train 53 | self.eta = eta 54 | 55 | 56 | def synchronize_with_server(self, server): 57 | server_state = server.model_dict[self.model_name].state_dict() 58 | self.model.load_state_dict(server_state, strict=False) 59 | # server_state,server.parameter_dict['resnet8'], self.model.state_dict() 60 | 61 | def compute_weight_update(self, epochs=1, loader=None, lambda_fedprox=0.0, print_train_loss=False, hp=None): 62 | clip_bound, privacy_sigma = None, None 63 | if hp is not None: 64 | clip_bound, privacy_sigma = hp.get("clip_bound", None), hp.get("privacy_sigma", None) 65 | if privacy_sigma is not None: 66 | train_stats = train_op_private(self.model, self.loader if not loader else loader, self.optimizer, epochs, lambda_fedprox=lambda_fedprox, print_train_loss=print_train_loss, clip_bound=clip_bound, privacy_sigma=privacy_sigma) 67 | else: 68 | train_stats = train_op(self.model, self.loader if not loader else loader, self.optimizer, epochs, lambda_fedprox=lambda_fedprox, print_train_loss=print_train_loss) 69 | return train_stats 70 | 71 | def compute_weight_update_datadistill(self, epochs=1, loader=None, lambda_fedprox=0.0, current_round=0, start_round=0): 72 | print(f"current round {current_round}, start round {start_round}") 73 | if self.images_train is not None and self.labels_train is not None: 74 | train_stats = train_op_datadistill(self.model, self.loader if not loader else loader, self.optimizer, epochs, self.images_train, self.labels_train, eta=self.eta, current_round=current_round, start_round=start_round) 75 | else: 76 | train_stats = train_op(self.model, self.loader if not loader else loader, self.optimizer, epochs, lambda_fedprox=lambda_fedprox) 77 | return train_stats 78 | 79 | def compute_weight_update_datadistill_soft(self, epochs=1, loader=None, lambda_fedprox=0.0, current_round=0, start_round=0, dsa=True, args=None): 80 | # print(f"soft distill, current round {current_round}, start round {start_round}") 81 | if self.images_train is not None and self.labels_train is not None: 82 | train_stats = train_op_datadistill_soft(self.model, self.loader if not loader else loader, self.optimizer, epochs, self.images_train, self.labels_train, eta=self.eta, current_round=current_round, start_round=start_round, dsa=dsa, args=args) 83 | else: 84 | train_stats = train_op(self.model, self.loader if not loader else loader, self.optimizer, epochs, lambda_fedprox=lambda_fedprox) 85 | return train_stats 86 | 87 | def compute_weight_update_datadistill_later(self, epochs=1, loader=None, lambda_fedprox=0.0, finetune_lr=1e-3, finetune_epoch=1, current_round=0, start_round=0, dsa=None, args=None): 88 | if self.images_train is not None and self.labels_train is not None: 89 | train_stats = train_op_datadistill_later(self.model, self.loader if not loader else loader, self.optimizer, epochs, self.images_train, self.labels_train, finetune_epoch=finetune_epoch, finetune_lr=finetune_lr, current_round=current_round, start_round=start_round, dsa=dsa, args=args) 90 | else: 91 | train_stats = train_op(self.model, self.loader if not loader else loader, self.optimizer, epochs, lambda_fedprox=lambda_fedprox) 92 | return train_stats 93 | 94 | 95 | def predict_logit(self, x): 96 | """Softmax prediction on input""" 97 | self.model.train() 98 | 99 | with torch.no_grad(): 100 | y_ = self.model(x) 101 | 102 | return y_ 103 | 104 | def predict_logit_eval(self, x): 105 | """Softmax prediction on input""" 106 | self.model.eval() 107 | with torch.no_grad(): 108 | y_ = self.model(x) 109 | 110 | return y_ 111 | 112 | -------------------------------------------------------------------------------- /experiment_manager.py: -------------------------------------------------------------------------------- 1 | import glob, os, time 2 | 3 | import numpy as np 4 | 5 | import itertools as it 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | 10 | def save_results(results_dict, path, name, verbose=True): 11 | results_numpy = {key : np.array(value)for key, value in results_dict.items()} 12 | 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | np.savez(path+name, **results_numpy) 16 | if verbose: 17 | print("Saved results to ", path+name+".npz") 18 | 19 | 20 | def load_results(path, filename, verbose=True): 21 | results_dict = np.load(path+filename, allow_pickle=True) 22 | 23 | if verbose: 24 | print("Loaded results from "+path+filename) 25 | return results_dict 26 | 27 | 28 | class Experiment(): 29 | '''Class that contains logic to store hyperparameters und results of an experiment''' 30 | hyperparameters = {} 31 | results = {} 32 | parameters = {} 33 | 34 | def __init__(self, hyperparameters=None, hp_dict=None): 35 | if hp_dict is not None: 36 | self.from_dict(hp_dict) 37 | else: 38 | self.hyperparameters = hyperparameters 39 | self.hyperparameters_ = {} 40 | self.results = {} 41 | self.parameters = {} 42 | self.hyperparameters['finished'] = False 43 | self.hyperparameters['log_id'] = np.random.randint(10000000) 44 | 45 | 46 | 47 | 48 | def __str__(self): 49 | selfname = "Hyperparameters: \n" 50 | for key, value in self.hyperparameters.items(): 51 | selfname += " - "+key+" "*(24-len(key))+str(value)+"\n" 52 | return selfname 53 | 54 | def __repr__(self): 55 | return self.__str__() 56 | 57 | 58 | 59 | 60 | def log(self, update_dict, printout=True, override=False): 61 | # update a result 62 | for key, value in update_dict.items(): 63 | if (not key in self.results) or override: 64 | self.results[key] = [value] 65 | else: 66 | self.results[key] += [value] 67 | 68 | if printout: 69 | print(update_dict) 70 | 71 | 72 | def is_log_round(self, c_round): 73 | log_freq = self.hyperparameters['log_frequency'] 74 | if log_freq < 0: 75 | log_freq = np.ceil(self.hyperparameters['communication_rounds']/(-log_freq)).astype('int') 76 | if c_round == self.hyperparameters['communication_rounds']: 77 | self.hyperparameters['finished'] = True 78 | 79 | return (c_round == 1) or (c_round % log_freq == 0) or (c_round == self.hyperparameters['communication_rounds']) 80 | 81 | def save_parameters(self, parameters): 82 | self.parameters = parameters 83 | 84 | def to_dict(self): 85 | # turns an experiment into a dict that can be saved to disc 86 | return {'hyperparameters' : self.hyperparameters, 'hyperparameters_' : self.hyperparameters_, 87 | 'parameters' : self.parameters, **self.results} 88 | 89 | def from_dict(self, hp_dict): 90 | # takes a dict and turns it into an experiment 91 | self.results = dict(hp_dict) 92 | 93 | self.hyperparameters = hp_dict['hyperparameters'][np.newaxis][0] 94 | 95 | if 'parameters' in hp_dict: 96 | self.parameters = hp_dict['parameters'][np.newaxis][0] 97 | del self.results['parameters'] 98 | else: 99 | self.parameters = {} 100 | 101 | if 'hyperparameters_' in hp_dict: 102 | self.hyperparameters_ = hp_dict['hyperparameters_'][np.newaxis][0] 103 | del self.results['hyperparameters_'] 104 | else: 105 | self.hyperparameters_ = {} 106 | 107 | 108 | def save_to_disc(self, path, name): 109 | if path: 110 | save_results(self.to_dict(), os.path.join(path, name), 'xp_'+str(self.hyperparameters['log_id'])) 111 | 112 | 113 | 114 | def get_all_hp_combinations(hp): 115 | '''Turns a dict of lists into a list of dicts''' 116 | combinations = it.product(*(hp[name] for name in hp)) 117 | hp_dicts = [{key : value[i] for i,key in enumerate(hp)}for value in combinations] 118 | return hp_dicts 119 | 120 | 121 | def list_of_dicts_to_dict(hp_dicts): 122 | '''Turns a list of dicts into one dict of lists containing all individual values''' 123 | one_dict = {} 124 | for hp in hp_dicts: 125 | for key, value in hp.items(): 126 | if not key in one_dict: 127 | one_dict[key] = [value] 128 | elif value not in one_dict[key]: 129 | one_dict[key] += [value] 130 | return one_dict 131 | 132 | 133 | def get_list_of_experiments(path, only_finished=False, verbose=True): 134 | '''Returns all the results saved at location path''' 135 | list_of_experiments = [] 136 | 137 | os.chdir(path) 138 | for file in glob.glob("*.npz"): 139 | list_of_experiments += [Experiment(hp_dict=load_results(path+"/",file, verbose=False))] 140 | 141 | if only_finished: 142 | list_of_experiments = [xp for xp in list_of_experiments if 'finished' in xp.hyperparameters and xp.hyperparameters['finished']] 143 | 144 | if list_of_experiments and verbose: 145 | print("Loaded ",len(list_of_experiments), " Results from ", path) 146 | print() 147 | get_experiments_metadata(list_of_experiments) 148 | 149 | if not list_of_experiments: 150 | print("No finished Experiments. Consider setting only_finished to False") 151 | 152 | return list_of_experiments 153 | 154 | 155 | def get_experiment(path, name, verbose=False): 156 | '''Returns one result saved at location path''' 157 | experiment = Experiment(hp_dict=load_results(path+"/",name+".npz", verbose=False)) 158 | 159 | if verbose: 160 | print("Loaded ",1, " Result from ", path) 161 | print() 162 | get_experiments_metadata([experiment]) 163 | 164 | return experiment 165 | 166 | 167 | def get_experiments_metadata(list_of_experiments): 168 | hp_dicts = [experiment.hyperparameters for experiment in list_of_experiments] 169 | 170 | print('Hyperparameters: \n' ,list_of_dicts_to_dict(hp_dicts)) 171 | print() 172 | print('Tracked Variables: \n', list(list_of_experiments[0].results.keys())) 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import random 2 | import models as model_utils 3 | from utils import * 4 | from client import Device 5 | from utils import kd_loss 6 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 7 | 8 | 9 | 10 | class Server(Device): 11 | def __init__(self, model_names, loader, val_loader, num_classes=10, images_train=None, labels_train=None, eta=0.5 , dataset = 'cifar10', client_loaders=None): 12 | super().__init__(loader) 13 | self.val_loader = val_loader 14 | 15 | print(f"dataset server {dataset}") 16 | self.model_dict = {model_name : partial(model_utils.get_model(model_name)[0], num_classes=num_classes, dataset = dataset)().to(device) for model_name in model_names} 17 | self.parameter_dict = {model_name : {key : value for key, value in model.named_parameters()} for model_name, model in self.model_dict.items()} 18 | self.client_loaders = client_loaders 19 | self.images_train, self.labels_train = images_train, labels_train 20 | self.eta = eta 21 | 22 | 23 | self.models = list(self.model_dict.values()) 24 | 25 | 26 | def evaluate_ensemble(self): 27 | return eval_op_ensemble(self.models, self.loader, self.val_loader) 28 | 29 | 30 | def select_clients(self, clients, frac=1.0, unbalance_rate=1, sample_mode="uniform"): 31 | return random.sample(clients, int(len(clients)*frac)) 32 | 33 | def fedavg(self, clients): 34 | unique_client_model_names = np.unique([client.model_name for client in clients]) 35 | self.weights = torch.Tensor([1. / len(clients)] * len(clients)) 36 | for model_name in unique_client_model_names: 37 | reduce_average(target=self.parameter_dict[model_name], sources=[client.W for client in clients if client.model_name == model_name]) 38 | 39 | def distill(self, clients, optimizer_fn, epochs=1, mode="mean_logits", num_classes=10): 40 | optimizer_dict = {model_name: optimizer_fn( 41 | model.parameters()) for model_name, model in self.model_dict.items()} 42 | for model_name in self.model_dict: 43 | print("Distilling {} ...".format(model_name)) 44 | 45 | model = self.model_dict[model_name] 46 | optimizer = optimizer_dict[model_name] 47 | 48 | model.train() 49 | 50 | for ep in range(epochs): 51 | running_loss, samples = 0.0, 0 52 | for x,_ in tqdm(self.val_loader): 53 | x = x.to(device) 54 | 55 | if mode == "mean_logits": 56 | y = torch.zeros([x.shape[0], num_classes], device="cuda") 57 | for i, client in enumerate(clients): 58 | y_p = client.predict_logit(x) 59 | y += (y_p/len(clients)).detach() 60 | 61 | y = nn.Softmax(1)(y) 62 | 63 | optimizer.zero_grad() 64 | 65 | y_ = nn.LogSoftmax(1)(model(x)) 66 | 67 | loss = torch.nn.KLDivLoss(reduction="batchmean")(y_, y.detach()) 68 | 69 | running_loss += loss.item()*y.shape[0] 70 | samples += y.shape[0] 71 | 72 | loss.backward() 73 | optimizer.step() 74 | 75 | return {"loss": running_loss / samples, "epochs": ep} 76 | 77 | 78 | def abavg(self, clients): 79 | unique_client_model_names = np.unique([client.model_name for client in clients]) 80 | acc = torch.zeros([len(clients)], device="cuda") 81 | for x, true_y in self.val_loader: 82 | x = x.to(device) 83 | true_y = true_y.to(device) 84 | samples = x.shape[0] 85 | for i, client in enumerate(clients): 86 | y_ = client.predict_logit(x) 87 | _, predicted = torch.max(y_.detach(), 1) 88 | acc[i] = (predicted == true_y).sum().item()/ samples 89 | self.weights = acc/ acc.sum() 90 | print(self.weights) 91 | for model_name in unique_client_model_names: 92 | reduce_weighted(target=self.parameter_dict[model_name], sources=[client.W for client in clients if client.model_name == model_name], weights = self.weights) 93 | 94 | def datadistill(self, clients, distill_iter, distill_lr, dsa, args, current_round=0, start_round=0, ifsoft=True, test_client = False): 95 | if self.images_train is None or self.labels_train is None or current_round < start_round: 96 | self.fedavg(clients) 97 | else: 98 | unique_client_model_names = np.unique( 99 | [client.model_name for client in clients]) 100 | for model_name in unique_client_model_names: 101 | reduce_average(target=self.parameter_dict[model_name], sources=[ 102 | client.W for client in clients if client.model_name == model_name]) 103 | distilled_dataset = TensorDataset(self.images_train, self.labels_train) 104 | distilled_loader = torch.utils.data.DataLoader(distilled_dataset, batch_size=256, shuffle=True) 105 | client_test_losses = [[], [], []] 106 | print(f"num of loaders {len(clients)}") 107 | for model_name in self.model_dict: 108 | model = self.model_dict[model_name] 109 | model.train() 110 | with torch.no_grad(): 111 | for _ in range(3): 112 | for (x_dis, y_dis) in distilled_loader: 113 | x_dis , y_dis = x_dis.to(device), y_dis.to(device) 114 | model(x_dis) 115 | optimizer = torch.optim.Adam(model.parameters(), lr=distill_lr) 116 | loss_avg = 0 117 | for _ in range(distill_iter): 118 | if test_client: 119 | with torch.no_grad(): 120 | model.eval() 121 | for i, client_loader in enumerate(self.client_loaders): 122 | samples, correct, loss_c = 0, 0, 0 123 | for x_c, y_c in client_loader: 124 | x_c, y_c = x_c.to(device), y_c.to(device) 125 | out_c = model(x_c) 126 | _, predicted = torch.max(out_c.detach(), 1) 127 | l = F.cross_entropy(out_c, y_c).item()*y_c.shape[0] 128 | samples += y_c.shape[0] 129 | loss_c += l 130 | test_loss_c = loss_c/samples 131 | client_test_losses[i].append(round(test_loss_c, 2)) 132 | model.train() 133 | for (x_dis, y_dis) in distilled_loader: 134 | x_dis , y_dis = x_dis.to(device), y_dis.to(device) 135 | if dsa: 136 | x_dis = DiffAugment(x_dis, args.dsa_strategy, param=args.dsa_param) 137 | optimizer.zero_grad() 138 | if ifsoft: 139 | loss_distill = kd_loss(model(x_dis), y_dis) 140 | else: 141 | loss_distill = nn.CrossEntropyLoss()(model(x_dis), y_dis) 142 | loss_distill.backward() 143 | loss_avg += loss_distill.item() 144 | optimizer.step() 145 | print("Server client losses:") 146 | print(client_test_losses) 147 | print(f"length of client losses {[len(x) for x in client_test_losses]}") 148 | 149 | 150 | def sync_bn(self): 151 | for model in self.models: 152 | model.train() 153 | for x, _ in self.val_loader: 154 | x = x.to(device) 155 | y = model(x) 156 | 157 | 158 | -------------------------------------------------------------------------------- /reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from contextlib import contextmanager 4 | 5 | 6 | class ReparamModule(nn.Module): 7 | def _get_module_from_name(self, mn): 8 | if mn == '': 9 | return self 10 | m = self 11 | for p in mn.split('.'): 12 | m = getattr(m, p) 13 | return m 14 | 15 | def __init__(self, module): 16 | super(ReparamModule, self).__init__() 17 | self.module = module 18 | 19 | param_infos = [] # (module name/path, param name) 20 | shared_param_memo = {} 21 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name) 22 | params = [] 23 | param_numels = [] 24 | param_shapes = [] 25 | for mn, m in self.named_modules(): 26 | for n, p in m.named_parameters(recurse=False): 27 | if p is not None: 28 | if p in shared_param_memo: 29 | shared_mn, shared_n = shared_param_memo[p] 30 | shared_param_infos.append((mn, n, shared_mn, shared_n)) 31 | else: 32 | shared_param_memo[p] = (mn, n) 33 | param_infos.append((mn, n)) 34 | params.append(p.detach()) 35 | param_numels.append(p.numel()) 36 | param_shapes.append(p.size()) 37 | 38 | assert len(set(p.dtype for p in params)) <= 1, \ 39 | "expects all parameters in module to have same dtype" 40 | 41 | # store the info for unflatten 42 | self._param_infos = tuple(param_infos) 43 | self._shared_param_infos = tuple(shared_param_infos) 44 | self._param_numels = tuple(param_numels) 45 | self._param_shapes = tuple(param_shapes) 46 | 47 | # flatten 48 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 49 | self.register_parameter('flat_param', flat_param) 50 | self.param_numel = flat_param.numel() 51 | del params 52 | del shared_param_memo 53 | 54 | # deregister the names as parameters 55 | for mn, n in self._param_infos: 56 | delattr(self._get_module_from_name(mn), n) 57 | for mn, n, _, _ in self._shared_param_infos: 58 | delattr(self._get_module_from_name(mn), n) 59 | 60 | # register the views as plain attributes 61 | self._unflatten_param(self.flat_param) 62 | 63 | # now buffers 64 | # they are not reparametrized. just store info as (module, name, buffer) 65 | buffer_infos = [] 66 | for mn, m in self.named_modules(): 67 | for n, b in m.named_buffers(recurse=False): 68 | if b is not None: 69 | buffer_infos.append((mn, n, b)) 70 | 71 | self._buffer_infos = tuple(buffer_infos) 72 | self._traced_self = None 73 | 74 | def trace(self, example_input, **trace_kwargs): 75 | assert self._traced_self is None, 'This ReparamModule is already traced' 76 | 77 | if isinstance(example_input, torch.Tensor): 78 | example_input = (example_input,) 79 | example_input = tuple(example_input) 80 | example_param = (self.flat_param.detach().clone(),) 81 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 82 | 83 | self._traced_self = torch.jit.trace_module( 84 | self, 85 | inputs=dict( 86 | _forward_with_param=example_param + example_input, 87 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 88 | ), 89 | **trace_kwargs, 90 | ) 91 | 92 | # replace forwards with traced versions 93 | self._forward_with_param = self._traced_self._forward_with_param 94 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 95 | return self 96 | 97 | def clear_views(self): 98 | for mn, n in self._param_infos: 99 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr 100 | 101 | def _apply(self, *args, **kwargs): 102 | if self._traced_self is not None: 103 | self._traced_self._apply(*args, **kwargs) 104 | return self 105 | return super(ReparamModule, self)._apply(*args, **kwargs) 106 | 107 | def _unflatten_param(self, flat_param): 108 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 109 | for (mn, n), p in zip(self._param_infos, ps): 110 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr 111 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 112 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 113 | 114 | @contextmanager 115 | def unflattened_param(self, flat_param): 116 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos] 117 | self._unflatten_param(flat_param) 118 | yield 119 | # Why not just `self._unflatten_param(self.flat_param)`? 120 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 121 | # 2. slightly faster since it does not require reconstruct the split+view 122 | # graph 123 | for (mn, n), p in zip(self._param_infos, saved_views): 124 | setattr(self._get_module_from_name(mn), n, p) 125 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 126 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 127 | 128 | @contextmanager 129 | def replaced_buffers(self, buffers): 130 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers): 131 | setattr(self._get_module_from_name(mn), n, new_b) 132 | yield 133 | for mn, n, old_b in self._buffer_infos: 134 | setattr(self._get_module_from_name(mn), n, old_b) 135 | 136 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 137 | with self.unflattened_param(flat_param): 138 | with self.replaced_buffers(buffers): 139 | return self.module(*inputs, **kwinputs) 140 | 141 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 142 | with self.unflattened_param(flat_param): 143 | return self.module(*inputs, **kwinputs) 144 | 145 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 146 | flat_param = torch.squeeze(flat_param) 147 | if flat_param is None: 148 | flat_param = self.flat_param 149 | if buffers is None: 150 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 151 | else: 152 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) -------------------------------------------------------------------------------- /run_end2end.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import random 3 | from client import Client 4 | from utils import * 5 | from server import Server 6 | from image_synthesizer import Synthesizer 7 | import resource 8 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 9 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 10 | os.system("wandb login --relogin 8bb1fef7b4815daa3cb2ec7c5b0b9ee40d7ea6ed") 11 | np.set_printoptions(precision=4, suppress=True) 12 | def reduce_average(target, sources): 13 | for name in target: 14 | target[name].data = torch.mean(torch.stack([source[name].detach() for source in sources]), dim=0).clone() 15 | 16 | channel_dict = { 17 | "cifar10": 3, 18 | "cinic10": 3, 19 | "cifar100": 3, 20 | "mnist": 1, 21 | "fmnist": 1, 22 | } 23 | imsize_dict = { 24 | "cifar10": (32, 32), 25 | "cinic10": (32, 32), 26 | "cifar100": (32, 32), 27 | "mnist": (28, 28), 28 | "fmnist": (28, 28), 29 | } 30 | import os 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--start", default=0, type=int) 34 | parser.add_argument("--end", default=None, type=int) 35 | parser.add_argument("--hp", default=None, type=str) 36 | parser.add_argument("--project", default=None, type=str) 37 | parser.add_argument("--DATA_PATH", default=None, type=str) 38 | parser.add_argument("--runs_name", default=None, type=str) 39 | parser.add_argument("--RESULTS_PATH", default=None, type=str) 40 | parser.add_argument("--ACC_PATH", default=None, type=str) 41 | parser.add_argument("--CHECKPOINT_PATH", default=None, type=str) 42 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset') 43 | parser.add_argument('--model', type=str, default='ConvNet', help='model') 44 | parser.add_argument('--ipc', type=int, default=1, help='image(s) per class') 45 | parser.add_argument('--Iteration', type=int, default=5000, help='how many distillation steps to perform') 46 | parser.add_argument('--lr_img', type=float, default=5e-2, help='learning rate for updating synthetic images') 47 | parser.add_argument('--lr_label', type=float, default=1e-2, help='learning rate for updating synthetic images') 48 | parser.add_argument('--least_ave_num', type=int, default=1, help='learning rate for updating synthetic images') 49 | parser.add_argument('--max_ave_num', type=int, default=10, help='learning rate for updating synthetic images') 50 | parser.add_argument('--lr_lr', type=float, default=1e-05, help='learning rate for updating... learning rate') 51 | parser.add_argument('--lr_teacher', type=float, default=0.01, help='initialization for synthetic learning rate') 52 | parser.add_argument('--lr_init', type=float, default=0.01, help='how to init lr (alpha)') 53 | parser.add_argument('--label_init', type=float, default=10, help='how to init lr (alpha)') 54 | parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data') 55 | parser.add_argument('--batch_syn', type=int, default=None, help='should only use this if you run out of VRAM') 56 | parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks') 57 | parser.add_argument('--r', type=str, default='real', choices=["noise", "real"], 58 | help='noise/real: initialize synthetic images from random noise or randomly sampled real images.') 59 | parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], 60 | help='whether to use differentiable Siamese augmentation.') 61 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', 62 | help='differentiable Siamese augmentation strategy') 63 | parser.add_argument('--pix_init', type=str, default='real', 64 | help='differentiable Siamese augmentation strategy') 65 | parser.add_argument('--data_path', type=str, default='data', help='dat aset path') 66 | parser.add_argument('--img_optim', type=str, default='adam', help='dat aset path') 67 | parser.add_argument('--lr_optim', type=str, default='adam', help='dat aset path') 68 | parser.add_argument('--buffer_path', type=str, default=None, help='buffer path') 69 | parser.add_argument('--expert_dir', type=str, default='./buffers', help='buffer path') 70 | parser.add_argument('--start_learning_label', type=int, default=0, help='how many expert epochs the target params are') 71 | parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are') 72 | parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data') 73 | parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at') 74 | parser.add_argument('--min_start_epoch', type=int, default=25, help='max epoch we can start at') 75 | parser.add_argument('--max_epoch_incre', type=int, default=5, help='max epoch we can start at') 76 | parser.add_argument('--classes', type=int, default=None, nargs='+', help='max epoch we can start at') 77 | parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM") 78 | parser.add_argument('--random_weights', action='store_true', help="will distill textures instead") 79 | parser.add_argument('--weight_averaging', action='store_true', help="will distill textures instead") 80 | parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)') 81 | parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)') 82 | parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc') 83 | 84 | args = parser.parse_args() 85 | 86 | args.RESULTS_PATH = os.path.join(args.RESULTS_PATH, args.dataset, args.runs_name, str(random.randint(0,1000))) 87 | if not os.path.exists(args.RESULTS_PATH): 88 | os.makedirs(args.RESULTS_PATH) 89 | 90 | def run_experiment(xp, xp_count, n_experiments): 91 | t0 = time.time() 92 | print(xp) 93 | hp = xp.hyperparameters 94 | run = wandb.init(project = args.project, config = hp, reinit = True, name=args.runs_name) 95 | print(wandb.config) 96 | args.dsa = True 97 | args.dsa_param = ParamDiffAug() 98 | 99 | num_classes = {"mnist" : 10, "fmnist" : 10, "cifar10" : 10,"cinic10" : 10, "cifar100" : 100, "nlp" : 4, 'news20': 20}[hp["dataset"]] 100 | if hp.get("loader_mode", "normal") != "normal": 101 | num_classes = 3 102 | args.dsa = True 103 | args.dsa_param = ParamDiffAug() 104 | args.num_classes = num_classes 105 | args.channel = channel_dict[hp['dataset']] 106 | args.imsize = imsize_dict[hp['dataset']] 107 | if args.batch_syn is None: 108 | args.batch_syn = num_classes * args.ipc 109 | print(f"num classes {num_classes}, dsa mode {hp.get('dsa', True)}") 110 | model_names = [model_name for model_name, k in hp["models"].items() for _ in range(k)] 111 | optimizer, optimizer_hp = getattr(torch.optim, hp["local_optimizer"][0]), hp["local_optimizer"][1] 112 | optimizer_fn = lambda x : optimizer(x, **{k : hp[k] if k in hp else v for k, v in optimizer_hp.items()}) 113 | print(f"dataset : {hp['dataset']}") 114 | train_data_all, test_data = data.get_data(hp["dataset"], args.DATA_PATH) 115 | # Creating data indices for training and validation splits: 116 | np.random.seed(hp["random_seed"]) 117 | torch.manual_seed(hp["random_seed"]) 118 | train_data = train_data_all 119 | if hp.get("loader_mode", "normal") == "normal": 120 | client_loaders, test_loader = data.get_loaders(train_data, test_data, n_clients=len(model_names), 121 | alpha=hp["alpha"], batch_size=hp["batch_size"], n_data=None, num_workers=4, seed=hp["random_seed"]) 122 | else: 123 | indices = torch.load("checkpoints/cifar10/ConvNet/0.01/823/sampled_indices.pth") 124 | client_loaders, test_loader, class_indices = data.get_loaders_classes(train_data, test_data, n_clients=len(model_names), 125 | alpha=hp["alpha"], batch_size=hp["batch_size"], n_data=None, num_workers=4, seed=hp["random_seed"], classes = [6,7,9], total_num = 6000, indices=indices) 126 | images_train, labels_train = None, None 127 | # initialize server and clients 128 | server = Server(np.unique(model_names), test_loader,test_loader,num_classes=num_classes, images_train=images_train, labels_train=labels_train, eta=hp.get('eta', 0) , dataset = hp['dataset']) 129 | clients = [Client(model_name, optimizer_fn, loader, idnum=i, num_classes=num_classes, images_train=images_train, labels_train=labels_train, eta=hp.get('eta', 0), dataset = hp['dataset']) for i, (loader, model_name) in enumerate(zip(client_loaders, model_names))] 130 | print(clients[0].model) 131 | # initialize data synthesizer 132 | synthesizer = Synthesizer(deepcopy(clients[0].model), test_loader, args) 133 | server.number_client_all = len(client_loaders) 134 | models.print_model(clients[0].model) 135 | # Start Distributed Training Process 136 | print("Start Distributed Training..\n") 137 | t1 = time.time() 138 | xp.log({"prep_time" : t1-t0}) 139 | maximum_acc_test, maximum_acc_val = 0, 0 140 | xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()}) 141 | test_accs, val_accs = [], [] 142 | trajectories_list = [] 143 | distilled_rounds = 0 144 | trajectories_list.append([]) 145 | trajectories_list[-1].append([p.cpu() for p in server.model_dict[list(server.model_dict.keys())[0]].parameters()]) 146 | print(f"model key {list(server.model_dict.keys())[0]}") 147 | for c_round in range(1, hp["communication_rounds"]+1): 148 | if distilled_rounds < hp["maximum_distill_round"]: 149 | if len(trajectories_list[distilled_rounds]) >= hp["minimum_trajectory_length"][distilled_rounds]: 150 | print(f"{c_round+1}th iteration, update synthesized data ...") 151 | synthesizer.synthesize(trajectories_list=trajectories_list, args=args) 152 | synthesizer.evaluate(c_round+1, args=args) 153 | distilled_rounds += 1 154 | trajectories_list.append([]) 155 | server.images_train, server.labels_train = synthesizer.image_syn.cpu().detach(), synthesizer.label_syn.cpu().detach() 156 | 157 | participating_clients = server.select_clients(clients, hp["participation_rate"], hp.get('unbalance_rate', 1), hp.get('sample_mode', "uniform")) 158 | xp.log({"participating_clients" : np.array([c.id for c in participating_clients])}) 159 | for client in participating_clients: 160 | client.synchronize_with_server(server) 161 | train_stats = client.compute_weight_update(hp["local_epochs"], lambda_fedprox=hp["lambda_fedprox"] if "PROX" in hp["aggregation_mode"] else 0.0) 162 | if hp["aggregation_mode"] == "FedAVG": 163 | server.fedavg(participating_clients) 164 | elif hp["aggregation_mode"] == "ABAVG": 165 | server.abavg(participating_clients) 166 | elif hp["aggregation_mode"] == "datadistill": 167 | distill_iter = hp.get("distill_iter", None) 168 | distill_lr = hp.get("distill_lr", None) 169 | server.datadistill(participating_clients, distill_iter, distill_lr, dsa=hp.get("dsa", True), args=args) 170 | elif "PROX" in hp["aggregation_mode"]: 171 | server.fedavg(participating_clients) 172 | else: 173 | import pdb; pdb.set_trace() 174 | if xp.is_log_round(c_round): 175 | xp.log({'communication_round' : c_round, 'epochs' : c_round*hp['local_epochs']}) 176 | xp.log({key : clients[0].optimizer.__dict__['param_groups'][0][key] for key in optimizer_hp}) 177 | if server.weights != None: 178 | xp.log({"weights": np.array(server.weights.cpu())}) 179 | for key, value in server.evaluate_ensemble().items(): 180 | if key == "test_accuracy": 181 | if value > maximum_acc_test: 182 | maximum_acc_test = value 183 | wandb.log({"maximum_acc_{}_a_{}_test".format("accuracy", hp["alpha"]): maximum_acc_test}, step=c_round) 184 | elif key == "val_accuracy": 185 | if value > maximum_acc_val: 186 | maximum_acc_val = value 187 | wandb.log({"maximum_acc_{}_a_{}_val".format("accuracy", hp["alpha"]): maximum_acc_val}, step=c_round) 188 | xp.log({"server_val_{}".format(key) : value for key, value in server.evaluate_ensemble().items()}) 189 | wandb.log({"server_{}_a_{}".format(key, hp["alpha"]) : value for key, value in server.evaluate_ensemble().items()}, step=c_round) 190 | xp.log({"epoch_time" : (time.time()-t1)/c_round}) 191 | stats = server.evaluate_ensemble() 192 | test_accs.append(stats['test_accuracy']) 193 | val_accs.append(stats['val_accuracy']) 194 | # Save results to Disk 195 | xp.save_to_disc(path=args.RESULTS_PATH, name="logfiles") 196 | e = int((time.time()-t1)/c_round*(hp['communication_rounds']-c_round)) 197 | print("Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60), 198 | "[{:.2f}%]\n".format(c_round/hp['communication_rounds']*100)) 199 | trajectories_list[-1].append([p.cpu() for p in server.model_dict[list(server.model_dict.keys())[0]].parameters()]) 200 | 201 | # Save model to disk 202 | server.save_model(path=args.CHECKPOINT_PATH, name=hp["save_model"]) 203 | # Delete objects to free up GPU memory 204 | del server; clients.clear() 205 | torch.cuda.empty_cache() 206 | run.finish() 207 | 208 | def run(): 209 | experiments_raw = json.loads(args.hp) 210 | hp_dicts = [hp for x in experiments_raw for hp in xpm.get_all_hp_combinations(x)][args.start:args.end] 211 | experiments = [xpm.Experiment(hyperparameters=hp) for hp in hp_dicts] 212 | 213 | print("Running {} Experiments..\n".format(len(experiments))) 214 | for xp_count, experiment in enumerate(experiments): 215 | run_experiment(experiment, xp_count, len(experiments)) 216 | 217 | 218 | if __name__ == "__main__": 219 | import wandb 220 | 221 | run() 222 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision 2 | import numpy as np 3 | import os, pickle 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.preprocessing import MinMaxScaler 6 | from sklearn.datasets import fetch_20newsgroups_vectorized 7 | 8 | 9 | class News20Dataset(torch.utils.data.Dataset): 10 | def __init__(self, data, targets): 11 | self.data = data 12 | self.targets = targets 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, idx): 18 | return self.data[idx], self.targets[idx] 19 | 20 | def get_news20group(path): 21 | def frnp(x): return torch.from_numpy(x).float() 22 | def from_sparse(x): 23 | x = x.tocoo() 24 | values = x.data 25 | indices = np.vstack((x.row, x.col)) 26 | 27 | i = torch.LongTensor(indices) 28 | v = torch.FloatTensor(values) 29 | shape = x.shape 30 | 31 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)) 32 | 33 | X, y = fetch_20newsgroups_vectorized(subset='train', return_X_y=True, 34 | ) 35 | x_test, y_test = fetch_20newsgroups_vectorized(subset='test', return_X_y=True, 36 | ) 37 | ys = [frnp(y).long(), frnp(y_test).long()] 38 | xs = [X, x_test] 39 | xs = [from_sparse(x).cuda() for x in xs] 40 | train_data = News20Dataset(xs[0], ys[0]) 41 | test_data = News20Dataset(xs[1], ys[1]) 42 | return train_data, test_data 43 | 44 | def get_cinic10(path): 45 | cinic_directory = '../../data/cinic10' 46 | cinic_mean = [0.47889522, 0.47227842, 0.43047404] 47 | cinic_std = [0.24205776, 0.23828046, 0.25874835] 48 | train_data = torchvision.datasets.ImageFolder(cinic_directory + '/train', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=cinic_mean,std=cinic_std)])) 49 | test_data = torchvision.datasets.ImageFolder(cinic_directory + '/test', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=cinic_mean,std=cinic_std)])) 50 | return train_data, test_data 51 | 52 | def get_mnist(path): 53 | mnist_transform = torchvision.transforms.Compose([ 54 | torchvision.transforms.ToTensor(), 55 | torchvision.transforms.Normalize((0.1307,), (0.3081,)) 56 | ]) 57 | train_data = torchvision.datasets.MNIST(root=path+"mnist", train=True, transform=mnist_transform, download=True) 58 | test_data = torchvision.datasets.MNIST(root=path+"mnist", train=True, transform=mnist_transform, download=True) 59 | return train_data, test_data 60 | 61 | 62 | def get_cifar10(path): 63 | transforms = torchvision.transforms.Compose([ 64 | torchvision.transforms.ToTensor(), 65 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), 66 | (0.2023, 0.1994, 0.2010)) 67 | ]) 68 | train_data = torchvision.datasets.CIFAR10(root=path+"CIFAR", train=True, download=True, transform=transforms) 69 | test_data = torchvision.datasets.CIFAR10(root=path+"CIFAR", train=False, download=True, transform=transforms) 70 | 71 | return train_data, test_data 72 | 73 | def get_fmnist(path): 74 | transforms = torchvision.transforms.Compose([ 75 | torchvision.transforms.ToTensor(), 76 | torchvision.transforms.Normalize((0.1307,), (0.3081,)) 77 | ]) 78 | train_data = torchvision.datasets.FashionMNIST(root=path+"FMNIST", train=True, download=True, transform=transforms) 79 | test_data = torchvision.datasets.FashionMNIST(root=path+"FMNIST", train=False, download=True, transform=transforms) 80 | 81 | return train_data, test_data 82 | 83 | 84 | def get_cifar100(path): 85 | transforms = torchvision.transforms.Compose([ 86 | torchvision.transforms.ToTensor(), 87 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), 88 | (0.2023, 0.1994, 0.2010)) 89 | ]) 90 | train_data = torchvision.datasets.CIFAR100(root=path+"CIFAR100", train=True, download=True, transform=transforms) 91 | test_data = torchvision.datasets.CIFAR100(root=path+"CIFAR100", train=False, download=True, transform=transforms) 92 | 93 | return train_data, test_data 94 | 95 | 96 | def get_cifar100_distill(path): 97 | transforms = torchvision.transforms.Compose([ 98 | torchvision.transforms.ToTensor(), 99 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), 100 | (0.2023, 0.1994, 0.2010)) 101 | ]) 102 | train_data = torchvision.datasets.CIFAR100(root=path+"CIFAR100", train=True, download=True, transform=transforms) 103 | test_data = torchvision.datasets.CIFAR100(root=path+"CIFAR100", train=False, download=True, transform=transforms) 104 | 105 | return torch.utils.data.ConcatDataset([train_data, test_data]) 106 | 107 | 108 | 109 | def get_stl10(path): 110 | transforms = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)), 111 | torchvision.transforms.ToTensor(), 112 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), 113 | (0.2023, 0.1994, 0.2010)) 114 | ]) 115 | 116 | data = torchvision.datasets.STL10(root=path+"STL10", split='unlabeled', folds=None, 117 | transform=transforms, 118 | download=True) 119 | return data 120 | 121 | 122 | 123 | 124 | 125 | def get_data(dataset, path): 126 | return {"mnist" : get_mnist, "fmnist": get_fmnist, "cifar10" : get_cifar10, "cinic10" : get_cinic10, "stl10" : get_stl10,"cifar100" : get_cifar100,"news20" : get_news20group, "cifar100_distill" : get_cifar100_distill}[dataset](path) 127 | 128 | 129 | def get_loaders(train_data, test_data, n_clients=10, alpha=0, batch_size=128, n_data=None, num_workers=0, seed=0): 130 | # import pdb; pdb.set_trace() 131 | subset_idcs = split_dirichlet(train_data.targets, n_clients, n_data, alpha, seed=seed) 132 | client_data = [torch.utils.data.Subset(train_data, subset_idcs[i]) for i in range(n_clients)] 133 | 134 | 135 | client_loaders = [torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for subset in client_data] 136 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, num_workers=num_workers) 137 | 138 | return client_loaders, test_loader 139 | 140 | def get_loaders_classes(train_data, test_data, n_clients=10, alpha=0, batch_size=128, n_data=None, num_workers=0, seed=0, classes = [0,2,4], total_num = 1500, indices=None): 141 | print(f"number of clients {n_clients}") 142 | if indices is None: 143 | num_per_class= int(total_num/len(classes)) 144 | n_clients = len(classes) 145 | classwise_indices = [[i for i in range(len(train_data)) if train_data.targets[i] == j] for j in classes] 146 | for i, class_ind in enumerate(classwise_indices): 147 | for j in class_ind: 148 | train_data.targets[j] = i 149 | classwise_indices_sampled = [np.random.choice(indices, num_per_class, replace=False) for indices in classwise_indices] 150 | else: 151 | classwise_indices_sampled = indices 152 | for i, class_ind in enumerate(classwise_indices_sampled): 153 | for j in class_ind: 154 | train_data.targets[j] = i 155 | client_data = [torch.utils.data.Subset(train_data, classwise_indices_sampled[i]) for i in range(n_clients)] 156 | # client_data = [torch.utils.data.Subset(train_data, np.concatenate(classwise_indices_sampled)) for i in range(n_clients)] 157 | classwise_indices_test = [i for i in range(len(test_data)) if test_data.targets[i] in classes] 158 | for i in classwise_indices_test: 159 | test_data.targets[i] = classes.index(test_data.targets[i]) 160 | test_data = torch.utils.data.Subset(test_data, classwise_indices_test) 161 | client_loaders = [torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for subset in client_data] 162 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, num_workers=num_workers) 163 | print(f"number of data per class: {[len(x) for x in classwise_indices_sampled]}:") 164 | print("train class sampled indices:") 165 | print(classwise_indices_sampled) 166 | print("test class sampled indices:") 167 | print(classwise_indices_test) 168 | return client_loaders, test_loader, classwise_indices_sampled 169 | 170 | 171 | 172 | from torch.utils.data import Dataset 173 | class my_subset(Dataset): 174 | r""" 175 | Subset of a dataset at specified indices. 176 | 177 | Arguments: 178 | dataset (Dataset): The whole Dataset 179 | indices (sequence): Indices in the whole set selected for subset 180 | labels(sequence) : targets as required for the indices. will be the same length as indices 181 | """ 182 | def __init__(self, dataset, indices,labels): 183 | self.dataset = dataset 184 | self.indices = indices 185 | labels_hold = torch.ones(len(dataset)).type(torch.long) *300 #( some number not present in the #labels just to make sure 186 | # import pdb; pdb.set_trace() 187 | labels_hold[self.indices] = torch.LongTensor(labels ) 188 | self.labels = labels_hold 189 | self.targets = torch.LongTensor(labels ) 190 | def __getitem__(self, idx): 191 | image = self.dataset[self.indices[idx]][0] 192 | label = self.labels[self.indices[idx]] 193 | return (image, label) 194 | 195 | def __len__(self): 196 | return len(self.indices) 197 | 198 | 199 | def split_dirichlet(labels, n_clients, n_data, alpha, double_stochstic=True, seed=0): 200 | '''Splits data among the clients according to a dirichlet distribution with parameter alpha''' 201 | 202 | np.random.seed(seed) 203 | 204 | if isinstance(labels, torch.Tensor): 205 | labels = labels.numpy() 206 | 207 | n_classes = np.max(labels)+1 208 | # import pdb; pdb.set_trace() 209 | label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes) 210 | 211 | if double_stochstic: 212 | label_distribution = make_double_stochstic(label_distribution) 213 | 214 | class_idcs = [np.argwhere(np.array(labels)==y).flatten() 215 | for y in range(n_classes)] 216 | 217 | client_idcs = [[] for _ in range(n_clients)] 218 | for c, fracs in zip(class_idcs, label_distribution): 219 | for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))): 220 | client_idcs[i] += [idcs] 221 | 222 | client_idcs = [np.concatenate(idcs) for idcs in client_idcs] 223 | 224 | print_split(client_idcs, labels) 225 | 226 | return client_idcs 227 | 228 | def unbalanced_dataset(dataset, imbalanced_factor=-1,num_classes=10): 229 | if imbalanced_factor > 0: 230 | imbalanced_num_list = [] 231 | sample_num = int(len(dataset.targets) / num_classes) 232 | for class_index in range(num_classes): 233 | imbalanced_num = sample_num / (imbalanced_factor ** (class_index / (num_classes - 1))) 234 | imbalanced_num_list.append(int(imbalanced_num)) 235 | np.random.shuffle(imbalanced_num_list) 236 | print(imbalanced_num_list) 237 | else: 238 | imbalanced_num_list = None 239 | index_to_train=[] 240 | for class_index in range(num_classes): 241 | index_to_class = [index for index, label in enumerate(dataset.targets) if label == class_index] 242 | np.random.shuffle(index_to_class) 243 | 244 | if imbalanced_num_list is not None: 245 | index_to_class = index_to_class[:imbalanced_num_list[class_index]] 246 | 247 | index_to_train.extend(index_to_class) 248 | print(f"class_index {class_index}, samples {len(index_to_class)}") 249 | dataset.data = dataset.data[index_to_train] 250 | dataset.targets = list(np.array(dataset.targets)[index_to_train]) 251 | return dataset 252 | 253 | def make_double_stochstic(x): 254 | rsum = None 255 | csum = None 256 | 257 | n = 0 258 | while n < 1000 and (np.any(rsum != 1) or np.any(csum != 1)): 259 | x /= x.sum(0) 260 | x = x / x.sum(1)[:, np.newaxis] 261 | rsum = x.sum(1) 262 | csum = x.sum(0) 263 | n += 1 264 | 265 | return x 266 | 267 | 268 | 269 | def print_split(idcs, labels): 270 | n_labels = np.max(labels) + 1 271 | print("Data split:") 272 | splits = [] 273 | for i, idccs in enumerate(idcs): 274 | split = np.sum(np.array(labels)[idccs].reshape(1,-1)==np.arange(n_labels).reshape(-1,1), axis=1) 275 | splits += [split] 276 | if len(idcs) < 30 or i < 10 or i>len(idcs)-10: 277 | print(" - Client {}: {:55} -> sum={}".format(i,str(split), np.sum(split)), flush=True) 278 | elif i==len(idcs)-10: 279 | print(". "*10+"\n"+". "*10+"\n"+". "*10) 280 | 281 | print(" - Total: {}".format(np.stack(splits, axis=0).sum(axis=0))) 282 | print() 283 | 284 | 285 | 286 | 287 | class IdxSubset(torch.utils.data.Dataset): 288 | 289 | def __init__(self, dataset, indices, return_index): 290 | self.dataset = dataset 291 | self.indices = indices 292 | self.return_index = return_index 293 | 294 | def __getitem__(self, idx): 295 | if self.return_index: 296 | return self.dataset[self.indices[idx]], idx 297 | else: 298 | return self.dataset[self.indices[idx]]#, idx 299 | 300 | def __len__(self): 301 | return len(self.indices) 302 | 303 | 304 | 305 | -------------------------------------------------------------------------------- /image_synthesizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.utils 8 | import tqdm 9 | from utils import kd_loss, DiffAugment 10 | import wandb 11 | import copy 12 | from torch.utils.data import Dataset 13 | from copy import deepcopy 14 | import time 15 | import random 16 | from reparam_module import ReparamModule 17 | import warnings 18 | warnings.filterwarnings("ignore", category=DeprecationWarning) 19 | mean_dataset={ 20 | "cifar10": [0.4914, 0.4822, 0.4465], 21 | "mnist": [0.1307], 22 | "fmnist": [0.1307], 23 | } 24 | std_dataset = { 25 | "cifar10" : [0.2023, 0.1994, 0.2010], 26 | "mnist" : [0.3081], 27 | "fmnist" : [0.3081], 28 | } 29 | 30 | class TensorDataset(Dataset): 31 | def __init__(self, images, labels): # images: n x c x h x w tensor 32 | self.images = images 33 | self.labels = labels 34 | 35 | def __getitem__(self, index): 36 | return self.images[index], self.labels[index] 37 | 38 | def __len__(self): 39 | return self.images.shape[0] 40 | 41 | def reduce_params(sources, weights): 42 | targets = [] 43 | for i in range(len(sources[0])): 44 | target = torch.sum(weights * torch.stack([source[i].cuda() for source in sources], dim = -1), dim=-1) 45 | targets.append(target) 46 | return targets 47 | 48 | def epoch(mode, dataloader, net, optimizer, criterion, aug=True, args=None): 49 | loss_avg, acc_avg, num_exp = 0, 0, 0 50 | net = net.cuda() 51 | if mode == 'train': 52 | net.train() 53 | else: 54 | net.eval() 55 | 56 | for i_batch, datum in enumerate(dataloader): 57 | img = datum[0].float().cuda() 58 | lab = datum[1].cuda() 59 | if aug: 60 | img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param) 61 | n_b = lab.shape[0] 62 | output = net(img) 63 | loss = criterion(output, lab) 64 | if mode == 'train': 65 | acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), np.argmax(lab.cpu().data.numpy(), axis=-1))) 66 | else: 67 | acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy())) 68 | loss_avg += loss.item()*n_b 69 | acc_avg += acc 70 | num_exp += n_b 71 | if mode == 'train': 72 | optimizer.zero_grad() 73 | loss.backward() 74 | optimizer.step() 75 | 76 | loss_avg /= num_exp 77 | acc_avg /= num_exp 78 | 79 | return loss_avg, acc_avg 80 | 81 | 82 | def evaluate_synset(it_eval, net, lr_net, images_train, labels_train, testloader, args): 83 | net = net.cuda() 84 | images_train = images_train.cuda() 85 | labels_train = labels_train.cuda() 86 | lr = float(lr_net) 87 | Epoch = 500 88 | lr_schedule = [Epoch//2+1] 89 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 90 | dst_train = TensorDataset(images_train, labels_train) 91 | trainloader = torch.utils.data.DataLoader(dst_train, batch_size=256, shuffle=True, num_workers=0) 92 | start = time.time() 93 | acc_train_list = [] 94 | loss_train_list = [] 95 | for ep in tqdm.tqdm(range(Epoch+1)): 96 | loss_train, acc_train = epoch('train', trainloader, net, optimizer, kd_loss, aug=True, args=args) 97 | acc_train_list.append(acc_train) 98 | loss_train_list.append(loss_train) 99 | if ep == Epoch: 100 | with torch.no_grad(): 101 | loss_test, acc_test = epoch('test', testloader, net, optimizer, nn.CrossEntropyLoss().cuda(), aug=False, args=args) 102 | if ep in lr_schedule: 103 | lr *= 0.1 104 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 105 | time_train = time.time() - start 106 | print('Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test)) 107 | return net, acc_train_list, acc_test 108 | 109 | class Synthesizer: 110 | def __init__(self, network, test_loader, args): 111 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 112 | self.dataset =args.dataset 113 | self.testloader =test_loader 114 | self.batch_syn = args.batch_syn 115 | self.save_path = args.RESULTS_PATH 116 | self.iteration = args.Iteration 117 | self.channel = args.channel 118 | hard_label = [np.ones(args.ipc, dtype=np.long)*i for i in range(args.num_classes)] 119 | label_syn = torch.nn.functional.one_hot(torch.tensor(hard_label).reshape(-1), num_classes=args.num_classes).float() 120 | label_syn = label_syn * args.label_init 121 | label_syn = label_syn.detach().to(self.device).requires_grad_(True) 122 | image_syn = torch.randn(size=(args.num_classes * args.ipc, args.channel, args.imsize[0], args.imsize[1]), dtype=torch.float) 123 | syn_lr = torch.tensor(args.lr_teacher).to(self.device) 124 | image_syn = image_syn.detach().to(self.device).requires_grad_(True) 125 | syn_lr = syn_lr.detach().to(self.device).requires_grad_(True) 126 | if args.img_optim == "sgd": 127 | optimizer_img = torch.optim.SGD([image_syn], lr=args.lr_img, momentum=0.5) 128 | optimizer_label = torch.optim.SGD([label_syn], lr=args.lr_label, momentum=0.5) 129 | else: 130 | optimizer_img = torch.optim.Adam([image_syn], lr=args.lr_img) 131 | optimizer_label = torch.optim.Adam([label_syn], lr=args.lr_label) 132 | if args.lr_optim == "sgd": 133 | optimizer_lr = torch.optim.SGD([syn_lr], lr=args.lr_lr, momentum=0.5) 134 | else: 135 | optimizer_lr = torch.optim.Adam([syn_lr], lr=args.lr_lr) 136 | self.test_loader = test_loader 137 | self.label_syn, self.image_syn, self.syn_lr = label_syn, image_syn, syn_lr 138 | self.optimizer_img, self.optimizer_label, self.optimizer_lr = optimizer_img, optimizer_label, optimizer_lr 139 | self.network = network.cuda() 140 | self.weight_averaging, self.least_ave_num, self.max_ave_num, self.random_weights = args.weight_averaging, args.least_ave_num, args.max_ave_num, args.random_weights 141 | self.distributed = torch.cuda.device_count() > 1 142 | self.syn_steps, self.min_start_epoch , self.max_start_epoch, self.expert_epochs = args.syn_steps, args.min_start_epoch, args.max_start_epoch, args.expert_epochs 143 | 144 | 145 | def synthesize(self, trajectories_list, args): 146 | for it in range(0, self.iteration): 147 | trajectories = trajectories_list[random.randint(0, len(trajectories_list)-1)] 148 | # trajectories = trajectories_list[-1] 149 | student_net = ReparamModule(copy.deepcopy(self.network)) 150 | if self.distributed: 151 | student_net = torch.nn.DataParallel(student_net) 152 | student_net.train() 153 | num_params = sum([np.prod(p.size()) for p in (student_net.parameters())]) 154 | curr_max_start_epoch = min([self.max_start_epoch, len(trajectories) - 1 - self.expert_epochs]) 155 | if curr_max_start_epoch == 0: 156 | start_epoch = 0 157 | else: 158 | start_epoch = np.random.randint(self.min_start_epoch, curr_max_start_epoch+1) 159 | # print(f"max start epoch {curr_max_start_epoch}, min start epoch {self.min_start_epoch}, expert epoch {self.expert_epochs}") 160 | # print(f"sampled start epoch {start_epoch}") 161 | starting_params = trajectories[start_epoch] 162 | if not self.weight_averaging: 163 | target_params = trajectories[start_epoch+self.expert_epochs] 164 | else: 165 | max_ave_num = self.max_ave_num+1 if self.max_ave_num < self.expert_epochs else self.expert_epochs+1 166 | averaging_num = random.choice(list(range(self.least_ave_num, max_ave_num))) 167 | candidate_params = random.choices(trajectories[start_epoch+1: start_epoch+self.expert_epochs+1],k=averaging_num) 168 | if not self.random_weights: 169 | weights = torch.full([len(candidate_params)], 1./len(candidate_params), dtype=torch.float, device="cuda") 170 | else: 171 | weights = torch.rand(len(candidate_params)).to(self.device) 172 | weights = torch.softmax(weights, dim=0) 173 | target_params = reduce_params(candidate_params, weights) 174 | 175 | target_params = torch.cat([p.data.to(self.device).reshape(-1) for p in target_params], 0) 176 | student_params = [torch.cat([p.data.to(self.device).reshape(-1) for p in starting_params], 0).requires_grad_(True)] 177 | starting_params = torch.cat([p.data.to(self.device).reshape(-1) for p in starting_params], 0) 178 | syn_images = self.image_syn 179 | y_hat = self.label_syn 180 | param_loss_list = [] 181 | param_dist_list = [] 182 | indices_chunks = [] 183 | 184 | for step in range(self.syn_steps): 185 | if not indices_chunks: 186 | indices = torch.randperm(len(syn_images)) 187 | indices_chunks = list(torch.split(indices, self.batch_syn)) 188 | 189 | these_indices = indices_chunks.pop() 190 | x = syn_images[these_indices] 191 | this_y = y_hat[these_indices] 192 | if args.dsa: 193 | x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param) 194 | if self.distributed: 195 | forward_params = student_params[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1) 196 | else: 197 | forward_params = student_params[-1] 198 | x = student_net(x, flat_param=forward_params) 199 | ce_loss = kd_loss(x, this_y) 200 | grad = torch.autograd.grad(ce_loss, student_params[-1], create_graph=True)[0] 201 | student_params.append(student_params[-1] - self.syn_lr * grad) 202 | 203 | param_loss = torch.tensor(0.0).to(self.device) 204 | param_dist = torch.tensor(0.0).to(self.device) 205 | 206 | param_loss += torch.nn.functional.mse_loss(student_params[-1], target_params, reduction="sum") 207 | param_dist += torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum") 208 | 209 | param_loss_list.append(param_loss) 210 | param_dist_list.append(param_dist) 211 | 212 | 213 | param_loss /= num_params 214 | param_dist /= num_params 215 | 216 | param_loss /= param_dist 217 | 218 | grand_loss = param_loss 219 | 220 | self.optimizer_img.zero_grad() 221 | self.optimizer_label.zero_grad() 222 | self.optimizer_lr.zero_grad() 223 | 224 | grand_loss.backward() 225 | 226 | self.optimizer_img.step() 227 | self.optimizer_lr.step() 228 | self.optimizer_label.step() 229 | 230 | # wandb.log({"Grand_Loss": grand_loss.detach().cpu()}) 231 | 232 | for _ in student_params: 233 | del _ 234 | 235 | if it%10 == 0: 236 | print('iter = %04d, loss = %.4f' % (it, grand_loss.item())) 237 | print(f"syn_labels = {F.softmax(self.label_syn)}") 238 | if (it+1)%500 == 0: 239 | self.evaluate(0,upload_wandb=False, args=args) 240 | 241 | def evaluate(self, c_round, upload_wandb= True, args=None): 242 | accs_test = [] 243 | accs_train = [] 244 | for it_eval in range(3): 245 | net_eval = copy.deepcopy(self.network).to(self.device) # get a random model 246 | eval_labs = self.label_syn.detach() 247 | with torch.no_grad(): 248 | image_save = self.image_syn 249 | image_syn_eval, label_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(eval_labs.detach()) # avoid any unaware modification 250 | lr_net = self.syn_lr.item() 251 | _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, lr_net, image_syn_eval, label_syn_eval, self.testloader, args) 252 | accs_test.append(acc_test) 253 | accs_train.append(acc_train) 254 | accs_test = np.array(accs_test) 255 | acc_test_mean = np.mean(accs_test) 256 | acc_test_std = np.std(accs_test) 257 | print('Evaluate %d, mean = %.4f std = %.4f\n-------------------------'%(len(accs_test), acc_test_mean, acc_test_std)) 258 | 259 | # uploading images to wandb 260 | upsampled = image_save 261 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2) 262 | upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3) 263 | grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True) 264 | if upload_wandb: 265 | wandb.log({'Accuracy/{}'.format("ConvNet"): acc_test_mean}, step=c_round) 266 | wandb.log({'Std/{}'.format("ConvNet"): acc_test_std}, step=c_round) 267 | wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=c_round) 268 | wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=c_round) 269 | 270 | image_syn_vis = copy.deepcopy(image_save.detach().cpu()) 271 | for ch in range(self.channel): 272 | image_syn_vis[:, ch] = image_syn_vis[:, ch] * std_dataset.get(self.dataset)[ch] + mean_dataset.get(self.dataset)[ch] 273 | image_syn_vis[image_syn_vis<0] = 0.0 274 | image_syn_vis[image_syn_vis>1] = 1.0 275 | grid = torchvision.utils.make_grid(image_syn_vis, nrow=10, normalize=True, scale_each=True) 276 | wandb.log({"Synthetic vis_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=c_round) 277 | save_data_path = os.path.join(self.save_path, "syn_data") 278 | if not os.path.exists(save_data_path): 279 | os.makedirs(save_data_path) 280 | torch.save(self.image_syn.detach().cpu(), os.path.join(save_data_path, "images_best.pt".format(c_round))) 281 | torch.save(self.label_syn.detach().cpu(), os.path.join(save_data_path, "labels_best.pt".format(c_round))) 282 | print(f"saved synthetic data at {save_data_path}") 283 | 284 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torchvision.models.resnet import ResNet, BasicBlock, conv3x3, conv1x1 6 | 7 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 8 | 9 | channel_dict = { 10 | "cifar10": 3, 11 | "cinic10": 3, 12 | "cifar100": 3, 13 | "mnist": 1, 14 | "fmnist": 1, 15 | } 16 | 17 | ############################################################################################################ 18 | # MOBILENET 19 | ############################################################################################################ 20 | 21 | class MLP(nn.Module): 22 | def __init__(self, num_classes=10, net_width=128, im_size = (28,28), dataset = 'cifar10'): 23 | super(MLP, self).__init__() 24 | channel = channel_dict.get(dataset) 25 | self.fc1 = nn.Linear(im_size[0]*im_size[1]*channel, net_width) 26 | self.fc2 = nn.Linear(net_width, net_width) 27 | self.fc3 = nn.Linear(net_width, num_classes) 28 | 29 | def forward(self, x): 30 | x = x.view(x.size(0), -1) 31 | x = F.relu(self.fc1(x)) 32 | x = F.relu(self.fc2(x)) 33 | return self.fc3(x) 34 | 35 | class Block(nn.Module): 36 | '''expand + depthwise + pointwise''' 37 | def __init__(self, in_planes, out_planes, expansion, stride, norm_layer): 38 | super(Block, self).__init__() 39 | self.stride = stride 40 | 41 | planes = expansion * in_planes 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 43 | self.bn1 = norm_layer(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 45 | self.bn2 = norm_layer(planes) 46 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 47 | self.bn3 = norm_layer(out_planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride == 1 and in_planes != out_planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 53 | norm_layer(out_planes), 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out = out + self.shortcut(x) if self.stride==1 else out 61 | return out 62 | 63 | 64 | class MobileNetV2(nn.Module): 65 | # (expansion, out_planes, num_blocks, stride) 66 | 67 | def __init__(self, num_classes=10, norm_layer=nn.BatchNorm2d,shrink=1, dataset = 'cifar10'): 68 | super(MobileNetV2, self).__init__() 69 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 70 | self.dataset = dataset 71 | channel = channel_dict.get(dataset) 72 | self.norm_layer = norm_layer 73 | self.cfg = [(1, 16//shrink, 1, 1), 74 | (6, 24//shrink, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 75 | (6, 32//shrink, 3, 2), 76 | (6, 64//shrink, 4, 2), 77 | (6, 96//shrink, 3, 1), 78 | (6, 160//shrink, 3, 2), 79 | (6, 320//shrink, 1, 1)] 80 | 81 | 82 | self.conv1 = nn.Conv2d(channel, 32, kernel_size=3, stride=1, padding=1, bias=False) 83 | self.bn1 = self.norm_layer(32) 84 | self.layers = self._make_layers(in_planes=32) 85 | self.conv2 = nn.Conv2d(self.cfg[-1][1], 1280//shrink, kernel_size=1, stride=1, padding=0, bias=False) 86 | self.bn2 = self.norm_layer(1280//shrink) 87 | 88 | 89 | self.classification_layer = nn.Linear(1280//shrink, num_classes) 90 | 91 | 92 | def _make_layers(self, in_planes): 93 | layers = [] 94 | for expansion, out_planes, num_blocks, stride in self.cfg: 95 | strides = [stride] + [1]*(num_blocks-1) 96 | for stride in strides: 97 | layers.append(Block(in_planes, out_planes, expansion, stride, self.norm_layer)) 98 | in_planes = out_planes 99 | return nn.Sequential(*layers) 100 | 101 | 102 | def extract_features(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layers(out) 105 | out = F.relu(self.bn2(self.conv2(out))) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | return out 109 | 110 | 111 | def forward(self, x): 112 | feature = self.extract_features(x) 113 | out = self.classification_layer(feature) 114 | return out 115 | 116 | 117 | 118 | 119 | def mobilenetv2(num_classes=10, dataset = 'cifar10'): 120 | return MobileNetV2(norm_layer=nn.BatchNorm2d, shrink=2, num_classes=num_classes, dataset = 'cifar10') 121 | 122 | 123 | 124 | 125 | ############################################################################################################ 126 | # RESNET 127 | ############################################################################################################ 128 | class basic_noskip(BasicBlock): 129 | expansion: int = 1 130 | def __init__( 131 | self, 132 | *args, 133 | **kwargs 134 | ) -> None: 135 | super(basic_noskip, self).__init__(*args, **kwargs) 136 | 137 | def forward(self, x): 138 | out = self.conv1(x) 139 | # out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | # out = self.bn2(out) 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | class Model_noskip(nn.Module): 149 | def __init__(self, channel=3, feature_dim=128, group_norm=False): 150 | super(Model_noskip, self).__init__() 151 | 152 | self.f = [] 153 | for name, module in ResNet(basic_noskip, [1,1,1,1], num_classes=10).named_children(): 154 | if name == 'conv1': 155 | module = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 156 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): 157 | self.f.append(module) 158 | # encoder 159 | self.f = nn.Sequential(*self.f) 160 | # projection head 161 | self.g = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512), 162 | nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True)) 163 | 164 | if group_norm: 165 | apply_gn(self) 166 | 167 | def forward(self, x): 168 | x = self.f(x) 169 | feature = torch.flatten(x, start_dim=1) 170 | out = self.g(feature) 171 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 172 | 173 | 174 | class resnet8_noskip(nn.Module): 175 | def __init__(self, num_classes=10, pretrained_path=None, group_norm=False, dataset = 'cifar10'): 176 | super(resnet8_noskip, self).__init__() 177 | channel = channel_dict.get(dataset) 178 | # encoder 179 | self.f = Model_noskip(channel = channel, group_norm=group_norm).f 180 | # classifier 181 | self.classification_layer = nn.Linear(512, num_classes, bias=True) 182 | 183 | 184 | if pretrained_path: 185 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False) 186 | 187 | 188 | def extract_features(self, x): 189 | return torch.flatten(self.f(x), start_dim=1) 190 | 191 | 192 | def forward(self, x): 193 | feature = self.extract_features(x) 194 | out = self.classification_layer(feature) 195 | return out 196 | 197 | class Model(nn.Module): 198 | def __init__(self, channel=3, feature_dim=128, group_norm=False): 199 | super(Model, self).__init__() 200 | 201 | self.f = [] 202 | for name, module in ResNet(BasicBlock, [1,1,1,1], num_classes=10).named_children(): 203 | if name == 'conv1': 204 | module = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 205 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): 206 | self.f.append(module) 207 | # encoder 208 | self.f = nn.Sequential(*self.f) 209 | # projection head 210 | self.g = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512), 211 | nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True)) 212 | 213 | if group_norm: 214 | apply_gn(self) 215 | 216 | def forward(self, x): 217 | x = self.f(x) 218 | feature = torch.flatten(x, start_dim=1) 219 | out = self.g(feature) 220 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 221 | 222 | 223 | class resnet8(nn.Module): 224 | def __init__(self, num_classes=10, pretrained_path=None, group_norm=False, dataset = 'cifar10'): 225 | super(resnet8, self).__init__() 226 | channel = channel_dict.get(dataset) 227 | # encoder 228 | self.f = Model(channel = channel, group_norm=group_norm).f 229 | # classifier 230 | self.classification_layer = nn.Linear(512, num_classes, bias=True) 231 | 232 | 233 | if pretrained_path: 234 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False) 235 | 236 | 237 | def extract_features(self, x): 238 | return torch.flatten(self.f(x), start_dim=1) 239 | 240 | 241 | def forward(self, x): 242 | feature = self.extract_features(x) 243 | out = self.classification_layer(feature) 244 | return out 245 | 246 | 247 | 248 | 249 | ############################################################################################################ 250 | # SHUFFLENET 251 | ############################################################################################################ 252 | 253 | 254 | 255 | class ShuffleBlock(nn.Module): 256 | def __init__(self, groups): 257 | super(ShuffleBlock, self).__init__() 258 | self.groups = groups 259 | 260 | def forward(self, x): 261 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 262 | N,C,H,W = x.size() 263 | g = self.groups 264 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 265 | 266 | 267 | class Bottleneck(nn.Module): 268 | def __init__(self, in_planes, out_planes, stride, groups): 269 | super(Bottleneck, self).__init__() 270 | self.stride = stride 271 | 272 | mid_planes = out_planes//4 273 | g = 1 if in_planes==24 else groups 274 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 275 | self.bn1 = nn.BatchNorm2d(mid_planes) 276 | self.shuffle1 = ShuffleBlock(groups=g) 277 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 278 | self.bn2 = nn.BatchNorm2d(mid_planes) 279 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 280 | self.bn3 = nn.BatchNorm2d(out_planes) 281 | 282 | self.shortcut = nn.Sequential() 283 | if stride == 2: 284 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 285 | 286 | def forward(self, x): 287 | out = F.relu(self.bn1(self.conv1(x))) 288 | out = self.shuffle1(out) 289 | out = F.relu(self.bn2(self.conv2(out))) 290 | out = self.bn3(self.conv3(out)) 291 | res = self.shortcut(x) 292 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 293 | return out 294 | 295 | 296 | class ShuffleNet(nn.Module): 297 | def __init__(self, num_classes=10): 298 | super(ShuffleNet, self).__init__() 299 | cfg = {'out_planes': [200,400,800],'num_blocks': [4,8,4],'groups': 2} 300 | 301 | out_planes = cfg['out_planes'] 302 | num_blocks = cfg['num_blocks'] 303 | groups = cfg['groups'] 304 | 305 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 306 | self.bn1 = nn.BatchNorm2d(24) 307 | self.in_planes = 24 308 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 309 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 310 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 311 | self.classification_layer = nn.Linear(out_planes[2], num_classes) 312 | 313 | 314 | def _make_layer(self, out_planes, num_blocks, groups): 315 | layers = [] 316 | for i in range(num_blocks): 317 | stride = 2 if i == 0 else 1 318 | cat_planes = self.in_planes if i == 0 else 0 319 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 320 | self.in_planes = out_planes 321 | return nn.Sequential(*layers) 322 | 323 | 324 | def extract_features(self, x): 325 | out = F.relu(self.bn1(self.conv1(x))) 326 | out = self.layer1(out) 327 | out = self.layer2(out) 328 | out = self.layer3(out) 329 | out = F.avg_pool2d(out, 4) 330 | feature = out.view(out.size(0), -1) 331 | return feature 332 | 333 | 334 | def forward(self, x): 335 | feature = self.extract_features(x) 336 | out = self.classification_layer(feature) 337 | return out 338 | 339 | 340 | ''' ConvNet ''' 341 | class ConvNet(nn.Module): 342 | def __init__(self, num_classes=10, net_width=128, net_depth=3, net_act='relu', net_norm='instancenorm', net_pooling='avgpooling', im_size = (32,32), dataset = 'cifar10'): 343 | super(ConvNet, self).__init__() 344 | channel = channel_dict.get(dataset) 345 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 346 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 347 | print(f"num feat {num_feat}") 348 | self.classifier = nn.Linear(num_feat, num_classes) 349 | 350 | def forward(self, x): 351 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 352 | out = self.get_feature(x) 353 | out = self.classifier(out) 354 | return out 355 | 356 | def get_feature(self,x): 357 | out = self.features(x) 358 | out = out.view(out.size(0), -1) 359 | return out 360 | 361 | def _get_activation(self, net_act): 362 | if net_act == 'sigmoid': 363 | return nn.Sigmoid() 364 | elif net_act == 'relu': 365 | return nn.ReLU(inplace=True) 366 | elif net_act == 'leakyrelu': 367 | return nn.LeakyReLU(negative_slope=0.01) 368 | else: 369 | exit('unknown activation function: %s'%net_act) 370 | 371 | def _get_pooling(self, net_pooling): 372 | if net_pooling == 'maxpooling': 373 | return nn.MaxPool2d(kernel_size=2, stride=2) 374 | elif net_pooling == 'avgpooling': 375 | return nn.AvgPool2d(kernel_size=2, stride=2) 376 | elif net_pooling == 'none': 377 | return None 378 | else: 379 | exit('unknown net_pooling: %s'%net_pooling) 380 | 381 | def _get_normlayer(self, net_norm, shape_feat): 382 | # shape_feat = (c*h*w) 383 | if net_norm == 'batchnorm': 384 | return nn.BatchNorm2d(shape_feat[0], affine=True) 385 | elif net_norm == 'layernorm': 386 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 387 | elif net_norm == 'instancenorm': 388 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 389 | elif net_norm == 'groupnorm': 390 | return nn.GroupNorm(4, shape_feat[0], affine=True) 391 | elif net_norm == 'none': 392 | return None 393 | else: 394 | exit('unknown net_norm: %s'%net_norm) 395 | 396 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 397 | layers = [] 398 | in_channels = channel 399 | if im_size[0] == 28: 400 | im_size = (32, 32) 401 | shape_feat = [in_channels, im_size[0], im_size[1]] 402 | for d in range(net_depth): 403 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 404 | shape_feat[0] = net_width 405 | if net_norm != 'none': 406 | layers += [self._get_normlayer(net_norm, shape_feat)] 407 | layers += [self._get_activation(net_act)] 408 | in_channels = net_width 409 | if net_pooling != 'none': 410 | layers += [self._get_pooling(net_pooling)] 411 | shape_feat[1] //= 2 412 | shape_feat[2] //= 2 413 | 414 | 415 | return nn.Sequential(*layers), shape_feat 416 | 417 | 418 | 419 | class TextModel(nn.Module): 420 | 421 | def __init__(self, vocab_size=95811, embed_dim=64, num_classes=4): 422 | super(TextModel, self).__init__() 423 | self.embedding = nn.EmbeddingBag(vocab_size, embed_dim) 424 | self.fc = nn.Linear(embed_dim, num_classes) 425 | self.init_weights() 426 | 427 | def init_weights(self): 428 | initrange = 0.5 429 | self.embedding.weight.data.uniform_(-initrange, initrange) 430 | self.fc.weight.data.uniform_(-initrange, initrange) 431 | self.fc.bias.data.zero_() 432 | 433 | def forward(self, text, offsets): 434 | embedded = self.embedding(text, offsets) 435 | return self.fc(embedded) 436 | 437 | class LogisticRegression(nn.Module): 438 | def __init__(self, input_dim=130107, num_classes=20): 439 | super(LogisticRegression, self).__init__() 440 | self.fc = torch.nn.Parameter(torch.zeros(input_dim, num_classes)) 441 | 442 | 443 | def forward(self, x): 444 | out = x @ self.fc 445 | return out 446 | 447 | def get_model(model): 448 | 449 | return { "mobilenetv2" : (mobilenetv2, optim.Adam, {"lr" : 0.001}), 450 | "shufflenet" : (ShuffleNet, optim.Adam, {"lr" : 0.001}), 451 | "resnet8" : (resnet8, optim.Adam, {"lr" : 0.001}), 452 | "resnet8_noskip" : (resnet8_noskip, optim.Adam, {"lr" : 0.001}), 453 | "ConvNet" : (ConvNet, optim.Adam, {"lr" : 0.001}), 454 | "MLP" : (MLP, optim.Adam, {"lr" : 0.001}), 455 | "TextModel" : (TextModel, optim.Adam, {"lr" : 1}), 456 | "LogisticRegression" : (LogisticRegression, optim.Adam, {"lr" : 0.001}), 457 | }[model] 458 | 459 | 460 | def print_model(model): 461 | n = 0 462 | print("Model:") 463 | for key, value in model.named_parameters(): 464 | print(' -', '{:30}'.format(key), list(value.shape), "Requires Grad:", value.requires_grad) 465 | n += value.numel() 466 | print("Total number of Parameters: ", n) 467 | print() 468 | 469 | 470 | 471 | 472 | 473 | 474 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, argparse, json, copy, time 2 | from tqdm import tqdm 3 | from functools import partial 4 | import torch, torchvision 5 | import numpy as np 6 | import torch.nn as nn 7 | import data, models 8 | import experiment_manager as xpm 9 | # from fl_devices import Client, Server, Client_flip, Client_target, Client_LIE 10 | from collections import OrderedDict 11 | from torch.utils.data import Dataset 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset 14 | from torchvision import datasets, transforms 15 | from scipy.ndimage.interpolation import rotate as scipyrotate 16 | 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | 19 | 20 | class ParamDiffAug(): 21 | def __init__(self): 22 | self.aug_mode = 'S' # 'multiple or single' 23 | self.prob_flip = 0.5 24 | self.ratio_scale = 1.2 25 | self.ratio_rotate = 15.0 26 | self.ratio_crop_pad = 0.125 27 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5 28 | self.ratio_noise = 0.05 29 | self.brightness = 1.0 30 | self.saturation = 2.0 31 | self.contrast = 0.5 32 | 33 | 34 | def set_seed_DiffAug(param): 35 | if param.latestseed == -1: 36 | return 37 | else: 38 | torch.random.manual_seed(param.latestseed) 39 | param.latestseed += 1 40 | 41 | 42 | def DiffAugment(x, strategy='', seed=-1, param=None): 43 | if seed == -1: 44 | param.batchmode = False 45 | else: 46 | param.batchmode = True 47 | 48 | param.latestseed = seed 49 | 50 | if strategy == 'None' or strategy == 'none': 51 | return x 52 | 53 | if strategy: 54 | if param.aug_mode == 'M': # original 55 | for p in strategy.split('_'): 56 | for f in AUGMENT_FNS[p]: 57 | x = f(x, param) 58 | elif param.aug_mode == 'S': 59 | pbties = strategy.split('_') 60 | set_seed_DiffAug(param) 61 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()] 62 | for f in AUGMENT_FNS[p]: 63 | x = f(x, param) 64 | else: 65 | exit('Error ZH: unknown augmentation mode.') 66 | x = x.contiguous() 67 | return x 68 | 69 | 70 | # We implement the following differentiable augmentation strategies based on the codes provided in https://github.com/mit-han-lab/data-efficient-gans. 71 | def rand_scale(x, param): 72 | # x>1, max scale 73 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 74 | ratio = param.ratio_scale 75 | set_seed_DiffAug(param) 76 | sx = torch.rand(x.shape[0]) * (ratio - 1.0 / ratio) + 1.0 / ratio 77 | set_seed_DiffAug(param) 78 | sy = torch.rand(x.shape[0]) * (ratio - 1.0 / ratio) + 1.0 / ratio 79 | theta = [[[sx[i], 0, 0], 80 | [0, sy[i], 0], ] for i in range(x.shape[0])] 81 | theta = torch.tensor(theta, dtype=torch.float) 82 | if param.batchmode: # batch-wise: 83 | theta[:] = theta[0] 84 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 85 | x = F.grid_sample(x, grid, align_corners=True) 86 | return x 87 | 88 | 89 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree 90 | ratio = param.ratio_rotate 91 | set_seed_DiffAug(param) 92 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 93 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 94 | [torch.sin(theta[i]), torch.cos(theta[i]), 0], ] for i in range(x.shape[0])] 95 | theta = torch.tensor(theta, dtype=torch.float) 96 | if param.batchmode: # batch-wise: 97 | theta[:] = theta[0] 98 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 99 | x = F.grid_sample(x, grid, align_corners=True) 100 | return x 101 | 102 | 103 | def rand_flip(x, param): 104 | prob = param.prob_flip 105 | set_seed_DiffAug(param) 106 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 107 | if param.batchmode: # batch-wise: 108 | randf[:] = randf[0] 109 | return torch.where(randf < prob, x.flip(3), x) 110 | 111 | 112 | def rand_brightness(x, param): 113 | ratio = param.brightness 114 | set_seed_DiffAug(param) 115 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 116 | if param.batchmode: # batch-wise: 117 | randb[:] = randb[0] 118 | x = x + (randb - 0.5) * ratio 119 | return x 120 | 121 | 122 | def rand_saturation(x, param): 123 | ratio = param.saturation 124 | x_mean = x.mean(dim=1, keepdim=True) 125 | set_seed_DiffAug(param) 126 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 127 | if param.batchmode: # batch-wise: 128 | rands[:] = rands[0] 129 | x = (x - x_mean) * (rands * ratio) + x_mean 130 | return x 131 | 132 | 133 | def rand_contrast(x, param): 134 | ratio = param.contrast 135 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 136 | set_seed_DiffAug(param) 137 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 138 | if param.batchmode: # batch-wise: 139 | randc[:] = randc[0] 140 | x = (x - x_mean) * (randc + ratio) + x_mean 141 | return x 142 | 143 | 144 | def rand_crop(x, param): 145 | # The image is padded on its surrounding and then cropped. 146 | ratio = param.ratio_crop_pad 147 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 148 | set_seed_DiffAug(param) 149 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 150 | set_seed_DiffAug(param) 151 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 152 | if param.batchmode: # batch-wise: 153 | translation_x[:] = translation_x[0] 154 | translation_y[:] = translation_y[0] 155 | grid_batch, grid_x, grid_y = torch.meshgrid( 156 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 157 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 158 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 159 | ) 160 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 161 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 162 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 163 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 164 | return x 165 | 166 | 167 | def rand_cutout(x, param): 168 | ratio = param.ratio_cutout 169 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 170 | set_seed_DiffAug(param) 171 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 172 | set_seed_DiffAug(param) 173 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 174 | if param.batchmode: # batch-wise: 175 | offset_x[:] = offset_x[0] 176 | offset_y[:] = offset_y[0] 177 | grid_batch, grid_x, grid_y = torch.meshgrid( 178 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 179 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 180 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 181 | ) 182 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 183 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 184 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 185 | mask[grid_batch, grid_x, grid_y] = 0 186 | x = x * mask.unsqueeze(1) 187 | return x 188 | 189 | 190 | AUGMENT_FNS = { 191 | 'color': [rand_brightness, rand_saturation, rand_contrast], 192 | 'crop': [rand_crop], 193 | 'cutout': [rand_cutout], 194 | 'flip': [rand_flip], 195 | 'scale': [rand_scale], 196 | 'rotate': [rand_rotate], 197 | } 198 | 199 | 200 | class TensorDataset(Dataset): 201 | def __init__(self, images, labels): # images: n x c x h x w tensor 202 | self.images = images.detach().float() 203 | self.labels = labels.detach() 204 | 205 | def __getitem__(self, index): 206 | return self.images[index], self.labels[index] 207 | 208 | def __len__(self): 209 | return self.images.shape[0] 210 | 211 | 212 | def get_benign_updates(mali_clients, server): 213 | # import pdb; pdb.set_trace() 214 | mal_user_grad_sum = {} 215 | mal_user_grad_pow = {} 216 | user_grad = {} 217 | server_weights = server.parameter_dict[mali_clients[0].model_name] 218 | for client in mali_clients: 219 | # import pdb; pdb.set_trace() 220 | for name in client.W: 221 | user_grad[name] = client.W[name].detach() - server_weights[name].detach() 222 | # import pdb; pdb.set_trace() 223 | if name not in mal_user_grad_sum: 224 | mal_user_grad_sum[name] = user_grad[name].clone() 225 | mal_user_grad_pow[name] = torch.pow(user_grad[name], 2) 226 | else: 227 | mal_user_grad_sum[name] += user_grad[name].clone() 228 | mal_user_grad_pow[name] += torch.pow(user_grad[name], 2) 229 | mal_user_grad_mean2 = OrderedDict() 230 | mal_user_grad_std2 = OrderedDict() 231 | 232 | for name in mali_clients[0].W: 233 | mal_user_grad_mean2[name] = mal_user_grad_sum[name] / len(mali_clients) 234 | mal_user_grad_std2[name] = torch.sqrt( 235 | (mal_user_grad_pow[name] / len(mali_clients) - torch.pow(mal_user_grad_mean2[name], 2))) 236 | 237 | return mal_user_grad_mean2, mal_user_grad_std2 238 | 239 | 240 | def plot_1d(benign_zscores, mali_zscores, mu, var, pi, save_name): 241 | from matplotlib import pyplot as plt 242 | import seaborn as sns 243 | from scipy.stats import multivariate_normal 244 | 245 | mu = mu.cpu() 246 | var = var.cpu() 247 | pi = pi.cpu() 248 | 249 | benign = np.array(benign_zscores) 250 | mali = np.array(mali_zscores) 251 | # import pdb; pdb.set_trace() 252 | min_X = np.concatenate([benign, mali]).min() 253 | max_X = np.concatenate([benign, mali]).max() 254 | X = np.linspace(min_X - 0.1, max_X + 0.1, 1000) 255 | G_benign = multivariate_normal(mean=mu[0], cov=var[0]) 256 | G_mali = multivariate_normal(mean=mu[1], cov=var[1]) 257 | y_benign = G_benign.pdf(X) 258 | y_mali = G_mali.pdf(X) 259 | y_ = y_mali + y_benign 260 | 261 | sns.distplot(benign, norm_hist=True, kde=False) 262 | sns.distplot(mali, norm_hist=True, kde=False) 263 | plt.plot(X, y_benign) 264 | plt.plot(X, y_mali) 265 | 266 | plt.tight_layout() 267 | plt.savefig(save_name) 268 | plt.clf() 269 | 270 | 271 | def plot_2d(data, y, real, save_name): 272 | import matplotlib.pyplot as plt 273 | import seaborn as sns 274 | # import pdb; pdb.set_trace() 275 | data = data.cpu() 276 | # y = np.array(y) 277 | # real = np.array(real) 278 | n = data.shape[0] 279 | colors = sns.color_palette("Paired", n_colors=12).as_hex() 280 | 281 | fig, ax = plt.subplots(1, 1, figsize=(1.61803398875 * 4, 4)) 282 | ax.set_facecolor("#bbbbbb") 283 | ax.set_xlabel("KL") 284 | ax.set_ylabel("CE") 285 | 286 | # plot the locations of all data points .. 287 | for i, point in enumerate(data.data): 288 | if real[i] == 0: 289 | # .. separating them by ground truth .. 290 | ax.scatter(*point, color="#000000", s=3, alpha=.75, zorder=n + i) 291 | else: 292 | ax.scatter(*point, color="#ffffff", s=3, alpha=.75, zorder=n + i) 293 | 294 | if y[i] == 0: 295 | # .. as well as their predicted class 296 | ax.scatter(*point, zorder=i, color="#dbe9ff", alpha=.6, edgecolors=colors[5]) 297 | else: 298 | ax.scatter(*point, zorder=i, color="#ffdbdb", alpha=.6, edgecolors=colors[1]) 299 | 300 | handles = [plt.Line2D([0], [0], color="w", lw=4, label="Ground Truth Benign"), 301 | plt.Line2D([0], [0], color="black", lw=4, label="Ground Truth Malicious"), 302 | plt.Line2D([0], [0], color=colors[1], lw=4, label="Predicted Benign"), 303 | plt.Line2D([0], [0], color=colors[5], lw=4, label="Predicted Malicious"), ] 304 | 305 | legend = ax.legend(loc="best", handles=handles) 306 | 307 | plt.tight_layout() 308 | plt.savefig(save_name) 309 | 310 | 311 | def train_op_target(model, loader, optimizer, epochs, lambda_fedprox=0.0, class_num=10): 312 | model.train() 313 | 314 | W0 = {k: v.detach().clone() for k, v in model.named_parameters()} 315 | 316 | running_loss, samples = 0.0, 0 317 | for ep in range(epochs): 318 | for x, y in loader: 319 | # import pdb; pdb.set_trace() 320 | # print(y) 321 | y = torch.tensor([1] * len(y)) 322 | # print(y) 323 | # import pdb; pdb.set_trace() 324 | x, y = x.to(device), y.to(device) 325 | 326 | optimizer.zero_grad() 327 | 328 | loss = nn.CrossEntropyLoss()(model(x), y) 329 | 330 | if lambda_fedprox > 0.0: 331 | loss += lambda_fedprox * torch.sum( 332 | (flatten(W0).cuda() - flatten(dict(model.named_parameters())).cuda()) ** 2) 333 | 334 | running_loss += loss.item() * y.shape[0] 335 | samples += y.shape[0] 336 | 337 | loss.backward() 338 | optimizer.step() 339 | 340 | return {"loss": running_loss / samples} 341 | 342 | 343 | def train_op_flip(model, loader, optimizer, epochs, lambda_fedprox=0.0, class_num=10): 344 | model.train() 345 | 346 | W0 = {k: v.detach().clone() for k, v in model.named_parameters()} 347 | 348 | running_loss, samples = 0.0, 0 349 | for ep in range(epochs): 350 | for x, y in loader: 351 | 352 | # print(y) 353 | y += 1 354 | y = y % class_num 355 | # print(y) 356 | # import pdb; pdb.set_trace() 357 | x, y = x.to(device), y.to(device) 358 | 359 | optimizer.zero_grad() 360 | 361 | loss = nn.CrossEntropyLoss()(model(x), y) 362 | 363 | if lambda_fedprox > 0.0: 364 | loss += lambda_fedprox * torch.sum( 365 | (flatten(W0).cuda() - flatten(dict(model.named_parameters())).cuda()) ** 2) 366 | 367 | running_loss += loss.item() * y.shape[0] 368 | samples += y.shape[0] 369 | 370 | loss.backward() 371 | optimizer.step() 372 | 373 | return {"loss": running_loss / samples} 374 | 375 | def eval_epoch(model, loader): 376 | running_loss, samples = 0.0, 0 377 | with torch.no_grad(): 378 | for x, y in loader: 379 | x, y = x.to(device), y.to(device) 380 | loss = nn.CrossEntropyLoss()(model(x), y) 381 | running_loss += loss.item() * y.shape[0] 382 | samples += y.shape[0] 383 | running_loss = running_loss / samples 384 | return running_loss 385 | 386 | 387 | 388 | def gaussian_noise(data_shape, s, sigma, device=None): 389 | """ 390 | Gaussian noise 391 | """ 392 | return torch.normal(0, sigma * s, data_shape).to(device) 393 | 394 | 395 | def train_op(model, loader, optimizer, epochs, lambda_fedprox=0.0, print_train_loss=False): 396 | model.train() 397 | 398 | W0 = {k: v.detach().clone() for k, v in model.named_parameters()} 399 | losses = [] 400 | running_loss, samples = 0.0, 0 401 | for ep in range(epochs): 402 | for it, (x, y) in enumerate(loader): 403 | if print_train_loss and it % 2 == 0: 404 | losses.append(round(eval_epoch(model, loader), 2)) 405 | x, y = x.to(device), y.to(device) 406 | optimizer.zero_grad() 407 | loss = nn.CrossEntropyLoss()(model(x), y) 408 | if lambda_fedprox > 0.0: 409 | # import pdb; pdb.set_trace() 410 | loss += lambda_fedprox * torch.sum( 411 | (flatten(W0).cuda() - flatten(dict(model.named_parameters())).cuda()) ** 2) 412 | running_loss += loss.item() * y.shape[0] 413 | samples += y.shape[0] 414 | loss.backward() 415 | optimizer.step() 416 | if print_train_loss: 417 | print(losses) 418 | 419 | return {"loss": running_loss / samples} 420 | 421 | def train_op_private(model, loader, optimizer, epochs, lambda_fedprox=0.0, print_train_loss=False, privacy_sigma = 1, clip_bound = 5): 422 | model.train() 423 | 424 | W0 = {k: v.detach().clone() for k, v in model.named_parameters()} 425 | losses = [] 426 | running_loss, samples = 0.0, 0 427 | for ep in range(epochs): 428 | clipped_grads = {name: torch.zeros_like(param) for name, param in model.named_parameters()} 429 | for it, (x, y) in enumerate(loader): 430 | if print_train_loss and it % 2 == 0: 431 | losses.append(round(eval_epoch(model, loader), 2)) 432 | x, y = x.to(device), y.to(device) 433 | optimizer.zero_grad() 434 | loss = nn.CrossEntropyLoss()(model(x), y) 435 | loss.backward() 436 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_bound) 437 | for name, param in model.named_parameters(): 438 | clipped_grads[name] += param.grad 439 | model.zero_grad() 440 | # add Gaussian noise 441 | for name, param in model.named_parameters(): 442 | clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, clip_bound, privacy_sigma, device='cuda') 443 | for name, param in model.named_parameters(): 444 | param.grad = clipped_grads[name] 445 | running_loss += loss.item() * y.shape[0] 446 | samples += y.shape[0] 447 | optimizer.step() 448 | if print_train_loss: 449 | print(losses) 450 | 451 | return {"loss": running_loss / samples} 452 | 453 | def train_op_datadistill(model, loader, optimizer, epochs, images_train, labels_train, eta=0.5, current_round=0, 454 | start_round=0): 455 | model.train() 456 | distilled_dataset = TensorDataset(images_train, labels_train) 457 | distilled_loader = torch.utils.data.DataLoader(distilled_dataset, batch_size=32, shuffle=True, num_workers=4) 458 | 459 | running_loss, samples = 0.0, 0 460 | for ep in range(epochs): 461 | for (x, y), (x_dis, y_dis) in zip(loader, distilled_loader): 462 | x, y = x.to(device), y.to(device) 463 | x_dis, y_dis = x_dis.cuda(), y_dis.cuda() 464 | optimizer.zero_grad() 465 | loss = nn.CrossEntropyLoss()(model(x), y) 466 | if current_round >= start_round: 467 | loss_distill = nn.CrossEntropyLoss()(model(x_dis), y_dis) 468 | if eta > 0.0: 469 | loss_total = loss + eta * loss_distill 470 | else: 471 | loss_total = loss 472 | else: 473 | loss_distill = 0 474 | loss_total = loss 475 | print(f"loss {loss}, loss_distill {loss_distill}") 476 | running_loss += loss.item() * y.shape[0] 477 | samples += y.shape[0] 478 | 479 | loss_total.backward() 480 | optimizer.step() 481 | 482 | return {"loss": running_loss / samples} 483 | 484 | 485 | def kd_loss(output, y): 486 | soft_label = F.softmax(y, dim=1) 487 | # soft_label = y 488 | logsoftmax = torch.nn.LogSoftmax() 489 | return torch.mean(- soft_label * logsoftmax(output)) 490 | 491 | 492 | def train_op_datadistill_soft(model, loader, optimizer, epochs, images_train, labels_train, eta=0.5, current_round=0, 493 | start_round=0, dsa=True, args=None): 494 | model.train() 495 | distilled_dataset = TensorDataset(images_train, labels_train) 496 | distilled_loader = torch.utils.data.DataLoader(distilled_dataset, batch_size=32, shuffle=True, num_workers=4) 497 | distilled_iter = iter(distilled_loader) 498 | running_loss, samples = 0.0, 0 499 | for ep in range(epochs): 500 | for (x, y) in loader: 501 | x, y = x.to(device), y.to(device) 502 | try: 503 | x_dis, y_dis = next(distilled_iter) 504 | except: 505 | distilled_iter = iter(distilled_loader) 506 | x_dis, y_dis = next(distilled_iter) 507 | x_dis, y_dis = x_dis.cuda(), y_dis.cuda() 508 | if dsa: 509 | x_dis = DiffAugment(x_dis, args.dsa_strategy, param=args.dsa_param) 510 | optimizer.zero_grad() 511 | loss = nn.CrossEntropyLoss()(model(x), y) 512 | if current_round >= start_round: 513 | loss_distill = kd_loss(model(x_dis), y_dis) 514 | if eta > 0.0: 515 | loss_total = (1 - eta) * loss + eta * loss_distill 516 | else: 517 | loss_total = loss 518 | else: 519 | loss_distill = 0 520 | loss_total = loss 521 | print(f"eta {eta}, loss {loss}, loss_distill {loss_distill}") 522 | running_loss += loss.item() * y.shape[0] 523 | samples += y.shape[0] 524 | 525 | loss_total.backward() 526 | # loss.backward() 527 | optimizer.step() 528 | 529 | return {"loss": running_loss / samples} 530 | 531 | 532 | def train_op_datadistill_later(model, loader, optimizer, epochs, images_train, labels_train, finetune_epoch=1, 533 | finetune_lr=1e-3, current_round=0, start_round=0, dsa=None, args=None): 534 | model.train() 535 | distilled_dataset = TensorDataset(images_train, labels_train) 536 | distilled_loader = torch.utils.data.DataLoader(distilled_dataset, batch_size=256, shuffle=True, num_workers=4) 537 | 538 | running_loss, samples = 0.0, 0 539 | for ep in range(epochs): 540 | for x, y in loader: 541 | x, y = x.to(device), y.to(device) 542 | optimizer.zero_grad() 543 | loss = nn.CrossEntropyLoss()(model(x), y) 544 | print(f"loss {loss}") 545 | running_loss += loss.item() * y.shape[0] 546 | samples += y.shape[0] 547 | 548 | loss.backward() 549 | optimizer.step() 550 | if current_round >= start_round: 551 | optimizer_finetune = torch.optim.Adam(model.parameters(), lr=finetune_lr) 552 | for ep in range(finetune_epoch): 553 | for x_dis, y_dis in distilled_loader: 554 | x_dis, y_dis = x_dis.cuda(), y_dis.cuda() 555 | if dsa: 556 | x_dis = DiffAugment(x_dis, args.dsa_strategy, param=args.dsa_param) 557 | optimizer_finetune.zero_grad() 558 | loss_distill = kd_loss(model(x_dis), y_dis) 559 | loss_distill.backward() 560 | optimizer_finetune.step() 561 | print(f"loss_distill {loss_distill}") 562 | 563 | return {"loss": running_loss / samples} 564 | 565 | 566 | def train_op_nlp(model, loader, optimizer, epochs, lambda_fedprox=0.0): 567 | model.train() 568 | 569 | W0 = {k: v.detach().clone() for k, v in model.named_parameters()} 570 | 571 | running_loss, samples = 0.0, 0 572 | for ep in range(epochs): 573 | for label, text, offsets in loader: 574 | label, text, offsets = label.to(device), text.to(device), offsets.to(device) 575 | 576 | optimizer.zero_grad() 577 | prediction = model(text, offsets) 578 | 579 | loss = nn.CrossEntropyLoss()(prediction, label) 580 | 581 | if lambda_fedprox > 0.0: 582 | # import pdb; pdb.set_trace() 583 | loss += lambda_fedprox * torch.sum( 584 | (flatten(W0).cuda() - flatten(dict(model.named_parameters())).cuda()) ** 2) 585 | try: 586 | running_loss += loss.item() * label.shape[0] 587 | samples += label.shape[0] 588 | 589 | loss.backward() 590 | optimizer.step() 591 | except: 592 | print(f"labels {label}") 593 | print(f"prediction {prediction}") 594 | print(f"loss {loss}") 595 | print(f"error") 596 | 597 | return {"loss": running_loss / samples} 598 | 599 | 600 | def eval_op(model, loader): 601 | model.train() 602 | samples, correct = 0, 0 603 | 604 | with torch.no_grad(): 605 | for i, (x, y) in enumerate(loader): 606 | x, y = x.to(device), y.to(device) 607 | 608 | y_ = model(x) 609 | _, predicted = torch.max(y_.detach(), 1) 610 | 611 | samples += y.shape[0] 612 | correct += (predicted == y).sum().item() 613 | 614 | return {"accuracy": correct / samples} 615 | 616 | 617 | def eval_op_ensemble(models, test_loader, val_loader): 618 | for model in models: 619 | # model.train() 620 | model.eval() 621 | 622 | samples, correct = 0, 0 623 | 624 | with torch.no_grad(): 625 | for i, (x, y) in enumerate(test_loader): 626 | x, y = x.to(device), y.to(device) 627 | 628 | y_ = torch.mean(torch.stack([model(x) for model in models], dim=0), dim=0) 629 | _, predicted = torch.max(y_.detach(), 1) 630 | 631 | samples += y.shape[0] 632 | correct += (predicted == y).sum().item() 633 | test_acc = correct / samples 634 | 635 | for model in models: 636 | model.eval() 637 | 638 | samples, correct = 0, 0 639 | 640 | with torch.no_grad(): 641 | for i, (x, y) in enumerate(val_loader): 642 | x, y = x.to(device), y.to(device) 643 | 644 | y_ = torch.mean(torch.stack([model(x) for model in models], dim=0), dim=0) 645 | _, predicted = torch.max(y_.detach(), 1) 646 | 647 | samples += y.shape[0] 648 | correct += (predicted == y).sum().item() 649 | val_acc = correct / samples 650 | 651 | return {"test_accuracy": test_acc, "val_accuracy": val_acc} 652 | 653 | 654 | def eval_op_ensemble_nlp(models, test_loader, val_loader): 655 | for model in models: 656 | model.train() 657 | 658 | samples, correct = 0, 0 659 | 660 | with torch.no_grad(): 661 | for label, text, offsets in test_loader: 662 | label, text, offsets = label.to(device), text.to(device), offsets.to(device) 663 | 664 | y_ = torch.mean(torch.stack([model(text, offsets) for model in models], dim=0), dim=0) 665 | _, predicted = torch.max(y_.detach(), 1) 666 | 667 | samples += label.shape[0] 668 | correct += (predicted == label).sum().item() 669 | test_acc = correct / samples 670 | 671 | for model in models: 672 | model.eval() 673 | 674 | samples, correct = 0, 0 675 | 676 | with torch.no_grad(): 677 | for label, text, offsets in val_loader: 678 | label, text, offsets = label.to(device), text.to(device), offsets.to(device) 679 | 680 | y_ = torch.mean(torch.stack([model(text, offsets) for model in models], dim=0), dim=0) 681 | _, predicted = torch.max(y_.detach(), 1) 682 | 683 | samples += label.shape[0] 684 | correct += (predicted == label).sum().item() 685 | val_acc = correct / samples 686 | 687 | return {"test_accuracy": test_acc, "val_accuracy": val_acc} 688 | 689 | 690 | def reduce_average(target, sources): 691 | # import pdb; pdb.set_trace() 692 | for name in target: 693 | target[name].data = torch.mean(torch.stack([source[name].detach() for source in sources]), dim=0).clone() 694 | 695 | 696 | def reduce_median(target, sources): 697 | for name in target: 698 | # import pdb; pdb.set_trace() 699 | target[name].data = torch.median(torch.stack([source[name].detach() for source in sources]), 700 | dim=0).values.clone() 701 | # import pdb; pdb.set_trace() 702 | 703 | 704 | def reduce_trimmed_mean(target, sources, mali_ratio): 705 | import math 706 | trimmed_mean_beta = math.ceil(mali_ratio * len(sources)) + 1 707 | for name in target: 708 | stacked_weights = torch.stack([source[name].detach() for source in sources]) 709 | # import pdb; pdb.set_trace() 710 | user_num = stacked_weights.size(0) 711 | largest_value, _ = torch.topk(stacked_weights, k=trimmed_mean_beta, dim=0) 712 | smallest_value, _ = torch.topk(stacked_weights, k=trimmed_mean_beta, dim=0, largest=False) 713 | target[name].data = (( 714 | torch.sum(stacked_weights, dim=0) 715 | - torch.sum(largest_value, dim=0) 716 | - torch.sum(smallest_value, dim=0) 717 | ) / (user_num - 2 * trimmed_mean_beta)).clone() 718 | # import pdb; pdb.set_trace() 719 | 720 | 721 | def reduce_krum(target, sources, mali_ratio): 722 | import math 723 | krum_mal_num = math.ceil(mali_ratio * len(sources)) + 1 724 | user_num = len(sources) 725 | user_flatten_grad = [] 726 | for source in sources: 727 | user_flatten_grad_i = [] 728 | for name in target: 729 | user_flatten_grad_i.append(torch.flatten(source[name].detach())) 730 | user_flatten_grad_i = torch.cat(user_flatten_grad_i) 731 | user_flatten_grad.append(user_flatten_grad_i) 732 | user_flatten_grad = torch.stack(user_flatten_grad) 733 | 734 | # compute l2 distance between users 735 | user_scores = torch.zeros((user_num, user_num), device=user_flatten_grad.device) 736 | for u_i, source in enumerate(sources): 737 | user_scores[u_i] = torch.norm( 738 | user_flatten_grad - user_flatten_grad[u_i], 739 | dim=list(range(len(user_flatten_grad.shape)))[1:], 740 | ) 741 | # import pdb; pdb.set_trace() 742 | user_scores[u_i, u_i] = float('inf') 743 | topk_user_scores, _ = torch.topk( 744 | user_scores, k=user_num - krum_mal_num - 2, dim=1, largest=False 745 | ) 746 | sm_user_scores = torch.sum(topk_user_scores, dim=1) 747 | 748 | # users with smallest score is selected as update gradient 749 | u_score, select_ui = torch.topk(sm_user_scores, k=1, largest=False) 750 | select_ui = select_ui.cpu().numpy() 751 | select_ui = select_ui[0] 752 | print(select_ui) 753 | # import pdb; pdb.set_trace() 754 | for name in target: 755 | target[name].data = sources[select_ui][name].detach().clone() 756 | 757 | 758 | def reduce_residual(source_1, source_2): 759 | tmp_dict = {} 760 | # import pdb; pdb.set_trace() 761 | for name in source_1: 762 | tmp_dict[name] = (source_1[name].detach() - source_2[name].detach()).clone() 763 | # import pdb; pdb.set_trace() 764 | return tmp_dict 765 | 766 | 767 | def reduce_weighted(target, sources, weights): 768 | for name in target: 769 | # import pdb; pdb.set_trace() 770 | target[name].data = torch.sum(weights * torch.stack([source[name].detach() for source in sources], dim=-1), 771 | dim=-1).clone() 772 | # import pdb; pdb.set_trace() 773 | 774 | 775 | def flatten(source): 776 | return torch.cat([value.flatten() for value in source.values()]) 777 | 778 | 779 | def copy(target, source): 780 | for name in target: 781 | target[name].data = source[name].detach().clone() 782 | 783 | 784 | def olr(mu, var): 785 | from scipy.stats import multivariate_normal 786 | X = np.linspace(0, 0.4, 1000) 787 | if mu[0] > mu[1]: 788 | new_mu = [mu[1], mu[0]] 789 | new_var = [var[1], var[0]] 790 | else: 791 | new_mu = mu 792 | new_var = var 793 | step = 500 794 | x_step = (new_mu[1] - new_mu[0]) / step 795 | 796 | G_m = multivariate_normal(mean=new_mu[0], cov=new_var[0]) 797 | G_b = multivariate_normal(mean=new_mu[1], cov=new_var[1]) 798 | 799 | y_benign = G_b.pdf(X) 800 | y_mali = G_m.pdf(X) 801 | index = 0 802 | while index < step: 803 | x = mu[0] + x_step * index 804 | if G_b.pdf(x) > G_m.pdf(x): 805 | break 806 | index += 1 807 | overlap = (1 - G_m.cdf(x)) + G_b.cdf(x) 808 | return overlap 809 | --------------------------------------------------------------------------------