├── LICENSE
├── README.md
├── config.py
├── criteo-dis.npy
├── datasets.py
├── experiments.py
├── femnist-dis.npy
├── figures
├── 100parties
│ ├── cifar10-homo.png
│ ├── cifar10-lb2.png
│ ├── cifar10-lbdir.png
│ └── cifar10-quan.png
├── 10parties
│ ├── cifar10-iid-diff-quantity.png
│ ├── cifar10-noise.png
│ ├── cifar10-noniid-label2.png
│ └── cifar10-noniid-labeldir.png
└── heavy-model
│ ├── resnet-noise.png
│ └── vgg-lbdir.png
├── model.py
├── models
├── celeba_model.py
├── mnist_model.py
└── svhn_model.py
├── partition.py
├── partition_to_file.sh
├── requirements.txt
├── resnetcifar.py
├── run.sh
├── scripts
├── 100-parties.sh
├── adult&covtype.sh
├── batch-size.sh
├── fcube.sh
├── femnist.sh
├── image-data-with-noise.sh
├── image-data-without-noise.sh
├── rcv1.sh
└── vgg&resnet.sh
├── utils.py
└── vggmodel.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Yiqun Diao, Qinbin Li
4 |
5 | Copyright (c) 2020 International Business Machines
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NIID-Bench
2 |
3 | [](https://arxiv.org/pdf/2102.02079.pdf)
4 |
5 | [](https://niidbench.xtra.science/)
6 |
7 |
8 | This is the code of paper [Federated Learning on Non-IID Data Silos: An Experimental Study](https://arxiv.org/pdf/2102.02079.pdf).
9 |
10 |
11 | This code runs a benchmark for federated learning algorithms under non-IID data distribution scenarios. Specifically, we implement 4 federated learning algorithms (FedAvg, FedProx, SCAFFOLD & FedNova), 3 types of non-IID settings (label distribution skew, feature distribution skew & quantity skew) and 9 datasets (MNIST, Cifar-10, Fashion-MNIST, SVHN, Generated 3D dataset, FEMNIST, adult, rcv1, covtype).
12 |
13 |
14 | ## Updates on NIID-Bench
15 | Our follow-up works based on NIID-Bench:
16 |
17 | * [FedOV](https://github.com/Xtra-Computing/FedOV): Towards Addressing Label Skews in One-Shot Federated Learning (ICLR 2023)
18 |
19 | * [FedConcat](https://github.com/sjtudyq/FedConcat): Exploiting Label Skew in Federated Learning with Model Concatenation (AAAI 2024)
20 |
21 | We publish NIID-Bench challenge https://niidbench.xtra.science, a benchmark to compare federated learning algorithms on comprehensive non-IID data settings. Researchers are welcome to test their algorithms on these settings, upload their codes and participate in our leaderboard!
22 |
23 | Implement `partition.py` to divide tabular datasets (csv format) into multiple files using our non-IID partitioning strategies. Column `Class` in the header is recognized as label. See an running example in `partition_to_file.sh`. The example dataset is [Credit Card Fraud Detection](https://www.kaggle.com/datasets/mlg-ulb/creditcardfraud).
24 |
25 | To adapt to your own tabular dataset in ``partition.py``, you need the following steps:
26 |
27 | 1. Load your own dataset in arrays. Replace Line 117-126.
28 | 2. The whole tabular dataset is stored in ``dataset``. The label column ID is stored in ``class_id``. Change Line 130 to your own label identifier.
29 |
30 | If your dataset is image dataset, ``partition.py`` is no longer applicable. You can refer to our function ``partition_data`` in ``utils.py``. You need to design your own dataloader like Line 183-198. For example, in load_mnist_data (Line 40), you need to write a dataloader to return your dataset as tuple (X_train, y_train, X_test, y_test). In terms of the dataloader format, you can refer to class ``MNIST_truncated`` (Line 60 in ``dataset.py``). After you get (X_train, y_train, X_test, y_test), the ``partition_data`` function will return the ``net_dataidx_map``.
31 |
32 | To support more settings and faciliate future researches, we now integrate MOON. We add CIFAR-100 and Tiny-ImageNet.
33 |
34 | ### Tiny-ImageNet
35 | You can download Tiny-ImageNet [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip). Then, you can follow the [instructions](https://github.com/AI-secure/DBA/blob/master/utils/tinyimagenet_reformat.py) to reformat the validation folder.
36 |
37 | ## Non-IID Settings
38 | ### Label Distribution Skew
39 | * **Quantity-based label imbalance**: each party owns data samples of a fixed number of labels.
40 | * **Distribution-based label imbalance**: each party is allocated a proportion of the samples of each label according to Dirichlet distribution.
41 | ### Feature Distribution Skew
42 | * **Noise-based feature imbalance**: We first divide the whole dataset into multiple parties randomly and equally. For each party, we add different levels of Gaussian noises.
43 | * **Synthetic feature imbalance**: For generated 3D data set, we allocate two parts which are symmetric of(0,0,0) to a subset for each party.
44 | * **Real-world feature imbalance**: For FEMNIST, we divide and assign the writers (and their characters) into each party randomly and equally.
45 | ### Quantity Skew
46 | * While the data distribution may still be consistent amongthe parties, the size of local dataset varies according to Dirichlet distribution.
47 |
48 |
49 |
50 | ## Usage
51 | Here is one example to run this code:
52 | ```
53 | python experiments.py --model=simple-cnn \
54 | --dataset=cifar10 \
55 | --alg=fedprox \
56 | --lr=0.01 \
57 | --batch-size=64 \
58 | --epochs=10 \
59 | --n_parties=10 \
60 | --mu=0.01 \
61 | --rho=0.9 \
62 | --comm_round=50 \
63 | --partition=noniid-labeldir \
64 | --beta=0.5\
65 | --device='cuda:0'\
66 | --datadir='./data/' \
67 | --logdir='./logs/' \
68 | --noise=0 \
69 | --sample=1 \
70 | --init_seed=0
71 | ```
72 |
73 | | Parameter | Description |
74 | | ----------------------------- | ---------------------------------------- |
75 | | `model` | The model architecture. Options: `simple-cnn`, `vgg`, `resnet`, `mlp`. Default = `mlp`. |
76 | | `dataset` | Dataset to use. Options: `mnist`, `cifar10`, `fmnist`, `svhn`, `generated`, `femnist`, `a9a`, `rcv1`, `covtype`. Default = `mnist`. |
77 | | `alg` | The training algorithm. Options: `fedavg`, `fedprox`, `scaffold`, `fednova`, `moon`. Default = `fedavg`. |
78 | | `lr` | Learning rate for the local models, default = `0.01`. |
79 | | `batch-size` | Batch size, default = `64`. |
80 | | `epochs` | Number of local training epochs, default = `5`. |
81 | | `n_parties` | Number of parties, default = `2`. |
82 | | `mu` | The proximal term parameter for FedProx, default = `0.001`. |
83 | | `rho` | The parameter controlling the momentum SGD, default = `0`. |
84 | | `comm_round` | Number of communication rounds to use, default = `50`. |
85 | | `partition` | The partition way. Options: `homo`, `noniid-labeldir`, `noniid-#label1` (or 2, 3, ..., which means the fixed number of labels each party owns), `real`, `iid-diff-quantity`. Default = `homo` |
86 | | `beta` | The concentration parameter of the Dirichlet distribution for heterogeneous partition, default = `0.5`. |
87 | | `device` | Specify the device to run the program, default = `cuda:0`. |
88 | | `datadir` | The path of the dataset, default = `./data/`. |
89 | | `logdir` | The path to store the logs, default = `./logs/`. |
90 | | `noise` | Maximum variance of Gaussian noise we add to local party, default = `0`. |
91 | | `sample` | Ratio of parties that participate in each communication round, default = `1`. |
92 | | `init_seed` | The initial seed, default = `0`. |
93 |
94 |
95 |
96 | ## Data Partition Map
97 | You can call function `get_partition_dict()` in `experiments.py` to access `net_dataidx_map`. `net_dataidx_map` is a dictionary. Its keys are party ID, and the value of each key is a list containing index of data assigned to this party. For our experiments, we usually set `init_seed=0`. When we repeat experiments of some setting, we change `init_seed` to 1 or 2. The default value of `noise` is 0 unless stated. We list the way to get our data partition as follow.
98 | * **Quantity-based label imbalance**: `partition`=`noniid-#label1`, `noniid-#label2` or `noniid-#label3`
99 | * **Distribution-based label imbalance**: `partition`=`noniid-labeldir`, `beta`=`0.5` or `0.1`
100 | * **Noise-based feature imbalance**: `partition`=`homo`, `noise`=`0.1` (actually noise does not affect `net_dataidx_map`)
101 | * **Synthetic feature imbalance & Real-world feature imbalance**: `partition`=`real`
102 | * **Quantity Skew**: `partition`=`iid-diff-quantity`, `beta`=`0.5` or `0.1`
103 | * **IID Setting**: `partition`=`homo`
104 | * **Mixed skew**: `partition` = `mixed` for mixture of distribution-based label imbalance and quantity skew; `partition` = `noniid-labeldir` and `noise` = `0.1` for mixture of distribution-based label imbalance and noise-based feature imbalance.
105 |
106 | Here is explanation of parameter for function `get_partition_dict()`.
107 |
108 | | Parameter | Description |
109 | | ----------------------------- | ---------------------------------------- |
110 | | `dataset` | Dataset to use. Options: `mnist`, `cifar10`, `fmnist`, `svhn`, `generated`, `femnist`, `a9a`, `rcv1`, `covtype`. |
111 | | `partition` | Tha partition way. Options: `homo`, `noniid-labeldir`, `noniid-#label1` (or 2, 3, ..., which means the fixed number of labels each party owns), `real`, `iid-diff-quantity` |
112 | | `n_parties` | Number of parties. |
113 | | `init_seed` | The initial seed. |
114 | | `datadir` | The path of the dataset. |
115 | | `logdir` | The path to store the logs. |
116 | | `beta` | The concentration parameter of the Dirichlet distribution for heterogeneous partition. |
117 |
118 | ## Leader Board
119 |
120 | Note that the accuracy shows the average of three experiments, while the training curve is based on only one experiment. Thus, there may be some difference. We show the training curve to compare convergence rate of different algorithms.
121 |
122 | ### Quantity-based label imbalance
123 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01
124 |
125 | | Partition | Model | Round | Algorithm | Accuracy |
126 | | --------------|--------------- | -------------- | ------------ | -------------- |
127 | | `noniid-#label2` | `simple-cnn` |50| FedProx (`mu=0.01`) | 50.7% |
128 | | `noniid-#label2` | `simple-cnn` |50| FedAvg | 49.8% |
129 | | `noniid-#label2` | `simple-cnn` |50| SCAFFOLD | 49.1% |
130 | | `noniid-#label2` | `simple-cnn` |50| FedNova | 46.5% |
131 |
132 | 
133 |
134 |
135 | * Cifar-10, 100 parties, sample rate = 0.1, batch size = 64, learning rate = 0.01
136 |
137 | | Partition | Model | Round | Algorithm | Accuracy |
138 | | --------------|--------------- | -------------- | ------------ | -------------- |
139 | | `noniid-#label2` | `simple-cnn` |500| FedNova | 48.0% |
140 | | `noniid-#label2` | `simple-cnn` |500| FedAvg | 45.3% |
141 | | `noniid-#label2` | `simple-cnn` |500| FedProx (`mu=0.001`) | 39.3% |
142 | | `noniid-#label2` | `simple-cnn` |500| SCAFFOLD | 10.0% |
143 |
144 | 
145 |
146 | ### Distribution-based label imbalance
147 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01
148 |
149 | | Partition | Model | Round | Algorithm | Accuracy |
150 | | --------------|--------------- | -------------- | ------------ | -------------- |
151 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| SCAFFOLD | 69.8% |
152 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| FedAvg | 68.2% |
153 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| FedProx (`mu=0.001`) | 67.9% |
154 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| FedNova | 66.8% |
155 |
156 | 
157 |
158 | | Partition | Model | Round | Algorithm | Accuracy |
159 | | --------------|--------------- | -------------- | ------------ | -------------- |
160 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| SCAFFOLD | 85.5% |
161 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| FedNova | 84.4% |
162 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| FedProx (`mu=0.01`) | 84.4% |
163 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| FedAvg | 84.0% |
164 |
165 | 
166 |
167 | * Cifar-10, 100 parties, sample rate = 0.1, batch size = 64, learning rate = 0.01
168 |
169 | | Partition | Model | Round | Algorithm | Accuracy |
170 | | --------------|--------------- | -------------- | ------------ | -------------- |
171 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| FedNova | 60.0% |
172 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| FedAvg | 59.4% |
173 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| FedProx (`mu=0.001`) | 58.8% |
174 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| SCAFFOLD | 10.0% |
175 |
176 | 
177 |
178 | ### Noise-based feature imbalance
179 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01
180 |
181 | | Partition | Model | Round | Algorithm | Accuracy |
182 | | --------------|--------------- | -------------- | ------------ | -------------- |
183 | | `homo` with `noise=0.1` | `simple-cnn` |50| SCAFFOLD | 70.1% |
184 | | `homo` with `noise=0.1` | `simple-cnn` |50| FedProx (`mu=0.01`) | 69.3% |
185 | | `homo` with `noise=0.1` | `simple-cnn` |50| FedAvg | 68.9% |
186 | | `homo` with `noise=0.1` | `simple-cnn` |50| FedNova | 68.5% |
187 |
188 | 
189 |
190 | | Partition | Model | Round | Algorithm | Accuracy |
191 | | --------------|--------------- | -------------- | ------------ | -------------- |
192 | | `homo` with `noise=0.1` | `resnet` |100| SCAFFOLD | 90.2% |
193 | | `homo` with `noise=0.1` | `resnet` |100| FedNova | 89.4% |
194 | | `homo` with `noise=0.1` | `resnet` |100| FedProx (`mu=0.01`) | 89.2% |
195 | | `homo` with `noise=0.1` | `resnet` |100| FedAvg | 89.1% |
196 |
197 | 
198 |
199 | ### Quantity Skew
200 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01
201 |
202 | | Partition | Model | Round | Algorithm | Accuracy |
203 | | --------------|--------------- | -------------- | ------------ | -------------- |
204 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| FedAvg | 72.0% |
205 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| FedProx (`mu=0.01`) | 71.2% |
206 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| SCAFFOLD | 62.4% |
207 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| FedNova | 10.0% |
208 |
209 | 
210 |
211 | ### IID Setting
212 | * Cifar-10, 100 parties, sample rate = 0.1, batch size = 64, learning rate = 0.01
213 |
214 | | Partition | Model | Round | Algorithm | Accuracy |
215 | | --------------|--------------- | -------------- | ------------ | -------------- |
216 | |`homo`| `simple-cnn` |500| FedNova | 66.1% |
217 | |`homo`| `simple-cnn` |500| FedProx (`mu=0.01`) | 66.0% |
218 | |`homo`| `simple-cnn` |500| FedAvg | 65.6% |
219 | |`homo`| `simple-cnn` |500| SCAFFOLD | 10.0% |
220 |
221 | 
222 |
223 | ## Citation
224 | If you find this repository useful, please cite our paper:
225 |
226 | ```
227 | @inproceedings{li2022federated,
228 | title={Federated Learning on Non-IID Data Silos: An Experimental Study},
229 | author={Li, Qinbin and Diao, Yiqun and Chen, Quan and He, Bingsheng},
230 | booktitle={IEEE International Conference on Data Engineering},
231 | year={2022}
232 | }
233 | ```
234 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Dictionary storing network parameters.
2 | params = {
3 | 'num_epochs': 100,# Number of epochs to train for.
4 | 'learning_rate': 2e-4,# Learning rate.
5 | 'beta1': 0.5,
6 | 'beta2': 0.999,
7 | 'save_epoch' : 10,# After how many epochs to save checkpoints and generate test output.
8 | }
9 |
--------------------------------------------------------------------------------
/criteo-dis.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/criteo-dis.npy
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torch
3 | from PIL import Image
4 | import numpy as np
5 | from torchvision.datasets import MNIST, CIFAR10, SVHN, FashionMNIST, CIFAR100, ImageFolder, DatasetFolder, utils
6 | from torchvision.datasets.vision import VisionDataset
7 | from torchvision.datasets.utils import download_file_from_google_drive, check_integrity
8 | from functools import partial
9 | from typing import Optional, Callable
10 | from torch.utils.model_zoo import tqdm
11 | import PIL
12 | import tarfile
13 | import torchvision
14 |
15 | import os
16 | import os.path
17 | import logging
18 | import torchvision.datasets.utils as utils
19 |
20 | logging.basicConfig()
21 | logger = logging.getLogger()
22 | logger.setLevel(logging.INFO)
23 |
24 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
25 |
26 | def mkdirs(dirpath):
27 | try:
28 | os.makedirs(dirpath)
29 | except Exception as _:
30 | pass
31 |
32 | def accimage_loader(path):
33 | import accimage
34 | try:
35 | return accimage.Image(path)
36 | except IOError:
37 | # Potentially a decoding problem, fall back to PIL.Image
38 | return pil_loader(path)
39 |
40 |
41 | def pil_loader(path):
42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
43 | with open(path, 'rb') as f:
44 | img = Image.open(f)
45 | return img.convert('RGB')
46 |
47 |
48 | def default_loader(path):
49 | from torchvision import get_image_backend
50 | if get_image_backend() == 'accimage':
51 | return accimage_loader(path)
52 | else:
53 | return pil_loader(path)
54 |
55 | class CustomTensorDataset(data.TensorDataset):
56 | def __getitem__(self, index):
57 | return tuple(tensor[index] for tensor in self.tensors) + (index,)
58 |
59 |
60 | class MNIST_truncated(data.Dataset):
61 |
62 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
63 |
64 | self.root = root
65 | self.dataidxs = dataidxs
66 | self.train = train
67 | self.transform = transform
68 | self.target_transform = target_transform
69 | self.download = download
70 |
71 | self.data, self.target = self.__build_truncated_dataset__()
72 |
73 | def __build_truncated_dataset__(self):
74 |
75 | mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download)
76 |
77 | # if self.train:
78 | # data = mnist_dataobj.train_data
79 | # target = mnist_dataobj.train_labels
80 | # else:
81 | # data = mnist_dataobj.test_data
82 | # target = mnist_dataobj.test_labels
83 |
84 | data = mnist_dataobj.data
85 | target = mnist_dataobj.targets
86 |
87 | if self.dataidxs is not None:
88 | data = data[self.dataidxs]
89 | target = target[self.dataidxs]
90 |
91 | return data, target
92 |
93 | def __getitem__(self, index):
94 | """
95 | Args:
96 | index (int): Index
97 |
98 | Returns:
99 | tuple: (image, target) where target is index of the target class.
100 | """
101 | img, target = self.data[index], self.target[index]
102 |
103 | # doing this so that it is consistent with all other datasets
104 | # to return a PIL Image
105 | img = Image.fromarray(img.numpy(), mode='L')
106 |
107 | # print("mnist img:", img)
108 | # print("mnist target:", target)
109 |
110 | if self.transform is not None:
111 | img = self.transform(img)
112 |
113 | if self.target_transform is not None:
114 | target = self.target_transform(target)
115 |
116 | return img, target
117 |
118 | def __len__(self):
119 | return len(self.data)
120 |
121 | class FashionMNIST_truncated(data.Dataset):
122 |
123 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
124 |
125 | self.root = root
126 | self.dataidxs = dataidxs
127 | self.train = train
128 | self.transform = transform
129 | self.target_transform = target_transform
130 | self.download = download
131 |
132 | self.data, self.target = self.__build_truncated_dataset__()
133 |
134 | def __build_truncated_dataset__(self):
135 |
136 | mnist_dataobj = FashionMNIST(self.root, self.train, self.transform, self.target_transform, self.download)
137 |
138 | # if self.train:
139 | # data = mnist_dataobj.train_data
140 | # target = mnist_dataobj.train_labels
141 | # else:
142 | # data = mnist_dataobj.test_data
143 | # target = mnist_dataobj.test_labels
144 |
145 | data = mnist_dataobj.data
146 | target = mnist_dataobj.targets
147 |
148 | if self.dataidxs is not None:
149 | data = data[self.dataidxs]
150 | target = target[self.dataidxs]
151 |
152 | return data, target
153 |
154 | def __getitem__(self, index):
155 | """
156 | Args:
157 | index (int): Index
158 |
159 | Returns:
160 | tuple: (image, target) where target is index of the target class.
161 | """
162 | img, target = self.data[index], self.target[index]
163 |
164 | # doing this so that it is consistent with all other datasets
165 | # to return a PIL Image
166 | img = Image.fromarray(img.numpy(), mode='L')
167 |
168 | # print("mnist img:", img)
169 | # print("mnist target:", target)
170 |
171 | if self.transform is not None:
172 | img = self.transform(img)
173 |
174 | if self.target_transform is not None:
175 | target = self.target_transform(target)
176 |
177 | return img, target
178 |
179 | def __len__(self):
180 | return len(self.data)
181 |
182 | class SVHN_custom(data.Dataset):
183 |
184 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
185 |
186 | self.root = root
187 | self.dataidxs = dataidxs
188 | self.train = train
189 | self.transform = transform
190 | self.target_transform = target_transform
191 | self.download = download
192 |
193 | self.data, self.target = self.__build_truncated_dataset__()
194 |
195 | def __build_truncated_dataset__(self):
196 | if self.train is True:
197 | # svhn_dataobj1 = SVHN(self.root, 'train', self.transform, self.target_transform, self.download)
198 | # svhn_dataobj2 = SVHN(self.root, 'extra', self.transform, self.target_transform, self.download)
199 | # data = np.concatenate((svhn_dataobj1.data, svhn_dataobj2.data), axis=0)
200 | # target = np.concatenate((svhn_dataobj1.labels, svhn_dataobj2.labels), axis=0)
201 |
202 | svhn_dataobj = SVHN(self.root, 'train', self.transform, self.target_transform, self.download)
203 | data = svhn_dataobj.data
204 | target = svhn_dataobj.labels
205 | else:
206 | svhn_dataobj = SVHN(self.root, 'test', self.transform, self.target_transform, self.download)
207 | data = svhn_dataobj.data
208 | target = svhn_dataobj.labels
209 |
210 | if self.dataidxs is not None:
211 | data = data[self.dataidxs]
212 | target = target[self.dataidxs]
213 | # print("svhn data:", data)
214 | # print("len svhn data:", len(data))
215 | # print("type svhn data:", type(data))
216 | # print("svhn target:", target)
217 | # print("type svhn target", type(target))
218 | return data, target
219 |
220 | # def truncate_channel(self, index):
221 | # for i in range(index.shape[0]):
222 | # gs_index = index[i]
223 | # self.data[gs_index, :, :, 1] = 0.0
224 | # self.data[gs_index, :, :, 2] = 0.0
225 |
226 | def __getitem__(self, index):
227 | """
228 | Args:
229 | index (int): Index
230 |
231 | Returns:
232 | tuple: (image, target) where target is index of the target class.
233 | """
234 | img, target = self.data[index], self.target[index]
235 | # print("svhn img:", img)
236 | # print("svhn target:", target)
237 | # doing this so that it is consistent with all other datasets
238 | # to return a PIL Image
239 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
240 |
241 | if self.transform is not None:
242 | img = self.transform(img)
243 |
244 | if self.target_transform is not None:
245 | target = self.target_transform(target)
246 |
247 | return img, target
248 |
249 | def __len__(self):
250 | return len(self.data)
251 |
252 |
253 | # torchvision CelebA
254 | class CelebA_custom(VisionDataset):
255 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
256 |
257 | Args:
258 | root (string): Root directory where images are downloaded to.
259 | split (string): One of {'train', 'valid', 'test', 'all'}.
260 | Accordingly dataset is selected.
261 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
262 | or ``landmarks``. Can also be a list to output a tuple with all specified target types.
263 | The targets represent:
264 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
265 | ``identity`` (int): label for each person (data points with the same identity are the same person)
266 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
267 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
268 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
269 | Defaults to ``attr``. If empty, ``None`` will be returned as target.
270 | transform (callable, optional): A function/transform that takes in an PIL image
271 | and returns a transformed version. E.g, ``transforms.ToTensor``
272 | target_transform (callable, optional): A function/transform that takes in the
273 | target and transforms it.
274 | download (bool, optional): If true, downloads the dataset from the internet and
275 | puts it in root directory. If dataset is already downloaded, it is not
276 | downloaded again.
277 | """
278 |
279 | base_folder = "celeba"
280 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
281 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
282 | # right now.
283 | file_list = [
284 | # File ID MD5 Hash Filename
285 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
286 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
287 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
288 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
289 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
290 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
291 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
292 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
293 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
294 | ]
295 |
296 | def __init__(self, root, dataidxs=None, split="train", target_type="attr", transform=None,
297 | target_transform=None, download=False):
298 | import pandas
299 | super(CelebA_custom, self).__init__(root, transform=transform,
300 | target_transform=target_transform)
301 | self.split = split
302 | if isinstance(target_type, list):
303 | self.target_type = target_type
304 | else:
305 | self.target_type = [target_type]
306 |
307 | if not self.target_type and self.target_transform is not None:
308 | raise RuntimeError('target_transform is specified but target_type is empty')
309 |
310 | if download:
311 | self.download()
312 |
313 | if not self._check_integrity():
314 | raise RuntimeError('Dataset not found or corrupted.' +
315 | ' You can use download=True to download it')
316 |
317 | split_map = {
318 | "train": 0,
319 | "valid": 1,
320 | "test": 2,
321 | "all": None,
322 | }
323 | split = split_map[split.lower()]
324 |
325 | fn = partial(os.path.join, self.root, self.base_folder)
326 | splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
327 | identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
328 | bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
329 | landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
330 | attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
331 |
332 | mask = slice(None) if split is None else (splits[1] == split)
333 |
334 | self.filename = splits[mask].index.values
335 | self.identity = torch.as_tensor(identity[mask].values)
336 | self.bbox = torch.as_tensor(bbox[mask].values)
337 | self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
338 | self.attr = torch.as_tensor(attr[mask].values)
339 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
340 | self.attr_names = list(attr.columns)
341 | self.gender_index = self.attr_names.index('Male')
342 | self.dataidxs = dataidxs
343 | if self.dataidxs is None:
344 | self.target = self.attr[:, self.gender_index:self.gender_index + 1].reshape(-1)
345 | else:
346 | self.target = self.attr[self.dataidxs, self.gender_index:self.gender_index + 1].reshape(-1)
347 |
348 | def _check_integrity(self):
349 | for (_, md5, filename) in self.file_list:
350 | fpath = os.path.join(self.root, self.base_folder, filename)
351 | _, ext = os.path.splitext(filename)
352 | # Allow original archive to be deleted (zip and 7z)
353 | # Only need the extracted images
354 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
355 | return False
356 |
357 | # Should check a hash of the images
358 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
359 |
360 | def download(self):
361 | import zipfile
362 |
363 | if self._check_integrity():
364 | print('Files already downloaded and verified')
365 | return
366 |
367 | for (file_id, md5, filename) in self.file_list:
368 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
369 |
370 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
371 | f.extractall(os.path.join(self.root, self.base_folder))
372 |
373 | def __getitem__(self, index):
374 | if self.dataidxs is None:
375 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
376 |
377 | target = []
378 | for t in self.target_type:
379 | if t == "attr":
380 | target.append(self.attr[index, self.gender_index])
381 | elif t == "identity":
382 | target.append(self.identity[index, 0])
383 | elif t == "bbox":
384 | target.append(self.bbox[index, :])
385 | elif t == "landmarks":
386 | target.append(self.landmarks_align[index, :])
387 | else:
388 | # TODO: refactor with utils.verify_str_arg
389 | raise ValueError("Target type \"{}\" is not recognized.".format(t))
390 | else:
391 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[self.dataidxs[index]]))
392 |
393 | target = []
394 | for t in self.target_type:
395 | if t == "attr":
396 | target.append(self.attr[self.dataidxs[index], self.gender_index])
397 | elif t == "identity":
398 | target.append(self.identity[self.dataidxs[index], 0])
399 | elif t == "bbox":
400 | target.append(self.bbox[self.dataidxs[index], :])
401 | elif t == "landmarks":
402 | target.append(self.landmarks_align[self.dataidxs[index], :])
403 | else:
404 | # TODO: refactor with utils.verify_str_arg
405 | raise ValueError("Target type \"{}\" is not recognized.".format(t))
406 |
407 | if self.transform is not None:
408 | X = self.transform(X)
409 | #print("target[0]:", target[0])
410 | if target:
411 | target = tuple(target) if len(target) > 1 else target[0]
412 |
413 | if self.target_transform is not None:
414 | target = self.target_transform(target)
415 | else:
416 | target = None
417 | #print("celeba target:", target)
418 | return X, target
419 |
420 | def __len__(self):
421 | if self.dataidxs is None:
422 | return len(self.attr)
423 | else:
424 | return len(self.dataidxs)
425 |
426 | def extra_repr(self):
427 | lines = ["Target type: {target_type}", "Split: {split}"]
428 | return '\n'.join(lines).format(**self.__dict__)
429 |
430 |
431 |
432 | class CIFAR10_truncated(data.Dataset):
433 |
434 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
435 |
436 | self.root = root
437 | self.dataidxs = dataidxs
438 | self.train = train
439 | self.transform = transform
440 | self.target_transform = target_transform
441 | self.download = download
442 |
443 | self.data, self.target = self.__build_truncated_dataset__()
444 |
445 | def __build_truncated_dataset__(self):
446 |
447 | cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)
448 |
449 | data = cifar_dataobj.data
450 | target = np.array(cifar_dataobj.targets)
451 |
452 | if self.dataidxs is not None:
453 | data = data[self.dataidxs]
454 | target = target[self.dataidxs]
455 |
456 | return data, target
457 |
458 | def truncate_channel(self, index):
459 | for i in range(index.shape[0]):
460 | gs_index = index[i]
461 | self.data[gs_index, :, :, 1] = 0.0
462 | self.data[gs_index, :, :, 2] = 0.0
463 |
464 | def __getitem__(self, index):
465 | """
466 | Args:
467 | index (int): Index
468 |
469 | Returns:
470 | tuple: (image, target) where target is index of the target class.
471 | """
472 | img, target = self.data[index], self.target[index]
473 |
474 | # print("cifar10 img:", img)
475 | # print("cifar10 target:", target)
476 |
477 | if self.transform is not None:
478 | img = self.transform(img)
479 |
480 | if self.target_transform is not None:
481 | target = self.target_transform(target)
482 |
483 | return img, target
484 |
485 | def __len__(self):
486 | return len(self.data)
487 |
488 | def gen_bar_updater() -> Callable[[int, int, int], None]:
489 | pbar = tqdm(total=None)
490 |
491 | def bar_update(count, block_size, total_size):
492 | if pbar.total is None and total_size:
493 | pbar.total = total_size
494 | progress_bytes = count * block_size
495 | pbar.update(progress_bytes - pbar.n)
496 |
497 | return bar_update
498 |
499 |
500 | def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
501 | """Download a file from a url and place it in root.
502 | Args:
503 | url (str): URL to download file from
504 | root (str): Directory to place downloaded file in
505 | filename (str, optional): Name to save the file under. If None, use the basename of the URL
506 | md5 (str, optional): MD5 checksum of the download. If None, do not check
507 | """
508 | import urllib
509 |
510 | root = os.path.expanduser(root)
511 | if not filename:
512 | filename = os.path.basename(url)
513 | fpath = os.path.join(root, filename)
514 |
515 | os.makedirs(root, exist_ok=True)
516 |
517 | # check if file is already present locally
518 | if check_integrity(fpath, md5):
519 | print('Using downloaded and verified file: ' + fpath)
520 | else: # download the file
521 | try:
522 | print('Downloading ' + url + ' to ' + fpath)
523 | urllib.request.urlretrieve(
524 | url, fpath,
525 | reporthook=gen_bar_updater()
526 | )
527 | except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
528 | if url[:5] == 'https':
529 | url = url.replace('https:', 'http:')
530 | print('Failed download. Trying https -> http instead.'
531 | ' Downloading ' + url + ' to ' + fpath)
532 | urllib.request.urlretrieve(
533 | url, fpath,
534 | reporthook=gen_bar_updater()
535 | )
536 | else:
537 | raise e
538 | # check integrity of downloaded file
539 | if not check_integrity(fpath, md5):
540 | raise RuntimeError("File not found or corrupted.")
541 |
542 | def _is_tarxz(filename: str) -> bool:
543 | return filename.endswith(".tar.xz")
544 |
545 |
546 | def _is_tar(filename: str) -> bool:
547 | return filename.endswith(".tar")
548 |
549 |
550 | def _is_targz(filename: str) -> bool:
551 | return filename.endswith(".tar.gz")
552 |
553 |
554 | def _is_tgz(filename: str) -> bool:
555 | return filename.endswith(".tgz")
556 |
557 |
558 | def _is_gzip(filename: str) -> bool:
559 | return filename.endswith(".gz") and not filename.endswith(".tar.gz")
560 |
561 |
562 | def _is_zip(filename: str) -> bool:
563 | return filename.endswith(".zip")
564 |
565 |
566 | def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
567 | if to_path is None:
568 | to_path = os.path.dirname(from_path)
569 |
570 | if _is_tar(from_path):
571 | with tarfile.open(from_path, 'r') as tar:
572 | def is_within_directory(directory, target):
573 |
574 | abs_directory = os.path.abspath(directory)
575 | abs_target = os.path.abspath(target)
576 |
577 | prefix = os.path.commonprefix([abs_directory, abs_target])
578 |
579 | return prefix == abs_directory
580 |
581 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
582 |
583 | for member in tar.getmembers():
584 | member_path = os.path.join(path, member.name)
585 | if not is_within_directory(path, member_path):
586 | raise Exception("Attempted Path Traversal in Tar File")
587 |
588 | tar.extractall(path, members, numeric_owner=numeric_owner)
589 |
590 |
591 | safe_extract(tar, path=to_path)
592 | elif _is_targz(from_path) or _is_tgz(from_path):
593 | with tarfile.open(from_path, 'r:gz') as tar:
594 | def is_within_directory(directory, target):
595 |
596 | abs_directory = os.path.abspath(directory)
597 | abs_target = os.path.abspath(target)
598 |
599 | prefix = os.path.commonprefix([abs_directory, abs_target])
600 |
601 | return prefix == abs_directory
602 |
603 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
604 |
605 | for member in tar.getmembers():
606 | member_path = os.path.join(path, member.name)
607 | if not is_within_directory(path, member_path):
608 | raise Exception("Attempted Path Traversal in Tar File")
609 |
610 | tar.extractall(path, members, numeric_owner=numeric_owner)
611 |
612 |
613 | safe_extract(tar, path=to_path)
614 | elif _is_tarxz(from_path):
615 | with tarfile.open(from_path, 'r:xz') as tar:
616 | def is_within_directory(directory, target):
617 |
618 | abs_directory = os.path.abspath(directory)
619 | abs_target = os.path.abspath(target)
620 |
621 | prefix = os.path.commonprefix([abs_directory, abs_target])
622 |
623 | return prefix == abs_directory
624 |
625 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
626 |
627 | for member in tar.getmembers():
628 | member_path = os.path.join(path, member.name)
629 | if not is_within_directory(path, member_path):
630 | raise Exception("Attempted Path Traversal in Tar File")
631 |
632 | tar.extractall(path, members, numeric_owner=numeric_owner)
633 |
634 |
635 | safe_extract(tar, path=to_path)
636 | elif _is_gzip(from_path):
637 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
638 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
639 | out_f.write(zip_f.read())
640 | elif _is_zip(from_path):
641 | with zipfile.ZipFile(from_path, 'r') as z:
642 | z.extractall(to_path)
643 | else:
644 | raise ValueError("Extraction of {} not supported".format(from_path))
645 |
646 | if remove_finished:
647 | os.remove(from_path)
648 |
649 |
650 | def download_and_extract_archive(
651 | url: str,
652 | download_root: str,
653 | extract_root: Optional[str] = None,
654 | filename: Optional[str] = None,
655 | md5: Optional[str] = None,
656 | remove_finished: bool = False,
657 | ) -> None:
658 | download_root = os.path.expanduser(download_root)
659 | if extract_root is None:
660 | extract_root = download_root
661 | if not filename:
662 | filename = os.path.basename(url)
663 |
664 | download_url(url, download_root, filename, md5)
665 |
666 | archive = os.path.join(download_root, filename)
667 | print("Extracting {} to {}".format(archive, extract_root))
668 | extract_archive(archive, extract_root, remove_finished)
669 |
670 | class FEMNIST(MNIST):
671 | """
672 | This dataset is derived from the Leaf repository
673 | (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST
674 | dataset, grouping examples by writer. Details about Leaf were published in
675 | "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097.
676 | """
677 | resources = [
678 | ('https://raw.githubusercontent.com/tao-shen/FEMNIST_pytorch/master/femnist.tar.gz',
679 | '59c65cec646fc57fe92d27d83afdf0ed')]
680 |
681 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None,
682 | download=False):
683 | super(MNIST, self).__init__(root, transform=transform,
684 | target_transform=target_transform)
685 | self.train = train
686 | self.dataidxs = dataidxs
687 |
688 | if download:
689 | self.download()
690 |
691 | if not self._check_exists():
692 | raise RuntimeError('Dataset not found.' +
693 | ' You can use download=True to download it')
694 | if self.train:
695 | data_file = self.training_file
696 | else:
697 | data_file = self.test_file
698 |
699 | self.data, self.targets, self.users_index = torch.load(os.path.join(self.processed_folder, data_file))
700 |
701 | if self.dataidxs is not None:
702 | self.data = self.data[self.dataidxs]
703 | self.targets = self.targets[self.dataidxs]
704 |
705 |
706 | def __getitem__(self, index):
707 | img, target = self.data[index], int(self.targets[index])
708 | img = Image.fromarray(img.numpy(), mode='F')
709 | if self.transform is not None:
710 | img = self.transform(img)
711 | if self.target_transform is not None:
712 | target = self.target_transform(target)
713 | return img, target
714 |
715 | def download(self):
716 | """Download the FEMNIST data if it doesn't exist in processed_folder already."""
717 | import shutil
718 |
719 | if self._check_exists():
720 | return
721 |
722 | mkdirs(self.raw_folder)
723 | mkdirs(self.processed_folder)
724 |
725 | # download files
726 | for url, md5 in self.resources:
727 | filename = url.rpartition('/')[2]
728 | download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
729 |
730 | # process and save as torch files
731 | print('Processing...')
732 | shutil.move(os.path.join(self.raw_folder, self.training_file), self.processed_folder)
733 | shutil.move(os.path.join(self.raw_folder, self.test_file), self.processed_folder)
734 |
735 | def __len__(self):
736 | return len(self.data)
737 |
738 | def _check_exists(self) -> bool:
739 | return all(
740 | check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]+os.path.splitext(os.path.basename(url))[1]))
741 | for url, _ in self.resources
742 | )
743 |
744 |
745 | class Generated(MNIST):
746 |
747 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None,
748 | download=False):
749 | super(MNIST, self).__init__(root, transform=transform,
750 | target_transform=target_transform)
751 | self.train = train
752 | self.dataidxs = dataidxs
753 |
754 | if self.train:
755 | self.data = np.load("data/generated/X_train.npy")
756 | self.targets = np.load("data/generated/y_train.npy")
757 | else:
758 | self.data = np.load("data/generated/X_test.npy")
759 | self.targets = np.load("data/generated/y_test.npy")
760 |
761 | if self.dataidxs is not None:
762 | self.data = self.data[self.dataidxs]
763 | self.targets = self.targets[self.dataidxs]
764 |
765 |
766 | def __getitem__(self, index):
767 | data, target = self.data[index], self.targets[index]
768 | return data, target
769 |
770 | def __len__(self):
771 | return len(self.data)
772 |
773 |
774 |
775 | class genData(MNIST):
776 | def __init__(self, data, targets):
777 | self.data = data
778 | self.targets = targets
779 | def __getitem__(self,index):
780 | data, target = self.data[index], self.targets[index]
781 | return data, target
782 | def __len__(self):
783 | return len(self.data)
784 |
785 | class CIFAR100_truncated(data.Dataset):
786 |
787 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
788 |
789 | self.root = root
790 | self.dataidxs = dataidxs
791 | self.train = train
792 | self.transform = transform
793 | self.target_transform = target_transform
794 | self.download = download
795 |
796 | self.data, self.target = self.__build_truncated_dataset__()
797 |
798 | def __build_truncated_dataset__(self):
799 |
800 | cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download)
801 |
802 | if torchvision.__version__ == '0.2.1':
803 | if self.train:
804 | data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels)
805 | else:
806 | data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels)
807 | else:
808 | data = cifar_dataobj.data
809 | target = np.array(cifar_dataobj.targets)
810 |
811 | if self.dataidxs is not None:
812 | data = data[self.dataidxs]
813 | target = target[self.dataidxs]
814 |
815 | return data, target
816 |
817 | def __getitem__(self, index):
818 | """
819 | Args:
820 | index (int): Index
821 | Returns:
822 | tuple: (image, target) where target is index of the target class.
823 | """
824 | img, target = self.data[index], self.target[index]
825 | img = Image.fromarray(img)
826 | # print("cifar10 img:", img)
827 | # print("cifar10 target:", target)
828 |
829 | if self.transform is not None:
830 | img = self.transform(img)
831 |
832 | if self.target_transform is not None:
833 | target = self.target_transform(target)
834 |
835 | return img, target
836 |
837 | def __len__(self):
838 | return len(self.data)
839 |
840 |
841 |
842 |
843 | class ImageFolder_custom(DatasetFolder):
844 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=None):
845 | self.root = root
846 | self.dataidxs = dataidxs
847 | self.train = train
848 | self.transform = transform
849 | self.target_transform = target_transform
850 |
851 | imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform)
852 | self.loader = imagefolder_obj.loader
853 | if self.dataidxs is not None:
854 | self.samples = np.array(imagefolder_obj.samples)[self.dataidxs]
855 | else:
856 | self.samples = np.array(imagefolder_obj.samples)
857 |
858 | def __getitem__(self, index):
859 | path = self.samples[index][0]
860 | target = self.samples[index][1]
861 | target = int(target)
862 | sample = self.loader(path)
863 | if self.transform is not None:
864 | sample = self.transform(sample)
865 | if self.target_transform is not None:
866 | target = self.target_transform(target)
867 |
868 | return sample, target
869 |
870 | def __len__(self):
871 | if self.dataidxs is None:
872 | return len(self.samples)
873 | else:
874 | return len(self.dataidxs)
875 |
--------------------------------------------------------------------------------
/experiments.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import torch
4 | import torch.optim as optim
5 | import torch.nn as nn
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | from torch.autograd import Variable
9 | import torch.utils.data as data
10 | import argparse
11 | import logging
12 | import os
13 | import copy
14 | from math import *
15 | import random
16 |
17 | import datetime
18 | #from torch.utils.tensorboard import SummaryWriter
19 |
20 | from model import *
21 | from utils import *
22 | from vggmodel import *
23 | from resnetcifar import *
24 |
25 | def get_args():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--model', type=str, default='mlp', help='neural network used in training')
28 | parser.add_argument('--dataset', type=str, default='mnist', help='dataset used for training')
29 | parser.add_argument('--net_config', type=lambda x: list(map(int, x.split(', '))))
30 | parser.add_argument('--partition', type=str, default='homo', help='the data partitioning strategy')
31 | parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)')
32 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 0.01)')
33 | parser.add_argument('--epochs', type=int, default=5, help='number of local epochs')
34 | parser.add_argument('--n_parties', type=int, default=2, help='number of workers in a distributed cluster')
35 | parser.add_argument('--alg', type=str, default='fedavg',
36 | help='fl algorithms: fedavg/fedprox/scaffold/fednova/moon')
37 | parser.add_argument('--use_projection_head', type=bool, default=False, help='whether add an additional header to model or not (see MOON)')
38 | parser.add_argument('--out_dim', type=int, default=256, help='the output dimension for the projection layer')
39 | parser.add_argument('--loss', type=str, default='contrastive', help='for moon')
40 | parser.add_argument('--temperature', type=float, default=0.5, help='the temperature parameter for contrastive loss')
41 | parser.add_argument('--comm_round', type=int, default=50, help='number of maximum communication roun')
42 | parser.add_argument('--is_same_initial', type=int, default=1, help='Whether initial all the models with the same parameters in fedavg')
43 | parser.add_argument('--init_seed', type=int, default=0, help="Random seed")
44 | parser.add_argument('--dropout_p', type=float, required=False, default=0.0, help="Dropout probability. Default=0.0")
45 | parser.add_argument('--datadir', type=str, required=False, default="./data/", help="Data directory")
46 | parser.add_argument('--reg', type=float, default=1e-5, help="L2 regularization strength")
47 | parser.add_argument('--logdir', type=str, required=False, default="./logs/", help='Log directory path')
48 | parser.add_argument('--modeldir', type=str, required=False, default="./models/", help='Model directory path')
49 | parser.add_argument('--beta', type=float, default=0.5, help='The parameter for the dirichlet distribution for data partitioning')
50 | parser.add_argument('--device', type=str, default='cuda:0', help='The device to run the program')
51 | parser.add_argument('--log_file_name', type=str, default=None, help='The log file name')
52 | parser.add_argument('--optimizer', type=str, default='sgd', help='the optimizer')
53 | parser.add_argument('--mu', type=float, default=0.001, help='the mu parameter for fedprox')
54 | parser.add_argument('--noise', type=float, default=0, help='how much noise we add to some party')
55 | parser.add_argument('--noise_type', type=str, default='level', help='Different level of noise or different space of noise')
56 | parser.add_argument('--rho', type=float, default=0, help='Parameter controlling the momentum SGD')
57 | parser.add_argument('--sample', type=float, default=1, help='Sample ratio for each communication round')
58 | args = parser.parse_args()
59 | return args
60 |
61 | def init_nets(net_configs, dropout_p, n_parties, args):
62 |
63 | nets = {net_i: None for net_i in range(n_parties)}
64 |
65 | if args.dataset in {'mnist', 'cifar10', 'svhn', 'fmnist'}:
66 | n_classes = 10
67 | elif args.dataset == 'celeba':
68 | n_classes = 2
69 | elif args.dataset == 'cifar100':
70 | n_classes = 100
71 | elif args.dataset == 'tinyimagenet':
72 | n_classes = 200
73 | elif args.dataset == 'femnist':
74 | n_classes = 62
75 | elif args.dataset == 'emnist':
76 | n_classes = 47
77 | elif args.dataset in {'a9a', 'covtype', 'rcv1', 'SUSY'}:
78 | n_classes = 2
79 | if args.use_projection_head:
80 | add = ""
81 | if "mnist" in args.dataset and args.model == "simple-cnn":
82 | add = "-mnist"
83 | for net_i in range(n_parties):
84 | net = ModelFedCon(args.model+add, args.out_dim, n_classes, net_configs)
85 | nets[net_i] = net
86 | else:
87 | if args.alg == 'moon':
88 | add = ""
89 | if "mnist" in args.dataset and args.model == "simple-cnn":
90 | add = "-mnist"
91 | for net_i in range(n_parties):
92 | net = ModelFedCon_noheader(args.model+add, args.out_dim, n_classes, net_configs)
93 | nets[net_i] = net
94 | else:
95 | for net_i in range(n_parties):
96 | if args.dataset == "generated":
97 | net = PerceptronModel()
98 | elif args.model == "mlp":
99 | if args.dataset == 'covtype':
100 | input_size = 54
101 | output_size = 2
102 | hidden_sizes = [32,16,8]
103 | elif args.dataset == 'a9a':
104 | input_size = 123
105 | output_size = 2
106 | hidden_sizes = [32,16,8]
107 | elif args.dataset == 'rcv1':
108 | input_size = 47236
109 | output_size = 2
110 | hidden_sizes = [32,16,8]
111 | elif args.dataset == 'SUSY':
112 | input_size = 18
113 | output_size = 2
114 | hidden_sizes = [16,8]
115 | net = FcNet(input_size, hidden_sizes, output_size, dropout_p)
116 | elif args.model == "vgg":
117 | net = vgg11()
118 | elif args.model == "simple-cnn":
119 | if args.dataset in ("cifar10", "cinic10", "svhn"):
120 | net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=10)
121 | elif args.dataset in ("mnist", 'femnist', 'fmnist'):
122 | net = SimpleCNNMNIST(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=10)
123 | elif args.dataset == 'celeba':
124 | net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=2)
125 | elif args.model == "vgg-9":
126 | if args.dataset in ("mnist", 'femnist'):
127 | net = ModerateCNNMNIST()
128 | elif args.dataset in ("cifar10", "cinic10", "svhn"):
129 | # print("in moderate cnn")
130 | net = ModerateCNN()
131 | elif args.dataset == 'celeba':
132 | net = ModerateCNN(output_dim=2)
133 | elif args.model == "resnet":
134 | net = ResNet50_cifar10()
135 | elif args.model == "vgg16":
136 | net = vgg16()
137 | else:
138 | print("not supported yet")
139 | exit(1)
140 | nets[net_i] = net
141 |
142 | model_meta_data = []
143 | layer_type = []
144 | for (k, v) in nets[0].state_dict().items():
145 | model_meta_data.append(v.shape)
146 | layer_type.append(k)
147 | return nets, model_meta_data, layer_type
148 |
149 |
150 | def train_net(net_id, net, train_dataloader, test_dataloader, epochs, lr, args_optimizer, device="cpu"):
151 | logger.info('Training network %s' % str(net_id))
152 |
153 | train_acc = compute_accuracy(net, train_dataloader, device=device)
154 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
155 |
156 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc))
157 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc))
158 |
159 | if args_optimizer == 'adam':
160 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg)
161 | elif args_optimizer == 'amsgrad':
162 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg,
163 | amsgrad=True)
164 | elif args_optimizer == 'sgd':
165 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg)
166 | criterion = nn.CrossEntropyLoss().to(device)
167 |
168 | cnt = 0
169 | if type(train_dataloader) == type([1]):
170 | pass
171 | else:
172 | train_dataloader = [train_dataloader]
173 |
174 | #writer = SummaryWriter()
175 |
176 | for epoch in range(epochs):
177 | epoch_loss_collector = []
178 | for tmp in train_dataloader:
179 | for batch_idx, (x, target) in enumerate(tmp):
180 | x, target = x.to(device), target.to(device)
181 |
182 | optimizer.zero_grad()
183 | x.requires_grad = True
184 | target.requires_grad = False
185 | target = target.long()
186 |
187 | out = net(x)
188 | loss = criterion(out, target)
189 |
190 | loss.backward()
191 | optimizer.step()
192 |
193 | cnt += 1
194 | epoch_loss_collector.append(loss.item())
195 |
196 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
197 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
198 |
199 | #train_acc = compute_accuracy(net, train_dataloader, device=device)
200 | #test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
201 |
202 | #writer.add_scalar('Accuracy/train', train_acc, epoch)
203 | #writer.add_scalar('Accuracy/test', test_acc, epoch)
204 |
205 | # if epoch % 10 == 0:
206 | # logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
207 | # train_acc = compute_accuracy(net, train_dataloader, device=device)
208 | # test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
209 | #
210 | # logger.info('>> Training accuracy: %f' % train_acc)
211 | # logger.info('>> Test accuracy: %f' % test_acc)
212 |
213 | train_acc = compute_accuracy(net, train_dataloader, device=device)
214 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
215 |
216 | logger.info('>> Training accuracy: %f' % train_acc)
217 | logger.info('>> Test accuracy: %f' % test_acc)
218 |
219 | net.to('cpu')
220 | logger.info(' ** Training complete **')
221 | return train_acc, test_acc
222 |
223 |
224 |
225 | def train_net_fedprox(net_id, net, global_net, train_dataloader, test_dataloader, epochs, lr, args_optimizer, mu, device="cpu"):
226 | logger.info('Training network %s' % str(net_id))
227 | logger.info('n_training: %d' % len(train_dataloader))
228 | logger.info('n_test: %d' % len(test_dataloader))
229 |
230 | train_acc = compute_accuracy(net, train_dataloader, device=device)
231 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
232 |
233 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc))
234 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc))
235 |
236 |
237 | if args_optimizer == 'adam':
238 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg)
239 | elif args_optimizer == 'amsgrad':
240 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg,
241 | amsgrad=True)
242 | elif args_optimizer == 'sgd':
243 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg)
244 |
245 | criterion = nn.CrossEntropyLoss().to(device)
246 |
247 | cnt = 0
248 | # mu = 0.001
249 | global_weight_collector = list(global_net.to(device).parameters())
250 |
251 | for epoch in range(epochs):
252 | epoch_loss_collector = []
253 | for batch_idx, (x, target) in enumerate(train_dataloader):
254 | x, target = x.to(device), target.to(device)
255 |
256 | optimizer.zero_grad()
257 | x.requires_grad = True
258 | target.requires_grad = False
259 | target = target.long()
260 |
261 | out = net(x)
262 | loss = criterion(out, target)
263 |
264 | #for fedprox
265 | fed_prox_reg = 0.0
266 | for param_index, param in enumerate(net.parameters()):
267 | fed_prox_reg += ((mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2)
268 | loss += fed_prox_reg
269 |
270 |
271 | loss.backward()
272 | optimizer.step()
273 |
274 | cnt += 1
275 | epoch_loss_collector.append(loss.item())
276 |
277 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
278 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
279 |
280 | # if epoch % 10 == 0:
281 | # train_acc = compute_accuracy(net, train_dataloader, device=device)
282 | # test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
283 | #
284 | # logger.info('>> Training accuracy: %f' % train_acc)
285 | # logger.info('>> Test accuracy: %f' % test_acc)
286 |
287 | train_acc = compute_accuracy(net, train_dataloader, device=device)
288 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
289 |
290 | logger.info('>> Training accuracy: %f' % train_acc)
291 | logger.info('>> Test accuracy: %f' % test_acc)
292 |
293 | net.to('cpu')
294 | logger.info(' ** Training complete **')
295 | return train_acc, test_acc
296 |
297 | def train_net_scaffold(net_id, net, global_model, c_local, c_global, train_dataloader, test_dataloader, epochs, lr, args_optimizer, device="cpu"):
298 | logger.info('Training network %s' % str(net_id))
299 |
300 | train_acc = compute_accuracy(net, train_dataloader, device=device)
301 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
302 |
303 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc))
304 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc))
305 |
306 | if args_optimizer == 'adam':
307 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg)
308 | elif args_optimizer == 'amsgrad':
309 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg,
310 | amsgrad=True)
311 | elif args_optimizer == 'sgd':
312 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg)
313 | criterion = nn.CrossEntropyLoss().to(device)
314 |
315 | cnt = 0
316 | if type(train_dataloader) == type([1]):
317 | pass
318 | else:
319 | train_dataloader = [train_dataloader]
320 |
321 | #writer = SummaryWriter()
322 |
323 | c_local.to(device)
324 | c_global.to(device)
325 | global_model.to(device)
326 |
327 | c_global_para = c_global.state_dict()
328 | c_local_para = c_local.state_dict()
329 |
330 | for epoch in range(epochs):
331 | epoch_loss_collector = []
332 | for tmp in train_dataloader:
333 | for batch_idx, (x, target) in enumerate(tmp):
334 | x, target = x.to(device), target.to(device)
335 |
336 | optimizer.zero_grad()
337 | x.requires_grad = True
338 | target.requires_grad = False
339 | target = target.long()
340 |
341 | out = net(x)
342 | loss = criterion(out, target)
343 |
344 | loss.backward()
345 | optimizer.step()
346 |
347 | net_para = net.state_dict()
348 | for key in net_para:
349 | net_para[key] = net_para[key] - args.lr * (c_global_para[key] - c_local_para[key])
350 | net.load_state_dict(net_para)
351 |
352 | cnt += 1
353 | epoch_loss_collector.append(loss.item())
354 |
355 |
356 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
357 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
358 |
359 | c_new_para = c_local.state_dict()
360 | c_delta_para = copy.deepcopy(c_local.state_dict())
361 | global_model_para = global_model.state_dict()
362 | net_para = net.state_dict()
363 | for key in net_para:
364 | c_new_para[key] = c_new_para[key] - c_global_para[key] + (global_model_para[key] - net_para[key]) / (cnt * args.lr)
365 | c_delta_para[key] = c_new_para[key] - c_local_para[key]
366 | c_local.load_state_dict(c_new_para)
367 |
368 |
369 | train_acc = compute_accuracy(net, train_dataloader, device=device)
370 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
371 |
372 | logger.info('>> Training accuracy: %f' % train_acc)
373 | logger.info('>> Test accuracy: %f' % test_acc)
374 |
375 | net.to('cpu')
376 | logger.info(' ** Training complete **')
377 | return train_acc, test_acc, c_delta_para
378 |
379 | def train_net_fednova(net_id, net, global_model, train_dataloader, test_dataloader, epochs, lr, args_optimizer, device="cpu"):
380 | logger.info('Training network %s' % str(net_id))
381 |
382 | train_acc = compute_accuracy(net, train_dataloader, device=device)
383 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
384 |
385 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc))
386 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc))
387 |
388 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg)
389 | criterion = nn.CrossEntropyLoss().to(device)
390 |
391 | if type(train_dataloader) == type([1]):
392 | pass
393 | else:
394 | train_dataloader = [train_dataloader]
395 |
396 | #writer = SummaryWriter()
397 |
398 |
399 | tau = 0
400 |
401 | for epoch in range(epochs):
402 | epoch_loss_collector = []
403 | for tmp in train_dataloader:
404 | for batch_idx, (x, target) in enumerate(tmp):
405 | x, target = x.to(device), target.to(device)
406 |
407 | optimizer.zero_grad()
408 | x.requires_grad = True
409 | target.requires_grad = False
410 | target = target.long()
411 |
412 | out = net(x)
413 | loss = criterion(out, target)
414 |
415 | loss.backward()
416 | optimizer.step()
417 |
418 | tau = tau + 1
419 |
420 | epoch_loss_collector.append(loss.item())
421 |
422 |
423 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
424 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))
425 |
426 | global_model.to(device)
427 | a_i = (tau - args.rho * (1 - pow(args.rho, tau)) / (1 - args.rho)) / (1 - args.rho)
428 | global_model.to(device)
429 | global_model_para = global_model.state_dict()
430 | net_para = net.state_dict()
431 | norm_grad = copy.deepcopy(global_model.state_dict())
432 | for key in norm_grad:
433 | #norm_grad[key] = (global_model_para[key] - net_para[key]) / a_i
434 | norm_grad[key] = torch.true_divide(global_model_para[key]-net_para[key], a_i)
435 | train_acc = compute_accuracy(net, train_dataloader, device=device)
436 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device)
437 |
438 | logger.info('>> Training accuracy: %f' % train_acc)
439 | logger.info('>> Test accuracy: %f' % test_acc)
440 |
441 | net.to('cpu')
442 | logger.info(' ** Training complete **')
443 | return train_acc, test_acc, a_i, norm_grad
444 |
445 |
446 | def train_net_moon(net_id, net, global_net, previous_nets, train_dataloader, test_dataloader, epochs, lr, args_optimizer, mu, temperature, args,
447 | round, device="cpu"):
448 |
449 | logger.info('Training network %s' % str(net_id))
450 |
451 | train_acc = compute_accuracy(net, train_dataloader, moon_model=True, device=device)
452 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, moon_model=True, device=device)
453 |
454 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc))
455 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc))
456 |
457 | # conloss = ContrastiveLoss(temperature)
458 |
459 | if args_optimizer == 'adam':
460 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg)
461 | elif args_optimizer == 'amsgrad':
462 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg,
463 | amsgrad=True)
464 | elif args_optimizer == 'sgd':
465 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=0.9,
466 | weight_decay=args.reg)
467 |
468 | criterion = nn.CrossEntropyLoss().to(device)
469 | # global_net.to(device)
470 |
471 | if args.loss != 'l2norm':
472 | for previous_net in previous_nets:
473 | previous_net.to(device)
474 | global_w = global_net.state_dict()
475 | # oppsi_nets = copy.deepcopy(previous_nets)
476 | # for net_id, oppsi_net in enumerate(oppsi_nets):
477 | # oppsi_w = oppsi_net.state_dict()
478 | # prev_w = previous_nets[net_id].state_dict()
479 | # for key in oppsi_w:
480 | # oppsi_w[key] = 2*global_w[key] - prev_w[key]
481 | # oppsi_nets.load_state_dict(oppsi_w)
482 | cnt = 0
483 | cos=torch.nn.CosineSimilarity(dim=-1).to(device)
484 | # mu = 0.001
485 |
486 | for epoch in range(epochs):
487 | epoch_loss_collector = []
488 | epoch_loss1_collector = []
489 | epoch_loss2_collector = []
490 | for batch_idx, (x, target) in enumerate(train_dataloader):
491 | x, target = x.to(device), target.to(device)
492 | if target.shape[0] == 1:
493 | continue
494 |
495 | optimizer.zero_grad()
496 | x.requires_grad = True
497 | target.requires_grad = False
498 | target = target.long()
499 |
500 | _, pro1, out = net(x)
501 | _, pro2, _ = global_net(x)
502 | if args.loss == 'l2norm':
503 | loss2 = mu * torch.mean(torch.norm(pro2-pro1, dim=1))
504 |
505 | elif args.loss == 'only_contrastive' or args.loss == 'contrastive':
506 | posi = cos(pro1, pro2)
507 | logits = posi.reshape(-1,1)
508 |
509 | for previous_net in previous_nets:
510 | previous_net.to(device)
511 | _, pro3, _ = previous_net(x)
512 | nega = cos(pro1, pro3)
513 | logits = torch.cat((logits, nega.reshape(-1,1)), dim=1)
514 |
515 | # previous_net.to('cpu')
516 |
517 | logits /= temperature
518 | labels = torch.zeros(x.size(0)).to(device).long()
519 |
520 | # loss = criterion(out, target) + mu * ContraLoss(pro1, pro2, pro3)
521 |
522 | loss2 = mu * criterion(logits, labels)
523 |
524 | if args.loss == 'only_contrastive':
525 | loss = loss2
526 | else:
527 | loss1 = criterion(out, target)
528 | loss = loss1 + loss2
529 |
530 | loss.backward()
531 | optimizer.step()
532 |
533 | cnt += 1
534 | epoch_loss_collector.append(loss.item())
535 | epoch_loss1_collector.append(loss1.item())
536 | epoch_loss2_collector.append(loss2.item())
537 |
538 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
539 | epoch_loss1 = sum(epoch_loss1_collector) / len(epoch_loss1_collector)
540 | epoch_loss2 = sum(epoch_loss2_collector) / len(epoch_loss2_collector)
541 | logger.info('Epoch: %d Loss: %f Loss1: %f Loss2: %f' % (epoch, epoch_loss, epoch_loss1, epoch_loss2))
542 |
543 |
544 | if args.loss != 'l2norm':
545 | for previous_net in previous_nets:
546 | previous_net.to('cpu')
547 | train_acc = compute_accuracy(net, train_dataloader, moon_model=True, device=device)
548 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, moon_model=True, device=device)
549 |
550 | logger.info('>> Training accuracy: %f' % train_acc)
551 | logger.info('>> Test accuracy: %f' % test_acc)
552 | net.to('cpu')
553 | logger.info(' ** Training complete **')
554 | return train_acc, test_acc
555 |
556 |
557 | def view_image(train_dataloader):
558 | for (x, target) in train_dataloader:
559 | np.save("img.npy", x)
560 | print(x.shape)
561 | exit(0)
562 |
563 |
564 | def local_train_net(nets, selected, args, net_dataidx_map, test_dl = None, device="cpu"):
565 | avg_acc = 0.0
566 |
567 | for net_id, net in nets.items():
568 | if net_id not in selected:
569 | continue
570 | dataidxs = net_dataidx_map[net_id]
571 |
572 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs)))
573 | # move the model to cuda device:
574 | net.to(device)
575 |
576 | noise_level = args.noise
577 | if net_id == args.n_parties - 1:
578 | noise_level = 0
579 |
580 | if args.noise_type == 'space':
581 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1)
582 | else:
583 | noise_level = args.noise / (args.n_parties - 1) * net_id
584 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
585 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32)
586 | n_epoch = args.epochs
587 |
588 |
589 | trainacc, testacc = train_net(net_id, net, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device)
590 | logger.info("net %d final test acc %f" % (net_id, testacc))
591 | avg_acc += testacc
592 | # saving the trained models here
593 | # save_model(net, net_id, args)
594 | # else:
595 | # load_model(net, net_id, device=device)
596 | avg_acc /= len(selected)
597 | if args.alg == 'local_training':
598 | logger.info("avg test acc %f" % avg_acc)
599 |
600 | nets_list = list(nets.values())
601 | return nets_list
602 |
603 |
604 | def local_train_net_fedprox(nets, selected, global_model, args, net_dataidx_map, test_dl = None, device="cpu"):
605 | avg_acc = 0.0
606 |
607 | for net_id, net in nets.items():
608 | if net_id not in selected:
609 | continue
610 | dataidxs = net_dataidx_map[net_id]
611 |
612 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs)))
613 | # move the model to cuda device:
614 | net.to(device)
615 |
616 | noise_level = args.noise
617 | if net_id == args.n_parties - 1:
618 | noise_level = 0
619 |
620 | if args.noise_type == 'space':
621 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1)
622 | else:
623 | noise_level = args.noise / (args.n_parties - 1) * net_id
624 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
625 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32)
626 | n_epoch = args.epochs
627 |
628 | trainacc, testacc = train_net_fedprox(net_id, net, global_model, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, args.mu, device=device)
629 | logger.info("net %d final test acc %f" % (net_id, testacc))
630 | avg_acc += testacc
631 | avg_acc /= len(selected)
632 | if args.alg == 'local_training':
633 | logger.info("avg test acc %f" % avg_acc)
634 |
635 | nets_list = list(nets.values())
636 | return nets_list
637 |
638 | def local_train_net_scaffold(nets, selected, global_model, c_nets, c_global, args, net_dataidx_map, test_dl = None, device="cpu"):
639 | avg_acc = 0.0
640 |
641 | total_delta = copy.deepcopy(global_model.state_dict())
642 | for key in total_delta:
643 | total_delta[key] = 0.0
644 | c_global.to(device)
645 | global_model.to(device)
646 | for net_id, net in nets.items():
647 | if net_id not in selected:
648 | continue
649 | dataidxs = net_dataidx_map[net_id]
650 |
651 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs)))
652 | # move the model to cuda device:
653 | net.to(device)
654 |
655 | c_nets[net_id].to(device)
656 |
657 | noise_level = args.noise
658 | if net_id == args.n_parties - 1:
659 | noise_level = 0
660 |
661 | if args.noise_type == 'space':
662 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1)
663 | else:
664 | noise_level = args.noise / (args.n_parties - 1) * net_id
665 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
666 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32)
667 | n_epoch = args.epochs
668 |
669 |
670 | trainacc, testacc, c_delta_para = train_net_scaffold(net_id, net, global_model, c_nets[net_id], c_global, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device)
671 |
672 | c_nets[net_id].to('cpu')
673 | for key in total_delta:
674 | total_delta[key] += c_delta_para[key]
675 |
676 |
677 | logger.info("net %d final test acc %f" % (net_id, testacc))
678 | avg_acc += testacc
679 | for key in total_delta:
680 | total_delta[key] /= args.n_parties
681 | c_global_para = c_global.state_dict()
682 | for key in c_global_para:
683 | if c_global_para[key].type() == 'torch.LongTensor':
684 | c_global_para[key] += total_delta[key].type(torch.LongTensor)
685 | elif c_global_para[key].type() == 'torch.cuda.LongTensor':
686 | c_global_para[key] += total_delta[key].type(torch.cuda.LongTensor)
687 | else:
688 | #print(c_global_para[key].type())
689 | c_global_para[key] += total_delta[key]
690 | c_global.load_state_dict(c_global_para)
691 |
692 | avg_acc /= len(selected)
693 | if args.alg == 'local_training':
694 | logger.info("avg test acc %f" % avg_acc)
695 |
696 | nets_list = list(nets.values())
697 | return nets_list
698 |
699 | def local_train_net_fednova(nets, selected, global_model, args, net_dataidx_map, test_dl = None, device="cpu"):
700 | avg_acc = 0.0
701 |
702 | a_list = []
703 | d_list = []
704 | n_list = []
705 | global_model.to(device)
706 | for net_id, net in nets.items():
707 | if net_id not in selected:
708 | continue
709 | dataidxs = net_dataidx_map[net_id]
710 |
711 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs)))
712 | # move the model to cuda device:
713 | net.to(device)
714 |
715 | noise_level = args.noise
716 | if net_id == args.n_parties - 1:
717 | noise_level = 0
718 |
719 | if args.noise_type == 'space':
720 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1)
721 | else:
722 | noise_level = args.noise / (args.n_parties - 1) * net_id
723 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
724 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32)
725 | n_epoch = args.epochs
726 |
727 |
728 | trainacc, testacc, a_i, d_i = train_net_fednova(net_id, net, global_model, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device)
729 |
730 | a_list.append(a_i)
731 | d_list.append(d_i)
732 | n_i = len(train_dl_local.dataset)
733 | n_list.append(n_i)
734 | logger.info("net %d final test acc %f" % (net_id, testacc))
735 | avg_acc += testacc
736 |
737 |
738 | avg_acc /= len(selected)
739 | if args.alg == 'local_training':
740 | logger.info("avg test acc %f" % avg_acc)
741 |
742 | nets_list = list(nets.values())
743 | return nets_list, a_list, d_list, n_list
744 |
745 | def local_train_net_moon(nets, selected, args, net_dataidx_map, test_dl=None, global_model = None, prev_model_pool = None, round=None, device="cpu"):
746 | avg_acc = 0.0
747 | global_model.to(device)
748 | for net_id, net in nets.items():
749 | if net_id not in selected:
750 | continue
751 | dataidxs = net_dataidx_map[net_id]
752 |
753 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs)))
754 | net.to(device)
755 |
756 | noise_level = args.noise
757 | if net_id == args.n_parties - 1:
758 | noise_level = 0
759 |
760 | if args.noise_type == 'space':
761 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1)
762 | else:
763 | noise_level = args.noise / (args.n_parties - 1) * net_id
764 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
765 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32)
766 | n_epoch = args.epochs
767 |
768 | prev_models=[]
769 | for i in range(len(prev_model_pool)):
770 | prev_models.append(prev_model_pool[i][net_id])
771 | trainacc, testacc = train_net_moon(net_id, net, global_model, prev_models, train_dl_local, test_dl, n_epoch, args.lr,
772 | args.optimizer, args.mu, args.temperature, args, round, device=device)
773 | logger.info("net %d final test acc %f" % (net_id, testacc))
774 | avg_acc += testacc
775 |
776 | avg_acc /= len(selected)
777 | if args.alg == 'local_training':
778 | logger.info("avg test acc %f" % avg_acc)
779 | global_model.to('cpu')
780 | nets_list = list(nets.values())
781 | return nets_list
782 |
783 |
784 |
785 | def get_partition_dict(dataset, partition, n_parties, init_seed=0, datadir='./data', logdir='./logs', beta=0.5):
786 | seed = init_seed
787 | np.random.seed(seed)
788 | torch.manual_seed(seed)
789 | random.seed(seed)
790 | X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data(
791 | dataset, datadir, logdir, partition, n_parties, beta=beta)
792 |
793 | return net_dataidx_map
794 |
795 | if __name__ == '__main__':
796 | # torch.set_printoptions(profile="full")
797 | args = get_args()
798 | mkdirs(args.logdir)
799 | mkdirs(args.modeldir)
800 | if args.log_file_name is None:
801 | argument_path='experiment_arguments-%s.json' % datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S")
802 | else:
803 | argument_path=args.log_file_name+'.json'
804 | with open(os.path.join(args.logdir, argument_path), 'w') as f:
805 | json.dump(str(args), f)
806 | device = torch.device(args.device)
807 | # logging.basicConfig(filename='test.log', level=logger.info, filemode='w')
808 | # logging.info("test")
809 | for handler in logging.root.handlers[:]:
810 | logging.root.removeHandler(handler)
811 |
812 | if args.log_file_name is None:
813 | args.log_file_name = 'experiment_log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S"))
814 | log_path=args.log_file_name+'.log'
815 | logging.basicConfig(
816 | filename=os.path.join(args.logdir, log_path),
817 | # filename='/home/qinbin/test.log',
818 | format='%(asctime)s %(levelname)-8s %(message)s',
819 | datefmt='%m-%d %H:%M', level=logging.DEBUG, filemode='w')
820 |
821 | logger = logging.getLogger()
822 | logger.setLevel(logging.DEBUG)
823 | logger.info(device)
824 |
825 | seed = args.init_seed
826 | logger.info("#" * 100)
827 | np.random.seed(seed)
828 | torch.manual_seed(seed)
829 | random.seed(seed)
830 | logger.info("Partitioning data")
831 | X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data(
832 | args.dataset, args.datadir, args.logdir, args.partition, args.n_parties, beta=args.beta)
833 |
834 | n_classes = len(np.unique(y_train))
835 |
836 | train_dl_global, test_dl_global, train_ds_global, test_ds_global = get_dataloader(args.dataset,
837 | args.datadir,
838 | args.batch_size,
839 | 32)
840 |
841 | print("len train_dl_global:", len(train_ds_global))
842 |
843 |
844 | data_size = len(test_ds_global)
845 |
846 | # test_dl = data.DataLoader(dataset=test_ds_global, batch_size=32, shuffle=False)
847 |
848 | train_all_in_list = []
849 | test_all_in_list = []
850 | if args.noise > 0:
851 | for party_id in range(args.n_parties):
852 | dataidxs = net_dataidx_map[party_id]
853 |
854 | noise_level = args.noise
855 | if party_id == args.n_parties - 1:
856 | noise_level = 0
857 |
858 | if args.noise_type == 'space':
859 | train_dl_local, test_dl_local, train_ds_local, test_ds_local = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, party_id, args.n_parties-1)
860 | else:
861 | noise_level = args.noise / (args.n_parties - 1) * party_id
862 | train_dl_local, test_dl_local, train_ds_local, test_ds_local = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
863 | train_all_in_list.append(train_ds_local)
864 | test_all_in_list.append(test_ds_local)
865 | train_all_in_ds = data.ConcatDataset(train_all_in_list)
866 | train_dl_global = data.DataLoader(dataset=train_all_in_ds, batch_size=args.batch_size, shuffle=True)
867 | test_all_in_ds = data.ConcatDataset(test_all_in_list)
868 | test_dl_global = data.DataLoader(dataset=test_all_in_ds, batch_size=32, shuffle=False)
869 |
870 |
871 |
872 | if args.alg == 'fedavg':
873 | logger.info("Initializing nets")
874 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args)
875 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args)
876 | global_model = global_models[0]
877 |
878 | global_para = global_model.state_dict()
879 | if args.is_same_initial:
880 | for net_id, net in nets.items():
881 | net.load_state_dict(global_para)
882 |
883 | for round in range(args.comm_round):
884 | logger.info("in comm round:" + str(round))
885 |
886 | arr = np.arange(args.n_parties)
887 | np.random.shuffle(arr)
888 | selected = arr[:int(args.n_parties * args.sample)]
889 |
890 | global_para = global_model.state_dict()
891 | if round == 0:
892 | if args.is_same_initial:
893 | for idx in selected:
894 | nets[idx].load_state_dict(global_para)
895 | else:
896 | for idx in selected:
897 | nets[idx].load_state_dict(global_para)
898 |
899 | local_train_net(nets, selected, args, net_dataidx_map, test_dl = test_dl_global, device=device)
900 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device)
901 |
902 | # update global model
903 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
904 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
905 |
906 | for idx in range(len(selected)):
907 | net_para = nets[selected[idx]].cpu().state_dict()
908 | if idx == 0:
909 | for key in net_para:
910 | global_para[key] = net_para[key] * fed_avg_freqs[idx]
911 | else:
912 | for key in net_para:
913 | global_para[key] += net_para[key] * fed_avg_freqs[idx]
914 | global_model.load_state_dict(global_para)
915 |
916 | logger.info('global n_training: %d' % len(train_dl_global))
917 | logger.info('global n_test: %d' % len(test_dl_global))
918 |
919 | global_model.to(device)
920 | train_acc = compute_accuracy(global_model, train_dl_global, device=device)
921 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device)
922 |
923 |
924 | logger.info('>> Global Model Train accuracy: %f' % train_acc)
925 | logger.info('>> Global Model Test accuracy: %f' % test_acc)
926 |
927 |
928 | elif args.alg == 'fedprox':
929 | logger.info("Initializing nets")
930 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args)
931 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args)
932 | global_model = global_models[0]
933 |
934 | global_para = global_model.state_dict()
935 |
936 | if args.is_same_initial:
937 | for net_id, net in nets.items():
938 | net.load_state_dict(global_para)
939 |
940 | for round in range(args.comm_round):
941 | logger.info("in comm round:" + str(round))
942 |
943 | arr = np.arange(args.n_parties)
944 | np.random.shuffle(arr)
945 | selected = arr[:int(args.n_parties * args.sample)]
946 |
947 | global_para = global_model.state_dict()
948 | if round == 0:
949 | if args.is_same_initial:
950 | for idx in selected:
951 | nets[idx].load_state_dict(global_para)
952 | else:
953 | for idx in selected:
954 | nets[idx].load_state_dict(global_para)
955 |
956 | local_train_net_fedprox(nets, selected, global_model, args, net_dataidx_map, test_dl = test_dl_global, device=device)
957 | global_model.to('cpu')
958 |
959 | # update global model
960 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
961 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
962 |
963 | for idx in range(len(selected)):
964 | net_para = nets[selected[idx]].cpu().state_dict()
965 | if idx == 0:
966 | for key in net_para:
967 | global_para[key] = net_para[key] * fed_avg_freqs[idx]
968 | else:
969 | for key in net_para:
970 | global_para[key] += net_para[key] * fed_avg_freqs[idx]
971 | global_model.load_state_dict(global_para)
972 |
973 |
974 | logger.info('global n_training: %d' % len(train_dl_global))
975 | logger.info('global n_test: %d' % len(test_dl_global))
976 |
977 |
978 | global_model.to(device)
979 | train_acc = compute_accuracy(global_model, train_dl_global, device=device)
980 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device)
981 |
982 |
983 | logger.info('>> Global Model Train accuracy: %f' % train_acc)
984 | logger.info('>> Global Model Test accuracy: %f' % test_acc)
985 |
986 | elif args.alg == 'scaffold':
987 | logger.info("Initializing nets")
988 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args)
989 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args)
990 | global_model = global_models[0]
991 |
992 | c_nets, _, _ = init_nets(args.net_config, args.dropout_p, args.n_parties, args)
993 | c_globals, _, _ = init_nets(args.net_config, 0, 1, args)
994 | c_global = c_globals[0]
995 | c_global_para = c_global.state_dict()
996 | for net_id, net in c_nets.items():
997 | net.load_state_dict(c_global_para)
998 |
999 | global_para = global_model.state_dict()
1000 | if args.is_same_initial:
1001 | for net_id, net in nets.items():
1002 | net.load_state_dict(global_para)
1003 |
1004 |
1005 | for round in range(args.comm_round):
1006 | logger.info("in comm round:" + str(round))
1007 |
1008 | arr = np.arange(args.n_parties)
1009 | np.random.shuffle(arr)
1010 | selected = arr[:int(args.n_parties * args.sample)]
1011 |
1012 | global_para = global_model.state_dict()
1013 | if round == 0:
1014 | if args.is_same_initial:
1015 | for idx in selected:
1016 | nets[idx].load_state_dict(global_para)
1017 | else:
1018 | for idx in selected:
1019 | nets[idx].load_state_dict(global_para)
1020 |
1021 | local_train_net_scaffold(nets, selected, global_model, c_nets, c_global, args, net_dataidx_map, test_dl = test_dl_global, device=device)
1022 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device)
1023 |
1024 | # update global model
1025 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
1026 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
1027 |
1028 | for idx in range(len(selected)):
1029 | net_para = nets[selected[idx]].cpu().state_dict()
1030 | if idx == 0:
1031 | for key in net_para:
1032 | global_para[key] = net_para[key] * fed_avg_freqs[idx]
1033 | else:
1034 | for key in net_para:
1035 | global_para[key] += net_para[key] * fed_avg_freqs[idx]
1036 | global_model.load_state_dict(global_para)
1037 |
1038 |
1039 | logger.info('global n_training: %d' % len(train_dl_global))
1040 | logger.info('global n_test: %d' % len(test_dl_global))
1041 |
1042 | global_model.to(device)
1043 | train_acc = compute_accuracy(global_model, train_dl_global, device=device)
1044 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device)
1045 |
1046 | logger.info('>> Global Model Train accuracy: %f' % train_acc)
1047 | logger.info('>> Global Model Test accuracy: %f' % test_acc)
1048 |
1049 | elif args.alg == 'fednova':
1050 | logger.info("Initializing nets")
1051 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args)
1052 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args)
1053 | global_model = global_models[0]
1054 |
1055 | d_list = [copy.deepcopy(global_model.state_dict()) for i in range(args.n_parties)]
1056 | d_total_round = copy.deepcopy(global_model.state_dict())
1057 | for i in range(args.n_parties):
1058 | for key in d_list[i]:
1059 | d_list[i][key] = 0
1060 | for key in d_total_round:
1061 | d_total_round[key] = 0
1062 |
1063 | data_sum = 0
1064 | for i in range(args.n_parties):
1065 | data_sum += len(traindata_cls_counts[i])
1066 | portion = []
1067 | for i in range(args.n_parties):
1068 | portion.append(len(traindata_cls_counts[i]) / data_sum)
1069 |
1070 | global_para = global_model.state_dict()
1071 | if args.is_same_initial:
1072 | for net_id, net in nets.items():
1073 | net.load_state_dict(global_para)
1074 |
1075 | for round in range(args.comm_round):
1076 | logger.info("in comm round:" + str(round))
1077 |
1078 | arr = np.arange(args.n_parties)
1079 | np.random.shuffle(arr)
1080 | selected = arr[:int(args.n_parties * args.sample)]
1081 |
1082 | global_para = global_model.state_dict()
1083 | if round == 0:
1084 | if args.is_same_initial:
1085 | for idx in selected:
1086 | nets[idx].load_state_dict(global_para)
1087 | else:
1088 | for idx in selected:
1089 | nets[idx].load_state_dict(global_para)
1090 |
1091 | _, a_list, d_list, n_list = local_train_net_fednova(nets, selected, global_model, args, net_dataidx_map, test_dl = test_dl_global, device=device)
1092 | total_n = sum(n_list)
1093 | #print("total_n:", total_n)
1094 | d_total_round = copy.deepcopy(global_model.state_dict())
1095 | for key in d_total_round:
1096 | d_total_round[key] = 0.0
1097 |
1098 | for i in range(len(selected)):
1099 | d_para = d_list[i]
1100 | for key in d_para:
1101 | #if d_total_round[key].type == 'torch.LongTensor':
1102 | # d_total_round[key] += (d_para[key] * n_list[i] / total_n).type(torch.LongTensor)
1103 | #else:
1104 | d_total_round[key] += d_para[key] * n_list[i] / total_n
1105 |
1106 |
1107 | # for i in range(len(selected)):
1108 | # d_total_round = d_total_round + d_list[i] * n_list[i] / total_n
1109 |
1110 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device)
1111 |
1112 | # update global model
1113 | coeff = 0.0
1114 | for i in range(len(selected)):
1115 | coeff = coeff + a_list[i] * n_list[i]/total_n
1116 |
1117 | updated_model = global_model.state_dict()
1118 | for key in updated_model:
1119 | #print(updated_model[key])
1120 | if updated_model[key].type() == 'torch.LongTensor':
1121 | updated_model[key] -= (coeff * d_total_round[key]).type(torch.LongTensor)
1122 | elif updated_model[key].type() == 'torch.cuda.LongTensor':
1123 | updated_model[key] -= (coeff * d_total_round[key]).type(torch.cuda.LongTensor)
1124 | else:
1125 | #print(updated_model[key].type())
1126 | #print((coeff*d_total_round[key].type()))
1127 | updated_model[key] -= coeff * d_total_round[key]
1128 | global_model.load_state_dict(updated_model)
1129 |
1130 |
1131 | logger.info('global n_training: %d' % len(train_dl_global))
1132 | logger.info('global n_test: %d' % len(test_dl_global))
1133 |
1134 | global_model.to(device)
1135 | train_acc = compute_accuracy(global_model, train_dl_global, device=device)
1136 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device)
1137 |
1138 |
1139 | logger.info('>> Global Model Train accuracy: %f' % train_acc)
1140 | logger.info('>> Global Model Test accuracy: %f' % test_acc)
1141 |
1142 | elif args.alg == 'moon':
1143 | logger.info("Initializing nets")
1144 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args)
1145 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args)
1146 | global_model = global_models[0]
1147 |
1148 | global_para = global_model.state_dict()
1149 | if args.is_same_initial:
1150 | for net_id, net in nets.items():
1151 | net.load_state_dict(global_para)
1152 |
1153 | old_nets_pool = []
1154 | old_nets = copy.deepcopy(nets)
1155 | for _, net in old_nets.items():
1156 | net.eval()
1157 | for param in net.parameters():
1158 | param.requires_grad = False
1159 |
1160 | for round in range(args.comm_round):
1161 | logger.info("in comm round:" + str(round))
1162 |
1163 | arr = np.arange(args.n_parties)
1164 | np.random.shuffle(arr)
1165 | selected = arr[:int(args.n_parties * args.sample)]
1166 |
1167 | global_para = global_model.state_dict()
1168 | if round == 0:
1169 | if args.is_same_initial:
1170 | for idx in selected:
1171 | nets[idx].load_state_dict(global_para)
1172 | else:
1173 | for idx in selected:
1174 | nets[idx].load_state_dict(global_para)
1175 |
1176 | local_train_net_moon(nets, selected, args, net_dataidx_map, test_dl = test_dl_global, global_model=global_model,
1177 | prev_model_pool=old_nets_pool, round=round, device=device)
1178 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device)
1179 |
1180 | # update global model
1181 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
1182 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
1183 |
1184 | for idx in range(len(selected)):
1185 | net_para = nets[selected[idx]].cpu().state_dict()
1186 | if idx == 0:
1187 | for key in net_para:
1188 | global_para[key] = net_para[key] * fed_avg_freqs[idx]
1189 | else:
1190 | for key in net_para:
1191 | global_para[key] += net_para[key] * fed_avg_freqs[idx]
1192 | global_model.load_state_dict(global_para)
1193 |
1194 | logger.info('global n_training: %d' % len(train_dl_global))
1195 | logger.info('global n_test: %d' % len(test_dl_global))
1196 |
1197 | global_model.to(device)
1198 | train_acc = compute_accuracy(global_model, train_dl_global, moon_model=True, device=device)
1199 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, moon_model=True, device=device)
1200 |
1201 |
1202 | logger.info('>> Global Model Train accuracy: %f' % train_acc)
1203 | logger.info('>> Global Model Test accuracy: %f' % test_acc)
1204 |
1205 | old_nets = copy.deepcopy(nets)
1206 | for _, net in old_nets.items():
1207 | net.eval()
1208 | for param in net.parameters():
1209 | param.requires_grad = False
1210 | if len(old_nets_pool) < 1:
1211 | old_nets_pool.append(old_nets)
1212 | else:
1213 | old_nets_pool[0] = old_nets
1214 |
1215 | elif args.alg == 'local_training':
1216 | logger.info("Initializing nets")
1217 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args)
1218 | arr = np.arange(args.n_parties)
1219 | local_train_net(nets, arr, args, net_dataidx_map, test_dl = test_dl_global, device=device)
1220 |
1221 | elif args.alg == 'all_in':
1222 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, 1, args)
1223 | n_epoch = args.epochs
1224 | nets[0].to(device)
1225 | trainacc, testacc = train_net(0, nets[0], train_dl_global, test_dl_global, n_epoch, args.lr, args.optimizer, device=device)
1226 |
1227 | logger.info("All in test acc: %f" % testacc)
1228 |
1229 |
1230 |
--------------------------------------------------------------------------------
/femnist-dis.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/femnist-dis.npy
--------------------------------------------------------------------------------
/figures/100parties/cifar10-homo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-homo.png
--------------------------------------------------------------------------------
/figures/100parties/cifar10-lb2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-lb2.png
--------------------------------------------------------------------------------
/figures/100parties/cifar10-lbdir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-lbdir.png
--------------------------------------------------------------------------------
/figures/100parties/cifar10-quan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-quan.png
--------------------------------------------------------------------------------
/figures/10parties/cifar10-iid-diff-quantity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-iid-diff-quantity.png
--------------------------------------------------------------------------------
/figures/10parties/cifar10-noise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-noise.png
--------------------------------------------------------------------------------
/figures/10parties/cifar10-noniid-label2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-noniid-label2.png
--------------------------------------------------------------------------------
/figures/10parties/cifar10-noniid-labeldir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-noniid-labeldir.png
--------------------------------------------------------------------------------
/figures/heavy-model/resnet-noise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/heavy-model/resnet-noise.png
--------------------------------------------------------------------------------
/figures/heavy-model/vgg-lbdir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/heavy-model/vgg-lbdir.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | import torchvision.models as models
6 | from resnetcifar import ResNet18_cifar10, ResNet50_cifar10
7 |
8 |
9 | # class MLP_header(nn.Module):
10 | # def __init__(self,):
11 | # super(MLP_header, self).__init__()
12 | # self.fc1 = nn.Linear(28*28, 512)
13 | # self.fc2 = nn.Linear(512, 512)
14 | # self.relu = nn.ReLU()
15 | # #projection
16 | # # self.fc3 = nn.Linear(512, 10)
17 | #
18 | # def forward(self, x):
19 | # x = x.view(-1, 28*28)
20 | # x = self.fc1(x)
21 | # x = self.relu(x)
22 | # x = self.fc2(x)
23 | # x = self.relu(x)
24 | # return x
25 |
26 | class MLP_header(nn.Module):
27 | def __init__(self, input_dim, hidden_dims, dropout_p=0.0):
28 |
29 | super().__init__()
30 |
31 | self.input_dim = input_dim
32 | self.hidden_dims = hidden_dims
33 | self.dropout_p = dropout_p
34 |
35 | self.dims = [self.input_dim]
36 | self.dims.extend(hidden_dims)
37 |
38 | self.layers = nn.ModuleList([])
39 |
40 | for i in range(len(self.dims) - 1):
41 | ip_dim = self.dims[i]
42 | op_dim = self.dims[i + 1]
43 | self.layers.append(
44 | nn.Linear(ip_dim, op_dim, bias=True)
45 | )
46 |
47 | self.__init_net_weights__()
48 |
49 | def __init_net_weights__(self):
50 |
51 | for m in self.layers:
52 | m.weight.data.normal_(0.0, 0.1)
53 | m.bias.data.fill_(0.1)
54 |
55 | def forward(self, x):
56 | x = x.view(-1, self.input_dim)
57 | for i, layer in enumerate(self.layers):
58 | x = layer(x)
59 | x = F.relu(x)
60 |
61 |
62 | return x
63 |
64 |
65 | class FcNet(nn.Module):
66 | """
67 | Fully connected network for MNIST classification
68 | """
69 |
70 | def __init__(self, input_dim, hidden_dims, output_dim, dropout_p=0.0):
71 |
72 | super().__init__()
73 |
74 | self.input_dim = input_dim
75 | self.hidden_dims = hidden_dims
76 | self.output_dim = output_dim
77 | self.dropout_p = dropout_p
78 |
79 | self.dims = [self.input_dim]
80 | self.dims.extend(hidden_dims)
81 | self.dims.append(self.output_dim)
82 |
83 | self.layers = nn.ModuleList([])
84 |
85 | for i in range(len(self.dims) - 1):
86 | ip_dim = self.dims[i]
87 | op_dim = self.dims[i + 1]
88 | self.layers.append(
89 | nn.Linear(ip_dim, op_dim, bias=True)
90 | )
91 |
92 | self.__init_net_weights__()
93 |
94 | def __init_net_weights__(self):
95 |
96 | for m in self.layers:
97 | m.weight.data.normal_(0.0, 0.1)
98 | m.bias.data.fill_(0.1)
99 |
100 | def forward(self, x):
101 |
102 | x = x.view(-1, self.input_dim)
103 |
104 | for i, layer in enumerate(self.layers):
105 | x = layer(x)
106 |
107 | # Do not apply ReLU on the final layer
108 | if i < (len(self.layers) - 1):
109 | x = F.relu(x)
110 |
111 | if i < (len(self.layers) - 1): # No dropout on output layer
112 | x = F.dropout(x, p=self.dropout_p, training=self.training)
113 |
114 | return x
115 |
116 |
117 | class ConvBlock(nn.Module):
118 | def __init__(self):
119 | super(ConvBlock, self).__init__()
120 | self.conv1 = nn.Conv2d(3, 6, 5)
121 | self.pool = nn.MaxPool2d(2, 2)
122 | self.conv2 = nn.Conv2d(6, 16, 5)
123 |
124 | def forward(self, x):
125 | x = self.pool(F.relu(self.conv1(x)))
126 | x = self.pool(F.relu(self.conv2(x)))
127 | x = x.view(-1, 16 * 5 * 5)
128 | return x
129 |
130 |
131 | class FCBlock(nn.Module):
132 | def __init__(self, input_dim, hidden_dims, output_dim=10):
133 | super(FCBlock, self).__init__()
134 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
135 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
136 | self.fc3 = nn.Linear(hidden_dims[1], output_dim)
137 |
138 | def forward(self, x):
139 | x = F.relu(self.fc1(x))
140 | x = F.relu(self.fc2(x))
141 | x = self.fc3(x)
142 | return x
143 |
144 |
145 | class VGGConvBlocks(nn.Module):
146 | '''
147 | VGG model
148 | '''
149 |
150 | def __init__(self, features, num_classes=10):
151 | super(VGGConvBlocks, self).__init__()
152 | self.features = features
153 | # Initialize weights
154 | for m in self.modules():
155 | if isinstance(m, nn.Conv2d):
156 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
157 | m.weight.data.normal_(0, math.sqrt(2. / n))
158 | m.bias.data.zero_()
159 |
160 | def forward(self, x):
161 | x = self.features(x)
162 | x = x.view(x.size(0), -1)
163 | return x
164 |
165 |
166 | class FCBlockVGG(nn.Module):
167 | def __init__(self, input_dim, hidden_dims, output_dim=10):
168 | super(FCBlockVGG, self).__init__()
169 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
170 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
171 | self.fc3 = nn.Linear(hidden_dims[1], output_dim)
172 |
173 | def forward(self, x):
174 | x = F.dropout(x)
175 | x = F.relu(self.fc1(x))
176 | x = F.dropout(x)
177 | x = F.relu(self.fc2(x))
178 | x = self.fc3(x)
179 | return x
180 |
181 | class SimpleCNN_header(nn.Module):
182 | def __init__(self, input_dim, hidden_dims, output_dim=10):
183 | super(SimpleCNN_header, self).__init__()
184 | self.conv1 = nn.Conv2d(3, 6, 5)
185 | self.relu = nn.ReLU()
186 | self.pool = nn.MaxPool2d(2, 2)
187 | self.conv2 = nn.Conv2d(6, 16, 5)
188 |
189 | # for now, we hard coded this network
190 | # i.e. we fix the number of hidden layers i.e. 2 layers
191 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
192 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
193 | #self.fc3 = nn.Linear(hidden_dims[1], output_dim)
194 |
195 | def forward(self, x):
196 |
197 | x = self.pool(self.relu(self.conv1(x)))
198 | x = self.pool(self.relu(self.conv2(x)))
199 | x = x.view(-1, 16 * 5 * 5)
200 |
201 | x = self.relu(self.fc1(x))
202 | x = self.relu(self.fc2(x))
203 | # x = self.fc3(x)
204 | return x
205 |
206 | class SimpleCNN(nn.Module):
207 | def __init__(self, input_dim, hidden_dims, output_dim=10):
208 | super(SimpleCNN, self).__init__()
209 | self.conv1 = nn.Conv2d(3, 6, 5)
210 | self.pool = nn.MaxPool2d(2, 2)
211 | self.conv2 = nn.Conv2d(6, 16, 5)
212 |
213 | # for now, we hard coded this network
214 | # i.e. we fix the number of hidden layers i.e. 2 layers
215 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
216 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
217 | self.fc3 = nn.Linear(hidden_dims[1], output_dim)
218 |
219 | def forward(self, x):
220 | x = self.pool(F.relu(self.conv1(x)))
221 | x = self.pool(F.relu(self.conv2(x)))
222 | x = x.view(-1, 16 * 5 * 5)
223 |
224 | x = F.relu(self.fc1(x))
225 | x = F.relu(self.fc2(x))
226 | x = self.fc3(x)
227 | return x
228 |
229 |
230 | # a simple perceptron model for generated 3D data
231 | class PerceptronModel(nn.Module):
232 | def __init__(self, input_dim=3, output_dim=2):
233 | super(PerceptronModel, self).__init__()
234 |
235 | self.fc1 = nn.Linear(input_dim, output_dim)
236 |
237 | def forward(self, x):
238 |
239 | x = self.fc1(x)
240 | return x
241 |
242 | class SimpleCNNMNIST_header(nn.Module):
243 | def __init__(self, input_dim, hidden_dims, output_dim=10):
244 | super(SimpleCNNMNIST_header, self).__init__()
245 | self.conv1 = nn.Conv2d(1, 6, 5)
246 | self.relu = nn.ReLU()
247 | self.pool = nn.MaxPool2d(2, 2)
248 | self.conv2 = nn.Conv2d(6, 16, 5)
249 |
250 | # for now, we hard coded this network
251 | # i.e. we fix the number of hidden layers i.e. 2 layers
252 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
253 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
254 | #self.fc3 = nn.Linear(hidden_dims[1], output_dim)
255 |
256 | def forward(self, x):
257 |
258 | x = self.pool(self.relu(self.conv1(x)))
259 | x = self.pool(self.relu(self.conv2(x)))
260 | x = x.view(-1, 16 * 4 * 4)
261 |
262 | x = self.relu(self.fc1(x))
263 | x = self.relu(self.fc2(x))
264 | # x = self.fc3(x)
265 | return x
266 |
267 | class SimpleCNNMNIST(nn.Module):
268 | def __init__(self, input_dim, hidden_dims, output_dim=10):
269 | super(SimpleCNNMNIST, self).__init__()
270 | self.conv1 = nn.Conv2d(1, 6, 5)
271 | self.pool = nn.MaxPool2d(2, 2)
272 | self.conv2 = nn.Conv2d(6, 16, 5)
273 |
274 | # for now, we hard coded this network
275 | # i.e. we fix the number of hidden layers i.e. 2 layers
276 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
277 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
278 | self.fc3 = nn.Linear(hidden_dims[1], output_dim)
279 |
280 | def forward(self, x):
281 | x = self.pool(F.relu(self.conv1(x)))
282 | x = self.pool(F.relu(self.conv2(x)))
283 | x = x.view(-1, 16 * 4 * 4)
284 |
285 | x = F.relu(self.fc1(x))
286 | x = F.relu(self.fc2(x))
287 | x = self.fc3(x)
288 | return x
289 |
290 |
291 | class SimpleCNNContainer(nn.Module):
292 | def __init__(self, input_channel, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10):
293 | super(SimpleCNNContainer, self).__init__()
294 | '''
295 | A testing cnn container, which allows initializing a CNN with given dims
296 |
297 | num_filters (list) :: number of convolution filters
298 | hidden_dims (list) :: number of neurons in hidden layers
299 |
300 | Assumptions:
301 | i) we use only two conv layers and three hidden layers (including the output layer)
302 | ii) kernel size in the two conv layers are identical
303 | '''
304 | self.conv1 = nn.Conv2d(input_channel, num_filters[0], kernel_size)
305 | self.pool = nn.MaxPool2d(2, 2)
306 | self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size)
307 |
308 | # for now, we hard coded this network
309 | # i.e. we fix the number of hidden layers i.e. 2 layers
310 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
311 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
312 | self.fc3 = nn.Linear(hidden_dims[1], output_dim)
313 |
314 | def forward(self, x):
315 | x = self.pool(F.relu(self.conv1(x)))
316 | x = self.pool(F.relu(self.conv2(x)))
317 | x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
318 | x = F.relu(self.fc1(x))
319 | x = F.relu(self.fc2(x))
320 | x = self.fc3(x)
321 | return x
322 |
323 |
324 | ############## LeNet for MNIST ###################
325 | class LeNet(nn.Module):
326 | def __init__(self):
327 | super(LeNet, self).__init__()
328 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
329 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
330 | self.fc1 = nn.Linear(4 * 4 * 50, 500)
331 | self.fc2 = nn.Linear(500, 10)
332 | self.ceriation = nn.CrossEntropyLoss()
333 |
334 | def forward(self, x):
335 | x = self.conv1(x)
336 | x = F.max_pool2d(x, 2, 2)
337 | x = F.relu(x)
338 | x = self.conv2(x)
339 | x = F.max_pool2d(x, 2, 2)
340 | x = F.relu(x)
341 | x = x.view(-1, 4 * 4 * 50)
342 | x = self.fc1(x)
343 | x = self.fc2(x)
344 | return x
345 |
346 |
347 | class LeNetContainer(nn.Module):
348 | def __init__(self, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10):
349 | super(LeNetContainer, self).__init__()
350 | self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size, 1)
351 | self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size, 1)
352 |
353 | self.fc1 = nn.Linear(input_dim, hidden_dims[0])
354 | self.fc2 = nn.Linear(hidden_dims[0], output_dim)
355 |
356 | def forward(self, x):
357 | x = self.conv1(x)
358 | x = F.max_pool2d(x, 2, 2)
359 | x = F.relu(x)
360 | x = self.conv2(x)
361 | x = F.max_pool2d(x, 2, 2)
362 | x = F.relu(x)
363 | x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
364 | x = self.fc1(x)
365 | x = self.fc2(x)
366 | return x
367 |
368 |
369 |
370 | ### Moderate size of CNN for CIFAR-10 dataset
371 | class ModerateCNN(nn.Module):
372 | def __init__(self, output_dim=10):
373 | super(ModerateCNN, self).__init__()
374 | self.conv_layer = nn.Sequential(
375 | # Conv Layer block 1
376 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
377 | nn.ReLU(inplace=True),
378 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
379 | nn.ReLU(inplace=True),
380 | nn.MaxPool2d(kernel_size=2, stride=2),
381 |
382 | # Conv Layer block 2
383 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
384 | nn.ReLU(inplace=True),
385 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
386 | nn.ReLU(inplace=True),
387 | nn.MaxPool2d(kernel_size=2, stride=2),
388 | nn.Dropout2d(p=0.05),
389 |
390 | # Conv Layer block 3
391 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
392 | nn.ReLU(inplace=True),
393 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
394 | nn.ReLU(inplace=True),
395 | nn.MaxPool2d(kernel_size=2, stride=2),
396 | )
397 |
398 | self.fc_layer = nn.Sequential(
399 | nn.Dropout(p=0.1),
400 | # nn.Linear(4096, 1024),
401 | nn.Linear(4096, 512),
402 | nn.ReLU(inplace=True),
403 | # nn.Linear(1024, 512),
404 | nn.Linear(512, 512),
405 | nn.ReLU(inplace=True),
406 | nn.Dropout(p=0.1),
407 | nn.Linear(512, output_dim)
408 | )
409 |
410 | def forward(self, x):
411 | x = self.conv_layer(x)
412 | x = x.view(x.size(0), -1)
413 | x = self.fc_layer(x)
414 | return x
415 |
416 |
417 | ### Moderate size of CNN for CIFAR-10 dataset
418 | class ModerateCNNCeleba(nn.Module):
419 | def __init__(self):
420 | super(ModerateCNNCeleba, self).__init__()
421 | self.conv_layer = nn.Sequential(
422 | # Conv Layer block 1
423 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
424 | nn.ReLU(inplace=True),
425 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
426 | nn.ReLU(inplace=True),
427 | nn.MaxPool2d(kernel_size=2, stride=2),
428 |
429 | # Conv Layer block 2
430 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
431 | nn.ReLU(inplace=True),
432 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
433 | nn.ReLU(inplace=True),
434 | nn.MaxPool2d(kernel_size=2, stride=2),
435 | # nn.Dropout2d(p=0.05),
436 |
437 | # Conv Layer block 3
438 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
439 | nn.ReLU(inplace=True),
440 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
441 | nn.ReLU(inplace=True),
442 | nn.MaxPool2d(kernel_size=2, stride=2),
443 | )
444 |
445 | self.fc_layer = nn.Sequential(
446 | nn.Dropout(p=0.1),
447 | # nn.Linear(4096, 1024),
448 | nn.Linear(4096, 512),
449 | nn.ReLU(inplace=True),
450 | # nn.Linear(1024, 512),
451 | nn.Linear(512, 512),
452 | nn.ReLU(inplace=True),
453 | nn.Dropout(p=0.1),
454 | nn.Linear(512, 2)
455 | )
456 |
457 | def forward(self, x):
458 | x = self.conv_layer(x)
459 | # x = x.view(x.size(0), -1)
460 | x = x.view(-1, 4096)
461 | x = self.fc_layer(x)
462 | return x
463 |
464 |
465 | class ModerateCNNMNIST(nn.Module):
466 | def __init__(self):
467 | super(ModerateCNNMNIST, self).__init__()
468 | self.conv_layer = nn.Sequential(
469 | # Conv Layer block 1
470 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
471 | nn.ReLU(inplace=True),
472 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
473 | nn.ReLU(inplace=True),
474 | nn.MaxPool2d(kernel_size=2, stride=2),
475 |
476 | # Conv Layer block 2
477 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
478 | nn.ReLU(inplace=True),
479 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
480 | nn.ReLU(inplace=True),
481 | nn.MaxPool2d(kernel_size=2, stride=2),
482 | nn.Dropout2d(p=0.05),
483 |
484 | # Conv Layer block 3
485 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
486 | nn.ReLU(inplace=True),
487 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
488 | nn.ReLU(inplace=True),
489 | nn.MaxPool2d(kernel_size=2, stride=2),
490 | )
491 |
492 | self.fc_layer = nn.Sequential(
493 | nn.Dropout(p=0.1),
494 | nn.Linear(2304, 1024),
495 | nn.ReLU(inplace=True),
496 | nn.Linear(1024, 512),
497 | nn.ReLU(inplace=True),
498 | nn.Dropout(p=0.1),
499 | nn.Linear(512, 10)
500 | )
501 |
502 | def forward(self, x):
503 | x = self.conv_layer(x)
504 | x = x.view(x.size(0), -1)
505 | x = self.fc_layer(x)
506 | return x
507 |
508 |
509 | class ModerateCNNContainer(nn.Module):
510 | def __init__(self, input_channels, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10):
511 | super(ModerateCNNContainer, self).__init__()
512 |
513 | ##
514 | self.conv_layer = nn.Sequential(
515 | # Conv Layer block 1
516 | nn.Conv2d(in_channels=input_channels, out_channels=num_filters[0], kernel_size=kernel_size, padding=1),
517 | nn.ReLU(inplace=True),
518 | nn.Conv2d(in_channels=num_filters[0], out_channels=num_filters[1], kernel_size=kernel_size, padding=1),
519 | nn.ReLU(inplace=True),
520 | nn.MaxPool2d(kernel_size=2, stride=2),
521 |
522 | # Conv Layer block 2
523 | nn.Conv2d(in_channels=num_filters[1], out_channels=num_filters[2], kernel_size=kernel_size, padding=1),
524 | nn.ReLU(inplace=True),
525 | nn.Conv2d(in_channels=num_filters[2], out_channels=num_filters[3], kernel_size=kernel_size, padding=1),
526 | nn.ReLU(inplace=True),
527 | nn.MaxPool2d(kernel_size=2, stride=2),
528 | nn.Dropout2d(p=0.05),
529 |
530 | # Conv Layer block 3
531 | nn.Conv2d(in_channels=num_filters[3], out_channels=num_filters[4], kernel_size=kernel_size, padding=1),
532 | nn.ReLU(inplace=True),
533 | nn.Conv2d(in_channels=num_filters[4], out_channels=num_filters[5], kernel_size=kernel_size, padding=1),
534 | nn.ReLU(inplace=True),
535 | nn.MaxPool2d(kernel_size=2, stride=2),
536 | )
537 |
538 | self.fc_layer = nn.Sequential(
539 | nn.Dropout(p=0.1),
540 | nn.Linear(input_dim, hidden_dims[0]),
541 | nn.ReLU(inplace=True),
542 | nn.Linear(hidden_dims[0], hidden_dims[1]),
543 | nn.ReLU(inplace=True),
544 | nn.Dropout(p=0.1),
545 | nn.Linear(hidden_dims[1], output_dim)
546 | )
547 |
548 | def forward(self, x):
549 | x = self.conv_layer(x)
550 | x = x.view(x.size(0), -1)
551 | x = self.fc_layer(x)
552 | return x
553 |
554 | def forward_conv(self, x):
555 | x = self.conv_layer(x)
556 | x = x.view(x.size(0), -1)
557 | return x
558 |
559 |
560 | class ModelFedCon(nn.Module):
561 |
562 | def __init__(self, base_model, out_dim, n_classes, net_configs=None):
563 | super(ModelFedCon, self).__init__()
564 |
565 | # if base_model == "resnet50":
566 | # basemodel = models.resnet50(pretrained=False)
567 | # self.features = nn.Sequential(*list(basemodel.children())[:-1])
568 | # num_ftrs = basemodel.fc.in_features
569 | # elif base_model == "resnet18":
570 | # basemodel = models.resnet18(pretrained=False)
571 | # self.features = nn.Sequential(*list(basemodel.children())[:-1])
572 | # num_ftrs = basemodel.fc.in_features
573 | if base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel" or base_model == "resnet50":
574 | basemodel = ResNet50_cifar10()
575 | self.features = nn.Sequential(*list(basemodel.children())[:-1])
576 | num_ftrs = basemodel.fc.in_features
577 | elif base_model == "resnet18-cifar10" or base_model == "resnet18":
578 | basemodel = ResNet18_cifar10()
579 | self.features = nn.Sequential(*list(basemodel.children())[:-1])
580 | num_ftrs = basemodel.fc.in_features
581 | elif base_model == "mlp":
582 | self.features = MLP_header(input_dim=net_configs[0], hidden_dims=net_configs[1:-1])
583 | num_ftrs = net_configs[-2]
584 | elif base_model == 'simple-cnn':
585 | self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes)
586 | num_ftrs = 84
587 | elif base_model == 'simple-cnn-mnist':
588 | self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes)
589 | num_ftrs = 84
590 |
591 | #summary(self.features.to('cuda:0'), (3,32,32))
592 | #print("features:", self.features)
593 | # projection MLP
594 | self.l1 = nn.Linear(num_ftrs, num_ftrs)
595 | self.l2 = nn.Linear(num_ftrs, out_dim)
596 |
597 | # last layer
598 | self.l3 = nn.Linear(out_dim, n_classes)
599 | self.num_ftrs = num_ftrs
600 |
601 | def _get_basemodel(self, model_name):
602 | try:
603 | model = self.model_dict[model_name]
604 | #print("Feature extractor:", model_name)
605 | return model
606 | except:
607 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
608 |
609 | def forward(self, x):
610 | h = self.features(x)
611 | h = h.reshape(-1, self.num_ftrs)
612 | #print("h before:", h)
613 | #print("h size:", h.size())
614 | #h = h.squeeze()
615 | #print("h after:", h)
616 | x = self.l1(h)
617 | x = F.relu(x)
618 | x = self.l2(x)
619 |
620 | y = self.l3(x)
621 | return h, x, y
622 |
623 |
624 | class ModelFedCon_noheader(nn.Module):
625 |
626 | def __init__(self, base_model, out_dim, n_classes, net_configs=None):
627 | super(ModelFedCon_noheader, self).__init__()
628 |
629 | # if base_model == "resnet":
630 | # basemodel = models.resnet50(pretrained=False)
631 | # self.features = nn.Sequential(*list(basemodel.children())[:-1])
632 | # num_ftrs = basemodel.fc.in_features
633 | if base_model == "resnet18":
634 | basemodel = models.resnet18(pretrained=False)
635 | self.features = nn.Sequential(*list(basemodel.children())[:-1])
636 | num_ftrs = basemodel.fc.in_features
637 | elif base_model == "resnet" or base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel":
638 | basemodel = ResNet50_cifar10()
639 | self.features = nn.Sequential(*list(basemodel.children())[:-1])
640 | num_ftrs = basemodel.fc.in_features
641 | elif base_model == "resnet18-cifar10":
642 | basemodel = ResNet18_cifar10()
643 | self.features = nn.Sequential(*list(basemodel.children())[:-1])
644 | num_ftrs = basemodel.fc.in_features
645 | elif base_model == "mlp":
646 | self.features = MLP_header(input_dim=net_configs[0], hidden_dims=net_configs[1:-1])
647 | num_ftrs = net_configs[-2]
648 | elif base_model == 'simple-cnn':
649 | self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes)
650 | num_ftrs = 84
651 | elif base_model == 'simple-cnn-mnist':
652 | self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes)
653 | num_ftrs = 84
654 |
655 | #summary(self.features.to('cuda:0'), (3,32,32))
656 | #print("features:", self.features)
657 | # projection MLP
658 | # self.l1 = nn.Linear(num_ftrs, num_ftrs)
659 | # self.l2 = nn.Linear(num_ftrs, out_dim)
660 |
661 | # last layer
662 | self.l3 = nn.Linear(num_ftrs, n_classes)
663 | self.num_ftrs = num_ftrs
664 |
665 | def _get_basemodel(self, model_name):
666 | try:
667 | model = self.model_dict[model_name]
668 | #print("Feature extractor:", model_name)
669 | return model
670 | except:
671 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
672 |
673 | def forward(self, x):
674 | h = self.features(x)
675 | h = h.reshape(-1, self.num_ftrs)
676 | #print("h before:", h)
677 | #print("h size:", h.size())
678 | #h = h.squeeze()
679 | #print("h after:", h)
680 | # x = self.l1(h)
681 | # x = F.relu(x)
682 | # x = self.l2(x)
683 |
684 | y = self.l3(h)
685 | return h, h, y
686 |
687 |
--------------------------------------------------------------------------------
/models/celeba_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Architecture based on InfoGAN paper.
7 | """
8 |
9 | class Generator(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | self.tconv1 = nn.ConvTranspose2d(228, 448, 2, 1, bias=False)
14 | self.bn1 = nn.BatchNorm2d(448)
15 |
16 | self.tconv2 = nn.ConvTranspose2d(448, 256, 4, 2, padding=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(256)
18 |
19 | self.tconv3 = nn.ConvTranspose2d(256, 128, 4, 2, padding=1, bias=False)
20 |
21 | self.tconv4 = nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False)
22 |
23 | self.tconv5 = nn.ConvTranspose2d(64, 3, 4, 2, padding=1, bias=False)
24 |
25 | def forward(self, x):
26 | x = F.relu(self.bn1(self.tconv1(x)))
27 | x = F.relu(self.bn2(self.tconv2(x)))
28 | x = F.relu(self.tconv3(x))
29 | x = F.relu(self.tconv4(x))
30 |
31 | img = torch.tanh(self.tconv5(x))
32 |
33 | return img
34 |
35 | class Discriminator(nn.Module):
36 | def __init__(self):
37 | super().__init__()
38 |
39 | self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
40 |
41 | self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
42 | self.bn2 = nn.BatchNorm2d(128)
43 |
44 | self.conv3 = nn.Conv2d(128, 256, 4, 2, 1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(256)
46 |
47 | def forward(self, x):
48 | x = F.leaky_relu(self.conv1(x), 0.1, inplace=True)
49 | x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1, inplace=True)
50 | x = F.leaky_relu(self.bn3(self.conv3(x)), 0.1, inplace=True)
51 |
52 | return x
53 |
54 | class DHead(nn.Module):
55 | def __init__(self):
56 | super().__init__()
57 |
58 | self.conv = nn.Conv2d(256, 1, 4)
59 |
60 | def forward(self, x):
61 | output = torch.sigmoid(self.conv(x))
62 |
63 | return output
64 |
65 | class QHead(nn.Module):
66 | def __init__(self):
67 | super().__init__()
68 |
69 | self.conv1 = nn.Conv2d(256, 128, 4, bias=False)
70 | self.bn1 = nn.BatchNorm2d(128)
71 |
72 | self.conv_disc = nn.Conv2d(128, 100, 1)
73 |
74 | self.conv_mu = nn.Conv2d(128, 1, 1)
75 | self.conv_var = nn.Conv2d(128, 1, 1)
76 |
77 | def forward(self, x):
78 | x = F.leaky_relu(self.bn1(self.conv1(x)), 0.1, inplace=True)
79 |
80 | disc_logits = self.conv_disc(x).squeeze()
81 |
82 | # Not used during training for celeba dataset.
83 | mu = self.conv_mu(x).squeeze()
84 | var = torch.exp(self.conv_var(x).squeeze())
85 |
86 | return disc_logits, mu, var
87 |
--------------------------------------------------------------------------------
/models/mnist_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Architecture based on InfoGAN paper.
7 | """
8 |
9 | class Generator(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | self.tconv1 = nn.ConvTranspose2d(74, 1024, 1, 1, bias=False)
14 | self.bn1 = nn.BatchNorm2d(1024)
15 |
16 | self.tconv2 = nn.ConvTranspose2d(1024, 128, 7, 1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(128)
18 |
19 | self.tconv3 = nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False)
20 | self.bn3 = nn.BatchNorm2d(64)
21 |
22 | self.tconv4 = nn.ConvTranspose2d(64, 1, 4, 2, padding=1, bias=False)
23 |
24 | def forward(self, x):
25 | x = F.relu(self.bn1(self.tconv1(x)))
26 | x = F.relu(self.bn2(self.tconv2(x)))
27 | x = F.relu(self.bn3(self.tconv3(x)))
28 |
29 | img = torch.sigmoid(self.tconv4(x))
30 |
31 | return img
32 |
33 | class Discriminator(nn.Module):
34 | def __init__(self):
35 | super().__init__()
36 |
37 | self.conv1 = nn.Conv2d(1, 64, 4, 2, 1)
38 |
39 | self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
40 | self.bn2 = nn.BatchNorm2d(128)
41 |
42 | self.conv3 = nn.Conv2d(128, 1024, 7, bias=False)
43 | self.bn3 = nn.BatchNorm2d(1024)
44 |
45 | def forward(self, x):
46 | x = F.leaky_relu(self.conv1(x), 0.1, inplace=True)
47 | x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1, inplace=True)
48 | x = F.leaky_relu(self.bn3(self.conv3(x)), 0.1, inplace=True)
49 |
50 | return x
51 |
52 | class DHead(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 |
56 | self.conv = nn.Conv2d(1024, 1, 1)
57 |
58 | def forward(self, x):
59 | output = torch.sigmoid(self.conv(x))
60 |
61 | return output
62 |
63 | class QHead(nn.Module):
64 | def __init__(self):
65 | super().__init__()
66 |
67 | self.conv1 = nn.Conv2d(1024, 128, 1, bias=False)
68 | self.bn1 = nn.BatchNorm2d(128)
69 |
70 | self.conv_disc = nn.Conv2d(128, 10, 1)
71 | self.conv_mu = nn.Conv2d(128, 2, 1)
72 | self.conv_var = nn.Conv2d(128, 2, 1)
73 |
74 | def forward(self, x):
75 | x = F.leaky_relu(self.bn1(self.conv1(x)), 0.1, inplace=True)
76 |
77 | disc_logits = self.conv_disc(x).squeeze()
78 |
79 | mu = self.conv_mu(x).squeeze()
80 | var = torch.exp(self.conv_var(x).squeeze())
81 |
82 | return disc_logits, mu, var
83 |
--------------------------------------------------------------------------------
/models/svhn_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Architecture based on InfoGAN paper.
7 | """
8 |
9 | class Generator(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | self.tconv1 = nn.ConvTranspose2d(138, 448, 2, 1, bias=False)
14 | self.bn1 = nn.BatchNorm2d(448)
15 |
16 | self.tconv2 = nn.ConvTranspose2d(448, 256, 4, 2, padding=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(256)
18 |
19 | self.tconv3 = nn.ConvTranspose2d(256, 128, 4, 2, padding=1, bias=False)
20 |
21 | self.tconv4 = nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False)
22 |
23 | self.tconv5 = nn.ConvTranspose2d(64, 3, 4, 2, padding=1, bias=False)
24 |
25 | def forward(self, x):
26 | x = F.relu(self.bn1(self.tconv1(x)))
27 | x = F.relu(self.bn2(self.tconv2(x)))
28 | x = F.relu(self.tconv3(x))
29 | x = F.relu(self.tconv4(x))
30 |
31 | img = torch.tanh(self.tconv5(x))
32 |
33 | return img
34 |
35 | class Discriminator(nn.Module):
36 | def __init__(self):
37 | super().__init__()
38 |
39 | self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
40 |
41 | self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
42 | self.bn2 = nn.BatchNorm2d(128)
43 |
44 | self.conv3 = nn.Conv2d(128, 256, 4, 2, 1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(256)
46 |
47 | def forward(self, x):
48 | x = F.leaky_relu(self.conv1(x), 0.1, inplace=True)
49 | x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1, inplace=True)
50 | x = F.leaky_relu(self.bn3(self.conv3(x)), 0.1, inplace=True)
51 |
52 | return x
53 |
54 | class DHead(nn.Module):
55 | def __init__(self):
56 | super().__init__()
57 |
58 | self.conv = nn.Conv2d(256, 1, 4)
59 |
60 | def forward(self, x):
61 | output = torch.sigmoid(self.conv(x))
62 |
63 | return output
64 |
65 | class QHead(nn.Module):
66 | def __init__(self):
67 | super().__init__()
68 |
69 | self.conv1 = nn.Conv2d(256, 128, 4, bias=False)
70 | self.bn1 = nn.BatchNorm2d(128)
71 |
72 | self.conv_disc = nn.Conv2d(128, 40, 1)
73 | self.conv_mu = nn.Conv2d(128, 4, 1)
74 | self.conv_var = nn.Conv2d(128, 4, 1)
75 |
76 | def forward(self, x):
77 | x = F.leaky_relu(self.bn1(self.conv1(x)), 0.1, inplace=True)
78 |
79 | disc_logits = self.conv_disc(x).squeeze()
80 |
81 | mu = self.conv_mu(x).squeeze()
82 | var = torch.exp(self.conv_var(x).squeeze())
83 |
84 | return disc_logits, mu, var
85 |
--------------------------------------------------------------------------------
/partition.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import random
5 | import argparse
6 | import csv
7 |
8 | from utils import mkdirs
9 | def partition_data(dataset, class_id, K, partition, n_parties, beta, seed):
10 | np.random.seed(seed)
11 | random.seed(seed)
12 |
13 | n_train = dataset.shape[0]
14 | y_train = dataset[:,class_id]
15 |
16 | if partition == "homo":
17 | idxs = np.random.permutation(n_train)
18 | batch_idxs = np.array_split(idxs, n_parties)
19 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}
20 |
21 | elif partition == "noniid-labeldir":
22 | min_size = 0
23 | min_require_size = 10
24 |
25 | N = dataset.shape[0]
26 | net_dataidx_map = {}
27 |
28 | while min_size < min_require_size:
29 | idx_batch = [[] for _ in range(n_parties)]
30 | for k in range(K):
31 | idx_k = np.where(y_train == k)[0]
32 | np.random.shuffle(idx_k)
33 | proportions = np.random.dirichlet(np.repeat(beta, n_parties))
34 | # logger.info("proportions1: ", proportions)
35 | # logger.info("sum pro1:", np.sum(proportions))
36 | ## Balance
37 | proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)])
38 | # logger.info("proportions2: ", proportions)
39 | proportions = proportions / proportions.sum()
40 | # logger.info("proportions3: ", proportions)
41 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
42 | # logger.info("proportions4: ", proportions)
43 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
44 | min_size = min([len(idx_j) for idx_j in idx_batch])
45 | # if K == 2 and n_parties <= 10:
46 | # if np.min(proportions) < 200:
47 | # min_size = 0
48 | # break
49 |
50 |
51 | for j in range(n_parties):
52 | np.random.shuffle(idx_batch[j])
53 | net_dataidx_map[j] = idx_batch[j]
54 |
55 | elif partition > "noniid-#label0" and partition <= "noniid-#label9":
56 | num = eval(partition[13:])
57 |
58 | times=[0 for i in range(K)]
59 | contain=[]
60 | for i in range(n_parties):
61 | current=[i%K]
62 | times[i%K]+=1
63 | j=1
64 | while (j 1:
152 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
153 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
154 | self.conv1 = conv3x3(inplanes, planes, stride)
155 | self.bn1 = norm_layer(planes)
156 | self.relu = nn.ReLU(inplace=True)
157 | self.conv2 = conv3x3(planes, planes)
158 | self.bn2 = norm_layer(planes)
159 | self.downsample = downsample
160 | self.stride = stride
161 |
162 | def forward(self, x):
163 | identity = x
164 |
165 | out = self.conv1(x)
166 | out = self.bn1(out)
167 | out = self.relu(out)
168 |
169 | out = self.conv2(out)
170 | out = self.bn2(out)
171 |
172 | if self.downsample is not None:
173 | identity = self.downsample(x)
174 |
175 | out += identity
176 | out = self.relu(out)
177 |
178 | return out
179 |
180 |
181 | class Bottleneck(nn.Module):
182 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
183 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
184 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
185 | # This variant is also known as ResNet V1.5 and improves accuracy according to
186 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
187 |
188 | expansion = 4
189 |
190 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
191 | base_width=64, dilation=1, norm_layer=None):
192 | super(Bottleneck, self).__init__()
193 | if norm_layer is None:
194 | norm_layer = nn.BatchNorm2d
195 | width = int(planes * (base_width / 64.)) * groups
196 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
197 | self.conv1 = conv1x1(inplanes, width)
198 | self.bn1 = norm_layer(width)
199 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
200 | self.bn2 = norm_layer(width)
201 | self.conv3 = conv1x1(width, planes * self.expansion)
202 | self.bn3 = norm_layer(planes * self.expansion)
203 | self.relu = nn.ReLU(inplace=True)
204 | self.downsample = downsample
205 | self.stride = stride
206 |
207 | def forward(self, x):
208 | identity = x
209 |
210 | out = self.conv1(x)
211 | out = self.bn1(out)
212 | out = self.relu(out)
213 |
214 | out = self.conv2(out)
215 | out = self.bn2(out)
216 | out = self.relu(out)
217 |
218 | out = self.conv3(out)
219 | out = self.bn3(out)
220 |
221 | if self.downsample is not None:
222 | identity = self.downsample(x)
223 |
224 | out += identity
225 | out = self.relu(out)
226 |
227 | return out
228 |
229 |
230 | class ResNetCifar10(nn.Module):
231 |
232 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
233 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
234 | norm_layer=None):
235 | super(ResNetCifar10, self).__init__()
236 | if norm_layer is None:
237 | norm_layer = nn.BatchNorm2d
238 | self._norm_layer = norm_layer
239 |
240 | self.inplanes = 64
241 | self.dilation = 1
242 | if replace_stride_with_dilation is None:
243 | # each element in the tuple indicates if we should replace
244 | # the 2x2 stride with a dilated convolution instead
245 | replace_stride_with_dilation = [False, False, False]
246 | if len(replace_stride_with_dilation) != 3:
247 | raise ValueError("replace_stride_with_dilation should be None "
248 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
249 | self.groups = groups
250 | self.base_width = width_per_group
251 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
252 | bias=False)
253 | self.bn1 = norm_layer(self.inplanes)
254 | self.relu = nn.ReLU(inplace=True)
255 | self.layer1 = self._make_layer(block, 64, layers[0])
256 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
257 | dilate=replace_stride_with_dilation[0])
258 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
259 | dilate=replace_stride_with_dilation[1])
260 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
261 | dilate=replace_stride_with_dilation[2])
262 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
263 | self.fc = nn.Linear(512 * block.expansion, num_classes)
264 |
265 | for m in self.modules():
266 | if isinstance(m, nn.Conv2d):
267 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
268 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
269 | nn.init.constant_(m.weight, 1)
270 | nn.init.constant_(m.bias, 0)
271 |
272 | # Zero-initialize the last BN in each residual branch,
273 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
274 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
275 | if zero_init_residual:
276 | for m in self.modules():
277 | if isinstance(m, Bottleneck):
278 | nn.init.constant_(m.bn3.weight, 0)
279 | elif isinstance(m, BasicBlock):
280 | nn.init.constant_(m.bn2.weight, 0)
281 |
282 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
283 | norm_layer = self._norm_layer
284 | downsample = None
285 | previous_dilation = self.dilation
286 | if dilate:
287 | self.dilation *= stride
288 | stride = 1
289 | if stride != 1 or self.inplanes != planes * block.expansion:
290 | downsample = nn.Sequential(
291 | conv1x1(self.inplanes, planes * block.expansion, stride),
292 | norm_layer(planes * block.expansion),
293 | )
294 |
295 | layers = []
296 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
297 | self.base_width, previous_dilation, norm_layer))
298 | self.inplanes = planes * block.expansion
299 | for _ in range(1, blocks):
300 | layers.append(block(self.inplanes, planes, groups=self.groups,
301 | base_width=self.base_width, dilation=self.dilation,
302 | norm_layer=norm_layer))
303 |
304 | return nn.Sequential(*layers)
305 |
306 | def _forward_impl(self, x):
307 | # See note [TorchScript super()]
308 | x = self.conv1(x)
309 | x = self.bn1(x)
310 | x = self.relu(x)
311 |
312 | x = self.layer1(x)
313 | x = self.layer2(x)
314 | x = self.layer3(x)
315 | x = self.layer4(x)
316 |
317 | x = self.avgpool(x)
318 | x = torch.flatten(x, 1)
319 | x = self.fc(x)
320 |
321 | return x
322 |
323 | def forward(self, x):
324 | return self._forward_impl(x)
325 |
326 |
327 | def ResNet18_cifar10(**kwargs):
328 | r"""ResNet-18 model from
329 | `"Deep Residual Learning for Image Recognition" `_
330 |
331 | Args:
332 | pretrained (bool): If True, returns a model pre-trained on ImageNet
333 | progress (bool): If True, displays a progress bar of the download to stderr
334 | """
335 | return ResNetCifar10(BasicBlock, [2, 2, 2, 2], **kwargs)
336 |
337 |
338 |
339 | def ResNet50_cifar10(**kwargs):
340 | r"""ResNet-50 model from
341 | `"Deep Residual Learning for Image Recognition" `_
342 |
343 | Args:
344 | pretrained (bool): If True, returns a model pre-trained on ImageNet
345 | progress (bool): If True, displays a progress bar of the download to stderr
346 | """
347 | return ResNetCifar10(Bottleneck, [3, 4, 6, 3], **kwargs)
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | python experiments.py --model=simple-cnn \
2 | --dataset=cifar10 \
3 | --alg=fednova \
4 | --lr=0.01 \
5 | --batch-size=64 \
6 | --epochs=10 \
7 | --n_parties=10 \
8 | --rho=0.9 \
9 | --comm_round=50 \
10 | --partition=noniid-labeldir \
11 | --beta=0.5\
12 | --device='cpu'\
13 | --datadir='./data/' \
14 | --logdir='./logs/' \
15 | --noise=0\
16 | --init_seed=0
--------------------------------------------------------------------------------
/scripts/100-parties.sh:
--------------------------------------------------------------------------------
1 | for partition in noniid-labeldir noniid-#label1 noniid-#label2 noniid-#label3 iid-diff-quantity homo
2 | do
3 | for alg in fedavg scaffold fednova
4 | do
5 | python experiments.py --model=simple-cnn \
6 | --dataset=cifar10 \
7 | --alg=$alg \
8 | --lr=0.01 \
9 | --batch-size=64 \
10 | --epochs=10 \
11 | --n_parties=100 \
12 | --rho=0.9 \
13 | --comm_round=500 \
14 | --partition=$partition \
15 | --beta=0.5\
16 | --device='cuda:0'\
17 | --datadir='./data/' \
18 | --logdir='./logs/' \
19 | --noise=0\
20 | --sample=0.1\
21 | --init_seed=0
22 | done
23 |
24 | for mu in 0.001 0.01 0.1 1
25 | do
26 | python experiments.py --model=simple-cnn \
27 | --dataset=cifar10 \
28 | --alg=fedprox \
29 | --lr=0.01 \
30 | --batch-size=64 \
31 | --epochs=10 \
32 | --n_parties=100 \
33 | --rho=0.9 \
34 | --mu=$mu
35 | --comm_round=500 \
36 | --partition=$partition \
37 | --beta=0.5\
38 | --device='cuda:0'\
39 | --datadir='./data/' \
40 | --logdir='./logs/' \
41 | --noise=0\
42 | --sample=0.1\
43 | --init_seed=0
44 | done
45 | done
46 |
--------------------------------------------------------------------------------
/scripts/adult&covtype.sh:
--------------------------------------------------------------------------------
1 | for init_seed in 0 1 2
2 | do
3 | for partition in noniid-labeldir noniid-#label1 iid-diff-quantity homo
4 | do
5 | for dataset in a9a covtype
6 | do
7 | for alg in fedavg scaffold fednova
8 | do
9 | python experiments.py --model=mlp \
10 | --dataset=$dataset \
11 | --alg=$alg \
12 | --lr=0.01 \
13 | --batch-size=64 \
14 | --epochs=10 \
15 | --n_parties=10 \
16 | --rho=0.9 \
17 | --comm_round=50 \
18 | --partition=$partition \
19 | --beta=0.5\
20 | --device='cuda:0'\
21 | --datadir='./data/' \
22 | --logdir='./logs/' \
23 | --noise=0\
24 | --init_seed=$init_seed
25 | done
26 |
27 | for mu in 0.001 0.01 0.1 1
28 | do
29 | python experiments.py --model=mlp \
30 | --dataset=$dataset \
31 | --alg=fedprox \
32 | --lr=0.01 \
33 | --batch-size=64 \
34 | --epochs=10 \
35 | --n_parties=10 \
36 | --rho=0.9 \
37 | --mu=$mu
38 | --comm_round=50 \
39 | --partition=$partition \
40 | --beta=0.5\
41 | --device='cuda:0'\
42 | --datadir='./data/' \
43 | --logdir='./logs/' \
44 | --noise=0\
45 | --init_seed=$init_seed
46 | done
47 |
48 | done
49 | done
50 | done
51 |
--------------------------------------------------------------------------------
/scripts/batch-size.sh:
--------------------------------------------------------------------------------
1 | for size in 16 32 64 128 256
2 | do
3 | for alg in fedavg fedprox scaffold fednova
4 | do
5 | python experiments.py --model=simple-cnn \
6 | --dataset=cifar10 \
7 | --alg=$alg \
8 | --lr=0.01 \
9 | --batch-size=$size \
10 | --epochs=10 \
11 | --n_parties=10 \
12 | --rho=0.9 \
13 | --mu=0.01 \
14 | --comm_round=50 \
15 | --partition=noniid-labeldir \
16 | --beta=0.5\
17 | --device='cuda:0'\
18 | --datadir='./data/' \
19 | --logdir='./logs/' \
20 | --noise=0\
21 | --sample=0\
22 | --init_seed=0
23 | done
24 | done
25 |
26 |
--------------------------------------------------------------------------------
/scripts/fcube.sh:
--------------------------------------------------------------------------------
1 | for init_seed in 0 1 2
2 | do
3 | for partition in real homo
4 | do
5 | for alg in fedavg scaffold fednova
6 | do
7 | python experiments.py --model=mlp \
8 | --dataset=generated \
9 | --alg=$alg \
10 | --lr=0.01 \
11 | --batch-size=64 \
12 | --epochs=10 \
13 | --n_parties=4 \
14 | --rho=0.9 \
15 | --comm_round=50 \
16 | --partition=$partition \
17 | --beta=0.5\
18 | --device='cuda:0'\
19 | --datadir='./data/' \
20 | --logdir='./logs/' \
21 | --noise=0\
22 | --init_seed=$init_seed
23 | done
24 |
25 | for mu in 0.001 0.01 0.1 1
26 | do
27 | python experiments.py --model=mlp \
28 | --dataset=generated \
29 | --alg=fedprox \
30 | --lr=0.01 \
31 | --batch-size=64 \
32 | --epochs=10 \
33 | --n_parties=4 \
34 | --rho=0.9 \
35 | --mu=$mu
36 | --comm_round=50 \
37 | --partition=$partition \
38 | --beta=0.5\
39 | --device='cuda:0'\
40 | --datadir='./data/' \
41 | --logdir='./logs/' \
42 | --noise=0\
43 | --init_seed=$init_seed
44 | done
45 |
46 | done
47 | done
48 |
--------------------------------------------------------------------------------
/scripts/femnist.sh:
--------------------------------------------------------------------------------
1 | for init_seed in 0 1 2
2 | do
3 | for partition in real homo
4 | do
5 | for alg in fedavg scaffold fednova
6 | do
7 | python experiments.py --model=mlp \
8 | --dataset=femnist \
9 | --alg=$alg \
10 | --lr=0.01 \
11 | --batch-size=64 \
12 | --epochs=10 \
13 | --n_parties=10 \
14 | --rho=0.9 \
15 | --comm_round=50 \
16 | --partition=$partition \
17 | --beta=0.5\
18 | --device='cuda:0'\
19 | --datadir='./data/' \
20 | --logdir='./logs/' \
21 | --noise=0\
22 | --init_seed=$init_seed
23 | done
24 |
25 | for mu in 0.001 0.01 0.1 1
26 | do
27 | python experiments.py --model=mlp \
28 | --dataset=femnist \
29 | --alg=fedprox \
30 | --lr=0.01 \
31 | --batch-size=64 \
32 | --epochs=10 \
33 | --n_parties=10 \
34 | --rho=0.9 \
35 | --mu=$mu
36 | --comm_round=50 \
37 | --partition=$partition \
38 | --beta=0.5\
39 | --device='cuda:0'\
40 | --datadir='./data/' \
41 | --logdir='./logs/' \
42 | --noise=0\
43 | --init_seed=$init_seed
44 | done
45 |
46 | done
47 | done
48 |
--------------------------------------------------------------------------------
/scripts/image-data-with-noise.sh:
--------------------------------------------------------------------------------
1 | for init_seed in 0 1 2
2 | do
3 | for dataset in cifar10 mnist fmnist svhn
4 | do
5 | for alg in fedavg scaffold fednova
6 | do
7 | python experiments.py --model=simple-cnn \
8 | --dataset=$dataset \
9 | --alg=$alg \
10 | --lr=0.01 \
11 | --batch-size=64 \
12 | --epochs=10 \
13 | --n_parties=10 \
14 | --rho=0.9 \
15 | --comm_round=50 \
16 | --partition=homo \
17 | --beta=0.5\
18 | --device='cuda:0'\
19 | --datadir='./data/' \
20 | --logdir='./logs/' \
21 | --noise=0.1\
22 | --init_seed=$init_seed
23 | done
24 |
25 | for mu in 0.001 0.01 0.1 1
26 | do
27 | python experiments.py --model=simple-cnn \
28 | --dataset=$dataset \
29 | --alg=fedprox \
30 | --lr=0.01 \
31 | --batch-size=64 \
32 | --epochs=10 \
33 | --n_parties=10 \
34 | --rho=0.9 \
35 | --mu=$mu
36 | --comm_round=50 \
37 | --partition=homo \
38 | --beta=0.5\
39 | --device='cuda:0'\
40 | --datadir='./data/' \
41 | --logdir='./logs/' \
42 | --noise=0.1\
43 | --init_seed=$init_seed
44 | done
45 |
46 | done
47 | done
48 |
49 |
50 |
--------------------------------------------------------------------------------
/scripts/image-data-without-noise.sh:
--------------------------------------------------------------------------------
1 | for init_seed in 0 1 2
2 | do
3 | for partition in noniid-labeldir noniid-#label1 noniid-#label2 noniid-#label3 iid-diff-quantity homo
4 | do
5 | for dataset in cifar10 mnist fmnist svhn
6 | do
7 | for alg in fedavg scaffold fednova
8 | do
9 | python experiments.py --model=simple-cnn \
10 | --dataset=$dataset \
11 | --alg=$alg \
12 | --lr=0.01 \
13 | --batch-size=64 \
14 | --epochs=10 \
15 | --n_parties=10 \
16 | --rho=0.9 \
17 | --comm_round=50 \
18 | --partition=$partition \
19 | --beta=0.5\
20 | --device='cuda:0'\
21 | --datadir='./data/' \
22 | --logdir='./logs/' \
23 | --noise=0\
24 | --init_seed=$init_seed
25 | done
26 |
27 | for mu in 0.001 0.01 0.1 1
28 | do
29 | python experiments.py --model=simple-cnn \
30 | --dataset=$dataset \
31 | --alg=fedprox \
32 | --lr=0.01 \
33 | --batch-size=64 \
34 | --epochs=10 \
35 | --n_parties=10 \
36 | --rho=0.9 \
37 | --mu=$mu
38 | --comm_round=50 \
39 | --partition=$partition \
40 | --beta=0.5\
41 | --device='cuda:0'\
42 | --datadir='./data/' \
43 | --logdir='./logs/' \
44 | --noise=0\
45 | --init_seed=$init_seed
46 | done
47 |
48 | done
49 | done
50 | done
51 |
52 |
53 |
--------------------------------------------------------------------------------
/scripts/rcv1.sh:
--------------------------------------------------------------------------------
1 | for init_seed in 0 1 2
2 | do
3 | for partition in noniid-labeldir noniid-#label1 iid-diff-quantity homo
4 | do
5 | for dataset in rcv1
6 | do
7 | for alg in fedavg scaffold fednova
8 | do
9 | python experiments.py --model=mlp \
10 | --dataset=$dataset \
11 | --alg=$alg \
12 | --lr=0.1 \
13 | --batch-size=64 \
14 | --epochs=10 \
15 | --n_parties=10 \
16 | --rho=0.9 \
17 | --comm_round=50 \
18 | --partition=$partition \
19 | --beta=0.5\
20 | --device='cuda:0'\
21 | --datadir='./data/' \
22 | --logdir='./logs/' \
23 | --noise=0\
24 | --init_seed=$init_seed
25 | done
26 |
27 | for mu in 0.001 0.01 0.1 1
28 | do
29 | python experiments.py --model=mlp \
30 | --dataset=$dataset \
31 | --alg=fedprox \
32 | --lr=0.1 \
33 | --batch-size=64 \
34 | --epochs=10 \
35 | --n_parties=10 \
36 | --rho=0.9 \
37 | --mu=$mu
38 | --comm_round=50 \
39 | --partition=$partition \
40 | --beta=0.5\
41 | --device='cuda:0'\
42 | --datadir='./data/' \
43 | --logdir='./logs/' \
44 | --noise=0\
45 | --init_seed=$init_seed
46 | done
47 |
48 | done
49 | done
50 | done
51 |
--------------------------------------------------------------------------------
/scripts/vgg&resnet.sh:
--------------------------------------------------------------------------------
1 | for model in vgg resnet
2 | do
3 | for partition in noniid-labeldir iid-diff-quantity
4 | do
5 | for alg in fedavg fedprox scaffold fednova
6 | do
7 | python experiments.py --model=$model \
8 | --dataset=cifar10 \
9 | --alg=$alg \
10 | --lr=0.01 \
11 | --batch-size=64 \
12 | --epochs=10 \
13 | --n_parties=10 \
14 | --rho=0.9 \
15 | --mu=0.01 \
16 | --comm_round=100 \
17 | --partition=$partition \
18 | --beta=0.1\
19 | --device='cuda:0'\
20 | --datadir='./data/' \
21 | --logdir='./logs/' \
22 | --noise=0\
23 | --init_seed=0
24 | done
25 | done
26 |
27 | for alg in fedavg fedprox scaffold fednova
28 | do
29 | python experiments.py --model=$model \
30 | --dataset=cifar10 \
31 | --alg=$alg \
32 | --lr=0.01 \
33 | --batch-size=64 \
34 | --epochs=10 \
35 | --n_parties=10 \
36 | --rho=0.9 \
37 | --mu=0.01 \
38 | --comm_round=100 \
39 | --partition=homo \
40 | --beta=0.1\
41 | --device='cuda:0'\
42 | --datadir='./data/' \
43 | --logdir='./logs/' \
44 | --noise=0.1\
45 | --init_seed=0
46 | done
47 | done
48 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import torch
5 | import torchvision.transforms as transforms
6 | import torch.utils.data as data
7 | from torch.autograd import Variable
8 | import torch.nn.functional as F
9 | import random
10 | from sklearn.metrics import confusion_matrix
11 | from torch.utils.data import DataLoader
12 | import copy
13 |
14 | from model import *
15 | from datasets import MNIST_truncated, CIFAR10_truncated, CIFAR100_truncated, ImageFolder_custom, SVHN_custom, FashionMNIST_truncated, CustomTensorDataset, CelebA_custom, FEMNIST, Generated, genData
16 | from math import sqrt
17 |
18 | import torch.nn as nn
19 |
20 | import torch.optim as optim
21 | import torchvision.utils as vutils
22 | import time
23 | import random
24 |
25 | from models.mnist_model import Generator, Discriminator, DHead, QHead
26 | from config import params
27 | import sklearn.datasets as sk
28 | from sklearn.datasets import load_svmlight_file
29 |
30 | logging.basicConfig()
31 | logger = logging.getLogger()
32 | logger.setLevel(logging.INFO)
33 |
34 | def mkdirs(dirpath):
35 | try:
36 | os.makedirs(dirpath)
37 | except Exception as _:
38 | pass
39 |
40 | def load_mnist_data(datadir):
41 |
42 | transform = transforms.Compose([transforms.ToTensor()])
43 |
44 | mnist_train_ds = MNIST_truncated(datadir, train=True, download=True, transform=transform)
45 | mnist_test_ds = MNIST_truncated(datadir, train=False, download=True, transform=transform)
46 |
47 | X_train, y_train = mnist_train_ds.data, mnist_train_ds.target
48 | X_test, y_test = mnist_test_ds.data, mnist_test_ds.target
49 |
50 | X_train = X_train.data.numpy()
51 | y_train = y_train.data.numpy()
52 | X_test = X_test.data.numpy()
53 | y_test = y_test.data.numpy()
54 |
55 | return (X_train, y_train, X_test, y_test)
56 |
57 | def load_fmnist_data(datadir):
58 |
59 | transform = transforms.Compose([transforms.ToTensor()])
60 |
61 | mnist_train_ds = FashionMNIST_truncated(datadir, train=True, download=True, transform=transform)
62 | mnist_test_ds = FashionMNIST_truncated(datadir, train=False, download=True, transform=transform)
63 |
64 | X_train, y_train = mnist_train_ds.data, mnist_train_ds.target
65 | X_test, y_test = mnist_test_ds.data, mnist_test_ds.target
66 |
67 | X_train = X_train.data.numpy()
68 | y_train = y_train.data.numpy()
69 | X_test = X_test.data.numpy()
70 | y_test = y_test.data.numpy()
71 |
72 | return (X_train, y_train, X_test, y_test)
73 |
74 | def load_svhn_data(datadir):
75 |
76 | transform = transforms.Compose([transforms.ToTensor()])
77 |
78 | svhn_train_ds = SVHN_custom(datadir, train=True, download=True, transform=transform)
79 | svhn_test_ds = SVHN_custom(datadir, train=False, download=True, transform=transform)
80 |
81 | X_train, y_train = svhn_train_ds.data, svhn_train_ds.target
82 | X_test, y_test = svhn_test_ds.data, svhn_test_ds.target
83 |
84 | # X_train = X_train.data.numpy()
85 | # y_train = y_train.data.numpy()
86 | # X_test = X_test.data.numpy()
87 | # y_test = y_test.data.numpy()
88 |
89 | return (X_train, y_train, X_test, y_test)
90 |
91 |
92 | def load_cifar10_data(datadir):
93 |
94 | transform = transforms.Compose([transforms.ToTensor()])
95 |
96 | cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=transform)
97 | cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=transform)
98 |
99 | X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target
100 | X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target
101 |
102 | # y_train = y_train.numpy()
103 | # y_test = y_test.numpy()
104 |
105 | return (X_train, y_train, X_test, y_test)
106 |
107 | def load_celeba_data(datadir):
108 |
109 | transform = transforms.Compose([transforms.ToTensor()])
110 |
111 | celeba_train_ds = CelebA_custom(datadir, split='train', target_type="attr", download=True, transform=transform)
112 | celeba_test_ds = CelebA_custom(datadir, split='test', target_type="attr", download=True, transform=transform)
113 |
114 | gender_index = celeba_train_ds.attr_names.index('Male')
115 | y_train = celeba_train_ds.attr[:,gender_index:gender_index+1].reshape(-1)
116 | y_test = celeba_test_ds.attr[:,gender_index:gender_index+1].reshape(-1)
117 |
118 | # y_train = y_train.numpy()
119 | # y_test = y_test.numpy()
120 |
121 | return (None, y_train, None, y_test)
122 |
123 | def load_femnist_data(datadir):
124 | transform = transforms.Compose([transforms.ToTensor()])
125 |
126 | mnist_train_ds = FEMNIST(datadir, train=True, transform=transform, download=True)
127 | mnist_test_ds = FEMNIST(datadir, train=False, transform=transform, download=True)
128 |
129 | X_train, y_train, u_train = mnist_train_ds.data, mnist_train_ds.targets, mnist_train_ds.users_index
130 | X_test, y_test, u_test = mnist_test_ds.data, mnist_test_ds.targets, mnist_test_ds.users_index
131 |
132 | X_train = X_train.data.numpy()
133 | y_train = y_train.data.numpy()
134 | u_train = np.array(u_train)
135 | X_test = X_test.data.numpy()
136 | y_test = y_test.data.numpy()
137 | u_test = np.array(u_test)
138 |
139 | return (X_train, y_train, u_train, X_test, y_test, u_test)
140 |
141 | def load_cifar100_data(datadir):
142 | transform = transforms.Compose([transforms.ToTensor()])
143 |
144 | cifar100_train_ds = CIFAR100_truncated(datadir, train=True, download=True, transform=transform)
145 | cifar100_test_ds = CIFAR100_truncated(datadir, train=False, download=True, transform=transform)
146 |
147 | X_train, y_train = cifar100_train_ds.data, cifar100_train_ds.target
148 | X_test, y_test = cifar100_test_ds.data, cifar100_test_ds.target
149 |
150 | # y_train = y_train.numpy()
151 | # y_test = y_test.numpy()
152 |
153 | return (X_train, y_train, X_test, y_test)
154 |
155 |
156 | def load_tinyimagenet_data(datadir):
157 | transform = transforms.Compose([transforms.ToTensor()])
158 | xray_train_ds = ImageFolder_custom(datadir+'./train/', transform=transform)
159 | xray_test_ds = ImageFolder_custom(datadir+'./val/', transform=transform)
160 |
161 | X_train, y_train = np.array([s[0] for s in xray_train_ds.samples]), np.array([int(s[1]) for s in xray_train_ds.samples])
162 | X_test, y_test = np.array([s[0] for s in xray_test_ds.samples]), np.array([int(s[1]) for s in xray_test_ds.samples])
163 |
164 | return (X_train, y_train, X_test, y_test)
165 |
166 | def record_net_data_stats(y_train, net_dataidx_map, logdir):
167 |
168 | net_cls_counts = {}
169 |
170 | for net_i, dataidx in net_dataidx_map.items():
171 | unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
172 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
173 | net_cls_counts[net_i] = tmp
174 |
175 | logger.info('Data statistics: %s' % str(net_cls_counts))
176 |
177 | return net_cls_counts
178 |
179 | def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4):
180 | #np.random.seed(2020)
181 | #torch.manual_seed(2020)
182 |
183 | if dataset == 'mnist':
184 | X_train, y_train, X_test, y_test = load_mnist_data(datadir)
185 | elif dataset == 'fmnist':
186 | X_train, y_train, X_test, y_test = load_fmnist_data(datadir)
187 | elif dataset == 'cifar10':
188 | X_train, y_train, X_test, y_test = load_cifar10_data(datadir)
189 | elif dataset == 'svhn':
190 | X_train, y_train, X_test, y_test = load_svhn_data(datadir)
191 | elif dataset == 'celeba':
192 | X_train, y_train, X_test, y_test = load_celeba_data(datadir)
193 | elif dataset == 'femnist':
194 | X_train, y_train, u_train, X_test, y_test, u_test = load_femnist_data(datadir)
195 | elif dataset == 'cifar100':
196 | X_train, y_train, X_test, y_test = load_cifar100_data(datadir)
197 | elif dataset == 'tinyimagenet':
198 | X_train, y_train, X_test, y_test = load_tinyimagenet_data(datadir)
199 | elif dataset == 'generated':
200 | X_train, y_train = [], []
201 | for loc in range(4):
202 | for i in range(1000):
203 | p1 = random.random()
204 | p2 = random.random()
205 | p3 = random.random()
206 | if loc > 1:
207 | p2 = -p2
208 | if loc % 2 ==1:
209 | p3 = -p3
210 | if i % 2 == 0:
211 | X_train.append([p1, p2, p3])
212 | y_train.append(0)
213 | else:
214 | X_train.append([-p1, -p2, -p3])
215 | y_train.append(1)
216 | X_test, y_test = [], []
217 | for i in range(1000):
218 | p1 = random.random() * 2 - 1
219 | p2 = random.random() * 2 - 1
220 | p3 = random.random() * 2 - 1
221 | X_test.append([p1, p2, p3])
222 | if p1>0:
223 | y_test.append(0)
224 | else:
225 | y_test.append(1)
226 | X_train = np.array(X_train, dtype=np.float32)
227 | X_test = np.array(X_test, dtype=np.float32)
228 | y_train = np.array(y_train, dtype=np.int32)
229 | y_test = np.array(y_test, dtype=np.int64)
230 | idxs = np.linspace(0,3999,4000,dtype=np.int64)
231 | batch_idxs = np.array_split(idxs, n_parties)
232 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}
233 | mkdirs("data/generated/")
234 | np.save("data/generated/X_train.npy",X_train)
235 | np.save("data/generated/X_test.npy",X_test)
236 | np.save("data/generated/y_train.npy",y_train)
237 | np.save("data/generated/y_test.npy",y_test)
238 |
239 | #elif dataset == 'covtype':
240 | # cov_type = sk.fetch_covtype('./data')
241 | # num_train = int(581012 * 0.75)
242 | # idxs = np.random.permutation(581012)
243 | # X_train = np.array(cov_type['data'][idxs[:num_train]], dtype=np.float32)
244 | # y_train = np.array(cov_type['target'][idxs[:num_train]], dtype=np.int32) - 1
245 | # X_test = np.array(cov_type['data'][idxs[num_train:]], dtype=np.float32)
246 | # y_test = np.array(cov_type['target'][idxs[num_train:]], dtype=np.int32) - 1
247 | # mkdirs("data/generated/")
248 | # np.save("data/generated/X_train.npy",X_train)
249 | # np.save("data/generated/X_test.npy",X_test)
250 | # np.save("data/generated/y_train.npy",y_train)
251 | # np.save("data/generated/y_test.npy",y_test)
252 |
253 | elif dataset in ('rcv1', 'SUSY', 'covtype'):
254 | X_train, y_train = load_svmlight_file(datadir+dataset)
255 | X_train = X_train.todense()
256 | num_train = int(X_train.shape[0] * 0.75)
257 | if dataset == 'covtype':
258 | y_train = y_train-1
259 | else:
260 | y_train = (y_train+1)/2
261 | idxs = np.random.permutation(X_train.shape[0])
262 |
263 | X_test = np.array(X_train[idxs[num_train:]], dtype=np.float32)
264 | y_test = np.array(y_train[idxs[num_train:]], dtype=np.int32)
265 | X_train = np.array(X_train[idxs[:num_train]], dtype=np.float32)
266 | y_train = np.array(y_train[idxs[:num_train]], dtype=np.int32)
267 |
268 | mkdirs("data/generated/")
269 | np.save("data/generated/X_train.npy",X_train)
270 | np.save("data/generated/X_test.npy",X_test)
271 | np.save("data/generated/y_train.npy",y_train)
272 | np.save("data/generated/y_test.npy",y_test)
273 |
274 | elif dataset in ('a9a'):
275 | X_train, y_train = load_svmlight_file(datadir+"a9a")
276 | X_test, y_test = load_svmlight_file(datadir+"a9a.t")
277 | X_train = X_train.todense()
278 | X_test = X_test.todense()
279 | X_test = np.c_[X_test, np.zeros((len(y_test), X_train.shape[1] - np.size(X_test[0, :])))]
280 |
281 | X_train = np.array(X_train, dtype=np.float32)
282 | X_test = np.array(X_test, dtype=np.float32)
283 | y_train = (y_train+1)/2
284 | y_test = (y_test+1)/2
285 | y_train = np.array(y_train, dtype=np.int32)
286 | y_test = np.array(y_test, dtype=np.int32)
287 |
288 | mkdirs("data/generated/")
289 | np.save("data/generated/X_train.npy",X_train)
290 | np.save("data/generated/X_test.npy",X_test)
291 | np.save("data/generated/y_train.npy",y_train)
292 | np.save("data/generated/y_test.npy",y_test)
293 |
294 |
295 | n_train = y_train.shape[0]
296 |
297 | if partition == "homo":
298 | idxs = np.random.permutation(n_train)
299 | batch_idxs = np.array_split(idxs, n_parties)
300 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}
301 |
302 |
303 | elif partition == "noniid-labeldir":
304 | min_size = 0
305 | min_require_size = 10
306 | K = 10
307 | if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'):
308 | K = 2
309 | # min_require_size = 100
310 | if dataset == 'cifar100':
311 | K = 100
312 | elif dataset == 'tinyimagenet':
313 | K = 200
314 |
315 | N = y_train.shape[0]
316 | #np.random.seed(2020)
317 | net_dataidx_map = {}
318 |
319 | while min_size < min_require_size:
320 | idx_batch = [[] for _ in range(n_parties)]
321 | for k in range(K):
322 | idx_k = np.where(y_train == k)[0]
323 | np.random.shuffle(idx_k)
324 | proportions = np.random.dirichlet(np.repeat(beta, n_parties))
325 | # logger.info("proportions1: ", proportions)
326 | # logger.info("sum pro1:", np.sum(proportions))
327 | ## Balance
328 | proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)])
329 | # logger.info("proportions2: ", proportions)
330 | proportions = proportions / proportions.sum()
331 | # logger.info("proportions3: ", proportions)
332 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
333 | # logger.info("proportions4: ", proportions)
334 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
335 | min_size = min([len(idx_j) for idx_j in idx_batch])
336 | # if K == 2 and n_parties <= 10:
337 | # if np.min(proportions) < 200:
338 | # min_size = 0
339 | # break
340 |
341 |
342 | for j in range(n_parties):
343 | np.random.shuffle(idx_batch[j])
344 | net_dataidx_map[j] = idx_batch[j]
345 |
346 | elif partition > "noniid-#label0" and partition <= "noniid-#label9":
347 | num = eval(partition[13:])
348 | if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'):
349 | num = 1
350 | K = 2
351 | else:
352 | K = 10
353 | if dataset == "cifar100":
354 | K = 100
355 | elif dataset == "tinyimagenet":
356 | K = 200
357 | if num == 10:
358 | net_dataidx_map ={i:np.ndarray(0,dtype=np.int64) for i in range(n_parties)}
359 | for i in range(10):
360 | idx_k = np.where(y_train==i)[0]
361 | np.random.shuffle(idx_k)
362 | split = np.array_split(idx_k,n_parties)
363 | for j in range(n_parties):
364 | net_dataidx_map[j]=np.append(net_dataidx_map[j],split[j])
365 | else:
366 | times=[0 for i in range(K)]
367 | contain=[]
368 | for i in range(n_parties):
369 | current=[i%K]
370 | times[i%K]+=1
371 | j=1
372 | while (j0:
508 | check[j]=1
509 | flag=False
510 | for i in range(10):
511 | if check[i]==0:
512 | flag=True
513 | break
514 |
515 |
516 | if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'):
517 | K = 2
518 | stat[:,0]=np.sum(stat[:,:5],axis=1)
519 | stat[:,1]=np.sum(stat[:,5:],axis=1)
520 | else:
521 | K = 10
522 |
523 | N = y_train.shape[0]
524 | #np.random.seed(2020)
525 | net_dataidx_map = {}
526 |
527 | idx_batch = [[] for _ in range(n_parties)]
528 | for k in range(K):
529 | idx_k = np.where(y_train == k)[0]
530 | np.random.shuffle(idx_k)
531 | proportions = stat[:,k]
532 | # logger.info("proportions2: ", proportions)
533 | proportions = proportions / proportions.sum()
534 | # logger.info("proportions3: ", proportions)
535 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
536 | # logger.info("proportions4: ", proportions)
537 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
538 |
539 |
540 | for j in range(n_parties):
541 | np.random.shuffle(idx_batch[j])
542 | net_dataidx_map[j] = idx_batch[j]
543 |
544 | traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map, logdir)
545 | return (X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts)
546 |
547 |
548 | def get_trainable_parameters(net):
549 | 'return trainable parameter values as a vector (only the first parameter set)'
550 | trainable=filter(lambda p: p.requires_grad, net.parameters())
551 | # logger.info("net.parameter.data:", list(net.parameters()))
552 | paramlist=list(trainable)
553 | N=0
554 | for params in paramlist:
555 | N+=params.numel()
556 | # logger.info("params.data:", params.data)
557 | X=torch.empty(N,dtype=torch.float64)
558 | X.fill_(0.0)
559 | offset=0
560 | for params in paramlist:
561 | numel=params.numel()
562 | with torch.no_grad():
563 | X[offset:offset+numel].copy_(params.data.view_as(X[offset:offset+numel].data))
564 | offset+=numel
565 | # logger.info("get trainable x:", X)
566 | return X
567 |
568 |
569 | def put_trainable_parameters(net,X):
570 | 'replace trainable parameter values by the given vector (only the first parameter set)'
571 | trainable=filter(lambda p: p.requires_grad, net.parameters())
572 | paramlist=list(trainable)
573 | offset=0
574 | for params in paramlist:
575 | numel=params.numel()
576 | with torch.no_grad():
577 | params.data.copy_(X[offset:offset+numel].data.view_as(params.data))
578 | offset+=numel
579 |
580 | def compute_accuracy(model, dataloader, get_confusion_matrix=False, moon_model=False, device="cpu"):
581 |
582 | was_training = False
583 | if model.training:
584 | model.eval()
585 | was_training = True
586 |
587 | true_labels_list, pred_labels_list = np.array([]), np.array([])
588 |
589 | if type(dataloader) == type([1]):
590 | pass
591 | else:
592 | dataloader = [dataloader]
593 |
594 | correct, total = 0, 0
595 | with torch.no_grad():
596 | for tmp in dataloader:
597 | for batch_idx, (x, target) in enumerate(tmp):
598 | x, target = x.to(device), target.to(device,dtype=torch.int64)
599 | if moon_model:
600 | _, _, out = model(x)
601 | else:
602 | out = model(x)
603 | _, pred_label = torch.max(out.data, 1)
604 |
605 | total += x.data.size()[0]
606 | correct += (pred_label == target.data).sum().item()
607 |
608 | if device == "cpu":
609 | pred_labels_list = np.append(pred_labels_list, pred_label.numpy())
610 | true_labels_list = np.append(true_labels_list, target.data.numpy())
611 | else:
612 | pred_labels_list = np.append(pred_labels_list, pred_label.cpu().numpy())
613 | true_labels_list = np.append(true_labels_list, target.data.cpu().numpy())
614 |
615 | if get_confusion_matrix:
616 | conf_matrix = confusion_matrix(true_labels_list, pred_labels_list)
617 |
618 | if was_training:
619 | model.train()
620 |
621 | if get_confusion_matrix:
622 | return correct/float(total), conf_matrix
623 |
624 | return correct/float(total)
625 |
626 |
627 | def save_model(model, model_index, args):
628 | logger.info("saving local model-{}".format(model_index))
629 | with open(args.modeldir+"trained_local_model"+str(model_index), "wb") as f_:
630 | torch.save(model.state_dict(), f_)
631 | return
632 |
633 | def load_model(model, model_index, device="cpu"):
634 | #
635 | with open("trained_local_model"+str(model_index), "rb") as f_:
636 | model.load_state_dict(torch.load(f_))
637 | model.to(device)
638 | return model
639 |
640 | class AddGaussianNoise(object):
641 | def __init__(self, mean=0., std=1., net_id=None, total=0):
642 | self.std = std
643 | self.mean = mean
644 | self.net_id = net_id
645 | self.num = int(sqrt(total))
646 | if self.num * self.num < total:
647 | self.num = self.num + 1
648 |
649 | def __call__(self, tensor):
650 | if self.net_id is None:
651 | return tensor + torch.randn(tensor.size()) * self.std + self.mean
652 | else:
653 | tmp = torch.randn(tensor.size())
654 | filt = torch.zeros(tensor.size())
655 | size = int(28 / self.num)
656 | row = int(self.net_id / size)
657 | col = self.net_id % size
658 | for i in range(size):
659 | for j in range(size):
660 | filt[:,row*size+i,col*size+j] = 1
661 | tmp = tmp * filt
662 | return tensor + tmp * self.std + self.mean
663 |
664 | def __repr__(self):
665 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
666 |
667 | def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, noise_level=0, net_id=None, total=0):
668 | if dataset in ('mnist', 'femnist', 'fmnist', 'cifar10', 'svhn', 'generated', 'covtype', 'a9a', 'rcv1', 'SUSY', 'cifar100', 'tinyimagenet'):
669 | if dataset == 'mnist':
670 | dl_obj = MNIST_truncated
671 |
672 | transform_train = transforms.Compose([
673 | transforms.ToTensor(),
674 | AddGaussianNoise(0., noise_level, net_id, total)])
675 |
676 | transform_test = transforms.Compose([
677 | transforms.ToTensor(),
678 | AddGaussianNoise(0., noise_level, net_id, total)])
679 |
680 | elif dataset == 'femnist':
681 | dl_obj = FEMNIST
682 | transform_train = transforms.Compose([
683 | transforms.ToTensor(),
684 | AddGaussianNoise(0., noise_level, net_id, total)])
685 | transform_test = transforms.Compose([
686 | transforms.ToTensor(),
687 | AddGaussianNoise(0., noise_level, net_id, total)])
688 |
689 | elif dataset == 'fmnist':
690 | dl_obj = FashionMNIST_truncated
691 | transform_train = transforms.Compose([
692 | transforms.ToTensor(),
693 | AddGaussianNoise(0., noise_level, net_id, total)])
694 | transform_test = transforms.Compose([
695 | transforms.ToTensor(),
696 | AddGaussianNoise(0., noise_level, net_id, total)])
697 |
698 | elif dataset == 'svhn':
699 | dl_obj = SVHN_custom
700 | transform_train = transforms.Compose([
701 | transforms.ToTensor(),
702 | AddGaussianNoise(0., noise_level, net_id, total)])
703 | transform_test = transforms.Compose([
704 | transforms.ToTensor(),
705 | AddGaussianNoise(0., noise_level, net_id, total)])
706 |
707 |
708 | elif dataset == 'cifar10':
709 | dl_obj = CIFAR10_truncated
710 |
711 | transform_train = transforms.Compose([
712 | transforms.ToTensor(),
713 | transforms.Lambda(lambda x: F.pad(
714 | Variable(x.unsqueeze(0), requires_grad=False),
715 | (4, 4, 4, 4), mode='reflect').data.squeeze()),
716 | transforms.ToPILImage(),
717 | transforms.RandomCrop(32),
718 | transforms.RandomHorizontalFlip(),
719 | transforms.ToTensor(),
720 | AddGaussianNoise(0., noise_level, net_id, total)
721 | ])
722 | # data prep for test set
723 | transform_test = transforms.Compose([
724 | transforms.ToTensor(),
725 | AddGaussianNoise(0., noise_level, net_id, total)])
726 |
727 | elif dataset == 'cifar100':
728 | dl_obj = CIFAR100_truncated
729 |
730 | normalize = transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
731 | std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])
732 | # transform_train = transforms.Compose([
733 | # transforms.RandomCrop(32),
734 | # transforms.RandomHorizontalFlip(),
735 | # transforms.ToTensor(),
736 | # normalize
737 | # ])
738 | transform_train = transforms.Compose([
739 | # transforms.ToPILImage(),
740 | transforms.RandomCrop(32, padding=4),
741 | transforms.RandomHorizontalFlip(),
742 | transforms.RandomRotation(15),
743 | transforms.ToTensor(),
744 | normalize
745 | ])
746 | # data prep for test set
747 | transform_test = transforms.Compose([
748 | transforms.ToTensor(),
749 | normalize])
750 | elif dataset == 'tinyimagenet':
751 | dl_obj = ImageFolder_custom
752 | transform_train = transforms.Compose([
753 | transforms.Resize(32),
754 | transforms.RandomCrop(32, padding=4),
755 | transforms.RandomHorizontalFlip(),
756 | transforms.RandomRotation(15),
757 | transforms.ToTensor(),
758 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
759 | ])
760 | transform_test = transforms.Compose([
761 | transforms.Resize(32),
762 | transforms.ToTensor(),
763 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
764 | ])
765 |
766 | else:
767 | dl_obj = Generated
768 | transform_train = None
769 | transform_test = None
770 |
771 |
772 | if dataset == "tinyimagenet":
773 | train_ds = dl_obj(datadir+'./train/', dataidxs=dataidxs, transform=transform_train)
774 | test_ds = dl_obj(datadir+'./val/', transform=transform_test)
775 | else:
776 | train_ds = dl_obj(datadir, dataidxs=dataidxs, train=True, transform=transform_train, download=True)
777 | test_ds = dl_obj(datadir, train=False, transform=transform_test, download=True)
778 |
779 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=False)
780 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=False)
781 |
782 | return train_dl, test_dl, train_ds, test_ds
783 |
784 |
785 | def weights_init(m):
786 | """
787 | Initialise weights of the model.
788 | """
789 | if(type(m) == nn.ConvTranspose2d or type(m) == nn.Conv2d):
790 | nn.init.normal_(m.weight.data, 0.0, 0.02)
791 | elif(type(m) == nn.BatchNorm2d):
792 | nn.init.normal_(m.weight.data, 1.0, 0.02)
793 | nn.init.constant_(m.bias.data, 0)
794 |
795 | class NormalNLLLoss:
796 | """
797 | Calculate the negative log likelihood
798 | of normal distribution.
799 | This needs to be minimised.
800 |
801 | Treating Q(cj | x) as a factored Gaussian.
802 | """
803 | def __call__(self, x, mu, var):
804 |
805 | logli = -0.5 * (var.mul(2 * np.pi) + 1e-6).log() - (x - mu).pow(2).div(var.mul(2.0) + 1e-6)
806 | nll = -(logli.sum(1).mean())
807 |
808 | return nll
809 |
810 |
811 | def noise_sample(choice, n_dis_c, dis_c_dim, n_con_c, n_z, batch_size, device):
812 | """
813 | Sample random noise vector for training.
814 |
815 | INPUT
816 | --------
817 | n_dis_c : Number of discrete latent code.
818 | dis_c_dim : Dimension of discrete latent code.
819 | n_con_c : Number of continuous latent code.
820 | n_z : Dimension of iicompressible noise.
821 | batch_size : Batch Size
822 | device : GPU/CPU
823 | """
824 |
825 | z = torch.randn(batch_size, n_z, 1, 1, device=device)
826 | idx = np.zeros((n_dis_c, batch_size))
827 | if(n_dis_c != 0):
828 | dis_c = torch.zeros(batch_size, n_dis_c, dis_c_dim, device=device)
829 |
830 | c_tmp = np.array(choice)
831 |
832 | for i in range(n_dis_c):
833 | idx[i] = np.random.randint(len(choice), size=batch_size)
834 | for j in range(batch_size):
835 | idx[i][j] = c_tmp[int(idx[i][j])]
836 |
837 | dis_c[torch.arange(0, batch_size), i, idx[i]] = 1.0
838 |
839 | dis_c = dis_c.view(batch_size, -1, 1, 1)
840 |
841 | if(n_con_c != 0):
842 | # Random uniform between -1 and 1.
843 | con_c = torch.rand(batch_size, n_con_c, 1, 1, device=device) * 2 - 1
844 |
845 | noise = z
846 | if(n_dis_c != 0):
847 | noise = torch.cat((z, dis_c), dim=1)
848 | if(n_con_c != 0):
849 | noise = torch.cat((noise, con_c), dim=1)
850 |
851 | return noise, idx
852 |
853 |
854 |
--------------------------------------------------------------------------------
/vggmodel.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 |
6 | __all__ = [
7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
8 | 'vgg19_bn', 'vgg19',
9 | ]
10 |
11 |
12 | class VGG(nn.Module):
13 | '''
14 | VGG model
15 | '''
16 | def __init__(self, features):
17 | super(VGG, self).__init__()
18 | self.features = features
19 | self.classifier = nn.Sequential(
20 | nn.Dropout(),
21 | nn.Linear(512, 512),
22 | nn.ReLU(True),
23 | nn.Dropout(),
24 | nn.Linear(512, 512),
25 | nn.ReLU(True),
26 | nn.Linear(512, 10),
27 | )
28 | # Initialize weights
29 | for m in self.modules():
30 | if isinstance(m, nn.Conv2d):
31 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
32 | m.weight.data.normal_(0, math.sqrt(2. / n))
33 | m.bias.data.zero_()
34 |
35 |
36 | def forward(self, x):
37 | x = self.features(x)
38 | x = x.view(x.size(0), -1)
39 | x = self.classifier(x)
40 | return x
41 |
42 |
43 | def make_layers(cfg, batch_norm=False):
44 | layers = []
45 | in_channels = 3
46 | for v in cfg:
47 | if v == 'M':
48 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
49 | else:
50 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
51 | if batch_norm:
52 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
53 | else:
54 | layers += [conv2d, nn.ReLU(inplace=True)]
55 | in_channels = v
56 | return nn.Sequential(*layers)
57 |
58 |
59 | cfg = {
60 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
61 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
62 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
63 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
64 | 512, 512, 512, 512, 'M'],
65 | }
66 |
67 |
68 | def vgg11():
69 | """VGG 11-layer model (configuration "A")"""
70 | return VGG(make_layers(cfg['A']))
71 |
72 |
73 | def vgg11_bn():
74 | """VGG 11-layer model (configuration "A") with batch normalization"""
75 | return VGG(make_layers(cfg['A'], batch_norm=True))
76 |
77 |
78 | def vgg13():
79 | """VGG 13-layer model (configuration "B")"""
80 | return VGG(make_layers(cfg['B']))
81 |
82 |
83 | def vgg13_bn():
84 | """VGG 13-layer model (configuration "B") with batch normalization"""
85 | return VGG(make_layers(cfg['B'], batch_norm=True))
86 |
87 |
88 | def vgg16():
89 | """VGG 16-layer model (configuration "D")"""
90 | return VGG(make_layers(cfg['D']))
91 |
92 |
93 | def vgg16_bn():
94 | """VGG 16-layer model (configuration "D") with batch normalization"""
95 | return VGG(make_layers(cfg['D'], batch_norm=True))
96 |
97 |
98 | def vgg19():
99 | """VGG 19-layer model (configuration "E")"""
100 | return VGG(make_layers(cfg['E']))
101 |
102 |
103 | def vgg19_bn():
104 | """VGG 19-layer model (configuration 'E') with batch normalization"""
105 | return VGG(make_layers(cfg['E'], batch_norm=True))
106 |
--------------------------------------------------------------------------------