├── LICENSE ├── README.md ├── commands ├── leaf │ ├── fedavg_celeba.sh │ ├── fedavg_femnist.sh │ ├── fedavg_reddit.sh │ ├── fedavg_sent140.sh │ ├── fedavg_shakespeare.sh │ └── leaf.md ├── original_cifar10 │ ├── cifar10.md │ ├── fedavg_cifar10.sh │ └── fedsgd_cifar10.sh ├── original_mnist │ ├── fedavg_mnist_2nn.sh │ ├── fedavg_mnist_cnn.sh │ ├── fedsgd_mnist_2nn.sh │ ├── fedsgd_mnist_cnn.sh │ └── mnist.md └── original_play_n_role │ ├── Shakespeare.md │ ├── fedavg_play_n_role.sh │ └── fedsgd_play_n_role.sh ├── main.py ├── requirments.txt └── src ├── __init__.py ├── algorithm ├── __init__.py ├── basealgorithm.py ├── fedadagrad.py ├── fedadam.py ├── fedavg.py ├── fedavgm.py ├── fedprox.py ├── fedsgd.py ├── fedyogi.py └── sent140.sh ├── client ├── __init__.py ├── baseclient.py ├── fedadagradclient.py ├── fedadamclient.py ├── fedavgclient.py ├── fedavgmclient.py ├── fedproxclient.py ├── fedsgdclient.py └── fedyogiclient.py ├── datasets ├── __init__.py ├── adult.py ├── beerreviews.py ├── cinic10.py ├── cover.py ├── gleam.py ├── heart.py ├── leaf │ ├── __init__.py │ ├── leaf_utils.py │ ├── postprocess │ │ ├── __init__.py │ │ ├── filter.py │ │ ├── postprocess.py │ │ ├── sample.py │ │ └── split.py │ └── preprocess │ │ ├── __init__.py │ │ ├── celeba.py │ │ ├── femnist.py │ │ ├── reddit.py │ │ ├── sent140.py │ │ └── shakespeare.py ├── leafparser.py ├── speechcommands.py ├── tinyimagenet.py ├── torchtextparser.py └── torchvisionparser.py ├── loaders ├── __init__.py ├── data.py ├── model.py └── split.py ├── metrics ├── __init__.py ├── basemetric.py └── metricszoo.py ├── models ├── __init__.py ├── distilbert.py ├── femnistcnn.py ├── lenet.py ├── logreg.py ├── m5.py ├── mobilebert.py ├── mobilenet.py ├── mobilenext.py ├── mobilevit.py ├── model_utils.py ├── resnet.py ├── sent140lstm.py ├── shufflenet.py ├── simplecnn.py ├── squeezebert.py ├── squeezenet.py ├── squeezenext.py ├── stackedlstm.py ├── stackedtransformer.py ├── twocnn.py ├── twonn.py └── vgg.py ├── server ├── __init__.py ├── baseserver.py ├── fedadagradserver.py ├── fedadamserver.py ├── fedavgmserver.py ├── fedavgserver.py ├── fedproxserver.py ├── fedsgdserver.py └── fedyogiserver.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Seok-Ju Hahn (GitHub: vaseline555; Email: sjhahn11512@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Federated Learning in PyTorch 3 | Implementations of various Federated Learning (FL) algorithms in PyTorch, especially for research purposes. 4 | 5 | ## Implementation Details 6 | ### Datasets 7 | * Supports all image classification datasets in `torchvision.datasets`. 8 | * Supports all text classification datasets in `torchtext.datasets`. 9 | * Supports all datasets in [LEAF benchmark](https://leaf.cmu.edu/) (*NO need to prepare raw data manually*) 10 | * Supports additional image classification datasets ([`TinyImageNet`](https://www.kaggle.com/c/tiny-imagenet), [`CINIC10`](https://datashare.ed.ac.uk/handle/10283/3192)). 11 | * Supports additional text classification datasets ([`BeerReviews`](https://snap.stanford.edu/data/web-BeerAdvocate.html)). 12 | * Supports tabular datasets ([`Heart`, `Adult`, `Cover`](https://archive.ics.uci.edu/ml/index.php)). 13 | * Supports temporal dataset ([`GLEAM`](http://www.skleinberg.org/data.html)) 14 | * __NOTE__: don't bother to search raw files of datasets; the dataset can automatically be downloaded to the designated path by just passing its name! 15 | ### Statistical Heterogeneity Simulations 16 | * `IID` (i.e., statistical homogeneity) 17 | * `Unbalanced` (i.e., sample counts heterogeneity) 18 | * `Pathological Non-IID` ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629)) 19 | * `Dirichlet distribution-based Non-IID` ([Hsu et al., 2019](https://arxiv.org/abs/1909.06335)) 20 | * `Pre-defined` (for datasets having natural semantic separation, including `LEAF` benchmark ([Caldas et al., 2018](https://arxiv.org/abs/1812.01097))) 21 | ### Models 22 | * `LogReg` (logistic regression), `StackedTransformer` (TransformerEncoder-based classifier) 23 | * `TwoNN`, `TwoCNN`, `SimpleCNN` ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629)) 24 | * `FEMNISTCNN`, `Sent140LSTM` ([Caldas et al., 2018](https://arxiv.org/abs/1812.01097))) 25 | * `LeNet` ([LeCun et al., 1998](https://ieeexplore.ieee.org/document/726791/)), `MobileNet` ([Howard et al., 2019](https://arxiv.org/abs/1905.02244)), `SqueezeNet` ([Iandola et al., 2016](https://arxiv.org/abs/1602.07360)), `VGG` ([Simonyan et al., 2014](https://arxiv.org/abs/1409.1556)), `ResNet` ([He et al., 2015](https://arxiv.org/abs/1512.03385)) 26 | * `MobileNeXt` ([Daquan et al., 2020](https://arxiv.org/abs/2007.02269)), `SqueezeNeXt` ([Gholami et al., 2016](https://arxiv.org/abs/1803.10615)), `MobileViT` ([Mehta et al., 2021](https://arxiv.org/abs/2110.02178)) 27 | * `DistilBERT` ([Sanh et al., 2019](https://arxiv.org/abs/1910.01108)), `SqueezeBERT` ([Iandola et al., 2020](https://arxiv.org/abs/2006.11316)), `MobileBERT` ([Sun et al., 2020](https://arxiv.org/abs/2004.02984)) 28 | * `M5` ([Dai et al., 2016](https://arxiv.org/abs/1610.00087)) 29 | ### Algorithms 30 | * `FedAvg` and `FedSGD` (McMahan et al., 2016) Communication-Efficient Learning of Deep Networks from Decentralized Data 31 | * `FedAvgM` (Hsu et al., 2019) Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification 32 | * `FedProx` (Li et al., 2018) Federated Optimization in Heterogeneous Networks 33 | * `FedOpt` (`FedAdam`, `FedYogi`, `FedAdaGrad`) (Reddi et al., 2020) Adaptive Federated Optimization 34 | 35 | ### Evaluation schemes 36 | * `local`: evaluate FL algorithm using holdout sets of (some/all) clients NOT participating in the current round. (i.e., evaluation of personalized federated learning setting) 37 | * `global`: evaluate FL algorithm using global holdout set located at the server. (*ONLY available if the raw dataset supports pre-defined validation/test set*). 38 | * `both`: evaluate FL algorithm using both `local` and `global` schemes. 39 | ### Metrics 40 | * Top-1 Accuracy, Top-5 Accuracy, Precision, Recall, F1 41 | * Area under ROC, Area under PRC, Youden's J 42 | * Seq2Seq Accuracy 43 | * MSE, RMSE, MAE, MAPE 44 | * $R^2$, $D^2$ 45 | 46 | ## Requirements 47 | * See `requirements.txt`. (I recommend building an independent environment for this project, using e.g., `Docker` or `conda`) 48 | * When you install `torchtext`, please check the version compatibility with `torch`. (See [official repository](https://github.com/pytorch/text#installation)) 49 | * Plus, please install `torch`-related packages using one command provided by the official guide (See [official installation guide](https://pytorch.org/get-started/locally/)); e.g., `conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 torchtext==0.13.0 cudatoolkit=11.6 -c pytorch -c conda-forge` 50 | 51 | ## Configurations 52 | * See `python3 main.py -h`. 53 | 54 | ## Example Commands 55 | * See shell files prepared in `commands` directory. 56 | 57 | ## TODO 58 | - [ ] Support another model, especially lightweight ones for cross-device FL setting. (e.g., [`EdgeNeXt`](https://github.com/mmaaz60/EdgeNeXt)) 59 | - [ ] Support another structured dataset including temporal and tabular data, along with datasets suitable for cross-silo FL setting. (e.g., [`MedMNIST`](https://github.com/MedMNIST/MedMNIST)) 60 | - [ ] Add other popular FL algorithms including personalized FL algorithms (e.g., [`SuPerFed`](https://arxiv.org/abs/2109.07628)). 61 | - [ ] Attach benchmark results of sample commands. 62 | 63 | ## Contact 64 | Should you have any feedback, please create a thread in __issue__ tab. Thank you :) 65 | -------------------------------------------------------------------------------- /commands/leaf/fedavg_celeba.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments for LEAF CelebA dataset 4 | python3 main.py \ 5 | --exp_name FedAvg_LEAF_CelebA --seed 42 --device cuda \ 6 | --dataset CelebA \ 7 | --split_type pre --test_size 0.1 \ 8 | --model_name TwoCNN --resize 84 --hidden_size 32 \ 9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 \ 10 | --R 5000 --E 5 --C 0.001 --B 10 --beta1 0 \ 11 | --optimizer SGD --lr 0.1 --lr_decay 1 --lr_decay_step 1 --criterion BCEWithLogitsLoss 12 | -------------------------------------------------------------------------------- /commands/leaf/fedavg_femnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments for LEAF FEMNIST dataset 4 | python3 main.py \ 5 | --exp_name FedAvg_LEAF_FEMNIST --seed 42 --device cuda \ 6 | --dataset FEMNIST \ 7 | --split_type pre --test_size 0.1 \ 8 | --model_name FEMNISTCNN --resize 28 --hidden_size 64 \ 9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 acc5 \ 10 | --R 5000 --E 5 --C 0.003 --B 10 --beta1 0 \ 11 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss 12 | -------------------------------------------------------------------------------- /commands/leaf/fedavg_reddit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments for LEAF Reddit dataset 4 | python3 main.py \ 5 | --exp_name FedAvg_LEAF_Reddit --seed 42 --device cuda \ 6 | --dataset Reddit \ 7 | --split_type pre --test_size 0.1 \ 8 | --model_name NextWordLSTM --num_layers 2 --num_embeddings 10000 --embedding_size 256 --hidden_size 256 \ 9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics seqacc \ 10 | --R 5000 --E 5 --C 0.013 --B 50 --beta1 0 \ 11 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion Seq2SeqLoss 12 | -------------------------------------------------------------------------------- /commands/leaf/fedavg_sent140.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments for LEAF Sent140 dataset 4 | ## Sampled 5% of raw dataset as stated in the original paper, as total clients is over 200K...! 5 | python3 main.py \ 6 | --exp_name FedAvg_LEAF_Sent140 --seed 42 --device cuda \ 7 | --dataset Sent140 \ 8 | --split_type pre --test_size 0.1 \ 9 | --model_name NextCharLSTM --embedding_size 300 --hidden_size 128 --num_layers 2 \ 10 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 \ 11 | --R 5000 --E 5 --C 0.0001 --B 10 --beta1 0 \ 12 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion BCEWithLogitsLoss 13 | -------------------------------------------------------------------------------- /commands/leaf/fedavg_shakespeare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments for LEAF Shakespeare dataset 4 | python3 main.py \ 5 | --exp_name FedAvg_LEAF_Shakespeare --seed 42 --device cuda \ 6 | --dataset Shakespeare \ 7 | --split_type pre --test_size 0.1 \ 8 | --model_name NextCharLSTM --num_embeddings 80 --embedding_size 8 --hidden_size 256 --num_layers 2 \ 9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 acc5 \ 10 | --R 5000 --E 5 --C 0.016 --B 10 --beta1 0 \ 11 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss 12 | -------------------------------------------------------------------------------- /commands/leaf/leaf.md: -------------------------------------------------------------------------------- 1 | # Replication Study - LEAF 2 | --- 3 | -------------------------------------------------------------------------------- /commands/original_cifar10/cifar10.md: -------------------------------------------------------------------------------- 1 | # Replication Study - CIFAR10 2 | --- 3 | ## CIFAR10 4 | | IID `FedSGD` ($E=1$) || IID `FedAvg` ($E=5$)|| 5 | |:------------:|:----------:|:----------:|:--------:| 6 | | $\eta$ | | $\eta$ | | 7 | | 0.45 | | 0.05 | | 8 | | 0.6 | | 0.15 | | 9 | | 0.7 | | 0.25 | | 10 | * Corresponds to Table 3 and Figure 4 from ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629)). (Note that $B=50, C=0.1, K=100$) -------------------------------------------------------------------------------- /commands/original_cifar10/fedavg_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments in Figure 4, 9 of (McMahan et al., 2016) 4 | ## IID split 5 | python3 main.py \ 6 | --exp_name FedAvg_CIFAR10_CNN_IID --seed 42 --device cuda \ 7 | --dataset CIFAR10 \ 8 | --split_type iid --test_size -1 \ 9 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \ 10 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \ 11 | --K 100 --R 1000 --C 0.1 --E 5 --B 50 --beta1 0 \ 12 | --optimizer SGD --lr 0.25 --lr_decay 0.99 --lr_decay_step 1 --criterion CrossEntropyLoss 13 | 14 | ## Pathological Non-IID split 15 | python3 main.py \ 16 | --exp_name FedAvg_CIFAR10_CNN_Patho --seed 42 --device cuda \ 17 | --dataset CIFAR10 \ 18 | --split_type patho --test_size -1 \ 19 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \ 20 | --algorithm fedavg --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \ 21 | --K 100 --R 1000 --C 0.1 --E 5 --B 50 --beta1 0 \ 22 | --optimizer SGD --lr 0.1 --lr_decay 0.99 --lr_decay_step 1 --criterion CrossEntropyLoss 23 | -------------------------------------------------------------------------------- /commands/original_cifar10/fedsgd_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedSGD experiments in Figure 4, 9 of (McMahan et al., 2016) 4 | ## IID split 5 | python3 main.py \ 6 | --exp_name FedSGD_CIFAR10_CNN_IID --seed 42 --device cuda \ 7 | --dataset CIFAR10 \ 8 | --split_type iid --test_size -1 \ 9 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \ 10 | --algorithm fedsgd --eval_fraction 1 --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \ 11 | --K 100 --R 10000 --C 0.1 --B 0 --beta1 0 \ 12 | --optimizer SGD --lr 0.6 --lr_decay 0.9934 --lr_decay_step 1 --criterion CrossEntropyLoss 13 | 14 | ## Pathological Non-IID split 15 | python3 main.py \ 16 | --exp_name FedSGD_CIFAR10_CNN_Patho --seed 42 --device cuda \ 17 | --dataset CIFAR10 \ 18 | --split_type patho --test_size -1 \ 19 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \ 20 | --algorithm fedsgd --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \ 21 | --K 100 --R 10000 --C 0.1 --B 0 --beta1 0 \ 22 | --optimizer SGD --lr 0.45 --lr_decay 0.9934 --lr_decay_step 1 --criterion CrossEntropyLoss 23 | -------------------------------------------------------------------------------- /commands/original_mnist/fedavg_mnist_2nn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments in Table 1 of (McMahan et al., 2016) 4 | ## IID split 5 | for b in 0 10 6 | do 7 | for c in 0.0 0.1 0.2 0.5 1.0 8 | do 9 | python3 main.py \ 10 | --exp_name "FedAvg_MNIST_2NN_IID_C${c}_B${b}" --seed 42 --device cuda \ 11 | --dataset MNIST \ 12 | --split_type iid --test_size 0 \ 13 | --model_name TwoNN --resize 28 --hidden_size 200 \ 14 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 15 | --K 100 --R 1000 --E 1 --C $c --B $b --beta1 0 \ 16 | --optimizer SGD --lr 0.1 --lr_decay 0.99 --lr_decay_step 25 --criterion CrossEntropyLoss 17 | done 18 | done 19 | 20 | ## Pathological Non-IID split 21 | for b in 0 10 22 | do 23 | for c in 0.0 0.1 0.2 0.5 1.0 24 | do 25 | python3 main.py \ 26 | --exp_name "FedAvg_MNIST_2NN_Patho_C${c}_B${b}" --seed 42 --device cuda \ 27 | --dataset MNIST \ 28 | --split_type patho --test_size 0 \ 29 | --model_name TwoNN --resize 28 --hidden_size 200 \ 30 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 31 | --K 100 --R 1000 --E 1 --C $c --B $b --beta1 0 \ 32 | --optimizer SGD --lr 0.01 --lr_decay 0.999 --lr_decay_step 10 --criterion CrossEntropyLoss 33 | done 34 | done 35 | -------------------------------------------------------------------------------- /commands/original_mnist/fedavg_mnist_cnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments in Table 1 of (McMahan et al., 2016) 4 | ## IID split 5 | for b in 0 10 6 | do 7 | for c in 0.0 0.1 0.2 0.5 1.0 8 | do 9 | python3 main.py \ 10 | --exp_name "FedAvg_MNIST_CNN_IID_C${c}_B${b}" --seed 42 --device cuda \ 11 | --dataset MNIST \ 12 | --split_type iid --test_size 0 \ 13 | --model_name TwoCNN --resize 28 --hidden_size 200 \ 14 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 15 | --K 100 --R 1000 --E 5 --C $c --B $b --beta1 0 \ 16 | --optimizer SGD --lr 0.215 --lr_decay 0.95 --lr_decay_step 25 --criterion CrossEntropyLoss 17 | done 18 | done 19 | 20 | ## Pathological Non-IID split 21 | for b in 0 10 22 | do 23 | for c in 0.0 0.1 0.2 0.5 1.0 24 | do 25 | python3 main.py \ 26 | --exp_name "FedAvg_MNIST_CNN_Patho_C${c}_B${b}" --seed 42 --device cuda \ 27 | --dataset MNIST \ 28 | --split_type patho --test_size 0 \ 29 | --model_name TwoCNN --resize 28 --hidden_size 200 \ 30 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 31 | --K 100 --R 1000 --E 5 --C $c --B $b --beta1 0 \ 32 | --optimizer SGD --lr 0.1 --lr_decay 0.99 --lr_decay_step 10 --criterion CrossEntropyLoss 33 | done 34 | done 35 | -------------------------------------------------------------------------------- /commands/original_mnist/fedsgd_mnist_2nn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedSGD experiments in Table 1 of (McMahan et al., 2016) 4 | ## IID split 5 | for c in 0.0 0.1 0.2 0.5 1.0 6 | do 7 | python3 main.py \ 8 | --exp_name "FedSGD_MNIST_2NN_IID_C${c}_B0" --seed 42 --device cuda \ 9 | --dataset MNIST \ 10 | --split_type iid --test_size 0 \ 11 | --model_name TwoNN --resize 28 --hidden_size 200 \ 12 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 13 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \ 14 | --optimizer SGD --lr 1.0 --lr_decay 0.99 --lr_decay_step 25 --criterion CrossEntropyLoss 15 | done 16 | 17 | ## Pathological Non-IID split 18 | for c in 0.0 0.1 0.2 0.5 1.0 19 | do 20 | python3 main.py \ 21 | --exp_name "FedSGD_MNIST_2NN_Patho_C${c}_B0" --seed 42 --device cuda \ 22 | --dataset MNIST \ 23 | --split_type patho --test_size 0 \ 24 | --model_name TwoNN --resize 28 --hidden_size 200 \ 25 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 26 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \ 27 | --optimizer SGD --lr 0.1 --lr_decay 0.95 --lr_decay_step 10 --criterion CrossEntropyLoss 28 | done 29 | -------------------------------------------------------------------------------- /commands/original_mnist/fedsgd_mnist_cnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedSGD experiments in Table 1 of (McMahan et al., 2016) 4 | ## IID split 5 | for c in 0.0 0.1 0.2 0.5 1.0 6 | do 7 | python3 main.py \ 8 | --exp_name "FedSGD_MNIST_CNN_IID_C${c}_B0" --seed 42 --device cuda \ 9 | --dataset MNIST \ 10 | --split_type iid --test_size 0 \ 11 | --model_name TwoCNN --resize 28 --hidden_size 200 \ 12 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 13 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \ 14 | --optimizer SGD --lr 0.5 --lr_decay 0.95 --lr_decay_step 25 --criterion CrossEntropyLoss 15 | done 16 | 17 | 18 | ## Pathological Non-IID split 19 | for c in 0.0 0.1 0.2 0.5 1.0 20 | do 21 | python3 main.py \ 22 | --exp_name "FedSGD_MNIST_CNN_Patho_C${c}_B0" --seed 42 --device cuda \ 23 | --dataset MNIST \ 24 | --split_type patho --test_size 0 \ 25 | --model_name TwoCNN --resize 28 --hidden_size 200 \ 26 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \ 27 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \ 28 | --optimizer SGD --lr 0.25 --lr_decay 0.99 --lr_decay_step 10 --criterion CrossEntropyLoss 29 | done 30 | -------------------------------------------------------------------------------- /commands/original_mnist/mnist.md: -------------------------------------------------------------------------------- 1 | # Replication Study - MNIST 2 | --- 3 | ## 2NN and CNN 4 | | $\text{2NN}$ | IID `FedSGD`($E=1$)|| Non-IID `FedSGD` ($E=1$)|| IID `FedAvg` ($E=5$)|| Non-IID `FedAvg` ($E=5$)|| 5 | |:------------:|:----------:|:------:|:----------:|:------:|:----------:|:------:|:----------:|:------:| 6 | | $C$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | 7 | | 0.0 | | | | | | | | | 8 | | 0.1 | | | | | | | | | 9 | | 0.2 | | | | | | | | | 10 | | 0.5 | | | | | | | | | 11 | | 1.0 | | | | | | | | | 12 | 13 | | $\text{CNN}$ | IID `FedSGD` ($E=1$)|| Non-IID `FedSGD` ($E=1$)|| IID `FedAvg` ($E=5$)|| Non-IID `FedAvg` ($E=5$)|| 14 | |:------------:|:----------:|:-------:|:----------:|:-----------:|:----------:|:--------:|:----------:|:-----------:| 15 | | $C$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | 16 | | 0.0 | | | | | | | | | 17 | | 0.1 | | | | | | | | | 18 | | 0.2 | | | | | | | | | 19 | | 0.5 | | | | | | | | | 20 | | 1.0 | | | | | | | | | 21 | * Corresponds to Table 1 from ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629)). (Note that $K=100$) 22 | -------------------------------------------------------------------------------- /commands/original_play_n_role/Shakespeare.md: -------------------------------------------------------------------------------- 1 | # Replication Study - Shakespeare 2 | --- 3 | -------------------------------------------------------------------------------- /commands/original_play_n_role/fedavg_play_n_role.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedAvg experiments in Table 2, Figure 2, 3 of (McMahan et al., 2016) 4 | ## Note: this is equivalent to Shakespeare dataset under LEAF benchmark 5 | ## Role & Play Non-IID split 6 | python3 main.py \ 7 | --exp_name FedAvg_Shakespeare_NextCharLSTM --seed 42 --device cuda \ 8 | --dataset Shakespeare \ 9 | --split_type pre --test_size 0.2 \ 10 | --model_name NextCharLSTM --num_embeddings 80 --embedding_size 8 --hidden_size 256 --num_layers 2 \ 11 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 100 --eval_metrics acc1 acc5 \ 12 | --R 2000 --C 0.002 --E 1 --B 10 --beta1 0 \ 13 | --optimizer SGD --lr 1.47 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss 14 | -------------------------------------------------------------------------------- /commands/original_play_n_role/fedsgd_play_n_role.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # FedSGD experiments in Table 2, Figure 2, 3 of (McMahan et al., 2016) 4 | ## Note: this is equivalent to Shakespeare dataset under LEAF benchmark 5 | ## Role & Play Non-IID split 6 | python3 main.py \ 7 | --exp_name FedSGD_Shakespeare_NextCharLSTM --seed 42 --device cuda \ 8 | --dataset Shakespeare \ 9 | --split_type pre --test_size 0.2 \ 10 | --model_name NextCharLSTM --num_embeddings 80 --embedding_size 8 --hidden_size 256 --num_layers 2 \ 11 | --algorithm fedsgd --eval_fraction 1 --eval_type local --eval_every 500 --eval_metrics acc1 acc5 \ 12 | --R 5000 --C 0.002 --B 0 --beta1 0 \ 13 | --optimizer SGD --lr 1.47 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss 14 | -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | tensorboard==2.12.1 2 | numpy==1.24.2 3 | pandas==1.5.3 4 | scikit-learn==1.2.2 5 | transformers==4.27.4 6 | tqdm==4.65.0 7 | einops==0.6.1 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import transformers 3 | 4 | # turn off unnecessary logging 5 | transformers.logging.set_verbosity_error() 6 | 7 | from .utils import set_seed, Range, TensorBoardRunner, check_args, init_weights, TqdmToLogger, MetricManager, stratified_split 8 | from .loaders import load_dataset, load_model 9 | 10 | 11 | 12 | # for logger initialization 13 | def set_logger(path, args): 14 | # initialize logger 15 | logger = logging.getLogger(__name__) 16 | logging_format = logging.Formatter( 17 | fmt='[%(levelname)s] (%(asctime)s) %(message)s', 18 | datefmt='%Y/%m/%d %I:%M:%S %p' 19 | ) 20 | stream_handler = logging.StreamHandler() 21 | file_handler = logging.FileHandler(path) 22 | 23 | stream_handler.setFormatter(logging_format) 24 | file_handler.setFormatter(logging_format) 25 | 26 | logger.addHandler(stream_handler) 27 | logger.addHandler(file_handler) 28 | logger.setLevel(level=logging.INFO) 29 | 30 | # print welcome message 31 | logger.info('[WELCOME] Initialize...') 32 | welcome_message = """ 33 | _______ _______ ______ _______ ______ _______ _______ _______ ______ 34 | |______ |______ | \ |______ |_____/ |_____| | |______ | \\ 35 | | |_______|_____/_|_______|__ _\_ |_ ___|_ __| _ |______ |_____/ 36 | | |______ |_____| |_____/ | \ | | | \ | | ____ 37 | |_____ |______ | | | \_ | \_| __|__ | \_| |_____| 38 | 39 | By. vaseline555 (Adam) 40 | """ 41 | logger.info(welcome_message) 42 | -------------------------------------------------------------------------------- /src/algorithm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/algorithm/__init__.py -------------------------------------------------------------------------------- /src/algorithm/basealgorithm.py: -------------------------------------------------------------------------------- 1 | from abc import * 2 | 3 | 4 | 5 | class BaseOptimizer(metaclass=ABCMeta): 6 | """Federated optimization algorithm. 7 | """ 8 | @abstractmethod 9 | def step(self, closure=None): 10 | raise NotImplementedError 11 | 12 | @abstractmethod 13 | def accumulate(self, **kwargs): 14 | raise NotImplementedError 15 | 16 | -------------------------------------------------------------------------------- /src/algorithm/fedadagrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basealgorithm import BaseOptimizer 4 | 5 | 6 | 7 | class FedadagradOptimizer(BaseOptimizer, torch.optim.Optimizer): 8 | def __init__(self, params, **kwargs): 9 | lr = kwargs.get('lr') 10 | v0 = kwargs.get('v0') 11 | tau = kwargs.get('tau') 12 | momentum = kwargs.get('beta') 13 | defaults = dict(lr=lr, momentum=momentum, v0=v0, tau=tau) 14 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults) 15 | 16 | def step(self, closure=None): 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for idx, group in enumerate(self.param_groups): 22 | beta = group['momentum'] 23 | tau = group['tau'] 24 | lr = group['lr'] 25 | v0 = group['v0'] 26 | for param in group['params']: 27 | if param.grad is None: 28 | continue 29 | # get (\Delta_t) 30 | delta = -param.grad.data 31 | 32 | if idx == 0: # idx == 0: parameters; optimize according to algorithm 33 | # calculate m_t 34 | if 'momentum_buffer1' not in self.state[param]: 35 | self.state[param]['momentum_buffer1'] = torch.zeros_like(param).detach() 36 | self.state[param]['momentum_buffer1'].mul_(beta).add_(delta.mul(1. - beta)) # \beta * m_t + (1 - \beta) * \Delta_t 37 | m_new = self.state[param]['momentum_buffer1'] 38 | 39 | # calculate v_t 40 | if 'momentum_buffer2' not in self.state[param]: 41 | self.state[param]['momentum_buffer2'] = v0 + delta.pow(2) 42 | self.state[param]['momentum_buffer2'].add_(delta.pow(2)) # v_t + \Delta_t^2 43 | 44 | # update parameters 45 | param.data.add_((m_new.div(self.state[param]['momentum_buffer2'].pow(0.5).add(tau))).mul(lr)) 46 | elif idx == 1: # idx == 1: buffers; just averaging 47 | param.data.add_(delta) 48 | return loss 49 | 50 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name): 51 | for group in self.param_groups: 52 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator): 53 | if check_if(name): 54 | server_param.data.zero_() 55 | server_param.data.grad = torch.zeros_like(server_param) 56 | continue 57 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype) 58 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates! 59 | server_param.grad = local_delta 60 | else: 61 | server_param.grad.data.add_(local_delta) 62 | -------------------------------------------------------------------------------- /src/algorithm/fedadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basealgorithm import BaseOptimizer 4 | 5 | 6 | 7 | class FedadamOptimizer(BaseOptimizer, torch.optim.Optimizer): 8 | def __init__(self, params, **kwargs): 9 | lr = kwargs.get('lr') 10 | v0 = kwargs.get('v0') 11 | tau = kwargs.get('tau') 12 | momentum = kwargs.get('betas') 13 | defaults = dict(lr=lr, momentum=momentum, v0=v0, tau=tau) 14 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults) 15 | 16 | def step(self, closure=None): 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for idx, group in enumerate(self.param_groups): 22 | (beta1, beta2) = group['momentum'] 23 | tau = group['tau'] 24 | lr = group['lr'] 25 | v0 = group['v0'] 26 | for param in group['params']: 27 | if param.grad is None: 28 | continue 29 | # get (\Delta_t) 30 | delta = -param.grad.data 31 | 32 | if idx == 0: # idx == 0: parameters; optimize according to algorithm 33 | # calculate m_t 34 | if 'momentum_buffer1' not in self.state[param]: 35 | self.state[param]['momentum_buffer1'] = torch.zeros_like(param).detach() 36 | self.state[param]['momentum_buffer1'].mul_(beta1).add_(delta.mul(1. - beta1)) # \beta1 * m_t + (1 - \beta1) * \Delta_t 37 | m_new = self.state[param]['momentum_buffer1'] 38 | 39 | # calculate v_t 40 | if 'momentum_buffer2' not in self.state[param]: 41 | self.state[param]['momentum_buffer2'] = v0 * beta2 + delta.pow(2).mul(1. - beta2) 42 | self.state[param]['momentum_buffer2'].mul_(beta2).add_(delta.pow(2).mul(1. - beta2)) # \beta2 * v_t + (1 - \beta2) * \Delta_t^2 43 | v_new = self.state[param]['momentum_buffer2'] 44 | 45 | # update parameters 46 | param.data.add_(m_new.div(v_new.pow(0.5).add(tau)).mul(lr)) 47 | elif idx == 1: # idx == 1: buffers; just averaging 48 | param.data.add_(delta) 49 | return loss 50 | 51 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name): 52 | for group in self.param_groups: 53 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator): 54 | if check_if(name): 55 | server_param.data.zero_() 56 | server_param.data.grad = torch.zeros_like(server_param) 57 | continue 58 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype) 59 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates! 60 | server_param.grad = local_delta 61 | else: 62 | server_param.grad.data.add_(local_delta) 63 | -------------------------------------------------------------------------------- /src/algorithm/fedavg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basealgorithm import BaseOptimizer 4 | 5 | 6 | 7 | class FedavgOptimizer(BaseOptimizer, torch.optim.Optimizer): 8 | def __init__(self, params, **kwargs): 9 | self.lr = kwargs.get('lr') 10 | self.momentum = kwargs.get('momentum', 0.) 11 | defaults = dict(lr=self.lr, momentum=self.momentum) 12 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults) 13 | 14 | def step(self, closure=None): 15 | loss = None 16 | if closure is not None: 17 | loss = closure() 18 | 19 | for idx, group in enumerate(self.param_groups): 20 | beta = group['momentum'] 21 | for param in group['params']: 22 | if param.grad is None: 23 | continue 24 | delta = param.grad.data 25 | 26 | if idx == 0: # idx == 0: parameters; optimize according to algorithm // idx == 1: buffers; just averaging 27 | if beta > 0.: 28 | if 'momentum_buffer' not in self.state[param]: 29 | self.state[param]['momentum_buffer'] = torch.zeros_like(param).detach() 30 | self.state[param]['momentum_buffer'].mul_(beta).add_(delta.mul(1. - beta)) # \beta * v + (1 - \beta) * grad 31 | delta = self.state[param]['momentum_buffer'] 32 | param.data.sub_(delta) 33 | return loss 34 | 35 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name): 36 | for group in self.param_groups: 37 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator): 38 | if check_if(name): 39 | server_param.data.zero_() 40 | server_param.data.grad = torch.zeros_like(server_param) 41 | continue 42 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype) 43 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates! 44 | server_param.grad = local_delta 45 | else: 46 | server_param.grad.data.add_(local_delta) 47 | -------------------------------------------------------------------------------- /src/algorithm/fedavgm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .fedavg import FedavgOptimizer 4 | 5 | 6 | 7 | class FedavgmOptimizer(FedavgOptimizer): 8 | def __init__(self, params, **kwargs): 9 | super(FedavgmOptimizer, self).__init__(params=params, **kwargs) 10 | -------------------------------------------------------------------------------- /src/algorithm/fedprox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .fedavg import FedavgOptimizer 4 | 5 | 6 | 7 | class FedproxOptimizer(FedavgOptimizer): 8 | def __init__(self, params, **kwargs): 9 | super(FedproxOptimizer, self).__init__(params=params, **kwargs) 10 | -------------------------------------------------------------------------------- /src/algorithm/fedsgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .fedavg import FedavgOptimizer 4 | 5 | 6 | 7 | class FedsgdOptimizer(FedavgOptimizer): 8 | def __init__(self, params, **kwargs): 9 | super(FedsgdOptimizer, self).__init__(params=params, **kwargs) 10 | 11 | def step(self, closure=None): 12 | loss = None 13 | if closure is not None: 14 | loss = closure() 15 | 16 | for idx, group in enumerate(self.param_groups): 17 | if idx == 1: continue 18 | beta = group['momentum'] 19 | for param in group['params']: 20 | if param.grad is None: 21 | continue 22 | delta = param.grad.mul(group['lr']) 23 | if beta > 0.: 24 | if 'momentum_buffer' not in self.state[param]: 25 | self.state[param]['momentum_buffer'] = torch.zeros_like(param).detach() 26 | self.state[param]['momentum_buffer'].mul_(beta).add_(delta.mul(1. - beta)) # \beta * v + (1 - \beta) * (lr * grad) 27 | delta = self.state[param]['momentum_buffer'] 28 | param.data.sub_(delta) 29 | return loss 30 | 31 | def accumulate(self, mixing_coefficient, local_layers_iterator): 32 | for idx, group in enumerate(self.param_groups): 33 | if idx == 1: continue # idx == 0: parameters; idx == 1: buffers - ignore as these are not parameters 34 | for server_param, (_, local_param) in zip(group['params'], local_layers_iterator): 35 | local_delta = local_param.grad.mul(mixing_coefficient).data.type(server_param.dtype) 36 | if server_param.grad is None: 37 | server_param.grad = local_delta 38 | else: 39 | server_param.grad.data.add_(local_delta) 40 | -------------------------------------------------------------------------------- /src/algorithm/fedyogi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .basealgorithm import BaseOptimizer 4 | 5 | 6 | 7 | class FedyogiOptimizer(BaseOptimizer, torch.optim.Optimizer): 8 | def __init__(self, params, **kwargs): 9 | lr = kwargs.get('lr') 10 | v0 = kwargs.get('v0') 11 | tau = kwargs.get('tau') 12 | momentum = kwargs.get('betas') 13 | defaults = dict(lr=lr, momentum=momentum, v0=v0, tau=tau) 14 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults) 15 | 16 | def step(self, closure=None): 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for idx, group in enumerate(self.param_groups): 22 | (beta1, beta2) = group['momentum'] 23 | tau = group['tau'] 24 | lr = group['lr'] 25 | v0 = group['v0'] 26 | for param in group['params']: 27 | if param.grad is None: 28 | continue 29 | # get (\Delta_t) 30 | delta = -param.grad.data 31 | 32 | if idx == 0: # idx == 0: parameters; optimize according to algorithm 33 | # calculate m_t 34 | if 'momentum_buffer1' not in self.state[param]: 35 | self.state[param]['momentum_buffer1'] = torch.zeros_like(param).detach() 36 | self.state[param]['momentum_buffer1'].mul_(beta1).add_(delta.mul(1. - beta1)) # \beta1 * m_t + (1 - \beta1) * \Delta_t 37 | m_new = self.state[param]['momentum_buffer1'] 38 | 39 | # calculate v_t 40 | if 'momentum_buffer2' not in self.state[param]: 41 | self.state[param]['momentum_buffer2'] = v0 - delta.pow(2).mul(1. - beta2).mul((v0 - delta).sign()) 42 | v_curr = self.state[param]['momentum_buffer2'] 43 | self.state[param]['momentum_buffer2'].sub_(delta.pow(2).mul(1. - beta2).mul(v_curr.sub(delta.pow(2)).sign())) # v_t - (1 - \beta2) * \Delta_t^2 * sgn(v_t - \Delta_t) 44 | v_new = self.state[param]['momentum_buffer2'] 45 | 46 | # update parameters 47 | param.data.add_((m_new.div(v_new.pow(0.5).add(tau))).mul(lr)) 48 | elif idx == 1: # idx == 1: buffers; just averaging 49 | param.data.add_(delta) 50 | return loss 51 | 52 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name): 53 | for group in self.param_groups: 54 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator): 55 | if check_if(name): 56 | server_param.data.zero_() 57 | server_param.data.grad = torch.zeros_like(server_param) 58 | continue 59 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype) 60 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates! 61 | server_param.grad = local_delta 62 | else: 63 | server_param.grad.data.add_(local_delta) 64 | -------------------------------------------------------------------------------- /src/algorithm/sent140.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | ## 2,460 clients, 2 classes 3 | 4 | for s in 42 1023 59999 5 | do 6 | for a in 0 0.1 1 10 7 | do 8 | python3 main.py \ 9 | --exp_name "Sent140_FedAvg_Fixed_${a} (${s})" --seed $s --device cuda:0 \ 10 | --dataset Sent140 --learner fixed --alpha $a \ 11 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \ 12 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 256 --num_layers 2 \ 13 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 5000 --eval_metrics acc1 \ 14 | --R 500 --C 0.0021 --E 5 --B 10 \ 15 | --optimizer SGD --lr 0.3 --lr_decay 0.98 --lr_decay_step 1 --criterion BCEWithLogitsLoss & 16 | 17 | python3 main.py \ 18 | --exp_name "Sent140_FedAvg_AdaHedge_${a} (${s})" --seed $s --device cuda:0 \ 19 | --dataset Sent140 --learner ah --alpha $a \ 20 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \ 21 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 256 --num_layers 2 \ 22 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 5000 --eval_metrics acc1 \ 23 | --R 500 --C 0.0021 --E 5 --B 10 \ 24 | --optimizer SGD --lr 0.3 --lr_decay 0.98 --lr_decay_step 1 --criterion BCEWithLogitsLoss & 25 | 26 | python3 main.py \ 27 | --exp_name "Sent140_FedAvg_SoftBayes_${a} (${s})" --seed $s --device cuda:0 \ 28 | --dataset Sent140 --learner sb --alpha $a \ 29 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \ 30 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 256 --num_layers 2 \ 31 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 100 --eval_metrics acc1 \ 32 | --R 500 --C 0.0021 --E 5 --B 10 \ 33 | --optimizer SGD --lr 0.3 --lr_decay 0.98 --lr_decay_step 1 --criterion BCEWithLogitsLoss 34 | done 35 | wait 36 | done 37 | 38 | python3 main.py \ 39 | --exp_name Sent140_FedAvg_Fixed --seed 42 --device cuda:2 \ 40 | --dataset Sent140 --learner fixed \ 41 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \ 42 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 80 --num_layers 2 \ 43 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 1000 --eval_metrics acc1 \ 44 | --R 1000 --C 0.0021 --E 5 --B 10 \ 45 | --optimizer SGD --lr 0.0003 --lr_decay 0.9995 --lr_decay_step 1 --criterion BCEWithLogitsLoss 46 | 47 | -------------------------------------------------------------------------------- /src/client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/client/__init__.py -------------------------------------------------------------------------------- /src/client/baseclient.py: -------------------------------------------------------------------------------- 1 | from abc import * 2 | 3 | 4 | 5 | class BaseClient(metaclass=ABCMeta): 6 | """Class for client object having its own (private) data and resources to train a model. 7 | """ 8 | def __init__(self, **kwargs): 9 | self.__identifier = None 10 | self.__model = None 11 | 12 | @property 13 | def id(self): 14 | return self.__identifier 15 | 16 | @id.setter 17 | def id(self, identifier): 18 | self.__identifier = identifier 19 | 20 | @property 21 | def model(self): 22 | return self.__model 23 | 24 | @model.setter 25 | def model(self, model): 26 | self.__model = model 27 | 28 | @abstractmethod 29 | def update(self): 30 | raise NotImplementedError 31 | 32 | @abstractmethod 33 | def evaluate(self): 34 | raise NotImplementedError 35 | 36 | @abstractmethod 37 | def download(self): 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def upload(self): 42 | raise NotImplementedError 43 | 44 | @abstractmethod 45 | def __len__(self): 46 | raise NotImplementedError 47 | 48 | @abstractmethod 49 | def __repr__(self): 50 | raise NotImplementedError 51 | -------------------------------------------------------------------------------- /src/client/fedadagradclient.py: -------------------------------------------------------------------------------- 1 | from .fedavgclient import FedavgClient 2 | 3 | 4 | 5 | class FedadagradClient(FedavgClient): 6 | def __init__(self, **kwargs): 7 | super(FedadagradClient, self).__init__(**kwargs) -------------------------------------------------------------------------------- /src/client/fedadamclient.py: -------------------------------------------------------------------------------- 1 | from .fedavgclient import FedavgClient 2 | 3 | 4 | 5 | class FedadamClient(FedavgClient): 6 | def __init__(self, **kwargs): 7 | super(FedadamClient, self).__init__(**kwargs) -------------------------------------------------------------------------------- /src/client/fedavgclient.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import inspect 4 | import itertools 5 | 6 | from .baseclient import BaseClient 7 | from src import MetricManager 8 | 9 | 10 | class FedavgClient(BaseClient): 11 | def __init__(self, args, training_set, test_set): 12 | super(FedavgClient, self).__init__() 13 | self.args = args 14 | self.training_set = training_set 15 | self.test_set = test_set 16 | 17 | self.optim = torch.optim.__dict__[self.args.optimizer] 18 | self.criterion = torch.nn.__dict__[self.args.criterion] 19 | 20 | self.train_loader = self._create_dataloader(self.training_set, shuffle=not self.args.no_shuffle) 21 | self.test_loader = self._create_dataloader(self.test_set, shuffle=False) 22 | 23 | def _refine_optim_args(self, args): 24 | required_args = inspect.getfullargspec(self.optim)[0] 25 | 26 | # collect eneterd arguments 27 | refined_args = {} 28 | for argument in required_args: 29 | if hasattr(args, argument): 30 | refined_args[argument] = getattr(args, argument) 31 | return refined_args 32 | 33 | def _create_dataloader(self, dataset, shuffle): 34 | if self.args.B == 0 : 35 | self.args.B = len(self.training_set) 36 | return torch.utils.data.DataLoader(dataset=dataset, batch_size=self.args.B, shuffle=shuffle) 37 | 38 | def update(self): 39 | mm = MetricManager(self.args.eval_metrics) 40 | self.model.train() 41 | self.model.to(self.args.device) 42 | 43 | optimizer = self.optim(self.model.parameters(), **self._refine_optim_args(self.args)) 44 | for e in range(self.args.E): 45 | for inputs, targets in self.train_loader: 46 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device) 47 | 48 | outputs = self.model(inputs) 49 | loss = self.criterion()(outputs, targets) 50 | 51 | for param in self.model.parameters(): 52 | param.grad = None 53 | loss.backward() 54 | if self.args.max_grad_norm > 0: 55 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 56 | optimizer.step() 57 | 58 | mm.track(loss.item(), outputs, targets) 59 | else: 60 | mm.aggregate(len(self.training_set), e + 1) 61 | else: 62 | self.model.to('cpu') 63 | return mm.results 64 | 65 | @torch.inference_mode() 66 | def evaluate(self): 67 | if self.args.train_only: # `args.test_size` == 0 68 | return {'loss': -1, 'metrics': {'none': -1}} 69 | 70 | mm = MetricManager(self.args.eval_metrics) 71 | self.model.eval() 72 | self.model.to(self.args.device) 73 | 74 | for inputs, targets in self.test_loader: 75 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device) 76 | 77 | outputs = self.model(inputs) 78 | loss = self.criterion()(outputs, targets) 79 | 80 | mm.track(loss.item(), outputs, targets) 81 | else: 82 | self.model.to('cpu') 83 | mm.aggregate(len(self.test_set)) 84 | return mm.results 85 | 86 | def download(self, model): 87 | self.model = copy.deepcopy(model) 88 | 89 | def upload(self): 90 | return itertools.chain.from_iterable([self.model.named_parameters(), self.model.named_buffers()]) 91 | 92 | def __len__(self): 93 | return len(self.training_set) 94 | 95 | def __repr__(self): 96 | return f'CLIENT < {self.id} >' 97 | -------------------------------------------------------------------------------- /src/client/fedavgmclient.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from .fedavgclient import FedavgClient 4 | from src import MetricManager 5 | 6 | 7 | class FedavgmClient(FedavgClient): 8 | def __init__(self, **kwargs): 9 | super(FedavgmClient, self).__init__(**kwargs) -------------------------------------------------------------------------------- /src/client/fedproxclient.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | 4 | from .fedavgclient import FedavgClient 5 | from src import MetricManager 6 | 7 | 8 | class FedproxClient(FedavgClient): 9 | def __init__(self, **kwargs): 10 | super(FedproxClient, self).__init__(**kwargs) 11 | 12 | def update(self): 13 | mm = MetricManager(self.args.eval_metrics) 14 | self.model.train() 15 | self.model.to(self.args.device) 16 | 17 | global_model = copy.deepcopy(self.model) 18 | for param in global_model.parameters(): 19 | param.requires_grad = False 20 | 21 | optimizer = self.optim(self.model.parameters(), **self._refine_optim_args(self.args)) 22 | for e in range(self.args.E): 23 | for inputs, targets in self.train_loader: 24 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device) 25 | 26 | outputs = self.model(inputs) 27 | loss = self.criterion()(outputs, targets) 28 | 29 | prox = 0. 30 | for name, param in self.model.named_parameters(): 31 | prox += (param - global_model.get_parameter(name)).norm(2) 32 | loss += self.args.mu * (0.5 * prox) 33 | 34 | for param in self.model.parameters(): 35 | param.grad = None 36 | loss.backward() 37 | if self.args.max_grad_norm > 0: 38 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 39 | optimizer.step() 40 | 41 | mm.track(loss.item(), outputs, targets) 42 | else: 43 | mm.aggregate(len(self.training_set), e + 1) 44 | else: 45 | self.model.to('cpu') 46 | return mm.results 47 | -------------------------------------------------------------------------------- /src/client/fedsgdclient.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .fedavgclient import FedavgClient 4 | from src import MetricManager 5 | 6 | 7 | 8 | class FedsgdClient(FedavgClient): 9 | def __init__(self, **kwargs): 10 | super(FedsgdClient, self).__init__(**kwargs) 11 | 12 | def update(self): 13 | mm = MetricManager(self.args.eval_metrics) 14 | self.model.train() 15 | self.model.to(self.args.device) 16 | 17 | for inputs, targets in self.train_loader: 18 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device) 19 | 20 | outputs = self.model(inputs) 21 | loss = self.criterion()(outputs, targets) 22 | 23 | for param in self.model.parameters(): 24 | param.grad = None 25 | loss.backward() 26 | if self.args.max_grad_norm > 0: 27 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 28 | mm.track(loss.item(), outputs, targets) 29 | else: 30 | self.model.to('cpu') 31 | mm.aggregate(len(self.training_set), 1) 32 | return mm.results 33 | 34 | def upload(self): 35 | return self.model.named_parameters() 36 | -------------------------------------------------------------------------------- /src/client/fedyogiclient.py: -------------------------------------------------------------------------------- 1 | from .fedavgclient import FedavgClient 2 | 3 | 4 | 5 | class FedyogiClient(FedavgClient): 6 | def __init__(self, **kwargs): 7 | super(FedyogiClient, self).__init__(**kwargs) -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .torchvisionparser import fetch_torchvision_dataset 2 | from .torchtextparser import fetch_torchtext_dataset 3 | from .leafparser import fetch_leaf 4 | from .tinyimagenet import fetch_tinyimagenet 5 | from .cinic10 import fetch_cinic10 6 | from .beerreviews import fetch_beerreviews 7 | from .heart import fetch_heart 8 | from .adult import fetch_adult 9 | from .cover import fetch_cover 10 | from .gleam import fetch_gleam 11 | from .speechcommands import fetch_speechcommands -------------------------------------------------------------------------------- /src/datasets/adult.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torchtext 5 | import pandas as pd 6 | 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import MinMaxScaler 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | 14 | class Adult(torch.utils.data.Dataset): 15 | def __init__(self, education, inputs, targets, scaler): 16 | self.identifier = education 17 | self.inputs, self.targets = inputs, targets 18 | self.scaler = scaler 19 | 20 | @staticmethod 21 | def inverse_transform(self, inputs): 22 | return self.scaler.inverse_transform(inputs) 23 | 24 | def __len__(self): 25 | return len(self.inputs) 26 | 27 | def __getitem__(self, index): 28 | inputs, targets = torch.tensor(self.inputs[index]).float(), torch.tensor(self.targets[index]).long() 29 | return inputs, targets 30 | 31 | def __repr__(self): 32 | return self.identifier 33 | 34 | # helper method to fetch Adult dataset 35 | def fetch_adult(args, root, seed, test_size): 36 | URL = [ 37 | 'http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data', 38 | 'http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test' 39 | ] 40 | MD5 = [ 41 | '5d7c39d7b8804f071cdd1f2a7c460872', 42 | '35238206dfdf7f1fe215bbb874adecdc' 43 | ] 44 | COL_NAME = [ 45 | 'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',\ 46 | 'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss',\ 47 | 'house_per_week', 'native_country', 'targets' 48 | ] 49 | NUM_COL = ['age', 'fnlwgt', 'capital_gain', 'capital_loss', 'house_per_week', 'education_num'] 50 | 51 | def _download(root): 52 | for idx, (url, md5) in enumerate(zip(URL, MD5)): 53 | _ = torchtext.utils.download_from_url( 54 | url=url, 55 | root=root, 56 | hash_value=md5, 57 | hash_type='md5' 58 | ) 59 | os.rename(os.path.join(root, url.split('/')[-1]), os.path.join(root, f"adult_{'train' if idx == 0 else 'test'}.csv")) 60 | 61 | def _munge_and_create_clients(root): 62 | # load dat 63 | df = pd.read_csv(os.path.join(root, 'adult_train.csv'), header=None, names=COL_NAME, na_values='?').dropna().reset_index(drop=True) 64 | df = df.drop(columns=['education']) 65 | 66 | # encode categorical data 67 | for col in df.columns: 68 | if col not in NUM_COL: 69 | replace_map = {key: value for value, key in enumerate(sorted(df[col].unique()))} 70 | df[col] = df[col].replace(replace_map) 71 | 72 | # adjust dtype 73 | for col in df.columns: 74 | if col in NUM_COL: 75 | df[col] = df[col].astype('float') 76 | else: 77 | df[col] = df[col].astype('category') 78 | 79 | # get one-hot encoded dummy columns for categorical data 80 | df = pd.concat([pd.get_dummies(df.iloc[:, :-1], columns=[col for col in df.columns if col not in NUM_COL][:-1], drop_first=True, dtype=int), df[['targets']]], axis=1) 81 | 82 | # creat clients by education 83 | clients = {} 84 | for edu in df['education_num'].unique(): 85 | clients[edu] = df.loc[df['education_num'] == edu] 86 | return clients 87 | 88 | def _process_client_datasets(dataset, seed, test_size): 89 | # remove identifier column 90 | edu = int(dataset['education_num'].unique()[0]) 91 | df = dataset.drop(columns=['education_num']) 92 | inputs, targets = df.iloc[:, :-1].values, df.iloc[:, -1].values 93 | 94 | # train-test split with stratified manner 95 | train_inputs, test_inputs, train_targets, test_targets = train_test_split(inputs, targets, test_size=test_size, random_state=seed, stratify=targets) 96 | 97 | # scaling inputs 98 | scaler = MinMaxScaler() 99 | train_inputs[:, :5] = scaler.fit_transform(train_inputs[:, :5]) 100 | test_inputs[:, :5] = scaler.transform(test_inputs[:, :5]) 101 | return ( 102 | Adult(f'[ADULT] CLIENT < Edu{str(edu).zfill(2)} > (train)', train_inputs, train_targets, scaler), 103 | Adult(f'[ADULT] CLIENT < Edu{str(edu).zfill(2)} > (test)', test_inputs, test_targets, scaler) 104 | ) 105 | 106 | logger.info(f'[LOAD] [ADULT] Check if raw data exists; if not, start downloading!') 107 | if not os.path.exists(os.path.join(root, 'adult')): 108 | _download(root=os.path.join(root, 'adult')) 109 | logger.info(f'[LOAD] [ADULT] ...raw data is successfully downloaded!') 110 | else: 111 | logger.info(f'[LOAD] [ADULT] ...raw data already exists!') 112 | 113 | logger.info(f'[LOAD] [ADULT] Munging dataset and create clients!') 114 | raw_clients = _munge_and_create_clients(os.path.join(root, 'adult')) 115 | logger.info('[LOAD] [ADULT] ...munged dataset and created clients!!') 116 | 117 | logger.info(f'[LOAD] [ADULT] Processing client datsets!') 118 | client_datasets = [] 119 | for dataset in raw_clients.values(): 120 | client_datasets.append(_process_client_datasets(dataset, seed, test_size)) 121 | logger.info('[LOAD] [ADULT] ...processed client datasets!') 122 | 123 | args.in_features = 84 124 | args.num_classes = 2 125 | args.K = 16 126 | return {}, client_datasets, args 127 | -------------------------------------------------------------------------------- /src/datasets/beerreviews.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import logging 5 | import torchtext 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | 11 | class BeerReviews(torch.utils.data.Dataset): 12 | URL = { 13 | 'look': 'http://people.csail.mit.edu/yujia/files/ls/beer/look.json', 14 | 'aroma': 'http://people.csail.mit.edu/yujia/files/ls/beer/aroma.json' 15 | } 16 | MD5 = { 17 | 'look': '4ad6dd806554ec50ad80b2acedda38d4', 18 | 'aroma': 'd5bc425fb075198a2d4792d33690a3fd' 19 | } 20 | ASPECT = ('look', 'aroma') 21 | 22 | def __init__(self, root, aspect, tokenizer=None, download=True): 23 | assert aspect in self.ASPECT, f'Unknown aspect {aspect}!' 24 | 25 | self.root = os.path.expanduser(root) 26 | self.aspect = aspect 27 | 28 | if download: 29 | self.download() 30 | if not self._check_exists(): 31 | err = 'Dataset not found or corrupted. You can use download=True to download it' 32 | logger.exception(err) 33 | raise RuntimeError(err) 34 | 35 | # get the data and binary targets 36 | self.inputs, self.targets = self.load_json() 37 | 38 | # Set the maximum sequence length and dataset length 39 | self.max_seq_len = max([len(text) for text in self.inputs]) 40 | self.length = len(self.targets) 41 | 42 | # get word embeddings from fasttext 43 | self.emb_bag = self.get_vocab('FastText') 44 | 45 | # set tokenizer 46 | self.tokenizer = tokenizer 47 | 48 | def _check_exists(self): 49 | return os.path.exists(os.path.join(self.root, 'beer')) 50 | 51 | def download(self): 52 | if self._check_exists(): 53 | return 54 | _ = torchtext.utils.download_from_url( 55 | url=self.URL[self.aspect], 56 | root=os.path.join(self.root, 'beer'), 57 | hash_value=self.MD5[self.aspect], 58 | hash_type='md5' 59 | ) 60 | 61 | def load_json(self): 62 | inputs, targets = [], [] 63 | path = f'{self.root}/beer/{self.aspect}.json' 64 | with open(path, 'r') as f: 65 | for line in f: 66 | example = json.loads(line) 67 | targets.append(example['y']) 68 | inputs.append(example['text']) 69 | targets = torch.tensor(targets) 70 | return inputs, targets 71 | 72 | def get_vocab(self, name='FastText'): 73 | vocab = getattr(torchtext.vocab, name)() 74 | 75 | # Add the pad token 76 | specials = [''] 77 | for token in specials: 78 | vocab.stoi[token] = len(vocab.itos) 79 | vocab.itos.append(token) 80 | vocab.vectors = torch.cat([vocab.vectors, torch.zeros(1, 300)], dim=0) 81 | return vocab 82 | 83 | def __len__(self): 84 | return self.length 85 | 86 | def __getitem__(self, index): 87 | text = self.inputs[index] 88 | if self.tokenizer is None: 89 | padded_text = self.emb_bag.stoi[''] * torch.ones(self.max_seq_len) 90 | padded_text[:len(text)] = torch.tensor([ 91 | self.emb_bag.stoi[token] if token in self.emb_bag.stoi 92 | else self.emb_bag.stoi['unk'] for token in text 93 | ]) 94 | inputs = torch.nn.functional.embedding(padded_text.long(), self.emb_bag.vectors).detach() 95 | else: 96 | inputs = self.tokenizer( 97 | text, 98 | return_tensors='pt', 99 | is_split_into_words=True, 100 | max_length=self.max_seq_len, 101 | return_attention_mask=False, 102 | truncation=True, 103 | padding='max_length' 104 | )['input_ids'] 105 | targets = self.targets[index] 106 | return inputs, targets 107 | 108 | def __repr__(self): 109 | return f'[BeerReviews ({self.aspect})] CLIENT' 110 | 111 | # helper method to fetch Beer Reviews dataset 112 | def fetch_beerreviews(args, root, aspect='look', tokenizer=None): 113 | logger.info(f'[LOAD] [BEERREVIEWS] Fetching dataset!') 114 | 115 | # create training dataset instance 116 | raw_train = BeerReviews(root, aspect, tokenizer) 117 | 118 | # create test dataset instance 119 | raw_test = None 120 | 121 | logger.info('[LOAD] [BEERREVIEWS] ...fetched dataset!') 122 | 123 | # adjust argument 124 | args.in_features = 300 125 | args.num_classes = 2 126 | if tokenizer is None: # use FastText embedding 127 | args.num_embedings = len(raw_train.emb_bag) 128 | args.embedding_size = raw_train.emb_bag.dim 129 | return raw_train, raw_test, args 130 | -------------------------------------------------------------------------------- /src/datasets/cinic10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torchvision 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | 10 | # dataset wrapper module 11 | class CINIC10(torchvision.datasets.ImageFolder): 12 | base_folder = 'cinic-10-batches-py' 13 | zip_md5 ='6ee4d0c996905fe93221de577967a372' 14 | splits = ('train', 'val', 'test') 15 | filename = 'CINIC-10.tar.gz' 16 | url = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz' 17 | 18 | def __init__(self, root, split='train', download=True, transform=None, **kwargs): 19 | self.data_root = os.path.expanduser(root) 20 | self.split = torchvision.datasets.utils.verify_str_arg(split, 'split', self.splits) 21 | if download: 22 | self.download() 23 | if not self._check_exists(): 24 | err = 'Dataset not found or corrupted. You can use download=True to download it' 25 | logger.exception(err) 26 | raise RuntimeError(err) 27 | super().__init__(root=self.split_folder, transform=transform, **kwargs) 28 | 29 | @property 30 | def dataset_folder(self): 31 | return os.path.join(self.data_root, self.base_folder) 32 | 33 | @property 34 | def split_folder(self): 35 | return os.path.join(self.dataset_folder, self.split) 36 | 37 | def _check_exists(self): 38 | return os.path.exists(self.split_folder) 39 | 40 | def download(self): 41 | if self._check_exists(): 42 | return 43 | torchvision.datasets.utils.download_and_extract_archive( 44 | self.url, self.dataset_folder, filename=self.filename, 45 | remove_finished=True, md5=self.zip_md5 46 | ) 47 | 48 | def __repr__(self): 49 | rep_str = {'train': 'CLIENT', 'test': 'SERVER'} 50 | return f'[CINIC10] {rep_str[self.split]}' 51 | 52 | # helper method to fetch CINIC-10 dataset 53 | def fetch_cinic10(args, root, transforms): 54 | logger.info('[LOAD] [CINIC10] Fetching dataset!') 55 | 56 | # default arguments 57 | DEFAULT_ARGS = {'root': root, 'transform': None, 'download': True} 58 | 59 | # configure arguments for training/test dataset 60 | train_args = DEFAULT_ARGS.copy() 61 | train_args['split'] = 'train' 62 | train_args['transform'] = transforms[0] 63 | 64 | # create training dataset instance 65 | raw_train = CINIC10(**train_args) 66 | 67 | # for global holdout set 68 | test_args = DEFAULT_ARGS.copy() 69 | test_args['transform'] = transforms[1] 70 | test_args['split'] = 'test' 71 | 72 | # create test dataset instance 73 | raw_test = CINIC10(**test_args) 74 | 75 | logger.info('[LOAD] [CINIC10] ...fetched dataset!') 76 | 77 | # adjust arguments 78 | args.in_channels = 3 79 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.targets))) 80 | return raw_train, raw_test, args 81 | -------------------------------------------------------------------------------- /src/datasets/cover.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torchvision 5 | import pandas as pd 6 | 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import StandardScaler 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | 14 | class Cover(torch.utils.data.Dataset): 15 | def __init__(self, groupby, inputs, targets, scaler): 16 | self.identifier = groupby 17 | self.inputs, self.targets = inputs, targets 18 | self.scaler = scaler 19 | 20 | @staticmethod 21 | def inverse_transform(self, inputs): 22 | return self.scaler.inverse_transform(inputs) 23 | 24 | def __len__(self): 25 | return len(self.inputs) 26 | 27 | def __getitem__(self, index): 28 | inputs, targets = torch.tensor(self.inputs[index]).float(), torch.tensor(self.targets[index]).long() 29 | return inputs, targets 30 | 31 | def __repr__(self): 32 | return self.identifier 33 | 34 | # helper method to fetch Cover type classification dataset 35 | # NOTE: data is grouped and split by `wilderness_area` 36 | def fetch_cover(args, root, seed, test_size): 37 | URL = 'http://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz' 38 | MD5 = '99670d8d942f09d459c7d4486fca8af5' 39 | COL_NAME = [ 40 | 'elevation', 'aspect', 'slope',\ 41 | 'horizontal_distance_to_hydrology', 'vertical_distance_to_hydrology', 'horizontal_distance_to_roadways',\ 42 | 'hillshade_9am', 'hiillshade_noon', 'hillshade_3pm',\ 43 | 'horizontal_distance_to_fire_points',\ 44 | 'wilderness_area',\ 45 | 'soil_type_0', 'soil_type_1', 'soil_type_2', 'soil_type_3', 'soil_type_4',\ 46 | 'soil_type_5', 'soil_type_6', 'soil_type_7', 'soil_type_8', 'soil_type_9',\ 47 | 'soil_type_10', 'soil_type_11', 'soil_type_12', 'soil_type_13', 'soil_type_14',\ 48 | 'soil_type_15', 'soil_type_16', 'soil_type_17', 'soil_type_18', 'soil_type_19',\ 49 | 'soil_type_20', 'soil_type_21', 'soil_type_22', 'soil_type_23', 'soil_type_24',\ 50 | 'soil_type_25', 'soil_type_26', 'soil_type_27', 'soil_type_28', 'soil_type_29',\ 51 | 'soil_type_30', 'soil_type_31', 'soil_type_32', 'soil_type_33', 'soil_type_34',\ 52 | 'soil_type_35', 'soil_type_36', 'soil_type_37', 'soil_type_38', 'soil_type_39',\ 53 | 'cover_type' 54 | ] 55 | AREA = ['Rawah', 'Neota', 'Comanche Peak', 'Cache la Poudre'] 56 | 57 | def _download(root): 58 | torchvision.datasets.utils.download_and_extract_archive( 59 | URL, root, filename= URL.split('/')[-1], 60 | remove_finished=True, md5=MD5 61 | ) 62 | os.rename(os.path.join(root, 'covtype.data'), os.path.join(root, 'covtype.csv')) 63 | 64 | def _munge_and_split(root, seed, test_size): 65 | # load data 66 | df = pd.read_csv(os.path.join(root, 'covtype.csv'), header=None) 67 | 68 | # reverse one-hot encoded columns 69 | wilderness_area = pd.Series(df.iloc[:, 10:14].values.argmax(1)) 70 | 71 | # concatenate into one dataframe 72 | df_raw = pd.concat([df.iloc[:, :10], wilderness_area, df.iloc[:, 14:-1], df.iloc[:, -1].sub(1)], axis=1) 73 | 74 | # rename column 75 | df_raw.columns = COL_NAME 76 | 77 | # split by wilderness area 78 | client_datasets = [] 79 | for idx, name in enumerate(AREA): 80 | # get dataframe 81 | df_temp = df_raw[df_raw['wilderness_area'] == idx].reset_index(drop=True) 82 | 83 | # get inputs and targets 84 | inputs, targets = df_temp.iloc[:, :-1].values.astype('float'), df_temp.iloc[:, -1].values.astype('float') 85 | 86 | # train-test split with stratified manner 87 | train_inputs, test_inputs, train_targets, test_targets = train_test_split(inputs, targets, test_size=test_size, random_state=seed, stratify=targets) 88 | 89 | # scaling inputs 90 | scaler = StandardScaler() 91 | train_inputs[:, :-40] = scaler.fit_transform(train_inputs[:, :-40]) # exclude last 40 columns (`soil type` - one-hot encoded categorical variable) 92 | test_inputs[:, :-40] = scaler.transform(test_inputs[:, :-40]) 93 | 94 | # assign as a client dataset 95 | client_datasets.append( 96 | ( 97 | Cover(f'[COVER] CLIENT < {name} > (train)', train_inputs, train_targets, scaler), 98 | Cover(f'[COVER] CLIENT < {name} > (test)', test_inputs, test_targets, scaler) 99 | ) 100 | ) 101 | return client_datasets 102 | 103 | logger.info(f'[LOAD] [COVER] Check if raw data exists; if not, start downloading!') 104 | if not os.path.exists(os.path.join(root, 'covertype')): 105 | _download(root=os.path.join(root, 'covertype')) 106 | logger.info(f'[LOAD] [COVER] ...raw data is successfully downloaded!') 107 | else: 108 | logger.info(f'[LOAD] [COVER] ...raw data already exists!') 109 | 110 | logger.info(f'[LOAD] [COVER] Munging and splitting dataset!') 111 | client_datasets = _munge_and_split(os.path.join(root, 'covertype'), seed, test_size) 112 | logger.info('[LOAD] [COVER] ...munged and splitted dataset!') 113 | 114 | args.in_features = 51 115 | args.num_classes = 7 116 | args.K = 4 117 | return {}, client_datasets, args 118 | -------------------------------------------------------------------------------- /src/datasets/heart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torchtext 5 | import pandas as pd 6 | 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import StandardScaler 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | 14 | class Heart(torch.utils.data.Dataset): 15 | def __init__(self, hospital, inputs, targets, scaler): 16 | self.identifier = hospital 17 | self.inputs, self.targets = inputs, targets 18 | self.scaler = scaler 19 | 20 | @staticmethod 21 | def inverse_transform(self, inputs): 22 | return self.scaler.inverse_transform(inputs) 23 | 24 | def __len__(self): 25 | return len(self.inputs) 26 | 27 | def __getitem__(self, index): 28 | inputs, targets = torch.tensor(self.inputs[index]).float(), torch.tensor(self.targets[index]).long() 29 | return inputs, targets 30 | 31 | def __repr__(self): 32 | return self.identifier 33 | 34 | # helper method to fetch Heart disease classification dataset 35 | def fetch_heart(args, root, seed, test_size): 36 | URL = { 37 | 'cleveland': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data', 38 | 'hungarian': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.hungarian.data', 39 | 'switzerland': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.switzerland.data', 40 | 'va': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.va.data' 41 | } 42 | MD5 = { 43 | 'cleveland': '2d91a8ff69cfd9616aa47b59d6f843db', 44 | 'hungarian': '22e96bee155b5973568101c93b3705f6', 45 | 'switzerland': '9a87f7577310b3917730d06ba9349e20', 46 | 'va': '4249d03ca7711e84f4444768c9426170' 47 | } 48 | COL_NAME = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'targets'] 49 | 50 | def _download(root): 51 | for hospital in URL.keys(): 52 | _ = torchtext.utils.download_from_url( 53 | url=URL[hospital], 54 | root=root, 55 | hash_value=MD5[hospital], 56 | hash_type='md5' 57 | ) 58 | os.rename(os.path.join(root, URL[hospital].split('/')[-1]), os.path.join(root, f'HEART ({hospital}).csv')) 59 | 60 | def _munge_and_split(root, hospital, seed, test_size): 61 | # load data 62 | to_drop = [ 63 | 10, # the slope of the peak exercise ST segment 64 | 11, # number of major vessels (0-3) colored by flourosopy 65 | 12 # thalassemia background 66 | ] 67 | df = pd.read_csv(os.path.join(root, f'HEART ({hospital}).csv'), header=None, na_values='?', usecols=[i for i in range(14) if i not in to_drop]).apply(lambda x: x.fillna(x.mean()),axis=0).reset_index(drop=True) 68 | 69 | # rename column 70 | df.columns = COL_NAME 71 | 72 | # adjust dtypes 73 | df['targets'] = df['targets'].where(df['targets'] == 0, 1) 74 | df['age'] = df['age'].astype(float) 75 | df['sex'] = df['sex'].astype(int) 76 | df['cp'] = df['cp'].astype(int) 77 | df['trestbps'] = df['trestbps'].astype(float) 78 | df['chol'] = df['chol'].astype(float) 79 | df['restecg'] = df['restecg'].astype(int) 80 | df['cp'] = df['cp'].astype(int) 81 | df['thalach'] = df['thalach'].astype(float) 82 | df['exang'] = df['exang'].astype(int) 83 | df['oldpeak'] = df['oldpeak'].astype(float) 84 | 85 | # get one-hot encoded dummy columns for categorical data 86 | df = pd.concat([pd.get_dummies(df.iloc[:, :-1], columns=['cp', 'restecg'], drop_first=True, dtype=int), df[['targets']]], axis=1) 87 | 88 | # get inputs and targets 89 | inputs, targets = df.iloc[:, :-1].values, df.iloc[:, -1].values 90 | 91 | # train-test split with stratified manner 92 | train_inputs, test_inputs, train_targets, test_targets = train_test_split(inputs, targets, test_size=test_size, random_state=seed, stratify=targets) 93 | 94 | # scaling inputs 95 | scaler = StandardScaler() 96 | train_inputs = scaler.fit_transform(train_inputs) 97 | test_inputs = scaler.transform(test_inputs) 98 | return ( 99 | Heart(f'[HEART] CLIENT < {hospital} > (train)', train_inputs, train_targets, scaler), 100 | Heart(f'[HEART] CLIENT < {hospital} > (test)', test_inputs, test_targets, scaler) 101 | ) 102 | 103 | logger.info(f'[LOAD] [HEART] Check if raw data exists; if not, start downloading!') 104 | if not os.path.exists(os.path.join(root, 'heart')): 105 | _download(root=os.path.join(root, 'heart')) 106 | logger.info(f'[LOAD] [HEART] ...raw data is successfully downloaded!') 107 | else: 108 | logger.info(f'[LOAD] [HEART] ...raw data already exists!') 109 | 110 | logger.info(f'[LOAD] [HEART] Munging and splitting dataset!') 111 | client_datasets = [] 112 | for hospital in URL.keys(): 113 | client_datasets.append(_munge_and_split(os.path.join(root, 'heart'), hospital, seed, test_size)) 114 | logger.info('[LOAD] [HEART] ...munged and splitted dataset!') 115 | 116 | args.in_features = 13 117 | args.num_classes = 2 118 | args.K = 4 119 | return {}, client_datasets, args 120 | -------------------------------------------------------------------------------- /src/datasets/leaf/__init__.py: -------------------------------------------------------------------------------- 1 | from .postprocess import postprocess_leaf 2 | from .leaf_utils import * -------------------------------------------------------------------------------- /src/datasets/leaf/leaf_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import logging 4 | import requests 5 | import torchvision 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | __all__ = ['download_data'] 10 | 11 | 12 | 13 | URL = { 14 | 'femnist': [ 15 | 'https://s3.amazonaws.com/nist-srd/SD19/by_class.zip', 16 | 'https://s3.amazonaws.com/nist-srd/SD19/by_write.zip' 17 | ], 18 | 'shakespeare': ['http://www.gutenberg.org/files/100/old/1994-01-100.zip'], 19 | 'sent140': [ 20 | 'http://cs.stanford.edu/people/alecmgo/trainingandtestdata.zip', 21 | 'http://nlp.stanford.edu/data/glove.6B.zip' # GloVe embedding for vocabularies 22 | ], 23 | 'celeba': [ 24 | '1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS', # Google Drive link ID 25 | '0B7EVK8r0v71pblRyaVFSWGxPY0U', # Google Drive link ID 26 | 'https://cseweb.ucsd.edu/~weijian/static/datasets/celeba/img_align_celeba.zip' # img_align_celeba.zip 27 | ], 28 | 'reddit': ['1ISzp69JmaIJqBpQCX-JJ8-kVyUns8M7o'] # Google Drive link ID 29 | } 30 | OPT = { # md5 checksum if direct URL link is provided, file name if Google Drive link ID is provided 31 | 'femnist': ['79572b1694a8506f2b722c7be54130c4', 'a29f21babf83db0bb28a2f77b2b456cb'], 32 | 'shakespeare': ['b8d60664a90939fa7b5d9f4dd064a1d5'], 33 | 'sent140': ['1647eb110dd2492512e27b9a70d5d1bc', '056ea991adb4740ac6bf1b6d9b50408b'], 34 | 'celeba': ['identity_CelebA.txt', 'list_attr_celeba.txt', '00d2c5bc6d35e252742224ab0c1e8fcb'], 35 | 'reddit': ['reddit_subsampled.zip'] 36 | } 37 | 38 | def download_data(download_root, dataset_name): 39 | """Download data from Google Drive and extract if it is archived. 40 | """ 41 | def _get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | def _save_response_content(download_root, response): 48 | CHUNK_SIZE = 32768 49 | with open(download_root, 'wb') as file: 50 | for chunk in response.iter_content(CHUNK_SIZE): 51 | if chunk: # filter out keezp-alive new chunks 52 | file.write(chunk) 53 | 54 | def _download_file_from_google_drive(download_root, file_name, identifier): 55 | BASE_URL = 'https://docs.google.com/uc?export=download' 56 | 57 | session = requests.Session() 58 | response = session.get(BASE_URL, params={'id': identifier, 'confirm': 1}, stream=True) 59 | token = _get_confirm_token(response) 60 | 61 | if token: 62 | params = {'id': identifier, 'confirm': token } 63 | response = session.get(BASE_URL, params=params, stream=True) 64 | _save_response_content(os.path.join(download_root, file_name), response) 65 | print(f'...successfully downloaded file `{file_name}` at `{download_root}`!') 66 | 67 | if '.zip' in file_name: 68 | with zipfile.ZipFile(os.path.join(download_root, file_name), 'r', compression=zipfile.ZIP_STORED) as zip_file: 69 | zip_file.extractall(download_root) 70 | print(f'...successfully extracted `{file_name}` at `{download_root}`!') 71 | 72 | # download data from web 73 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Start downloading data...!') 74 | try: 75 | for (url, opt) in zip(URL[dataset_name], OPT[dataset_name]): 76 | if 'http' not in url: 77 | _download_file_from_google_drive(download_root, opt, url) 78 | else: 79 | torchvision.datasets.utils.download_and_extract_archive( 80 | url=url, 81 | download_root=download_root, 82 | md5=opt, 83 | remove_finished=True 84 | ) 85 | else: 86 | logger.info(f'[LEAF - {dataset_name.upper()}] ...finished downloading data!') 87 | except: 88 | logger.exception(url) 89 | raise Exception(url) 90 | -------------------------------------------------------------------------------- /src/datasets/leaf/postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .postprocess import postprocess_leaf -------------------------------------------------------------------------------- /src/datasets/leaf/postprocess/filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | 6 | def filter_clients(dataset_name, root, min_samples_per_clients): 7 | # set path 8 | data_dir = os.path.join(root, dataset_name) 9 | subdir = os.path.join(data_dir, 'sampled_data') 10 | 11 | # collect sampled files 12 | files = [] 13 | if os.path.exists(subdir): 14 | files = os.listdir(subdir) 15 | if len(files) == 0: 16 | subdir = os.path.join(data_dir, 'all_data') 17 | files = os.listdir(subdir) 18 | files = [f for f in files if f.endswith('.json')] 19 | 20 | # calculate remaining clients data 21 | for f in files: 22 | users = [] 23 | hierarchies = [] 24 | num_samples = [] 25 | user_data = {} 26 | 27 | file_dir = os.path.join(subdir, f) 28 | with open(file_dir, 'r') as file: 29 | data = json.load(file) 30 | 31 | num_users = len(data['users']) 32 | for i in range(num_users): 33 | curr_user = data['users'][i] 34 | curr_hierarchy = None 35 | if 'hierarchies' in data: 36 | curr_hierarchy = data['hierarchies'][i] 37 | curr_num_samples = data['num_samples'][i] 38 | if curr_num_samples >= min_samples_per_clients: 39 | user_data[curr_user] = data['user_data'][curr_user] 40 | users.append(curr_user) 41 | if curr_hierarchy is not None: 42 | hierarchies.append(curr_hierarchy) 43 | num_samples.append(data['num_samples'][i]) 44 | 45 | # create json file 46 | all_data = {} 47 | all_data['users'] = users 48 | if len(hierarchies) == len(users): 49 | all_data['hierarchies'] = hierarchies 50 | all_data['num_samples'] = num_samples 51 | all_data['user_data'] = user_data 52 | 53 | # save file 54 | with open(os.path.join(data_dir, 'rem_clients_data', f'{f[:-5]}_keep_{str(min_samples_per_clients).zfill(4)}.json'), 'w') as outfile: 55 | json.dump(all_data, outfile) 56 | -------------------------------------------------------------------------------- /src/datasets/leaf/postprocess/postprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | 5 | from src.datasets.leaf.postprocess.sample import sample_clients 6 | from src.datasets.leaf.postprocess.filter import filter_clients 7 | from src.datasets.leaf.postprocess.split import split_datasets 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | 13 | def postprocess_leaf(dataset_name, root, seed, raw_data_fraction, min_samples_per_clients, test_size): 14 | # check if raw data is prepared 15 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Check pre-processing data...!') 16 | if not os.path.exists(f'{root}/{dataset_name}/all_data'): 17 | err = f'[LOAD] [LEAF - {dataset_name.upper()}] Please check if the raw data is correctly prepared in `{root}`!' 18 | raise AssertionError(err) 19 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...data pre-proceesing has been completed!') 20 | 21 | # create client datasets 22 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Sample clients from raw data...!') 23 | if not os.path.exists(f'{root}/{dataset_name}/sampled_data'): 24 | os.makedirs(f'{root}/{dataset_name}/sampled_data') 25 | sample_clients(dataset_name, root, seed, raw_data_fraction) 26 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done sampling clients from raw data!') 27 | 28 | # remove clients with less than given `min_samples_per_clients` 29 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Filter out remaining clients...!') 30 | if not os.path.exists(f'{root}/{dataset_name}/rem_clients_data') and (raw_data_fraction < 1.): 31 | os.makedirs(f'{root}/{dataset_name}/rem_clients_data') 32 | filter_clients(dataset_name, root, min_samples_per_clients) 33 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done filtering remaining clients!') 34 | 35 | # create train-test split 36 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Split into training & test sets...!') 37 | if (not os.path.exists(f'{root}/{dataset_name}/train')) or (not os.path.exists(f'{root}/{dataset_name}/test')): 38 | if not os.path.exists(f'{root}/{dataset_name}/train'): 39 | os.makedirs(f'{root}/{dataset_name}/train') 40 | if not os.path.exists(f'{root}/{dataset_name}/test'): 41 | os.makedirs(f'{root}/{dataset_name}/test') 42 | split_datasets(dataset_name, root, seed, test_size) 43 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done splitting into training & test sets!') 44 | 45 | # get number of clients 46 | train_data = [file for file in os.listdir(os.path.join(root, dataset_name, 'train')) if file.endswith('.json')][0] 47 | num_clients = len(json.load(open(f'{root}/{dataset_name}/train/{train_data}', 'r'))['users']) 48 | return num_clients 49 | -------------------------------------------------------------------------------- /src/datasets/leaf/postprocess/sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from collections import OrderedDict 6 | 7 | 8 | 9 | def sample_clients(dataset_name, root, seed, used_raw_data_fraction): 10 | # set path 11 | data_dir = os.path.join(root, dataset_name) 12 | subdir = os.path.join(data_dir, 'all_data') 13 | files = os.listdir(subdir) 14 | files = [f for f in files if f.endswith('.json')] 15 | 16 | # set seed 17 | rng = random.Random(seed) 18 | 19 | # split data in non-IID manner 20 | for f in files: 21 | file_dir = os.path.join(subdir, f) 22 | with open(file_dir, 'r') as file: 23 | data = json.load(file, object_pairs_hook=OrderedDict) 24 | 25 | # get meta data 26 | num_users = len(data['users']) 27 | tot_num_samples = sum(data['num_samples']) 28 | num_new_samples = int(used_raw_data_fraction * tot_num_samples) 29 | hierarchies = None 30 | 31 | # non-IID split 32 | ctot_num_samples = 0 33 | users = data['users'] 34 | users_and_hiers = None 35 | if 'hierarchies' in data: 36 | users_and_hiers = list(zip(users, data['hierarchies'])) 37 | rng.shuffle(users_and_hiers) 38 | else: 39 | rng.shuffle(users) 40 | user_i = 0 41 | num_samples = [] 42 | user_data = {} 43 | 44 | if 'hierarchies' in data: 45 | hierarchies = [] 46 | 47 | while(ctot_num_samples < num_new_samples): 48 | hierarchy = None 49 | if users_and_hiers is not None: 50 | user, hier = users_and_hiers[user_i] 51 | else: 52 | user = users[user_i] 53 | cdata = data['user_data'][user] 54 | cnum_samples = len(data['user_data'][user]['y']) 55 | 56 | if (ctot_num_samples + cnum_samples) > num_new_samples: 57 | cnum_samples = num_new_samples - ctot_num_samples 58 | indices = [i for i in range(cnum_samples)] 59 | new_indices = rng.sample(indices, cnum_samples) 60 | x, y = [], [] 61 | for i in new_indices: 62 | x.append(data['user_data'][user]['x'][i]) 63 | y.append(data['user_data'][user]['y'][i]) 64 | cdata = {'x': x, 'y': y} 65 | 66 | if 'hierarchies' in data: 67 | hierarchies.append(hier) 68 | 69 | num_samples.append(cnum_samples) 70 | user_data[user] = cdata 71 | ctot_num_samples += cnum_samples 72 | user_i += 1 73 | 74 | if 'hierarchies' in data: 75 | users = [u for u, h in users_and_hiers][:user_i] 76 | else: 77 | users = users[:user_i] 78 | 79 | # create json file 80 | all_data = {} 81 | all_data['users'] = users 82 | if hierarchies is not None: 83 | all_data['hierarchies'] = hierarchies 84 | all_data['num_samples'] = num_samples 85 | all_data['user_data'] = user_data 86 | 87 | # save file 88 | with open(os.path.join(data_dir, 'sampled_data', f'{f[:-5]}_niid_0{str(used_raw_data_fraction)[2:]}.json'), 'w') as out_file: 89 | json.dump(all_data, out_file) 90 | -------------------------------------------------------------------------------- /src/datasets/leaf/postprocess/split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import random 5 | import logging 6 | 7 | from collections import OrderedDict 8 | 9 | 10 | 11 | def split_datasets(dataset_name, root, seed, test_size): 12 | # set path 13 | data_dir = os.path.join(root, dataset_name) 14 | subdir = os.path.join(data_dir, 'rem_clients_data') 15 | 16 | # collect sampled files 17 | files = [] 18 | if os.path.exists(subdir): 19 | files = os.listdir(subdir) 20 | if len(files) == 0: 21 | subdir = os.path.join(data_dir, 'sampled_data') 22 | if os.path.exists(subdir): 23 | files = os.listdir(subdir) 24 | if len(files) == 0: 25 | subdir = os.path.join(data_dir, 'all_data') 26 | files = os.listdir(subdir) 27 | files = [f for f in files if f.endswith('.json')] 28 | 29 | # set seed 30 | rng = random.Random(seed) 31 | 32 | # check if data contains information on hierarchies 33 | file_dir = os.path.join(subdir, files[0]) 34 | with open(file_dir, 'r') as file: 35 | data = json.load(file) 36 | 37 | # split training/test data inside clients 38 | for f in files: 39 | file_dir = os.path.join(subdir, f) 40 | with open(file_dir, 'r') as file: 41 | data = json.load(file, object_pairs_hook=OrderedDict) 42 | 43 | num_samples_train, num_samples_test = [], [] 44 | user_data_train, user_data_test = {}, {} 45 | user_indices = [] # indices of users in `data['users']` that are not deleted 46 | 47 | for i, u in enumerate(data['users']): 48 | curr_num_samples = len(data['user_data'][u]['y']) 49 | if curr_num_samples >= 2: 50 | # ensures number of train and test samples both >= 1 51 | num_train_samples = max(1, int((1. - test_size) * curr_num_samples)) 52 | if curr_num_samples == 2: 53 | num_train_samples = 1 54 | num_test_samples = curr_num_samples - num_train_samples 55 | 56 | indices = [j for j in range(curr_num_samples)] 57 | if dataset_name == 'shakespeare': 58 | train_indices = [i for i in range(num_train_samples)] 59 | test_indices = [i for i in range(num_train_samples + 80 - 1, curr_num_samples)] 60 | else: 61 | train_indices = rng.sample(indices, num_train_samples) 62 | test_indices = [i for i in range(curr_num_samples) if i not in train_indices] 63 | 64 | if len(train_indices) >= 1 and len(test_indices) >= 1: 65 | user_indices.append(i) 66 | num_samples_train.append(num_train_samples) 67 | num_samples_test.append(num_test_samples) 68 | user_data_train[u] = {'x': [], 'y': []} 69 | user_data_test[u] = {'x': [], 'y': []} 70 | 71 | train_blist = [False for _ in range(curr_num_samples)] 72 | test_blist = [False for _ in range(curr_num_samples)] 73 | 74 | for j in train_indices: train_blist[j] = True 75 | for j in test_indices:test_blist[j] = True 76 | 77 | for j in range(curr_num_samples): 78 | if train_blist[j]: 79 | user_data_train[u]['x'].append(data['user_data'][u]['x'][j]) 80 | user_data_train[u]['y'].append(data['user_data'][u]['y'][j]) 81 | elif test_blist[j]: 82 | user_data_test[u]['x'].append(data['user_data'][u]['x'][j]) 83 | user_data_test[u]['y'].append(data['user_data'][u]['y'][j]) 84 | users = [data['users'][i] for i in user_indices] 85 | 86 | # create json file of training set 87 | all_data_train = {} 88 | all_data_train['users'] = users 89 | all_data_train['num_samples'] = num_samples_train 90 | all_data_train['user_data'] = user_data_train 91 | 92 | # save file of training set 93 | with open(os.path.join(data_dir, 'train', f'{f[:-5]}_train_0{str(1. - test_size)[2:]}.json'), 'w') as outfile: 94 | json.dump(all_data_train, outfile) 95 | 96 | # create json file of test set 97 | all_data_test = {} 98 | all_data_test['users'] = users 99 | all_data_test['num_samples'] = num_samples_test 100 | all_data_test['user_data'] = user_data_test 101 | 102 | # save file of test set 103 | with open(os.path.join(data_dir, 'test', f'{f[:-5]}_test_0{str(test_size)[2:]}.json'), 'w') as outfile: 104 | json.dump(all_data_test, outfile) 105 | -------------------------------------------------------------------------------- /src/datasets/leaf/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/datasets/leaf/preprocess/__init__.py -------------------------------------------------------------------------------- /src/datasets/leaf/preprocess/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | 9 | def preprocess(root): 10 | TARGET_NAME = 'Smiling' 11 | 12 | def _get_metadata(path): 13 | with open(os.path.join(path, 'raw', 'identity_CelebA.txt'), 'r') as f_identities: 14 | identities = f_identities.read().split('\n') 15 | with open(os.path.join(path, 'raw', 'list_attr_celeba.txt'), 'r') as f_attributes: 16 | attributes = f_attributes.read().split('\n') 17 | return identities, attributes 18 | 19 | def _get_celebrities_and_images(identities, min_samples=5): 20 | all_celebs = {} 21 | for line in identities: 22 | info = line.split() 23 | if len(info) < 2: 24 | continue 25 | image, celeb = info[0], info[1] 26 | if celeb not in all_celebs: 27 | all_celebs[celeb] = [] 28 | all_celebs[celeb].append(image) 29 | 30 | # ignore all celebrities with less than `min_samples` images. 31 | good_celebs = {c: all_celebs[c] for c in all_celebs if len(all_celebs[c]) >= min_samples} 32 | return good_celebs 33 | 34 | def _get_celebrities_and_target(celebrities, attributes, attribute_name=TARGET_NAME): 35 | def _get_celebrities_by_image(identities): 36 | good_images = {} 37 | for c in identities: 38 | images = identities[c] 39 | for img in images: 40 | good_images[img] = c 41 | return good_images 42 | 43 | col_names = attributes[1] 44 | col_idx = col_names.split().index(attribute_name) 45 | 46 | celeb_attributes = {} 47 | good_images = _get_celebrities_by_image(celebrities) 48 | 49 | for line in attributes[2:]: 50 | info = line.split() 51 | if len(info) == 0: 52 | continue 53 | image = info[0] 54 | if image not in good_images: 55 | continue 56 | celeb = good_images[image] 57 | att = (int(info[1:][col_idx]) + 1) / 2 58 | if celeb not in celeb_attributes: 59 | celeb_attributes[celeb] = [] 60 | celeb_attributes[celeb].append(att) 61 | return celeb_attributes 62 | 63 | def _convert_to_json(path, celebrities, targets): 64 | all_data = {} 65 | 66 | celeb_keys = [c for c in celebrities] 67 | num_samples = [len(celebrities[c]) for c in celeb_keys] 68 | data = {c: {'x': celebrities[c], 'y': targets[c]} for c in celebrities} 69 | 70 | all_data['users'] = celeb_keys 71 | all_data['num_samples'] = num_samples 72 | all_data['user_data'] = data 73 | 74 | with open(os.path.join(path, 'all_data', 'all_data.json'), 'w') as outfile: 75 | json.dump(all_data, outfile) 76 | 77 | # set path 78 | DATASET_NAME = __file__.split('/')[-1].split('.')[0] 79 | path = os.path.join(os.path.expanduser(root), DATASET_NAME) 80 | 81 | # check if preprocessing has already done 82 | if not os.path.exists(os.path.join(path, 'all_data')): 83 | os.makedirs(os.path.join(path, 'all_data')) 84 | else: 85 | return 86 | 87 | # get IDs and attributes 88 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Get meta-data...!') 89 | identities, attributes = _get_metadata(path) 90 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished parsing meta-data (IDs and attributes)!') 91 | 92 | # filter out celebs 93 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Construct celeb-image hashmap...!') 94 | celebrities = _get_celebrities_and_images(identities) 95 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...constructed celeb-image hashmap!') 96 | 97 | # filter out targets 98 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Construct inputs-targets hashmap...!') 99 | targets = _get_celebrities_and_target(celebrities, attributes) 100 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...constructed inputs-targets hashmap!') 101 | 102 | # convert to json format 103 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert data to json format...!') 104 | _convert_to_json(path, celebrities, targets) 105 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished converting data to json format!') 106 | -------------------------------------------------------------------------------- /src/datasets/leaf/preprocess/femnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import logging 5 | import hashlib 6 | 7 | import numpy as np 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | 13 | def preprocess(root): 14 | MAX_WRITERS = 100 # max number of writers per json file. 15 | 16 | def _save_obj(obj, name): 17 | with open(f'{name}.pkl', 'wb') as f: 18 | pickle.dump(obj, f) 19 | 20 | def _load_obj(name): 21 | with open(f'{name}.pkl', 'rb') as f: 22 | file = pickle.load(f) 23 | return file 24 | 25 | def _parse_files(path): 26 | def _parse_class_files(path): 27 | class_files = [] # (class, file directory) 28 | class_dir = os.path.join(path, 'raw', 'by_class') 29 | rel_class_dir = os.path.join(path, 'raw', 'by_class') 30 | classes = os.listdir(class_dir) 31 | classes = [c for c in classes if len(c) == 2] 32 | 33 | for cl in classes: 34 | cldir = os.path.join(class_dir, cl) 35 | rel_cldir = os.path.join(rel_class_dir, cl) 36 | subcls = os.listdir(cldir) 37 | 38 | subcls = [s for s in subcls if (('hsf' in s) and ('mit' not in s))] 39 | 40 | for subcl in subcls: 41 | subcldir = os.path.join(cldir, subcl) 42 | rel_subcldir = os.path.join(rel_cldir, subcl) 43 | images = os.listdir(subcldir) 44 | image_dirs = [os.path.join(rel_subcldir, i) for i in images] 45 | for image_dir in image_dirs: 46 | class_files.append((cl, image_dir)) 47 | _save_obj(class_files, os.path.join(path, 'intermediate', 'class_file_dirs')) 48 | 49 | def _parse_write_files(path): 50 | write_files = [] # (writer, file directory) 51 | write_dir = os.path.join(path, 'raw', 'by_write') 52 | rel_write_dir = os.path.join(path, 'raw', 'by_write') 53 | write_parts = os.listdir(write_dir) 54 | 55 | for write_part in write_parts: 56 | writers_dir = os.path.join(write_dir, write_part) 57 | rel_writers_dir = os.path.join(rel_write_dir, write_part) 58 | writers = os.listdir(writers_dir) 59 | 60 | for writer in writers: 61 | writer_dir = os.path.join(writers_dir, writer) 62 | rel_writer_dir = os.path.join(rel_writers_dir, writer) 63 | wtypes = os.listdir(writer_dir) 64 | 65 | for wtype in wtypes: 66 | type_dir = os.path.join(writer_dir, wtype) 67 | rel_type_dir = os.path.join(rel_writer_dir, wtype) 68 | images = os.listdir(type_dir) 69 | image_dirs = [os.path.join(rel_type_dir, i) for i in images] 70 | 71 | for image_dir in image_dirs: 72 | write_files.append((writer, image_dir)) 73 | _save_obj(write_files, os.path.join(path, 'intermediate', 'write_file_dirs')) 74 | 75 | # parse class & write files sequentially 76 | _parse_class_files(path) 77 | _parse_write_files(path) 78 | 79 | def _get_hashes(path): 80 | def _get_class_hashes(path): 81 | cfd = os.path.join(path, 'intermediate', 'class_file_dirs') 82 | class_file_dirs = _load_obj(cfd) 83 | class_file_hashes = [] 84 | 85 | count = 0 86 | for cclass, cfile in class_file_dirs: 87 | chash = hashlib.md5(open(cfile, 'rb').read()).hexdigest() 88 | class_file_hashes.append((cclass, cfile, chash)) 89 | count += 1 90 | 91 | cfhd = os.path.join(path, 'intermediate', 'class_file_hashes') 92 | _save_obj(class_file_hashes, cfhd) 93 | 94 | def _get_write_hashes(path): 95 | wfd = os.path.join(path, 'intermediate', 'write_file_dirs') 96 | write_file_dirs = _load_obj(wfd) 97 | write_file_hashes = [] 98 | 99 | count = 0 100 | for cclass, cfile in write_file_dirs: 101 | chash = hashlib.md5(open(cfile, 'rb').read()).hexdigest() 102 | write_file_hashes.append((cclass, cfile, chash)) 103 | count += 1 104 | 105 | wfhd = os.path.join(path, 'intermediate', 'write_file_hashes') 106 | _save_obj(write_file_hashes, wfhd) 107 | 108 | # get class & write hashes sequentially 109 | _get_class_hashes(path) 110 | _get_write_hashes(path) 111 | 112 | def _match_by_hashes(path): 113 | # read class file hash 114 | cfhd = os.path.join(path, 'intermediate', 'class_file_hashes') 115 | class_file_hashes = _load_obj(cfhd) # each element is (class, file dir, hash) 116 | class_hash_dict = {} 117 | for i in range(len(class_file_hashes)): 118 | c, f, h = class_file_hashes[len(class_file_hashes)-i-1] 119 | class_hash_dict[h] = (c, f) 120 | 121 | # read write file hash 122 | wfhd = os.path.join(path, 'intermediate', 'write_file_hashes') 123 | write_file_hashes = _load_obj(wfhd) # each element is (writer, file dir, hash) 124 | 125 | # match 126 | write_classes = [] 127 | for tup in write_file_hashes: 128 | w, f, h = tup 129 | write_classes.append((w, f, class_hash_dict[h][0])) 130 | wwcd = os.path.join(path, 'intermediate', 'write_with_class') 131 | _save_obj(write_classes, wwcd) 132 | 133 | def _group_by_write(path): 134 | wwcd = os.path.join(path, 'intermediate', 'write_with_class') 135 | write_class = _load_obj(wwcd) 136 | 137 | writers, cimages = [], [] # each entry is a (writer, [list of (file, class)]) tuple 138 | cw, _, _ = write_class[0] 139 | for w, f, c in write_class: 140 | if w != cw: 141 | writers.append((cw, cimages)) 142 | cw = w 143 | cimages = [(f, c)] 144 | cimages.append((f, c)) 145 | writers.append((cw, cimages)) 146 | 147 | ibwd = os.path.join(path, 'intermediate', 'images_by_write') 148 | _save_obj(writers, ibwd) 149 | 150 | def _convert_to_json(path): 151 | def _relabel_femnist_class(c): 152 | """ 153 | Maps hexadecimal class value (string) to a decimal number. 154 | 155 | Args: 156 | c: class indices represented by hexadecimal values 157 | 158 | Returns: 159 | - 0 through 9 for classes representing respective numbers 160 | - 10 through 35 for classes representing respective uppercase letters 161 | - 36 through 61 for classes representing respective lowercase letters 162 | """ 163 | if c.isdigit() and int(c) < 40: # digit 164 | return (int(c) - 30) 165 | elif int(c, 16) <= 90: # uppercase 166 | return (int(c, 16) - 55) 167 | else: # lowercase 168 | return (int(c, 16) - 61) 169 | 170 | by_writer_dir = os.path.join(path, 'intermediate', 'images_by_write') 171 | writers = _load_obj(by_writer_dir) 172 | 173 | users, num_samples, user_data = [], [], {} 174 | writer_count, all_writers = 0, 0 175 | 176 | # assign data 177 | for w, l in writers: 178 | users.append(w) 179 | num_samples.append(len(l)) 180 | user_data[w] = {'x': [], 'y': []} 181 | 182 | for f, c in l: 183 | #gray = PIL.Image.open(f).convert('L') 184 | #vec = 1 - np.array(gray) / 255 # scale all pixel values to between 0 and 1 185 | #vec = vec.tolist() 186 | 187 | nc = _relabel_femnist_class(c) 188 | user_data[w]['x'].append(str(f)) 189 | user_data[w]['y'].append(nc) 190 | writer_count += 1 191 | all_writers += 1 192 | else: 193 | all_data = {} 194 | all_data['users'] = users 195 | all_data['num_samples'] = num_samples 196 | all_data['user_data'] = user_data 197 | 198 | file_name = f'all_data.json' 199 | file_path = os.path.join(path, 'all_data', file_name) 200 | 201 | with open(file_path, 'w') as outfile: 202 | json.dump(all_data, outfile) 203 | 204 | # set path 205 | DATASET_NAME = __file__.split('/')[-1].split('.')[0] 206 | path = os.path.join(os.path.expanduser(root), DATASET_NAME) 207 | 208 | # check if preprocessing has already done 209 | if not os.path.exists(os.path.join(path, 'all_data')): 210 | os.makedirs(os.path.join(path, 'all_data')) 211 | else: 212 | return 213 | 214 | # create intermediate file directories 215 | if not os.path.exists(os.path.join(path, 'intermediate')): 216 | os.makedirs(os.path.join(path, 'intermediate')) 217 | 218 | # parse files 219 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Extract files of raw images...!') 220 | _parse_files(path) 221 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished extracting files of raw images!') 222 | 223 | # get file hashes 224 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Calculate image hashes...!') 225 | _get_hashes(path) 226 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished calculating image hashes!') 227 | 228 | # match by hashes 229 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Assign class labels to write images...!') 230 | _match_by_hashes(path) 231 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished assigning class labels to write images!') 232 | 233 | # group images by writer 234 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Group images by write...!') 235 | _group_by_write(path) 236 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished grouping images by writer!') 237 | 238 | # convert to json format 239 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert data to json format...!') 240 | _convert_to_json(path) 241 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished converting data to json format!') 242 | -------------------------------------------------------------------------------- /src/datasets/leaf/preprocess/reddit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | 5 | from collections import Counter, defaultdict 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | 11 | def preprocess(root): 12 | def _make_path(path, dir_name): 13 | if not os.path.exists(os.path.join(path, dir_name)): 14 | os.makedirs(os.path.join(path, dir_name)) 15 | 16 | def _load_data(path): 17 | with open(path, 'r') as file: 18 | data = json.load(file) 19 | return data 20 | 21 | def _save_data(path, data): 22 | with open(path, 'w') as file: 23 | json.dump(data, file) 24 | 25 | def _refine_data(data): 26 | num_samples = [] 27 | for user in data['users']: 28 | num_samples.append(len(data['user_data'][user]['x'])) # get correct sample counts 29 | data['user_data'][user]['y'] = [original if (type(original) is list) else original['target_tokens'] for original in data['user_data'][user]['y']] # don't know why... but some samples are not parsed (i.e., in `dict` format, not `list`) 30 | else: 31 | data['num_samples'] = num_samples 32 | return data 33 | 34 | def _build_counter(train_data): 35 | all_tokens = [] 36 | for u in train_data: 37 | for c in train_data[u]['x']: 38 | for s in c: 39 | all_tokens.extend(s) 40 | counter = Counter() 41 | counter.update(all_tokens) 42 | return counter 43 | 44 | def _build_vocab(counter): 45 | vocab_size = 10000 46 | pad_symbol, unk_symbol, bos_symbol, eos_symbol = 0, 1, 2, 3 47 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 48 | count_pairs = count_pairs[:vocab_size - 1] 49 | 50 | words, _ = list(zip(*count_pairs)) 51 | words = list(words) 52 | vocab = {} 53 | vocab[''] = pad_symbol 54 | vocab[''] = unk_symbol 55 | vocab[''] = bos_symbol 56 | vocab[''] = eos_symbol 57 | 58 | idx = 4 # due to special tokens 59 | while len(words) > 0: 60 | w = words.pop() 61 | if w in ['', '', '', '']: 62 | continue 63 | vocab[w] = idx 64 | idx += 1 65 | vocab = {'vocab': vocab, 'size': vocab_size, 'unk_symbol': unk_symbol, 'pad_symbol': pad_symbol, 'bos_symbol': bos_symbol, 'eos_symbol': eos_symbol} 66 | return vocab 67 | 68 | def _convert_to_ids_and_get_length(raw, vocab): 69 | def _tokens_to_word_ids(tokens, vocab): 70 | return [vocab[word] for word in tokens] 71 | 72 | def _convert_to_id(container, key): 73 | transformed = [] 74 | for data in container[key]: 75 | for sent in data: 76 | idx = _tokens_to_word_ids(sent, vocab) 77 | transformed.append([idx]) 78 | return transformed 79 | 80 | for user in raw['users']: 81 | raw['user_data'][user]['x'] = _convert_to_id(raw['user_data'][user], 'x') 82 | raw['user_data'][user]['y'] = _convert_to_id(raw['user_data'][user], 'y') 83 | return raw 84 | 85 | # set path 86 | DATASET_NAME = __file__.split('/')[-1].split('.')[0] 87 | path = os.path.join(os.path.expanduser(root), DATASET_NAME) 88 | 89 | # check if preprocessing has already done 90 | _make_path(path, 'all_data') 91 | _make_path(path, 'vocab') 92 | _make_path(path, 'intermediate') 93 | _make_path(path, 'train') 94 | _make_path(path, 'test') 95 | 96 | # adjust path since preprocessed json files are already prepared 97 | if os.path.exists(os.path.join(path, 'raw', 'new_small_data')): 98 | for file in os.listdir(os.path.join(path, 'raw', 'new_small_data')): 99 | if 'train' in file: 100 | os.replace(os.path.join(path, 'raw', 'new_small_data', file), os.path.join(path, 'intermediate', file)) 101 | elif 'test' in file: 102 | os.replace(os.path.join(path, 'raw', 'new_small_data', file), os.path.join(path, 'intermediate', file)) 103 | else: 104 | os.remove(os.path.join(path, 'raw', 'new_small_data', file)) # `val` data is not required... 105 | 106 | # edit `num_samples`: don't know why but it is not correct... 107 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Refine raw data...!') 108 | train_data = _load_data(os.path.join(path, 'intermediate', 'train_data.json')) 109 | test_data = _load_data(os.path.join(path, 'intermediate', 'test_data.json')) 110 | 111 | # correct number of samples and filter tokenized samples only 112 | train_data = _refine_data(train_data) 113 | test_data = _refine_data(test_data) 114 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished refining raw data!') 115 | 116 | # aggreagte data 117 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Combine raw data...!') 118 | _save_data(os.path.join(path, 'all_data', 'train_data_refined.json'), train_data) 119 | _save_data(os.path.join(path, 'all_data', 'test_data_refined.json'), test_data) 120 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished combining raw data!') 121 | 122 | # build vocabulary 123 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Build vocabulary...!') 124 | counter = _build_counter(train_data['user_data']) 125 | vocab_raw = _build_vocab(counter) 126 | _save_data(os.path.join(path, 'vocab', 'reddit_vocab.json'), vocab_raw) 127 | 128 | vocab = defaultdict(lambda: vocab_raw['unk_symbol']) 129 | vocab.update(vocab_raw['vocab']) 130 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...vocabulary is successfully created!') 131 | 132 | # convert tokens to index 133 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert tokens into indices using vocabulary...!') 134 | train_data = _convert_to_ids_and_get_length(train_data, vocab) 135 | test_data = _convert_to_ids_and_get_length(test_data, vocab) 136 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...all tokens are converted into indices!') 137 | 138 | # save processed data 139 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Split into training & test sets...!') 140 | _save_data(os.path.join(path, 'train', 'all_data_niid_00_train.json'), train_data) 141 | _save_data(os.path.join(path, 'test', 'all_data_niid_00_test.json'), test_data) 142 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...done splitting into training & test sets!') 143 | -------------------------------------------------------------------------------- /src/datasets/leaf/preprocess/sent140.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import logging 5 | 6 | import pandas as pd 7 | 8 | from collections import defaultdict, OrderedDict 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | pd.set_option('mode.chained_assignment', None) 13 | 14 | 15 | 16 | def preprocess(root): 17 | RAW_TRAINING = 'training.1600000.processed.noemoticon.csv' 18 | RAW_TEST = 'testdata.manual.2009.06.14.csv' 19 | def _get_glove_vocab(path): 20 | # read GloVe embeddings (300-dim) 21 | lines = [] 22 | with open(os.path.join(path, 'raw', 'glove.6B.300d.txt'), 'r') as file: 23 | lines = file.readlines() 24 | 25 | # process embeddings 26 | lines = [l.split() for l in lines] 27 | vocab = [l[0] for l in lines] 28 | 29 | # get word indices 30 | vocab_indices = defaultdict(int) 31 | for i, w in enumerate(vocab): 32 | vocab_indices[w] = i + 1 33 | 34 | # get index - embedding map 35 | embs = [[float(n) for n in l[1:]] for l in lines] 36 | embs.insert(0, [0. for _ in range(300)]) # for unknown token 37 | 38 | # save into file 39 | with open(os.path.join(path, 'vocab', 'glove.6B.300d.json'), 'w') as file: 40 | json.dump(embs, file) 41 | return vocab_indices 42 | 43 | def _combine_data(path): 44 | raw_train = pd.read_csv( 45 | os.path.join(path, RAW_TRAINING), 46 | encoding='ISO-8859-1', 47 | header=None, 48 | names=['target', 'id', 'date', 'flag', 'user', 'text'], 49 | usecols=['target', 'user', 'text'], 50 | index_col='user' 51 | ) 52 | raw_test = pd.read_csv( 53 | os.path.join(path, RAW_TEST), 54 | encoding='ISO-8859-1', 55 | header=None, 56 | names=['target', 'id', 'date', 'flag', 'user', 'text'], 57 | usecols=['target', 'user', 'text'], 58 | index_col='user' 59 | ) 60 | raw_all = pd.concat([raw_train, raw_test]).sort_index(kind='mergesort') 61 | return raw_all 62 | 63 | def _convert_to_json(path, raw_all, indices): 64 | def _cleanse(df): 65 | # dictionary containing all emojis with their meanings. 66 | emojis = { 67 | ':)': 'smile', ':-)': 'smile', ';d': 'wink', ':-E': 'vampire', ':(': 'sad', 68 | ':-(': 'sad', ':-<': 'sad', ':P': 'raspberry', ':O': 'surprised', 69 | ':-@': 'shocked', ':@': 'shocked',':-$': 'confused', ':\\': 'annoyed', 70 | ':#': 'mute', ':X': 'mute', ':^)': 'smile', ':-&': 'confused', '$_$': 'greedy', 71 | '@@': 'eyeroll', ':-!': 'confused', ':-D': 'smile', ':-0': 'yell', 'O.o': 'confused', 72 | '<(-_-)>': 'robot', 'd[-_-]b': 'dj', ":'-)": 'sadsmile', ';)': 'wink', 73 | ';-)': 'wink', 'O:-)': 'angel','O*-)': 'angel','(:-D': 'gossip', '=^.^=': 'cat' 74 | } 75 | 76 | # replace mentions 77 | df = df.apply(lambda x: re.sub(r"@[^\s]+", "USER", str(x).strip())) 78 | 79 | # replace links 80 | df = df.apply(lambda x: re.sub(r"((http://)[^ ]*|(https://)[^ ]*|( www\.)[^ ]*)", "URL", str(x).strip())) 81 | 82 | # remove non-alphabetical characters 83 | df = df.apply(lambda x: re.sub(r"[^A-Za-z0-9]+", " ", str(x).strip())) 84 | 85 | # remove numbers 86 | df = df.apply(lambda x: re.sub(r"[0-9]+", " ", str(x).strip())) 87 | 88 | # remove ordinals 89 | df = df.apply(lambda x: re.sub(r"[0-9]+(?:st| st|nd| nd|rd| rd|th| th)" , " ", str(x).strip())) 90 | 91 | # remove punctuations 92 | df = df.str.replace(r'[^\w\s]', '', regex=True) 93 | 94 | # replace 3 or more consecutive letters by 2 letter 95 | df = df.apply(lambda x: re.sub(r"(.)\1\1+", r"\1\1", str(x).strip())) 96 | 97 | # replace emojis 98 | for emoji, name in emojis.items(): 99 | df = df.str.replace(emoji, name, regex=False) 100 | 101 | # normalize whitespaces 102 | df = df.str.replace(r'\s+', ' ', regex=True) 103 | 104 | # to lowercase 105 | df = df.str.lower() 106 | return df 107 | 108 | def _split_line(line): 109 | """Split given line/phrase into list of words 110 | """ 111 | return re.findall(r"[\w']+|[.,!?;]", line) 112 | 113 | def _line_to_indices(line, word2id, max_words=25): 114 | """Converts given phrase into list of word indices. 115 | 116 | - If the phrase has more than `max_words` words, 117 | returns a list containing indices of the first `max_words` words. 118 | 119 | - If the phrase has less than `max_words` words, 120 | repeatedly appends integer representing padding index to returned list 121 | until the list's length is `max_words`. 122 | 123 | Args: 124 | line: string representing phrase/sequence of words 125 | word2id: dictionary with string words as keys and int indices as values 126 | max_words: maximum number of word indices in returned list 127 | 128 | Returns: 129 | indices: list of word indices, one index for each word in phrase 130 | """ 131 | unk_id = 0 132 | line_list = _split_line(line) # split phrase in words 133 | indices = [word2id[w] for w in line_list[:max_words]] 134 | indices += [unk_id] * (max_words - len(indices)) 135 | return indices 136 | 137 | # convert user ID into digits 138 | user_id_map = {str_id: int_id for int_id, str_id in enumerate(raw_all.index.unique().tolist())} 139 | curr_ids = raw_all.index.to_series() 140 | raw_all.index = curr_ids.map(user_id_map) 141 | 142 | # refine raw data 143 | raw_all.loc[:, 'target'].replace({4: 1, 2: 0}, inplace=True) 144 | raw_all.loc[:, 'text'] = _cleanse(raw_all.loc[:, 'text']) 145 | raw_all.loc[:, 'text'] = raw_all['text'].apply(lambda x: _line_to_indices(x, indices)) 146 | raw_all = raw_all.reset_index().groupby('user').agg({'text': lambda x: [i for i in x], 'target': lambda y: [l for l in y]}).rename(columns={'text': 'x', 'target': 'y'}) 147 | raw_all.index = raw_all.index.astype(str) 148 | 149 | # get required elements 150 | users = raw_all.index.tolist() 151 | num_samples = raw_all['y'].apply(len).values.tolist() 152 | user_data = raw_all.T.to_dict() 153 | 154 | # create json file 155 | all_data = OrderedDict() 156 | all_data['users'] = users 157 | all_data['num_samples'] = num_samples 158 | all_data['user_data'] = user_data 159 | 160 | # save file 161 | with open(os.path.join(path, 'all_data', 'all_data.json'), 'w') as outfile: 162 | json.dump(all_data, outfile) 163 | 164 | # set path 165 | DATASET_NAME = __file__.split('/')[-1].split('.')[0] 166 | path = os.path.join(os.path.expanduser(root), DATASET_NAME) 167 | 168 | # check if preprocessing has already done 169 | if not os.path.exists(os.path.join(path, 'all_data')): 170 | os.makedirs(os.path.join(path, 'all_data')) 171 | else: 172 | return 173 | 174 | # make path for GloVe vocabulary 175 | if not os.path.exists(os.path.join(path, 'vocab')): 176 | os.makedirs(os.path.join(path, 'vocab')) 177 | 178 | # get GloVe vocabulary 179 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Process GloVe embeddings (300 dim.)...!') 180 | glove_vocab_indices = _get_glove_vocab(path) 181 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished processing GloVe embeddings!') 182 | 183 | # combine raw data 184 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Combine raw data...!') 185 | raw_all = _combine_data(os.path.join(path, 'raw')) 186 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished combining raw data!') 187 | 188 | # convert to json format 189 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert data to json format... (this may take several minutes)!') 190 | _convert_to_json(path, raw_all, glove_vocab_indices) 191 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished converting data to json format!') 192 | -------------------------------------------------------------------------------- /src/datasets/leafparser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import sys 4 | import json 5 | import torch 6 | import logging 7 | import importlib 8 | import concurrent.futures 9 | 10 | from abc import abstractmethod 11 | 12 | from src import TqdmToLogger 13 | from src.datasets.leaf import * 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | 19 | class LEAFDataset(torch.utils.data.Dataset): 20 | """Base dataset class for LEAF benchmark dataset. 21 | """ 22 | def __init__(self): 23 | super(LEAFDataset, self).__init__() 24 | self.identifier = None 25 | self.num_samples = 0 26 | 27 | @abstractmethod 28 | def make_dataset(self): 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def __getitem__(self, index): 33 | raise NotImplementedError 34 | 35 | def __len__(self): 36 | return self.num_samples 37 | 38 | def __repr__(self): 39 | return str(self.identifier) 40 | 41 | # LEAF - FEMNIST 42 | class FEMNIST(LEAFDataset): 43 | def __init__(self, in_channels, num_classes, transform=None): 44 | super(FEMNIST, self).__init__() 45 | self.in_channels = in_channels 46 | self.transform = transform 47 | self.num_classes = num_classes 48 | 49 | def _process(self, raw_path): 50 | inputs = PIL.Image.open(raw_path).convert('L') 51 | return inputs 52 | 53 | def make_dataset(self): 54 | inputs, targets = self.data['x'], self.data['y'] 55 | self.inputs = [raw_path for raw_path in inputs] 56 | self.targets = torch.tensor(targets).long() 57 | self.num_samples = len(self.inputs) 58 | 59 | def __getitem__(self, index): 60 | inputs, targets = self._process(self.inputs[index]), self.targets[index] 61 | if self.transform is not None: 62 | inputs = self.transform(inputs) 63 | return inputs, targets 64 | 65 | # LEAF - Shakespeare 66 | class Shakespeare(LEAFDataset): 67 | def __init__(self, num_embeddings, num_classes): 68 | super(Shakespeare, self).__init__() 69 | self.num_embeddings = num_embeddings 70 | self.num_classes = num_classes 71 | 72 | def make_dataset(self): 73 | self.inputs, self.targets = torch.tensor(self.data['x']), torch.tensor(self.data['y']) 74 | self.num_samples = len(self.inputs) 75 | 76 | def __getitem__(self, index): 77 | return self.inputs[index], self.targets[index] 78 | 79 | # LEAF - Sent140 80 | class Sent140(LEAFDataset): 81 | def __init__(self, num_embeddings, seq_len, num_classes): 82 | super(Sent140, self).__init__() 83 | self.num_embeddings = num_embeddings 84 | self.seq_len = seq_len 85 | self.num_classes = num_classes 86 | 87 | def make_dataset(self): 88 | self.inputs, self.targets = torch.tensor(self.data['x']).long(), torch.tensor(self.data['y']).long() 89 | self.num_samples = len(self.inputs) 90 | 91 | def __getitem__(self, index): 92 | return self.inputs[index], self.targets[index] 93 | 94 | # LEAF - CelebA 95 | class CelebA(LEAFDataset): 96 | def __init__(self, in_channels, img_path, num_classes, transform=None): 97 | super(CelebA, self).__init__() 98 | self.in_channels = in_channels 99 | self.img_path = img_path 100 | self.num_classes = num_classes 101 | self.transform = transform 102 | 103 | def _process(self, path): 104 | inputs = PIL.Image.open(os.path.join(self.img_path, path)).convert('RGB') 105 | return inputs 106 | 107 | def make_dataset(self): 108 | inputs, targets = self.data['x'], self.data['y'] 109 | self.inputs = [fname for fname in inputs] 110 | self.targets = torch.tensor(targets).long() 111 | self.num_samples = len(self.inputs) 112 | 113 | def __getitem__(self, index): 114 | inputs, targets = self._process(self.inputs[index]), self.targets[index] 115 | if self.transform is not None: 116 | inputs = self.transform(inputs) 117 | return inputs, targets 118 | 119 | # LEAF - Reddit 120 | class Reddit(LEAFDataset): 121 | def __init__(self, num_embeddings, seq_len, num_classes): 122 | super(Reddit, self).__init__() 123 | self.num_embeddings = num_embeddings 124 | self.seq_len = seq_len 125 | self.num_classes = num_classes 126 | 127 | def make_dataset(self): 128 | self.inputs, self.targets = torch.tensor(self.data['x']).squeeze(1), torch.tensor(self.data['y']).squeeze(1) 129 | self.num_samples = len(self.inputs) 130 | 131 | def __getitem__(self, index): 132 | return self.inputs[index], self.targets[index] 133 | 134 | def fetch_leaf(args, dataset_name, root, seed, raw_data_fraction, test_size, transforms): 135 | CONFIG = { 136 | 'femnist': {'in_channels': 1, 'num_classes': 62}, 137 | 'shakespeare': {'num_embeddings': 80, 'num_classes': 80}, 138 | 'sent140': {'num_embeddings': 400000 + 1 if args.model_name == 'Sent140LSTM' else None, 'seq_len': 25, 'num_classes': 2}, # using GloVe 300-dim embeddings 139 | 'celeba': {'in_channels': 3, 'img_path': f'{root}/celeba/raw/img_align_celeba', 'num_classes': 2}, 140 | 'reddit': {'num_embeddings': 10000, 'seq_len': 10, 'num_classes': 10000} # + 1 for an unknown token 141 | } 142 | 143 | def _load_processed(path, mode): 144 | file = os.listdir(os.path.join(path, mode))[0] 145 | with open(os.path.join(path, mode, file), 'r') as f: 146 | proc = json.load(f) 147 | return proc 148 | 149 | def _assign_to_clients(dataset_name, dataset_class, raw_train, raw_test, transforms): 150 | def _construct_dataset(idx, user): 151 | # instantiate module for each training set and test set 152 | tr_dset, te_dset = dataset_class(**CONFIG[dataset_name]), dataset_class(**CONFIG[dataset_name]) 153 | 154 | # set essential attributes for training 155 | tr_dset.identifier = f'[LOAD] [{dataset_name.upper()}] CLIENT < {str(user).zfill(8)} > (train)' 156 | tr_dset.data = raw_train['user_data'][user] 157 | tr_dset.make_dataset() 158 | 159 | # set essential attributes for test 160 | te_dset.identifier = f'[LOAD] [{dataset_name.upper()}] CLIENT < {str(user).zfill(8)} > (test)' 161 | te_dset.data = raw_test['user_data'][user] 162 | te_dset.make_dataset() 163 | 164 | # transplant transform method 165 | tr_dset.transform = transforms[0] 166 | te_dset.transform = transforms[1] 167 | return (tr_dset, te_dset) 168 | 169 | datasets = [] 170 | with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count() - 1) as workhorse: 171 | for idx, user in TqdmToLogger( 172 | enumerate(raw_train['users']), 173 | logger=logger, 174 | desc=f'[LOAD] [LEAF - {dataset_name.upper()}] ...assigning... ', 175 | total=len(raw_train['users']) 176 | ): 177 | datasets.append(workhorse.submit(_construct_dataset, idx, user).result()) 178 | return datasets 179 | 180 | # retrieve appropriate dataset module 181 | dataset_class = getattr(sys.modules[__name__], dataset_name) 182 | 183 | # download data 184 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Check if raw data exists; if not, start downloading!') 185 | if not os.path.exists(f'{root}/{dataset_name.lower()}/raw'): 186 | os.makedirs(f'{root}/{dataset_name.lower()}/raw') 187 | download_data(download_root=f'{root}/{dataset_name.lower()}/raw', dataset_name=dataset_name.lower()) 188 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...raw data is successfully downloaded!') 189 | else: 190 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...raw data already exists!') 191 | 192 | # pre-process raw data (fetch all raw data into json format) 193 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Pre-process raw data into json format!') 194 | importlib.import_module(f'.leaf.preprocess.{dataset_name.lower()}', package=__package__).preprocess(root=root) 195 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done pre-processing raw data into json format!') 196 | 197 | # post-process raw data (split data) 198 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Post-process raw data to be split into train & test!') 199 | args.num_clients = postprocess_leaf(dataset_name.lower(), root, seed, raw_data_fraction=raw_data_fraction, min_samples_per_clients=0, test_size=test_size) 200 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done post-processing raw data into train & test splits!') 201 | 202 | # get raw data 203 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Load training & test datasets...!') 204 | raw_train = _load_processed(os.path.join(root, dataset_name.lower()), 'train') 205 | raw_test = _load_processed(os.path.join(root, dataset_name.lower()), 'test') 206 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done parsing trainig & test datasets!') 207 | 208 | # make dataset for each client 209 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Instantiate client datasets and create split hashmap...!') 210 | client_datasets = _assign_to_clients(dataset_name.lower(), dataset_class, raw_train, raw_test, transforms) 211 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...instantiated client datasets and created split hashmap!') 212 | 213 | # adjust arguments 214 | args.num_classes = CONFIG[dataset_name.lower()]['num_classes'] 215 | args.K = len(client_datasets) 216 | if 'in_channels' in CONFIG[dataset_name.lower()].keys(): 217 | args.in_channels = CONFIG[dataset_name.lower()]['in_channels'] 218 | if 'seq_len' in CONFIG[dataset_name.lower()].keys(): 219 | args.seq_len = CONFIG[dataset_name.lower()]['seq_len'] 220 | if 'num_embeddings' in CONFIG[dataset_name.lower()].keys(): 221 | args.num_embeddings = CONFIG[dataset_name.lower()]['num_embeddings'] 222 | 223 | # adjust argument 224 | return {}, client_datasets, args 225 | -------------------------------------------------------------------------------- /src/datasets/speechcommands.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torchaudio 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | LABELS = sorted([ 10 | 'on', 'learn', 'tree', 'down', 'forward', 11 | 'backward', 'happy', 'off', 'nine', 'eight', 12 | 'left', 'four', 'one', 'visual', 'sheila', 13 | 'no', 'six', 'dog', 'up', 'five', 14 | 'marvin', 'cat', 'yes', 'zero', 'house', 15 | 'bird', 'go', 'seven', 'stop', 'wow', 16 | 'three', 'follow', 'right', 'bed', 'two' 17 | ]) 18 | 19 | # dataset wrapper module 20 | class AudioClassificationDataset(torch.utils.data.Dataset): 21 | def __init__(self, dataset, dataset_name, suffix): 22 | self.dataset = dataset 23 | self.dataset_name = dataset_name 24 | self.suffix = suffix 25 | self.targets = self.dataset.targets 26 | 27 | def __getitem__(self, index): 28 | def _label_to_index(word): 29 | # Return the position of the word in labels 30 | return torch.tensor(LABELS.index(word)) 31 | 32 | def _pad_sequence(batch, max_len=16000): 33 | # Make all tensor in a batch the same length by padding with zeros 34 | batch = [torch.nn.functional.pad(item.t(), (0, 0, 0, max_len - len(item.t())), value=0.) for item in batch] 35 | batch = torch.cat(batch) 36 | return batch.t() 37 | 38 | # get raw batch by index 39 | batch = self.dataset[index] 40 | 41 | # gather in lists, and encode labels as indices 42 | inputs, targets = [], [] 43 | for waveform, _, label, *_ in (batch,): 44 | inputs += [waveform] 45 | targets += [_label_to_index(label)] 46 | 47 | # group the list of tensors into a batched tensor 48 | inputs = _pad_sequence(inputs) 49 | targets = torch.stack(targets).squeeze() 50 | return inputs, targets 51 | 52 | def __len__(self): 53 | return len(self.dataset) 54 | 55 | def __repr__(self): 56 | return f'[{self.dataset_name}] {self.suffix}' 57 | 58 | class SpeechCommands(torchaudio.datasets.SPEECHCOMMANDS): 59 | def __init__(self, root, split, download): 60 | self.data_root = os.path.expanduser(root) 61 | super(SpeechCommands, self).__init__(root=self.data_root, subset=split, download=download) 62 | 63 | def __repr__(self): 64 | rep_str = {'train': 'CLIENT', 'test': 'SERVER'} 65 | return f'[SpeechCommands] {rep_str[self.split]}' 66 | 67 | # helper method to fetch CINIC-10 dataset 68 | def fetch_speechcommands(args, root): 69 | logger.info('[LOAD] [SpeechCommands] Fetching dataset!') 70 | 71 | # default arguments 72 | DEFAULT_ARGS = {'root': root, 'download': True} 73 | 74 | # configure arguments for training/test dataset 75 | train_args = DEFAULT_ARGS.copy() 76 | train_args['split'] = 'training' 77 | 78 | # create training dataset instance 79 | raw_train = SpeechCommands(**train_args) 80 | train_targets = torch.tensor([LABELS.index(filename.split('/')[3]) for filename in raw_train._walker]).long() 81 | setattr(raw_train, 'targets', train_targets) 82 | raw_train = AudioClassificationDataset(raw_train, 'SpeechCommands'.upper(), 'CLIENT') 83 | 84 | # for global holdout set 85 | test_args = DEFAULT_ARGS.copy() 86 | test_args['split'] = 'testing' 87 | 88 | # create test dataset instance 89 | raw_test = SpeechCommands(**test_args) 90 | test_targets = torch.tensor([LABELS.index(filename.split('/')[3]) for filename in raw_test._walker]).long() 91 | setattr(raw_test, 'targets', test_targets) 92 | raw_test = AudioClassificationDataset(raw_test, 'SpeechCommands'.upper(), 'SERVER') 93 | logger.info('[LOAD] [SpeechCommands] ...fetched dataset!') 94 | 95 | # adjust arguments 96 | args.in_channels = 1 97 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.dataset.targets))) 98 | return raw_train, raw_test, args 99 | 100 | 101 | -------------------------------------------------------------------------------- /src/datasets/tinyimagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import logging 5 | import torchvision 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | 11 | # dataset wrapper module 12 | class TinyImageNet(torchvision.datasets.ImageFolder): 13 | base_folder = 'tiny-imagenet-200' 14 | zip_md5 = '90528d7ca1a48142e341f4ef8d21d0de' 15 | splits = ('train', 'val', 'test') 16 | filename = 'tiny-imagenet-200.zip' 17 | url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' 18 | 19 | def __init__(self, root, split='train', download=True, transform=None, **kwargs): 20 | self.data_root = os.path.expanduser(root) 21 | self.split = torchvision.datasets.utils.verify_str_arg(split, 'split', self.splits) 22 | if download: 23 | self.download() 24 | if not self._check_exists(): 25 | err = 'Dataset not found or corrupted. You can use download=True to download it' 26 | logger.exception(err) 27 | raise RuntimeError(err) 28 | super().__init__(root=self.split_folder, transform=transform, **kwargs) 29 | 30 | def normalize_tin_val_folder_structure(self, path, images_folder='images', annotations_file='val_annotations.txt'): 31 | images_folder = os.path.join(path, images_folder) 32 | annotations_file = os.path.join(path, annotations_file) 33 | 34 | # check if file exists 35 | if not os.path.exists(images_folder) and not os.path.exists(annotations_file): 36 | if not os.listdir(path): 37 | err = 'Validation folder is empty!' 38 | logger.exception(err) 39 | raise RuntimeError(err) 40 | 41 | # parse the annotations 42 | with open(annotations_file) as f: 43 | for line in f: 44 | values = line.split() 45 | img, label = values[:2] 46 | 47 | img_file = os.path.join(images_folder, img) 48 | label_folder = os.path.join(path, label) 49 | 50 | os.makedirs(label_folder, exist_ok=True) 51 | try: 52 | shutil.rmtree(img_file, os.path.join(label_folder, img)) 53 | except FileNotFoundError: 54 | continue 55 | if not os.listdir(images_folder): 56 | raise AssertionError 57 | shutil.rmtree(images_folder) 58 | os.remove(annotations_file) 59 | 60 | @property 61 | def dataset_folder(self): 62 | return os.path.join(self.data_root, self.base_folder) 63 | 64 | @property 65 | def split_folder(self): 66 | return os.path.join(self.dataset_folder, self.split) 67 | 68 | def _check_exists(self): 69 | return os.path.exists(self.split_folder) 70 | 71 | def download(self): 72 | if self._check_exists(): 73 | return 74 | torchvision.datasets.utils.download_and_extract_archive( 75 | self.url, self.data_root, filename=self.filename, 76 | remove_finished=True, md5=self.zip_md5 77 | ) 78 | assert 'val' in self.splits 79 | self.normalize_tin_val_folder_structure(os.path.join(self.dataset_folder, 'val')) 80 | 81 | def __repr__(self): 82 | rep_str = {'train': 'CLIENT', 'test': 'SERVER'} 83 | return f'[TinyImageNet] {rep_str[self.split]}' 84 | 85 | # helper method to fetch Tiny ImageNet dataset 86 | def fetch_tinyimagenet(args, root, transforms): 87 | logger.info('[LOAD] [TINYIMAGENET] Fetching dataset!') 88 | 89 | # default arguments 90 | DEFAULT_ARGS = {'root': root, 'transform': None, 'download': True} 91 | 92 | # configure arguments for training/test dataset 93 | train_args = DEFAULT_ARGS.copy() 94 | train_args['split'] = 'train' 95 | train_args['transform'] = transforms[0] 96 | 97 | # create training dataset instance 98 | raw_train = TinyImageNet(**train_args) 99 | 100 | # for global holdout set 101 | test_args = DEFAULT_ARGS.copy() 102 | test_args['transform'] = transforms[1] 103 | test_args['split'] = 'test' 104 | 105 | # create test dataset instance 106 | raw_test = TinyImageNet(**test_args) 107 | 108 | logger.info('[LOAD] [CINIC10] ...fetched dataset!') 109 | 110 | # adjust argument 111 | args.in_channels = 3 112 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.targets))) 113 | return raw_train, raw_test, args 114 | -------------------------------------------------------------------------------- /src/datasets/torchtextparser.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import sys 4 | import csv 5 | import torch 6 | import logging 7 | import torchtext 8 | 9 | from src import TqdmToLogger 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | 15 | # dataset wrapper module 16 | class TextClassificationDataset(torch.utils.data.Dataset): 17 | def __init__(self, dataset_name, inputs, targets): 18 | self.identifier = dataset_name 19 | self.inputs = inputs 20 | self.targets = targets 21 | 22 | def __len__(self): 23 | return len(self.inputs) 24 | 25 | def __getitem__(self, index): 26 | inputs = self.inputs[index] 27 | targets = self.targets[index] 28 | return inputs, targets 29 | 30 | def __repr__(self): 31 | return str(self.identifier) 32 | 33 | # helper method to fetch dataset from `torchtext.datasets` 34 | def fetch_torchtext_dataset(args, dataset_name, root, tokenizer, seq_len, num_embeddings): 35 | URL = { 36 | 'AG_NEWS': 'https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz', 37 | 'SogouNews': 'https://s3.amazonaws.com/fast-ai-nlp/sogou_news_csv.tgz', 38 | 'DBpedia': 'https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz', 39 | 'YelpReviewPolarity': 'https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz', 40 | 'YelpReviewFull': 'https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz', 41 | 'YahooAnswers': 'https://s3.amazonaws.com/fast-ai-nlp/yahoo_answers_csv.tgz', 42 | 'AmazonReviewPolarity': 'https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz', 43 | 'AmazonReviewFull': 'https://s3.amazonaws.com/fast-ai-nlp/amazon_review_full_csv.tgz' 44 | } 45 | MD5 = { 46 | 'AG_NEWS': '2c2d85915f1ca34b29e754ce3b403c81', 47 | 'SogouNews': '45c19a17716907a7a3bde8c7398f7542', 48 | 'DBpedia': '531dab845dd933d0d7b04e82bffe6d96', 49 | 'YelpReviewPolarity': '0f09b3af1a79c136ef9ca5f29df9ed9a', 50 | 'YelpReviewFull': 'a4acce1892d0f927165488f62860cabe', 51 | 'YahooAnswers': '672a634b0a8a1314138321e7d075a64e', 52 | 'AmazonReviewPolarity': 'ee221fbfc3bf7d49dd7dee8064b2404c', 53 | 'AmazonReviewFull': '289723722b64d337a40f809edd29d0f0' 54 | } 55 | NUM_CLASSES = { 56 | 'AG_NEWS': 4, 57 | 'SogouNews': 5, 58 | 'DBpedia': 14, 59 | 'YelpReviewPolarity': 2, 60 | 'YelpReviewFull': 5, 61 | 'YahooAnswers': 10, 62 | 'AmazonReviewPolarity': 2, 63 | 'AmazonReviewFull': 5 64 | } 65 | 66 | if dataset_name not in URL.keys(): 67 | err = f'Dataset ({dataset_name}) is not supported!' 68 | logger.exception(err) 69 | raise Exception(err) 70 | 71 | def _unicode_csv_reader(unicode_csv_data, **kwargs): 72 | maxInt = sys.maxsize 73 | while True: 74 | try: 75 | csv.field_size_limit(maxInt) 76 | break 77 | except OverflowError: 78 | maxInt = int(maxInt / 10) 79 | csv.field_size_limit(maxInt) 80 | 81 | for line in csv.reader(unicode_csv_data, **kwargs): 82 | yield line 83 | 84 | def _csv_iterator(data_path, yield_cls=False): 85 | tokenizer = torchtext.data.utils.get_tokenizer('basic_english') 86 | with io.open(data_path, encoding='utf8') as f: 87 | reader = _unicode_csv_reader(f) 88 | for row in reader: 89 | tokens = ' '.join(row[1:]) 90 | tokens = tokenizer(tokens) 91 | if yield_cls: 92 | yield int(row[0]) - 1, torchtext.data.utils.ngrams_iterator(tokens, ngrams=1) 93 | else: 94 | yield torchtext.data.utils.ngrams_iterator(tokens, ngrams=1) 95 | 96 | def _create_data_from_iterator(vocab, iterator, max_len): 97 | inputs, targets = [], [] 98 | for label, tokens in TqdmToLogger(iterator, logger=logger, desc=f'[LOAD] [{dataset_name.upper()}] ...prepare raw data!'): 99 | tokens = [vocab[token] for token in tokens] 100 | # pad tokens to have max length 101 | pad_len = max_len - len(tokens) % max_len 102 | if pad_len > 0: 103 | tokens.extend([vocab[''] for _ in range(pad_len)]) 104 | 105 | # slice tokens up to max length 106 | tokens = tokens[:max_len] 107 | 108 | # collect processed pairs 109 | inputs.append(tokens) 110 | targets.append(label) 111 | return torch.tensor(inputs).long(), torch.tensor(targets).long() 112 | 113 | def _create_data_from_tokenizer(tokenizer, iterator, max_len): 114 | inputs, targets = [], [] 115 | for label, tokens in TqdmToLogger(iterator, logger=logger, desc=f'[LOAD] [{dataset_name.upper()}] ...prepare raw data!'): 116 | tokens = tokenizer( 117 | list(tokens), 118 | return_tensors='pt', 119 | is_split_into_words=True, 120 | max_length=max_len, 121 | return_attention_mask=False, 122 | truncation=True, 123 | padding='max_length' 124 | )['input_ids'] 125 | 126 | inputs.append(*tokens) 127 | targets.append(label) 128 | return inputs, targets 129 | 130 | # download files 131 | logger.info(f'[LOAD] [{dataset_name.upper()}] Start downloading files!') 132 | root = os.path.expanduser(root) 133 | raw_files = torchtext.utils.download_from_url( 134 | url=URL[dataset_name], 135 | root=root, 136 | hash_value=MD5[dataset_name], 137 | hash_type='md5' 138 | ) 139 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...downloaded files!') 140 | 141 | 142 | # extract archive 143 | logger.info(f'[LOAD] [{dataset_name.upper()}] Extract archived files!') 144 | raw_files = torchtext.utils.extract_archive(raw_files) 145 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...successfully extracted archived files!') 146 | 147 | # retrieve split files 148 | for fname in raw_files: 149 | if fname.endswith('train.csv'): 150 | train_csv_path = fname 151 | if fname.endswith('test.csv'): 152 | test_csv_path = fname 153 | 154 | # build vocabularies using training set 155 | if tokenizer is None: 156 | logger.info(f'[LOAD] [{dataset_name.upper()}] Build vocabularies!') 157 | vocab = torchtext.vocab.build_vocab_from_iterator(_csv_iterator(train_csv_path), specials=[''], max_tokens=num_embeddings) 158 | vocab.set_default_index(vocab['']) 159 | vocab.vocab.insert_token('', 0) 160 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...vocabularies are built!') 161 | 162 | # tokenize training & test data and prepare inputs/targets 163 | logger.info(f'[LOAD] [{dataset_name.upper()}] Create trainig & test set!') 164 | if tokenizer is None: 165 | tr_inputs, tr_targets = _create_data_from_iterator(vocab, _csv_iterator(train_csv_path, yield_cls=True), seq_len) 166 | te_inputs, te_targets = _create_data_from_iterator(vocab, _csv_iterator(test_csv_path, yield_cls=True), seq_len) 167 | else: 168 | tr_inputs, tr_targets = _create_data_from_tokenizer(tokenizer, _csv_iterator(train_csv_path, yield_cls=True), seq_len) 169 | te_inputs, te_targets = _create_data_from_tokenizer(tokenizer, _csv_iterator(test_csv_path, yield_cls=True), seq_len) 170 | 171 | # adjust labels 172 | min_label_tr, min_label_te = min(tr_targets), min(te_targets) 173 | tr_targets = torch.tensor([l - min_label_tr for l in tr_targets]).long() 174 | te_targets = torch.tensor([l - min_label_te for l in te_targets]).long() 175 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...created training & test set!') 176 | 177 | # adjust arguments 178 | args.num_embeddings = len(vocab) + 1 if tokenizer is None else tokenizer.vocab_size 179 | args.num_classes = NUM_CLASSES[dataset_name] 180 | return TextClassificationDataset(f'[{dataset_name}] CLIENT', tr_inputs, tr_targets), TextClassificationDataset(f'[{dataset_name}] SERVER', te_inputs, te_targets), args 181 | -------------------------------------------------------------------------------- /src/datasets/torchvisionparser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import torchvision 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | 9 | # dataset wrapper module 10 | class VisionClassificationDataset(torch.utils.data.Dataset): 11 | def __init__(self, dataset, dataset_name, suffix): 12 | self.dataset = dataset 13 | self.dataset_name = dataset_name 14 | self.suffix = suffix 15 | self.targets = self.dataset.targets 16 | 17 | def __getitem__(self, index): 18 | inputs, targets = self.dataset[index] 19 | return inputs, targets 20 | 21 | def __len__(self): 22 | return len(self.dataset) 23 | 24 | def __repr__(self): 25 | return f'[{self.dataset_name}] {self.suffix}' 26 | 27 | # helper method to fetch dataset from `torchvision.datasets` 28 | def fetch_torchvision_dataset(args, dataset_name, root, transforms): 29 | logger.info(f'[LOAD] [{dataset_name.upper()}] Fetching dataset!') 30 | 31 | # default arguments 32 | DEFAULT_ARGS = {'root': root, 'transform': None, 'download': True} 33 | 34 | if dataset_name in [ 35 | 'MNIST', 'FashionMNIST', 'QMNIST', 'KMNIST', 'EMNIST',\ 36 | 'CIFAR10', 'CIFAR100', 'USPS' 37 | ]: 38 | # configure arguments for training/test dataset 39 | train_args = DEFAULT_ARGS.copy() 40 | train_args['train'] = True 41 | train_args['transform'] = transforms[0] 42 | 43 | # special case - EMNIST 44 | if dataset_name == 'EMNIST': 45 | train_args['split'] = 'byclass' 46 | 47 | # create training dataset instance 48 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args) 49 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT') 50 | 51 | # for global holdout set 52 | test_args = DEFAULT_ARGS.copy() 53 | test_args['transform'] = transforms[1] 54 | test_args['train'] = False 55 | 56 | # special case - EMNIST 57 | if dataset_name == 'EMNIST': 58 | test_args['split'] = 'byclass' 59 | 60 | # create test dataset instance 61 | raw_test = torchvision.datasets.__dict__[dataset_name](**test_args) 62 | raw_test = VisionClassificationDataset(raw_test, dataset_name.upper(), 'SERVER') 63 | 64 | # adjust arguments 65 | if 'CIFAR' in dataset_name: 66 | args.in_channels = 3 67 | else: 68 | args.in_channels = 1 69 | 70 | elif dataset_name in [ 71 | 'Country211',\ 72 | 'DTD', 'Flowers102', 'Food101', 'FGVCAircraft',\ 73 | 'GTSRB', 'RenderedSST2', 'StanfordCars',\ 74 | 'STL10', 'SVHN' 75 | ]: 76 | # configure arguments for training/test dataset 77 | train_args = DEFAULT_ARGS.copy() 78 | train_args['split'] = 'train' 79 | train_args['transform'] = transforms[0] 80 | 81 | # create training dataset instance 82 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args) 83 | 84 | # for global holdout set 85 | test_args = DEFAULT_ARGS.copy() 86 | test_args['transform'] = transforms[1] 87 | test_args['split'] = 'test' 88 | 89 | # create test dataset instance 90 | raw_test = torchvision.datasets.__dict__[dataset_name](**test_args) 91 | 92 | # for compatibility, create attribute `targets` 93 | if dataset_name in ['DTD', 'Flowers102', 'Food101', 'FGVCAircraft']: 94 | setattr(raw_train, 'targets', raw_train._labels) 95 | setattr(raw_test, 'targets', raw_test._labels) 96 | elif dataset_name in ['GTSRB', 'RenderedSST2', 'StanfordCars']: 97 | setattr(raw_train, 'targets', [*list(zip(*raw_train._samples))[-1]]) 98 | setattr(raw_test, 'targets', [*list(zip(*raw_test._samples))[-1]]) 99 | elif dataset_name in ['STL10', 'SVHN']: 100 | setattr(raw_train, 'targets', raw_train.labels) 101 | setattr(raw_test, 'targets', raw_test.labels) 102 | 103 | # set raw datasets 104 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT') 105 | raw_test = VisionClassificationDataset(raw_test, dataset_name.upper(), 'SERVER') 106 | 107 | # adjust arguments 108 | if 'RenderedSST2' in dataset_name: 109 | args.in_channels = 1 110 | else: 111 | args.in_channels = 3 112 | 113 | elif dataset_name in ['Places365', 'INaturalist', 'OxfordIIITPet', 'Omniglot']: 114 | # configure arguments for training/test dataset 115 | train_args = DEFAULT_ARGS.copy() 116 | train_args['transform'] = transforms[0] 117 | 118 | if dataset_name == 'Places365': 119 | train_args['split'] = 'train-standard' 120 | elif dataset_name == 'OxfordIIITPet': 121 | train_args['split'] = 'trainval' 122 | elif dataset_name == 'INaturalist': 123 | train_args['version'] = '2021_train_mini' 124 | elif dataset_name == 'Omniglot': 125 | train_args['background'] = True 126 | 127 | # create training dataset instance 128 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args) 129 | 130 | # for global holdout set 131 | test_args = DEFAULT_ARGS.copy() 132 | test_args['transform'] = transforms[1] 133 | 134 | if dataset_name == 'Places365': 135 | test_args['split'] = 'val' 136 | elif dataset_name == 'OxfordIIITPet': 137 | test_args['split'] = 'test' 138 | elif dataset_name == 'INaturalist': 139 | test_args['version'] = '2021_valid' 140 | elif dataset_name == 'Omniglot': 141 | test_args['background'] = False 142 | 143 | # create test dataset instance 144 | raw_test = torchvision.datasets.__dict__[dataset_name](**test_args) 145 | 146 | # for compatibility, create attribute `targets` 147 | if dataset_name == 'OxfordIIITPet': 148 | setattr(raw_train, 'targets', raw_train._labels) 149 | elif dataset_name == 'INaturalist': 150 | setattr(raw_train, 'targets', [*list(zip(*raw_train.index))[0]]) 151 | elif dataset_name == 'Omniglot': 152 | setattr(raw_train, 'targets', [*list(zip(*raw_train._flat_character_images))[-1]]) 153 | 154 | # set raw datasets 155 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT') 156 | raw_test = VisionClassificationDataset(raw_test, dataset_name.upper(), 'SERVER') 157 | 158 | # adjust arguments 159 | if 'Omniglot' in dataset_name: 160 | args.in_channels = 1 161 | else: 162 | args.in_channels = 3 163 | 164 | elif dataset_name in ['Caltech256', 'SEMEION', 'SUN397']: 165 | # configure arguments for training dataset 166 | # NOTE: these datasets do NOT provide pre-defined split 167 | # Thus, use all datasets as a training dataset 168 | train_args = DEFAULT_ARGS.copy() 169 | train_args['transform'] = transforms[0] 170 | 171 | # create training dataset instance 172 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args) 173 | 174 | # for compatibility, create attribute `targets` 175 | if dataset_name == 'Caltech256': 176 | setattr(raw_train, 'targets', raw_train.y) 177 | elif dataset_name == 'SEMEION': 178 | setattr(raw_train, 'targets', raw_train.labels) 179 | elif dataset_name == 'SUN397': 180 | setattr(raw_train, 'targets', raw_train._labels) 181 | 182 | # set raw datasets (no holdout set is supported) 183 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT') 184 | raw_test = None 185 | 186 | # adjust arguments 187 | if 'SEMEION' in dataset_name: 188 | args.in_channels = 1 189 | else: 190 | args.in_channels = 3 191 | else: 192 | err = f'[LOAD] Dataset `{dataset_name}` is not supported!' 193 | logger.exception(err) 194 | raise Exception(err) 195 | 196 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.dataset.targets))) 197 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...fetched dataset!') 198 | return raw_train, raw_test, args 199 | -------------------------------------------------------------------------------- /src/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import load_dataset 2 | from .model import load_model -------------------------------------------------------------------------------- /src/loaders/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | import logging 4 | import importlib 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | 10 | def load_model(args): 11 | # retrieve model skeleton 12 | model_class = importlib.import_module('..models', package=__package__).__dict__[args.model_name] 13 | 14 | # get required model arguments 15 | required_args = inspect.getargspec(model_class)[0] 16 | 17 | # collect eneterd model arguments 18 | model_args = {} 19 | for argument in required_args: 20 | if argument == 'self': 21 | continue 22 | model_args[argument] = getattr(args, argument) 23 | 24 | # get model instance 25 | model = model_class(**model_args) 26 | 27 | # adjust arguments if needed 28 | if args.use_pt_model: 29 | args.num_embeddings = model.num_embeddings 30 | args.embedding_size = model.embedding_size 31 | args.num_hiddens = model.num_hiddens 32 | args.dropout = model.dropout 33 | return model, args 34 | -------------------------------------------------------------------------------- /src/loaders/split.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | from src import TqdmToLogger 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | 10 | def simulate_split(args, dataset): 11 | """Split data indices using labels. 12 | 13 | Args: 14 | args (argparser): arguments 15 | dataset (dataset): raw dataset instance to be split 16 | 17 | Returns: 18 | split_map (dict): dictionary with key is a client index and a corresponding value is a list of indices 19 | """ 20 | # IID split (i.e., statistical homogeneity) 21 | if args.split_type == 'iid': 22 | # shuffle sample indices 23 | shuffled_indices = np.random.permutation(len(dataset)) 24 | 25 | # get adjusted indices 26 | split_indices = np.array_split(shuffled_indices, args.K) 27 | 28 | # construct a hashmap 29 | split_map = {k: split_indices[k] for k in range(args.K)} 30 | return split_map 31 | 32 | # non-IID split by sample unbalancedness 33 | if args.split_type == 'unbalanced': 34 | # shuffle sample indices 35 | shuffled_indices = np.random.permutation(len(dataset)) 36 | 37 | # split indices by number of clients 38 | split_indices = np.array_split(shuffled_indices, args.K) 39 | 40 | # randomly remove some proportion (1% ~ 5%) of data 41 | keep_ratio = np.random.uniform(low=0.95, high=0.99, size=len(split_indices)) 42 | 43 | # get adjusted indices 44 | split_indices = [indices[:int(len(indices) * ratio)] for indices, ratio in zip(split_indices, keep_ratio)] 45 | 46 | # construct a hashmap 47 | split_map = {k: split_indices[k] for k in range(args.K)} 48 | return split_map 49 | 50 | # Non-IID split proposed in (McMahan et al., 2016); each client has samples from at least two different classes 51 | elif args.split_type == 'patho': 52 | try: 53 | assert args.mincls >= 2 54 | except AssertionError as e: 55 | logger.exception("[SIMULATE] Each client should have samples from at least 2 distinct classes!") 56 | raise e 57 | 58 | # get indices by class labels 59 | _, unique_inverse, unique_counts = np.unique(dataset.targets, return_inverse=True, return_counts=True) 60 | class_indices = np.split(np.argsort(unique_inverse), np.cumsum(unique_counts[:-1])) 61 | 62 | # divide shards 63 | num_shards_per_class = args.K * args.mincls // args.num_classes 64 | if num_shards_per_class < 1: 65 | err = f'[SIMULATE] Increase the number of minimum class (`args.mincls` > {args.mincls}) or the number of participating clients (`args.K` > {args.K})!' 66 | logger.exception(err) 67 | raise Exception(err) 68 | 69 | # split class indices again into groups, each having the designated number of shards 70 | split_indices = [np.array_split(np.random.permutation(indices), num_shards_per_class) for indices in class_indices] 71 | 72 | # make hashmap to track remaining shards to be assigned per client 73 | class_shards_counts = dict(zip([i for i in range(args.num_classes)], [len(split_idx) for split_idx in split_indices])) 74 | 75 | # assign divided shards to clients 76 | assigned_shards = [] 77 | for _ in TqdmToLogger( 78 | range(args.K), 79 | logger=logger, 80 | desc='[SIMULATE] ...assigning to clients... ' 81 | ): 82 | # update selection proability according to the count of reamining shards 83 | # i.e., do NOT sample from class having no remaining shards 84 | selection_prob = np.where(np.array(list(class_shards_counts.values())) > 0, 1., 0.) 85 | selection_prob /= sum(selection_prob) 86 | 87 | # select classes to be considered 88 | try: 89 | selected_classes = np.random.choice(args.num_classes, args.mincls, replace=False, p=selection_prob) 90 | except: # if shard size is not fit enough, some clients may inevitably have samples from classes less than the number of `mincls` 91 | selected_classes = np.random.choice(args.num_classes, args.mincls, replace=True, p=selection_prob) 92 | 93 | # assign shards in randomly selected classes to current client 94 | for it, class_idx in enumerate(selected_classes): 95 | selected_shard_indices = np.random.choice(len(split_indices[class_idx]), 1)[0] 96 | selected_shards = split_indices[class_idx].pop(selected_shard_indices) 97 | if it == 0: 98 | assigned_shards.append([selected_shards]) 99 | else: 100 | assigned_shards[-1].append(selected_shards) 101 | class_shards_counts[class_idx] -= 1 102 | else: 103 | assigned_shards[-1] = np.concatenate(assigned_shards[-1]) 104 | 105 | # construct a hashmap 106 | split_map = {k: assigned_shards[k] for k in range(args.K)} 107 | return split_map 108 | 109 | # Non-IID split proposed in (Hsu et al., 2019); simulation of non-IID split scenario using Dirichlet distribution 110 | elif args.split_type == 'diri': 111 | MIN_SAMPLES = int(1 / args.test_size) 112 | 113 | # get indices by class labels 114 | total_counts = len(dataset.targets) 115 | _, unique_inverse, unique_counts = np.unique(dataset.targets, return_inverse=True, return_counts=True) 116 | class_indices = np.split(np.argsort(unique_inverse), np.cumsum(unique_counts[:-1])) 117 | 118 | # calculate ideal samples counts per client 119 | ideal_counts = len(dataset.targets) // args.K 120 | if ideal_counts < 1: 121 | err = f'[SIMULATE] Decrease the number of participating clients (`args.K` < {args.K})!' 122 | logger.exception(err) 123 | raise Exception(err) 124 | 125 | # split dataset 126 | ## define temporary container 127 | assigned_indices = [] 128 | 129 | ## NOTE: it is possible that not all samples be consumed, as it is intended for satisfying each clients having at least `MIN_SAMPLES` samples per class 130 | for k in TqdmToLogger(range(args.K), logger=logger, desc='[SIMULATE] ...assigning to clients... '): 131 | ### for current client of which index is `k` 132 | curr_indices = [] 133 | satisfied_counts = 0 134 | 135 | ### ...until the number of samples close to ideal counts is filled 136 | while satisfied_counts < ideal_counts: 137 | ### define Dirichlet distribution of which prior distribution is an uniform distribution 138 | diri_prior = np.random.uniform(size=args.num_classes) 139 | 140 | ### sample a parameter corresponded to that of categorical distribution 141 | cat_param = np.random.dirichlet(alpha=args.cncntrtn * diri_prior) 142 | 143 | ### try to sample by amount of `ideal_counts`` 144 | sampled = np.random.choice(args.num_classes, ideal_counts, p=cat_param) 145 | 146 | ### count per-class samples 147 | unique, counts = np.unique(sampled, return_counts=True) 148 | if len(unique) < args.mincls: 149 | continue 150 | 151 | ### filter out sampled classes not having as much as `MIN_SAMPLES` 152 | required_counts = counts * (counts > MIN_SAMPLES) 153 | 154 | ### assign from population indices split by classes 155 | for idx, required_class in enumerate(unique): 156 | if required_counts[idx] == 0: continue 157 | sampled_indices = class_indices[required_class][:required_counts[idx]] 158 | curr_indices.append(sampled_indices) 159 | class_indices[required_class] = class_indices[required_class][:required_counts[idx]] 160 | satisfied_counts += sum(required_counts) 161 | 162 | ### when enough samples are collected, go to next clients! 163 | assigned_indices.append(np.concatenate(curr_indices)) 164 | 165 | # construct a hashmap 166 | split_map = {k: assigned_indices[k] for k in range(args.K)} 167 | return split_map 168 | # `leaf` - LEAF benchmark (Caldas et al., 2018); `fedvis` - Federated Vision Datasets (Hsu, Qi and Brown, 2020) 169 | elif args.split_type in ['leaf']: 170 | logger.info('[SIMULATE] Use pre-defined split!') 171 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metricszoo import * -------------------------------------------------------------------------------- /src/metrics/basemetric.py: -------------------------------------------------------------------------------- 1 | from abc import * 2 | 3 | 4 | 5 | class BaseMetric(metaclass=ABCMeta): 6 | @abstractmethod 7 | def __init__(self): 8 | raise NotImplementedError 9 | 10 | @abstractmethod 11 | def collect(self, pred, true): 12 | raise NotImplementedError 13 | 14 | @abstractmethod 15 | def summarize(self): 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .twonn import TwoNN 2 | from .twocnn import TwoCNN 3 | from .lenet import LeNet 4 | from .m5 import M5 5 | from .vgg import * 6 | from .resnet import * 7 | from .shufflenet import ShuffleNet 8 | from .mobilenet import MobileNet 9 | from .mobilenext import MobileNeXt 10 | from .squeezenet import SqueezeNet 11 | from .squeezenext import SqueezeNeXt 12 | from .mobilevit import MobileViT 13 | from .stackedlstm import StackedLSTM 14 | from .distilbert import DistilBert 15 | from .mobilebert import MobileBert 16 | from .squeezebert import SqueezeBert 17 | from .logreg import LogReg 18 | from .simplecnn import SimpleCNN 19 | from .femnistcnn import FEMNISTCNN 20 | from .sent140lstm import Sent140LSTM 21 | from .stackedtransformer import StackedTransformer -------------------------------------------------------------------------------- /src/models/distilbert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import DistilBertModel, DistilBertConfig 4 | 5 | 6 | 7 | class DistilBert(torch.nn.Module): 8 | def __init__(self, num_classes, num_embeddings, embedding_size, hidden_size, dropout, use_pt_model, is_seq2seq): 9 | super(DistilBert, self).__init__() 10 | self.is_seq2seq = is_seq2seq 11 | # define encoder 12 | if use_pt_model: # fine-tuning 13 | self.features = DistilBertModel.from_pretrained('distilbert-base-uncased') 14 | self.num_classes = num_classes 15 | self.num_embeddings = self.features.config.vocab_size 16 | self.embedding_size = self.features.config.dim 17 | self.num_hiddens = self.features.config.hidden_size 18 | self.dropout = self.features.config.dropout 19 | 20 | self.classifier = torch.nn.Linear(self.embedding_size, self.num_classes, bias=True) 21 | else: # from scratch 22 | self.num_classes = num_classes 23 | self.num_embeddings = num_embeddings 24 | self.embedding_size = embedding_size 25 | self.num_hiddens = hidden_size 26 | self.dropout = dropout 27 | 28 | config = DistilBertConfig( 29 | vocab_size=self.num_embeddings, 30 | dim=self.embedding_size, 31 | hidden_size=self.num_hiddens, 32 | hidden_dropout_prob=self.dropout 33 | ) 34 | self.features = DistilBertModel(config) 35 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True) 36 | 37 | def forward(self, x): 38 | x = self.features(x)[0] 39 | x = self.classifier(x.last_hidden_state if self.is_seq2seq else x[:, 0, :]) # use [CLS] token for classification 40 | return x 41 | -------------------------------------------------------------------------------- /src/models/femnistcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class FEMNISTCNN(torch.nn.Module): # for FEMNIST experiment in Caldas et al., 2018; (https://github.com/TalwalkarLab/leaf/blob/master/models/femnist/cnn.py 6 | def __init__(self, in_channels, hidden_size, num_classes): 7 | super(FEMNISTCNN, self).__init__() 8 | self.in_channels = in_channels 9 | self.hidden_channels = hidden_size 10 | self.num_classes = num_classes 11 | 12 | self.features = torch.nn.Sequential( 13 | torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=5, padding=1, stride=1, bias=True), 14 | torch.nn.ReLU(), 15 | torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1), 16 | torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels * 2, kernel_size=5, padding=1, stride=1, bias=True), 17 | torch.nn.ReLU(), 18 | torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1), 19 | ) 20 | self.classifier = torch.nn.Sequential( 21 | torch.nn.AdaptiveAvgPool2d((7, 7)), 22 | torch.nn.Flatten(), 23 | torch.nn.Linear(in_features=self.hidden_channels * 2 * (7 * 7), out_features=2048, bias=True), 24 | torch.nn.ReLU(), 25 | torch.nn.Linear(in_features=2048, out_features=self.num_classes, bias=True) 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = self.classifier(x) 31 | return x 32 | -------------------------------------------------------------------------------- /src/models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class LeNet(torch.nn.Module): 6 | def __init__(self, in_channels, num_classes, hidden_size, dropout): 7 | super(LeNet, self).__init__() 8 | self.in_channels = in_channels 9 | self.hidden_channels = hidden_size 10 | self.num_classes = num_classes 11 | self.dropout = dropout 12 | 13 | self.features = torch.nn.Sequential( 14 | torch.nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size=(5, 5), padding=1, stride=1, bias=True), 15 | torch.nn.ReLU(True), 16 | torch.nn.MaxPool2d((2, 2)), 17 | torch.nn.Conv2d(self.hidden_channels, self.hidden_channels * 2, kernel_size=(5, 5), padding=1, stride=1, bias=True), 18 | torch.nn.ReLU(True), 19 | torch.nn.MaxPool2d((2, 2)), 20 | ) 21 | self.classifier = torch.nn.Sequential( 22 | torch.nn.AdaptiveAvgPool2d((4, 4)), 23 | torch.nn.Flatten(), 24 | torch.nn.Linear((4 * 4) * (self.hidden_channels * 2), self.hidden_channels, bias=True), 25 | torch.nn.ReLU(True), 26 | torch.nn.Dropout(self.dropout), 27 | torch.nn.Linear(self.hidden_channels, self.hidden_channels // 2, bias=True), 28 | torch.nn.ReLU(True), 29 | torch.nn.Linear(self.hidden_channels // 2, self.num_classes, bias=True) 30 | ) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = self.classifier(x) 35 | return x 36 | -------------------------------------------------------------------------------- /src/models/logreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class LogReg(torch.nn.Module): 6 | def __init__(self, in_features, num_layers, hidden_size, num_classes): 7 | super(LogReg, self).__init__() 8 | self.in_features = in_features 9 | self.num_classes = num_classes 10 | 11 | if num_layers == 1: 12 | self.features = torch.nn.Identity() 13 | self.classifier = torch.nn.Linear(in_features, num_classes, bias=True) 14 | else: 15 | features = [torch.nn.Linear(in_features, hidden_size, bias=True)] 16 | for _ in range(num_layers - 1): 17 | features.append(torch.nn.Linear(hidden_size, hidden_size, bias=True)) 18 | features.append(torch.nn.ReLU(True)) 19 | self.features = torch.nn.Sequential(*features) 20 | self.classifier = torch.nn.Linear(hidden_size, num_classes, bias=True) 21 | 22 | def forward(self, x): 23 | x = self.features(x) 24 | x = self.classifier(x) 25 | return x -------------------------------------------------------------------------------- /src/models/m5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class M5(torch.nn.Module): 6 | def __init__(self, in_channels, hidden_size, num_classes): 7 | super(M5, self).__init__() 8 | self.in_channels = in_channels 9 | self.num_hiddens = hidden_size 10 | self.num_classes = num_classes 11 | 12 | self.features = torch.nn.Sequential( 13 | torch.nn.Conv1d(self.in_channels, self.num_hiddens, kernel_size=80, stride=16), 14 | torch.nn.BatchNorm1d(self.num_hiddens), 15 | torch.nn.ReLU(True), 16 | torch.nn.MaxPool1d(4), 17 | torch.nn.Conv1d(self.num_hiddens, self.num_hiddens, kernel_size=3), 18 | torch.nn.BatchNorm1d(self.num_hiddens), 19 | torch.nn.ReLU(True), 20 | torch.nn.MaxPool1d(4), 21 | torch.nn.Conv1d(self.num_hiddens, self.num_hiddens * 2, kernel_size=3), 22 | torch.nn.BatchNorm1d(self.num_hiddens * 2), 23 | torch.nn.ReLU(True), 24 | torch.nn.MaxPool1d(4), 25 | torch.nn.Conv1d(self.num_hiddens * 2, self.num_hiddens * 2, kernel_size=3), 26 | torch.nn.BatchNorm1d(self.num_hiddens * 2), 27 | torch.nn.ReLU(True), 28 | torch.nn.MaxPool1d(4), 29 | torch.nn.AdaptiveAvgPool1d(1), 30 | torch.nn.Flatten() 31 | ) 32 | self.classifier = torch.nn.Linear(self.num_hiddens * 2, self.num_classes) 33 | 34 | def forward(self, x): 35 | x = self.features(x) 36 | x = self.classifier(x) 37 | return x -------------------------------------------------------------------------------- /src/models/mobilebert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import MobileBertModel, MobileBertConfig 4 | 5 | 6 | 7 | class MobileBert(torch.nn.Module): 8 | def __init__(self, num_classes, num_embeddings, embedding_size, hidden_size, dropout, use_pt_model, is_seq2seq): 9 | super(MobileBert, self).__init__() 10 | self.is_seq2seq = is_seq2seq 11 | # define encoder 12 | if use_pt_model: # fine-tuning 13 | self.features = MobileBertModel.from_pretrained('google/mobilebert-uncased') 14 | self.num_classes = num_classes 15 | self.num_embeddings = self.features.config.vocab_size 16 | self.embedding_size = self.features.config.embedding_size 17 | self.num_hiddens = self.features.config.hidden_size 18 | self.dropout = self.features.config.hidden_dropout_prob 19 | 20 | self.classifier = torch.nn.Linear(self.features.config.hidden_size, self.num_classes, bias=True) 21 | 22 | else: # from scratch 23 | self.num_classes = num_classes 24 | self.num_embeddings = num_embeddings 25 | self.embedding_size = embedding_size 26 | self.num_hiddens = hidden_size 27 | self.dropout = dropout 28 | 29 | config = MobileBertConfig( 30 | vocab_size=self.num_embeddings, 31 | embedding_size=self.embedding_size, 32 | hidden_size=self.num_hiddens, 33 | hidden_dropout_prob=self.dropout 34 | ) 35 | self.features = MobileBertModel(config) 36 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True) 37 | 38 | def forward(self, x): 39 | x = self.features(x) 40 | x = self.classifier(x['last_hidden_state'] if self.is_seq2seq else x['pooler_output']) 41 | return x 42 | -------------------------------------------------------------------------------- /src/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.models.model_utils import make_divisible, SELayer, InvertedResidualBlock 4 | 5 | 6 | 7 | class MobileNet(torch.nn.Module): # MobileNetv3-small 8 | CONFIG = [# k, t, c, SE, HS, s 9 | [3, 1, 16, 1, 0, 2], 10 | [3, 4.5, 24, 0, 0, 2], 11 | [3, 3.67, 24, 0, 0, 1], 12 | [5, 4, 40, 1, 1, 2], 13 | [5, 6, 40, 1, 1, 1], 14 | [5, 6, 40, 1, 1, 1], 15 | [5, 3, 48, 1, 1, 1], 16 | [5, 3, 48, 1, 1, 1], 17 | [5, 6, 96, 1, 1, 2], 18 | [5, 6, 96, 1, 1, 1], 19 | [5, 6, 96, 1, 1, 1] 20 | ] 21 | 22 | def __init__(self, in_channels, num_classes, dropout): 23 | super(MobileNet, self).__init__() 24 | self.in_channels = in_channels 25 | self.num_classes = num_classes 26 | self.dropout = dropout 27 | 28 | hidden_channels = make_divisible(16, 8) 29 | layers = [ 30 | torch.nn.Sequential( 31 | torch.nn.Conv2d(in_channels, hidden_channels, 3, 2, 1, bias=False), 32 | torch.nn.BatchNorm2d(make_divisible(16, 8)), 33 | torch.nn.Hardswish(True) 34 | ) 35 | ] 36 | 37 | # building inverted residual blocks 38 | for k, t, c, use_se, use_hs, s in self.CONFIG: 39 | out_channels = make_divisible(c * 1, 8) 40 | exp_size = make_divisible(hidden_channels * t, 8) 41 | layers.append(InvertedResidualBlock(hidden_channels, exp_size, out_channels, k, s, use_se, use_hs)) 42 | hidden_channels = out_channels 43 | else: 44 | self.features1 = torch.nn.Sequential(*layers) 45 | 46 | # building last several layers 47 | self.features2 = torch.nn.Sequential( 48 | torch.nn.Conv2d(hidden_channels, exp_size, 1, 1, 0, bias=False), 49 | torch.nn.BatchNorm2d(exp_size), 50 | torch.nn.Hardswish(True), 51 | torch.nn.AdaptiveAvgPool2d((1, 1)) 52 | ) 53 | out_channels = 1024 54 | self.classifier = torch.nn.Sequential( 55 | torch.nn.Flatten(), 56 | torch.nn.Linear(exp_size, out_channels), 57 | torch.nn.Hardswish(True), 58 | torch.nn.Dropout(self.dropout), 59 | torch.nn.Linear(out_channels, self.num_classes), 60 | ) 61 | 62 | for m in self.modules(): 63 | if isinstance(m, torch.nn.Conv2d): 64 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out') 65 | if m.bias is not None: 66 | torch.nn.init.zeros_(m.bias) 67 | elif isinstance(m, torch.nn.BatchNorm2d): 68 | torch.nn.init.ones_(m.weight) 69 | torch.nn.init.zeros_(m.bias) 70 | elif isinstance(m, torch.nn.Linear): 71 | torch.nn.init.normal_(m.weight, 0, 0.01) 72 | torch.nn.init.zeros_(m.bias) 73 | 74 | def forward(self, x): 75 | x = self.features1(x) 76 | x = self.features2(x) 77 | x = self.classifier(x) 78 | return x 79 | -------------------------------------------------------------------------------- /src/models/mobilenext.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from src.models.model_utils import make_divisible, SandGlassLayer 5 | 6 | 7 | 8 | class MobileNeXt(torch.nn.Module): 9 | CONFIG = [# t, c, n, s 10 | [2, 96, 1, 2], 11 | [6, 144, 1, 1], 12 | [6, 192, 3, 2], 13 | [6, 288, 3, 2], 14 | [6, 384, 4, 1], 15 | [6, 576, 4, 2], 16 | [6, 960, 3, 1], 17 | [6, 1280, 1, 1] 18 | ] 19 | 20 | def __init__(self, in_channels, num_classes, dropout): 21 | super(MobileNeXt, self).__init__() 22 | self.in_channels = in_channels 23 | self.num_classes = num_classes 24 | self.dropout = dropout 25 | 26 | # building first layer 27 | hidden_channels = make_divisible(32, 8) 28 | layers = [ 29 | torch.nn.Sequential( 30 | torch.nn.Conv2d(self.in_channels, hidden_channels, 3, 2, 1, bias=False), 31 | torch.nn.BatchNorm2d(hidden_channels), 32 | torch.nn.ReLU6(True) 33 | ) 34 | ] 35 | 36 | # building blocks 37 | for t, c, n, s in self.CONFIG: 38 | out_channels = make_divisible(c, 8) 39 | for i in range(n): 40 | layers.append(SandGlassLayer(hidden_channels, out_channels, s if i == 0 else 1, t)) 41 | hidden_channels = out_channels 42 | self.features = torch.nn.Sequential(*layers) 43 | 44 | # building classifier 45 | self.classifier = torch.nn.Sequential( 46 | torch.nn.AdaptiveAvgPool2d((1, 1)), 47 | torch.nn.Flatten(), 48 | torch.nn.Dropout(self.dropout), 49 | torch.nn.Linear(out_channels, self.num_classes), 50 | ) 51 | 52 | for m in self.modules(): 53 | if isinstance(m, torch.nn.Conv2d): 54 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 55 | m.weight.data.normal_(0, math.sqrt(2. / n)) 56 | if m.bias is not None: 57 | m.bias.data.zero_() 58 | elif isinstance(m, torch.nn.BatchNorm2d): 59 | m.weight.data.fill_(1) 60 | m.bias.data.zero_() 61 | elif isinstance(m, torch.nn.Linear): 62 | m.weight.data.normal_(0, 0.01) 63 | m.bias.data.zero_() 64 | 65 | def forward(self, x): 66 | x = self.features(x) 67 | x = self.classifier(x) 68 | return x 69 | -------------------------------------------------------------------------------- /src/models/mobilevit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.models.model_utils import MV2Block, MobileViTBlock 4 | 5 | 6 | 7 | class MobileViT(torch.nn.Module): 8 | L = [2, 4, 3] 9 | DIMS = [64, 80, 96] 10 | CHANNELS = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320] 11 | 12 | def __init__(self, crop, resize, in_channels, num_classes, dropout): 13 | super(MobileViT, self).__init__() 14 | size = crop if resize is None else resize 15 | assert (size >= 16) and (size % 2 == 0) 16 | 17 | self.size = size 18 | self.in_channels = in_channels 19 | self.num_classes = num_classes 20 | self.dropout = dropout 21 | self.patch_size = 2 22 | 23 | self.conv1 = torch.nn.Sequential( 24 | torch.nn.Conv2d(self.in_channels, self.CHANNELS[0], 3, 2, 1, bias=False), 25 | torch.nn.BatchNorm2d(self.CHANNELS[0]), 26 | torch.nn.SiLU(True) 27 | ) 28 | 29 | self.mv2 = torch.nn.ModuleList([]) 30 | self.mv2.append(MV2Block(self.CHANNELS[0], self.CHANNELS[1], 1, 4)) 31 | self.mv2.append(MV2Block(self.CHANNELS[1], self.CHANNELS[2], 2, 4)) 32 | self.mv2.append(MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4)) 33 | self.mv2.append(MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4)) 34 | self.mv2.append(MV2Block(self.CHANNELS[3], self.CHANNELS[4], 2, 4)) 35 | self.mv2.append(MV2Block(self.CHANNELS[5], self.CHANNELS[6], 2, 4)) 36 | self.mv2.append(MV2Block(self.CHANNELS[7], self.CHANNELS[8], 2, 4)) 37 | 38 | self.mvit = torch.nn.ModuleList([]) 39 | self.mvit.append(MobileViTBlock(self.DIMS[0], self.L[0], self.CHANNELS[5], 3, self.patch_size, int(self.DIMS[0] * 2))) 40 | self.mvit.append(MobileViTBlock(self.DIMS[1], self.L[1], self.CHANNELS[7], 3, self.patch_size, int(self.DIMS[1] * 4))) 41 | self.mvit.append(MobileViTBlock(self.DIMS[2], self.L[2], self.CHANNELS[9], 3, self.patch_size, int(self.DIMS[2] * 4))) 42 | 43 | self.features = torch.nn.Sequential( 44 | torch.nn.Conv2d(self.in_channels, self.CHANNELS[0], 3, 2, 1, bias=False), 45 | torch.nn.BatchNorm2d(self.CHANNELS[0]), 46 | torch.nn.SiLU(True), 47 | MV2Block(self.CHANNELS[0], self.CHANNELS[1], 1, 4), 48 | MV2Block(self.CHANNELS[1], self.CHANNELS[2], 2, 4), 49 | MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4), 50 | MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4), 51 | MV2Block(self.CHANNELS[3], self.CHANNELS[4], 2, 4), 52 | MobileViTBlock(self.DIMS[0], self.L[0], self.CHANNELS[5], 3, 2, int(self.DIMS[0] * 2)), 53 | MV2Block(self.CHANNELS[5], self.CHANNELS[6], 2, 4), 54 | MobileViTBlock(self.DIMS[1], self.L[1], self.CHANNELS[7], 3, 2, int(self.DIMS[1] * 4)), 55 | MV2Block(self.CHANNELS[7], self.CHANNELS[8], 2, 4), 56 | MobileViTBlock(self.DIMS[2], self.L[2], self.CHANNELS[9], 3, 2, int(self.DIMS[2] * 4)), 57 | torch.nn.Conv2d(self.CHANNELS[-2], self.CHANNELS[-1], 1, 1, 0, bias=False), 58 | torch.nn.BatchNorm2d(self.CHANNELS[-1]), 59 | torch.nn.SiLU(True) 60 | ) 61 | self.classifier = torch.nn.Sequential( 62 | torch.nn.AvgPool2d(self.size // 32, 1), 63 | torch.nn.Flatten(), 64 | torch.nn.Dropout(self.dropout), 65 | torch.nn.Linear(self.CHANNELS[-1], self.num_classes, bias=False) 66 | ) 67 | 68 | def forward(self, x): 69 | x = self.features(x) 70 | x = self.classifier(x) 71 | return x 72 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.models.model_utils import ResidualBlock 4 | 5 | 6 | 7 | __all__ = ['ResNet10', 'ResNet18', 'ResNet34'] 8 | 9 | CONFIGS = { 10 | 'ResNet10': [1, 1, 1, 1], 11 | 'ResNet18': [2, 2, 2, 2], 12 | 'ResNet34': [3, 4, 6, 3] 13 | } 14 | 15 | class ResNet(torch.nn.Module): 16 | def __init__(self, config, block, in_channels, hidden_size, num_classes): 17 | super(ResNet, self).__init__() 18 | self.in_channels = in_channels 19 | self.hidden_size = hidden_size 20 | self.num_classes = num_classes 21 | 22 | self.features = torch.nn.Sequential( 23 | torch.nn.Conv2d(self.in_channels, self.hidden_size, kernel_size=3, stride=1, padding=1, bias=False), 24 | torch.nn.BatchNorm2d(self.hidden_size), 25 | torch.nn.ReLU(True), 26 | self._make_layers(block, self.hidden_size, config[0], stride=1), 27 | self._make_layers(block, self.hidden_size * 2, config[1], stride=2), 28 | self._make_layers(block, self.hidden_size * 4, config[2], stride=2), 29 | self._make_layers(block, self.hidden_size * 8, config[3], stride=2), 30 | ) 31 | self.classifier = torch.nn.Sequential( 32 | torch.nn.AdaptiveAvgPool2d((7, 7)), 33 | torch.nn.Flatten(), 34 | torch.nn.Linear((7 * 7) * self.hidden_size, self.num_classes, bias=True) 35 | ) 36 | 37 | def forward(self, x): 38 | x = self.features(x) 39 | x = self.classifier(x) 40 | return x 41 | 42 | def _make_layers(self, block, planes, num_blocks, stride): 43 | strides = [stride] + [1] * (num_blocks - 1) 44 | layers = [] 45 | for stride in strides: 46 | layers.append(block(self.hidden_size, planes, stride)) 47 | self.hidden_size = planes 48 | return torch.nn.Sequential(*layers) 49 | 50 | class ResNet10(ResNet): 51 | def __init__(self, in_channels, hidden_size, num_classes): 52 | super(ResNet10, self).__init__(CONFIGS['ResNet10'], ResidualBlock, in_channels, hidden_size, num_classes) 53 | 54 | class ResNet18(ResNet): 55 | def __init__(self, in_channels, hidden_size, num_classes): 56 | super(ResNet18, self).__init__(CONFIGS['ResNet18'], ResidualBlock, in_channels, hidden_size, num_classes) 57 | 58 | class ResNet34(ResNet): 59 | def __init__(self, in_channels, hidden_size, num_classes): 60 | super(ResNet34, self).__init__(CONFIGS['ResNet34'], ResidualBlock, in_channels, hidden_size, num_classes) 61 | -------------------------------------------------------------------------------- /src/models/sent140lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.models.model_utils import Lambda 4 | 5 | 6 | 7 | class Sent140LSTM(torch.nn.Module): 8 | def __init__(self, num_classes, embedding_size, hidden_size, dropout, num_layers, glove_emb): 9 | super(Sent140LSTM, self).__init__() 10 | self.embedding_size = embedding_size 11 | self.num_hiddens = hidden_size 12 | self.num_classes = num_classes 13 | self.dropout = dropout 14 | self.num_layers = num_layers 15 | self.glove_emb = glove_emb 16 | 17 | self.features = torch.nn.Sequential( 18 | torch.nn.Embedding.from_pretrained(self.glove_emb, padding_idx=0), 19 | torch.nn.LSTM( 20 | input_size=self.embedding_size, 21 | hidden_size=self.num_hiddens, 22 | num_layers=self.num_layers, 23 | batch_first=True, 24 | dropout=self.dropout, 25 | bias=True 26 | ), 27 | Lambda(lambda x: x[0]) 28 | ) 29 | self.classifier = torch.nn.Sequential( 30 | torch.nn.Linear(self.num_hiddens, self.num_hiddens, bias=True), 31 | torch.nn.ReLU(True), 32 | torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True) 33 | ) 34 | 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | x = self.classifier(x[:, -1, :]) 39 | return x 40 | -------------------------------------------------------------------------------- /src/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.models.model_utils import ShuffleNetInvRes 4 | 5 | 6 | 7 | class ShuffleNet(torch.nn.Module): 8 | def __init__(self, in_channels, num_classes, dropout): 9 | super(ShuffleNet, self).__init__() 10 | self.in_channels = in_channels 11 | self.num_classes = num_classes 12 | self.dropout = dropout 13 | 14 | self.stage_repeats = [4, 8, 4] 15 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 16 | 17 | # building feature extractor 18 | features = [] 19 | 20 | # input layers 21 | hidden_channels = self.stage_out_channels[1] 22 | features.append( 23 | torch.nn.Sequential( 24 | torch.nn.Conv2d(self.in_channels, hidden_channels, 3, 2, 1, bias=False), 25 | torch.nn.BatchNorm2d(hidden_channels), 26 | torch.nn.ReLU(True), 27 | torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | ) 29 | ) 30 | 31 | # inverted residual layers 32 | for idx, num_repeats in enumerate(self.stage_repeats): 33 | out_channels = self.stage_out_channels[idx + 2] 34 | for i in range(num_repeats): 35 | if i == 0: 36 | features.append(ShuffleNetInvRes(hidden_channels, out_channels, 2, 2)) 37 | else: 38 | features.append(ShuffleNetInvRes(hidden_channels, out_channels, 1, 1)) 39 | hidden_channels = out_channels 40 | 41 | # pooling layers 42 | features.append( 43 | torch.nn.Sequential( 44 | torch.nn.Conv2d(hidden_channels, self.stage_out_channels[-1], 1, 1, 0, bias=False), 45 | torch.nn.BatchNorm2d(self.stage_out_channels[-1]), 46 | torch.nn.ReLU(True), 47 | torch.nn.AdaptiveAvgPool2d((1, 1)), 48 | torch.nn.Flatten() 49 | ) 50 | ) 51 | self.features = torch.nn.Sequential(*features) 52 | self.classifier = torch.nn.Linear(self.stage_out_channels[-1], self.num_classes) 53 | 54 | def forward(self, x): 55 | x = self.features(x) 56 | x = self.classifier(x) 57 | return x -------------------------------------------------------------------------------- /src/models/simplecnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class SimpleCNN(torch.nn.Module): # for CIFAR10 experiment in McMahan et al., 2016; (https://github.com/tensorflow/tensorflow/blob/r0.11/tensorflow/models/image/cifar10/cifar10.py) 6 | def __init__(self, in_channels, hidden_size, num_classes): 7 | super(SimpleCNN, self).__init__() 8 | self.in_channels = in_channels 9 | self.hidden_channels = hidden_size 10 | self.num_classes = num_classes 11 | 12 | self.features = torch.nn.Sequential( 13 | torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=5, padding=2, stride=1, bias=True), 14 | torch.nn.ReLU(), 15 | torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 16 | torch.nn.LocalResponseNorm(size=9, alpha=0.001), 17 | torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=5, padding=2, stride=1, bias=True), 18 | torch.nn.ReLU(), 19 | torch.nn.LocalResponseNorm(size=9, alpha=0.001), 20 | torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 21 | ) 22 | self.classifier = torch.nn.Sequential( 23 | torch.nn.AdaptiveAvgPool2d((6, 6)), 24 | torch.nn.Flatten(), 25 | torch.nn.Linear(in_features=self.hidden_channels * (6 * 6), out_features=384, bias=True), 26 | torch.nn.ReLU(), 27 | torch.nn.Linear(in_features=384, out_features=192, bias=True), 28 | torch.nn.ReLU(), 29 | torch.nn.Linear(in_features=192, out_features=self.num_classes, bias=True) 30 | ) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = self.classifier(x) 35 | return x 36 | -------------------------------------------------------------------------------- /src/models/squeezebert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import SqueezeBertModel, SqueezeBertConfig 4 | 5 | 6 | 7 | class SqueezeBert(torch.nn.Module): 8 | def __init__(self, num_classes, num_embeddings, embedding_size, hidden_size, dropout, use_pt_model, is_seq2seq): 9 | super(SqueezeBert, self).__init__() 10 | self.is_seq2seq = is_seq2seq 11 | # define encoder 12 | if use_pt_model: # fine-tuning 13 | self.features = SqueezeBertModel.from_pretrained('squeezebert/squeezebert-uncased') 14 | self.num_classes = num_classes 15 | self.num_embeddings = self.features.config.vocab_size 16 | self.embedding_size = self.features.config.embedding_size 17 | self.num_hiddens = self.features.config.hidden_size 18 | self.dropout = self.features.config.hidden_dropout_prob 19 | 20 | self.classifier = torch.nn.Linear(self.features.config.hidden_size, self.num_classes, bias=True) 21 | 22 | else: # from scratch 23 | assert embedding_size == hidden_size, 'If you want embedding_size != intermediate hidden_size, please insert a Conv1d layer to adjust the number of channels before the first SqueezeBertModule.' 24 | self.num_classes = num_classes 25 | self.num_embeddings = num_embeddings 26 | self.embedding_size = embedding_size 27 | self.num_hiddens = hidden_size 28 | self.dropout = dropout 29 | 30 | config = SqueezeBertConfig( 31 | vocab_size=self.num_embeddings, 32 | embedding_size=self.embedding_size, 33 | hidden_size=self.num_hiddens, 34 | hidden_dropout_prob=self.dropout 35 | ) 36 | self.features = SqueezeBertModel(config) 37 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True) 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = self.classifier(x['last_hidden_state'] if self.is_seq2seq else x['pooler_output']) 42 | return x 43 | -------------------------------------------------------------------------------- /src/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from src.models.model_utils import FireBlock 5 | 6 | 7 | 8 | class SqueezeNet(torch.nn.Module): # MobileNetv3-small 9 | def __init__(self, in_channels, num_classes, dropout): 10 | super(SqueezeNet, self).__init__() 11 | self.in_channels = in_channels 12 | self.num_classes = num_classes 13 | self.dropout = dropout 14 | 15 | self.features = torch.nn.Sequential( 16 | torch.nn.Conv2d(self.in_channels, 64, kernel_size=3, stride=2), 17 | torch.nn.ReLU(True), 18 | torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 19 | FireBlock(64, 16, 64, 64), 20 | FireBlock(128, 16, 64, 64), 21 | torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 22 | FireBlock(128, 32, 128, 128), 23 | FireBlock(256, 32, 128, 128), 24 | torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 25 | FireBlock(256, 48, 192, 192), 26 | FireBlock(384, 48, 192, 192), 27 | FireBlock(384, 64, 256, 256), 28 | FireBlock(512, 64, 256, 256), 29 | ) 30 | final_conv = torch.nn.Conv2d(512, self.num_classes, kernel_size=1) 31 | self.classifier = torch.nn.Sequential( 32 | torch.nn.Dropout(self.dropout), 33 | final_conv, 34 | torch.nn.ReLU(True), 35 | torch.nn.AdaptiveAvgPool2d((1, 1)) 36 | ) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, torch.nn.Conv2d): 40 | if m is final_conv: 41 | torch.nn.init.normal_(m.weight, mean=0.0, std=0.01) 42 | else: 43 | torch.nn.init.kaiming_uniform_(m.weight) 44 | if m.bias is not None: 45 | torch.nn.init.constant_(m.bias, 0) 46 | 47 | def forward(self, x): 48 | x = self.features(x) 49 | x = self.classifier(x) 50 | x = torch.flatten(x, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /src/models/squeezenext.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from src.models.model_utils import SNXBlock 5 | 6 | 7 | 8 | class SqueezeNeXt(torch.nn.Module): # MobileNetv3-small 9 | def __init__(self, in_channels, num_classes, dropout): 10 | super(SqueezeNeXt, self).__init__() 11 | self.in_channels = in_channels 12 | self.num_classes = num_classes 13 | self.dropout = dropout 14 | 15 | self.hidden_channels = 64 16 | self.features = torch.nn.Sequential( 17 | torch.nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, 1, bias=False), 18 | torch.nn.BatchNorm2d(self.hidden_channels), 19 | torch.nn.ReLU(True), 20 | self._make_layer(2, 32, 1), 21 | self._make_layer(4, 64, 2), 22 | self._make_layer(14, 128, 2), 23 | self._make_layer(1, 256, 2), 24 | torch.nn.Conv2d(self.hidden_channels, 128, 1, 1, bias=False), 25 | torch.nn.BatchNorm2d(128), 26 | torch.nn.ReLU(True), 27 | torch.nn.AdaptiveAvgPool2d((1, 1)) 28 | ) 29 | self.classifier = torch.nn.Sequential( 30 | torch.nn.Flatten(), 31 | torch.nn.Dropout(self.dropout), 32 | torch.nn.Linear(128, self.num_classes) 33 | ) 34 | 35 | for m in self.modules(): 36 | if isinstance(m, torch.nn.Conv2d): 37 | torch.nn.init.xavier_uniform_(m.weight, gain=math.sqrt(2.)) 38 | if m.bias is not None: 39 | torch.nn.init.constant_(m.bias, 0.) 40 | elif isinstance(m, torch.nn.BatchNorm2d): 41 | torch.nn.init.constant_(m.weight, 1.) 42 | torch.nn.init.constant_(m.bias, 0.) 43 | 44 | def _make_layer(self, num_block, out_channels, stride): 45 | strides = [stride] + [1] * (num_block - 1) 46 | layers = [] 47 | for s in strides: 48 | layers.append(SNXBlock(self.hidden_channels, out_channels, s)) 49 | self.hidden_channels = out_channels 50 | return torch.nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | x = self.features(x) 54 | x = self.classifier(x) 55 | return x 56 | -------------------------------------------------------------------------------- /src/models/stackedlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.models.model_utils import Lambda 4 | 5 | 6 | 7 | class StackedLSTM(torch.nn.Module): 8 | def __init__(self, num_classes, embedding_size, num_embeddings, hidden_size, dropout, num_layers, is_seq2seq): 9 | super(StackedLSTM, self).__init__() 10 | self.is_seq2seq = is_seq2seq 11 | self.num_hiddens = hidden_size 12 | self.num_classes = num_classes 13 | self.num_embeddings = num_embeddings 14 | self.embedding_size = embedding_size 15 | self.dropout = dropout 16 | self.num_layers = num_layers 17 | 18 | self.features = torch.nn.Sequential( 19 | torch.nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_size), 20 | torch.nn.LSTM( 21 | input_size=self.embedding_size, 22 | hidden_size=self.num_hiddens, 23 | num_layers=self.num_layers, 24 | batch_first=True, 25 | dropout=self.dropout, 26 | bias=True 27 | ), 28 | Lambda(lambda x: x[0]) 29 | ) 30 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = self.classifier(x if self.is_seq2seq else x[:, -1, :]) 35 | return x 36 | -------------------------------------------------------------------------------- /src/models/stackedtransformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from src.models.model_utils import Lambda, PositionalEncoding 5 | 6 | 7 | 8 | class StackedTransformer(torch.nn.Module): 9 | def __init__(self, num_classes, embedding_size, num_embeddings, hidden_size, seq_len, dropout, num_layers, is_seq2seq): 10 | super(StackedTransformer, self).__init__() 11 | self.is_seq2seq = is_seq2seq 12 | self.num_hiddens = hidden_size 13 | self.num_classes = num_classes 14 | self.num_embeddings = num_embeddings 15 | self.embedding_size = embedding_size 16 | self.seq_len = seq_len 17 | self.dropout = dropout 18 | self.num_layers = num_layers 19 | 20 | self.features = torch.nn.Sequential( 21 | torch.nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_size), 22 | PositionalEncoding(self.embedding_size, self.dropout), 23 | torch.nn.TransformerEncoder( 24 | torch.nn.TransformerEncoderLayer(self.embedding_size, 16, self.num_hiddens, self.dropout, batch_first=True), 25 | self.num_layers 26 | ) 27 | ) 28 | self.classifier = torch.nn.Linear(self.embedding_size, self.num_classes, bias=True) 29 | 30 | def forward(self, x): 31 | x = self.features(x) 32 | x = self.classifier(x if self.is_seq2seq else x[:, 0, :]) 33 | return x 34 | -------------------------------------------------------------------------------- /src/models/twocnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class TwoCNN(torch.nn.Module): # McMahan et al., 2016; 1,663,370 parameters 6 | def __init__(self, in_channels, hidden_size, num_classes): 7 | super(TwoCNN, self).__init__() 8 | self.in_channels = in_channels 9 | self.hidden_channels = hidden_size 10 | self.num_classes = num_classes 11 | 12 | self.features = torch.nn.Sequential( 13 | torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=(5, 5), padding=1, stride=1, bias=True), 14 | torch.nn.ReLU(True), 15 | torch.nn.MaxPool2d(kernel_size=(2, 2), padding=1), 16 | torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels * 2, kernel_size=(5, 5), padding=1, stride=1, bias=True), 17 | torch.nn.ReLU(True), 18 | torch.nn.MaxPool2d(kernel_size=(2, 2), padding=1) 19 | ) 20 | self.classifier = torch.nn.Sequential( 21 | torch.nn.AdaptiveAvgPool2d((7, 7)), 22 | torch.nn.Flatten(), 23 | torch.nn.Linear(in_features=(self.hidden_channels * 2) * (7 * 7), out_features=512, bias=True), 24 | torch.nn.ReLU(True), 25 | torch.nn.Linear(in_features=512, out_features=self.num_classes, bias=True) 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = self.classifier(x) 31 | return x 32 | -------------------------------------------------------------------------------- /src/models/twonn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class TwoNN(torch.nn.Module): # McMahan et al., 2016; 199,210 parameters 6 | def __init__(self, resize, hidden_size, num_classes): 7 | super(TwoNN, self).__init__() 8 | self.in_features = resize**2 9 | self.num_hiddens = hidden_size 10 | self.num_classes = num_classes 11 | 12 | self.features = torch.nn.Sequential( 13 | torch.nn.Flatten(), 14 | torch.nn.Linear(in_features=self.in_features, out_features=self.num_hiddens, bias=True), 15 | torch.nn.ReLU(True), 16 | torch.nn.Linear(in_features=self.num_hiddens, out_features=self.num_hiddens, bias=True), 17 | torch.nn.ReLU(True) 18 | ) 19 | self.classifier = torch.nn.Linear(in_features=self.num_hiddens, out_features=self.num_classes, bias=True) 20 | 21 | def forward(self, x): 22 | x = self.features(x) 23 | x = self.classifier(x) 24 | return x 25 | -------------------------------------------------------------------------------- /src/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | __all__ = ['VGG9', 'VGG9BN', 'VGG11', 'VGG11BN', 'VGG13', 'VGG13BN'] 6 | 7 | CONFIGS = { 8 | 'VGG9': [64, 'mp', 128, 'mp', 256, 256, 'mp', 512, 512, 'mp'], 9 | 'VGG11': [64, 64, 'mp', 128, 128, 'mp', 256, 256, 'mp', 512, 512, 'mp'], 10 | 'VGG13': [64, 64, 'mp', 128, 128, 'mp', 256, 256, 'mp', 512, 512, 'mp', 512, 512, 'mp'], 11 | } 12 | 13 | class VGG(torch.nn.Module): 14 | def __init__(self, config, use_bn, in_channels, num_classes, dropout): 15 | super(VGG, self).__init__() 16 | self.use_bn = use_bn 17 | 18 | self.in_channels = in_channels 19 | self.num_classes = num_classes 20 | self.dropout = dropout 21 | 22 | self.features = torch.nn.Sequential(*self._make_layers(config, use_bn)) 23 | self.classifier = torch.nn.Sequential( 24 | torch.nn.AdaptiveAvgPool2d((7, 7)), 25 | torch.nn.Flatten(), 26 | torch.nn.Linear((7 * 7) * 512, 4096, bias=True), 27 | torch.nn.ReLU(True), 28 | torch.nn.Dropout(self.dropout), 29 | torch.nn.Linear(4096, 4096), 30 | torch.nn.ReLU(True), 31 | torch.nn.Dropout(self.dropout), 32 | torch.nn.Linear(4096, self.num_classes) 33 | ) 34 | 35 | def forward(self, x): 36 | x = self.features(x) 37 | x = self.classifier(x) 38 | return x 39 | 40 | def _make_layers(self, config, use_bn): 41 | layers = [] 42 | in_channels = self.in_channels 43 | for v in config: 44 | if v == 'mp': 45 | layers.append(torch.nn.MaxPool2d(2, 2)) 46 | else: 47 | layers.append(torch.nn.Conv2d(in_channels, v, 3, 1, 1)) 48 | if use_bn: 49 | layers.append(torch.nn.BatchNorm2d(v)) 50 | layers.append(torch.nn.ReLU(True)) 51 | in_channels = v 52 | return layers 53 | 54 | class VGG9(VGG): 55 | def __init__(self, in_channels, num_classes, dropout): 56 | super(VGG9, self).__init__(CONFIGS['VGG9'], False, in_channels, num_classes, dropout) 57 | 58 | class VGG9BN(VGG): 59 | def __init__(self, in_channels, num_classes, dropout): 60 | super(VGG9BN, self).__init__(CONFIGS['VGG9'], True, in_channels, num_classes, dropout) 61 | 62 | class VGG11(VGG): 63 | def __init__(self, in_channels, num_classes, dropout): 64 | super(VGG11, self).__init__(CONFIGS['VGG11'], False, in_channels, num_classes, dropout) 65 | 66 | class VGG11BN(VGG): 67 | def __init__(self, in_channels, num_classes, dropout): 68 | super(VGG11BN, self).__init__(CONFIGS['VGG11'], True, in_channels, num_classes, dropout) 69 | 70 | class VGG13(VGG): 71 | def __init__(self, in_channels, num_classes, dropout): 72 | super(VGG13, self).__init__(CONFIGS['VGG13'], False, in_channels, num_classes, dropout) 73 | 74 | class VGG13BN(VGG): 75 | def __init__(self, in_channels, num_classes, dropout): 76 | super(VGG13BN, self).__init__(CONFIGS['VGG13'], True, in_channels, num_classes, dropout) 77 | -------------------------------------------------------------------------------- /src/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/server/__init__.py -------------------------------------------------------------------------------- /src/server/baseserver.py: -------------------------------------------------------------------------------- 1 | from abc import * 2 | 3 | 4 | class BaseServer(metaclass=ABCMeta): 5 | """Centeral server orchestrating the whole process of federated learning. 6 | """ 7 | def __init__(self, **kwargs): 8 | self._round = 0 9 | self._model = None 10 | self._clients = None 11 | 12 | @property 13 | def model(self): 14 | return self._model 15 | 16 | @model.setter 17 | def model(self, model): 18 | self._model = model 19 | 20 | @property 21 | def round(self): 22 | return self._round 23 | 24 | @round.setter 25 | def round(self, round): 26 | self._round = round 27 | 28 | @property 29 | def clients(self): 30 | return self._clients 31 | 32 | @clients.setter 33 | def clients(self, clients): 34 | self._clients = clients 35 | 36 | @abstractmethod 37 | def _init_model(self, model): 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def _get_algorithm(self, model, **kwargs): 42 | raise NotImplementedError 43 | 44 | @abstractmethod 45 | def _create_clients(self, client_datasets): 46 | raise NotImplementedError 47 | 48 | @abstractmethod 49 | def _sample_clients(self): 50 | raise NotImplementedError 51 | 52 | @abstractmethod 53 | def _request(self, indices, eval=False): 54 | raise NotImplementedError 55 | 56 | @abstractmethod 57 | def _aggregate(self, indices, update_sizes): 58 | raise NotImplementedError 59 | 60 | @abstractmethod 61 | def _central_evaluate(self): 62 | raise NotImplementedError 63 | 64 | @abstractmethod 65 | def update(self): 66 | raise NotImplementedError 67 | 68 | @abstractmethod 69 | def evaluate(self): 70 | raise NotImplementedError 71 | 72 | @abstractmethod 73 | def finalize(self): 74 | raise NotImplementedError 75 | -------------------------------------------------------------------------------- /src/server/fedadagradserver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .fedavgserver import FedavgServer 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | 9 | class FedadagradServer(FedavgServer): 10 | def __init__(self, **kwargs): 11 | super(FedadagradServer, self).__init__(**kwargs) 12 | self.opt_kwargs = dict( 13 | beta=self.args.beta1, 14 | v0=self.args.tau**2, 15 | tau=self.args.tau, 16 | lr=self.args.server_lr 17 | ) 18 | -------------------------------------------------------------------------------- /src/server/fedadamserver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .fedavgserver import FedavgServer 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | 9 | class FedadamServer(FedavgServer): 10 | def __init__(self, **kwargs): 11 | super(FedadamServer, self).__init__(**kwargs) 12 | self.opt_kwargs = dict( 13 | betas=(self.args.beta1, self.args.beta2), 14 | v0=self.args.tau**2, 15 | tau=self.args.tau, 16 | lr=self.args.server_lr 17 | ) 18 | -------------------------------------------------------------------------------- /src/server/fedavgmserver.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | from .fedavgserver import FedavgServer 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | 10 | class FedavgmServer(FedavgServer): 11 | def __init__(self, **kwargs): 12 | super(FedavgmServer, self).__init__(**kwargs) 13 | -------------------------------------------------------------------------------- /src/server/fedproxserver.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | from .fedavgserver import FedavgServer 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | 10 | class FedproxServer(FedavgServer): 11 | def __init__(self, **kwargs): 12 | super(FedproxServer, self).__init__(**kwargs) 13 | -------------------------------------------------------------------------------- /src/server/fedsgdserver.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | from .fedavgserver import FedavgServer 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | 10 | class FedsgdServer(FedavgServer): 11 | def __init__(self, **kwargs): 12 | super(FedsgdServer, self).__init__(**kwargs) 13 | -------------------------------------------------------------------------------- /src/server/fedyogiserver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .fedavgserver import FedavgServer 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | 9 | class FedyogiServer(FedavgServer): 10 | def __init__(self, **kwargs): 11 | super(FedyogiServer, self).__init__(**kwargs) 12 | self.opt_kwargs = dict( 13 | betas=(self.args.beta1, self.args.beta2), 14 | v0=self.args.tau**2, 15 | tau=self.args.tau, 16 | lr=self.args.server_lr 17 | ) 18 | --------------------------------------------------------------------------------