├── LICENSE
├── README.md
├── commands
├── leaf
│ ├── fedavg_celeba.sh
│ ├── fedavg_femnist.sh
│ ├── fedavg_reddit.sh
│ ├── fedavg_sent140.sh
│ ├── fedavg_shakespeare.sh
│ └── leaf.md
├── original_cifar10
│ ├── cifar10.md
│ ├── fedavg_cifar10.sh
│ └── fedsgd_cifar10.sh
├── original_mnist
│ ├── fedavg_mnist_2nn.sh
│ ├── fedavg_mnist_cnn.sh
│ ├── fedsgd_mnist_2nn.sh
│ ├── fedsgd_mnist_cnn.sh
│ └── mnist.md
└── original_play_n_role
│ ├── Shakespeare.md
│ ├── fedavg_play_n_role.sh
│ └── fedsgd_play_n_role.sh
├── main.py
├── requirments.txt
└── src
├── __init__.py
├── algorithm
├── __init__.py
├── basealgorithm.py
├── fedadagrad.py
├── fedadam.py
├── fedavg.py
├── fedavgm.py
├── fedprox.py
├── fedsgd.py
├── fedyogi.py
└── sent140.sh
├── client
├── __init__.py
├── baseclient.py
├── fedadagradclient.py
├── fedadamclient.py
├── fedavgclient.py
├── fedavgmclient.py
├── fedproxclient.py
├── fedsgdclient.py
└── fedyogiclient.py
├── datasets
├── __init__.py
├── adult.py
├── beerreviews.py
├── cinic10.py
├── cover.py
├── gleam.py
├── heart.py
├── leaf
│ ├── __init__.py
│ ├── leaf_utils.py
│ ├── postprocess
│ │ ├── __init__.py
│ │ ├── filter.py
│ │ ├── postprocess.py
│ │ ├── sample.py
│ │ └── split.py
│ └── preprocess
│ │ ├── __init__.py
│ │ ├── celeba.py
│ │ ├── femnist.py
│ │ ├── reddit.py
│ │ ├── sent140.py
│ │ └── shakespeare.py
├── leafparser.py
├── speechcommands.py
├── tinyimagenet.py
├── torchtextparser.py
└── torchvisionparser.py
├── loaders
├── __init__.py
├── data.py
├── model.py
└── split.py
├── metrics
├── __init__.py
├── basemetric.py
└── metricszoo.py
├── models
├── __init__.py
├── distilbert.py
├── femnistcnn.py
├── lenet.py
├── logreg.py
├── m5.py
├── mobilebert.py
├── mobilenet.py
├── mobilenext.py
├── mobilevit.py
├── model_utils.py
├── resnet.py
├── sent140lstm.py
├── shufflenet.py
├── simplecnn.py
├── squeezebert.py
├── squeezenet.py
├── squeezenext.py
├── stackedlstm.py
├── stackedtransformer.py
├── twocnn.py
├── twonn.py
└── vgg.py
├── server
├── __init__.py
├── baseserver.py
├── fedadagradserver.py
├── fedadamserver.py
├── fedavgmserver.py
├── fedavgserver.py
├── fedproxserver.py
├── fedsgdserver.py
└── fedyogiserver.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Seok-Ju Hahn (GitHub: vaseline555; Email: sjhahn11512@gmail.com)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Federated Learning in PyTorch
3 | Implementations of various Federated Learning (FL) algorithms in PyTorch, especially for research purposes.
4 |
5 | ## Implementation Details
6 | ### Datasets
7 | * Supports all image classification datasets in `torchvision.datasets`.
8 | * Supports all text classification datasets in `torchtext.datasets`.
9 | * Supports all datasets in [LEAF benchmark](https://leaf.cmu.edu/) (*NO need to prepare raw data manually*)
10 | * Supports additional image classification datasets ([`TinyImageNet`](https://www.kaggle.com/c/tiny-imagenet), [`CINIC10`](https://datashare.ed.ac.uk/handle/10283/3192)).
11 | * Supports additional text classification datasets ([`BeerReviews`](https://snap.stanford.edu/data/web-BeerAdvocate.html)).
12 | * Supports tabular datasets ([`Heart`, `Adult`, `Cover`](https://archive.ics.uci.edu/ml/index.php)).
13 | * Supports temporal dataset ([`GLEAM`](http://www.skleinberg.org/data.html))
14 | * __NOTE__: don't bother to search raw files of datasets; the dataset can automatically be downloaded to the designated path by just passing its name!
15 | ### Statistical Heterogeneity Simulations
16 | * `IID` (i.e., statistical homogeneity)
17 | * `Unbalanced` (i.e., sample counts heterogeneity)
18 | * `Pathological Non-IID` ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629))
19 | * `Dirichlet distribution-based Non-IID` ([Hsu et al., 2019](https://arxiv.org/abs/1909.06335))
20 | * `Pre-defined` (for datasets having natural semantic separation, including `LEAF` benchmark ([Caldas et al., 2018](https://arxiv.org/abs/1812.01097)))
21 | ### Models
22 | * `LogReg` (logistic regression), `StackedTransformer` (TransformerEncoder-based classifier)
23 | * `TwoNN`, `TwoCNN`, `SimpleCNN` ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629))
24 | * `FEMNISTCNN`, `Sent140LSTM` ([Caldas et al., 2018](https://arxiv.org/abs/1812.01097)))
25 | * `LeNet` ([LeCun et al., 1998](https://ieeexplore.ieee.org/document/726791/)), `MobileNet` ([Howard et al., 2019](https://arxiv.org/abs/1905.02244)), `SqueezeNet` ([Iandola et al., 2016](https://arxiv.org/abs/1602.07360)), `VGG` ([Simonyan et al., 2014](https://arxiv.org/abs/1409.1556)), `ResNet` ([He et al., 2015](https://arxiv.org/abs/1512.03385))
26 | * `MobileNeXt` ([Daquan et al., 2020](https://arxiv.org/abs/2007.02269)), `SqueezeNeXt` ([Gholami et al., 2016](https://arxiv.org/abs/1803.10615)), `MobileViT` ([Mehta et al., 2021](https://arxiv.org/abs/2110.02178))
27 | * `DistilBERT` ([Sanh et al., 2019](https://arxiv.org/abs/1910.01108)), `SqueezeBERT` ([Iandola et al., 2020](https://arxiv.org/abs/2006.11316)), `MobileBERT` ([Sun et al., 2020](https://arxiv.org/abs/2004.02984))
28 | * `M5` ([Dai et al., 2016](https://arxiv.org/abs/1610.00087))
29 | ### Algorithms
30 | * `FedAvg` and `FedSGD` (McMahan et al., 2016) Communication-Efficient Learning of Deep Networks from Decentralized Data
31 | * `FedAvgM` (Hsu et al., 2019) Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification
32 | * `FedProx` (Li et al., 2018) Federated Optimization in Heterogeneous Networks
33 | * `FedOpt` (`FedAdam`, `FedYogi`, `FedAdaGrad`) (Reddi et al., 2020) Adaptive Federated Optimization
34 |
35 | ### Evaluation schemes
36 | * `local`: evaluate FL algorithm using holdout sets of (some/all) clients NOT participating in the current round. (i.e., evaluation of personalized federated learning setting)
37 | * `global`: evaluate FL algorithm using global holdout set located at the server. (*ONLY available if the raw dataset supports pre-defined validation/test set*).
38 | * `both`: evaluate FL algorithm using both `local` and `global` schemes.
39 | ### Metrics
40 | * Top-1 Accuracy, Top-5 Accuracy, Precision, Recall, F1
41 | * Area under ROC, Area under PRC, Youden's J
42 | * Seq2Seq Accuracy
43 | * MSE, RMSE, MAE, MAPE
44 | * $R^2$, $D^2$
45 |
46 | ## Requirements
47 | * See `requirements.txt`. (I recommend building an independent environment for this project, using e.g., `Docker` or `conda`)
48 | * When you install `torchtext`, please check the version compatibility with `torch`. (See [official repository](https://github.com/pytorch/text#installation))
49 | * Plus, please install `torch`-related packages using one command provided by the official guide (See [official installation guide](https://pytorch.org/get-started/locally/)); e.g., `conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 torchtext==0.13.0 cudatoolkit=11.6 -c pytorch -c conda-forge`
50 |
51 | ## Configurations
52 | * See `python3 main.py -h`.
53 |
54 | ## Example Commands
55 | * See shell files prepared in `commands` directory.
56 |
57 | ## TODO
58 | - [ ] Support another model, especially lightweight ones for cross-device FL setting. (e.g., [`EdgeNeXt`](https://github.com/mmaaz60/EdgeNeXt))
59 | - [ ] Support another structured dataset including temporal and tabular data, along with datasets suitable for cross-silo FL setting. (e.g., [`MedMNIST`](https://github.com/MedMNIST/MedMNIST))
60 | - [ ] Add other popular FL algorithms including personalized FL algorithms (e.g., [`SuPerFed`](https://arxiv.org/abs/2109.07628)).
61 | - [ ] Attach benchmark results of sample commands.
62 |
63 | ## Contact
64 | Should you have any feedback, please create a thread in __issue__ tab. Thank you :)
65 |
--------------------------------------------------------------------------------
/commands/leaf/fedavg_celeba.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments for LEAF CelebA dataset
4 | python3 main.py \
5 | --exp_name FedAvg_LEAF_CelebA --seed 42 --device cuda \
6 | --dataset CelebA \
7 | --split_type pre --test_size 0.1 \
8 | --model_name TwoCNN --resize 84 --hidden_size 32 \
9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 \
10 | --R 5000 --E 5 --C 0.001 --B 10 --beta1 0 \
11 | --optimizer SGD --lr 0.1 --lr_decay 1 --lr_decay_step 1 --criterion BCEWithLogitsLoss
12 |
--------------------------------------------------------------------------------
/commands/leaf/fedavg_femnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments for LEAF FEMNIST dataset
4 | python3 main.py \
5 | --exp_name FedAvg_LEAF_FEMNIST --seed 42 --device cuda \
6 | --dataset FEMNIST \
7 | --split_type pre --test_size 0.1 \
8 | --model_name FEMNISTCNN --resize 28 --hidden_size 64 \
9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 acc5 \
10 | --R 5000 --E 5 --C 0.003 --B 10 --beta1 0 \
11 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss
12 |
--------------------------------------------------------------------------------
/commands/leaf/fedavg_reddit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments for LEAF Reddit dataset
4 | python3 main.py \
5 | --exp_name FedAvg_LEAF_Reddit --seed 42 --device cuda \
6 | --dataset Reddit \
7 | --split_type pre --test_size 0.1 \
8 | --model_name NextWordLSTM --num_layers 2 --num_embeddings 10000 --embedding_size 256 --hidden_size 256 \
9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics seqacc \
10 | --R 5000 --E 5 --C 0.013 --B 50 --beta1 0 \
11 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion Seq2SeqLoss
12 |
--------------------------------------------------------------------------------
/commands/leaf/fedavg_sent140.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments for LEAF Sent140 dataset
4 | ## Sampled 5% of raw dataset as stated in the original paper, as total clients is over 200K...!
5 | python3 main.py \
6 | --exp_name FedAvg_LEAF_Sent140 --seed 42 --device cuda \
7 | --dataset Sent140 \
8 | --split_type pre --test_size 0.1 \
9 | --model_name NextCharLSTM --embedding_size 300 --hidden_size 128 --num_layers 2 \
10 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 \
11 | --R 5000 --E 5 --C 0.0001 --B 10 --beta1 0 \
12 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion BCEWithLogitsLoss
13 |
--------------------------------------------------------------------------------
/commands/leaf/fedavg_shakespeare.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments for LEAF Shakespeare dataset
4 | python3 main.py \
5 | --exp_name FedAvg_LEAF_Shakespeare --seed 42 --device cuda \
6 | --dataset Shakespeare \
7 | --split_type pre --test_size 0.1 \
8 | --model_name NextCharLSTM --num_embeddings 80 --embedding_size 8 --hidden_size 256 --num_layers 2 \
9 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 50 --eval_metrics acc1 acc5 \
10 | --R 5000 --E 5 --C 0.016 --B 10 --beta1 0 \
11 | --optimizer SGD --lr 0.0003 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss
12 |
--------------------------------------------------------------------------------
/commands/leaf/leaf.md:
--------------------------------------------------------------------------------
1 | # Replication Study - LEAF
2 | ---
3 |
--------------------------------------------------------------------------------
/commands/original_cifar10/cifar10.md:
--------------------------------------------------------------------------------
1 | # Replication Study - CIFAR10
2 | ---
3 | ## CIFAR10
4 | | IID `FedSGD` ($E=1$) || IID `FedAvg` ($E=5$)||
5 | |:------------:|:----------:|:----------:|:--------:|
6 | | $\eta$ | | $\eta$ | |
7 | | 0.45 | | 0.05 | |
8 | | 0.6 | | 0.15 | |
9 | | 0.7 | | 0.25 | |
10 | * Corresponds to Table 3 and Figure 4 from ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629)). (Note that $B=50, C=0.1, K=100$)
--------------------------------------------------------------------------------
/commands/original_cifar10/fedavg_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments in Figure 4, 9 of (McMahan et al., 2016)
4 | ## IID split
5 | python3 main.py \
6 | --exp_name FedAvg_CIFAR10_CNN_IID --seed 42 --device cuda \
7 | --dataset CIFAR10 \
8 | --split_type iid --test_size -1 \
9 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \
10 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \
11 | --K 100 --R 1000 --C 0.1 --E 5 --B 50 --beta1 0 \
12 | --optimizer SGD --lr 0.25 --lr_decay 0.99 --lr_decay_step 1 --criterion CrossEntropyLoss
13 |
14 | ## Pathological Non-IID split
15 | python3 main.py \
16 | --exp_name FedAvg_CIFAR10_CNN_Patho --seed 42 --device cuda \
17 | --dataset CIFAR10 \
18 | --split_type patho --test_size -1 \
19 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \
20 | --algorithm fedavg --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \
21 | --K 100 --R 1000 --C 0.1 --E 5 --B 50 --beta1 0 \
22 | --optimizer SGD --lr 0.1 --lr_decay 0.99 --lr_decay_step 1 --criterion CrossEntropyLoss
23 |
--------------------------------------------------------------------------------
/commands/original_cifar10/fedsgd_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedSGD experiments in Figure 4, 9 of (McMahan et al., 2016)
4 | ## IID split
5 | python3 main.py \
6 | --exp_name FedSGD_CIFAR10_CNN_IID --seed 42 --device cuda \
7 | --dataset CIFAR10 \
8 | --split_type iid --test_size -1 \
9 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \
10 | --algorithm fedsgd --eval_fraction 1 --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \
11 | --K 100 --R 10000 --C 0.1 --B 0 --beta1 0 \
12 | --optimizer SGD --lr 0.6 --lr_decay 0.9934 --lr_decay_step 1 --criterion CrossEntropyLoss
13 |
14 | ## Pathological Non-IID split
15 | python3 main.py \
16 | --exp_name FedSGD_CIFAR10_CNN_Patho --seed 42 --device cuda \
17 | --dataset CIFAR10 \
18 | --split_type patho --test_size -1 \
19 | --model_name SimpleCNN --crop 24 --randhf 0.5 --randjit 0.5 --imnorm --hidden_size 64 \
20 | --algorithm fedsgd --eval_type local --eval_every 1 --eval_metrics acc1 acc5 \
21 | --K 100 --R 10000 --C 0.1 --B 0 --beta1 0 \
22 | --optimizer SGD --lr 0.45 --lr_decay 0.9934 --lr_decay_step 1 --criterion CrossEntropyLoss
23 |
--------------------------------------------------------------------------------
/commands/original_mnist/fedavg_mnist_2nn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments in Table 1 of (McMahan et al., 2016)
4 | ## IID split
5 | for b in 0 10
6 | do
7 | for c in 0.0 0.1 0.2 0.5 1.0
8 | do
9 | python3 main.py \
10 | --exp_name "FedAvg_MNIST_2NN_IID_C${c}_B${b}" --seed 42 --device cuda \
11 | --dataset MNIST \
12 | --split_type iid --test_size 0 \
13 | --model_name TwoNN --resize 28 --hidden_size 200 \
14 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
15 | --K 100 --R 1000 --E 1 --C $c --B $b --beta1 0 \
16 | --optimizer SGD --lr 0.1 --lr_decay 0.99 --lr_decay_step 25 --criterion CrossEntropyLoss
17 | done
18 | done
19 |
20 | ## Pathological Non-IID split
21 | for b in 0 10
22 | do
23 | for c in 0.0 0.1 0.2 0.5 1.0
24 | do
25 | python3 main.py \
26 | --exp_name "FedAvg_MNIST_2NN_Patho_C${c}_B${b}" --seed 42 --device cuda \
27 | --dataset MNIST \
28 | --split_type patho --test_size 0 \
29 | --model_name TwoNN --resize 28 --hidden_size 200 \
30 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
31 | --K 100 --R 1000 --E 1 --C $c --B $b --beta1 0 \
32 | --optimizer SGD --lr 0.01 --lr_decay 0.999 --lr_decay_step 10 --criterion CrossEntropyLoss
33 | done
34 | done
35 |
--------------------------------------------------------------------------------
/commands/original_mnist/fedavg_mnist_cnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments in Table 1 of (McMahan et al., 2016)
4 | ## IID split
5 | for b in 0 10
6 | do
7 | for c in 0.0 0.1 0.2 0.5 1.0
8 | do
9 | python3 main.py \
10 | --exp_name "FedAvg_MNIST_CNN_IID_C${c}_B${b}" --seed 42 --device cuda \
11 | --dataset MNIST \
12 | --split_type iid --test_size 0 \
13 | --model_name TwoCNN --resize 28 --hidden_size 200 \
14 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
15 | --K 100 --R 1000 --E 5 --C $c --B $b --beta1 0 \
16 | --optimizer SGD --lr 0.215 --lr_decay 0.95 --lr_decay_step 25 --criterion CrossEntropyLoss
17 | done
18 | done
19 |
20 | ## Pathological Non-IID split
21 | for b in 0 10
22 | do
23 | for c in 0.0 0.1 0.2 0.5 1.0
24 | do
25 | python3 main.py \
26 | --exp_name "FedAvg_MNIST_CNN_Patho_C${c}_B${b}" --seed 42 --device cuda \
27 | --dataset MNIST \
28 | --split_type patho --test_size 0 \
29 | --model_name TwoCNN --resize 28 --hidden_size 200 \
30 | --algorithm fedavg --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
31 | --K 100 --R 1000 --E 5 --C $c --B $b --beta1 0 \
32 | --optimizer SGD --lr 0.1 --lr_decay 0.99 --lr_decay_step 10 --criterion CrossEntropyLoss
33 | done
34 | done
35 |
--------------------------------------------------------------------------------
/commands/original_mnist/fedsgd_mnist_2nn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedSGD experiments in Table 1 of (McMahan et al., 2016)
4 | ## IID split
5 | for c in 0.0 0.1 0.2 0.5 1.0
6 | do
7 | python3 main.py \
8 | --exp_name "FedSGD_MNIST_2NN_IID_C${c}_B0" --seed 42 --device cuda \
9 | --dataset MNIST \
10 | --split_type iid --test_size 0 \
11 | --model_name TwoNN --resize 28 --hidden_size 200 \
12 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
13 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \
14 | --optimizer SGD --lr 1.0 --lr_decay 0.99 --lr_decay_step 25 --criterion CrossEntropyLoss
15 | done
16 |
17 | ## Pathological Non-IID split
18 | for c in 0.0 0.1 0.2 0.5 1.0
19 | do
20 | python3 main.py \
21 | --exp_name "FedSGD_MNIST_2NN_Patho_C${c}_B0" --seed 42 --device cuda \
22 | --dataset MNIST \
23 | --split_type patho --test_size 0 \
24 | --model_name TwoNN --resize 28 --hidden_size 200 \
25 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
26 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \
27 | --optimizer SGD --lr 0.1 --lr_decay 0.95 --lr_decay_step 10 --criterion CrossEntropyLoss
28 | done
29 |
--------------------------------------------------------------------------------
/commands/original_mnist/fedsgd_mnist_cnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedSGD experiments in Table 1 of (McMahan et al., 2016)
4 | ## IID split
5 | for c in 0.0 0.1 0.2 0.5 1.0
6 | do
7 | python3 main.py \
8 | --exp_name "FedSGD_MNIST_CNN_IID_C${c}_B0" --seed 42 --device cuda \
9 | --dataset MNIST \
10 | --split_type iid --test_size 0 \
11 | --model_name TwoCNN --resize 28 --hidden_size 200 \
12 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
13 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \
14 | --optimizer SGD --lr 0.5 --lr_decay 0.95 --lr_decay_step 25 --criterion CrossEntropyLoss
15 | done
16 |
17 |
18 | ## Pathological Non-IID split
19 | for c in 0.0 0.1 0.2 0.5 1.0
20 | do
21 | python3 main.py \
22 | --exp_name "FedSGD_MNIST_CNN_Patho_C${c}_B0" --seed 42 --device cuda \
23 | --dataset MNIST \
24 | --split_type patho --test_size 0 \
25 | --model_name TwoCNN --resize 28 --hidden_size 200 \
26 | --algorithm fedsgd --eval_fraction 1 --eval_type both --eval_every 1 --eval_metrics acc1 acc5 \
27 | --K 100 --R 1000 --C $c --B 0 --beta1 0 \
28 | --optimizer SGD --lr 0.25 --lr_decay 0.99 --lr_decay_step 10 --criterion CrossEntropyLoss
29 | done
30 |
--------------------------------------------------------------------------------
/commands/original_mnist/mnist.md:
--------------------------------------------------------------------------------
1 | # Replication Study - MNIST
2 | ---
3 | ## 2NN and CNN
4 | | $\text{2NN}$ | IID `FedSGD`($E=1$)|| Non-IID `FedSGD` ($E=1$)|| IID `FedAvg` ($E=5$)|| Non-IID `FedAvg` ($E=5$)||
5 | |:------------:|:----------:|:------:|:----------:|:------:|:----------:|:------:|:----------:|:------:|
6 | | $C$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ |
7 | | 0.0 | | | | | | | | |
8 | | 0.1 | | | | | | | | |
9 | | 0.2 | | | | | | | | |
10 | | 0.5 | | | | | | | | |
11 | | 1.0 | | | | | | | | |
12 |
13 | | $\text{CNN}$ | IID `FedSGD` ($E=1$)|| Non-IID `FedSGD` ($E=1$)|| IID `FedAvg` ($E=5$)|| Non-IID `FedAvg` ($E=5$)||
14 | |:------------:|:----------:|:-------:|:----------:|:-----------:|:----------:|:--------:|:----------:|:-----------:|
15 | | $C$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ | $B=\infty$ | $B=10$ |
16 | | 0.0 | | | | | | | | |
17 | | 0.1 | | | | | | | | |
18 | | 0.2 | | | | | | | | |
19 | | 0.5 | | | | | | | | |
20 | | 1.0 | | | | | | | | |
21 | * Corresponds to Table 1 from ([McMahan et al., 2016](https://arxiv.org/abs/1602.05629)). (Note that $K=100$)
22 |
--------------------------------------------------------------------------------
/commands/original_play_n_role/Shakespeare.md:
--------------------------------------------------------------------------------
1 | # Replication Study - Shakespeare
2 | ---
3 |
--------------------------------------------------------------------------------
/commands/original_play_n_role/fedavg_play_n_role.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedAvg experiments in Table 2, Figure 2, 3 of (McMahan et al., 2016)
4 | ## Note: this is equivalent to Shakespeare dataset under LEAF benchmark
5 | ## Role & Play Non-IID split
6 | python3 main.py \
7 | --exp_name FedAvg_Shakespeare_NextCharLSTM --seed 42 --device cuda \
8 | --dataset Shakespeare \
9 | --split_type pre --test_size 0.2 \
10 | --model_name NextCharLSTM --num_embeddings 80 --embedding_size 8 --hidden_size 256 --num_layers 2 \
11 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 100 --eval_metrics acc1 acc5 \
12 | --R 2000 --C 0.002 --E 1 --B 10 --beta1 0 \
13 | --optimizer SGD --lr 1.47 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss
14 |
--------------------------------------------------------------------------------
/commands/original_play_n_role/fedsgd_play_n_role.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # FedSGD experiments in Table 2, Figure 2, 3 of (McMahan et al., 2016)
4 | ## Note: this is equivalent to Shakespeare dataset under LEAF benchmark
5 | ## Role & Play Non-IID split
6 | python3 main.py \
7 | --exp_name FedSGD_Shakespeare_NextCharLSTM --seed 42 --device cuda \
8 | --dataset Shakespeare \
9 | --split_type pre --test_size 0.2 \
10 | --model_name NextCharLSTM --num_embeddings 80 --embedding_size 8 --hidden_size 256 --num_layers 2 \
11 | --algorithm fedsgd --eval_fraction 1 --eval_type local --eval_every 500 --eval_metrics acc1 acc5 \
12 | --R 5000 --C 0.002 --B 0 --beta1 0 \
13 | --optimizer SGD --lr 1.47 --lr_decay 1 --lr_decay_step 1 --criterion CrossEntropyLoss
14 |
--------------------------------------------------------------------------------
/requirments.txt:
--------------------------------------------------------------------------------
1 | tensorboard==2.12.1
2 | numpy==1.24.2
3 | pandas==1.5.3
4 | scikit-learn==1.2.2
5 | transformers==4.27.4
6 | tqdm==4.65.0
7 | einops==0.6.1
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import transformers
3 |
4 | # turn off unnecessary logging
5 | transformers.logging.set_verbosity_error()
6 |
7 | from .utils import set_seed, Range, TensorBoardRunner, check_args, init_weights, TqdmToLogger, MetricManager, stratified_split
8 | from .loaders import load_dataset, load_model
9 |
10 |
11 |
12 | # for logger initialization
13 | def set_logger(path, args):
14 | # initialize logger
15 | logger = logging.getLogger(__name__)
16 | logging_format = logging.Formatter(
17 | fmt='[%(levelname)s] (%(asctime)s) %(message)s',
18 | datefmt='%Y/%m/%d %I:%M:%S %p'
19 | )
20 | stream_handler = logging.StreamHandler()
21 | file_handler = logging.FileHandler(path)
22 |
23 | stream_handler.setFormatter(logging_format)
24 | file_handler.setFormatter(logging_format)
25 |
26 | logger.addHandler(stream_handler)
27 | logger.addHandler(file_handler)
28 | logger.setLevel(level=logging.INFO)
29 |
30 | # print welcome message
31 | logger.info('[WELCOME] Initialize...')
32 | welcome_message = """
33 | _______ _______ ______ _______ ______ _______ _______ _______ ______
34 | |______ |______ | \ |______ |_____/ |_____| | |______ | \\
35 | | |_______|_____/_|_______|__ _\_ |_ ___|_ __| _ |______ |_____/
36 | | |______ |_____| |_____/ | \ | | | \ | | ____
37 | |_____ |______ | | | \_ | \_| __|__ | \_| |_____|
38 |
39 | By. vaseline555 (Adam)
40 | """
41 | logger.info(welcome_message)
42 |
--------------------------------------------------------------------------------
/src/algorithm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/algorithm/__init__.py
--------------------------------------------------------------------------------
/src/algorithm/basealgorithm.py:
--------------------------------------------------------------------------------
1 | from abc import *
2 |
3 |
4 |
5 | class BaseOptimizer(metaclass=ABCMeta):
6 | """Federated optimization algorithm.
7 | """
8 | @abstractmethod
9 | def step(self, closure=None):
10 | raise NotImplementedError
11 |
12 | @abstractmethod
13 | def accumulate(self, **kwargs):
14 | raise NotImplementedError
15 |
16 |
--------------------------------------------------------------------------------
/src/algorithm/fedadagrad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .basealgorithm import BaseOptimizer
4 |
5 |
6 |
7 | class FedadagradOptimizer(BaseOptimizer, torch.optim.Optimizer):
8 | def __init__(self, params, **kwargs):
9 | lr = kwargs.get('lr')
10 | v0 = kwargs.get('v0')
11 | tau = kwargs.get('tau')
12 | momentum = kwargs.get('beta')
13 | defaults = dict(lr=lr, momentum=momentum, v0=v0, tau=tau)
14 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults)
15 |
16 | def step(self, closure=None):
17 | loss = None
18 | if closure is not None:
19 | loss = closure()
20 |
21 | for idx, group in enumerate(self.param_groups):
22 | beta = group['momentum']
23 | tau = group['tau']
24 | lr = group['lr']
25 | v0 = group['v0']
26 | for param in group['params']:
27 | if param.grad is None:
28 | continue
29 | # get (\Delta_t)
30 | delta = -param.grad.data
31 |
32 | if idx == 0: # idx == 0: parameters; optimize according to algorithm
33 | # calculate m_t
34 | if 'momentum_buffer1' not in self.state[param]:
35 | self.state[param]['momentum_buffer1'] = torch.zeros_like(param).detach()
36 | self.state[param]['momentum_buffer1'].mul_(beta).add_(delta.mul(1. - beta)) # \beta * m_t + (1 - \beta) * \Delta_t
37 | m_new = self.state[param]['momentum_buffer1']
38 |
39 | # calculate v_t
40 | if 'momentum_buffer2' not in self.state[param]:
41 | self.state[param]['momentum_buffer2'] = v0 + delta.pow(2)
42 | self.state[param]['momentum_buffer2'].add_(delta.pow(2)) # v_t + \Delta_t^2
43 |
44 | # update parameters
45 | param.data.add_((m_new.div(self.state[param]['momentum_buffer2'].pow(0.5).add(tau))).mul(lr))
46 | elif idx == 1: # idx == 1: buffers; just averaging
47 | param.data.add_(delta)
48 | return loss
49 |
50 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name):
51 | for group in self.param_groups:
52 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator):
53 | if check_if(name):
54 | server_param.data.zero_()
55 | server_param.data.grad = torch.zeros_like(server_param)
56 | continue
57 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype)
58 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates!
59 | server_param.grad = local_delta
60 | else:
61 | server_param.grad.data.add_(local_delta)
62 |
--------------------------------------------------------------------------------
/src/algorithm/fedadam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .basealgorithm import BaseOptimizer
4 |
5 |
6 |
7 | class FedadamOptimizer(BaseOptimizer, torch.optim.Optimizer):
8 | def __init__(self, params, **kwargs):
9 | lr = kwargs.get('lr')
10 | v0 = kwargs.get('v0')
11 | tau = kwargs.get('tau')
12 | momentum = kwargs.get('betas')
13 | defaults = dict(lr=lr, momentum=momentum, v0=v0, tau=tau)
14 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults)
15 |
16 | def step(self, closure=None):
17 | loss = None
18 | if closure is not None:
19 | loss = closure()
20 |
21 | for idx, group in enumerate(self.param_groups):
22 | (beta1, beta2) = group['momentum']
23 | tau = group['tau']
24 | lr = group['lr']
25 | v0 = group['v0']
26 | for param in group['params']:
27 | if param.grad is None:
28 | continue
29 | # get (\Delta_t)
30 | delta = -param.grad.data
31 |
32 | if idx == 0: # idx == 0: parameters; optimize according to algorithm
33 | # calculate m_t
34 | if 'momentum_buffer1' not in self.state[param]:
35 | self.state[param]['momentum_buffer1'] = torch.zeros_like(param).detach()
36 | self.state[param]['momentum_buffer1'].mul_(beta1).add_(delta.mul(1. - beta1)) # \beta1 * m_t + (1 - \beta1) * \Delta_t
37 | m_new = self.state[param]['momentum_buffer1']
38 |
39 | # calculate v_t
40 | if 'momentum_buffer2' not in self.state[param]:
41 | self.state[param]['momentum_buffer2'] = v0 * beta2 + delta.pow(2).mul(1. - beta2)
42 | self.state[param]['momentum_buffer2'].mul_(beta2).add_(delta.pow(2).mul(1. - beta2)) # \beta2 * v_t + (1 - \beta2) * \Delta_t^2
43 | v_new = self.state[param]['momentum_buffer2']
44 |
45 | # update parameters
46 | param.data.add_(m_new.div(v_new.pow(0.5).add(tau)).mul(lr))
47 | elif idx == 1: # idx == 1: buffers; just averaging
48 | param.data.add_(delta)
49 | return loss
50 |
51 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name):
52 | for group in self.param_groups:
53 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator):
54 | if check_if(name):
55 | server_param.data.zero_()
56 | server_param.data.grad = torch.zeros_like(server_param)
57 | continue
58 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype)
59 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates!
60 | server_param.grad = local_delta
61 | else:
62 | server_param.grad.data.add_(local_delta)
63 |
--------------------------------------------------------------------------------
/src/algorithm/fedavg.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .basealgorithm import BaseOptimizer
4 |
5 |
6 |
7 | class FedavgOptimizer(BaseOptimizer, torch.optim.Optimizer):
8 | def __init__(self, params, **kwargs):
9 | self.lr = kwargs.get('lr')
10 | self.momentum = kwargs.get('momentum', 0.)
11 | defaults = dict(lr=self.lr, momentum=self.momentum)
12 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults)
13 |
14 | def step(self, closure=None):
15 | loss = None
16 | if closure is not None:
17 | loss = closure()
18 |
19 | for idx, group in enumerate(self.param_groups):
20 | beta = group['momentum']
21 | for param in group['params']:
22 | if param.grad is None:
23 | continue
24 | delta = param.grad.data
25 |
26 | if idx == 0: # idx == 0: parameters; optimize according to algorithm // idx == 1: buffers; just averaging
27 | if beta > 0.:
28 | if 'momentum_buffer' not in self.state[param]:
29 | self.state[param]['momentum_buffer'] = torch.zeros_like(param).detach()
30 | self.state[param]['momentum_buffer'].mul_(beta).add_(delta.mul(1. - beta)) # \beta * v + (1 - \beta) * grad
31 | delta = self.state[param]['momentum_buffer']
32 | param.data.sub_(delta)
33 | return loss
34 |
35 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name):
36 | for group in self.param_groups:
37 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator):
38 | if check_if(name):
39 | server_param.data.zero_()
40 | server_param.data.grad = torch.zeros_like(server_param)
41 | continue
42 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype)
43 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates!
44 | server_param.grad = local_delta
45 | else:
46 | server_param.grad.data.add_(local_delta)
47 |
--------------------------------------------------------------------------------
/src/algorithm/fedavgm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .fedavg import FedavgOptimizer
4 |
5 |
6 |
7 | class FedavgmOptimizer(FedavgOptimizer):
8 | def __init__(self, params, **kwargs):
9 | super(FedavgmOptimizer, self).__init__(params=params, **kwargs)
10 |
--------------------------------------------------------------------------------
/src/algorithm/fedprox.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .fedavg import FedavgOptimizer
4 |
5 |
6 |
7 | class FedproxOptimizer(FedavgOptimizer):
8 | def __init__(self, params, **kwargs):
9 | super(FedproxOptimizer, self).__init__(params=params, **kwargs)
10 |
--------------------------------------------------------------------------------
/src/algorithm/fedsgd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .fedavg import FedavgOptimizer
4 |
5 |
6 |
7 | class FedsgdOptimizer(FedavgOptimizer):
8 | def __init__(self, params, **kwargs):
9 | super(FedsgdOptimizer, self).__init__(params=params, **kwargs)
10 |
11 | def step(self, closure=None):
12 | loss = None
13 | if closure is not None:
14 | loss = closure()
15 |
16 | for idx, group in enumerate(self.param_groups):
17 | if idx == 1: continue
18 | beta = group['momentum']
19 | for param in group['params']:
20 | if param.grad is None:
21 | continue
22 | delta = param.grad.mul(group['lr'])
23 | if beta > 0.:
24 | if 'momentum_buffer' not in self.state[param]:
25 | self.state[param]['momentum_buffer'] = torch.zeros_like(param).detach()
26 | self.state[param]['momentum_buffer'].mul_(beta).add_(delta.mul(1. - beta)) # \beta * v + (1 - \beta) * (lr * grad)
27 | delta = self.state[param]['momentum_buffer']
28 | param.data.sub_(delta)
29 | return loss
30 |
31 | def accumulate(self, mixing_coefficient, local_layers_iterator):
32 | for idx, group in enumerate(self.param_groups):
33 | if idx == 1: continue # idx == 0: parameters; idx == 1: buffers - ignore as these are not parameters
34 | for server_param, (_, local_param) in zip(group['params'], local_layers_iterator):
35 | local_delta = local_param.grad.mul(mixing_coefficient).data.type(server_param.dtype)
36 | if server_param.grad is None:
37 | server_param.grad = local_delta
38 | else:
39 | server_param.grad.data.add_(local_delta)
40 |
--------------------------------------------------------------------------------
/src/algorithm/fedyogi.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .basealgorithm import BaseOptimizer
4 |
5 |
6 |
7 | class FedyogiOptimizer(BaseOptimizer, torch.optim.Optimizer):
8 | def __init__(self, params, **kwargs):
9 | lr = kwargs.get('lr')
10 | v0 = kwargs.get('v0')
11 | tau = kwargs.get('tau')
12 | momentum = kwargs.get('betas')
13 | defaults = dict(lr=lr, momentum=momentum, v0=v0, tau=tau)
14 | BaseOptimizer.__init__(self); torch.optim.Optimizer.__init__(self, params=params, defaults=defaults)
15 |
16 | def step(self, closure=None):
17 | loss = None
18 | if closure is not None:
19 | loss = closure()
20 |
21 | for idx, group in enumerate(self.param_groups):
22 | (beta1, beta2) = group['momentum']
23 | tau = group['tau']
24 | lr = group['lr']
25 | v0 = group['v0']
26 | for param in group['params']:
27 | if param.grad is None:
28 | continue
29 | # get (\Delta_t)
30 | delta = -param.grad.data
31 |
32 | if idx == 0: # idx == 0: parameters; optimize according to algorithm
33 | # calculate m_t
34 | if 'momentum_buffer1' not in self.state[param]:
35 | self.state[param]['momentum_buffer1'] = torch.zeros_like(param).detach()
36 | self.state[param]['momentum_buffer1'].mul_(beta1).add_(delta.mul(1. - beta1)) # \beta1 * m_t + (1 - \beta1) * \Delta_t
37 | m_new = self.state[param]['momentum_buffer1']
38 |
39 | # calculate v_t
40 | if 'momentum_buffer2' not in self.state[param]:
41 | self.state[param]['momentum_buffer2'] = v0 - delta.pow(2).mul(1. - beta2).mul((v0 - delta).sign())
42 | v_curr = self.state[param]['momentum_buffer2']
43 | self.state[param]['momentum_buffer2'].sub_(delta.pow(2).mul(1. - beta2).mul(v_curr.sub(delta.pow(2)).sign())) # v_t - (1 - \beta2) * \Delta_t^2 * sgn(v_t - \Delta_t)
44 | v_new = self.state[param]['momentum_buffer2']
45 |
46 | # update parameters
47 | param.data.add_((m_new.div(v_new.pow(0.5).add(tau))).mul(lr))
48 | elif idx == 1: # idx == 1: buffers; just averaging
49 | param.data.add_(delta)
50 | return loss
51 |
52 | def accumulate(self, mixing_coefficient, local_layers_iterator, check_if=lambda name: 'num_batches_tracked' in name):
53 | for group in self.param_groups:
54 | for server_param, (name, local_signals) in zip(group['params'], local_layers_iterator):
55 | if check_if(name):
56 | server_param.data.zero_()
57 | server_param.data.grad = torch.zeros_like(server_param)
58 | continue
59 | local_delta = (server_param - local_signals).mul(mixing_coefficient).data.type(server_param.dtype)
60 | if server_param.grad is None: # NOTE: grad buffer is used to accumulate local updates!
61 | server_param.grad = local_delta
62 | else:
63 | server_param.grad.data.add_(local_delta)
64 |
--------------------------------------------------------------------------------
/src/algorithm/sent140.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | ## 2,460 clients, 2 classes
3 |
4 | for s in 42 1023 59999
5 | do
6 | for a in 0 0.1 1 10
7 | do
8 | python3 main.py \
9 | --exp_name "Sent140_FedAvg_Fixed_${a} (${s})" --seed $s --device cuda:0 \
10 | --dataset Sent140 --learner fixed --alpha $a \
11 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \
12 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 256 --num_layers 2 \
13 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 5000 --eval_metrics acc1 \
14 | --R 500 --C 0.0021 --E 5 --B 10 \
15 | --optimizer SGD --lr 0.3 --lr_decay 0.98 --lr_decay_step 1 --criterion BCEWithLogitsLoss &
16 |
17 | python3 main.py \
18 | --exp_name "Sent140_FedAvg_AdaHedge_${a} (${s})" --seed $s --device cuda:0 \
19 | --dataset Sent140 --learner ah --alpha $a \
20 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \
21 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 256 --num_layers 2 \
22 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 5000 --eval_metrics acc1 \
23 | --R 500 --C 0.0021 --E 5 --B 10 \
24 | --optimizer SGD --lr 0.3 --lr_decay 0.98 --lr_decay_step 1 --criterion BCEWithLogitsLoss &
25 |
26 | python3 main.py \
27 | --exp_name "Sent140_FedAvg_SoftBayes_${a} (${s})" --seed $s --device cuda:0 \
28 | --dataset Sent140 --learner sb --alpha $a \
29 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \
30 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 256 --num_layers 2 \
31 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 100 --eval_metrics acc1 \
32 | --R 500 --C 0.0021 --E 5 --B 10 \
33 | --optimizer SGD --lr 0.3 --lr_decay 0.98 --lr_decay_step 1 --criterion BCEWithLogitsLoss
34 | done
35 | wait
36 | done
37 |
38 | python3 main.py \
39 | --exp_name Sent140_FedAvg_Fixed --seed 42 --device cuda:2 \
40 | --dataset Sent140 --learner fixed \
41 | --split_type pre --rawsmpl 0.01 --test_size 0.2 \
42 | --model_name Sent140LSTM --embedding_size 300 --hidden_size 80 --num_layers 2 \
43 | --algorithm fedavg --eval_fraction 1 --eval_type local --eval_every 1000 --eval_metrics acc1 \
44 | --R 1000 --C 0.0021 --E 5 --B 10 \
45 | --optimizer SGD --lr 0.0003 --lr_decay 0.9995 --lr_decay_step 1 --criterion BCEWithLogitsLoss
46 |
47 |
--------------------------------------------------------------------------------
/src/client/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/client/__init__.py
--------------------------------------------------------------------------------
/src/client/baseclient.py:
--------------------------------------------------------------------------------
1 | from abc import *
2 |
3 |
4 |
5 | class BaseClient(metaclass=ABCMeta):
6 | """Class for client object having its own (private) data and resources to train a model.
7 | """
8 | def __init__(self, **kwargs):
9 | self.__identifier = None
10 | self.__model = None
11 |
12 | @property
13 | def id(self):
14 | return self.__identifier
15 |
16 | @id.setter
17 | def id(self, identifier):
18 | self.__identifier = identifier
19 |
20 | @property
21 | def model(self):
22 | return self.__model
23 |
24 | @model.setter
25 | def model(self, model):
26 | self.__model = model
27 |
28 | @abstractmethod
29 | def update(self):
30 | raise NotImplementedError
31 |
32 | @abstractmethod
33 | def evaluate(self):
34 | raise NotImplementedError
35 |
36 | @abstractmethod
37 | def download(self):
38 | raise NotImplementedError
39 |
40 | @abstractmethod
41 | def upload(self):
42 | raise NotImplementedError
43 |
44 | @abstractmethod
45 | def __len__(self):
46 | raise NotImplementedError
47 |
48 | @abstractmethod
49 | def __repr__(self):
50 | raise NotImplementedError
51 |
--------------------------------------------------------------------------------
/src/client/fedadagradclient.py:
--------------------------------------------------------------------------------
1 | from .fedavgclient import FedavgClient
2 |
3 |
4 |
5 | class FedadagradClient(FedavgClient):
6 | def __init__(self, **kwargs):
7 | super(FedadagradClient, self).__init__(**kwargs)
--------------------------------------------------------------------------------
/src/client/fedadamclient.py:
--------------------------------------------------------------------------------
1 | from .fedavgclient import FedavgClient
2 |
3 |
4 |
5 | class FedadamClient(FedavgClient):
6 | def __init__(self, **kwargs):
7 | super(FedadamClient, self).__init__(**kwargs)
--------------------------------------------------------------------------------
/src/client/fedavgclient.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import inspect
4 | import itertools
5 |
6 | from .baseclient import BaseClient
7 | from src import MetricManager
8 |
9 |
10 | class FedavgClient(BaseClient):
11 | def __init__(self, args, training_set, test_set):
12 | super(FedavgClient, self).__init__()
13 | self.args = args
14 | self.training_set = training_set
15 | self.test_set = test_set
16 |
17 | self.optim = torch.optim.__dict__[self.args.optimizer]
18 | self.criterion = torch.nn.__dict__[self.args.criterion]
19 |
20 | self.train_loader = self._create_dataloader(self.training_set, shuffle=not self.args.no_shuffle)
21 | self.test_loader = self._create_dataloader(self.test_set, shuffle=False)
22 |
23 | def _refine_optim_args(self, args):
24 | required_args = inspect.getfullargspec(self.optim)[0]
25 |
26 | # collect eneterd arguments
27 | refined_args = {}
28 | for argument in required_args:
29 | if hasattr(args, argument):
30 | refined_args[argument] = getattr(args, argument)
31 | return refined_args
32 |
33 | def _create_dataloader(self, dataset, shuffle):
34 | if self.args.B == 0 :
35 | self.args.B = len(self.training_set)
36 | return torch.utils.data.DataLoader(dataset=dataset, batch_size=self.args.B, shuffle=shuffle)
37 |
38 | def update(self):
39 | mm = MetricManager(self.args.eval_metrics)
40 | self.model.train()
41 | self.model.to(self.args.device)
42 |
43 | optimizer = self.optim(self.model.parameters(), **self._refine_optim_args(self.args))
44 | for e in range(self.args.E):
45 | for inputs, targets in self.train_loader:
46 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device)
47 |
48 | outputs = self.model(inputs)
49 | loss = self.criterion()(outputs, targets)
50 |
51 | for param in self.model.parameters():
52 | param.grad = None
53 | loss.backward()
54 | if self.args.max_grad_norm > 0:
55 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
56 | optimizer.step()
57 |
58 | mm.track(loss.item(), outputs, targets)
59 | else:
60 | mm.aggregate(len(self.training_set), e + 1)
61 | else:
62 | self.model.to('cpu')
63 | return mm.results
64 |
65 | @torch.inference_mode()
66 | def evaluate(self):
67 | if self.args.train_only: # `args.test_size` == 0
68 | return {'loss': -1, 'metrics': {'none': -1}}
69 |
70 | mm = MetricManager(self.args.eval_metrics)
71 | self.model.eval()
72 | self.model.to(self.args.device)
73 |
74 | for inputs, targets in self.test_loader:
75 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device)
76 |
77 | outputs = self.model(inputs)
78 | loss = self.criterion()(outputs, targets)
79 |
80 | mm.track(loss.item(), outputs, targets)
81 | else:
82 | self.model.to('cpu')
83 | mm.aggregate(len(self.test_set))
84 | return mm.results
85 |
86 | def download(self, model):
87 | self.model = copy.deepcopy(model)
88 |
89 | def upload(self):
90 | return itertools.chain.from_iterable([self.model.named_parameters(), self.model.named_buffers()])
91 |
92 | def __len__(self):
93 | return len(self.training_set)
94 |
95 | def __repr__(self):
96 | return f'CLIENT < {self.id} >'
97 |
--------------------------------------------------------------------------------
/src/client/fedavgmclient.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | from .fedavgclient import FedavgClient
4 | from src import MetricManager
5 |
6 |
7 | class FedavgmClient(FedavgClient):
8 | def __init__(self, **kwargs):
9 | super(FedavgmClient, self).__init__(**kwargs)
--------------------------------------------------------------------------------
/src/client/fedproxclient.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 |
4 | from .fedavgclient import FedavgClient
5 | from src import MetricManager
6 |
7 |
8 | class FedproxClient(FedavgClient):
9 | def __init__(self, **kwargs):
10 | super(FedproxClient, self).__init__(**kwargs)
11 |
12 | def update(self):
13 | mm = MetricManager(self.args.eval_metrics)
14 | self.model.train()
15 | self.model.to(self.args.device)
16 |
17 | global_model = copy.deepcopy(self.model)
18 | for param in global_model.parameters():
19 | param.requires_grad = False
20 |
21 | optimizer = self.optim(self.model.parameters(), **self._refine_optim_args(self.args))
22 | for e in range(self.args.E):
23 | for inputs, targets in self.train_loader:
24 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device)
25 |
26 | outputs = self.model(inputs)
27 | loss = self.criterion()(outputs, targets)
28 |
29 | prox = 0.
30 | for name, param in self.model.named_parameters():
31 | prox += (param - global_model.get_parameter(name)).norm(2)
32 | loss += self.args.mu * (0.5 * prox)
33 |
34 | for param in self.model.parameters():
35 | param.grad = None
36 | loss.backward()
37 | if self.args.max_grad_norm > 0:
38 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
39 | optimizer.step()
40 |
41 | mm.track(loss.item(), outputs, targets)
42 | else:
43 | mm.aggregate(len(self.training_set), e + 1)
44 | else:
45 | self.model.to('cpu')
46 | return mm.results
47 |
--------------------------------------------------------------------------------
/src/client/fedsgdclient.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .fedavgclient import FedavgClient
4 | from src import MetricManager
5 |
6 |
7 |
8 | class FedsgdClient(FedavgClient):
9 | def __init__(self, **kwargs):
10 | super(FedsgdClient, self).__init__(**kwargs)
11 |
12 | def update(self):
13 | mm = MetricManager(self.args.eval_metrics)
14 | self.model.train()
15 | self.model.to(self.args.device)
16 |
17 | for inputs, targets in self.train_loader:
18 | inputs, targets = inputs.to(self.args.device), targets.to(self.args.device)
19 |
20 | outputs = self.model(inputs)
21 | loss = self.criterion()(outputs, targets)
22 |
23 | for param in self.model.parameters():
24 | param.grad = None
25 | loss.backward()
26 | if self.args.max_grad_norm > 0:
27 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
28 | mm.track(loss.item(), outputs, targets)
29 | else:
30 | self.model.to('cpu')
31 | mm.aggregate(len(self.training_set), 1)
32 | return mm.results
33 |
34 | def upload(self):
35 | return self.model.named_parameters()
36 |
--------------------------------------------------------------------------------
/src/client/fedyogiclient.py:
--------------------------------------------------------------------------------
1 | from .fedavgclient import FedavgClient
2 |
3 |
4 |
5 | class FedyogiClient(FedavgClient):
6 | def __init__(self, **kwargs):
7 | super(FedyogiClient, self).__init__(**kwargs)
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .torchvisionparser import fetch_torchvision_dataset
2 | from .torchtextparser import fetch_torchtext_dataset
3 | from .leafparser import fetch_leaf
4 | from .tinyimagenet import fetch_tinyimagenet
5 | from .cinic10 import fetch_cinic10
6 | from .beerreviews import fetch_beerreviews
7 | from .heart import fetch_heart
8 | from .adult import fetch_adult
9 | from .cover import fetch_cover
10 | from .gleam import fetch_gleam
11 | from .speechcommands import fetch_speechcommands
--------------------------------------------------------------------------------
/src/datasets/adult.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import torchtext
5 | import pandas as pd
6 |
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.preprocessing import MinMaxScaler
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 |
14 | class Adult(torch.utils.data.Dataset):
15 | def __init__(self, education, inputs, targets, scaler):
16 | self.identifier = education
17 | self.inputs, self.targets = inputs, targets
18 | self.scaler = scaler
19 |
20 | @staticmethod
21 | def inverse_transform(self, inputs):
22 | return self.scaler.inverse_transform(inputs)
23 |
24 | def __len__(self):
25 | return len(self.inputs)
26 |
27 | def __getitem__(self, index):
28 | inputs, targets = torch.tensor(self.inputs[index]).float(), torch.tensor(self.targets[index]).long()
29 | return inputs, targets
30 |
31 | def __repr__(self):
32 | return self.identifier
33 |
34 | # helper method to fetch Adult dataset
35 | def fetch_adult(args, root, seed, test_size):
36 | URL = [
37 | 'http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data',
38 | 'http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'
39 | ]
40 | MD5 = [
41 | '5d7c39d7b8804f071cdd1f2a7c460872',
42 | '35238206dfdf7f1fe215bbb874adecdc'
43 | ]
44 | COL_NAME = [
45 | 'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',\
46 | 'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss',\
47 | 'house_per_week', 'native_country', 'targets'
48 | ]
49 | NUM_COL = ['age', 'fnlwgt', 'capital_gain', 'capital_loss', 'house_per_week', 'education_num']
50 |
51 | def _download(root):
52 | for idx, (url, md5) in enumerate(zip(URL, MD5)):
53 | _ = torchtext.utils.download_from_url(
54 | url=url,
55 | root=root,
56 | hash_value=md5,
57 | hash_type='md5'
58 | )
59 | os.rename(os.path.join(root, url.split('/')[-1]), os.path.join(root, f"adult_{'train' if idx == 0 else 'test'}.csv"))
60 |
61 | def _munge_and_create_clients(root):
62 | # load dat
63 | df = pd.read_csv(os.path.join(root, 'adult_train.csv'), header=None, names=COL_NAME, na_values='?').dropna().reset_index(drop=True)
64 | df = df.drop(columns=['education'])
65 |
66 | # encode categorical data
67 | for col in df.columns:
68 | if col not in NUM_COL:
69 | replace_map = {key: value for value, key in enumerate(sorted(df[col].unique()))}
70 | df[col] = df[col].replace(replace_map)
71 |
72 | # adjust dtype
73 | for col in df.columns:
74 | if col in NUM_COL:
75 | df[col] = df[col].astype('float')
76 | else:
77 | df[col] = df[col].astype('category')
78 |
79 | # get one-hot encoded dummy columns for categorical data
80 | df = pd.concat([pd.get_dummies(df.iloc[:, :-1], columns=[col for col in df.columns if col not in NUM_COL][:-1], drop_first=True, dtype=int), df[['targets']]], axis=1)
81 |
82 | # creat clients by education
83 | clients = {}
84 | for edu in df['education_num'].unique():
85 | clients[edu] = df.loc[df['education_num'] == edu]
86 | return clients
87 |
88 | def _process_client_datasets(dataset, seed, test_size):
89 | # remove identifier column
90 | edu = int(dataset['education_num'].unique()[0])
91 | df = dataset.drop(columns=['education_num'])
92 | inputs, targets = df.iloc[:, :-1].values, df.iloc[:, -1].values
93 |
94 | # train-test split with stratified manner
95 | train_inputs, test_inputs, train_targets, test_targets = train_test_split(inputs, targets, test_size=test_size, random_state=seed, stratify=targets)
96 |
97 | # scaling inputs
98 | scaler = MinMaxScaler()
99 | train_inputs[:, :5] = scaler.fit_transform(train_inputs[:, :5])
100 | test_inputs[:, :5] = scaler.transform(test_inputs[:, :5])
101 | return (
102 | Adult(f'[ADULT] CLIENT < Edu{str(edu).zfill(2)} > (train)', train_inputs, train_targets, scaler),
103 | Adult(f'[ADULT] CLIENT < Edu{str(edu).zfill(2)} > (test)', test_inputs, test_targets, scaler)
104 | )
105 |
106 | logger.info(f'[LOAD] [ADULT] Check if raw data exists; if not, start downloading!')
107 | if not os.path.exists(os.path.join(root, 'adult')):
108 | _download(root=os.path.join(root, 'adult'))
109 | logger.info(f'[LOAD] [ADULT] ...raw data is successfully downloaded!')
110 | else:
111 | logger.info(f'[LOAD] [ADULT] ...raw data already exists!')
112 |
113 | logger.info(f'[LOAD] [ADULT] Munging dataset and create clients!')
114 | raw_clients = _munge_and_create_clients(os.path.join(root, 'adult'))
115 | logger.info('[LOAD] [ADULT] ...munged dataset and created clients!!')
116 |
117 | logger.info(f'[LOAD] [ADULT] Processing client datsets!')
118 | client_datasets = []
119 | for dataset in raw_clients.values():
120 | client_datasets.append(_process_client_datasets(dataset, seed, test_size))
121 | logger.info('[LOAD] [ADULT] ...processed client datasets!')
122 |
123 | args.in_features = 84
124 | args.num_classes = 2
125 | args.K = 16
126 | return {}, client_datasets, args
127 |
--------------------------------------------------------------------------------
/src/datasets/beerreviews.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import logging
5 | import torchtext
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 |
11 | class BeerReviews(torch.utils.data.Dataset):
12 | URL = {
13 | 'look': 'http://people.csail.mit.edu/yujia/files/ls/beer/look.json',
14 | 'aroma': 'http://people.csail.mit.edu/yujia/files/ls/beer/aroma.json'
15 | }
16 | MD5 = {
17 | 'look': '4ad6dd806554ec50ad80b2acedda38d4',
18 | 'aroma': 'd5bc425fb075198a2d4792d33690a3fd'
19 | }
20 | ASPECT = ('look', 'aroma')
21 |
22 | def __init__(self, root, aspect, tokenizer=None, download=True):
23 | assert aspect in self.ASPECT, f'Unknown aspect {aspect}!'
24 |
25 | self.root = os.path.expanduser(root)
26 | self.aspect = aspect
27 |
28 | if download:
29 | self.download()
30 | if not self._check_exists():
31 | err = 'Dataset not found or corrupted. You can use download=True to download it'
32 | logger.exception(err)
33 | raise RuntimeError(err)
34 |
35 | # get the data and binary targets
36 | self.inputs, self.targets = self.load_json()
37 |
38 | # Set the maximum sequence length and dataset length
39 | self.max_seq_len = max([len(text) for text in self.inputs])
40 | self.length = len(self.targets)
41 |
42 | # get word embeddings from fasttext
43 | self.emb_bag = self.get_vocab('FastText')
44 |
45 | # set tokenizer
46 | self.tokenizer = tokenizer
47 |
48 | def _check_exists(self):
49 | return os.path.exists(os.path.join(self.root, 'beer'))
50 |
51 | def download(self):
52 | if self._check_exists():
53 | return
54 | _ = torchtext.utils.download_from_url(
55 | url=self.URL[self.aspect],
56 | root=os.path.join(self.root, 'beer'),
57 | hash_value=self.MD5[self.aspect],
58 | hash_type='md5'
59 | )
60 |
61 | def load_json(self):
62 | inputs, targets = [], []
63 | path = f'{self.root}/beer/{self.aspect}.json'
64 | with open(path, 'r') as f:
65 | for line in f:
66 | example = json.loads(line)
67 | targets.append(example['y'])
68 | inputs.append(example['text'])
69 | targets = torch.tensor(targets)
70 | return inputs, targets
71 |
72 | def get_vocab(self, name='FastText'):
73 | vocab = getattr(torchtext.vocab, name)()
74 |
75 | # Add the pad token
76 | specials = ['']
77 | for token in specials:
78 | vocab.stoi[token] = len(vocab.itos)
79 | vocab.itos.append(token)
80 | vocab.vectors = torch.cat([vocab.vectors, torch.zeros(1, 300)], dim=0)
81 | return vocab
82 |
83 | def __len__(self):
84 | return self.length
85 |
86 | def __getitem__(self, index):
87 | text = self.inputs[index]
88 | if self.tokenizer is None:
89 | padded_text = self.emb_bag.stoi[''] * torch.ones(self.max_seq_len)
90 | padded_text[:len(text)] = torch.tensor([
91 | self.emb_bag.stoi[token] if token in self.emb_bag.stoi
92 | else self.emb_bag.stoi['unk'] for token in text
93 | ])
94 | inputs = torch.nn.functional.embedding(padded_text.long(), self.emb_bag.vectors).detach()
95 | else:
96 | inputs = self.tokenizer(
97 | text,
98 | return_tensors='pt',
99 | is_split_into_words=True,
100 | max_length=self.max_seq_len,
101 | return_attention_mask=False,
102 | truncation=True,
103 | padding='max_length'
104 | )['input_ids']
105 | targets = self.targets[index]
106 | return inputs, targets
107 |
108 | def __repr__(self):
109 | return f'[BeerReviews ({self.aspect})] CLIENT'
110 |
111 | # helper method to fetch Beer Reviews dataset
112 | def fetch_beerreviews(args, root, aspect='look', tokenizer=None):
113 | logger.info(f'[LOAD] [BEERREVIEWS] Fetching dataset!')
114 |
115 | # create training dataset instance
116 | raw_train = BeerReviews(root, aspect, tokenizer)
117 |
118 | # create test dataset instance
119 | raw_test = None
120 |
121 | logger.info('[LOAD] [BEERREVIEWS] ...fetched dataset!')
122 |
123 | # adjust argument
124 | args.in_features = 300
125 | args.num_classes = 2
126 | if tokenizer is None: # use FastText embedding
127 | args.num_embedings = len(raw_train.emb_bag)
128 | args.embedding_size = raw_train.emb_bag.dim
129 | return raw_train, raw_test, args
130 |
--------------------------------------------------------------------------------
/src/datasets/cinic10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import torchvision
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 |
10 | # dataset wrapper module
11 | class CINIC10(torchvision.datasets.ImageFolder):
12 | base_folder = 'cinic-10-batches-py'
13 | zip_md5 ='6ee4d0c996905fe93221de577967a372'
14 | splits = ('train', 'val', 'test')
15 | filename = 'CINIC-10.tar.gz'
16 | url = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz'
17 |
18 | def __init__(self, root, split='train', download=True, transform=None, **kwargs):
19 | self.data_root = os.path.expanduser(root)
20 | self.split = torchvision.datasets.utils.verify_str_arg(split, 'split', self.splits)
21 | if download:
22 | self.download()
23 | if not self._check_exists():
24 | err = 'Dataset not found or corrupted. You can use download=True to download it'
25 | logger.exception(err)
26 | raise RuntimeError(err)
27 | super().__init__(root=self.split_folder, transform=transform, **kwargs)
28 |
29 | @property
30 | def dataset_folder(self):
31 | return os.path.join(self.data_root, self.base_folder)
32 |
33 | @property
34 | def split_folder(self):
35 | return os.path.join(self.dataset_folder, self.split)
36 |
37 | def _check_exists(self):
38 | return os.path.exists(self.split_folder)
39 |
40 | def download(self):
41 | if self._check_exists():
42 | return
43 | torchvision.datasets.utils.download_and_extract_archive(
44 | self.url, self.dataset_folder, filename=self.filename,
45 | remove_finished=True, md5=self.zip_md5
46 | )
47 |
48 | def __repr__(self):
49 | rep_str = {'train': 'CLIENT', 'test': 'SERVER'}
50 | return f'[CINIC10] {rep_str[self.split]}'
51 |
52 | # helper method to fetch CINIC-10 dataset
53 | def fetch_cinic10(args, root, transforms):
54 | logger.info('[LOAD] [CINIC10] Fetching dataset!')
55 |
56 | # default arguments
57 | DEFAULT_ARGS = {'root': root, 'transform': None, 'download': True}
58 |
59 | # configure arguments for training/test dataset
60 | train_args = DEFAULT_ARGS.copy()
61 | train_args['split'] = 'train'
62 | train_args['transform'] = transforms[0]
63 |
64 | # create training dataset instance
65 | raw_train = CINIC10(**train_args)
66 |
67 | # for global holdout set
68 | test_args = DEFAULT_ARGS.copy()
69 | test_args['transform'] = transforms[1]
70 | test_args['split'] = 'test'
71 |
72 | # create test dataset instance
73 | raw_test = CINIC10(**test_args)
74 |
75 | logger.info('[LOAD] [CINIC10] ...fetched dataset!')
76 |
77 | # adjust arguments
78 | args.in_channels = 3
79 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.targets)))
80 | return raw_train, raw_test, args
81 |
--------------------------------------------------------------------------------
/src/datasets/cover.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import torchvision
5 | import pandas as pd
6 |
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.preprocessing import StandardScaler
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 |
14 | class Cover(torch.utils.data.Dataset):
15 | def __init__(self, groupby, inputs, targets, scaler):
16 | self.identifier = groupby
17 | self.inputs, self.targets = inputs, targets
18 | self.scaler = scaler
19 |
20 | @staticmethod
21 | def inverse_transform(self, inputs):
22 | return self.scaler.inverse_transform(inputs)
23 |
24 | def __len__(self):
25 | return len(self.inputs)
26 |
27 | def __getitem__(self, index):
28 | inputs, targets = torch.tensor(self.inputs[index]).float(), torch.tensor(self.targets[index]).long()
29 | return inputs, targets
30 |
31 | def __repr__(self):
32 | return self.identifier
33 |
34 | # helper method to fetch Cover type classification dataset
35 | # NOTE: data is grouped and split by `wilderness_area`
36 | def fetch_cover(args, root, seed, test_size):
37 | URL = 'http://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz'
38 | MD5 = '99670d8d942f09d459c7d4486fca8af5'
39 | COL_NAME = [
40 | 'elevation', 'aspect', 'slope',\
41 | 'horizontal_distance_to_hydrology', 'vertical_distance_to_hydrology', 'horizontal_distance_to_roadways',\
42 | 'hillshade_9am', 'hiillshade_noon', 'hillshade_3pm',\
43 | 'horizontal_distance_to_fire_points',\
44 | 'wilderness_area',\
45 | 'soil_type_0', 'soil_type_1', 'soil_type_2', 'soil_type_3', 'soil_type_4',\
46 | 'soil_type_5', 'soil_type_6', 'soil_type_7', 'soil_type_8', 'soil_type_9',\
47 | 'soil_type_10', 'soil_type_11', 'soil_type_12', 'soil_type_13', 'soil_type_14',\
48 | 'soil_type_15', 'soil_type_16', 'soil_type_17', 'soil_type_18', 'soil_type_19',\
49 | 'soil_type_20', 'soil_type_21', 'soil_type_22', 'soil_type_23', 'soil_type_24',\
50 | 'soil_type_25', 'soil_type_26', 'soil_type_27', 'soil_type_28', 'soil_type_29',\
51 | 'soil_type_30', 'soil_type_31', 'soil_type_32', 'soil_type_33', 'soil_type_34',\
52 | 'soil_type_35', 'soil_type_36', 'soil_type_37', 'soil_type_38', 'soil_type_39',\
53 | 'cover_type'
54 | ]
55 | AREA = ['Rawah', 'Neota', 'Comanche Peak', 'Cache la Poudre']
56 |
57 | def _download(root):
58 | torchvision.datasets.utils.download_and_extract_archive(
59 | URL, root, filename= URL.split('/')[-1],
60 | remove_finished=True, md5=MD5
61 | )
62 | os.rename(os.path.join(root, 'covtype.data'), os.path.join(root, 'covtype.csv'))
63 |
64 | def _munge_and_split(root, seed, test_size):
65 | # load data
66 | df = pd.read_csv(os.path.join(root, 'covtype.csv'), header=None)
67 |
68 | # reverse one-hot encoded columns
69 | wilderness_area = pd.Series(df.iloc[:, 10:14].values.argmax(1))
70 |
71 | # concatenate into one dataframe
72 | df_raw = pd.concat([df.iloc[:, :10], wilderness_area, df.iloc[:, 14:-1], df.iloc[:, -1].sub(1)], axis=1)
73 |
74 | # rename column
75 | df_raw.columns = COL_NAME
76 |
77 | # split by wilderness area
78 | client_datasets = []
79 | for idx, name in enumerate(AREA):
80 | # get dataframe
81 | df_temp = df_raw[df_raw['wilderness_area'] == idx].reset_index(drop=True)
82 |
83 | # get inputs and targets
84 | inputs, targets = df_temp.iloc[:, :-1].values.astype('float'), df_temp.iloc[:, -1].values.astype('float')
85 |
86 | # train-test split with stratified manner
87 | train_inputs, test_inputs, train_targets, test_targets = train_test_split(inputs, targets, test_size=test_size, random_state=seed, stratify=targets)
88 |
89 | # scaling inputs
90 | scaler = StandardScaler()
91 | train_inputs[:, :-40] = scaler.fit_transform(train_inputs[:, :-40]) # exclude last 40 columns (`soil type` - one-hot encoded categorical variable)
92 | test_inputs[:, :-40] = scaler.transform(test_inputs[:, :-40])
93 |
94 | # assign as a client dataset
95 | client_datasets.append(
96 | (
97 | Cover(f'[COVER] CLIENT < {name} > (train)', train_inputs, train_targets, scaler),
98 | Cover(f'[COVER] CLIENT < {name} > (test)', test_inputs, test_targets, scaler)
99 | )
100 | )
101 | return client_datasets
102 |
103 | logger.info(f'[LOAD] [COVER] Check if raw data exists; if not, start downloading!')
104 | if not os.path.exists(os.path.join(root, 'covertype')):
105 | _download(root=os.path.join(root, 'covertype'))
106 | logger.info(f'[LOAD] [COVER] ...raw data is successfully downloaded!')
107 | else:
108 | logger.info(f'[LOAD] [COVER] ...raw data already exists!')
109 |
110 | logger.info(f'[LOAD] [COVER] Munging and splitting dataset!')
111 | client_datasets = _munge_and_split(os.path.join(root, 'covertype'), seed, test_size)
112 | logger.info('[LOAD] [COVER] ...munged and splitted dataset!')
113 |
114 | args.in_features = 51
115 | args.num_classes = 7
116 | args.K = 4
117 | return {}, client_datasets, args
118 |
--------------------------------------------------------------------------------
/src/datasets/heart.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import torchtext
5 | import pandas as pd
6 |
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.preprocessing import StandardScaler
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 |
14 | class Heart(torch.utils.data.Dataset):
15 | def __init__(self, hospital, inputs, targets, scaler):
16 | self.identifier = hospital
17 | self.inputs, self.targets = inputs, targets
18 | self.scaler = scaler
19 |
20 | @staticmethod
21 | def inverse_transform(self, inputs):
22 | return self.scaler.inverse_transform(inputs)
23 |
24 | def __len__(self):
25 | return len(self.inputs)
26 |
27 | def __getitem__(self, index):
28 | inputs, targets = torch.tensor(self.inputs[index]).float(), torch.tensor(self.targets[index]).long()
29 | return inputs, targets
30 |
31 | def __repr__(self):
32 | return self.identifier
33 |
34 | # helper method to fetch Heart disease classification dataset
35 | def fetch_heart(args, root, seed, test_size):
36 | URL = {
37 | 'cleveland': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data',
38 | 'hungarian': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.hungarian.data',
39 | 'switzerland': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.switzerland.data',
40 | 'va': 'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.va.data'
41 | }
42 | MD5 = {
43 | 'cleveland': '2d91a8ff69cfd9616aa47b59d6f843db',
44 | 'hungarian': '22e96bee155b5973568101c93b3705f6',
45 | 'switzerland': '9a87f7577310b3917730d06ba9349e20',
46 | 'va': '4249d03ca7711e84f4444768c9426170'
47 | }
48 | COL_NAME = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'targets']
49 |
50 | def _download(root):
51 | for hospital in URL.keys():
52 | _ = torchtext.utils.download_from_url(
53 | url=URL[hospital],
54 | root=root,
55 | hash_value=MD5[hospital],
56 | hash_type='md5'
57 | )
58 | os.rename(os.path.join(root, URL[hospital].split('/')[-1]), os.path.join(root, f'HEART ({hospital}).csv'))
59 |
60 | def _munge_and_split(root, hospital, seed, test_size):
61 | # load data
62 | to_drop = [
63 | 10, # the slope of the peak exercise ST segment
64 | 11, # number of major vessels (0-3) colored by flourosopy
65 | 12 # thalassemia background
66 | ]
67 | df = pd.read_csv(os.path.join(root, f'HEART ({hospital}).csv'), header=None, na_values='?', usecols=[i for i in range(14) if i not in to_drop]).apply(lambda x: x.fillna(x.mean()),axis=0).reset_index(drop=True)
68 |
69 | # rename column
70 | df.columns = COL_NAME
71 |
72 | # adjust dtypes
73 | df['targets'] = df['targets'].where(df['targets'] == 0, 1)
74 | df['age'] = df['age'].astype(float)
75 | df['sex'] = df['sex'].astype(int)
76 | df['cp'] = df['cp'].astype(int)
77 | df['trestbps'] = df['trestbps'].astype(float)
78 | df['chol'] = df['chol'].astype(float)
79 | df['restecg'] = df['restecg'].astype(int)
80 | df['cp'] = df['cp'].astype(int)
81 | df['thalach'] = df['thalach'].astype(float)
82 | df['exang'] = df['exang'].astype(int)
83 | df['oldpeak'] = df['oldpeak'].astype(float)
84 |
85 | # get one-hot encoded dummy columns for categorical data
86 | df = pd.concat([pd.get_dummies(df.iloc[:, :-1], columns=['cp', 'restecg'], drop_first=True, dtype=int), df[['targets']]], axis=1)
87 |
88 | # get inputs and targets
89 | inputs, targets = df.iloc[:, :-1].values, df.iloc[:, -1].values
90 |
91 | # train-test split with stratified manner
92 | train_inputs, test_inputs, train_targets, test_targets = train_test_split(inputs, targets, test_size=test_size, random_state=seed, stratify=targets)
93 |
94 | # scaling inputs
95 | scaler = StandardScaler()
96 | train_inputs = scaler.fit_transform(train_inputs)
97 | test_inputs = scaler.transform(test_inputs)
98 | return (
99 | Heart(f'[HEART] CLIENT < {hospital} > (train)', train_inputs, train_targets, scaler),
100 | Heart(f'[HEART] CLIENT < {hospital} > (test)', test_inputs, test_targets, scaler)
101 | )
102 |
103 | logger.info(f'[LOAD] [HEART] Check if raw data exists; if not, start downloading!')
104 | if not os.path.exists(os.path.join(root, 'heart')):
105 | _download(root=os.path.join(root, 'heart'))
106 | logger.info(f'[LOAD] [HEART] ...raw data is successfully downloaded!')
107 | else:
108 | logger.info(f'[LOAD] [HEART] ...raw data already exists!')
109 |
110 | logger.info(f'[LOAD] [HEART] Munging and splitting dataset!')
111 | client_datasets = []
112 | for hospital in URL.keys():
113 | client_datasets.append(_munge_and_split(os.path.join(root, 'heart'), hospital, seed, test_size))
114 | logger.info('[LOAD] [HEART] ...munged and splitted dataset!')
115 |
116 | args.in_features = 13
117 | args.num_classes = 2
118 | args.K = 4
119 | return {}, client_datasets, args
120 |
--------------------------------------------------------------------------------
/src/datasets/leaf/__init__.py:
--------------------------------------------------------------------------------
1 | from .postprocess import postprocess_leaf
2 | from .leaf_utils import *
--------------------------------------------------------------------------------
/src/datasets/leaf/leaf_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | import logging
4 | import requests
5 | import torchvision
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | __all__ = ['download_data']
10 |
11 |
12 |
13 | URL = {
14 | 'femnist': [
15 | 'https://s3.amazonaws.com/nist-srd/SD19/by_class.zip',
16 | 'https://s3.amazonaws.com/nist-srd/SD19/by_write.zip'
17 | ],
18 | 'shakespeare': ['http://www.gutenberg.org/files/100/old/1994-01-100.zip'],
19 | 'sent140': [
20 | 'http://cs.stanford.edu/people/alecmgo/trainingandtestdata.zip',
21 | 'http://nlp.stanford.edu/data/glove.6B.zip' # GloVe embedding for vocabularies
22 | ],
23 | 'celeba': [
24 | '1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS', # Google Drive link ID
25 | '0B7EVK8r0v71pblRyaVFSWGxPY0U', # Google Drive link ID
26 | 'https://cseweb.ucsd.edu/~weijian/static/datasets/celeba/img_align_celeba.zip' # img_align_celeba.zip
27 | ],
28 | 'reddit': ['1ISzp69JmaIJqBpQCX-JJ8-kVyUns8M7o'] # Google Drive link ID
29 | }
30 | OPT = { # md5 checksum if direct URL link is provided, file name if Google Drive link ID is provided
31 | 'femnist': ['79572b1694a8506f2b722c7be54130c4', 'a29f21babf83db0bb28a2f77b2b456cb'],
32 | 'shakespeare': ['b8d60664a90939fa7b5d9f4dd064a1d5'],
33 | 'sent140': ['1647eb110dd2492512e27b9a70d5d1bc', '056ea991adb4740ac6bf1b6d9b50408b'],
34 | 'celeba': ['identity_CelebA.txt', 'list_attr_celeba.txt', '00d2c5bc6d35e252742224ab0c1e8fcb'],
35 | 'reddit': ['reddit_subsampled.zip']
36 | }
37 |
38 | def download_data(download_root, dataset_name):
39 | """Download data from Google Drive and extract if it is archived.
40 | """
41 | def _get_confirm_token(response):
42 | for key, value in response.cookies.items():
43 | if key.startswith('download_warning'):
44 | return value
45 | return None
46 |
47 | def _save_response_content(download_root, response):
48 | CHUNK_SIZE = 32768
49 | with open(download_root, 'wb') as file:
50 | for chunk in response.iter_content(CHUNK_SIZE):
51 | if chunk: # filter out keezp-alive new chunks
52 | file.write(chunk)
53 |
54 | def _download_file_from_google_drive(download_root, file_name, identifier):
55 | BASE_URL = 'https://docs.google.com/uc?export=download'
56 |
57 | session = requests.Session()
58 | response = session.get(BASE_URL, params={'id': identifier, 'confirm': 1}, stream=True)
59 | token = _get_confirm_token(response)
60 |
61 | if token:
62 | params = {'id': identifier, 'confirm': token }
63 | response = session.get(BASE_URL, params=params, stream=True)
64 | _save_response_content(os.path.join(download_root, file_name), response)
65 | print(f'...successfully downloaded file `{file_name}` at `{download_root}`!')
66 |
67 | if '.zip' in file_name:
68 | with zipfile.ZipFile(os.path.join(download_root, file_name), 'r', compression=zipfile.ZIP_STORED) as zip_file:
69 | zip_file.extractall(download_root)
70 | print(f'...successfully extracted `{file_name}` at `{download_root}`!')
71 |
72 | # download data from web
73 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Start downloading data...!')
74 | try:
75 | for (url, opt) in zip(URL[dataset_name], OPT[dataset_name]):
76 | if 'http' not in url:
77 | _download_file_from_google_drive(download_root, opt, url)
78 | else:
79 | torchvision.datasets.utils.download_and_extract_archive(
80 | url=url,
81 | download_root=download_root,
82 | md5=opt,
83 | remove_finished=True
84 | )
85 | else:
86 | logger.info(f'[LEAF - {dataset_name.upper()}] ...finished downloading data!')
87 | except:
88 | logger.exception(url)
89 | raise Exception(url)
90 |
--------------------------------------------------------------------------------
/src/datasets/leaf/postprocess/__init__.py:
--------------------------------------------------------------------------------
1 | from .postprocess import postprocess_leaf
--------------------------------------------------------------------------------
/src/datasets/leaf/postprocess/filter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 |
5 |
6 | def filter_clients(dataset_name, root, min_samples_per_clients):
7 | # set path
8 | data_dir = os.path.join(root, dataset_name)
9 | subdir = os.path.join(data_dir, 'sampled_data')
10 |
11 | # collect sampled files
12 | files = []
13 | if os.path.exists(subdir):
14 | files = os.listdir(subdir)
15 | if len(files) == 0:
16 | subdir = os.path.join(data_dir, 'all_data')
17 | files = os.listdir(subdir)
18 | files = [f for f in files if f.endswith('.json')]
19 |
20 | # calculate remaining clients data
21 | for f in files:
22 | users = []
23 | hierarchies = []
24 | num_samples = []
25 | user_data = {}
26 |
27 | file_dir = os.path.join(subdir, f)
28 | with open(file_dir, 'r') as file:
29 | data = json.load(file)
30 |
31 | num_users = len(data['users'])
32 | for i in range(num_users):
33 | curr_user = data['users'][i]
34 | curr_hierarchy = None
35 | if 'hierarchies' in data:
36 | curr_hierarchy = data['hierarchies'][i]
37 | curr_num_samples = data['num_samples'][i]
38 | if curr_num_samples >= min_samples_per_clients:
39 | user_data[curr_user] = data['user_data'][curr_user]
40 | users.append(curr_user)
41 | if curr_hierarchy is not None:
42 | hierarchies.append(curr_hierarchy)
43 | num_samples.append(data['num_samples'][i])
44 |
45 | # create json file
46 | all_data = {}
47 | all_data['users'] = users
48 | if len(hierarchies) == len(users):
49 | all_data['hierarchies'] = hierarchies
50 | all_data['num_samples'] = num_samples
51 | all_data['user_data'] = user_data
52 |
53 | # save file
54 | with open(os.path.join(data_dir, 'rem_clients_data', f'{f[:-5]}_keep_{str(min_samples_per_clients).zfill(4)}.json'), 'w') as outfile:
55 | json.dump(all_data, outfile)
56 |
--------------------------------------------------------------------------------
/src/datasets/leaf/postprocess/postprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 |
5 | from src.datasets.leaf.postprocess.sample import sample_clients
6 | from src.datasets.leaf.postprocess.filter import filter_clients
7 | from src.datasets.leaf.postprocess.split import split_datasets
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 |
13 | def postprocess_leaf(dataset_name, root, seed, raw_data_fraction, min_samples_per_clients, test_size):
14 | # check if raw data is prepared
15 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Check pre-processing data...!')
16 | if not os.path.exists(f'{root}/{dataset_name}/all_data'):
17 | err = f'[LOAD] [LEAF - {dataset_name.upper()}] Please check if the raw data is correctly prepared in `{root}`!'
18 | raise AssertionError(err)
19 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...data pre-proceesing has been completed!')
20 |
21 | # create client datasets
22 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Sample clients from raw data...!')
23 | if not os.path.exists(f'{root}/{dataset_name}/sampled_data'):
24 | os.makedirs(f'{root}/{dataset_name}/sampled_data')
25 | sample_clients(dataset_name, root, seed, raw_data_fraction)
26 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done sampling clients from raw data!')
27 |
28 | # remove clients with less than given `min_samples_per_clients`
29 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Filter out remaining clients...!')
30 | if not os.path.exists(f'{root}/{dataset_name}/rem_clients_data') and (raw_data_fraction < 1.):
31 | os.makedirs(f'{root}/{dataset_name}/rem_clients_data')
32 | filter_clients(dataset_name, root, min_samples_per_clients)
33 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done filtering remaining clients!')
34 |
35 | # create train-test split
36 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Split into training & test sets...!')
37 | if (not os.path.exists(f'{root}/{dataset_name}/train')) or (not os.path.exists(f'{root}/{dataset_name}/test')):
38 | if not os.path.exists(f'{root}/{dataset_name}/train'):
39 | os.makedirs(f'{root}/{dataset_name}/train')
40 | if not os.path.exists(f'{root}/{dataset_name}/test'):
41 | os.makedirs(f'{root}/{dataset_name}/test')
42 | split_datasets(dataset_name, root, seed, test_size)
43 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done splitting into training & test sets!')
44 |
45 | # get number of clients
46 | train_data = [file for file in os.listdir(os.path.join(root, dataset_name, 'train')) if file.endswith('.json')][0]
47 | num_clients = len(json.load(open(f'{root}/{dataset_name}/train/{train_data}', 'r'))['users'])
48 | return num_clients
49 |
--------------------------------------------------------------------------------
/src/datasets/leaf/postprocess/sample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 |
5 | from collections import OrderedDict
6 |
7 |
8 |
9 | def sample_clients(dataset_name, root, seed, used_raw_data_fraction):
10 | # set path
11 | data_dir = os.path.join(root, dataset_name)
12 | subdir = os.path.join(data_dir, 'all_data')
13 | files = os.listdir(subdir)
14 | files = [f for f in files if f.endswith('.json')]
15 |
16 | # set seed
17 | rng = random.Random(seed)
18 |
19 | # split data in non-IID manner
20 | for f in files:
21 | file_dir = os.path.join(subdir, f)
22 | with open(file_dir, 'r') as file:
23 | data = json.load(file, object_pairs_hook=OrderedDict)
24 |
25 | # get meta data
26 | num_users = len(data['users'])
27 | tot_num_samples = sum(data['num_samples'])
28 | num_new_samples = int(used_raw_data_fraction * tot_num_samples)
29 | hierarchies = None
30 |
31 | # non-IID split
32 | ctot_num_samples = 0
33 | users = data['users']
34 | users_and_hiers = None
35 | if 'hierarchies' in data:
36 | users_and_hiers = list(zip(users, data['hierarchies']))
37 | rng.shuffle(users_and_hiers)
38 | else:
39 | rng.shuffle(users)
40 | user_i = 0
41 | num_samples = []
42 | user_data = {}
43 |
44 | if 'hierarchies' in data:
45 | hierarchies = []
46 |
47 | while(ctot_num_samples < num_new_samples):
48 | hierarchy = None
49 | if users_and_hiers is not None:
50 | user, hier = users_and_hiers[user_i]
51 | else:
52 | user = users[user_i]
53 | cdata = data['user_data'][user]
54 | cnum_samples = len(data['user_data'][user]['y'])
55 |
56 | if (ctot_num_samples + cnum_samples) > num_new_samples:
57 | cnum_samples = num_new_samples - ctot_num_samples
58 | indices = [i for i in range(cnum_samples)]
59 | new_indices = rng.sample(indices, cnum_samples)
60 | x, y = [], []
61 | for i in new_indices:
62 | x.append(data['user_data'][user]['x'][i])
63 | y.append(data['user_data'][user]['y'][i])
64 | cdata = {'x': x, 'y': y}
65 |
66 | if 'hierarchies' in data:
67 | hierarchies.append(hier)
68 |
69 | num_samples.append(cnum_samples)
70 | user_data[user] = cdata
71 | ctot_num_samples += cnum_samples
72 | user_i += 1
73 |
74 | if 'hierarchies' in data:
75 | users = [u for u, h in users_and_hiers][:user_i]
76 | else:
77 | users = users[:user_i]
78 |
79 | # create json file
80 | all_data = {}
81 | all_data['users'] = users
82 | if hierarchies is not None:
83 | all_data['hierarchies'] = hierarchies
84 | all_data['num_samples'] = num_samples
85 | all_data['user_data'] = user_data
86 |
87 | # save file
88 | with open(os.path.join(data_dir, 'sampled_data', f'{f[:-5]}_niid_0{str(used_raw_data_fraction)[2:]}.json'), 'w') as out_file:
89 | json.dump(all_data, out_file)
90 |
--------------------------------------------------------------------------------
/src/datasets/leaf/postprocess/split.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import random
5 | import logging
6 |
7 | from collections import OrderedDict
8 |
9 |
10 |
11 | def split_datasets(dataset_name, root, seed, test_size):
12 | # set path
13 | data_dir = os.path.join(root, dataset_name)
14 | subdir = os.path.join(data_dir, 'rem_clients_data')
15 |
16 | # collect sampled files
17 | files = []
18 | if os.path.exists(subdir):
19 | files = os.listdir(subdir)
20 | if len(files) == 0:
21 | subdir = os.path.join(data_dir, 'sampled_data')
22 | if os.path.exists(subdir):
23 | files = os.listdir(subdir)
24 | if len(files) == 0:
25 | subdir = os.path.join(data_dir, 'all_data')
26 | files = os.listdir(subdir)
27 | files = [f for f in files if f.endswith('.json')]
28 |
29 | # set seed
30 | rng = random.Random(seed)
31 |
32 | # check if data contains information on hierarchies
33 | file_dir = os.path.join(subdir, files[0])
34 | with open(file_dir, 'r') as file:
35 | data = json.load(file)
36 |
37 | # split training/test data inside clients
38 | for f in files:
39 | file_dir = os.path.join(subdir, f)
40 | with open(file_dir, 'r') as file:
41 | data = json.load(file, object_pairs_hook=OrderedDict)
42 |
43 | num_samples_train, num_samples_test = [], []
44 | user_data_train, user_data_test = {}, {}
45 | user_indices = [] # indices of users in `data['users']` that are not deleted
46 |
47 | for i, u in enumerate(data['users']):
48 | curr_num_samples = len(data['user_data'][u]['y'])
49 | if curr_num_samples >= 2:
50 | # ensures number of train and test samples both >= 1
51 | num_train_samples = max(1, int((1. - test_size) * curr_num_samples))
52 | if curr_num_samples == 2:
53 | num_train_samples = 1
54 | num_test_samples = curr_num_samples - num_train_samples
55 |
56 | indices = [j for j in range(curr_num_samples)]
57 | if dataset_name == 'shakespeare':
58 | train_indices = [i for i in range(num_train_samples)]
59 | test_indices = [i for i in range(num_train_samples + 80 - 1, curr_num_samples)]
60 | else:
61 | train_indices = rng.sample(indices, num_train_samples)
62 | test_indices = [i for i in range(curr_num_samples) if i not in train_indices]
63 |
64 | if len(train_indices) >= 1 and len(test_indices) >= 1:
65 | user_indices.append(i)
66 | num_samples_train.append(num_train_samples)
67 | num_samples_test.append(num_test_samples)
68 | user_data_train[u] = {'x': [], 'y': []}
69 | user_data_test[u] = {'x': [], 'y': []}
70 |
71 | train_blist = [False for _ in range(curr_num_samples)]
72 | test_blist = [False for _ in range(curr_num_samples)]
73 |
74 | for j in train_indices: train_blist[j] = True
75 | for j in test_indices:test_blist[j] = True
76 |
77 | for j in range(curr_num_samples):
78 | if train_blist[j]:
79 | user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
80 | user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
81 | elif test_blist[j]:
82 | user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
83 | user_data_test[u]['y'].append(data['user_data'][u]['y'][j])
84 | users = [data['users'][i] for i in user_indices]
85 |
86 | # create json file of training set
87 | all_data_train = {}
88 | all_data_train['users'] = users
89 | all_data_train['num_samples'] = num_samples_train
90 | all_data_train['user_data'] = user_data_train
91 |
92 | # save file of training set
93 | with open(os.path.join(data_dir, 'train', f'{f[:-5]}_train_0{str(1. - test_size)[2:]}.json'), 'w') as outfile:
94 | json.dump(all_data_train, outfile)
95 |
96 | # create json file of test set
97 | all_data_test = {}
98 | all_data_test['users'] = users
99 | all_data_test['num_samples'] = num_samples_test
100 | all_data_test['user_data'] = user_data_test
101 |
102 | # save file of test set
103 | with open(os.path.join(data_dir, 'test', f'{f[:-5]}_test_0{str(test_size)[2:]}.json'), 'w') as outfile:
104 | json.dump(all_data_test, outfile)
105 |
--------------------------------------------------------------------------------
/src/datasets/leaf/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/datasets/leaf/preprocess/__init__.py
--------------------------------------------------------------------------------
/src/datasets/leaf/preprocess/celeba.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 |
9 | def preprocess(root):
10 | TARGET_NAME = 'Smiling'
11 |
12 | def _get_metadata(path):
13 | with open(os.path.join(path, 'raw', 'identity_CelebA.txt'), 'r') as f_identities:
14 | identities = f_identities.read().split('\n')
15 | with open(os.path.join(path, 'raw', 'list_attr_celeba.txt'), 'r') as f_attributes:
16 | attributes = f_attributes.read().split('\n')
17 | return identities, attributes
18 |
19 | def _get_celebrities_and_images(identities, min_samples=5):
20 | all_celebs = {}
21 | for line in identities:
22 | info = line.split()
23 | if len(info) < 2:
24 | continue
25 | image, celeb = info[0], info[1]
26 | if celeb not in all_celebs:
27 | all_celebs[celeb] = []
28 | all_celebs[celeb].append(image)
29 |
30 | # ignore all celebrities with less than `min_samples` images.
31 | good_celebs = {c: all_celebs[c] for c in all_celebs if len(all_celebs[c]) >= min_samples}
32 | return good_celebs
33 |
34 | def _get_celebrities_and_target(celebrities, attributes, attribute_name=TARGET_NAME):
35 | def _get_celebrities_by_image(identities):
36 | good_images = {}
37 | for c in identities:
38 | images = identities[c]
39 | for img in images:
40 | good_images[img] = c
41 | return good_images
42 |
43 | col_names = attributes[1]
44 | col_idx = col_names.split().index(attribute_name)
45 |
46 | celeb_attributes = {}
47 | good_images = _get_celebrities_by_image(celebrities)
48 |
49 | for line in attributes[2:]:
50 | info = line.split()
51 | if len(info) == 0:
52 | continue
53 | image = info[0]
54 | if image not in good_images:
55 | continue
56 | celeb = good_images[image]
57 | att = (int(info[1:][col_idx]) + 1) / 2
58 | if celeb not in celeb_attributes:
59 | celeb_attributes[celeb] = []
60 | celeb_attributes[celeb].append(att)
61 | return celeb_attributes
62 |
63 | def _convert_to_json(path, celebrities, targets):
64 | all_data = {}
65 |
66 | celeb_keys = [c for c in celebrities]
67 | num_samples = [len(celebrities[c]) for c in celeb_keys]
68 | data = {c: {'x': celebrities[c], 'y': targets[c]} for c in celebrities}
69 |
70 | all_data['users'] = celeb_keys
71 | all_data['num_samples'] = num_samples
72 | all_data['user_data'] = data
73 |
74 | with open(os.path.join(path, 'all_data', 'all_data.json'), 'w') as outfile:
75 | json.dump(all_data, outfile)
76 |
77 | # set path
78 | DATASET_NAME = __file__.split('/')[-1].split('.')[0]
79 | path = os.path.join(os.path.expanduser(root), DATASET_NAME)
80 |
81 | # check if preprocessing has already done
82 | if not os.path.exists(os.path.join(path, 'all_data')):
83 | os.makedirs(os.path.join(path, 'all_data'))
84 | else:
85 | return
86 |
87 | # get IDs and attributes
88 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Get meta-data...!')
89 | identities, attributes = _get_metadata(path)
90 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished parsing meta-data (IDs and attributes)!')
91 |
92 | # filter out celebs
93 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Construct celeb-image hashmap...!')
94 | celebrities = _get_celebrities_and_images(identities)
95 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...constructed celeb-image hashmap!')
96 |
97 | # filter out targets
98 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Construct inputs-targets hashmap...!')
99 | targets = _get_celebrities_and_target(celebrities, attributes)
100 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...constructed inputs-targets hashmap!')
101 |
102 | # convert to json format
103 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert data to json format...!')
104 | _convert_to_json(path, celebrities, targets)
105 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished converting data to json format!')
106 |
--------------------------------------------------------------------------------
/src/datasets/leaf/preprocess/femnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import logging
5 | import hashlib
6 |
7 | import numpy as np
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 |
13 | def preprocess(root):
14 | MAX_WRITERS = 100 # max number of writers per json file.
15 |
16 | def _save_obj(obj, name):
17 | with open(f'{name}.pkl', 'wb') as f:
18 | pickle.dump(obj, f)
19 |
20 | def _load_obj(name):
21 | with open(f'{name}.pkl', 'rb') as f:
22 | file = pickle.load(f)
23 | return file
24 |
25 | def _parse_files(path):
26 | def _parse_class_files(path):
27 | class_files = [] # (class, file directory)
28 | class_dir = os.path.join(path, 'raw', 'by_class')
29 | rel_class_dir = os.path.join(path, 'raw', 'by_class')
30 | classes = os.listdir(class_dir)
31 | classes = [c for c in classes if len(c) == 2]
32 |
33 | for cl in classes:
34 | cldir = os.path.join(class_dir, cl)
35 | rel_cldir = os.path.join(rel_class_dir, cl)
36 | subcls = os.listdir(cldir)
37 |
38 | subcls = [s for s in subcls if (('hsf' in s) and ('mit' not in s))]
39 |
40 | for subcl in subcls:
41 | subcldir = os.path.join(cldir, subcl)
42 | rel_subcldir = os.path.join(rel_cldir, subcl)
43 | images = os.listdir(subcldir)
44 | image_dirs = [os.path.join(rel_subcldir, i) for i in images]
45 | for image_dir in image_dirs:
46 | class_files.append((cl, image_dir))
47 | _save_obj(class_files, os.path.join(path, 'intermediate', 'class_file_dirs'))
48 |
49 | def _parse_write_files(path):
50 | write_files = [] # (writer, file directory)
51 | write_dir = os.path.join(path, 'raw', 'by_write')
52 | rel_write_dir = os.path.join(path, 'raw', 'by_write')
53 | write_parts = os.listdir(write_dir)
54 |
55 | for write_part in write_parts:
56 | writers_dir = os.path.join(write_dir, write_part)
57 | rel_writers_dir = os.path.join(rel_write_dir, write_part)
58 | writers = os.listdir(writers_dir)
59 |
60 | for writer in writers:
61 | writer_dir = os.path.join(writers_dir, writer)
62 | rel_writer_dir = os.path.join(rel_writers_dir, writer)
63 | wtypes = os.listdir(writer_dir)
64 |
65 | for wtype in wtypes:
66 | type_dir = os.path.join(writer_dir, wtype)
67 | rel_type_dir = os.path.join(rel_writer_dir, wtype)
68 | images = os.listdir(type_dir)
69 | image_dirs = [os.path.join(rel_type_dir, i) for i in images]
70 |
71 | for image_dir in image_dirs:
72 | write_files.append((writer, image_dir))
73 | _save_obj(write_files, os.path.join(path, 'intermediate', 'write_file_dirs'))
74 |
75 | # parse class & write files sequentially
76 | _parse_class_files(path)
77 | _parse_write_files(path)
78 |
79 | def _get_hashes(path):
80 | def _get_class_hashes(path):
81 | cfd = os.path.join(path, 'intermediate', 'class_file_dirs')
82 | class_file_dirs = _load_obj(cfd)
83 | class_file_hashes = []
84 |
85 | count = 0
86 | for cclass, cfile in class_file_dirs:
87 | chash = hashlib.md5(open(cfile, 'rb').read()).hexdigest()
88 | class_file_hashes.append((cclass, cfile, chash))
89 | count += 1
90 |
91 | cfhd = os.path.join(path, 'intermediate', 'class_file_hashes')
92 | _save_obj(class_file_hashes, cfhd)
93 |
94 | def _get_write_hashes(path):
95 | wfd = os.path.join(path, 'intermediate', 'write_file_dirs')
96 | write_file_dirs = _load_obj(wfd)
97 | write_file_hashes = []
98 |
99 | count = 0
100 | for cclass, cfile in write_file_dirs:
101 | chash = hashlib.md5(open(cfile, 'rb').read()).hexdigest()
102 | write_file_hashes.append((cclass, cfile, chash))
103 | count += 1
104 |
105 | wfhd = os.path.join(path, 'intermediate', 'write_file_hashes')
106 | _save_obj(write_file_hashes, wfhd)
107 |
108 | # get class & write hashes sequentially
109 | _get_class_hashes(path)
110 | _get_write_hashes(path)
111 |
112 | def _match_by_hashes(path):
113 | # read class file hash
114 | cfhd = os.path.join(path, 'intermediate', 'class_file_hashes')
115 | class_file_hashes = _load_obj(cfhd) # each element is (class, file dir, hash)
116 | class_hash_dict = {}
117 | for i in range(len(class_file_hashes)):
118 | c, f, h = class_file_hashes[len(class_file_hashes)-i-1]
119 | class_hash_dict[h] = (c, f)
120 |
121 | # read write file hash
122 | wfhd = os.path.join(path, 'intermediate', 'write_file_hashes')
123 | write_file_hashes = _load_obj(wfhd) # each element is (writer, file dir, hash)
124 |
125 | # match
126 | write_classes = []
127 | for tup in write_file_hashes:
128 | w, f, h = tup
129 | write_classes.append((w, f, class_hash_dict[h][0]))
130 | wwcd = os.path.join(path, 'intermediate', 'write_with_class')
131 | _save_obj(write_classes, wwcd)
132 |
133 | def _group_by_write(path):
134 | wwcd = os.path.join(path, 'intermediate', 'write_with_class')
135 | write_class = _load_obj(wwcd)
136 |
137 | writers, cimages = [], [] # each entry is a (writer, [list of (file, class)]) tuple
138 | cw, _, _ = write_class[0]
139 | for w, f, c in write_class:
140 | if w != cw:
141 | writers.append((cw, cimages))
142 | cw = w
143 | cimages = [(f, c)]
144 | cimages.append((f, c))
145 | writers.append((cw, cimages))
146 |
147 | ibwd = os.path.join(path, 'intermediate', 'images_by_write')
148 | _save_obj(writers, ibwd)
149 |
150 | def _convert_to_json(path):
151 | def _relabel_femnist_class(c):
152 | """
153 | Maps hexadecimal class value (string) to a decimal number.
154 |
155 | Args:
156 | c: class indices represented by hexadecimal values
157 |
158 | Returns:
159 | - 0 through 9 for classes representing respective numbers
160 | - 10 through 35 for classes representing respective uppercase letters
161 | - 36 through 61 for classes representing respective lowercase letters
162 | """
163 | if c.isdigit() and int(c) < 40: # digit
164 | return (int(c) - 30)
165 | elif int(c, 16) <= 90: # uppercase
166 | return (int(c, 16) - 55)
167 | else: # lowercase
168 | return (int(c, 16) - 61)
169 |
170 | by_writer_dir = os.path.join(path, 'intermediate', 'images_by_write')
171 | writers = _load_obj(by_writer_dir)
172 |
173 | users, num_samples, user_data = [], [], {}
174 | writer_count, all_writers = 0, 0
175 |
176 | # assign data
177 | for w, l in writers:
178 | users.append(w)
179 | num_samples.append(len(l))
180 | user_data[w] = {'x': [], 'y': []}
181 |
182 | for f, c in l:
183 | #gray = PIL.Image.open(f).convert('L')
184 | #vec = 1 - np.array(gray) / 255 # scale all pixel values to between 0 and 1
185 | #vec = vec.tolist()
186 |
187 | nc = _relabel_femnist_class(c)
188 | user_data[w]['x'].append(str(f))
189 | user_data[w]['y'].append(nc)
190 | writer_count += 1
191 | all_writers += 1
192 | else:
193 | all_data = {}
194 | all_data['users'] = users
195 | all_data['num_samples'] = num_samples
196 | all_data['user_data'] = user_data
197 |
198 | file_name = f'all_data.json'
199 | file_path = os.path.join(path, 'all_data', file_name)
200 |
201 | with open(file_path, 'w') as outfile:
202 | json.dump(all_data, outfile)
203 |
204 | # set path
205 | DATASET_NAME = __file__.split('/')[-1].split('.')[0]
206 | path = os.path.join(os.path.expanduser(root), DATASET_NAME)
207 |
208 | # check if preprocessing has already done
209 | if not os.path.exists(os.path.join(path, 'all_data')):
210 | os.makedirs(os.path.join(path, 'all_data'))
211 | else:
212 | return
213 |
214 | # create intermediate file directories
215 | if not os.path.exists(os.path.join(path, 'intermediate')):
216 | os.makedirs(os.path.join(path, 'intermediate'))
217 |
218 | # parse files
219 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Extract files of raw images...!')
220 | _parse_files(path)
221 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished extracting files of raw images!')
222 |
223 | # get file hashes
224 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Calculate image hashes...!')
225 | _get_hashes(path)
226 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished calculating image hashes!')
227 |
228 | # match by hashes
229 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Assign class labels to write images...!')
230 | _match_by_hashes(path)
231 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished assigning class labels to write images!')
232 |
233 | # group images by writer
234 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Group images by write...!')
235 | _group_by_write(path)
236 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished grouping images by writer!')
237 |
238 | # convert to json format
239 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert data to json format...!')
240 | _convert_to_json(path)
241 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished converting data to json format!')
242 |
--------------------------------------------------------------------------------
/src/datasets/leaf/preprocess/reddit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 |
5 | from collections import Counter, defaultdict
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 |
11 | def preprocess(root):
12 | def _make_path(path, dir_name):
13 | if not os.path.exists(os.path.join(path, dir_name)):
14 | os.makedirs(os.path.join(path, dir_name))
15 |
16 | def _load_data(path):
17 | with open(path, 'r') as file:
18 | data = json.load(file)
19 | return data
20 |
21 | def _save_data(path, data):
22 | with open(path, 'w') as file:
23 | json.dump(data, file)
24 |
25 | def _refine_data(data):
26 | num_samples = []
27 | for user in data['users']:
28 | num_samples.append(len(data['user_data'][user]['x'])) # get correct sample counts
29 | data['user_data'][user]['y'] = [original if (type(original) is list) else original['target_tokens'] for original in data['user_data'][user]['y']] # don't know why... but some samples are not parsed (i.e., in `dict` format, not `list`)
30 | else:
31 | data['num_samples'] = num_samples
32 | return data
33 |
34 | def _build_counter(train_data):
35 | all_tokens = []
36 | for u in train_data:
37 | for c in train_data[u]['x']:
38 | for s in c:
39 | all_tokens.extend(s)
40 | counter = Counter()
41 | counter.update(all_tokens)
42 | return counter
43 |
44 | def _build_vocab(counter):
45 | vocab_size = 10000
46 | pad_symbol, unk_symbol, bos_symbol, eos_symbol = 0, 1, 2, 3
47 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
48 | count_pairs = count_pairs[:vocab_size - 1]
49 |
50 | words, _ = list(zip(*count_pairs))
51 | words = list(words)
52 | vocab = {}
53 | vocab[''] = pad_symbol
54 | vocab[''] = unk_symbol
55 | vocab[''] = bos_symbol
56 | vocab[''] = eos_symbol
57 |
58 | idx = 4 # due to special tokens
59 | while len(words) > 0:
60 | w = words.pop()
61 | if w in ['', '', '', '']:
62 | continue
63 | vocab[w] = idx
64 | idx += 1
65 | vocab = {'vocab': vocab, 'size': vocab_size, 'unk_symbol': unk_symbol, 'pad_symbol': pad_symbol, 'bos_symbol': bos_symbol, 'eos_symbol': eos_symbol}
66 | return vocab
67 |
68 | def _convert_to_ids_and_get_length(raw, vocab):
69 | def _tokens_to_word_ids(tokens, vocab):
70 | return [vocab[word] for word in tokens]
71 |
72 | def _convert_to_id(container, key):
73 | transformed = []
74 | for data in container[key]:
75 | for sent in data:
76 | idx = _tokens_to_word_ids(sent, vocab)
77 | transformed.append([idx])
78 | return transformed
79 |
80 | for user in raw['users']:
81 | raw['user_data'][user]['x'] = _convert_to_id(raw['user_data'][user], 'x')
82 | raw['user_data'][user]['y'] = _convert_to_id(raw['user_data'][user], 'y')
83 | return raw
84 |
85 | # set path
86 | DATASET_NAME = __file__.split('/')[-1].split('.')[0]
87 | path = os.path.join(os.path.expanduser(root), DATASET_NAME)
88 |
89 | # check if preprocessing has already done
90 | _make_path(path, 'all_data')
91 | _make_path(path, 'vocab')
92 | _make_path(path, 'intermediate')
93 | _make_path(path, 'train')
94 | _make_path(path, 'test')
95 |
96 | # adjust path since preprocessed json files are already prepared
97 | if os.path.exists(os.path.join(path, 'raw', 'new_small_data')):
98 | for file in os.listdir(os.path.join(path, 'raw', 'new_small_data')):
99 | if 'train' in file:
100 | os.replace(os.path.join(path, 'raw', 'new_small_data', file), os.path.join(path, 'intermediate', file))
101 | elif 'test' in file:
102 | os.replace(os.path.join(path, 'raw', 'new_small_data', file), os.path.join(path, 'intermediate', file))
103 | else:
104 | os.remove(os.path.join(path, 'raw', 'new_small_data', file)) # `val` data is not required...
105 |
106 | # edit `num_samples`: don't know why but it is not correct...
107 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Refine raw data...!')
108 | train_data = _load_data(os.path.join(path, 'intermediate', 'train_data.json'))
109 | test_data = _load_data(os.path.join(path, 'intermediate', 'test_data.json'))
110 |
111 | # correct number of samples and filter tokenized samples only
112 | train_data = _refine_data(train_data)
113 | test_data = _refine_data(test_data)
114 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished refining raw data!')
115 |
116 | # aggreagte data
117 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Combine raw data...!')
118 | _save_data(os.path.join(path, 'all_data', 'train_data_refined.json'), train_data)
119 | _save_data(os.path.join(path, 'all_data', 'test_data_refined.json'), test_data)
120 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished combining raw data!')
121 |
122 | # build vocabulary
123 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Build vocabulary...!')
124 | counter = _build_counter(train_data['user_data'])
125 | vocab_raw = _build_vocab(counter)
126 | _save_data(os.path.join(path, 'vocab', 'reddit_vocab.json'), vocab_raw)
127 |
128 | vocab = defaultdict(lambda: vocab_raw['unk_symbol'])
129 | vocab.update(vocab_raw['vocab'])
130 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...vocabulary is successfully created!')
131 |
132 | # convert tokens to index
133 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert tokens into indices using vocabulary...!')
134 | train_data = _convert_to_ids_and_get_length(train_data, vocab)
135 | test_data = _convert_to_ids_and_get_length(test_data, vocab)
136 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...all tokens are converted into indices!')
137 |
138 | # save processed data
139 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Split into training & test sets...!')
140 | _save_data(os.path.join(path, 'train', 'all_data_niid_00_train.json'), train_data)
141 | _save_data(os.path.join(path, 'test', 'all_data_niid_00_test.json'), test_data)
142 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...done splitting into training & test sets!')
143 |
--------------------------------------------------------------------------------
/src/datasets/leaf/preprocess/sent140.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import logging
5 |
6 | import pandas as pd
7 |
8 | from collections import defaultdict, OrderedDict
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | pd.set_option('mode.chained_assignment', None)
13 |
14 |
15 |
16 | def preprocess(root):
17 | RAW_TRAINING = 'training.1600000.processed.noemoticon.csv'
18 | RAW_TEST = 'testdata.manual.2009.06.14.csv'
19 | def _get_glove_vocab(path):
20 | # read GloVe embeddings (300-dim)
21 | lines = []
22 | with open(os.path.join(path, 'raw', 'glove.6B.300d.txt'), 'r') as file:
23 | lines = file.readlines()
24 |
25 | # process embeddings
26 | lines = [l.split() for l in lines]
27 | vocab = [l[0] for l in lines]
28 |
29 | # get word indices
30 | vocab_indices = defaultdict(int)
31 | for i, w in enumerate(vocab):
32 | vocab_indices[w] = i + 1
33 |
34 | # get index - embedding map
35 | embs = [[float(n) for n in l[1:]] for l in lines]
36 | embs.insert(0, [0. for _ in range(300)]) # for unknown token
37 |
38 | # save into file
39 | with open(os.path.join(path, 'vocab', 'glove.6B.300d.json'), 'w') as file:
40 | json.dump(embs, file)
41 | return vocab_indices
42 |
43 | def _combine_data(path):
44 | raw_train = pd.read_csv(
45 | os.path.join(path, RAW_TRAINING),
46 | encoding='ISO-8859-1',
47 | header=None,
48 | names=['target', 'id', 'date', 'flag', 'user', 'text'],
49 | usecols=['target', 'user', 'text'],
50 | index_col='user'
51 | )
52 | raw_test = pd.read_csv(
53 | os.path.join(path, RAW_TEST),
54 | encoding='ISO-8859-1',
55 | header=None,
56 | names=['target', 'id', 'date', 'flag', 'user', 'text'],
57 | usecols=['target', 'user', 'text'],
58 | index_col='user'
59 | )
60 | raw_all = pd.concat([raw_train, raw_test]).sort_index(kind='mergesort')
61 | return raw_all
62 |
63 | def _convert_to_json(path, raw_all, indices):
64 | def _cleanse(df):
65 | # dictionary containing all emojis with their meanings.
66 | emojis = {
67 | ':)': 'smile', ':-)': 'smile', ';d': 'wink', ':-E': 'vampire', ':(': 'sad',
68 | ':-(': 'sad', ':-<': 'sad', ':P': 'raspberry', ':O': 'surprised',
69 | ':-@': 'shocked', ':@': 'shocked',':-$': 'confused', ':\\': 'annoyed',
70 | ':#': 'mute', ':X': 'mute', ':^)': 'smile', ':-&': 'confused', '$_$': 'greedy',
71 | '@@': 'eyeroll', ':-!': 'confused', ':-D': 'smile', ':-0': 'yell', 'O.o': 'confused',
72 | '<(-_-)>': 'robot', 'd[-_-]b': 'dj', ":'-)": 'sadsmile', ';)': 'wink',
73 | ';-)': 'wink', 'O:-)': 'angel','O*-)': 'angel','(:-D': 'gossip', '=^.^=': 'cat'
74 | }
75 |
76 | # replace mentions
77 | df = df.apply(lambda x: re.sub(r"@[^\s]+", "USER", str(x).strip()))
78 |
79 | # replace links
80 | df = df.apply(lambda x: re.sub(r"((http://)[^ ]*|(https://)[^ ]*|( www\.)[^ ]*)", "URL", str(x).strip()))
81 |
82 | # remove non-alphabetical characters
83 | df = df.apply(lambda x: re.sub(r"[^A-Za-z0-9]+", " ", str(x).strip()))
84 |
85 | # remove numbers
86 | df = df.apply(lambda x: re.sub(r"[0-9]+", " ", str(x).strip()))
87 |
88 | # remove ordinals
89 | df = df.apply(lambda x: re.sub(r"[0-9]+(?:st| st|nd| nd|rd| rd|th| th)" , " ", str(x).strip()))
90 |
91 | # remove punctuations
92 | df = df.str.replace(r'[^\w\s]', '', regex=True)
93 |
94 | # replace 3 or more consecutive letters by 2 letter
95 | df = df.apply(lambda x: re.sub(r"(.)\1\1+", r"\1\1", str(x).strip()))
96 |
97 | # replace emojis
98 | for emoji, name in emojis.items():
99 | df = df.str.replace(emoji, name, regex=False)
100 |
101 | # normalize whitespaces
102 | df = df.str.replace(r'\s+', ' ', regex=True)
103 |
104 | # to lowercase
105 | df = df.str.lower()
106 | return df
107 |
108 | def _split_line(line):
109 | """Split given line/phrase into list of words
110 | """
111 | return re.findall(r"[\w']+|[.,!?;]", line)
112 |
113 | def _line_to_indices(line, word2id, max_words=25):
114 | """Converts given phrase into list of word indices.
115 |
116 | - If the phrase has more than `max_words` words,
117 | returns a list containing indices of the first `max_words` words.
118 |
119 | - If the phrase has less than `max_words` words,
120 | repeatedly appends integer representing padding index to returned list
121 | until the list's length is `max_words`.
122 |
123 | Args:
124 | line: string representing phrase/sequence of words
125 | word2id: dictionary with string words as keys and int indices as values
126 | max_words: maximum number of word indices in returned list
127 |
128 | Returns:
129 | indices: list of word indices, one index for each word in phrase
130 | """
131 | unk_id = 0
132 | line_list = _split_line(line) # split phrase in words
133 | indices = [word2id[w] for w in line_list[:max_words]]
134 | indices += [unk_id] * (max_words - len(indices))
135 | return indices
136 |
137 | # convert user ID into digits
138 | user_id_map = {str_id: int_id for int_id, str_id in enumerate(raw_all.index.unique().tolist())}
139 | curr_ids = raw_all.index.to_series()
140 | raw_all.index = curr_ids.map(user_id_map)
141 |
142 | # refine raw data
143 | raw_all.loc[:, 'target'].replace({4: 1, 2: 0}, inplace=True)
144 | raw_all.loc[:, 'text'] = _cleanse(raw_all.loc[:, 'text'])
145 | raw_all.loc[:, 'text'] = raw_all['text'].apply(lambda x: _line_to_indices(x, indices))
146 | raw_all = raw_all.reset_index().groupby('user').agg({'text': lambda x: [i for i in x], 'target': lambda y: [l for l in y]}).rename(columns={'text': 'x', 'target': 'y'})
147 | raw_all.index = raw_all.index.astype(str)
148 |
149 | # get required elements
150 | users = raw_all.index.tolist()
151 | num_samples = raw_all['y'].apply(len).values.tolist()
152 | user_data = raw_all.T.to_dict()
153 |
154 | # create json file
155 | all_data = OrderedDict()
156 | all_data['users'] = users
157 | all_data['num_samples'] = num_samples
158 | all_data['user_data'] = user_data
159 |
160 | # save file
161 | with open(os.path.join(path, 'all_data', 'all_data.json'), 'w') as outfile:
162 | json.dump(all_data, outfile)
163 |
164 | # set path
165 | DATASET_NAME = __file__.split('/')[-1].split('.')[0]
166 | path = os.path.join(os.path.expanduser(root), DATASET_NAME)
167 |
168 | # check if preprocessing has already done
169 | if not os.path.exists(os.path.join(path, 'all_data')):
170 | os.makedirs(os.path.join(path, 'all_data'))
171 | else:
172 | return
173 |
174 | # make path for GloVe vocabulary
175 | if not os.path.exists(os.path.join(path, 'vocab')):
176 | os.makedirs(os.path.join(path, 'vocab'))
177 |
178 | # get GloVe vocabulary
179 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Process GloVe embeddings (300 dim.)...!')
180 | glove_vocab_indices = _get_glove_vocab(path)
181 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished processing GloVe embeddings!')
182 |
183 | # combine raw data
184 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Combine raw data...!')
185 | raw_all = _combine_data(os.path.join(path, 'raw'))
186 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished combining raw data!')
187 |
188 | # convert to json format
189 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] Convert data to json format... (this may take several minutes)!')
190 | _convert_to_json(path, raw_all, glove_vocab_indices)
191 | logger.info(f'[LOAD] [LEAF - {DATASET_NAME.upper()}] ...finished converting data to json format!')
192 |
--------------------------------------------------------------------------------
/src/datasets/leafparser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import sys
4 | import json
5 | import torch
6 | import logging
7 | import importlib
8 | import concurrent.futures
9 |
10 | from abc import abstractmethod
11 |
12 | from src import TqdmToLogger
13 | from src.datasets.leaf import *
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 |
19 | class LEAFDataset(torch.utils.data.Dataset):
20 | """Base dataset class for LEAF benchmark dataset.
21 | """
22 | def __init__(self):
23 | super(LEAFDataset, self).__init__()
24 | self.identifier = None
25 | self.num_samples = 0
26 |
27 | @abstractmethod
28 | def make_dataset(self):
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | def __getitem__(self, index):
33 | raise NotImplementedError
34 |
35 | def __len__(self):
36 | return self.num_samples
37 |
38 | def __repr__(self):
39 | return str(self.identifier)
40 |
41 | # LEAF - FEMNIST
42 | class FEMNIST(LEAFDataset):
43 | def __init__(self, in_channels, num_classes, transform=None):
44 | super(FEMNIST, self).__init__()
45 | self.in_channels = in_channels
46 | self.transform = transform
47 | self.num_classes = num_classes
48 |
49 | def _process(self, raw_path):
50 | inputs = PIL.Image.open(raw_path).convert('L')
51 | return inputs
52 |
53 | def make_dataset(self):
54 | inputs, targets = self.data['x'], self.data['y']
55 | self.inputs = [raw_path for raw_path in inputs]
56 | self.targets = torch.tensor(targets).long()
57 | self.num_samples = len(self.inputs)
58 |
59 | def __getitem__(self, index):
60 | inputs, targets = self._process(self.inputs[index]), self.targets[index]
61 | if self.transform is not None:
62 | inputs = self.transform(inputs)
63 | return inputs, targets
64 |
65 | # LEAF - Shakespeare
66 | class Shakespeare(LEAFDataset):
67 | def __init__(self, num_embeddings, num_classes):
68 | super(Shakespeare, self).__init__()
69 | self.num_embeddings = num_embeddings
70 | self.num_classes = num_classes
71 |
72 | def make_dataset(self):
73 | self.inputs, self.targets = torch.tensor(self.data['x']), torch.tensor(self.data['y'])
74 | self.num_samples = len(self.inputs)
75 |
76 | def __getitem__(self, index):
77 | return self.inputs[index], self.targets[index]
78 |
79 | # LEAF - Sent140
80 | class Sent140(LEAFDataset):
81 | def __init__(self, num_embeddings, seq_len, num_classes):
82 | super(Sent140, self).__init__()
83 | self.num_embeddings = num_embeddings
84 | self.seq_len = seq_len
85 | self.num_classes = num_classes
86 |
87 | def make_dataset(self):
88 | self.inputs, self.targets = torch.tensor(self.data['x']).long(), torch.tensor(self.data['y']).long()
89 | self.num_samples = len(self.inputs)
90 |
91 | def __getitem__(self, index):
92 | return self.inputs[index], self.targets[index]
93 |
94 | # LEAF - CelebA
95 | class CelebA(LEAFDataset):
96 | def __init__(self, in_channels, img_path, num_classes, transform=None):
97 | super(CelebA, self).__init__()
98 | self.in_channels = in_channels
99 | self.img_path = img_path
100 | self.num_classes = num_classes
101 | self.transform = transform
102 |
103 | def _process(self, path):
104 | inputs = PIL.Image.open(os.path.join(self.img_path, path)).convert('RGB')
105 | return inputs
106 |
107 | def make_dataset(self):
108 | inputs, targets = self.data['x'], self.data['y']
109 | self.inputs = [fname for fname in inputs]
110 | self.targets = torch.tensor(targets).long()
111 | self.num_samples = len(self.inputs)
112 |
113 | def __getitem__(self, index):
114 | inputs, targets = self._process(self.inputs[index]), self.targets[index]
115 | if self.transform is not None:
116 | inputs = self.transform(inputs)
117 | return inputs, targets
118 |
119 | # LEAF - Reddit
120 | class Reddit(LEAFDataset):
121 | def __init__(self, num_embeddings, seq_len, num_classes):
122 | super(Reddit, self).__init__()
123 | self.num_embeddings = num_embeddings
124 | self.seq_len = seq_len
125 | self.num_classes = num_classes
126 |
127 | def make_dataset(self):
128 | self.inputs, self.targets = torch.tensor(self.data['x']).squeeze(1), torch.tensor(self.data['y']).squeeze(1)
129 | self.num_samples = len(self.inputs)
130 |
131 | def __getitem__(self, index):
132 | return self.inputs[index], self.targets[index]
133 |
134 | def fetch_leaf(args, dataset_name, root, seed, raw_data_fraction, test_size, transforms):
135 | CONFIG = {
136 | 'femnist': {'in_channels': 1, 'num_classes': 62},
137 | 'shakespeare': {'num_embeddings': 80, 'num_classes': 80},
138 | 'sent140': {'num_embeddings': 400000 + 1 if args.model_name == 'Sent140LSTM' else None, 'seq_len': 25, 'num_classes': 2}, # using GloVe 300-dim embeddings
139 | 'celeba': {'in_channels': 3, 'img_path': f'{root}/celeba/raw/img_align_celeba', 'num_classes': 2},
140 | 'reddit': {'num_embeddings': 10000, 'seq_len': 10, 'num_classes': 10000} # + 1 for an unknown token
141 | }
142 |
143 | def _load_processed(path, mode):
144 | file = os.listdir(os.path.join(path, mode))[0]
145 | with open(os.path.join(path, mode, file), 'r') as f:
146 | proc = json.load(f)
147 | return proc
148 |
149 | def _assign_to_clients(dataset_name, dataset_class, raw_train, raw_test, transforms):
150 | def _construct_dataset(idx, user):
151 | # instantiate module for each training set and test set
152 | tr_dset, te_dset = dataset_class(**CONFIG[dataset_name]), dataset_class(**CONFIG[dataset_name])
153 |
154 | # set essential attributes for training
155 | tr_dset.identifier = f'[LOAD] [{dataset_name.upper()}] CLIENT < {str(user).zfill(8)} > (train)'
156 | tr_dset.data = raw_train['user_data'][user]
157 | tr_dset.make_dataset()
158 |
159 | # set essential attributes for test
160 | te_dset.identifier = f'[LOAD] [{dataset_name.upper()}] CLIENT < {str(user).zfill(8)} > (test)'
161 | te_dset.data = raw_test['user_data'][user]
162 | te_dset.make_dataset()
163 |
164 | # transplant transform method
165 | tr_dset.transform = transforms[0]
166 | te_dset.transform = transforms[1]
167 | return (tr_dset, te_dset)
168 |
169 | datasets = []
170 | with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count() - 1) as workhorse:
171 | for idx, user in TqdmToLogger(
172 | enumerate(raw_train['users']),
173 | logger=logger,
174 | desc=f'[LOAD] [LEAF - {dataset_name.upper()}] ...assigning... ',
175 | total=len(raw_train['users'])
176 | ):
177 | datasets.append(workhorse.submit(_construct_dataset, idx, user).result())
178 | return datasets
179 |
180 | # retrieve appropriate dataset module
181 | dataset_class = getattr(sys.modules[__name__], dataset_name)
182 |
183 | # download data
184 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Check if raw data exists; if not, start downloading!')
185 | if not os.path.exists(f'{root}/{dataset_name.lower()}/raw'):
186 | os.makedirs(f'{root}/{dataset_name.lower()}/raw')
187 | download_data(download_root=f'{root}/{dataset_name.lower()}/raw', dataset_name=dataset_name.lower())
188 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...raw data is successfully downloaded!')
189 | else:
190 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...raw data already exists!')
191 |
192 | # pre-process raw data (fetch all raw data into json format)
193 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Pre-process raw data into json format!')
194 | importlib.import_module(f'.leaf.preprocess.{dataset_name.lower()}', package=__package__).preprocess(root=root)
195 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done pre-processing raw data into json format!')
196 |
197 | # post-process raw data (split data)
198 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Post-process raw data to be split into train & test!')
199 | args.num_clients = postprocess_leaf(dataset_name.lower(), root, seed, raw_data_fraction=raw_data_fraction, min_samples_per_clients=0, test_size=test_size)
200 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done post-processing raw data into train & test splits!')
201 |
202 | # get raw data
203 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Load training & test datasets...!')
204 | raw_train = _load_processed(os.path.join(root, dataset_name.lower()), 'train')
205 | raw_test = _load_processed(os.path.join(root, dataset_name.lower()), 'test')
206 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...done parsing trainig & test datasets!')
207 |
208 | # make dataset for each client
209 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] Instantiate client datasets and create split hashmap...!')
210 | client_datasets = _assign_to_clients(dataset_name.lower(), dataset_class, raw_train, raw_test, transforms)
211 | logger.info(f'[LOAD] [LEAF - {dataset_name.upper()}] ...instantiated client datasets and created split hashmap!')
212 |
213 | # adjust arguments
214 | args.num_classes = CONFIG[dataset_name.lower()]['num_classes']
215 | args.K = len(client_datasets)
216 | if 'in_channels' in CONFIG[dataset_name.lower()].keys():
217 | args.in_channels = CONFIG[dataset_name.lower()]['in_channels']
218 | if 'seq_len' in CONFIG[dataset_name.lower()].keys():
219 | args.seq_len = CONFIG[dataset_name.lower()]['seq_len']
220 | if 'num_embeddings' in CONFIG[dataset_name.lower()].keys():
221 | args.num_embeddings = CONFIG[dataset_name.lower()]['num_embeddings']
222 |
223 | # adjust argument
224 | return {}, client_datasets, args
225 |
--------------------------------------------------------------------------------
/src/datasets/speechcommands.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import torchaudio
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | LABELS = sorted([
10 | 'on', 'learn', 'tree', 'down', 'forward',
11 | 'backward', 'happy', 'off', 'nine', 'eight',
12 | 'left', 'four', 'one', 'visual', 'sheila',
13 | 'no', 'six', 'dog', 'up', 'five',
14 | 'marvin', 'cat', 'yes', 'zero', 'house',
15 | 'bird', 'go', 'seven', 'stop', 'wow',
16 | 'three', 'follow', 'right', 'bed', 'two'
17 | ])
18 |
19 | # dataset wrapper module
20 | class AudioClassificationDataset(torch.utils.data.Dataset):
21 | def __init__(self, dataset, dataset_name, suffix):
22 | self.dataset = dataset
23 | self.dataset_name = dataset_name
24 | self.suffix = suffix
25 | self.targets = self.dataset.targets
26 |
27 | def __getitem__(self, index):
28 | def _label_to_index(word):
29 | # Return the position of the word in labels
30 | return torch.tensor(LABELS.index(word))
31 |
32 | def _pad_sequence(batch, max_len=16000):
33 | # Make all tensor in a batch the same length by padding with zeros
34 | batch = [torch.nn.functional.pad(item.t(), (0, 0, 0, max_len - len(item.t())), value=0.) for item in batch]
35 | batch = torch.cat(batch)
36 | return batch.t()
37 |
38 | # get raw batch by index
39 | batch = self.dataset[index]
40 |
41 | # gather in lists, and encode labels as indices
42 | inputs, targets = [], []
43 | for waveform, _, label, *_ in (batch,):
44 | inputs += [waveform]
45 | targets += [_label_to_index(label)]
46 |
47 | # group the list of tensors into a batched tensor
48 | inputs = _pad_sequence(inputs)
49 | targets = torch.stack(targets).squeeze()
50 | return inputs, targets
51 |
52 | def __len__(self):
53 | return len(self.dataset)
54 |
55 | def __repr__(self):
56 | return f'[{self.dataset_name}] {self.suffix}'
57 |
58 | class SpeechCommands(torchaudio.datasets.SPEECHCOMMANDS):
59 | def __init__(self, root, split, download):
60 | self.data_root = os.path.expanduser(root)
61 | super(SpeechCommands, self).__init__(root=self.data_root, subset=split, download=download)
62 |
63 | def __repr__(self):
64 | rep_str = {'train': 'CLIENT', 'test': 'SERVER'}
65 | return f'[SpeechCommands] {rep_str[self.split]}'
66 |
67 | # helper method to fetch CINIC-10 dataset
68 | def fetch_speechcommands(args, root):
69 | logger.info('[LOAD] [SpeechCommands] Fetching dataset!')
70 |
71 | # default arguments
72 | DEFAULT_ARGS = {'root': root, 'download': True}
73 |
74 | # configure arguments for training/test dataset
75 | train_args = DEFAULT_ARGS.copy()
76 | train_args['split'] = 'training'
77 |
78 | # create training dataset instance
79 | raw_train = SpeechCommands(**train_args)
80 | train_targets = torch.tensor([LABELS.index(filename.split('/')[3]) for filename in raw_train._walker]).long()
81 | setattr(raw_train, 'targets', train_targets)
82 | raw_train = AudioClassificationDataset(raw_train, 'SpeechCommands'.upper(), 'CLIENT')
83 |
84 | # for global holdout set
85 | test_args = DEFAULT_ARGS.copy()
86 | test_args['split'] = 'testing'
87 |
88 | # create test dataset instance
89 | raw_test = SpeechCommands(**test_args)
90 | test_targets = torch.tensor([LABELS.index(filename.split('/')[3]) for filename in raw_test._walker]).long()
91 | setattr(raw_test, 'targets', test_targets)
92 | raw_test = AudioClassificationDataset(raw_test, 'SpeechCommands'.upper(), 'SERVER')
93 | logger.info('[LOAD] [SpeechCommands] ...fetched dataset!')
94 |
95 | # adjust arguments
96 | args.in_channels = 1
97 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.dataset.targets)))
98 | return raw_train, raw_test, args
99 |
100 |
101 |
--------------------------------------------------------------------------------
/src/datasets/tinyimagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 | import logging
5 | import torchvision
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 |
11 | # dataset wrapper module
12 | class TinyImageNet(torchvision.datasets.ImageFolder):
13 | base_folder = 'tiny-imagenet-200'
14 | zip_md5 = '90528d7ca1a48142e341f4ef8d21d0de'
15 | splits = ('train', 'val', 'test')
16 | filename = 'tiny-imagenet-200.zip'
17 | url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
18 |
19 | def __init__(self, root, split='train', download=True, transform=None, **kwargs):
20 | self.data_root = os.path.expanduser(root)
21 | self.split = torchvision.datasets.utils.verify_str_arg(split, 'split', self.splits)
22 | if download:
23 | self.download()
24 | if not self._check_exists():
25 | err = 'Dataset not found or corrupted. You can use download=True to download it'
26 | logger.exception(err)
27 | raise RuntimeError(err)
28 | super().__init__(root=self.split_folder, transform=transform, **kwargs)
29 |
30 | def normalize_tin_val_folder_structure(self, path, images_folder='images', annotations_file='val_annotations.txt'):
31 | images_folder = os.path.join(path, images_folder)
32 | annotations_file = os.path.join(path, annotations_file)
33 |
34 | # check if file exists
35 | if not os.path.exists(images_folder) and not os.path.exists(annotations_file):
36 | if not os.listdir(path):
37 | err = 'Validation folder is empty!'
38 | logger.exception(err)
39 | raise RuntimeError(err)
40 |
41 | # parse the annotations
42 | with open(annotations_file) as f:
43 | for line in f:
44 | values = line.split()
45 | img, label = values[:2]
46 |
47 | img_file = os.path.join(images_folder, img)
48 | label_folder = os.path.join(path, label)
49 |
50 | os.makedirs(label_folder, exist_ok=True)
51 | try:
52 | shutil.rmtree(img_file, os.path.join(label_folder, img))
53 | except FileNotFoundError:
54 | continue
55 | if not os.listdir(images_folder):
56 | raise AssertionError
57 | shutil.rmtree(images_folder)
58 | os.remove(annotations_file)
59 |
60 | @property
61 | def dataset_folder(self):
62 | return os.path.join(self.data_root, self.base_folder)
63 |
64 | @property
65 | def split_folder(self):
66 | return os.path.join(self.dataset_folder, self.split)
67 |
68 | def _check_exists(self):
69 | return os.path.exists(self.split_folder)
70 |
71 | def download(self):
72 | if self._check_exists():
73 | return
74 | torchvision.datasets.utils.download_and_extract_archive(
75 | self.url, self.data_root, filename=self.filename,
76 | remove_finished=True, md5=self.zip_md5
77 | )
78 | assert 'val' in self.splits
79 | self.normalize_tin_val_folder_structure(os.path.join(self.dataset_folder, 'val'))
80 |
81 | def __repr__(self):
82 | rep_str = {'train': 'CLIENT', 'test': 'SERVER'}
83 | return f'[TinyImageNet] {rep_str[self.split]}'
84 |
85 | # helper method to fetch Tiny ImageNet dataset
86 | def fetch_tinyimagenet(args, root, transforms):
87 | logger.info('[LOAD] [TINYIMAGENET] Fetching dataset!')
88 |
89 | # default arguments
90 | DEFAULT_ARGS = {'root': root, 'transform': None, 'download': True}
91 |
92 | # configure arguments for training/test dataset
93 | train_args = DEFAULT_ARGS.copy()
94 | train_args['split'] = 'train'
95 | train_args['transform'] = transforms[0]
96 |
97 | # create training dataset instance
98 | raw_train = TinyImageNet(**train_args)
99 |
100 | # for global holdout set
101 | test_args = DEFAULT_ARGS.copy()
102 | test_args['transform'] = transforms[1]
103 | test_args['split'] = 'test'
104 |
105 | # create test dataset instance
106 | raw_test = TinyImageNet(**test_args)
107 |
108 | logger.info('[LOAD] [CINIC10] ...fetched dataset!')
109 |
110 | # adjust argument
111 | args.in_channels = 3
112 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.targets)))
113 | return raw_train, raw_test, args
114 |
--------------------------------------------------------------------------------
/src/datasets/torchtextparser.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import sys
4 | import csv
5 | import torch
6 | import logging
7 | import torchtext
8 |
9 | from src import TqdmToLogger
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 |
15 | # dataset wrapper module
16 | class TextClassificationDataset(torch.utils.data.Dataset):
17 | def __init__(self, dataset_name, inputs, targets):
18 | self.identifier = dataset_name
19 | self.inputs = inputs
20 | self.targets = targets
21 |
22 | def __len__(self):
23 | return len(self.inputs)
24 |
25 | def __getitem__(self, index):
26 | inputs = self.inputs[index]
27 | targets = self.targets[index]
28 | return inputs, targets
29 |
30 | def __repr__(self):
31 | return str(self.identifier)
32 |
33 | # helper method to fetch dataset from `torchtext.datasets`
34 | def fetch_torchtext_dataset(args, dataset_name, root, tokenizer, seq_len, num_embeddings):
35 | URL = {
36 | 'AG_NEWS': 'https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz',
37 | 'SogouNews': 'https://s3.amazonaws.com/fast-ai-nlp/sogou_news_csv.tgz',
38 | 'DBpedia': 'https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz',
39 | 'YelpReviewPolarity': 'https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz',
40 | 'YelpReviewFull': 'https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz',
41 | 'YahooAnswers': 'https://s3.amazonaws.com/fast-ai-nlp/yahoo_answers_csv.tgz',
42 | 'AmazonReviewPolarity': 'https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz',
43 | 'AmazonReviewFull': 'https://s3.amazonaws.com/fast-ai-nlp/amazon_review_full_csv.tgz'
44 | }
45 | MD5 = {
46 | 'AG_NEWS': '2c2d85915f1ca34b29e754ce3b403c81',
47 | 'SogouNews': '45c19a17716907a7a3bde8c7398f7542',
48 | 'DBpedia': '531dab845dd933d0d7b04e82bffe6d96',
49 | 'YelpReviewPolarity': '0f09b3af1a79c136ef9ca5f29df9ed9a',
50 | 'YelpReviewFull': 'a4acce1892d0f927165488f62860cabe',
51 | 'YahooAnswers': '672a634b0a8a1314138321e7d075a64e',
52 | 'AmazonReviewPolarity': 'ee221fbfc3bf7d49dd7dee8064b2404c',
53 | 'AmazonReviewFull': '289723722b64d337a40f809edd29d0f0'
54 | }
55 | NUM_CLASSES = {
56 | 'AG_NEWS': 4,
57 | 'SogouNews': 5,
58 | 'DBpedia': 14,
59 | 'YelpReviewPolarity': 2,
60 | 'YelpReviewFull': 5,
61 | 'YahooAnswers': 10,
62 | 'AmazonReviewPolarity': 2,
63 | 'AmazonReviewFull': 5
64 | }
65 |
66 | if dataset_name not in URL.keys():
67 | err = f'Dataset ({dataset_name}) is not supported!'
68 | logger.exception(err)
69 | raise Exception(err)
70 |
71 | def _unicode_csv_reader(unicode_csv_data, **kwargs):
72 | maxInt = sys.maxsize
73 | while True:
74 | try:
75 | csv.field_size_limit(maxInt)
76 | break
77 | except OverflowError:
78 | maxInt = int(maxInt / 10)
79 | csv.field_size_limit(maxInt)
80 |
81 | for line in csv.reader(unicode_csv_data, **kwargs):
82 | yield line
83 |
84 | def _csv_iterator(data_path, yield_cls=False):
85 | tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
86 | with io.open(data_path, encoding='utf8') as f:
87 | reader = _unicode_csv_reader(f)
88 | for row in reader:
89 | tokens = ' '.join(row[1:])
90 | tokens = tokenizer(tokens)
91 | if yield_cls:
92 | yield int(row[0]) - 1, torchtext.data.utils.ngrams_iterator(tokens, ngrams=1)
93 | else:
94 | yield torchtext.data.utils.ngrams_iterator(tokens, ngrams=1)
95 |
96 | def _create_data_from_iterator(vocab, iterator, max_len):
97 | inputs, targets = [], []
98 | for label, tokens in TqdmToLogger(iterator, logger=logger, desc=f'[LOAD] [{dataset_name.upper()}] ...prepare raw data!'):
99 | tokens = [vocab[token] for token in tokens]
100 | # pad tokens to have max length
101 | pad_len = max_len - len(tokens) % max_len
102 | if pad_len > 0:
103 | tokens.extend([vocab[''] for _ in range(pad_len)])
104 |
105 | # slice tokens up to max length
106 | tokens = tokens[:max_len]
107 |
108 | # collect processed pairs
109 | inputs.append(tokens)
110 | targets.append(label)
111 | return torch.tensor(inputs).long(), torch.tensor(targets).long()
112 |
113 | def _create_data_from_tokenizer(tokenizer, iterator, max_len):
114 | inputs, targets = [], []
115 | for label, tokens in TqdmToLogger(iterator, logger=logger, desc=f'[LOAD] [{dataset_name.upper()}] ...prepare raw data!'):
116 | tokens = tokenizer(
117 | list(tokens),
118 | return_tensors='pt',
119 | is_split_into_words=True,
120 | max_length=max_len,
121 | return_attention_mask=False,
122 | truncation=True,
123 | padding='max_length'
124 | )['input_ids']
125 |
126 | inputs.append(*tokens)
127 | targets.append(label)
128 | return inputs, targets
129 |
130 | # download files
131 | logger.info(f'[LOAD] [{dataset_name.upper()}] Start downloading files!')
132 | root = os.path.expanduser(root)
133 | raw_files = torchtext.utils.download_from_url(
134 | url=URL[dataset_name],
135 | root=root,
136 | hash_value=MD5[dataset_name],
137 | hash_type='md5'
138 | )
139 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...downloaded files!')
140 |
141 |
142 | # extract archive
143 | logger.info(f'[LOAD] [{dataset_name.upper()}] Extract archived files!')
144 | raw_files = torchtext.utils.extract_archive(raw_files)
145 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...successfully extracted archived files!')
146 |
147 | # retrieve split files
148 | for fname in raw_files:
149 | if fname.endswith('train.csv'):
150 | train_csv_path = fname
151 | if fname.endswith('test.csv'):
152 | test_csv_path = fname
153 |
154 | # build vocabularies using training set
155 | if tokenizer is None:
156 | logger.info(f'[LOAD] [{dataset_name.upper()}] Build vocabularies!')
157 | vocab = torchtext.vocab.build_vocab_from_iterator(_csv_iterator(train_csv_path), specials=[''], max_tokens=num_embeddings)
158 | vocab.set_default_index(vocab[''])
159 | vocab.vocab.insert_token('', 0)
160 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...vocabularies are built!')
161 |
162 | # tokenize training & test data and prepare inputs/targets
163 | logger.info(f'[LOAD] [{dataset_name.upper()}] Create trainig & test set!')
164 | if tokenizer is None:
165 | tr_inputs, tr_targets = _create_data_from_iterator(vocab, _csv_iterator(train_csv_path, yield_cls=True), seq_len)
166 | te_inputs, te_targets = _create_data_from_iterator(vocab, _csv_iterator(test_csv_path, yield_cls=True), seq_len)
167 | else:
168 | tr_inputs, tr_targets = _create_data_from_tokenizer(tokenizer, _csv_iterator(train_csv_path, yield_cls=True), seq_len)
169 | te_inputs, te_targets = _create_data_from_tokenizer(tokenizer, _csv_iterator(test_csv_path, yield_cls=True), seq_len)
170 |
171 | # adjust labels
172 | min_label_tr, min_label_te = min(tr_targets), min(te_targets)
173 | tr_targets = torch.tensor([l - min_label_tr for l in tr_targets]).long()
174 | te_targets = torch.tensor([l - min_label_te for l in te_targets]).long()
175 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...created training & test set!')
176 |
177 | # adjust arguments
178 | args.num_embeddings = len(vocab) + 1 if tokenizer is None else tokenizer.vocab_size
179 | args.num_classes = NUM_CLASSES[dataset_name]
180 | return TextClassificationDataset(f'[{dataset_name}] CLIENT', tr_inputs, tr_targets), TextClassificationDataset(f'[{dataset_name}] SERVER', te_inputs, te_targets), args
181 |
--------------------------------------------------------------------------------
/src/datasets/torchvisionparser.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import torchvision
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 |
9 | # dataset wrapper module
10 | class VisionClassificationDataset(torch.utils.data.Dataset):
11 | def __init__(self, dataset, dataset_name, suffix):
12 | self.dataset = dataset
13 | self.dataset_name = dataset_name
14 | self.suffix = suffix
15 | self.targets = self.dataset.targets
16 |
17 | def __getitem__(self, index):
18 | inputs, targets = self.dataset[index]
19 | return inputs, targets
20 |
21 | def __len__(self):
22 | return len(self.dataset)
23 |
24 | def __repr__(self):
25 | return f'[{self.dataset_name}] {self.suffix}'
26 |
27 | # helper method to fetch dataset from `torchvision.datasets`
28 | def fetch_torchvision_dataset(args, dataset_name, root, transforms):
29 | logger.info(f'[LOAD] [{dataset_name.upper()}] Fetching dataset!')
30 |
31 | # default arguments
32 | DEFAULT_ARGS = {'root': root, 'transform': None, 'download': True}
33 |
34 | if dataset_name in [
35 | 'MNIST', 'FashionMNIST', 'QMNIST', 'KMNIST', 'EMNIST',\
36 | 'CIFAR10', 'CIFAR100', 'USPS'
37 | ]:
38 | # configure arguments for training/test dataset
39 | train_args = DEFAULT_ARGS.copy()
40 | train_args['train'] = True
41 | train_args['transform'] = transforms[0]
42 |
43 | # special case - EMNIST
44 | if dataset_name == 'EMNIST':
45 | train_args['split'] = 'byclass'
46 |
47 | # create training dataset instance
48 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args)
49 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT')
50 |
51 | # for global holdout set
52 | test_args = DEFAULT_ARGS.copy()
53 | test_args['transform'] = transforms[1]
54 | test_args['train'] = False
55 |
56 | # special case - EMNIST
57 | if dataset_name == 'EMNIST':
58 | test_args['split'] = 'byclass'
59 |
60 | # create test dataset instance
61 | raw_test = torchvision.datasets.__dict__[dataset_name](**test_args)
62 | raw_test = VisionClassificationDataset(raw_test, dataset_name.upper(), 'SERVER')
63 |
64 | # adjust arguments
65 | if 'CIFAR' in dataset_name:
66 | args.in_channels = 3
67 | else:
68 | args.in_channels = 1
69 |
70 | elif dataset_name in [
71 | 'Country211',\
72 | 'DTD', 'Flowers102', 'Food101', 'FGVCAircraft',\
73 | 'GTSRB', 'RenderedSST2', 'StanfordCars',\
74 | 'STL10', 'SVHN'
75 | ]:
76 | # configure arguments for training/test dataset
77 | train_args = DEFAULT_ARGS.copy()
78 | train_args['split'] = 'train'
79 | train_args['transform'] = transforms[0]
80 |
81 | # create training dataset instance
82 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args)
83 |
84 | # for global holdout set
85 | test_args = DEFAULT_ARGS.copy()
86 | test_args['transform'] = transforms[1]
87 | test_args['split'] = 'test'
88 |
89 | # create test dataset instance
90 | raw_test = torchvision.datasets.__dict__[dataset_name](**test_args)
91 |
92 | # for compatibility, create attribute `targets`
93 | if dataset_name in ['DTD', 'Flowers102', 'Food101', 'FGVCAircraft']:
94 | setattr(raw_train, 'targets', raw_train._labels)
95 | setattr(raw_test, 'targets', raw_test._labels)
96 | elif dataset_name in ['GTSRB', 'RenderedSST2', 'StanfordCars']:
97 | setattr(raw_train, 'targets', [*list(zip(*raw_train._samples))[-1]])
98 | setattr(raw_test, 'targets', [*list(zip(*raw_test._samples))[-1]])
99 | elif dataset_name in ['STL10', 'SVHN']:
100 | setattr(raw_train, 'targets', raw_train.labels)
101 | setattr(raw_test, 'targets', raw_test.labels)
102 |
103 | # set raw datasets
104 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT')
105 | raw_test = VisionClassificationDataset(raw_test, dataset_name.upper(), 'SERVER')
106 |
107 | # adjust arguments
108 | if 'RenderedSST2' in dataset_name:
109 | args.in_channels = 1
110 | else:
111 | args.in_channels = 3
112 |
113 | elif dataset_name in ['Places365', 'INaturalist', 'OxfordIIITPet', 'Omniglot']:
114 | # configure arguments for training/test dataset
115 | train_args = DEFAULT_ARGS.copy()
116 | train_args['transform'] = transforms[0]
117 |
118 | if dataset_name == 'Places365':
119 | train_args['split'] = 'train-standard'
120 | elif dataset_name == 'OxfordIIITPet':
121 | train_args['split'] = 'trainval'
122 | elif dataset_name == 'INaturalist':
123 | train_args['version'] = '2021_train_mini'
124 | elif dataset_name == 'Omniglot':
125 | train_args['background'] = True
126 |
127 | # create training dataset instance
128 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args)
129 |
130 | # for global holdout set
131 | test_args = DEFAULT_ARGS.copy()
132 | test_args['transform'] = transforms[1]
133 |
134 | if dataset_name == 'Places365':
135 | test_args['split'] = 'val'
136 | elif dataset_name == 'OxfordIIITPet':
137 | test_args['split'] = 'test'
138 | elif dataset_name == 'INaturalist':
139 | test_args['version'] = '2021_valid'
140 | elif dataset_name == 'Omniglot':
141 | test_args['background'] = False
142 |
143 | # create test dataset instance
144 | raw_test = torchvision.datasets.__dict__[dataset_name](**test_args)
145 |
146 | # for compatibility, create attribute `targets`
147 | if dataset_name == 'OxfordIIITPet':
148 | setattr(raw_train, 'targets', raw_train._labels)
149 | elif dataset_name == 'INaturalist':
150 | setattr(raw_train, 'targets', [*list(zip(*raw_train.index))[0]])
151 | elif dataset_name == 'Omniglot':
152 | setattr(raw_train, 'targets', [*list(zip(*raw_train._flat_character_images))[-1]])
153 |
154 | # set raw datasets
155 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT')
156 | raw_test = VisionClassificationDataset(raw_test, dataset_name.upper(), 'SERVER')
157 |
158 | # adjust arguments
159 | if 'Omniglot' in dataset_name:
160 | args.in_channels = 1
161 | else:
162 | args.in_channels = 3
163 |
164 | elif dataset_name in ['Caltech256', 'SEMEION', 'SUN397']:
165 | # configure arguments for training dataset
166 | # NOTE: these datasets do NOT provide pre-defined split
167 | # Thus, use all datasets as a training dataset
168 | train_args = DEFAULT_ARGS.copy()
169 | train_args['transform'] = transforms[0]
170 |
171 | # create training dataset instance
172 | raw_train = torchvision.datasets.__dict__[dataset_name](**train_args)
173 |
174 | # for compatibility, create attribute `targets`
175 | if dataset_name == 'Caltech256':
176 | setattr(raw_train, 'targets', raw_train.y)
177 | elif dataset_name == 'SEMEION':
178 | setattr(raw_train, 'targets', raw_train.labels)
179 | elif dataset_name == 'SUN397':
180 | setattr(raw_train, 'targets', raw_train._labels)
181 |
182 | # set raw datasets (no holdout set is supported)
183 | raw_train = VisionClassificationDataset(raw_train, dataset_name.upper(), 'CLIENT')
184 | raw_test = None
185 |
186 | # adjust arguments
187 | if 'SEMEION' in dataset_name:
188 | args.in_channels = 1
189 | else:
190 | args.in_channels = 3
191 | else:
192 | err = f'[LOAD] Dataset `{dataset_name}` is not supported!'
193 | logger.exception(err)
194 | raise Exception(err)
195 |
196 | args.num_classes = len(torch.unique(torch.as_tensor(raw_train.dataset.targets)))
197 | logger.info(f'[LOAD] [{dataset_name.upper()}] ...fetched dataset!')
198 | return raw_train, raw_test, args
199 |
--------------------------------------------------------------------------------
/src/loaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import load_dataset
2 | from .model import load_model
--------------------------------------------------------------------------------
/src/loaders/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import inspect
3 | import logging
4 | import importlib
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 |
10 | def load_model(args):
11 | # retrieve model skeleton
12 | model_class = importlib.import_module('..models', package=__package__).__dict__[args.model_name]
13 |
14 | # get required model arguments
15 | required_args = inspect.getargspec(model_class)[0]
16 |
17 | # collect eneterd model arguments
18 | model_args = {}
19 | for argument in required_args:
20 | if argument == 'self':
21 | continue
22 | model_args[argument] = getattr(args, argument)
23 |
24 | # get model instance
25 | model = model_class(**model_args)
26 |
27 | # adjust arguments if needed
28 | if args.use_pt_model:
29 | args.num_embeddings = model.num_embeddings
30 | args.embedding_size = model.embedding_size
31 | args.num_hiddens = model.num_hiddens
32 | args.dropout = model.dropout
33 | return model, args
34 |
--------------------------------------------------------------------------------
/src/loaders/split.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 |
4 | from src import TqdmToLogger
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 |
10 | def simulate_split(args, dataset):
11 | """Split data indices using labels.
12 |
13 | Args:
14 | args (argparser): arguments
15 | dataset (dataset): raw dataset instance to be split
16 |
17 | Returns:
18 | split_map (dict): dictionary with key is a client index and a corresponding value is a list of indices
19 | """
20 | # IID split (i.e., statistical homogeneity)
21 | if args.split_type == 'iid':
22 | # shuffle sample indices
23 | shuffled_indices = np.random.permutation(len(dataset))
24 |
25 | # get adjusted indices
26 | split_indices = np.array_split(shuffled_indices, args.K)
27 |
28 | # construct a hashmap
29 | split_map = {k: split_indices[k] for k in range(args.K)}
30 | return split_map
31 |
32 | # non-IID split by sample unbalancedness
33 | if args.split_type == 'unbalanced':
34 | # shuffle sample indices
35 | shuffled_indices = np.random.permutation(len(dataset))
36 |
37 | # split indices by number of clients
38 | split_indices = np.array_split(shuffled_indices, args.K)
39 |
40 | # randomly remove some proportion (1% ~ 5%) of data
41 | keep_ratio = np.random.uniform(low=0.95, high=0.99, size=len(split_indices))
42 |
43 | # get adjusted indices
44 | split_indices = [indices[:int(len(indices) * ratio)] for indices, ratio in zip(split_indices, keep_ratio)]
45 |
46 | # construct a hashmap
47 | split_map = {k: split_indices[k] for k in range(args.K)}
48 | return split_map
49 |
50 | # Non-IID split proposed in (McMahan et al., 2016); each client has samples from at least two different classes
51 | elif args.split_type == 'patho':
52 | try:
53 | assert args.mincls >= 2
54 | except AssertionError as e:
55 | logger.exception("[SIMULATE] Each client should have samples from at least 2 distinct classes!")
56 | raise e
57 |
58 | # get indices by class labels
59 | _, unique_inverse, unique_counts = np.unique(dataset.targets, return_inverse=True, return_counts=True)
60 | class_indices = np.split(np.argsort(unique_inverse), np.cumsum(unique_counts[:-1]))
61 |
62 | # divide shards
63 | num_shards_per_class = args.K * args.mincls // args.num_classes
64 | if num_shards_per_class < 1:
65 | err = f'[SIMULATE] Increase the number of minimum class (`args.mincls` > {args.mincls}) or the number of participating clients (`args.K` > {args.K})!'
66 | logger.exception(err)
67 | raise Exception(err)
68 |
69 | # split class indices again into groups, each having the designated number of shards
70 | split_indices = [np.array_split(np.random.permutation(indices), num_shards_per_class) for indices in class_indices]
71 |
72 | # make hashmap to track remaining shards to be assigned per client
73 | class_shards_counts = dict(zip([i for i in range(args.num_classes)], [len(split_idx) for split_idx in split_indices]))
74 |
75 | # assign divided shards to clients
76 | assigned_shards = []
77 | for _ in TqdmToLogger(
78 | range(args.K),
79 | logger=logger,
80 | desc='[SIMULATE] ...assigning to clients... '
81 | ):
82 | # update selection proability according to the count of reamining shards
83 | # i.e., do NOT sample from class having no remaining shards
84 | selection_prob = np.where(np.array(list(class_shards_counts.values())) > 0, 1., 0.)
85 | selection_prob /= sum(selection_prob)
86 |
87 | # select classes to be considered
88 | try:
89 | selected_classes = np.random.choice(args.num_classes, args.mincls, replace=False, p=selection_prob)
90 | except: # if shard size is not fit enough, some clients may inevitably have samples from classes less than the number of `mincls`
91 | selected_classes = np.random.choice(args.num_classes, args.mincls, replace=True, p=selection_prob)
92 |
93 | # assign shards in randomly selected classes to current client
94 | for it, class_idx in enumerate(selected_classes):
95 | selected_shard_indices = np.random.choice(len(split_indices[class_idx]), 1)[0]
96 | selected_shards = split_indices[class_idx].pop(selected_shard_indices)
97 | if it == 0:
98 | assigned_shards.append([selected_shards])
99 | else:
100 | assigned_shards[-1].append(selected_shards)
101 | class_shards_counts[class_idx] -= 1
102 | else:
103 | assigned_shards[-1] = np.concatenate(assigned_shards[-1])
104 |
105 | # construct a hashmap
106 | split_map = {k: assigned_shards[k] for k in range(args.K)}
107 | return split_map
108 |
109 | # Non-IID split proposed in (Hsu et al., 2019); simulation of non-IID split scenario using Dirichlet distribution
110 | elif args.split_type == 'diri':
111 | MIN_SAMPLES = int(1 / args.test_size)
112 |
113 | # get indices by class labels
114 | total_counts = len(dataset.targets)
115 | _, unique_inverse, unique_counts = np.unique(dataset.targets, return_inverse=True, return_counts=True)
116 | class_indices = np.split(np.argsort(unique_inverse), np.cumsum(unique_counts[:-1]))
117 |
118 | # calculate ideal samples counts per client
119 | ideal_counts = len(dataset.targets) // args.K
120 | if ideal_counts < 1:
121 | err = f'[SIMULATE] Decrease the number of participating clients (`args.K` < {args.K})!'
122 | logger.exception(err)
123 | raise Exception(err)
124 |
125 | # split dataset
126 | ## define temporary container
127 | assigned_indices = []
128 |
129 | ## NOTE: it is possible that not all samples be consumed, as it is intended for satisfying each clients having at least `MIN_SAMPLES` samples per class
130 | for k in TqdmToLogger(range(args.K), logger=logger, desc='[SIMULATE] ...assigning to clients... '):
131 | ### for current client of which index is `k`
132 | curr_indices = []
133 | satisfied_counts = 0
134 |
135 | ### ...until the number of samples close to ideal counts is filled
136 | while satisfied_counts < ideal_counts:
137 | ### define Dirichlet distribution of which prior distribution is an uniform distribution
138 | diri_prior = np.random.uniform(size=args.num_classes)
139 |
140 | ### sample a parameter corresponded to that of categorical distribution
141 | cat_param = np.random.dirichlet(alpha=args.cncntrtn * diri_prior)
142 |
143 | ### try to sample by amount of `ideal_counts``
144 | sampled = np.random.choice(args.num_classes, ideal_counts, p=cat_param)
145 |
146 | ### count per-class samples
147 | unique, counts = np.unique(sampled, return_counts=True)
148 | if len(unique) < args.mincls:
149 | continue
150 |
151 | ### filter out sampled classes not having as much as `MIN_SAMPLES`
152 | required_counts = counts * (counts > MIN_SAMPLES)
153 |
154 | ### assign from population indices split by classes
155 | for idx, required_class in enumerate(unique):
156 | if required_counts[idx] == 0: continue
157 | sampled_indices = class_indices[required_class][:required_counts[idx]]
158 | curr_indices.append(sampled_indices)
159 | class_indices[required_class] = class_indices[required_class][:required_counts[idx]]
160 | satisfied_counts += sum(required_counts)
161 |
162 | ### when enough samples are collected, go to next clients!
163 | assigned_indices.append(np.concatenate(curr_indices))
164 |
165 | # construct a hashmap
166 | split_map = {k: assigned_indices[k] for k in range(args.K)}
167 | return split_map
168 | # `leaf` - LEAF benchmark (Caldas et al., 2018); `fedvis` - Federated Vision Datasets (Hsu, Qi and Brown, 2020)
169 | elif args.split_type in ['leaf']:
170 | logger.info('[SIMULATE] Use pre-defined split!')
171 |
--------------------------------------------------------------------------------
/src/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .metricszoo import *
--------------------------------------------------------------------------------
/src/metrics/basemetric.py:
--------------------------------------------------------------------------------
1 | from abc import *
2 |
3 |
4 |
5 | class BaseMetric(metaclass=ABCMeta):
6 | @abstractmethod
7 | def __init__(self):
8 | raise NotImplementedError
9 |
10 | @abstractmethod
11 | def collect(self, pred, true):
12 | raise NotImplementedError
13 |
14 | @abstractmethod
15 | def summarize(self):
16 | raise NotImplementedError
17 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .twonn import TwoNN
2 | from .twocnn import TwoCNN
3 | from .lenet import LeNet
4 | from .m5 import M5
5 | from .vgg import *
6 | from .resnet import *
7 | from .shufflenet import ShuffleNet
8 | from .mobilenet import MobileNet
9 | from .mobilenext import MobileNeXt
10 | from .squeezenet import SqueezeNet
11 | from .squeezenext import SqueezeNeXt
12 | from .mobilevit import MobileViT
13 | from .stackedlstm import StackedLSTM
14 | from .distilbert import DistilBert
15 | from .mobilebert import MobileBert
16 | from .squeezebert import SqueezeBert
17 | from .logreg import LogReg
18 | from .simplecnn import SimpleCNN
19 | from .femnistcnn import FEMNISTCNN
20 | from .sent140lstm import Sent140LSTM
21 | from .stackedtransformer import StackedTransformer
--------------------------------------------------------------------------------
/src/models/distilbert.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformers import DistilBertModel, DistilBertConfig
4 |
5 |
6 |
7 | class DistilBert(torch.nn.Module):
8 | def __init__(self, num_classes, num_embeddings, embedding_size, hidden_size, dropout, use_pt_model, is_seq2seq):
9 | super(DistilBert, self).__init__()
10 | self.is_seq2seq = is_seq2seq
11 | # define encoder
12 | if use_pt_model: # fine-tuning
13 | self.features = DistilBertModel.from_pretrained('distilbert-base-uncased')
14 | self.num_classes = num_classes
15 | self.num_embeddings = self.features.config.vocab_size
16 | self.embedding_size = self.features.config.dim
17 | self.num_hiddens = self.features.config.hidden_size
18 | self.dropout = self.features.config.dropout
19 |
20 | self.classifier = torch.nn.Linear(self.embedding_size, self.num_classes, bias=True)
21 | else: # from scratch
22 | self.num_classes = num_classes
23 | self.num_embeddings = num_embeddings
24 | self.embedding_size = embedding_size
25 | self.num_hiddens = hidden_size
26 | self.dropout = dropout
27 |
28 | config = DistilBertConfig(
29 | vocab_size=self.num_embeddings,
30 | dim=self.embedding_size,
31 | hidden_size=self.num_hiddens,
32 | hidden_dropout_prob=self.dropout
33 | )
34 | self.features = DistilBertModel(config)
35 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True)
36 |
37 | def forward(self, x):
38 | x = self.features(x)[0]
39 | x = self.classifier(x.last_hidden_state if self.is_seq2seq else x[:, 0, :]) # use [CLS] token for classification
40 | return x
41 |
--------------------------------------------------------------------------------
/src/models/femnistcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class FEMNISTCNN(torch.nn.Module): # for FEMNIST experiment in Caldas et al., 2018; (https://github.com/TalwalkarLab/leaf/blob/master/models/femnist/cnn.py
6 | def __init__(self, in_channels, hidden_size, num_classes):
7 | super(FEMNISTCNN, self).__init__()
8 | self.in_channels = in_channels
9 | self.hidden_channels = hidden_size
10 | self.num_classes = num_classes
11 |
12 | self.features = torch.nn.Sequential(
13 | torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=5, padding=1, stride=1, bias=True),
14 | torch.nn.ReLU(),
15 | torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
16 | torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels * 2, kernel_size=5, padding=1, stride=1, bias=True),
17 | torch.nn.ReLU(),
18 | torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
19 | )
20 | self.classifier = torch.nn.Sequential(
21 | torch.nn.AdaptiveAvgPool2d((7, 7)),
22 | torch.nn.Flatten(),
23 | torch.nn.Linear(in_features=self.hidden_channels * 2 * (7 * 7), out_features=2048, bias=True),
24 | torch.nn.ReLU(),
25 | torch.nn.Linear(in_features=2048, out_features=self.num_classes, bias=True)
26 | )
27 |
28 | def forward(self, x):
29 | x = self.features(x)
30 | x = self.classifier(x)
31 | return x
32 |
--------------------------------------------------------------------------------
/src/models/lenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class LeNet(torch.nn.Module):
6 | def __init__(self, in_channels, num_classes, hidden_size, dropout):
7 | super(LeNet, self).__init__()
8 | self.in_channels = in_channels
9 | self.hidden_channels = hidden_size
10 | self.num_classes = num_classes
11 | self.dropout = dropout
12 |
13 | self.features = torch.nn.Sequential(
14 | torch.nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size=(5, 5), padding=1, stride=1, bias=True),
15 | torch.nn.ReLU(True),
16 | torch.nn.MaxPool2d((2, 2)),
17 | torch.nn.Conv2d(self.hidden_channels, self.hidden_channels * 2, kernel_size=(5, 5), padding=1, stride=1, bias=True),
18 | torch.nn.ReLU(True),
19 | torch.nn.MaxPool2d((2, 2)),
20 | )
21 | self.classifier = torch.nn.Sequential(
22 | torch.nn.AdaptiveAvgPool2d((4, 4)),
23 | torch.nn.Flatten(),
24 | torch.nn.Linear((4 * 4) * (self.hidden_channels * 2), self.hidden_channels, bias=True),
25 | torch.nn.ReLU(True),
26 | torch.nn.Dropout(self.dropout),
27 | torch.nn.Linear(self.hidden_channels, self.hidden_channels // 2, bias=True),
28 | torch.nn.ReLU(True),
29 | torch.nn.Linear(self.hidden_channels // 2, self.num_classes, bias=True)
30 | )
31 |
32 | def forward(self, x):
33 | x = self.features(x)
34 | x = self.classifier(x)
35 | return x
36 |
--------------------------------------------------------------------------------
/src/models/logreg.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class LogReg(torch.nn.Module):
6 | def __init__(self, in_features, num_layers, hidden_size, num_classes):
7 | super(LogReg, self).__init__()
8 | self.in_features = in_features
9 | self.num_classes = num_classes
10 |
11 | if num_layers == 1:
12 | self.features = torch.nn.Identity()
13 | self.classifier = torch.nn.Linear(in_features, num_classes, bias=True)
14 | else:
15 | features = [torch.nn.Linear(in_features, hidden_size, bias=True)]
16 | for _ in range(num_layers - 1):
17 | features.append(torch.nn.Linear(hidden_size, hidden_size, bias=True))
18 | features.append(torch.nn.ReLU(True))
19 | self.features = torch.nn.Sequential(*features)
20 | self.classifier = torch.nn.Linear(hidden_size, num_classes, bias=True)
21 |
22 | def forward(self, x):
23 | x = self.features(x)
24 | x = self.classifier(x)
25 | return x
--------------------------------------------------------------------------------
/src/models/m5.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class M5(torch.nn.Module):
6 | def __init__(self, in_channels, hidden_size, num_classes):
7 | super(M5, self).__init__()
8 | self.in_channels = in_channels
9 | self.num_hiddens = hidden_size
10 | self.num_classes = num_classes
11 |
12 | self.features = torch.nn.Sequential(
13 | torch.nn.Conv1d(self.in_channels, self.num_hiddens, kernel_size=80, stride=16),
14 | torch.nn.BatchNorm1d(self.num_hiddens),
15 | torch.nn.ReLU(True),
16 | torch.nn.MaxPool1d(4),
17 | torch.nn.Conv1d(self.num_hiddens, self.num_hiddens, kernel_size=3),
18 | torch.nn.BatchNorm1d(self.num_hiddens),
19 | torch.nn.ReLU(True),
20 | torch.nn.MaxPool1d(4),
21 | torch.nn.Conv1d(self.num_hiddens, self.num_hiddens * 2, kernel_size=3),
22 | torch.nn.BatchNorm1d(self.num_hiddens * 2),
23 | torch.nn.ReLU(True),
24 | torch.nn.MaxPool1d(4),
25 | torch.nn.Conv1d(self.num_hiddens * 2, self.num_hiddens * 2, kernel_size=3),
26 | torch.nn.BatchNorm1d(self.num_hiddens * 2),
27 | torch.nn.ReLU(True),
28 | torch.nn.MaxPool1d(4),
29 | torch.nn.AdaptiveAvgPool1d(1),
30 | torch.nn.Flatten()
31 | )
32 | self.classifier = torch.nn.Linear(self.num_hiddens * 2, self.num_classes)
33 |
34 | def forward(self, x):
35 | x = self.features(x)
36 | x = self.classifier(x)
37 | return x
--------------------------------------------------------------------------------
/src/models/mobilebert.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformers import MobileBertModel, MobileBertConfig
4 |
5 |
6 |
7 | class MobileBert(torch.nn.Module):
8 | def __init__(self, num_classes, num_embeddings, embedding_size, hidden_size, dropout, use_pt_model, is_seq2seq):
9 | super(MobileBert, self).__init__()
10 | self.is_seq2seq = is_seq2seq
11 | # define encoder
12 | if use_pt_model: # fine-tuning
13 | self.features = MobileBertModel.from_pretrained('google/mobilebert-uncased')
14 | self.num_classes = num_classes
15 | self.num_embeddings = self.features.config.vocab_size
16 | self.embedding_size = self.features.config.embedding_size
17 | self.num_hiddens = self.features.config.hidden_size
18 | self.dropout = self.features.config.hidden_dropout_prob
19 |
20 | self.classifier = torch.nn.Linear(self.features.config.hidden_size, self.num_classes, bias=True)
21 |
22 | else: # from scratch
23 | self.num_classes = num_classes
24 | self.num_embeddings = num_embeddings
25 | self.embedding_size = embedding_size
26 | self.num_hiddens = hidden_size
27 | self.dropout = dropout
28 |
29 | config = MobileBertConfig(
30 | vocab_size=self.num_embeddings,
31 | embedding_size=self.embedding_size,
32 | hidden_size=self.num_hiddens,
33 | hidden_dropout_prob=self.dropout
34 | )
35 | self.features = MobileBertModel(config)
36 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True)
37 |
38 | def forward(self, x):
39 | x = self.features(x)
40 | x = self.classifier(x['last_hidden_state'] if self.is_seq2seq else x['pooler_output'])
41 | return x
42 |
--------------------------------------------------------------------------------
/src/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.models.model_utils import make_divisible, SELayer, InvertedResidualBlock
4 |
5 |
6 |
7 | class MobileNet(torch.nn.Module): # MobileNetv3-small
8 | CONFIG = [# k, t, c, SE, HS, s
9 | [3, 1, 16, 1, 0, 2],
10 | [3, 4.5, 24, 0, 0, 2],
11 | [3, 3.67, 24, 0, 0, 1],
12 | [5, 4, 40, 1, 1, 2],
13 | [5, 6, 40, 1, 1, 1],
14 | [5, 6, 40, 1, 1, 1],
15 | [5, 3, 48, 1, 1, 1],
16 | [5, 3, 48, 1, 1, 1],
17 | [5, 6, 96, 1, 1, 2],
18 | [5, 6, 96, 1, 1, 1],
19 | [5, 6, 96, 1, 1, 1]
20 | ]
21 |
22 | def __init__(self, in_channels, num_classes, dropout):
23 | super(MobileNet, self).__init__()
24 | self.in_channels = in_channels
25 | self.num_classes = num_classes
26 | self.dropout = dropout
27 |
28 | hidden_channels = make_divisible(16, 8)
29 | layers = [
30 | torch.nn.Sequential(
31 | torch.nn.Conv2d(in_channels, hidden_channels, 3, 2, 1, bias=False),
32 | torch.nn.BatchNorm2d(make_divisible(16, 8)),
33 | torch.nn.Hardswish(True)
34 | )
35 | ]
36 |
37 | # building inverted residual blocks
38 | for k, t, c, use_se, use_hs, s in self.CONFIG:
39 | out_channels = make_divisible(c * 1, 8)
40 | exp_size = make_divisible(hidden_channels * t, 8)
41 | layers.append(InvertedResidualBlock(hidden_channels, exp_size, out_channels, k, s, use_se, use_hs))
42 | hidden_channels = out_channels
43 | else:
44 | self.features1 = torch.nn.Sequential(*layers)
45 |
46 | # building last several layers
47 | self.features2 = torch.nn.Sequential(
48 | torch.nn.Conv2d(hidden_channels, exp_size, 1, 1, 0, bias=False),
49 | torch.nn.BatchNorm2d(exp_size),
50 | torch.nn.Hardswish(True),
51 | torch.nn.AdaptiveAvgPool2d((1, 1))
52 | )
53 | out_channels = 1024
54 | self.classifier = torch.nn.Sequential(
55 | torch.nn.Flatten(),
56 | torch.nn.Linear(exp_size, out_channels),
57 | torch.nn.Hardswish(True),
58 | torch.nn.Dropout(self.dropout),
59 | torch.nn.Linear(out_channels, self.num_classes),
60 | )
61 |
62 | for m in self.modules():
63 | if isinstance(m, torch.nn.Conv2d):
64 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
65 | if m.bias is not None:
66 | torch.nn.init.zeros_(m.bias)
67 | elif isinstance(m, torch.nn.BatchNorm2d):
68 | torch.nn.init.ones_(m.weight)
69 | torch.nn.init.zeros_(m.bias)
70 | elif isinstance(m, torch.nn.Linear):
71 | torch.nn.init.normal_(m.weight, 0, 0.01)
72 | torch.nn.init.zeros_(m.bias)
73 |
74 | def forward(self, x):
75 | x = self.features1(x)
76 | x = self.features2(x)
77 | x = self.classifier(x)
78 | return x
79 |
--------------------------------------------------------------------------------
/src/models/mobilenext.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from src.models.model_utils import make_divisible, SandGlassLayer
5 |
6 |
7 |
8 | class MobileNeXt(torch.nn.Module):
9 | CONFIG = [# t, c, n, s
10 | [2, 96, 1, 2],
11 | [6, 144, 1, 1],
12 | [6, 192, 3, 2],
13 | [6, 288, 3, 2],
14 | [6, 384, 4, 1],
15 | [6, 576, 4, 2],
16 | [6, 960, 3, 1],
17 | [6, 1280, 1, 1]
18 | ]
19 |
20 | def __init__(self, in_channels, num_classes, dropout):
21 | super(MobileNeXt, self).__init__()
22 | self.in_channels = in_channels
23 | self.num_classes = num_classes
24 | self.dropout = dropout
25 |
26 | # building first layer
27 | hidden_channels = make_divisible(32, 8)
28 | layers = [
29 | torch.nn.Sequential(
30 | torch.nn.Conv2d(self.in_channels, hidden_channels, 3, 2, 1, bias=False),
31 | torch.nn.BatchNorm2d(hidden_channels),
32 | torch.nn.ReLU6(True)
33 | )
34 | ]
35 |
36 | # building blocks
37 | for t, c, n, s in self.CONFIG:
38 | out_channels = make_divisible(c, 8)
39 | for i in range(n):
40 | layers.append(SandGlassLayer(hidden_channels, out_channels, s if i == 0 else 1, t))
41 | hidden_channels = out_channels
42 | self.features = torch.nn.Sequential(*layers)
43 |
44 | # building classifier
45 | self.classifier = torch.nn.Sequential(
46 | torch.nn.AdaptiveAvgPool2d((1, 1)),
47 | torch.nn.Flatten(),
48 | torch.nn.Dropout(self.dropout),
49 | torch.nn.Linear(out_channels, self.num_classes),
50 | )
51 |
52 | for m in self.modules():
53 | if isinstance(m, torch.nn.Conv2d):
54 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
55 | m.weight.data.normal_(0, math.sqrt(2. / n))
56 | if m.bias is not None:
57 | m.bias.data.zero_()
58 | elif isinstance(m, torch.nn.BatchNorm2d):
59 | m.weight.data.fill_(1)
60 | m.bias.data.zero_()
61 | elif isinstance(m, torch.nn.Linear):
62 | m.weight.data.normal_(0, 0.01)
63 | m.bias.data.zero_()
64 |
65 | def forward(self, x):
66 | x = self.features(x)
67 | x = self.classifier(x)
68 | return x
69 |
--------------------------------------------------------------------------------
/src/models/mobilevit.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.models.model_utils import MV2Block, MobileViTBlock
4 |
5 |
6 |
7 | class MobileViT(torch.nn.Module):
8 | L = [2, 4, 3]
9 | DIMS = [64, 80, 96]
10 | CHANNELS = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
11 |
12 | def __init__(self, crop, resize, in_channels, num_classes, dropout):
13 | super(MobileViT, self).__init__()
14 | size = crop if resize is None else resize
15 | assert (size >= 16) and (size % 2 == 0)
16 |
17 | self.size = size
18 | self.in_channels = in_channels
19 | self.num_classes = num_classes
20 | self.dropout = dropout
21 | self.patch_size = 2
22 |
23 | self.conv1 = torch.nn.Sequential(
24 | torch.nn.Conv2d(self.in_channels, self.CHANNELS[0], 3, 2, 1, bias=False),
25 | torch.nn.BatchNorm2d(self.CHANNELS[0]),
26 | torch.nn.SiLU(True)
27 | )
28 |
29 | self.mv2 = torch.nn.ModuleList([])
30 | self.mv2.append(MV2Block(self.CHANNELS[0], self.CHANNELS[1], 1, 4))
31 | self.mv2.append(MV2Block(self.CHANNELS[1], self.CHANNELS[2], 2, 4))
32 | self.mv2.append(MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4))
33 | self.mv2.append(MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4))
34 | self.mv2.append(MV2Block(self.CHANNELS[3], self.CHANNELS[4], 2, 4))
35 | self.mv2.append(MV2Block(self.CHANNELS[5], self.CHANNELS[6], 2, 4))
36 | self.mv2.append(MV2Block(self.CHANNELS[7], self.CHANNELS[8], 2, 4))
37 |
38 | self.mvit = torch.nn.ModuleList([])
39 | self.mvit.append(MobileViTBlock(self.DIMS[0], self.L[0], self.CHANNELS[5], 3, self.patch_size, int(self.DIMS[0] * 2)))
40 | self.mvit.append(MobileViTBlock(self.DIMS[1], self.L[1], self.CHANNELS[7], 3, self.patch_size, int(self.DIMS[1] * 4)))
41 | self.mvit.append(MobileViTBlock(self.DIMS[2], self.L[2], self.CHANNELS[9], 3, self.patch_size, int(self.DIMS[2] * 4)))
42 |
43 | self.features = torch.nn.Sequential(
44 | torch.nn.Conv2d(self.in_channels, self.CHANNELS[0], 3, 2, 1, bias=False),
45 | torch.nn.BatchNorm2d(self.CHANNELS[0]),
46 | torch.nn.SiLU(True),
47 | MV2Block(self.CHANNELS[0], self.CHANNELS[1], 1, 4),
48 | MV2Block(self.CHANNELS[1], self.CHANNELS[2], 2, 4),
49 | MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4),
50 | MV2Block(self.CHANNELS[2], self.CHANNELS[3], 1, 4),
51 | MV2Block(self.CHANNELS[3], self.CHANNELS[4], 2, 4),
52 | MobileViTBlock(self.DIMS[0], self.L[0], self.CHANNELS[5], 3, 2, int(self.DIMS[0] * 2)),
53 | MV2Block(self.CHANNELS[5], self.CHANNELS[6], 2, 4),
54 | MobileViTBlock(self.DIMS[1], self.L[1], self.CHANNELS[7], 3, 2, int(self.DIMS[1] * 4)),
55 | MV2Block(self.CHANNELS[7], self.CHANNELS[8], 2, 4),
56 | MobileViTBlock(self.DIMS[2], self.L[2], self.CHANNELS[9], 3, 2, int(self.DIMS[2] * 4)),
57 | torch.nn.Conv2d(self.CHANNELS[-2], self.CHANNELS[-1], 1, 1, 0, bias=False),
58 | torch.nn.BatchNorm2d(self.CHANNELS[-1]),
59 | torch.nn.SiLU(True)
60 | )
61 | self.classifier = torch.nn.Sequential(
62 | torch.nn.AvgPool2d(self.size // 32, 1),
63 | torch.nn.Flatten(),
64 | torch.nn.Dropout(self.dropout),
65 | torch.nn.Linear(self.CHANNELS[-1], self.num_classes, bias=False)
66 | )
67 |
68 | def forward(self, x):
69 | x = self.features(x)
70 | x = self.classifier(x)
71 | return x
72 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.models.model_utils import ResidualBlock
4 |
5 |
6 |
7 | __all__ = ['ResNet10', 'ResNet18', 'ResNet34']
8 |
9 | CONFIGS = {
10 | 'ResNet10': [1, 1, 1, 1],
11 | 'ResNet18': [2, 2, 2, 2],
12 | 'ResNet34': [3, 4, 6, 3]
13 | }
14 |
15 | class ResNet(torch.nn.Module):
16 | def __init__(self, config, block, in_channels, hidden_size, num_classes):
17 | super(ResNet, self).__init__()
18 | self.in_channels = in_channels
19 | self.hidden_size = hidden_size
20 | self.num_classes = num_classes
21 |
22 | self.features = torch.nn.Sequential(
23 | torch.nn.Conv2d(self.in_channels, self.hidden_size, kernel_size=3, stride=1, padding=1, bias=False),
24 | torch.nn.BatchNorm2d(self.hidden_size),
25 | torch.nn.ReLU(True),
26 | self._make_layers(block, self.hidden_size, config[0], stride=1),
27 | self._make_layers(block, self.hidden_size * 2, config[1], stride=2),
28 | self._make_layers(block, self.hidden_size * 4, config[2], stride=2),
29 | self._make_layers(block, self.hidden_size * 8, config[3], stride=2),
30 | )
31 | self.classifier = torch.nn.Sequential(
32 | torch.nn.AdaptiveAvgPool2d((7, 7)),
33 | torch.nn.Flatten(),
34 | torch.nn.Linear((7 * 7) * self.hidden_size, self.num_classes, bias=True)
35 | )
36 |
37 | def forward(self, x):
38 | x = self.features(x)
39 | x = self.classifier(x)
40 | return x
41 |
42 | def _make_layers(self, block, planes, num_blocks, stride):
43 | strides = [stride] + [1] * (num_blocks - 1)
44 | layers = []
45 | for stride in strides:
46 | layers.append(block(self.hidden_size, planes, stride))
47 | self.hidden_size = planes
48 | return torch.nn.Sequential(*layers)
49 |
50 | class ResNet10(ResNet):
51 | def __init__(self, in_channels, hidden_size, num_classes):
52 | super(ResNet10, self).__init__(CONFIGS['ResNet10'], ResidualBlock, in_channels, hidden_size, num_classes)
53 |
54 | class ResNet18(ResNet):
55 | def __init__(self, in_channels, hidden_size, num_classes):
56 | super(ResNet18, self).__init__(CONFIGS['ResNet18'], ResidualBlock, in_channels, hidden_size, num_classes)
57 |
58 | class ResNet34(ResNet):
59 | def __init__(self, in_channels, hidden_size, num_classes):
60 | super(ResNet34, self).__init__(CONFIGS['ResNet34'], ResidualBlock, in_channels, hidden_size, num_classes)
61 |
--------------------------------------------------------------------------------
/src/models/sent140lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.models.model_utils import Lambda
4 |
5 |
6 |
7 | class Sent140LSTM(torch.nn.Module):
8 | def __init__(self, num_classes, embedding_size, hidden_size, dropout, num_layers, glove_emb):
9 | super(Sent140LSTM, self).__init__()
10 | self.embedding_size = embedding_size
11 | self.num_hiddens = hidden_size
12 | self.num_classes = num_classes
13 | self.dropout = dropout
14 | self.num_layers = num_layers
15 | self.glove_emb = glove_emb
16 |
17 | self.features = torch.nn.Sequential(
18 | torch.nn.Embedding.from_pretrained(self.glove_emb, padding_idx=0),
19 | torch.nn.LSTM(
20 | input_size=self.embedding_size,
21 | hidden_size=self.num_hiddens,
22 | num_layers=self.num_layers,
23 | batch_first=True,
24 | dropout=self.dropout,
25 | bias=True
26 | ),
27 | Lambda(lambda x: x[0])
28 | )
29 | self.classifier = torch.nn.Sequential(
30 | torch.nn.Linear(self.num_hiddens, self.num_hiddens, bias=True),
31 | torch.nn.ReLU(True),
32 | torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True)
33 | )
34 |
35 |
36 | def forward(self, x):
37 | x = self.features(x)
38 | x = self.classifier(x[:, -1, :])
39 | return x
40 |
--------------------------------------------------------------------------------
/src/models/shufflenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.models.model_utils import ShuffleNetInvRes
4 |
5 |
6 |
7 | class ShuffleNet(torch.nn.Module):
8 | def __init__(self, in_channels, num_classes, dropout):
9 | super(ShuffleNet, self).__init__()
10 | self.in_channels = in_channels
11 | self.num_classes = num_classes
12 | self.dropout = dropout
13 |
14 | self.stage_repeats = [4, 8, 4]
15 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
16 |
17 | # building feature extractor
18 | features = []
19 |
20 | # input layers
21 | hidden_channels = self.stage_out_channels[1]
22 | features.append(
23 | torch.nn.Sequential(
24 | torch.nn.Conv2d(self.in_channels, hidden_channels, 3, 2, 1, bias=False),
25 | torch.nn.BatchNorm2d(hidden_channels),
26 | torch.nn.ReLU(True),
27 | torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
28 | )
29 | )
30 |
31 | # inverted residual layers
32 | for idx, num_repeats in enumerate(self.stage_repeats):
33 | out_channels = self.stage_out_channels[idx + 2]
34 | for i in range(num_repeats):
35 | if i == 0:
36 | features.append(ShuffleNetInvRes(hidden_channels, out_channels, 2, 2))
37 | else:
38 | features.append(ShuffleNetInvRes(hidden_channels, out_channels, 1, 1))
39 | hidden_channels = out_channels
40 |
41 | # pooling layers
42 | features.append(
43 | torch.nn.Sequential(
44 | torch.nn.Conv2d(hidden_channels, self.stage_out_channels[-1], 1, 1, 0, bias=False),
45 | torch.nn.BatchNorm2d(self.stage_out_channels[-1]),
46 | torch.nn.ReLU(True),
47 | torch.nn.AdaptiveAvgPool2d((1, 1)),
48 | torch.nn.Flatten()
49 | )
50 | )
51 | self.features = torch.nn.Sequential(*features)
52 | self.classifier = torch.nn.Linear(self.stage_out_channels[-1], self.num_classes)
53 |
54 | def forward(self, x):
55 | x = self.features(x)
56 | x = self.classifier(x)
57 | return x
--------------------------------------------------------------------------------
/src/models/simplecnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class SimpleCNN(torch.nn.Module): # for CIFAR10 experiment in McMahan et al., 2016; (https://github.com/tensorflow/tensorflow/blob/r0.11/tensorflow/models/image/cifar10/cifar10.py)
6 | def __init__(self, in_channels, hidden_size, num_classes):
7 | super(SimpleCNN, self).__init__()
8 | self.in_channels = in_channels
9 | self.hidden_channels = hidden_size
10 | self.num_classes = num_classes
11 |
12 | self.features = torch.nn.Sequential(
13 | torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=5, padding=2, stride=1, bias=True),
14 | torch.nn.ReLU(),
15 | torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
16 | torch.nn.LocalResponseNorm(size=9, alpha=0.001),
17 | torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=5, padding=2, stride=1, bias=True),
18 | torch.nn.ReLU(),
19 | torch.nn.LocalResponseNorm(size=9, alpha=0.001),
20 | torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
21 | )
22 | self.classifier = torch.nn.Sequential(
23 | torch.nn.AdaptiveAvgPool2d((6, 6)),
24 | torch.nn.Flatten(),
25 | torch.nn.Linear(in_features=self.hidden_channels * (6 * 6), out_features=384, bias=True),
26 | torch.nn.ReLU(),
27 | torch.nn.Linear(in_features=384, out_features=192, bias=True),
28 | torch.nn.ReLU(),
29 | torch.nn.Linear(in_features=192, out_features=self.num_classes, bias=True)
30 | )
31 |
32 | def forward(self, x):
33 | x = self.features(x)
34 | x = self.classifier(x)
35 | return x
36 |
--------------------------------------------------------------------------------
/src/models/squeezebert.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformers import SqueezeBertModel, SqueezeBertConfig
4 |
5 |
6 |
7 | class SqueezeBert(torch.nn.Module):
8 | def __init__(self, num_classes, num_embeddings, embedding_size, hidden_size, dropout, use_pt_model, is_seq2seq):
9 | super(SqueezeBert, self).__init__()
10 | self.is_seq2seq = is_seq2seq
11 | # define encoder
12 | if use_pt_model: # fine-tuning
13 | self.features = SqueezeBertModel.from_pretrained('squeezebert/squeezebert-uncased')
14 | self.num_classes = num_classes
15 | self.num_embeddings = self.features.config.vocab_size
16 | self.embedding_size = self.features.config.embedding_size
17 | self.num_hiddens = self.features.config.hidden_size
18 | self.dropout = self.features.config.hidden_dropout_prob
19 |
20 | self.classifier = torch.nn.Linear(self.features.config.hidden_size, self.num_classes, bias=True)
21 |
22 | else: # from scratch
23 | assert embedding_size == hidden_size, 'If you want embedding_size != intermediate hidden_size, please insert a Conv1d layer to adjust the number of channels before the first SqueezeBertModule.'
24 | self.num_classes = num_classes
25 | self.num_embeddings = num_embeddings
26 | self.embedding_size = embedding_size
27 | self.num_hiddens = hidden_size
28 | self.dropout = dropout
29 |
30 | config = SqueezeBertConfig(
31 | vocab_size=self.num_embeddings,
32 | embedding_size=self.embedding_size,
33 | hidden_size=self.num_hiddens,
34 | hidden_dropout_prob=self.dropout
35 | )
36 | self.features = SqueezeBertModel(config)
37 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True)
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | x = self.classifier(x['last_hidden_state'] if self.is_seq2seq else x['pooler_output'])
42 | return x
43 |
--------------------------------------------------------------------------------
/src/models/squeezenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from src.models.model_utils import FireBlock
5 |
6 |
7 |
8 | class SqueezeNet(torch.nn.Module): # MobileNetv3-small
9 | def __init__(self, in_channels, num_classes, dropout):
10 | super(SqueezeNet, self).__init__()
11 | self.in_channels = in_channels
12 | self.num_classes = num_classes
13 | self.dropout = dropout
14 |
15 | self.features = torch.nn.Sequential(
16 | torch.nn.Conv2d(self.in_channels, 64, kernel_size=3, stride=2),
17 | torch.nn.ReLU(True),
18 | torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
19 | FireBlock(64, 16, 64, 64),
20 | FireBlock(128, 16, 64, 64),
21 | torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
22 | FireBlock(128, 32, 128, 128),
23 | FireBlock(256, 32, 128, 128),
24 | torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
25 | FireBlock(256, 48, 192, 192),
26 | FireBlock(384, 48, 192, 192),
27 | FireBlock(384, 64, 256, 256),
28 | FireBlock(512, 64, 256, 256),
29 | )
30 | final_conv = torch.nn.Conv2d(512, self.num_classes, kernel_size=1)
31 | self.classifier = torch.nn.Sequential(
32 | torch.nn.Dropout(self.dropout),
33 | final_conv,
34 | torch.nn.ReLU(True),
35 | torch.nn.AdaptiveAvgPool2d((1, 1))
36 | )
37 |
38 | for m in self.modules():
39 | if isinstance(m, torch.nn.Conv2d):
40 | if m is final_conv:
41 | torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
42 | else:
43 | torch.nn.init.kaiming_uniform_(m.weight)
44 | if m.bias is not None:
45 | torch.nn.init.constant_(m.bias, 0)
46 |
47 | def forward(self, x):
48 | x = self.features(x)
49 | x = self.classifier(x)
50 | x = torch.flatten(x, 1)
51 | return x
52 |
--------------------------------------------------------------------------------
/src/models/squeezenext.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from src.models.model_utils import SNXBlock
5 |
6 |
7 |
8 | class SqueezeNeXt(torch.nn.Module): # MobileNetv3-small
9 | def __init__(self, in_channels, num_classes, dropout):
10 | super(SqueezeNeXt, self).__init__()
11 | self.in_channels = in_channels
12 | self.num_classes = num_classes
13 | self.dropout = dropout
14 |
15 | self.hidden_channels = 64
16 | self.features = torch.nn.Sequential(
17 | torch.nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, 1, bias=False),
18 | torch.nn.BatchNorm2d(self.hidden_channels),
19 | torch.nn.ReLU(True),
20 | self._make_layer(2, 32, 1),
21 | self._make_layer(4, 64, 2),
22 | self._make_layer(14, 128, 2),
23 | self._make_layer(1, 256, 2),
24 | torch.nn.Conv2d(self.hidden_channels, 128, 1, 1, bias=False),
25 | torch.nn.BatchNorm2d(128),
26 | torch.nn.ReLU(True),
27 | torch.nn.AdaptiveAvgPool2d((1, 1))
28 | )
29 | self.classifier = torch.nn.Sequential(
30 | torch.nn.Flatten(),
31 | torch.nn.Dropout(self.dropout),
32 | torch.nn.Linear(128, self.num_classes)
33 | )
34 |
35 | for m in self.modules():
36 | if isinstance(m, torch.nn.Conv2d):
37 | torch.nn.init.xavier_uniform_(m.weight, gain=math.sqrt(2.))
38 | if m.bias is not None:
39 | torch.nn.init.constant_(m.bias, 0.)
40 | elif isinstance(m, torch.nn.BatchNorm2d):
41 | torch.nn.init.constant_(m.weight, 1.)
42 | torch.nn.init.constant_(m.bias, 0.)
43 |
44 | def _make_layer(self, num_block, out_channels, stride):
45 | strides = [stride] + [1] * (num_block - 1)
46 | layers = []
47 | for s in strides:
48 | layers.append(SNXBlock(self.hidden_channels, out_channels, s))
49 | self.hidden_channels = out_channels
50 | return torch.nn.Sequential(*layers)
51 |
52 | def forward(self, x):
53 | x = self.features(x)
54 | x = self.classifier(x)
55 | return x
56 |
--------------------------------------------------------------------------------
/src/models/stackedlstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.models.model_utils import Lambda
4 |
5 |
6 |
7 | class StackedLSTM(torch.nn.Module):
8 | def __init__(self, num_classes, embedding_size, num_embeddings, hidden_size, dropout, num_layers, is_seq2seq):
9 | super(StackedLSTM, self).__init__()
10 | self.is_seq2seq = is_seq2seq
11 | self.num_hiddens = hidden_size
12 | self.num_classes = num_classes
13 | self.num_embeddings = num_embeddings
14 | self.embedding_size = embedding_size
15 | self.dropout = dropout
16 | self.num_layers = num_layers
17 |
18 | self.features = torch.nn.Sequential(
19 | torch.nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_size),
20 | torch.nn.LSTM(
21 | input_size=self.embedding_size,
22 | hidden_size=self.num_hiddens,
23 | num_layers=self.num_layers,
24 | batch_first=True,
25 | dropout=self.dropout,
26 | bias=True
27 | ),
28 | Lambda(lambda x: x[0])
29 | )
30 | self.classifier = torch.nn.Linear(self.num_hiddens, self.num_classes, bias=True)
31 |
32 | def forward(self, x):
33 | x = self.features(x)
34 | x = self.classifier(x if self.is_seq2seq else x[:, -1, :])
35 | return x
36 |
--------------------------------------------------------------------------------
/src/models/stackedtransformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from src.models.model_utils import Lambda, PositionalEncoding
5 |
6 |
7 |
8 | class StackedTransformer(torch.nn.Module):
9 | def __init__(self, num_classes, embedding_size, num_embeddings, hidden_size, seq_len, dropout, num_layers, is_seq2seq):
10 | super(StackedTransformer, self).__init__()
11 | self.is_seq2seq = is_seq2seq
12 | self.num_hiddens = hidden_size
13 | self.num_classes = num_classes
14 | self.num_embeddings = num_embeddings
15 | self.embedding_size = embedding_size
16 | self.seq_len = seq_len
17 | self.dropout = dropout
18 | self.num_layers = num_layers
19 |
20 | self.features = torch.nn.Sequential(
21 | torch.nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_size),
22 | PositionalEncoding(self.embedding_size, self.dropout),
23 | torch.nn.TransformerEncoder(
24 | torch.nn.TransformerEncoderLayer(self.embedding_size, 16, self.num_hiddens, self.dropout, batch_first=True),
25 | self.num_layers
26 | )
27 | )
28 | self.classifier = torch.nn.Linear(self.embedding_size, self.num_classes, bias=True)
29 |
30 | def forward(self, x):
31 | x = self.features(x)
32 | x = self.classifier(x if self.is_seq2seq else x[:, 0, :])
33 | return x
34 |
--------------------------------------------------------------------------------
/src/models/twocnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class TwoCNN(torch.nn.Module): # McMahan et al., 2016; 1,663,370 parameters
6 | def __init__(self, in_channels, hidden_size, num_classes):
7 | super(TwoCNN, self).__init__()
8 | self.in_channels = in_channels
9 | self.hidden_channels = hidden_size
10 | self.num_classes = num_classes
11 |
12 | self.features = torch.nn.Sequential(
13 | torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=(5, 5), padding=1, stride=1, bias=True),
14 | torch.nn.ReLU(True),
15 | torch.nn.MaxPool2d(kernel_size=(2, 2), padding=1),
16 | torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels * 2, kernel_size=(5, 5), padding=1, stride=1, bias=True),
17 | torch.nn.ReLU(True),
18 | torch.nn.MaxPool2d(kernel_size=(2, 2), padding=1)
19 | )
20 | self.classifier = torch.nn.Sequential(
21 | torch.nn.AdaptiveAvgPool2d((7, 7)),
22 | torch.nn.Flatten(),
23 | torch.nn.Linear(in_features=(self.hidden_channels * 2) * (7 * 7), out_features=512, bias=True),
24 | torch.nn.ReLU(True),
25 | torch.nn.Linear(in_features=512, out_features=self.num_classes, bias=True)
26 | )
27 |
28 | def forward(self, x):
29 | x = self.features(x)
30 | x = self.classifier(x)
31 | return x
32 |
--------------------------------------------------------------------------------
/src/models/twonn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class TwoNN(torch.nn.Module): # McMahan et al., 2016; 199,210 parameters
6 | def __init__(self, resize, hidden_size, num_classes):
7 | super(TwoNN, self).__init__()
8 | self.in_features = resize**2
9 | self.num_hiddens = hidden_size
10 | self.num_classes = num_classes
11 |
12 | self.features = torch.nn.Sequential(
13 | torch.nn.Flatten(),
14 | torch.nn.Linear(in_features=self.in_features, out_features=self.num_hiddens, bias=True),
15 | torch.nn.ReLU(True),
16 | torch.nn.Linear(in_features=self.num_hiddens, out_features=self.num_hiddens, bias=True),
17 | torch.nn.ReLU(True)
18 | )
19 | self.classifier = torch.nn.Linear(in_features=self.num_hiddens, out_features=self.num_classes, bias=True)
20 |
21 | def forward(self, x):
22 | x = self.features(x)
23 | x = self.classifier(x)
24 | return x
25 |
--------------------------------------------------------------------------------
/src/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | __all__ = ['VGG9', 'VGG9BN', 'VGG11', 'VGG11BN', 'VGG13', 'VGG13BN']
6 |
7 | CONFIGS = {
8 | 'VGG9': [64, 'mp', 128, 'mp', 256, 256, 'mp', 512, 512, 'mp'],
9 | 'VGG11': [64, 64, 'mp', 128, 128, 'mp', 256, 256, 'mp', 512, 512, 'mp'],
10 | 'VGG13': [64, 64, 'mp', 128, 128, 'mp', 256, 256, 'mp', 512, 512, 'mp', 512, 512, 'mp'],
11 | }
12 |
13 | class VGG(torch.nn.Module):
14 | def __init__(self, config, use_bn, in_channels, num_classes, dropout):
15 | super(VGG, self).__init__()
16 | self.use_bn = use_bn
17 |
18 | self.in_channels = in_channels
19 | self.num_classes = num_classes
20 | self.dropout = dropout
21 |
22 | self.features = torch.nn.Sequential(*self._make_layers(config, use_bn))
23 | self.classifier = torch.nn.Sequential(
24 | torch.nn.AdaptiveAvgPool2d((7, 7)),
25 | torch.nn.Flatten(),
26 | torch.nn.Linear((7 * 7) * 512, 4096, bias=True),
27 | torch.nn.ReLU(True),
28 | torch.nn.Dropout(self.dropout),
29 | torch.nn.Linear(4096, 4096),
30 | torch.nn.ReLU(True),
31 | torch.nn.Dropout(self.dropout),
32 | torch.nn.Linear(4096, self.num_classes)
33 | )
34 |
35 | def forward(self, x):
36 | x = self.features(x)
37 | x = self.classifier(x)
38 | return x
39 |
40 | def _make_layers(self, config, use_bn):
41 | layers = []
42 | in_channels = self.in_channels
43 | for v in config:
44 | if v == 'mp':
45 | layers.append(torch.nn.MaxPool2d(2, 2))
46 | else:
47 | layers.append(torch.nn.Conv2d(in_channels, v, 3, 1, 1))
48 | if use_bn:
49 | layers.append(torch.nn.BatchNorm2d(v))
50 | layers.append(torch.nn.ReLU(True))
51 | in_channels = v
52 | return layers
53 |
54 | class VGG9(VGG):
55 | def __init__(self, in_channels, num_classes, dropout):
56 | super(VGG9, self).__init__(CONFIGS['VGG9'], False, in_channels, num_classes, dropout)
57 |
58 | class VGG9BN(VGG):
59 | def __init__(self, in_channels, num_classes, dropout):
60 | super(VGG9BN, self).__init__(CONFIGS['VGG9'], True, in_channels, num_classes, dropout)
61 |
62 | class VGG11(VGG):
63 | def __init__(self, in_channels, num_classes, dropout):
64 | super(VGG11, self).__init__(CONFIGS['VGG11'], False, in_channels, num_classes, dropout)
65 |
66 | class VGG11BN(VGG):
67 | def __init__(self, in_channels, num_classes, dropout):
68 | super(VGG11BN, self).__init__(CONFIGS['VGG11'], True, in_channels, num_classes, dropout)
69 |
70 | class VGG13(VGG):
71 | def __init__(self, in_channels, num_classes, dropout):
72 | super(VGG13, self).__init__(CONFIGS['VGG13'], False, in_channels, num_classes, dropout)
73 |
74 | class VGG13BN(VGG):
75 | def __init__(self, in_channels, num_classes, dropout):
76 | super(VGG13BN, self).__init__(CONFIGS['VGG13'], True, in_channels, num_classes, dropout)
77 |
--------------------------------------------------------------------------------
/src/server/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vaseline555/Federated-Learning-in-PyTorch/6c07b19c6810c82bd9455bf7364808a568376bf4/src/server/__init__.py
--------------------------------------------------------------------------------
/src/server/baseserver.py:
--------------------------------------------------------------------------------
1 | from abc import *
2 |
3 |
4 | class BaseServer(metaclass=ABCMeta):
5 | """Centeral server orchestrating the whole process of federated learning.
6 | """
7 | def __init__(self, **kwargs):
8 | self._round = 0
9 | self._model = None
10 | self._clients = None
11 |
12 | @property
13 | def model(self):
14 | return self._model
15 |
16 | @model.setter
17 | def model(self, model):
18 | self._model = model
19 |
20 | @property
21 | def round(self):
22 | return self._round
23 |
24 | @round.setter
25 | def round(self, round):
26 | self._round = round
27 |
28 | @property
29 | def clients(self):
30 | return self._clients
31 |
32 | @clients.setter
33 | def clients(self, clients):
34 | self._clients = clients
35 |
36 | @abstractmethod
37 | def _init_model(self, model):
38 | raise NotImplementedError
39 |
40 | @abstractmethod
41 | def _get_algorithm(self, model, **kwargs):
42 | raise NotImplementedError
43 |
44 | @abstractmethod
45 | def _create_clients(self, client_datasets):
46 | raise NotImplementedError
47 |
48 | @abstractmethod
49 | def _sample_clients(self):
50 | raise NotImplementedError
51 |
52 | @abstractmethod
53 | def _request(self, indices, eval=False):
54 | raise NotImplementedError
55 |
56 | @abstractmethod
57 | def _aggregate(self, indices, update_sizes):
58 | raise NotImplementedError
59 |
60 | @abstractmethod
61 | def _central_evaluate(self):
62 | raise NotImplementedError
63 |
64 | @abstractmethod
65 | def update(self):
66 | raise NotImplementedError
67 |
68 | @abstractmethod
69 | def evaluate(self):
70 | raise NotImplementedError
71 |
72 | @abstractmethod
73 | def finalize(self):
74 | raise NotImplementedError
75 |
--------------------------------------------------------------------------------
/src/server/fedadagradserver.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .fedavgserver import FedavgServer
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 |
9 | class FedadagradServer(FedavgServer):
10 | def __init__(self, **kwargs):
11 | super(FedadagradServer, self).__init__(**kwargs)
12 | self.opt_kwargs = dict(
13 | beta=self.args.beta1,
14 | v0=self.args.tau**2,
15 | tau=self.args.tau,
16 | lr=self.args.server_lr
17 | )
18 |
--------------------------------------------------------------------------------
/src/server/fedadamserver.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .fedavgserver import FedavgServer
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 |
9 | class FedadamServer(FedavgServer):
10 | def __init__(self, **kwargs):
11 | super(FedadamServer, self).__init__(**kwargs)
12 | self.opt_kwargs = dict(
13 | betas=(self.args.beta1, self.args.beta2),
14 | v0=self.args.tau**2,
15 | tau=self.args.tau,
16 | lr=self.args.server_lr
17 | )
18 |
--------------------------------------------------------------------------------
/src/server/fedavgmserver.py:
--------------------------------------------------------------------------------
1 |
2 | import logging
3 |
4 | from .fedavgserver import FedavgServer
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 |
10 | class FedavgmServer(FedavgServer):
11 | def __init__(self, **kwargs):
12 | super(FedavgmServer, self).__init__(**kwargs)
13 |
--------------------------------------------------------------------------------
/src/server/fedproxserver.py:
--------------------------------------------------------------------------------
1 |
2 | import logging
3 |
4 | from .fedavgserver import FedavgServer
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 |
10 | class FedproxServer(FedavgServer):
11 | def __init__(self, **kwargs):
12 | super(FedproxServer, self).__init__(**kwargs)
13 |
--------------------------------------------------------------------------------
/src/server/fedsgdserver.py:
--------------------------------------------------------------------------------
1 |
2 | import logging
3 |
4 | from .fedavgserver import FedavgServer
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 |
10 | class FedsgdServer(FedavgServer):
11 | def __init__(self, **kwargs):
12 | super(FedsgdServer, self).__init__(**kwargs)
13 |
--------------------------------------------------------------------------------
/src/server/fedyogiserver.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .fedavgserver import FedavgServer
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 |
9 | class FedyogiServer(FedavgServer):
10 | def __init__(self, **kwargs):
11 | super(FedyogiServer, self).__init__(**kwargs)
12 | self.opt_kwargs = dict(
13 | betas=(self.args.beta1, self.args.beta2),
14 | v0=self.args.tau**2,
15 | tau=self.args.tau,
16 | lr=self.args.server_lr
17 | )
18 |
--------------------------------------------------------------------------------