├── LICENSE ├── README.md ├── config.py ├── criteo-dis.npy ├── datasets.py ├── experiments.py ├── femnist-dis.npy ├── figures ├── 100parties │ ├── cifar10-homo.png │ ├── cifar10-lb2.png │ ├── cifar10-lbdir.png │ └── cifar10-quan.png ├── 10parties │ ├── cifar10-iid-diff-quantity.png │ ├── cifar10-noise.png │ ├── cifar10-noniid-label2.png │ └── cifar10-noniid-labeldir.png └── heavy-model │ ├── resnet-noise.png │ └── vgg-lbdir.png ├── model.py ├── models ├── celeba_model.py ├── mnist_model.py └── svhn_model.py ├── partition.py ├── partition_to_file.sh ├── requirements.txt ├── resnetcifar.py ├── run.sh ├── scripts ├── 100-parties.sh ├── adult&covtype.sh ├── batch-size.sh ├── fcube.sh ├── femnist.sh ├── image-data-with-noise.sh ├── image-data-without-noise.sh ├── rcv1.sh └── vgg&resnet.sh ├── utils.py └── vggmodel.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yiqun Diao, Qinbin Li 4 | 5 | Copyright (c) 2020 International Business Machines 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NIID-Bench 2 | 3 | [![paper](https://img.shields.io/badge/PAPER-arXiv-yellowgreen?style=for-the-badge)](https://arxiv.org/pdf/2102.02079.pdf) 4 |     5 | [![paper](https://img.shields.io/badge/leaderboard-5%2B%20Methods-228c22?style=for-the-badge)](https://niidbench.xtra.science/) 6 |     7 | 8 | This is the code of paper [Federated Learning on Non-IID Data Silos: An Experimental Study](https://arxiv.org/pdf/2102.02079.pdf). 9 | 10 | 11 | This code runs a benchmark for federated learning algorithms under non-IID data distribution scenarios. Specifically, we implement 4 federated learning algorithms (FedAvg, FedProx, SCAFFOLD & FedNova), 3 types of non-IID settings (label distribution skew, feature distribution skew & quantity skew) and 9 datasets (MNIST, Cifar-10, Fashion-MNIST, SVHN, Generated 3D dataset, FEMNIST, adult, rcv1, covtype). 12 | 13 | 14 | ## Updates on NIID-Bench 15 | Our follow-up works based on NIID-Bench: 16 | 17 | * [FedOV](https://github.com/Xtra-Computing/FedOV): Towards Addressing Label Skews in One-Shot Federated Learning (ICLR 2023) 18 | 19 | * [FedConcat](https://github.com/sjtudyq/FedConcat): Exploiting Label Skew in Federated Learning with Model Concatenation (AAAI 2024) 20 | 21 | We publish NIID-Bench challenge https://niidbench.xtra.science, a benchmark to compare federated learning algorithms on comprehensive non-IID data settings. Researchers are welcome to test their algorithms on these settings, upload their codes and participate in our leaderboard! 22 | 23 | Implement `partition.py` to divide tabular datasets (csv format) into multiple files using our non-IID partitioning strategies. Column `Class` in the header is recognized as label. See an running example in `partition_to_file.sh`. The example dataset is [Credit Card Fraud Detection](https://www.kaggle.com/datasets/mlg-ulb/creditcardfraud). 24 | 25 | To adapt to your own tabular dataset in ​``partition.py``, you need the following steps: 26 | 27 | 1. Load your own dataset in arrays. Replace Line 117-126. 28 | 2. The whole tabular dataset is stored in ​``dataset​​``. The label column ID is stored in ``class_id​​``. Change Line 130 to your own label identifier. 29 | 30 | If your dataset is image dataset, ``partition.py​​`` is no longer applicable. You can refer to our function ``partition_data​​`` in ``utils.py``. You need to design your own dataloader like Line 183-198. For example, in load_mnist_data (Line 40), you need to write a dataloader to return your dataset as tuple (X_train, y_train, X_test, y_test). In terms of the dataloader format, you can refer to class ``MNIST_truncated​​`` (Line 60 in ``dataset.py``). After you get (X_train, y_train, X_test, y_test), the ``partition_data`` function will return the ​``net_dataidx_map``​. 31 | 32 | To support more settings and faciliate future researches, we now integrate MOON. We add CIFAR-100 and Tiny-ImageNet. 33 | 34 | ### Tiny-ImageNet 35 | You can download Tiny-ImageNet [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip). Then, you can follow the [instructions](https://github.com/AI-secure/DBA/blob/master/utils/tinyimagenet_reformat.py) to reformat the validation folder. 36 | 37 | ## Non-IID Settings 38 | ### Label Distribution Skew 39 | * **Quantity-based label imbalance**: each party owns data samples of a fixed number of labels. 40 | * **Distribution-based label imbalance**: each party is allocated a proportion of the samples of each label according to Dirichlet distribution. 41 | ### Feature Distribution Skew 42 | * **Noise-based feature imbalance**: We first divide the whole dataset into multiple parties randomly and equally. For each party, we add different levels of Gaussian noises. 43 | * **Synthetic feature imbalance**: For generated 3D data set, we allocate two parts which are symmetric of(0,0,0) to a subset for each party. 44 | * **Real-world feature imbalance**: For FEMNIST, we divide and assign the writers (and their characters) into each party randomly and equally. 45 | ### Quantity Skew 46 | * While the data distribution may still be consistent amongthe parties, the size of local dataset varies according to Dirichlet distribution. 47 | 48 | 49 | 50 | ## Usage 51 | Here is one example to run this code: 52 | ``` 53 | python experiments.py --model=simple-cnn \ 54 | --dataset=cifar10 \ 55 | --alg=fedprox \ 56 | --lr=0.01 \ 57 | --batch-size=64 \ 58 | --epochs=10 \ 59 | --n_parties=10 \ 60 | --mu=0.01 \ 61 | --rho=0.9 \ 62 | --comm_round=50 \ 63 | --partition=noniid-labeldir \ 64 | --beta=0.5\ 65 | --device='cuda:0'\ 66 | --datadir='./data/' \ 67 | --logdir='./logs/' \ 68 | --noise=0 \ 69 | --sample=1 \ 70 | --init_seed=0 71 | ``` 72 | 73 | | Parameter | Description | 74 | | ----------------------------- | ---------------------------------------- | 75 | | `model` | The model architecture. Options: `simple-cnn`, `vgg`, `resnet`, `mlp`. Default = `mlp`. | 76 | | `dataset` | Dataset to use. Options: `mnist`, `cifar10`, `fmnist`, `svhn`, `generated`, `femnist`, `a9a`, `rcv1`, `covtype`. Default = `mnist`. | 77 | | `alg` | The training algorithm. Options: `fedavg`, `fedprox`, `scaffold`, `fednova`, `moon`. Default = `fedavg`. | 78 | | `lr` | Learning rate for the local models, default = `0.01`. | 79 | | `batch-size` | Batch size, default = `64`. | 80 | | `epochs` | Number of local training epochs, default = `5`. | 81 | | `n_parties` | Number of parties, default = `2`. | 82 | | `mu` | The proximal term parameter for FedProx, default = `0.001`. | 83 | | `rho` | The parameter controlling the momentum SGD, default = `0`. | 84 | | `comm_round` | Number of communication rounds to use, default = `50`. | 85 | | `partition` | The partition way. Options: `homo`, `noniid-labeldir`, `noniid-#label1` (or 2, 3, ..., which means the fixed number of labels each party owns), `real`, `iid-diff-quantity`. Default = `homo` | 86 | | `beta` | The concentration parameter of the Dirichlet distribution for heterogeneous partition, default = `0.5`. | 87 | | `device` | Specify the device to run the program, default = `cuda:0`. | 88 | | `datadir` | The path of the dataset, default = `./data/`. | 89 | | `logdir` | The path to store the logs, default = `./logs/`. | 90 | | `noise` | Maximum variance of Gaussian noise we add to local party, default = `0`. | 91 | | `sample` | Ratio of parties that participate in each communication round, default = `1`. | 92 | | `init_seed` | The initial seed, default = `0`. | 93 | 94 | 95 | 96 | ## Data Partition Map 97 | You can call function `get_partition_dict()` in `experiments.py` to access `net_dataidx_map`. `net_dataidx_map` is a dictionary. Its keys are party ID, and the value of each key is a list containing index of data assigned to this party. For our experiments, we usually set `init_seed=0`. When we repeat experiments of some setting, we change `init_seed` to 1 or 2. The default value of `noise` is 0 unless stated. We list the way to get our data partition as follow. 98 | * **Quantity-based label imbalance**: `partition`=`noniid-#label1`, `noniid-#label2` or `noniid-#label3` 99 | * **Distribution-based label imbalance**: `partition`=`noniid-labeldir`, `beta`=`0.5` or `0.1` 100 | * **Noise-based feature imbalance**: `partition`=`homo`, `noise`=`0.1` (actually noise does not affect `net_dataidx_map`) 101 | * **Synthetic feature imbalance & Real-world feature imbalance**: `partition`=`real` 102 | * **Quantity Skew**: `partition`=`iid-diff-quantity`, `beta`=`0.5` or `0.1` 103 | * **IID Setting**: `partition`=`homo` 104 | * **Mixed skew**: `partition` = `mixed` for mixture of distribution-based label imbalance and quantity skew; `partition` = `noniid-labeldir` and `noise` = `0.1` for mixture of distribution-based label imbalance and noise-based feature imbalance. 105 | 106 | Here is explanation of parameter for function `get_partition_dict()`. 107 | 108 | | Parameter | Description | 109 | | ----------------------------- | ---------------------------------------- | 110 | | `dataset` | Dataset to use. Options: `mnist`, `cifar10`, `fmnist`, `svhn`, `generated`, `femnist`, `a9a`, `rcv1`, `covtype`. | 111 | | `partition` | Tha partition way. Options: `homo`, `noniid-labeldir`, `noniid-#label1` (or 2, 3, ..., which means the fixed number of labels each party owns), `real`, `iid-diff-quantity` | 112 | | `n_parties` | Number of parties. | 113 | | `init_seed` | The initial seed. | 114 | | `datadir` | The path of the dataset. | 115 | | `logdir` | The path to store the logs. | 116 | | `beta` | The concentration parameter of the Dirichlet distribution for heterogeneous partition. | 117 | 118 | ## Leader Board 119 | 120 | Note that the accuracy shows the average of three experiments, while the training curve is based on only one experiment. Thus, there may be some difference. We show the training curve to compare convergence rate of different algorithms. 121 | 122 | ### Quantity-based label imbalance 123 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01 124 | 125 | | Partition | Model | Round | Algorithm | Accuracy | 126 | | --------------|--------------- | -------------- | ------------ | -------------- | 127 | | `noniid-#label2` | `simple-cnn` |50| FedProx (`mu=0.01`) | 50.7% | 128 | | `noniid-#label2` | `simple-cnn` |50| FedAvg | 49.8% | 129 | | `noniid-#label2` | `simple-cnn` |50| SCAFFOLD | 49.1% | 130 | | `noniid-#label2` | `simple-cnn` |50| FedNova | 46.5% | 131 | 132 |
133 | 134 | 135 | * Cifar-10, 100 parties, sample rate = 0.1, batch size = 64, learning rate = 0.01 136 | 137 | | Partition | Model | Round | Algorithm | Accuracy | 138 | | --------------|--------------- | -------------- | ------------ | -------------- | 139 | | `noniid-#label2` | `simple-cnn` |500| FedNova | 48.0% | 140 | | `noniid-#label2` | `simple-cnn` |500| FedAvg | 45.3% | 141 | | `noniid-#label2` | `simple-cnn` |500| FedProx (`mu=0.001`) | 39.3% | 142 | | `noniid-#label2` | `simple-cnn` |500| SCAFFOLD | 10.0% | 143 | 144 |
145 | 146 | ### Distribution-based label imbalance 147 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01 148 | 149 | | Partition | Model | Round | Algorithm | Accuracy | 150 | | --------------|--------------- | -------------- | ------------ | -------------- | 151 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| SCAFFOLD | 69.8% | 152 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| FedAvg | 68.2% | 153 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| FedProx (`mu=0.001`) | 67.9% | 154 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |50| FedNova | 66.8% | 155 | 156 |
157 | 158 | | Partition | Model | Round | Algorithm | Accuracy | 159 | | --------------|--------------- | -------------- | ------------ | -------------- | 160 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| SCAFFOLD | 85.5% | 161 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| FedNova | 84.4% | 162 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| FedProx (`mu=0.01`) | 84.4% | 163 | | `noniid-labeldir` with `beta=0.1` | `vgg` |100| FedAvg | 84.0% | 164 | 165 |
166 | 167 | * Cifar-10, 100 parties, sample rate = 0.1, batch size = 64, learning rate = 0.01 168 | 169 | | Partition | Model | Round | Algorithm | Accuracy | 170 | | --------------|--------------- | -------------- | ------------ | -------------- | 171 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| FedNova | 60.0% | 172 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| FedAvg | 59.4% | 173 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| FedProx (`mu=0.001`) | 58.8% | 174 | | `noniid-labeldir` with `beta=0.5` | `simple-cnn` |500| SCAFFOLD | 10.0% | 175 | 176 |
177 | 178 | ### Noise-based feature imbalance 179 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01 180 | 181 | | Partition | Model | Round | Algorithm | Accuracy | 182 | | --------------|--------------- | -------------- | ------------ | -------------- | 183 | | `homo` with `noise=0.1` | `simple-cnn` |50| SCAFFOLD | 70.1% | 184 | | `homo` with `noise=0.1` | `simple-cnn` |50| FedProx (`mu=0.01`) | 69.3% | 185 | | `homo` with `noise=0.1` | `simple-cnn` |50| FedAvg | 68.9% | 186 | | `homo` with `noise=0.1` | `simple-cnn` |50| FedNova | 68.5% | 187 | 188 |
189 | 190 | | Partition | Model | Round | Algorithm | Accuracy | 191 | | --------------|--------------- | -------------- | ------------ | -------------- | 192 | | `homo` with `noise=0.1` | `resnet` |100| SCAFFOLD | 90.2% | 193 | | `homo` with `noise=0.1` | `resnet` |100| FedNova | 89.4% | 194 | | `homo` with `noise=0.1` | `resnet` |100| FedProx (`mu=0.01`) | 89.2% | 195 | | `homo` with `noise=0.1` | `resnet` |100| FedAvg | 89.1% | 196 | 197 |
198 | 199 | ### Quantity Skew 200 | * Cifar-10, 10 parties, sample rate = 1, batch size = 64, learning rate = 0.01 201 | 202 | | Partition | Model | Round | Algorithm | Accuracy | 203 | | --------------|--------------- | -------------- | ------------ | -------------- | 204 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| FedAvg | 72.0% | 205 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| FedProx (`mu=0.01`) | 71.2% | 206 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| SCAFFOLD | 62.4% | 207 | | `iid-diff-quantity` with `beta=0.5` | `simple-cnn` |50| FedNova | 10.0% | 208 | 209 |
210 | 211 | ### IID Setting 212 | * Cifar-10, 100 parties, sample rate = 0.1, batch size = 64, learning rate = 0.01 213 | 214 | | Partition | Model | Round | Algorithm | Accuracy | 215 | | --------------|--------------- | -------------- | ------------ | -------------- | 216 | |`homo`| `simple-cnn` |500| FedNova | 66.1% | 217 | |`homo`| `simple-cnn` |500| FedProx (`mu=0.01`) | 66.0% | 218 | |`homo`| `simple-cnn` |500| FedAvg | 65.6% | 219 | |`homo`| `simple-cnn` |500| SCAFFOLD | 10.0% | 220 | 221 |
222 | 223 | ## Citation 224 | If you find this repository useful, please cite our paper: 225 | 226 | ``` 227 | @inproceedings{li2022federated, 228 | title={Federated Learning on Non-IID Data Silos: An Experimental Study}, 229 | author={Li, Qinbin and Diao, Yiqun and Chen, Quan and He, Bingsheng}, 230 | booktitle={IEEE International Conference on Data Engineering}, 231 | year={2022} 232 | } 233 | ``` 234 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Dictionary storing network parameters. 2 | params = { 3 | 'num_epochs': 100,# Number of epochs to train for. 4 | 'learning_rate': 2e-4,# Learning rate. 5 | 'beta1': 0.5, 6 | 'beta2': 0.999, 7 | 'save_epoch' : 10,# After how many epochs to save checkpoints and generate test output. 8 | } 9 | -------------------------------------------------------------------------------- /criteo-dis.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/criteo-dis.npy -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | from PIL import Image 4 | import numpy as np 5 | from torchvision.datasets import MNIST, CIFAR10, SVHN, FashionMNIST, CIFAR100, ImageFolder, DatasetFolder, utils 6 | from torchvision.datasets.vision import VisionDataset 7 | from torchvision.datasets.utils import download_file_from_google_drive, check_integrity 8 | from functools import partial 9 | from typing import Optional, Callable 10 | from torch.utils.model_zoo import tqdm 11 | import PIL 12 | import tarfile 13 | import torchvision 14 | 15 | import os 16 | import os.path 17 | import logging 18 | import torchvision.datasets.utils as utils 19 | 20 | logging.basicConfig() 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | 24 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 25 | 26 | def mkdirs(dirpath): 27 | try: 28 | os.makedirs(dirpath) 29 | except Exception as _: 30 | pass 31 | 32 | def accimage_loader(path): 33 | import accimage 34 | try: 35 | return accimage.Image(path) 36 | except IOError: 37 | # Potentially a decoding problem, fall back to PIL.Image 38 | return pil_loader(path) 39 | 40 | 41 | def pil_loader(path): 42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 43 | with open(path, 'rb') as f: 44 | img = Image.open(f) 45 | return img.convert('RGB') 46 | 47 | 48 | def default_loader(path): 49 | from torchvision import get_image_backend 50 | if get_image_backend() == 'accimage': 51 | return accimage_loader(path) 52 | else: 53 | return pil_loader(path) 54 | 55 | class CustomTensorDataset(data.TensorDataset): 56 | def __getitem__(self, index): 57 | return tuple(tensor[index] for tensor in self.tensors) + (index,) 58 | 59 | 60 | class MNIST_truncated(data.Dataset): 61 | 62 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 63 | 64 | self.root = root 65 | self.dataidxs = dataidxs 66 | self.train = train 67 | self.transform = transform 68 | self.target_transform = target_transform 69 | self.download = download 70 | 71 | self.data, self.target = self.__build_truncated_dataset__() 72 | 73 | def __build_truncated_dataset__(self): 74 | 75 | mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download) 76 | 77 | # if self.train: 78 | # data = mnist_dataobj.train_data 79 | # target = mnist_dataobj.train_labels 80 | # else: 81 | # data = mnist_dataobj.test_data 82 | # target = mnist_dataobj.test_labels 83 | 84 | data = mnist_dataobj.data 85 | target = mnist_dataobj.targets 86 | 87 | if self.dataidxs is not None: 88 | data = data[self.dataidxs] 89 | target = target[self.dataidxs] 90 | 91 | return data, target 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | 98 | Returns: 99 | tuple: (image, target) where target is index of the target class. 100 | """ 101 | img, target = self.data[index], self.target[index] 102 | 103 | # doing this so that it is consistent with all other datasets 104 | # to return a PIL Image 105 | img = Image.fromarray(img.numpy(), mode='L') 106 | 107 | # print("mnist img:", img) 108 | # print("mnist target:", target) 109 | 110 | if self.transform is not None: 111 | img = self.transform(img) 112 | 113 | if self.target_transform is not None: 114 | target = self.target_transform(target) 115 | 116 | return img, target 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | class FashionMNIST_truncated(data.Dataset): 122 | 123 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 124 | 125 | self.root = root 126 | self.dataidxs = dataidxs 127 | self.train = train 128 | self.transform = transform 129 | self.target_transform = target_transform 130 | self.download = download 131 | 132 | self.data, self.target = self.__build_truncated_dataset__() 133 | 134 | def __build_truncated_dataset__(self): 135 | 136 | mnist_dataobj = FashionMNIST(self.root, self.train, self.transform, self.target_transform, self.download) 137 | 138 | # if self.train: 139 | # data = mnist_dataobj.train_data 140 | # target = mnist_dataobj.train_labels 141 | # else: 142 | # data = mnist_dataobj.test_data 143 | # target = mnist_dataobj.test_labels 144 | 145 | data = mnist_dataobj.data 146 | target = mnist_dataobj.targets 147 | 148 | if self.dataidxs is not None: 149 | data = data[self.dataidxs] 150 | target = target[self.dataidxs] 151 | 152 | return data, target 153 | 154 | def __getitem__(self, index): 155 | """ 156 | Args: 157 | index (int): Index 158 | 159 | Returns: 160 | tuple: (image, target) where target is index of the target class. 161 | """ 162 | img, target = self.data[index], self.target[index] 163 | 164 | # doing this so that it is consistent with all other datasets 165 | # to return a PIL Image 166 | img = Image.fromarray(img.numpy(), mode='L') 167 | 168 | # print("mnist img:", img) 169 | # print("mnist target:", target) 170 | 171 | if self.transform is not None: 172 | img = self.transform(img) 173 | 174 | if self.target_transform is not None: 175 | target = self.target_transform(target) 176 | 177 | return img, target 178 | 179 | def __len__(self): 180 | return len(self.data) 181 | 182 | class SVHN_custom(data.Dataset): 183 | 184 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 185 | 186 | self.root = root 187 | self.dataidxs = dataidxs 188 | self.train = train 189 | self.transform = transform 190 | self.target_transform = target_transform 191 | self.download = download 192 | 193 | self.data, self.target = self.__build_truncated_dataset__() 194 | 195 | def __build_truncated_dataset__(self): 196 | if self.train is True: 197 | # svhn_dataobj1 = SVHN(self.root, 'train', self.transform, self.target_transform, self.download) 198 | # svhn_dataobj2 = SVHN(self.root, 'extra', self.transform, self.target_transform, self.download) 199 | # data = np.concatenate((svhn_dataobj1.data, svhn_dataobj2.data), axis=0) 200 | # target = np.concatenate((svhn_dataobj1.labels, svhn_dataobj2.labels), axis=0) 201 | 202 | svhn_dataobj = SVHN(self.root, 'train', self.transform, self.target_transform, self.download) 203 | data = svhn_dataobj.data 204 | target = svhn_dataobj.labels 205 | else: 206 | svhn_dataobj = SVHN(self.root, 'test', self.transform, self.target_transform, self.download) 207 | data = svhn_dataobj.data 208 | target = svhn_dataobj.labels 209 | 210 | if self.dataidxs is not None: 211 | data = data[self.dataidxs] 212 | target = target[self.dataidxs] 213 | # print("svhn data:", data) 214 | # print("len svhn data:", len(data)) 215 | # print("type svhn data:", type(data)) 216 | # print("svhn target:", target) 217 | # print("type svhn target", type(target)) 218 | return data, target 219 | 220 | # def truncate_channel(self, index): 221 | # for i in range(index.shape[0]): 222 | # gs_index = index[i] 223 | # self.data[gs_index, :, :, 1] = 0.0 224 | # self.data[gs_index, :, :, 2] = 0.0 225 | 226 | def __getitem__(self, index): 227 | """ 228 | Args: 229 | index (int): Index 230 | 231 | Returns: 232 | tuple: (image, target) where target is index of the target class. 233 | """ 234 | img, target = self.data[index], self.target[index] 235 | # print("svhn img:", img) 236 | # print("svhn target:", target) 237 | # doing this so that it is consistent with all other datasets 238 | # to return a PIL Image 239 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 240 | 241 | if self.transform is not None: 242 | img = self.transform(img) 243 | 244 | if self.target_transform is not None: 245 | target = self.target_transform(target) 246 | 247 | return img, target 248 | 249 | def __len__(self): 250 | return len(self.data) 251 | 252 | 253 | # torchvision CelebA 254 | class CelebA_custom(VisionDataset): 255 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 256 | 257 | Args: 258 | root (string): Root directory where images are downloaded to. 259 | split (string): One of {'train', 'valid', 'test', 'all'}. 260 | Accordingly dataset is selected. 261 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 262 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 263 | The targets represent: 264 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 265 | ``identity`` (int): label for each person (data points with the same identity are the same person) 266 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 267 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 268 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 269 | Defaults to ``attr``. If empty, ``None`` will be returned as target. 270 | transform (callable, optional): A function/transform that takes in an PIL image 271 | and returns a transformed version. E.g, ``transforms.ToTensor`` 272 | target_transform (callable, optional): A function/transform that takes in the 273 | target and transforms it. 274 | download (bool, optional): If true, downloads the dataset from the internet and 275 | puts it in root directory. If dataset is already downloaded, it is not 276 | downloaded again. 277 | """ 278 | 279 | base_folder = "celeba" 280 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 281 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 282 | # right now. 283 | file_list = [ 284 | # File ID MD5 Hash Filename 285 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 286 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 287 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 288 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 289 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 290 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 291 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 292 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 293 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 294 | ] 295 | 296 | def __init__(self, root, dataidxs=None, split="train", target_type="attr", transform=None, 297 | target_transform=None, download=False): 298 | import pandas 299 | super(CelebA_custom, self).__init__(root, transform=transform, 300 | target_transform=target_transform) 301 | self.split = split 302 | if isinstance(target_type, list): 303 | self.target_type = target_type 304 | else: 305 | self.target_type = [target_type] 306 | 307 | if not self.target_type and self.target_transform is not None: 308 | raise RuntimeError('target_transform is specified but target_type is empty') 309 | 310 | if download: 311 | self.download() 312 | 313 | if not self._check_integrity(): 314 | raise RuntimeError('Dataset not found or corrupted.' + 315 | ' You can use download=True to download it') 316 | 317 | split_map = { 318 | "train": 0, 319 | "valid": 1, 320 | "test": 2, 321 | "all": None, 322 | } 323 | split = split_map[split.lower()] 324 | 325 | fn = partial(os.path.join, self.root, self.base_folder) 326 | splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) 327 | identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0) 328 | bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0) 329 | landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1) 330 | attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) 331 | 332 | mask = slice(None) if split is None else (splits[1] == split) 333 | 334 | self.filename = splits[mask].index.values 335 | self.identity = torch.as_tensor(identity[mask].values) 336 | self.bbox = torch.as_tensor(bbox[mask].values) 337 | self.landmarks_align = torch.as_tensor(landmarks_align[mask].values) 338 | self.attr = torch.as_tensor(attr[mask].values) 339 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 340 | self.attr_names = list(attr.columns) 341 | self.gender_index = self.attr_names.index('Male') 342 | self.dataidxs = dataidxs 343 | if self.dataidxs is None: 344 | self.target = self.attr[:, self.gender_index:self.gender_index + 1].reshape(-1) 345 | else: 346 | self.target = self.attr[self.dataidxs, self.gender_index:self.gender_index + 1].reshape(-1) 347 | 348 | def _check_integrity(self): 349 | for (_, md5, filename) in self.file_list: 350 | fpath = os.path.join(self.root, self.base_folder, filename) 351 | _, ext = os.path.splitext(filename) 352 | # Allow original archive to be deleted (zip and 7z) 353 | # Only need the extracted images 354 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 355 | return False 356 | 357 | # Should check a hash of the images 358 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 359 | 360 | def download(self): 361 | import zipfile 362 | 363 | if self._check_integrity(): 364 | print('Files already downloaded and verified') 365 | return 366 | 367 | for (file_id, md5, filename) in self.file_list: 368 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 369 | 370 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 371 | f.extractall(os.path.join(self.root, self.base_folder)) 372 | 373 | def __getitem__(self, index): 374 | if self.dataidxs is None: 375 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 376 | 377 | target = [] 378 | for t in self.target_type: 379 | if t == "attr": 380 | target.append(self.attr[index, self.gender_index]) 381 | elif t == "identity": 382 | target.append(self.identity[index, 0]) 383 | elif t == "bbox": 384 | target.append(self.bbox[index, :]) 385 | elif t == "landmarks": 386 | target.append(self.landmarks_align[index, :]) 387 | else: 388 | # TODO: refactor with utils.verify_str_arg 389 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 390 | else: 391 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[self.dataidxs[index]])) 392 | 393 | target = [] 394 | for t in self.target_type: 395 | if t == "attr": 396 | target.append(self.attr[self.dataidxs[index], self.gender_index]) 397 | elif t == "identity": 398 | target.append(self.identity[self.dataidxs[index], 0]) 399 | elif t == "bbox": 400 | target.append(self.bbox[self.dataidxs[index], :]) 401 | elif t == "landmarks": 402 | target.append(self.landmarks_align[self.dataidxs[index], :]) 403 | else: 404 | # TODO: refactor with utils.verify_str_arg 405 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 406 | 407 | if self.transform is not None: 408 | X = self.transform(X) 409 | #print("target[0]:", target[0]) 410 | if target: 411 | target = tuple(target) if len(target) > 1 else target[0] 412 | 413 | if self.target_transform is not None: 414 | target = self.target_transform(target) 415 | else: 416 | target = None 417 | #print("celeba target:", target) 418 | return X, target 419 | 420 | def __len__(self): 421 | if self.dataidxs is None: 422 | return len(self.attr) 423 | else: 424 | return len(self.dataidxs) 425 | 426 | def extra_repr(self): 427 | lines = ["Target type: {target_type}", "Split: {split}"] 428 | return '\n'.join(lines).format(**self.__dict__) 429 | 430 | 431 | 432 | class CIFAR10_truncated(data.Dataset): 433 | 434 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 435 | 436 | self.root = root 437 | self.dataidxs = dataidxs 438 | self.train = train 439 | self.transform = transform 440 | self.target_transform = target_transform 441 | self.download = download 442 | 443 | self.data, self.target = self.__build_truncated_dataset__() 444 | 445 | def __build_truncated_dataset__(self): 446 | 447 | cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) 448 | 449 | data = cifar_dataobj.data 450 | target = np.array(cifar_dataobj.targets) 451 | 452 | if self.dataidxs is not None: 453 | data = data[self.dataidxs] 454 | target = target[self.dataidxs] 455 | 456 | return data, target 457 | 458 | def truncate_channel(self, index): 459 | for i in range(index.shape[0]): 460 | gs_index = index[i] 461 | self.data[gs_index, :, :, 1] = 0.0 462 | self.data[gs_index, :, :, 2] = 0.0 463 | 464 | def __getitem__(self, index): 465 | """ 466 | Args: 467 | index (int): Index 468 | 469 | Returns: 470 | tuple: (image, target) where target is index of the target class. 471 | """ 472 | img, target = self.data[index], self.target[index] 473 | 474 | # print("cifar10 img:", img) 475 | # print("cifar10 target:", target) 476 | 477 | if self.transform is not None: 478 | img = self.transform(img) 479 | 480 | if self.target_transform is not None: 481 | target = self.target_transform(target) 482 | 483 | return img, target 484 | 485 | def __len__(self): 486 | return len(self.data) 487 | 488 | def gen_bar_updater() -> Callable[[int, int, int], None]: 489 | pbar = tqdm(total=None) 490 | 491 | def bar_update(count, block_size, total_size): 492 | if pbar.total is None and total_size: 493 | pbar.total = total_size 494 | progress_bytes = count * block_size 495 | pbar.update(progress_bytes - pbar.n) 496 | 497 | return bar_update 498 | 499 | 500 | def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None: 501 | """Download a file from a url and place it in root. 502 | Args: 503 | url (str): URL to download file from 504 | root (str): Directory to place downloaded file in 505 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 506 | md5 (str, optional): MD5 checksum of the download. If None, do not check 507 | """ 508 | import urllib 509 | 510 | root = os.path.expanduser(root) 511 | if not filename: 512 | filename = os.path.basename(url) 513 | fpath = os.path.join(root, filename) 514 | 515 | os.makedirs(root, exist_ok=True) 516 | 517 | # check if file is already present locally 518 | if check_integrity(fpath, md5): 519 | print('Using downloaded and verified file: ' + fpath) 520 | else: # download the file 521 | try: 522 | print('Downloading ' + url + ' to ' + fpath) 523 | urllib.request.urlretrieve( 524 | url, fpath, 525 | reporthook=gen_bar_updater() 526 | ) 527 | except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] 528 | if url[:5] == 'https': 529 | url = url.replace('https:', 'http:') 530 | print('Failed download. Trying https -> http instead.' 531 | ' Downloading ' + url + ' to ' + fpath) 532 | urllib.request.urlretrieve( 533 | url, fpath, 534 | reporthook=gen_bar_updater() 535 | ) 536 | else: 537 | raise e 538 | # check integrity of downloaded file 539 | if not check_integrity(fpath, md5): 540 | raise RuntimeError("File not found or corrupted.") 541 | 542 | def _is_tarxz(filename: str) -> bool: 543 | return filename.endswith(".tar.xz") 544 | 545 | 546 | def _is_tar(filename: str) -> bool: 547 | return filename.endswith(".tar") 548 | 549 | 550 | def _is_targz(filename: str) -> bool: 551 | return filename.endswith(".tar.gz") 552 | 553 | 554 | def _is_tgz(filename: str) -> bool: 555 | return filename.endswith(".tgz") 556 | 557 | 558 | def _is_gzip(filename: str) -> bool: 559 | return filename.endswith(".gz") and not filename.endswith(".tar.gz") 560 | 561 | 562 | def _is_zip(filename: str) -> bool: 563 | return filename.endswith(".zip") 564 | 565 | 566 | def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: 567 | if to_path is None: 568 | to_path = os.path.dirname(from_path) 569 | 570 | if _is_tar(from_path): 571 | with tarfile.open(from_path, 'r') as tar: 572 | def is_within_directory(directory, target): 573 | 574 | abs_directory = os.path.abspath(directory) 575 | abs_target = os.path.abspath(target) 576 | 577 | prefix = os.path.commonprefix([abs_directory, abs_target]) 578 | 579 | return prefix == abs_directory 580 | 581 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 582 | 583 | for member in tar.getmembers(): 584 | member_path = os.path.join(path, member.name) 585 | if not is_within_directory(path, member_path): 586 | raise Exception("Attempted Path Traversal in Tar File") 587 | 588 | tar.extractall(path, members, numeric_owner=numeric_owner) 589 | 590 | 591 | safe_extract(tar, path=to_path) 592 | elif _is_targz(from_path) or _is_tgz(from_path): 593 | with tarfile.open(from_path, 'r:gz') as tar: 594 | def is_within_directory(directory, target): 595 | 596 | abs_directory = os.path.abspath(directory) 597 | abs_target = os.path.abspath(target) 598 | 599 | prefix = os.path.commonprefix([abs_directory, abs_target]) 600 | 601 | return prefix == abs_directory 602 | 603 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 604 | 605 | for member in tar.getmembers(): 606 | member_path = os.path.join(path, member.name) 607 | if not is_within_directory(path, member_path): 608 | raise Exception("Attempted Path Traversal in Tar File") 609 | 610 | tar.extractall(path, members, numeric_owner=numeric_owner) 611 | 612 | 613 | safe_extract(tar, path=to_path) 614 | elif _is_tarxz(from_path): 615 | with tarfile.open(from_path, 'r:xz') as tar: 616 | def is_within_directory(directory, target): 617 | 618 | abs_directory = os.path.abspath(directory) 619 | abs_target = os.path.abspath(target) 620 | 621 | prefix = os.path.commonprefix([abs_directory, abs_target]) 622 | 623 | return prefix == abs_directory 624 | 625 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 626 | 627 | for member in tar.getmembers(): 628 | member_path = os.path.join(path, member.name) 629 | if not is_within_directory(path, member_path): 630 | raise Exception("Attempted Path Traversal in Tar File") 631 | 632 | tar.extractall(path, members, numeric_owner=numeric_owner) 633 | 634 | 635 | safe_extract(tar, path=to_path) 636 | elif _is_gzip(from_path): 637 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 638 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 639 | out_f.write(zip_f.read()) 640 | elif _is_zip(from_path): 641 | with zipfile.ZipFile(from_path, 'r') as z: 642 | z.extractall(to_path) 643 | else: 644 | raise ValueError("Extraction of {} not supported".format(from_path)) 645 | 646 | if remove_finished: 647 | os.remove(from_path) 648 | 649 | 650 | def download_and_extract_archive( 651 | url: str, 652 | download_root: str, 653 | extract_root: Optional[str] = None, 654 | filename: Optional[str] = None, 655 | md5: Optional[str] = None, 656 | remove_finished: bool = False, 657 | ) -> None: 658 | download_root = os.path.expanduser(download_root) 659 | if extract_root is None: 660 | extract_root = download_root 661 | if not filename: 662 | filename = os.path.basename(url) 663 | 664 | download_url(url, download_root, filename, md5) 665 | 666 | archive = os.path.join(download_root, filename) 667 | print("Extracting {} to {}".format(archive, extract_root)) 668 | extract_archive(archive, extract_root, remove_finished) 669 | 670 | class FEMNIST(MNIST): 671 | """ 672 | This dataset is derived from the Leaf repository 673 | (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST 674 | dataset, grouping examples by writer. Details about Leaf were published in 675 | "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097. 676 | """ 677 | resources = [ 678 | ('https://raw.githubusercontent.com/tao-shen/FEMNIST_pytorch/master/femnist.tar.gz', 679 | '59c65cec646fc57fe92d27d83afdf0ed')] 680 | 681 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, 682 | download=False): 683 | super(MNIST, self).__init__(root, transform=transform, 684 | target_transform=target_transform) 685 | self.train = train 686 | self.dataidxs = dataidxs 687 | 688 | if download: 689 | self.download() 690 | 691 | if not self._check_exists(): 692 | raise RuntimeError('Dataset not found.' + 693 | ' You can use download=True to download it') 694 | if self.train: 695 | data_file = self.training_file 696 | else: 697 | data_file = self.test_file 698 | 699 | self.data, self.targets, self.users_index = torch.load(os.path.join(self.processed_folder, data_file)) 700 | 701 | if self.dataidxs is not None: 702 | self.data = self.data[self.dataidxs] 703 | self.targets = self.targets[self.dataidxs] 704 | 705 | 706 | def __getitem__(self, index): 707 | img, target = self.data[index], int(self.targets[index]) 708 | img = Image.fromarray(img.numpy(), mode='F') 709 | if self.transform is not None: 710 | img = self.transform(img) 711 | if self.target_transform is not None: 712 | target = self.target_transform(target) 713 | return img, target 714 | 715 | def download(self): 716 | """Download the FEMNIST data if it doesn't exist in processed_folder already.""" 717 | import shutil 718 | 719 | if self._check_exists(): 720 | return 721 | 722 | mkdirs(self.raw_folder) 723 | mkdirs(self.processed_folder) 724 | 725 | # download files 726 | for url, md5 in self.resources: 727 | filename = url.rpartition('/')[2] 728 | download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) 729 | 730 | # process and save as torch files 731 | print('Processing...') 732 | shutil.move(os.path.join(self.raw_folder, self.training_file), self.processed_folder) 733 | shutil.move(os.path.join(self.raw_folder, self.test_file), self.processed_folder) 734 | 735 | def __len__(self): 736 | return len(self.data) 737 | 738 | def _check_exists(self) -> bool: 739 | return all( 740 | check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]+os.path.splitext(os.path.basename(url))[1])) 741 | for url, _ in self.resources 742 | ) 743 | 744 | 745 | class Generated(MNIST): 746 | 747 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, 748 | download=False): 749 | super(MNIST, self).__init__(root, transform=transform, 750 | target_transform=target_transform) 751 | self.train = train 752 | self.dataidxs = dataidxs 753 | 754 | if self.train: 755 | self.data = np.load("data/generated/X_train.npy") 756 | self.targets = np.load("data/generated/y_train.npy") 757 | else: 758 | self.data = np.load("data/generated/X_test.npy") 759 | self.targets = np.load("data/generated/y_test.npy") 760 | 761 | if self.dataidxs is not None: 762 | self.data = self.data[self.dataidxs] 763 | self.targets = self.targets[self.dataidxs] 764 | 765 | 766 | def __getitem__(self, index): 767 | data, target = self.data[index], self.targets[index] 768 | return data, target 769 | 770 | def __len__(self): 771 | return len(self.data) 772 | 773 | 774 | 775 | class genData(MNIST): 776 | def __init__(self, data, targets): 777 | self.data = data 778 | self.targets = targets 779 | def __getitem__(self,index): 780 | data, target = self.data[index], self.targets[index] 781 | return data, target 782 | def __len__(self): 783 | return len(self.data) 784 | 785 | class CIFAR100_truncated(data.Dataset): 786 | 787 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 788 | 789 | self.root = root 790 | self.dataidxs = dataidxs 791 | self.train = train 792 | self.transform = transform 793 | self.target_transform = target_transform 794 | self.download = download 795 | 796 | self.data, self.target = self.__build_truncated_dataset__() 797 | 798 | def __build_truncated_dataset__(self): 799 | 800 | cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download) 801 | 802 | if torchvision.__version__ == '0.2.1': 803 | if self.train: 804 | data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels) 805 | else: 806 | data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels) 807 | else: 808 | data = cifar_dataobj.data 809 | target = np.array(cifar_dataobj.targets) 810 | 811 | if self.dataidxs is not None: 812 | data = data[self.dataidxs] 813 | target = target[self.dataidxs] 814 | 815 | return data, target 816 | 817 | def __getitem__(self, index): 818 | """ 819 | Args: 820 | index (int): Index 821 | Returns: 822 | tuple: (image, target) where target is index of the target class. 823 | """ 824 | img, target = self.data[index], self.target[index] 825 | img = Image.fromarray(img) 826 | # print("cifar10 img:", img) 827 | # print("cifar10 target:", target) 828 | 829 | if self.transform is not None: 830 | img = self.transform(img) 831 | 832 | if self.target_transform is not None: 833 | target = self.target_transform(target) 834 | 835 | return img, target 836 | 837 | def __len__(self): 838 | return len(self.data) 839 | 840 | 841 | 842 | 843 | class ImageFolder_custom(DatasetFolder): 844 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=None): 845 | self.root = root 846 | self.dataidxs = dataidxs 847 | self.train = train 848 | self.transform = transform 849 | self.target_transform = target_transform 850 | 851 | imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform) 852 | self.loader = imagefolder_obj.loader 853 | if self.dataidxs is not None: 854 | self.samples = np.array(imagefolder_obj.samples)[self.dataidxs] 855 | else: 856 | self.samples = np.array(imagefolder_obj.samples) 857 | 858 | def __getitem__(self, index): 859 | path = self.samples[index][0] 860 | target = self.samples[index][1] 861 | target = int(target) 862 | sample = self.loader(path) 863 | if self.transform is not None: 864 | sample = self.transform(sample) 865 | if self.target_transform is not None: 866 | target = self.target_transform(target) 867 | 868 | return sample, target 869 | 870 | def __len__(self): 871 | if self.dataidxs is None: 872 | return len(self.samples) 873 | else: 874 | return len(self.dataidxs) 875 | -------------------------------------------------------------------------------- /experiments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from torch.autograd import Variable 9 | import torch.utils.data as data 10 | import argparse 11 | import logging 12 | import os 13 | import copy 14 | from math import * 15 | import random 16 | 17 | import datetime 18 | #from torch.utils.tensorboard import SummaryWriter 19 | 20 | from model import * 21 | from utils import * 22 | from vggmodel import * 23 | from resnetcifar import * 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--model', type=str, default='mlp', help='neural network used in training') 28 | parser.add_argument('--dataset', type=str, default='mnist', help='dataset used for training') 29 | parser.add_argument('--net_config', type=lambda x: list(map(int, x.split(', ')))) 30 | parser.add_argument('--partition', type=str, default='homo', help='the data partitioning strategy') 31 | parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)') 32 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 0.01)') 33 | parser.add_argument('--epochs', type=int, default=5, help='number of local epochs') 34 | parser.add_argument('--n_parties', type=int, default=2, help='number of workers in a distributed cluster') 35 | parser.add_argument('--alg', type=str, default='fedavg', 36 | help='fl algorithms: fedavg/fedprox/scaffold/fednova/moon') 37 | parser.add_argument('--use_projection_head', type=bool, default=False, help='whether add an additional header to model or not (see MOON)') 38 | parser.add_argument('--out_dim', type=int, default=256, help='the output dimension for the projection layer') 39 | parser.add_argument('--loss', type=str, default='contrastive', help='for moon') 40 | parser.add_argument('--temperature', type=float, default=0.5, help='the temperature parameter for contrastive loss') 41 | parser.add_argument('--comm_round', type=int, default=50, help='number of maximum communication roun') 42 | parser.add_argument('--is_same_initial', type=int, default=1, help='Whether initial all the models with the same parameters in fedavg') 43 | parser.add_argument('--init_seed', type=int, default=0, help="Random seed") 44 | parser.add_argument('--dropout_p', type=float, required=False, default=0.0, help="Dropout probability. Default=0.0") 45 | parser.add_argument('--datadir', type=str, required=False, default="./data/", help="Data directory") 46 | parser.add_argument('--reg', type=float, default=1e-5, help="L2 regularization strength") 47 | parser.add_argument('--logdir', type=str, required=False, default="./logs/", help='Log directory path') 48 | parser.add_argument('--modeldir', type=str, required=False, default="./models/", help='Model directory path') 49 | parser.add_argument('--beta', type=float, default=0.5, help='The parameter for the dirichlet distribution for data partitioning') 50 | parser.add_argument('--device', type=str, default='cuda:0', help='The device to run the program') 51 | parser.add_argument('--log_file_name', type=str, default=None, help='The log file name') 52 | parser.add_argument('--optimizer', type=str, default='sgd', help='the optimizer') 53 | parser.add_argument('--mu', type=float, default=0.001, help='the mu parameter for fedprox') 54 | parser.add_argument('--noise', type=float, default=0, help='how much noise we add to some party') 55 | parser.add_argument('--noise_type', type=str, default='level', help='Different level of noise or different space of noise') 56 | parser.add_argument('--rho', type=float, default=0, help='Parameter controlling the momentum SGD') 57 | parser.add_argument('--sample', type=float, default=1, help='Sample ratio for each communication round') 58 | args = parser.parse_args() 59 | return args 60 | 61 | def init_nets(net_configs, dropout_p, n_parties, args): 62 | 63 | nets = {net_i: None for net_i in range(n_parties)} 64 | 65 | if args.dataset in {'mnist', 'cifar10', 'svhn', 'fmnist'}: 66 | n_classes = 10 67 | elif args.dataset == 'celeba': 68 | n_classes = 2 69 | elif args.dataset == 'cifar100': 70 | n_classes = 100 71 | elif args.dataset == 'tinyimagenet': 72 | n_classes = 200 73 | elif args.dataset == 'femnist': 74 | n_classes = 62 75 | elif args.dataset == 'emnist': 76 | n_classes = 47 77 | elif args.dataset in {'a9a', 'covtype', 'rcv1', 'SUSY'}: 78 | n_classes = 2 79 | if args.use_projection_head: 80 | add = "" 81 | if "mnist" in args.dataset and args.model == "simple-cnn": 82 | add = "-mnist" 83 | for net_i in range(n_parties): 84 | net = ModelFedCon(args.model+add, args.out_dim, n_classes, net_configs) 85 | nets[net_i] = net 86 | else: 87 | if args.alg == 'moon': 88 | add = "" 89 | if "mnist" in args.dataset and args.model == "simple-cnn": 90 | add = "-mnist" 91 | for net_i in range(n_parties): 92 | net = ModelFedCon_noheader(args.model+add, args.out_dim, n_classes, net_configs) 93 | nets[net_i] = net 94 | else: 95 | for net_i in range(n_parties): 96 | if args.dataset == "generated": 97 | net = PerceptronModel() 98 | elif args.model == "mlp": 99 | if args.dataset == 'covtype': 100 | input_size = 54 101 | output_size = 2 102 | hidden_sizes = [32,16,8] 103 | elif args.dataset == 'a9a': 104 | input_size = 123 105 | output_size = 2 106 | hidden_sizes = [32,16,8] 107 | elif args.dataset == 'rcv1': 108 | input_size = 47236 109 | output_size = 2 110 | hidden_sizes = [32,16,8] 111 | elif args.dataset == 'SUSY': 112 | input_size = 18 113 | output_size = 2 114 | hidden_sizes = [16,8] 115 | net = FcNet(input_size, hidden_sizes, output_size, dropout_p) 116 | elif args.model == "vgg": 117 | net = vgg11() 118 | elif args.model == "simple-cnn": 119 | if args.dataset in ("cifar10", "cinic10", "svhn"): 120 | net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=10) 121 | elif args.dataset in ("mnist", 'femnist', 'fmnist'): 122 | net = SimpleCNNMNIST(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=10) 123 | elif args.dataset == 'celeba': 124 | net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=2) 125 | elif args.model == "vgg-9": 126 | if args.dataset in ("mnist", 'femnist'): 127 | net = ModerateCNNMNIST() 128 | elif args.dataset in ("cifar10", "cinic10", "svhn"): 129 | # print("in moderate cnn") 130 | net = ModerateCNN() 131 | elif args.dataset == 'celeba': 132 | net = ModerateCNN(output_dim=2) 133 | elif args.model == "resnet": 134 | net = ResNet50_cifar10() 135 | elif args.model == "vgg16": 136 | net = vgg16() 137 | else: 138 | print("not supported yet") 139 | exit(1) 140 | nets[net_i] = net 141 | 142 | model_meta_data = [] 143 | layer_type = [] 144 | for (k, v) in nets[0].state_dict().items(): 145 | model_meta_data.append(v.shape) 146 | layer_type.append(k) 147 | return nets, model_meta_data, layer_type 148 | 149 | 150 | def train_net(net_id, net, train_dataloader, test_dataloader, epochs, lr, args_optimizer, device="cpu"): 151 | logger.info('Training network %s' % str(net_id)) 152 | 153 | train_acc = compute_accuracy(net, train_dataloader, device=device) 154 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 155 | 156 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc)) 157 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc)) 158 | 159 | if args_optimizer == 'adam': 160 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg) 161 | elif args_optimizer == 'amsgrad': 162 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg, 163 | amsgrad=True) 164 | elif args_optimizer == 'sgd': 165 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg) 166 | criterion = nn.CrossEntropyLoss().to(device) 167 | 168 | cnt = 0 169 | if type(train_dataloader) == type([1]): 170 | pass 171 | else: 172 | train_dataloader = [train_dataloader] 173 | 174 | #writer = SummaryWriter() 175 | 176 | for epoch in range(epochs): 177 | epoch_loss_collector = [] 178 | for tmp in train_dataloader: 179 | for batch_idx, (x, target) in enumerate(tmp): 180 | x, target = x.to(device), target.to(device) 181 | 182 | optimizer.zero_grad() 183 | x.requires_grad = True 184 | target.requires_grad = False 185 | target = target.long() 186 | 187 | out = net(x) 188 | loss = criterion(out, target) 189 | 190 | loss.backward() 191 | optimizer.step() 192 | 193 | cnt += 1 194 | epoch_loss_collector.append(loss.item()) 195 | 196 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) 197 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss)) 198 | 199 | #train_acc = compute_accuracy(net, train_dataloader, device=device) 200 | #test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 201 | 202 | #writer.add_scalar('Accuracy/train', train_acc, epoch) 203 | #writer.add_scalar('Accuracy/test', test_acc, epoch) 204 | 205 | # if epoch % 10 == 0: 206 | # logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss)) 207 | # train_acc = compute_accuracy(net, train_dataloader, device=device) 208 | # test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 209 | # 210 | # logger.info('>> Training accuracy: %f' % train_acc) 211 | # logger.info('>> Test accuracy: %f' % test_acc) 212 | 213 | train_acc = compute_accuracy(net, train_dataloader, device=device) 214 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 215 | 216 | logger.info('>> Training accuracy: %f' % train_acc) 217 | logger.info('>> Test accuracy: %f' % test_acc) 218 | 219 | net.to('cpu') 220 | logger.info(' ** Training complete **') 221 | return train_acc, test_acc 222 | 223 | 224 | 225 | def train_net_fedprox(net_id, net, global_net, train_dataloader, test_dataloader, epochs, lr, args_optimizer, mu, device="cpu"): 226 | logger.info('Training network %s' % str(net_id)) 227 | logger.info('n_training: %d' % len(train_dataloader)) 228 | logger.info('n_test: %d' % len(test_dataloader)) 229 | 230 | train_acc = compute_accuracy(net, train_dataloader, device=device) 231 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 232 | 233 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc)) 234 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc)) 235 | 236 | 237 | if args_optimizer == 'adam': 238 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg) 239 | elif args_optimizer == 'amsgrad': 240 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg, 241 | amsgrad=True) 242 | elif args_optimizer == 'sgd': 243 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg) 244 | 245 | criterion = nn.CrossEntropyLoss().to(device) 246 | 247 | cnt = 0 248 | # mu = 0.001 249 | global_weight_collector = list(global_net.to(device).parameters()) 250 | 251 | for epoch in range(epochs): 252 | epoch_loss_collector = [] 253 | for batch_idx, (x, target) in enumerate(train_dataloader): 254 | x, target = x.to(device), target.to(device) 255 | 256 | optimizer.zero_grad() 257 | x.requires_grad = True 258 | target.requires_grad = False 259 | target = target.long() 260 | 261 | out = net(x) 262 | loss = criterion(out, target) 263 | 264 | #for fedprox 265 | fed_prox_reg = 0.0 266 | for param_index, param in enumerate(net.parameters()): 267 | fed_prox_reg += ((mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2) 268 | loss += fed_prox_reg 269 | 270 | 271 | loss.backward() 272 | optimizer.step() 273 | 274 | cnt += 1 275 | epoch_loss_collector.append(loss.item()) 276 | 277 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) 278 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss)) 279 | 280 | # if epoch % 10 == 0: 281 | # train_acc = compute_accuracy(net, train_dataloader, device=device) 282 | # test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 283 | # 284 | # logger.info('>> Training accuracy: %f' % train_acc) 285 | # logger.info('>> Test accuracy: %f' % test_acc) 286 | 287 | train_acc = compute_accuracy(net, train_dataloader, device=device) 288 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 289 | 290 | logger.info('>> Training accuracy: %f' % train_acc) 291 | logger.info('>> Test accuracy: %f' % test_acc) 292 | 293 | net.to('cpu') 294 | logger.info(' ** Training complete **') 295 | return train_acc, test_acc 296 | 297 | def train_net_scaffold(net_id, net, global_model, c_local, c_global, train_dataloader, test_dataloader, epochs, lr, args_optimizer, device="cpu"): 298 | logger.info('Training network %s' % str(net_id)) 299 | 300 | train_acc = compute_accuracy(net, train_dataloader, device=device) 301 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 302 | 303 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc)) 304 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc)) 305 | 306 | if args_optimizer == 'adam': 307 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg) 308 | elif args_optimizer == 'amsgrad': 309 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg, 310 | amsgrad=True) 311 | elif args_optimizer == 'sgd': 312 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg) 313 | criterion = nn.CrossEntropyLoss().to(device) 314 | 315 | cnt = 0 316 | if type(train_dataloader) == type([1]): 317 | pass 318 | else: 319 | train_dataloader = [train_dataloader] 320 | 321 | #writer = SummaryWriter() 322 | 323 | c_local.to(device) 324 | c_global.to(device) 325 | global_model.to(device) 326 | 327 | c_global_para = c_global.state_dict() 328 | c_local_para = c_local.state_dict() 329 | 330 | for epoch in range(epochs): 331 | epoch_loss_collector = [] 332 | for tmp in train_dataloader: 333 | for batch_idx, (x, target) in enumerate(tmp): 334 | x, target = x.to(device), target.to(device) 335 | 336 | optimizer.zero_grad() 337 | x.requires_grad = True 338 | target.requires_grad = False 339 | target = target.long() 340 | 341 | out = net(x) 342 | loss = criterion(out, target) 343 | 344 | loss.backward() 345 | optimizer.step() 346 | 347 | net_para = net.state_dict() 348 | for key in net_para: 349 | net_para[key] = net_para[key] - args.lr * (c_global_para[key] - c_local_para[key]) 350 | net.load_state_dict(net_para) 351 | 352 | cnt += 1 353 | epoch_loss_collector.append(loss.item()) 354 | 355 | 356 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) 357 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss)) 358 | 359 | c_new_para = c_local.state_dict() 360 | c_delta_para = copy.deepcopy(c_local.state_dict()) 361 | global_model_para = global_model.state_dict() 362 | net_para = net.state_dict() 363 | for key in net_para: 364 | c_new_para[key] = c_new_para[key] - c_global_para[key] + (global_model_para[key] - net_para[key]) / (cnt * args.lr) 365 | c_delta_para[key] = c_new_para[key] - c_local_para[key] 366 | c_local.load_state_dict(c_new_para) 367 | 368 | 369 | train_acc = compute_accuracy(net, train_dataloader, device=device) 370 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 371 | 372 | logger.info('>> Training accuracy: %f' % train_acc) 373 | logger.info('>> Test accuracy: %f' % test_acc) 374 | 375 | net.to('cpu') 376 | logger.info(' ** Training complete **') 377 | return train_acc, test_acc, c_delta_para 378 | 379 | def train_net_fednova(net_id, net, global_model, train_dataloader, test_dataloader, epochs, lr, args_optimizer, device="cpu"): 380 | logger.info('Training network %s' % str(net_id)) 381 | 382 | train_acc = compute_accuracy(net, train_dataloader, device=device) 383 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 384 | 385 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc)) 386 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc)) 387 | 388 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=args.rho, weight_decay=args.reg) 389 | criterion = nn.CrossEntropyLoss().to(device) 390 | 391 | if type(train_dataloader) == type([1]): 392 | pass 393 | else: 394 | train_dataloader = [train_dataloader] 395 | 396 | #writer = SummaryWriter() 397 | 398 | 399 | tau = 0 400 | 401 | for epoch in range(epochs): 402 | epoch_loss_collector = [] 403 | for tmp in train_dataloader: 404 | for batch_idx, (x, target) in enumerate(tmp): 405 | x, target = x.to(device), target.to(device) 406 | 407 | optimizer.zero_grad() 408 | x.requires_grad = True 409 | target.requires_grad = False 410 | target = target.long() 411 | 412 | out = net(x) 413 | loss = criterion(out, target) 414 | 415 | loss.backward() 416 | optimizer.step() 417 | 418 | tau = tau + 1 419 | 420 | epoch_loss_collector.append(loss.item()) 421 | 422 | 423 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) 424 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss)) 425 | 426 | global_model.to(device) 427 | a_i = (tau - args.rho * (1 - pow(args.rho, tau)) / (1 - args.rho)) / (1 - args.rho) 428 | global_model.to(device) 429 | global_model_para = global_model.state_dict() 430 | net_para = net.state_dict() 431 | norm_grad = copy.deepcopy(global_model.state_dict()) 432 | for key in norm_grad: 433 | #norm_grad[key] = (global_model_para[key] - net_para[key]) / a_i 434 | norm_grad[key] = torch.true_divide(global_model_para[key]-net_para[key], a_i) 435 | train_acc = compute_accuracy(net, train_dataloader, device=device) 436 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, device=device) 437 | 438 | logger.info('>> Training accuracy: %f' % train_acc) 439 | logger.info('>> Test accuracy: %f' % test_acc) 440 | 441 | net.to('cpu') 442 | logger.info(' ** Training complete **') 443 | return train_acc, test_acc, a_i, norm_grad 444 | 445 | 446 | def train_net_moon(net_id, net, global_net, previous_nets, train_dataloader, test_dataloader, epochs, lr, args_optimizer, mu, temperature, args, 447 | round, device="cpu"): 448 | 449 | logger.info('Training network %s' % str(net_id)) 450 | 451 | train_acc = compute_accuracy(net, train_dataloader, moon_model=True, device=device) 452 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, moon_model=True, device=device) 453 | 454 | logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc)) 455 | logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc)) 456 | 457 | # conloss = ContrastiveLoss(temperature) 458 | 459 | if args_optimizer == 'adam': 460 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg) 461 | elif args_optimizer == 'amsgrad': 462 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg, 463 | amsgrad=True) 464 | elif args_optimizer == 'sgd': 465 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=0.9, 466 | weight_decay=args.reg) 467 | 468 | criterion = nn.CrossEntropyLoss().to(device) 469 | # global_net.to(device) 470 | 471 | if args.loss != 'l2norm': 472 | for previous_net in previous_nets: 473 | previous_net.to(device) 474 | global_w = global_net.state_dict() 475 | # oppsi_nets = copy.deepcopy(previous_nets) 476 | # for net_id, oppsi_net in enumerate(oppsi_nets): 477 | # oppsi_w = oppsi_net.state_dict() 478 | # prev_w = previous_nets[net_id].state_dict() 479 | # for key in oppsi_w: 480 | # oppsi_w[key] = 2*global_w[key] - prev_w[key] 481 | # oppsi_nets.load_state_dict(oppsi_w) 482 | cnt = 0 483 | cos=torch.nn.CosineSimilarity(dim=-1).to(device) 484 | # mu = 0.001 485 | 486 | for epoch in range(epochs): 487 | epoch_loss_collector = [] 488 | epoch_loss1_collector = [] 489 | epoch_loss2_collector = [] 490 | for batch_idx, (x, target) in enumerate(train_dataloader): 491 | x, target = x.to(device), target.to(device) 492 | if target.shape[0] == 1: 493 | continue 494 | 495 | optimizer.zero_grad() 496 | x.requires_grad = True 497 | target.requires_grad = False 498 | target = target.long() 499 | 500 | _, pro1, out = net(x) 501 | _, pro2, _ = global_net(x) 502 | if args.loss == 'l2norm': 503 | loss2 = mu * torch.mean(torch.norm(pro2-pro1, dim=1)) 504 | 505 | elif args.loss == 'only_contrastive' or args.loss == 'contrastive': 506 | posi = cos(pro1, pro2) 507 | logits = posi.reshape(-1,1) 508 | 509 | for previous_net in previous_nets: 510 | previous_net.to(device) 511 | _, pro3, _ = previous_net(x) 512 | nega = cos(pro1, pro3) 513 | logits = torch.cat((logits, nega.reshape(-1,1)), dim=1) 514 | 515 | # previous_net.to('cpu') 516 | 517 | logits /= temperature 518 | labels = torch.zeros(x.size(0)).to(device).long() 519 | 520 | # loss = criterion(out, target) + mu * ContraLoss(pro1, pro2, pro3) 521 | 522 | loss2 = mu * criterion(logits, labels) 523 | 524 | if args.loss == 'only_contrastive': 525 | loss = loss2 526 | else: 527 | loss1 = criterion(out, target) 528 | loss = loss1 + loss2 529 | 530 | loss.backward() 531 | optimizer.step() 532 | 533 | cnt += 1 534 | epoch_loss_collector.append(loss.item()) 535 | epoch_loss1_collector.append(loss1.item()) 536 | epoch_loss2_collector.append(loss2.item()) 537 | 538 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) 539 | epoch_loss1 = sum(epoch_loss1_collector) / len(epoch_loss1_collector) 540 | epoch_loss2 = sum(epoch_loss2_collector) / len(epoch_loss2_collector) 541 | logger.info('Epoch: %d Loss: %f Loss1: %f Loss2: %f' % (epoch, epoch_loss, epoch_loss1, epoch_loss2)) 542 | 543 | 544 | if args.loss != 'l2norm': 545 | for previous_net in previous_nets: 546 | previous_net.to('cpu') 547 | train_acc = compute_accuracy(net, train_dataloader, moon_model=True, device=device) 548 | test_acc, conf_matrix = compute_accuracy(net, test_dataloader, get_confusion_matrix=True, moon_model=True, device=device) 549 | 550 | logger.info('>> Training accuracy: %f' % train_acc) 551 | logger.info('>> Test accuracy: %f' % test_acc) 552 | net.to('cpu') 553 | logger.info(' ** Training complete **') 554 | return train_acc, test_acc 555 | 556 | 557 | def view_image(train_dataloader): 558 | for (x, target) in train_dataloader: 559 | np.save("img.npy", x) 560 | print(x.shape) 561 | exit(0) 562 | 563 | 564 | def local_train_net(nets, selected, args, net_dataidx_map, test_dl = None, device="cpu"): 565 | avg_acc = 0.0 566 | 567 | for net_id, net in nets.items(): 568 | if net_id not in selected: 569 | continue 570 | dataidxs = net_dataidx_map[net_id] 571 | 572 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 573 | # move the model to cuda device: 574 | net.to(device) 575 | 576 | noise_level = args.noise 577 | if net_id == args.n_parties - 1: 578 | noise_level = 0 579 | 580 | if args.noise_type == 'space': 581 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) 582 | else: 583 | noise_level = args.noise / (args.n_parties - 1) * net_id 584 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) 585 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) 586 | n_epoch = args.epochs 587 | 588 | 589 | trainacc, testacc = train_net(net_id, net, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device) 590 | logger.info("net %d final test acc %f" % (net_id, testacc)) 591 | avg_acc += testacc 592 | # saving the trained models here 593 | # save_model(net, net_id, args) 594 | # else: 595 | # load_model(net, net_id, device=device) 596 | avg_acc /= len(selected) 597 | if args.alg == 'local_training': 598 | logger.info("avg test acc %f" % avg_acc) 599 | 600 | nets_list = list(nets.values()) 601 | return nets_list 602 | 603 | 604 | def local_train_net_fedprox(nets, selected, global_model, args, net_dataidx_map, test_dl = None, device="cpu"): 605 | avg_acc = 0.0 606 | 607 | for net_id, net in nets.items(): 608 | if net_id not in selected: 609 | continue 610 | dataidxs = net_dataidx_map[net_id] 611 | 612 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 613 | # move the model to cuda device: 614 | net.to(device) 615 | 616 | noise_level = args.noise 617 | if net_id == args.n_parties - 1: 618 | noise_level = 0 619 | 620 | if args.noise_type == 'space': 621 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) 622 | else: 623 | noise_level = args.noise / (args.n_parties - 1) * net_id 624 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) 625 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) 626 | n_epoch = args.epochs 627 | 628 | trainacc, testacc = train_net_fedprox(net_id, net, global_model, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, args.mu, device=device) 629 | logger.info("net %d final test acc %f" % (net_id, testacc)) 630 | avg_acc += testacc 631 | avg_acc /= len(selected) 632 | if args.alg == 'local_training': 633 | logger.info("avg test acc %f" % avg_acc) 634 | 635 | nets_list = list(nets.values()) 636 | return nets_list 637 | 638 | def local_train_net_scaffold(nets, selected, global_model, c_nets, c_global, args, net_dataidx_map, test_dl = None, device="cpu"): 639 | avg_acc = 0.0 640 | 641 | total_delta = copy.deepcopy(global_model.state_dict()) 642 | for key in total_delta: 643 | total_delta[key] = 0.0 644 | c_global.to(device) 645 | global_model.to(device) 646 | for net_id, net in nets.items(): 647 | if net_id not in selected: 648 | continue 649 | dataidxs = net_dataidx_map[net_id] 650 | 651 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 652 | # move the model to cuda device: 653 | net.to(device) 654 | 655 | c_nets[net_id].to(device) 656 | 657 | noise_level = args.noise 658 | if net_id == args.n_parties - 1: 659 | noise_level = 0 660 | 661 | if args.noise_type == 'space': 662 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) 663 | else: 664 | noise_level = args.noise / (args.n_parties - 1) * net_id 665 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) 666 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) 667 | n_epoch = args.epochs 668 | 669 | 670 | trainacc, testacc, c_delta_para = train_net_scaffold(net_id, net, global_model, c_nets[net_id], c_global, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device) 671 | 672 | c_nets[net_id].to('cpu') 673 | for key in total_delta: 674 | total_delta[key] += c_delta_para[key] 675 | 676 | 677 | logger.info("net %d final test acc %f" % (net_id, testacc)) 678 | avg_acc += testacc 679 | for key in total_delta: 680 | total_delta[key] /= args.n_parties 681 | c_global_para = c_global.state_dict() 682 | for key in c_global_para: 683 | if c_global_para[key].type() == 'torch.LongTensor': 684 | c_global_para[key] += total_delta[key].type(torch.LongTensor) 685 | elif c_global_para[key].type() == 'torch.cuda.LongTensor': 686 | c_global_para[key] += total_delta[key].type(torch.cuda.LongTensor) 687 | else: 688 | #print(c_global_para[key].type()) 689 | c_global_para[key] += total_delta[key] 690 | c_global.load_state_dict(c_global_para) 691 | 692 | avg_acc /= len(selected) 693 | if args.alg == 'local_training': 694 | logger.info("avg test acc %f" % avg_acc) 695 | 696 | nets_list = list(nets.values()) 697 | return nets_list 698 | 699 | def local_train_net_fednova(nets, selected, global_model, args, net_dataidx_map, test_dl = None, device="cpu"): 700 | avg_acc = 0.0 701 | 702 | a_list = [] 703 | d_list = [] 704 | n_list = [] 705 | global_model.to(device) 706 | for net_id, net in nets.items(): 707 | if net_id not in selected: 708 | continue 709 | dataidxs = net_dataidx_map[net_id] 710 | 711 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 712 | # move the model to cuda device: 713 | net.to(device) 714 | 715 | noise_level = args.noise 716 | if net_id == args.n_parties - 1: 717 | noise_level = 0 718 | 719 | if args.noise_type == 'space': 720 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) 721 | else: 722 | noise_level = args.noise / (args.n_parties - 1) * net_id 723 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) 724 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) 725 | n_epoch = args.epochs 726 | 727 | 728 | trainacc, testacc, a_i, d_i = train_net_fednova(net_id, net, global_model, train_dl_local, test_dl, n_epoch, args.lr, args.optimizer, device=device) 729 | 730 | a_list.append(a_i) 731 | d_list.append(d_i) 732 | n_i = len(train_dl_local.dataset) 733 | n_list.append(n_i) 734 | logger.info("net %d final test acc %f" % (net_id, testacc)) 735 | avg_acc += testacc 736 | 737 | 738 | avg_acc /= len(selected) 739 | if args.alg == 'local_training': 740 | logger.info("avg test acc %f" % avg_acc) 741 | 742 | nets_list = list(nets.values()) 743 | return nets_list, a_list, d_list, n_list 744 | 745 | def local_train_net_moon(nets, selected, args, net_dataidx_map, test_dl=None, global_model = None, prev_model_pool = None, round=None, device="cpu"): 746 | avg_acc = 0.0 747 | global_model.to(device) 748 | for net_id, net in nets.items(): 749 | if net_id not in selected: 750 | continue 751 | dataidxs = net_dataidx_map[net_id] 752 | 753 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 754 | net.to(device) 755 | 756 | noise_level = args.noise 757 | if net_id == args.n_parties - 1: 758 | noise_level = 0 759 | 760 | if args.noise_type == 'space': 761 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1) 762 | else: 763 | noise_level = args.noise / (args.n_parties - 1) * net_id 764 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) 765 | train_dl_global, test_dl_global, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32) 766 | n_epoch = args.epochs 767 | 768 | prev_models=[] 769 | for i in range(len(prev_model_pool)): 770 | prev_models.append(prev_model_pool[i][net_id]) 771 | trainacc, testacc = train_net_moon(net_id, net, global_model, prev_models, train_dl_local, test_dl, n_epoch, args.lr, 772 | args.optimizer, args.mu, args.temperature, args, round, device=device) 773 | logger.info("net %d final test acc %f" % (net_id, testacc)) 774 | avg_acc += testacc 775 | 776 | avg_acc /= len(selected) 777 | if args.alg == 'local_training': 778 | logger.info("avg test acc %f" % avg_acc) 779 | global_model.to('cpu') 780 | nets_list = list(nets.values()) 781 | return nets_list 782 | 783 | 784 | 785 | def get_partition_dict(dataset, partition, n_parties, init_seed=0, datadir='./data', logdir='./logs', beta=0.5): 786 | seed = init_seed 787 | np.random.seed(seed) 788 | torch.manual_seed(seed) 789 | random.seed(seed) 790 | X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data( 791 | dataset, datadir, logdir, partition, n_parties, beta=beta) 792 | 793 | return net_dataidx_map 794 | 795 | if __name__ == '__main__': 796 | # torch.set_printoptions(profile="full") 797 | args = get_args() 798 | mkdirs(args.logdir) 799 | mkdirs(args.modeldir) 800 | if args.log_file_name is None: 801 | argument_path='experiment_arguments-%s.json' % datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S") 802 | else: 803 | argument_path=args.log_file_name+'.json' 804 | with open(os.path.join(args.logdir, argument_path), 'w') as f: 805 | json.dump(str(args), f) 806 | device = torch.device(args.device) 807 | # logging.basicConfig(filename='test.log', level=logger.info, filemode='w') 808 | # logging.info("test") 809 | for handler in logging.root.handlers[:]: 810 | logging.root.removeHandler(handler) 811 | 812 | if args.log_file_name is None: 813 | args.log_file_name = 'experiment_log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S")) 814 | log_path=args.log_file_name+'.log' 815 | logging.basicConfig( 816 | filename=os.path.join(args.logdir, log_path), 817 | # filename='/home/qinbin/test.log', 818 | format='%(asctime)s %(levelname)-8s %(message)s', 819 | datefmt='%m-%d %H:%M', level=logging.DEBUG, filemode='w') 820 | 821 | logger = logging.getLogger() 822 | logger.setLevel(logging.DEBUG) 823 | logger.info(device) 824 | 825 | seed = args.init_seed 826 | logger.info("#" * 100) 827 | np.random.seed(seed) 828 | torch.manual_seed(seed) 829 | random.seed(seed) 830 | logger.info("Partitioning data") 831 | X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data( 832 | args.dataset, args.datadir, args.logdir, args.partition, args.n_parties, beta=args.beta) 833 | 834 | n_classes = len(np.unique(y_train)) 835 | 836 | train_dl_global, test_dl_global, train_ds_global, test_ds_global = get_dataloader(args.dataset, 837 | args.datadir, 838 | args.batch_size, 839 | 32) 840 | 841 | print("len train_dl_global:", len(train_ds_global)) 842 | 843 | 844 | data_size = len(test_ds_global) 845 | 846 | # test_dl = data.DataLoader(dataset=test_ds_global, batch_size=32, shuffle=False) 847 | 848 | train_all_in_list = [] 849 | test_all_in_list = [] 850 | if args.noise > 0: 851 | for party_id in range(args.n_parties): 852 | dataidxs = net_dataidx_map[party_id] 853 | 854 | noise_level = args.noise 855 | if party_id == args.n_parties - 1: 856 | noise_level = 0 857 | 858 | if args.noise_type == 'space': 859 | train_dl_local, test_dl_local, train_ds_local, test_ds_local = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, party_id, args.n_parties-1) 860 | else: 861 | noise_level = args.noise / (args.n_parties - 1) * party_id 862 | train_dl_local, test_dl_local, train_ds_local, test_ds_local = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level) 863 | train_all_in_list.append(train_ds_local) 864 | test_all_in_list.append(test_ds_local) 865 | train_all_in_ds = data.ConcatDataset(train_all_in_list) 866 | train_dl_global = data.DataLoader(dataset=train_all_in_ds, batch_size=args.batch_size, shuffle=True) 867 | test_all_in_ds = data.ConcatDataset(test_all_in_list) 868 | test_dl_global = data.DataLoader(dataset=test_all_in_ds, batch_size=32, shuffle=False) 869 | 870 | 871 | 872 | if args.alg == 'fedavg': 873 | logger.info("Initializing nets") 874 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) 875 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args) 876 | global_model = global_models[0] 877 | 878 | global_para = global_model.state_dict() 879 | if args.is_same_initial: 880 | for net_id, net in nets.items(): 881 | net.load_state_dict(global_para) 882 | 883 | for round in range(args.comm_round): 884 | logger.info("in comm round:" + str(round)) 885 | 886 | arr = np.arange(args.n_parties) 887 | np.random.shuffle(arr) 888 | selected = arr[:int(args.n_parties * args.sample)] 889 | 890 | global_para = global_model.state_dict() 891 | if round == 0: 892 | if args.is_same_initial: 893 | for idx in selected: 894 | nets[idx].load_state_dict(global_para) 895 | else: 896 | for idx in selected: 897 | nets[idx].load_state_dict(global_para) 898 | 899 | local_train_net(nets, selected, args, net_dataidx_map, test_dl = test_dl_global, device=device) 900 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device) 901 | 902 | # update global model 903 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected]) 904 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected] 905 | 906 | for idx in range(len(selected)): 907 | net_para = nets[selected[idx]].cpu().state_dict() 908 | if idx == 0: 909 | for key in net_para: 910 | global_para[key] = net_para[key] * fed_avg_freqs[idx] 911 | else: 912 | for key in net_para: 913 | global_para[key] += net_para[key] * fed_avg_freqs[idx] 914 | global_model.load_state_dict(global_para) 915 | 916 | logger.info('global n_training: %d' % len(train_dl_global)) 917 | logger.info('global n_test: %d' % len(test_dl_global)) 918 | 919 | global_model.to(device) 920 | train_acc = compute_accuracy(global_model, train_dl_global, device=device) 921 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device) 922 | 923 | 924 | logger.info('>> Global Model Train accuracy: %f' % train_acc) 925 | logger.info('>> Global Model Test accuracy: %f' % test_acc) 926 | 927 | 928 | elif args.alg == 'fedprox': 929 | logger.info("Initializing nets") 930 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) 931 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args) 932 | global_model = global_models[0] 933 | 934 | global_para = global_model.state_dict() 935 | 936 | if args.is_same_initial: 937 | for net_id, net in nets.items(): 938 | net.load_state_dict(global_para) 939 | 940 | for round in range(args.comm_round): 941 | logger.info("in comm round:" + str(round)) 942 | 943 | arr = np.arange(args.n_parties) 944 | np.random.shuffle(arr) 945 | selected = arr[:int(args.n_parties * args.sample)] 946 | 947 | global_para = global_model.state_dict() 948 | if round == 0: 949 | if args.is_same_initial: 950 | for idx in selected: 951 | nets[idx].load_state_dict(global_para) 952 | else: 953 | for idx in selected: 954 | nets[idx].load_state_dict(global_para) 955 | 956 | local_train_net_fedprox(nets, selected, global_model, args, net_dataidx_map, test_dl = test_dl_global, device=device) 957 | global_model.to('cpu') 958 | 959 | # update global model 960 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected]) 961 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected] 962 | 963 | for idx in range(len(selected)): 964 | net_para = nets[selected[idx]].cpu().state_dict() 965 | if idx == 0: 966 | for key in net_para: 967 | global_para[key] = net_para[key] * fed_avg_freqs[idx] 968 | else: 969 | for key in net_para: 970 | global_para[key] += net_para[key] * fed_avg_freqs[idx] 971 | global_model.load_state_dict(global_para) 972 | 973 | 974 | logger.info('global n_training: %d' % len(train_dl_global)) 975 | logger.info('global n_test: %d' % len(test_dl_global)) 976 | 977 | 978 | global_model.to(device) 979 | train_acc = compute_accuracy(global_model, train_dl_global, device=device) 980 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device) 981 | 982 | 983 | logger.info('>> Global Model Train accuracy: %f' % train_acc) 984 | logger.info('>> Global Model Test accuracy: %f' % test_acc) 985 | 986 | elif args.alg == 'scaffold': 987 | logger.info("Initializing nets") 988 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) 989 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args) 990 | global_model = global_models[0] 991 | 992 | c_nets, _, _ = init_nets(args.net_config, args.dropout_p, args.n_parties, args) 993 | c_globals, _, _ = init_nets(args.net_config, 0, 1, args) 994 | c_global = c_globals[0] 995 | c_global_para = c_global.state_dict() 996 | for net_id, net in c_nets.items(): 997 | net.load_state_dict(c_global_para) 998 | 999 | global_para = global_model.state_dict() 1000 | if args.is_same_initial: 1001 | for net_id, net in nets.items(): 1002 | net.load_state_dict(global_para) 1003 | 1004 | 1005 | for round in range(args.comm_round): 1006 | logger.info("in comm round:" + str(round)) 1007 | 1008 | arr = np.arange(args.n_parties) 1009 | np.random.shuffle(arr) 1010 | selected = arr[:int(args.n_parties * args.sample)] 1011 | 1012 | global_para = global_model.state_dict() 1013 | if round == 0: 1014 | if args.is_same_initial: 1015 | for idx in selected: 1016 | nets[idx].load_state_dict(global_para) 1017 | else: 1018 | for idx in selected: 1019 | nets[idx].load_state_dict(global_para) 1020 | 1021 | local_train_net_scaffold(nets, selected, global_model, c_nets, c_global, args, net_dataidx_map, test_dl = test_dl_global, device=device) 1022 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device) 1023 | 1024 | # update global model 1025 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected]) 1026 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected] 1027 | 1028 | for idx in range(len(selected)): 1029 | net_para = nets[selected[idx]].cpu().state_dict() 1030 | if idx == 0: 1031 | for key in net_para: 1032 | global_para[key] = net_para[key] * fed_avg_freqs[idx] 1033 | else: 1034 | for key in net_para: 1035 | global_para[key] += net_para[key] * fed_avg_freqs[idx] 1036 | global_model.load_state_dict(global_para) 1037 | 1038 | 1039 | logger.info('global n_training: %d' % len(train_dl_global)) 1040 | logger.info('global n_test: %d' % len(test_dl_global)) 1041 | 1042 | global_model.to(device) 1043 | train_acc = compute_accuracy(global_model, train_dl_global, device=device) 1044 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device) 1045 | 1046 | logger.info('>> Global Model Train accuracy: %f' % train_acc) 1047 | logger.info('>> Global Model Test accuracy: %f' % test_acc) 1048 | 1049 | elif args.alg == 'fednova': 1050 | logger.info("Initializing nets") 1051 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) 1052 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args) 1053 | global_model = global_models[0] 1054 | 1055 | d_list = [copy.deepcopy(global_model.state_dict()) for i in range(args.n_parties)] 1056 | d_total_round = copy.deepcopy(global_model.state_dict()) 1057 | for i in range(args.n_parties): 1058 | for key in d_list[i]: 1059 | d_list[i][key] = 0 1060 | for key in d_total_round: 1061 | d_total_round[key] = 0 1062 | 1063 | data_sum = 0 1064 | for i in range(args.n_parties): 1065 | data_sum += len(traindata_cls_counts[i]) 1066 | portion = [] 1067 | for i in range(args.n_parties): 1068 | portion.append(len(traindata_cls_counts[i]) / data_sum) 1069 | 1070 | global_para = global_model.state_dict() 1071 | if args.is_same_initial: 1072 | for net_id, net in nets.items(): 1073 | net.load_state_dict(global_para) 1074 | 1075 | for round in range(args.comm_round): 1076 | logger.info("in comm round:" + str(round)) 1077 | 1078 | arr = np.arange(args.n_parties) 1079 | np.random.shuffle(arr) 1080 | selected = arr[:int(args.n_parties * args.sample)] 1081 | 1082 | global_para = global_model.state_dict() 1083 | if round == 0: 1084 | if args.is_same_initial: 1085 | for idx in selected: 1086 | nets[idx].load_state_dict(global_para) 1087 | else: 1088 | for idx in selected: 1089 | nets[idx].load_state_dict(global_para) 1090 | 1091 | _, a_list, d_list, n_list = local_train_net_fednova(nets, selected, global_model, args, net_dataidx_map, test_dl = test_dl_global, device=device) 1092 | total_n = sum(n_list) 1093 | #print("total_n:", total_n) 1094 | d_total_round = copy.deepcopy(global_model.state_dict()) 1095 | for key in d_total_round: 1096 | d_total_round[key] = 0.0 1097 | 1098 | for i in range(len(selected)): 1099 | d_para = d_list[i] 1100 | for key in d_para: 1101 | #if d_total_round[key].type == 'torch.LongTensor': 1102 | # d_total_round[key] += (d_para[key] * n_list[i] / total_n).type(torch.LongTensor) 1103 | #else: 1104 | d_total_round[key] += d_para[key] * n_list[i] / total_n 1105 | 1106 | 1107 | # for i in range(len(selected)): 1108 | # d_total_round = d_total_round + d_list[i] * n_list[i] / total_n 1109 | 1110 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device) 1111 | 1112 | # update global model 1113 | coeff = 0.0 1114 | for i in range(len(selected)): 1115 | coeff = coeff + a_list[i] * n_list[i]/total_n 1116 | 1117 | updated_model = global_model.state_dict() 1118 | for key in updated_model: 1119 | #print(updated_model[key]) 1120 | if updated_model[key].type() == 'torch.LongTensor': 1121 | updated_model[key] -= (coeff * d_total_round[key]).type(torch.LongTensor) 1122 | elif updated_model[key].type() == 'torch.cuda.LongTensor': 1123 | updated_model[key] -= (coeff * d_total_round[key]).type(torch.cuda.LongTensor) 1124 | else: 1125 | #print(updated_model[key].type()) 1126 | #print((coeff*d_total_round[key].type())) 1127 | updated_model[key] -= coeff * d_total_round[key] 1128 | global_model.load_state_dict(updated_model) 1129 | 1130 | 1131 | logger.info('global n_training: %d' % len(train_dl_global)) 1132 | logger.info('global n_test: %d' % len(test_dl_global)) 1133 | 1134 | global_model.to(device) 1135 | train_acc = compute_accuracy(global_model, train_dl_global, device=device) 1136 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, device=device) 1137 | 1138 | 1139 | logger.info('>> Global Model Train accuracy: %f' % train_acc) 1140 | logger.info('>> Global Model Test accuracy: %f' % test_acc) 1141 | 1142 | elif args.alg == 'moon': 1143 | logger.info("Initializing nets") 1144 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) 1145 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 0, 1, args) 1146 | global_model = global_models[0] 1147 | 1148 | global_para = global_model.state_dict() 1149 | if args.is_same_initial: 1150 | for net_id, net in nets.items(): 1151 | net.load_state_dict(global_para) 1152 | 1153 | old_nets_pool = [] 1154 | old_nets = copy.deepcopy(nets) 1155 | for _, net in old_nets.items(): 1156 | net.eval() 1157 | for param in net.parameters(): 1158 | param.requires_grad = False 1159 | 1160 | for round in range(args.comm_round): 1161 | logger.info("in comm round:" + str(round)) 1162 | 1163 | arr = np.arange(args.n_parties) 1164 | np.random.shuffle(arr) 1165 | selected = arr[:int(args.n_parties * args.sample)] 1166 | 1167 | global_para = global_model.state_dict() 1168 | if round == 0: 1169 | if args.is_same_initial: 1170 | for idx in selected: 1171 | nets[idx].load_state_dict(global_para) 1172 | else: 1173 | for idx in selected: 1174 | nets[idx].load_state_dict(global_para) 1175 | 1176 | local_train_net_moon(nets, selected, args, net_dataidx_map, test_dl = test_dl_global, global_model=global_model, 1177 | prev_model_pool=old_nets_pool, round=round, device=device) 1178 | # local_train_net(nets, args, net_dataidx_map, local_split=False, device=device) 1179 | 1180 | # update global model 1181 | total_data_points = sum([len(net_dataidx_map[r]) for r in selected]) 1182 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected] 1183 | 1184 | for idx in range(len(selected)): 1185 | net_para = nets[selected[idx]].cpu().state_dict() 1186 | if idx == 0: 1187 | for key in net_para: 1188 | global_para[key] = net_para[key] * fed_avg_freqs[idx] 1189 | else: 1190 | for key in net_para: 1191 | global_para[key] += net_para[key] * fed_avg_freqs[idx] 1192 | global_model.load_state_dict(global_para) 1193 | 1194 | logger.info('global n_training: %d' % len(train_dl_global)) 1195 | logger.info('global n_test: %d' % len(test_dl_global)) 1196 | 1197 | global_model.to(device) 1198 | train_acc = compute_accuracy(global_model, train_dl_global, moon_model=True, device=device) 1199 | test_acc, conf_matrix = compute_accuracy(global_model, test_dl_global, get_confusion_matrix=True, moon_model=True, device=device) 1200 | 1201 | 1202 | logger.info('>> Global Model Train accuracy: %f' % train_acc) 1203 | logger.info('>> Global Model Test accuracy: %f' % test_acc) 1204 | 1205 | old_nets = copy.deepcopy(nets) 1206 | for _, net in old_nets.items(): 1207 | net.eval() 1208 | for param in net.parameters(): 1209 | param.requires_grad = False 1210 | if len(old_nets_pool) < 1: 1211 | old_nets_pool.append(old_nets) 1212 | else: 1213 | old_nets_pool[0] = old_nets 1214 | 1215 | elif args.alg == 'local_training': 1216 | logger.info("Initializing nets") 1217 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, args.n_parties, args) 1218 | arr = np.arange(args.n_parties) 1219 | local_train_net(nets, arr, args, net_dataidx_map, test_dl = test_dl_global, device=device) 1220 | 1221 | elif args.alg == 'all_in': 1222 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.dropout_p, 1, args) 1223 | n_epoch = args.epochs 1224 | nets[0].to(device) 1225 | trainacc, testacc = train_net(0, nets[0], train_dl_global, test_dl_global, n_epoch, args.lr, args.optimizer, device=device) 1226 | 1227 | logger.info("All in test acc: %f" % testacc) 1228 | 1229 | 1230 | -------------------------------------------------------------------------------- /femnist-dis.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/femnist-dis.npy -------------------------------------------------------------------------------- /figures/100parties/cifar10-homo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-homo.png -------------------------------------------------------------------------------- /figures/100parties/cifar10-lb2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-lb2.png -------------------------------------------------------------------------------- /figures/100parties/cifar10-lbdir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-lbdir.png -------------------------------------------------------------------------------- /figures/100parties/cifar10-quan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/100parties/cifar10-quan.png -------------------------------------------------------------------------------- /figures/10parties/cifar10-iid-diff-quantity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-iid-diff-quantity.png -------------------------------------------------------------------------------- /figures/10parties/cifar10-noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-noise.png -------------------------------------------------------------------------------- /figures/10parties/cifar10-noniid-label2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-noniid-label2.png -------------------------------------------------------------------------------- /figures/10parties/cifar10-noniid-labeldir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/10parties/cifar10-noniid-labeldir.png -------------------------------------------------------------------------------- /figures/heavy-model/resnet-noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/heavy-model/resnet-noise.png -------------------------------------------------------------------------------- /figures/heavy-model/vgg-lbdir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/NIID-Bench/61b901ab645e62391772e9898384c2c3485b7a6c/figures/heavy-model/vgg-lbdir.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torchvision.models as models 6 | from resnetcifar import ResNet18_cifar10, ResNet50_cifar10 7 | 8 | 9 | # class MLP_header(nn.Module): 10 | # def __init__(self,): 11 | # super(MLP_header, self).__init__() 12 | # self.fc1 = nn.Linear(28*28, 512) 13 | # self.fc2 = nn.Linear(512, 512) 14 | # self.relu = nn.ReLU() 15 | # #projection 16 | # # self.fc3 = nn.Linear(512, 10) 17 | # 18 | # def forward(self, x): 19 | # x = x.view(-1, 28*28) 20 | # x = self.fc1(x) 21 | # x = self.relu(x) 22 | # x = self.fc2(x) 23 | # x = self.relu(x) 24 | # return x 25 | 26 | class MLP_header(nn.Module): 27 | def __init__(self, input_dim, hidden_dims, dropout_p=0.0): 28 | 29 | super().__init__() 30 | 31 | self.input_dim = input_dim 32 | self.hidden_dims = hidden_dims 33 | self.dropout_p = dropout_p 34 | 35 | self.dims = [self.input_dim] 36 | self.dims.extend(hidden_dims) 37 | 38 | self.layers = nn.ModuleList([]) 39 | 40 | for i in range(len(self.dims) - 1): 41 | ip_dim = self.dims[i] 42 | op_dim = self.dims[i + 1] 43 | self.layers.append( 44 | nn.Linear(ip_dim, op_dim, bias=True) 45 | ) 46 | 47 | self.__init_net_weights__() 48 | 49 | def __init_net_weights__(self): 50 | 51 | for m in self.layers: 52 | m.weight.data.normal_(0.0, 0.1) 53 | m.bias.data.fill_(0.1) 54 | 55 | def forward(self, x): 56 | x = x.view(-1, self.input_dim) 57 | for i, layer in enumerate(self.layers): 58 | x = layer(x) 59 | x = F.relu(x) 60 | 61 | 62 | return x 63 | 64 | 65 | class FcNet(nn.Module): 66 | """ 67 | Fully connected network for MNIST classification 68 | """ 69 | 70 | def __init__(self, input_dim, hidden_dims, output_dim, dropout_p=0.0): 71 | 72 | super().__init__() 73 | 74 | self.input_dim = input_dim 75 | self.hidden_dims = hidden_dims 76 | self.output_dim = output_dim 77 | self.dropout_p = dropout_p 78 | 79 | self.dims = [self.input_dim] 80 | self.dims.extend(hidden_dims) 81 | self.dims.append(self.output_dim) 82 | 83 | self.layers = nn.ModuleList([]) 84 | 85 | for i in range(len(self.dims) - 1): 86 | ip_dim = self.dims[i] 87 | op_dim = self.dims[i + 1] 88 | self.layers.append( 89 | nn.Linear(ip_dim, op_dim, bias=True) 90 | ) 91 | 92 | self.__init_net_weights__() 93 | 94 | def __init_net_weights__(self): 95 | 96 | for m in self.layers: 97 | m.weight.data.normal_(0.0, 0.1) 98 | m.bias.data.fill_(0.1) 99 | 100 | def forward(self, x): 101 | 102 | x = x.view(-1, self.input_dim) 103 | 104 | for i, layer in enumerate(self.layers): 105 | x = layer(x) 106 | 107 | # Do not apply ReLU on the final layer 108 | if i < (len(self.layers) - 1): 109 | x = F.relu(x) 110 | 111 | if i < (len(self.layers) - 1): # No dropout on output layer 112 | x = F.dropout(x, p=self.dropout_p, training=self.training) 113 | 114 | return x 115 | 116 | 117 | class ConvBlock(nn.Module): 118 | def __init__(self): 119 | super(ConvBlock, self).__init__() 120 | self.conv1 = nn.Conv2d(3, 6, 5) 121 | self.pool = nn.MaxPool2d(2, 2) 122 | self.conv2 = nn.Conv2d(6, 16, 5) 123 | 124 | def forward(self, x): 125 | x = self.pool(F.relu(self.conv1(x))) 126 | x = self.pool(F.relu(self.conv2(x))) 127 | x = x.view(-1, 16 * 5 * 5) 128 | return x 129 | 130 | 131 | class FCBlock(nn.Module): 132 | def __init__(self, input_dim, hidden_dims, output_dim=10): 133 | super(FCBlock, self).__init__() 134 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 135 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 136 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 137 | 138 | def forward(self, x): 139 | x = F.relu(self.fc1(x)) 140 | x = F.relu(self.fc2(x)) 141 | x = self.fc3(x) 142 | return x 143 | 144 | 145 | class VGGConvBlocks(nn.Module): 146 | ''' 147 | VGG model 148 | ''' 149 | 150 | def __init__(self, features, num_classes=10): 151 | super(VGGConvBlocks, self).__init__() 152 | self.features = features 153 | # Initialize weights 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 157 | m.weight.data.normal_(0, math.sqrt(2. / n)) 158 | m.bias.data.zero_() 159 | 160 | def forward(self, x): 161 | x = self.features(x) 162 | x = x.view(x.size(0), -1) 163 | return x 164 | 165 | 166 | class FCBlockVGG(nn.Module): 167 | def __init__(self, input_dim, hidden_dims, output_dim=10): 168 | super(FCBlockVGG, self).__init__() 169 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 170 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 171 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 172 | 173 | def forward(self, x): 174 | x = F.dropout(x) 175 | x = F.relu(self.fc1(x)) 176 | x = F.dropout(x) 177 | x = F.relu(self.fc2(x)) 178 | x = self.fc3(x) 179 | return x 180 | 181 | class SimpleCNN_header(nn.Module): 182 | def __init__(self, input_dim, hidden_dims, output_dim=10): 183 | super(SimpleCNN_header, self).__init__() 184 | self.conv1 = nn.Conv2d(3, 6, 5) 185 | self.relu = nn.ReLU() 186 | self.pool = nn.MaxPool2d(2, 2) 187 | self.conv2 = nn.Conv2d(6, 16, 5) 188 | 189 | # for now, we hard coded this network 190 | # i.e. we fix the number of hidden layers i.e. 2 layers 191 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 192 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 193 | #self.fc3 = nn.Linear(hidden_dims[1], output_dim) 194 | 195 | def forward(self, x): 196 | 197 | x = self.pool(self.relu(self.conv1(x))) 198 | x = self.pool(self.relu(self.conv2(x))) 199 | x = x.view(-1, 16 * 5 * 5) 200 | 201 | x = self.relu(self.fc1(x)) 202 | x = self.relu(self.fc2(x)) 203 | # x = self.fc3(x) 204 | return x 205 | 206 | class SimpleCNN(nn.Module): 207 | def __init__(self, input_dim, hidden_dims, output_dim=10): 208 | super(SimpleCNN, self).__init__() 209 | self.conv1 = nn.Conv2d(3, 6, 5) 210 | self.pool = nn.MaxPool2d(2, 2) 211 | self.conv2 = nn.Conv2d(6, 16, 5) 212 | 213 | # for now, we hard coded this network 214 | # i.e. we fix the number of hidden layers i.e. 2 layers 215 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 216 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 217 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 218 | 219 | def forward(self, x): 220 | x = self.pool(F.relu(self.conv1(x))) 221 | x = self.pool(F.relu(self.conv2(x))) 222 | x = x.view(-1, 16 * 5 * 5) 223 | 224 | x = F.relu(self.fc1(x)) 225 | x = F.relu(self.fc2(x)) 226 | x = self.fc3(x) 227 | return x 228 | 229 | 230 | # a simple perceptron model for generated 3D data 231 | class PerceptronModel(nn.Module): 232 | def __init__(self, input_dim=3, output_dim=2): 233 | super(PerceptronModel, self).__init__() 234 | 235 | self.fc1 = nn.Linear(input_dim, output_dim) 236 | 237 | def forward(self, x): 238 | 239 | x = self.fc1(x) 240 | return x 241 | 242 | class SimpleCNNMNIST_header(nn.Module): 243 | def __init__(self, input_dim, hidden_dims, output_dim=10): 244 | super(SimpleCNNMNIST_header, self).__init__() 245 | self.conv1 = nn.Conv2d(1, 6, 5) 246 | self.relu = nn.ReLU() 247 | self.pool = nn.MaxPool2d(2, 2) 248 | self.conv2 = nn.Conv2d(6, 16, 5) 249 | 250 | # for now, we hard coded this network 251 | # i.e. we fix the number of hidden layers i.e. 2 layers 252 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 253 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 254 | #self.fc3 = nn.Linear(hidden_dims[1], output_dim) 255 | 256 | def forward(self, x): 257 | 258 | x = self.pool(self.relu(self.conv1(x))) 259 | x = self.pool(self.relu(self.conv2(x))) 260 | x = x.view(-1, 16 * 4 * 4) 261 | 262 | x = self.relu(self.fc1(x)) 263 | x = self.relu(self.fc2(x)) 264 | # x = self.fc3(x) 265 | return x 266 | 267 | class SimpleCNNMNIST(nn.Module): 268 | def __init__(self, input_dim, hidden_dims, output_dim=10): 269 | super(SimpleCNNMNIST, self).__init__() 270 | self.conv1 = nn.Conv2d(1, 6, 5) 271 | self.pool = nn.MaxPool2d(2, 2) 272 | self.conv2 = nn.Conv2d(6, 16, 5) 273 | 274 | # for now, we hard coded this network 275 | # i.e. we fix the number of hidden layers i.e. 2 layers 276 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 277 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 278 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 279 | 280 | def forward(self, x): 281 | x = self.pool(F.relu(self.conv1(x))) 282 | x = self.pool(F.relu(self.conv2(x))) 283 | x = x.view(-1, 16 * 4 * 4) 284 | 285 | x = F.relu(self.fc1(x)) 286 | x = F.relu(self.fc2(x)) 287 | x = self.fc3(x) 288 | return x 289 | 290 | 291 | class SimpleCNNContainer(nn.Module): 292 | def __init__(self, input_channel, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10): 293 | super(SimpleCNNContainer, self).__init__() 294 | ''' 295 | A testing cnn container, which allows initializing a CNN with given dims 296 | 297 | num_filters (list) :: number of convolution filters 298 | hidden_dims (list) :: number of neurons in hidden layers 299 | 300 | Assumptions: 301 | i) we use only two conv layers and three hidden layers (including the output layer) 302 | ii) kernel size in the two conv layers are identical 303 | ''' 304 | self.conv1 = nn.Conv2d(input_channel, num_filters[0], kernel_size) 305 | self.pool = nn.MaxPool2d(2, 2) 306 | self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size) 307 | 308 | # for now, we hard coded this network 309 | # i.e. we fix the number of hidden layers i.e. 2 layers 310 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 311 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 312 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 313 | 314 | def forward(self, x): 315 | x = self.pool(F.relu(self.conv1(x))) 316 | x = self.pool(F.relu(self.conv2(x))) 317 | x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3]) 318 | x = F.relu(self.fc1(x)) 319 | x = F.relu(self.fc2(x)) 320 | x = self.fc3(x) 321 | return x 322 | 323 | 324 | ############## LeNet for MNIST ################### 325 | class LeNet(nn.Module): 326 | def __init__(self): 327 | super(LeNet, self).__init__() 328 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 329 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 330 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 331 | self.fc2 = nn.Linear(500, 10) 332 | self.ceriation = nn.CrossEntropyLoss() 333 | 334 | def forward(self, x): 335 | x = self.conv1(x) 336 | x = F.max_pool2d(x, 2, 2) 337 | x = F.relu(x) 338 | x = self.conv2(x) 339 | x = F.max_pool2d(x, 2, 2) 340 | x = F.relu(x) 341 | x = x.view(-1, 4 * 4 * 50) 342 | x = self.fc1(x) 343 | x = self.fc2(x) 344 | return x 345 | 346 | 347 | class LeNetContainer(nn.Module): 348 | def __init__(self, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10): 349 | super(LeNetContainer, self).__init__() 350 | self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size, 1) 351 | self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size, 1) 352 | 353 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 354 | self.fc2 = nn.Linear(hidden_dims[0], output_dim) 355 | 356 | def forward(self, x): 357 | x = self.conv1(x) 358 | x = F.max_pool2d(x, 2, 2) 359 | x = F.relu(x) 360 | x = self.conv2(x) 361 | x = F.max_pool2d(x, 2, 2) 362 | x = F.relu(x) 363 | x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3]) 364 | x = self.fc1(x) 365 | x = self.fc2(x) 366 | return x 367 | 368 | 369 | 370 | ### Moderate size of CNN for CIFAR-10 dataset 371 | class ModerateCNN(nn.Module): 372 | def __init__(self, output_dim=10): 373 | super(ModerateCNN, self).__init__() 374 | self.conv_layer = nn.Sequential( 375 | # Conv Layer block 1 376 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 377 | nn.ReLU(inplace=True), 378 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 379 | nn.ReLU(inplace=True), 380 | nn.MaxPool2d(kernel_size=2, stride=2), 381 | 382 | # Conv Layer block 2 383 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 384 | nn.ReLU(inplace=True), 385 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 386 | nn.ReLU(inplace=True), 387 | nn.MaxPool2d(kernel_size=2, stride=2), 388 | nn.Dropout2d(p=0.05), 389 | 390 | # Conv Layer block 3 391 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 392 | nn.ReLU(inplace=True), 393 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 394 | nn.ReLU(inplace=True), 395 | nn.MaxPool2d(kernel_size=2, stride=2), 396 | ) 397 | 398 | self.fc_layer = nn.Sequential( 399 | nn.Dropout(p=0.1), 400 | # nn.Linear(4096, 1024), 401 | nn.Linear(4096, 512), 402 | nn.ReLU(inplace=True), 403 | # nn.Linear(1024, 512), 404 | nn.Linear(512, 512), 405 | nn.ReLU(inplace=True), 406 | nn.Dropout(p=0.1), 407 | nn.Linear(512, output_dim) 408 | ) 409 | 410 | def forward(self, x): 411 | x = self.conv_layer(x) 412 | x = x.view(x.size(0), -1) 413 | x = self.fc_layer(x) 414 | return x 415 | 416 | 417 | ### Moderate size of CNN for CIFAR-10 dataset 418 | class ModerateCNNCeleba(nn.Module): 419 | def __init__(self): 420 | super(ModerateCNNCeleba, self).__init__() 421 | self.conv_layer = nn.Sequential( 422 | # Conv Layer block 1 423 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 424 | nn.ReLU(inplace=True), 425 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 426 | nn.ReLU(inplace=True), 427 | nn.MaxPool2d(kernel_size=2, stride=2), 428 | 429 | # Conv Layer block 2 430 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 431 | nn.ReLU(inplace=True), 432 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 433 | nn.ReLU(inplace=True), 434 | nn.MaxPool2d(kernel_size=2, stride=2), 435 | # nn.Dropout2d(p=0.05), 436 | 437 | # Conv Layer block 3 438 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 439 | nn.ReLU(inplace=True), 440 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 441 | nn.ReLU(inplace=True), 442 | nn.MaxPool2d(kernel_size=2, stride=2), 443 | ) 444 | 445 | self.fc_layer = nn.Sequential( 446 | nn.Dropout(p=0.1), 447 | # nn.Linear(4096, 1024), 448 | nn.Linear(4096, 512), 449 | nn.ReLU(inplace=True), 450 | # nn.Linear(1024, 512), 451 | nn.Linear(512, 512), 452 | nn.ReLU(inplace=True), 453 | nn.Dropout(p=0.1), 454 | nn.Linear(512, 2) 455 | ) 456 | 457 | def forward(self, x): 458 | x = self.conv_layer(x) 459 | # x = x.view(x.size(0), -1) 460 | x = x.view(-1, 4096) 461 | x = self.fc_layer(x) 462 | return x 463 | 464 | 465 | class ModerateCNNMNIST(nn.Module): 466 | def __init__(self): 467 | super(ModerateCNNMNIST, self).__init__() 468 | self.conv_layer = nn.Sequential( 469 | # Conv Layer block 1 470 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1), 471 | nn.ReLU(inplace=True), 472 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 473 | nn.ReLU(inplace=True), 474 | nn.MaxPool2d(kernel_size=2, stride=2), 475 | 476 | # Conv Layer block 2 477 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 478 | nn.ReLU(inplace=True), 479 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 480 | nn.ReLU(inplace=True), 481 | nn.MaxPool2d(kernel_size=2, stride=2), 482 | nn.Dropout2d(p=0.05), 483 | 484 | # Conv Layer block 3 485 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 486 | nn.ReLU(inplace=True), 487 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 488 | nn.ReLU(inplace=True), 489 | nn.MaxPool2d(kernel_size=2, stride=2), 490 | ) 491 | 492 | self.fc_layer = nn.Sequential( 493 | nn.Dropout(p=0.1), 494 | nn.Linear(2304, 1024), 495 | nn.ReLU(inplace=True), 496 | nn.Linear(1024, 512), 497 | nn.ReLU(inplace=True), 498 | nn.Dropout(p=0.1), 499 | nn.Linear(512, 10) 500 | ) 501 | 502 | def forward(self, x): 503 | x = self.conv_layer(x) 504 | x = x.view(x.size(0), -1) 505 | x = self.fc_layer(x) 506 | return x 507 | 508 | 509 | class ModerateCNNContainer(nn.Module): 510 | def __init__(self, input_channels, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10): 511 | super(ModerateCNNContainer, self).__init__() 512 | 513 | ## 514 | self.conv_layer = nn.Sequential( 515 | # Conv Layer block 1 516 | nn.Conv2d(in_channels=input_channels, out_channels=num_filters[0], kernel_size=kernel_size, padding=1), 517 | nn.ReLU(inplace=True), 518 | nn.Conv2d(in_channels=num_filters[0], out_channels=num_filters[1], kernel_size=kernel_size, padding=1), 519 | nn.ReLU(inplace=True), 520 | nn.MaxPool2d(kernel_size=2, stride=2), 521 | 522 | # Conv Layer block 2 523 | nn.Conv2d(in_channels=num_filters[1], out_channels=num_filters[2], kernel_size=kernel_size, padding=1), 524 | nn.ReLU(inplace=True), 525 | nn.Conv2d(in_channels=num_filters[2], out_channels=num_filters[3], kernel_size=kernel_size, padding=1), 526 | nn.ReLU(inplace=True), 527 | nn.MaxPool2d(kernel_size=2, stride=2), 528 | nn.Dropout2d(p=0.05), 529 | 530 | # Conv Layer block 3 531 | nn.Conv2d(in_channels=num_filters[3], out_channels=num_filters[4], kernel_size=kernel_size, padding=1), 532 | nn.ReLU(inplace=True), 533 | nn.Conv2d(in_channels=num_filters[4], out_channels=num_filters[5], kernel_size=kernel_size, padding=1), 534 | nn.ReLU(inplace=True), 535 | nn.MaxPool2d(kernel_size=2, stride=2), 536 | ) 537 | 538 | self.fc_layer = nn.Sequential( 539 | nn.Dropout(p=0.1), 540 | nn.Linear(input_dim, hidden_dims[0]), 541 | nn.ReLU(inplace=True), 542 | nn.Linear(hidden_dims[0], hidden_dims[1]), 543 | nn.ReLU(inplace=True), 544 | nn.Dropout(p=0.1), 545 | nn.Linear(hidden_dims[1], output_dim) 546 | ) 547 | 548 | def forward(self, x): 549 | x = self.conv_layer(x) 550 | x = x.view(x.size(0), -1) 551 | x = self.fc_layer(x) 552 | return x 553 | 554 | def forward_conv(self, x): 555 | x = self.conv_layer(x) 556 | x = x.view(x.size(0), -1) 557 | return x 558 | 559 | 560 | class ModelFedCon(nn.Module): 561 | 562 | def __init__(self, base_model, out_dim, n_classes, net_configs=None): 563 | super(ModelFedCon, self).__init__() 564 | 565 | # if base_model == "resnet50": 566 | # basemodel = models.resnet50(pretrained=False) 567 | # self.features = nn.Sequential(*list(basemodel.children())[:-1]) 568 | # num_ftrs = basemodel.fc.in_features 569 | # elif base_model == "resnet18": 570 | # basemodel = models.resnet18(pretrained=False) 571 | # self.features = nn.Sequential(*list(basemodel.children())[:-1]) 572 | # num_ftrs = basemodel.fc.in_features 573 | if base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel" or base_model == "resnet50": 574 | basemodel = ResNet50_cifar10() 575 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 576 | num_ftrs = basemodel.fc.in_features 577 | elif base_model == "resnet18-cifar10" or base_model == "resnet18": 578 | basemodel = ResNet18_cifar10() 579 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 580 | num_ftrs = basemodel.fc.in_features 581 | elif base_model == "mlp": 582 | self.features = MLP_header(input_dim=net_configs[0], hidden_dims=net_configs[1:-1]) 583 | num_ftrs = net_configs[-2] 584 | elif base_model == 'simple-cnn': 585 | self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes) 586 | num_ftrs = 84 587 | elif base_model == 'simple-cnn-mnist': 588 | self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes) 589 | num_ftrs = 84 590 | 591 | #summary(self.features.to('cuda:0'), (3,32,32)) 592 | #print("features:", self.features) 593 | # projection MLP 594 | self.l1 = nn.Linear(num_ftrs, num_ftrs) 595 | self.l2 = nn.Linear(num_ftrs, out_dim) 596 | 597 | # last layer 598 | self.l3 = nn.Linear(out_dim, n_classes) 599 | self.num_ftrs = num_ftrs 600 | 601 | def _get_basemodel(self, model_name): 602 | try: 603 | model = self.model_dict[model_name] 604 | #print("Feature extractor:", model_name) 605 | return model 606 | except: 607 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") 608 | 609 | def forward(self, x): 610 | h = self.features(x) 611 | h = h.reshape(-1, self.num_ftrs) 612 | #print("h before:", h) 613 | #print("h size:", h.size()) 614 | #h = h.squeeze() 615 | #print("h after:", h) 616 | x = self.l1(h) 617 | x = F.relu(x) 618 | x = self.l2(x) 619 | 620 | y = self.l3(x) 621 | return h, x, y 622 | 623 | 624 | class ModelFedCon_noheader(nn.Module): 625 | 626 | def __init__(self, base_model, out_dim, n_classes, net_configs=None): 627 | super(ModelFedCon_noheader, self).__init__() 628 | 629 | # if base_model == "resnet": 630 | # basemodel = models.resnet50(pretrained=False) 631 | # self.features = nn.Sequential(*list(basemodel.children())[:-1]) 632 | # num_ftrs = basemodel.fc.in_features 633 | if base_model == "resnet18": 634 | basemodel = models.resnet18(pretrained=False) 635 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 636 | num_ftrs = basemodel.fc.in_features 637 | elif base_model == "resnet" or base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel": 638 | basemodel = ResNet50_cifar10() 639 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 640 | num_ftrs = basemodel.fc.in_features 641 | elif base_model == "resnet18-cifar10": 642 | basemodel = ResNet18_cifar10() 643 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 644 | num_ftrs = basemodel.fc.in_features 645 | elif base_model == "mlp": 646 | self.features = MLP_header(input_dim=net_configs[0], hidden_dims=net_configs[1:-1]) 647 | num_ftrs = net_configs[-2] 648 | elif base_model == 'simple-cnn': 649 | self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes) 650 | num_ftrs = 84 651 | elif base_model == 'simple-cnn-mnist': 652 | self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes) 653 | num_ftrs = 84 654 | 655 | #summary(self.features.to('cuda:0'), (3,32,32)) 656 | #print("features:", self.features) 657 | # projection MLP 658 | # self.l1 = nn.Linear(num_ftrs, num_ftrs) 659 | # self.l2 = nn.Linear(num_ftrs, out_dim) 660 | 661 | # last layer 662 | self.l3 = nn.Linear(num_ftrs, n_classes) 663 | self.num_ftrs = num_ftrs 664 | 665 | def _get_basemodel(self, model_name): 666 | try: 667 | model = self.model_dict[model_name] 668 | #print("Feature extractor:", model_name) 669 | return model 670 | except: 671 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") 672 | 673 | def forward(self, x): 674 | h = self.features(x) 675 | h = h.reshape(-1, self.num_ftrs) 676 | #print("h before:", h) 677 | #print("h size:", h.size()) 678 | #h = h.squeeze() 679 | #print("h after:", h) 680 | # x = self.l1(h) 681 | # x = F.relu(x) 682 | # x = self.l2(x) 683 | 684 | y = self.l3(h) 685 | return h, h, y 686 | 687 | -------------------------------------------------------------------------------- /models/celeba_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Architecture based on InfoGAN paper. 7 | """ 8 | 9 | class Generator(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | self.tconv1 = nn.ConvTranspose2d(228, 448, 2, 1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(448) 15 | 16 | self.tconv2 = nn.ConvTranspose2d(448, 256, 4, 2, padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(256) 18 | 19 | self.tconv3 = nn.ConvTranspose2d(256, 128, 4, 2, padding=1, bias=False) 20 | 21 | self.tconv4 = nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False) 22 | 23 | self.tconv5 = nn.ConvTranspose2d(64, 3, 4, 2, padding=1, bias=False) 24 | 25 | def forward(self, x): 26 | x = F.relu(self.bn1(self.tconv1(x))) 27 | x = F.relu(self.bn2(self.tconv2(x))) 28 | x = F.relu(self.tconv3(x)) 29 | x = F.relu(self.tconv4(x)) 30 | 31 | img = torch.tanh(self.tconv5(x)) 32 | 33 | return img 34 | 35 | class Discriminator(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | 39 | self.conv1 = nn.Conv2d(3, 64, 4, 2, 1) 40 | 41 | self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False) 42 | self.bn2 = nn.BatchNorm2d(128) 43 | 44 | self.conv3 = nn.Conv2d(128, 256, 4, 2, 1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(256) 46 | 47 | def forward(self, x): 48 | x = F.leaky_relu(self.conv1(x), 0.1, inplace=True) 49 | x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1, inplace=True) 50 | x = F.leaky_relu(self.bn3(self.conv3(x)), 0.1, inplace=True) 51 | 52 | return x 53 | 54 | class DHead(nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | self.conv = nn.Conv2d(256, 1, 4) 59 | 60 | def forward(self, x): 61 | output = torch.sigmoid(self.conv(x)) 62 | 63 | return output 64 | 65 | class QHead(nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | 69 | self.conv1 = nn.Conv2d(256, 128, 4, bias=False) 70 | self.bn1 = nn.BatchNorm2d(128) 71 | 72 | self.conv_disc = nn.Conv2d(128, 100, 1) 73 | 74 | self.conv_mu = nn.Conv2d(128, 1, 1) 75 | self.conv_var = nn.Conv2d(128, 1, 1) 76 | 77 | def forward(self, x): 78 | x = F.leaky_relu(self.bn1(self.conv1(x)), 0.1, inplace=True) 79 | 80 | disc_logits = self.conv_disc(x).squeeze() 81 | 82 | # Not used during training for celeba dataset. 83 | mu = self.conv_mu(x).squeeze() 84 | var = torch.exp(self.conv_var(x).squeeze()) 85 | 86 | return disc_logits, mu, var 87 | -------------------------------------------------------------------------------- /models/mnist_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Architecture based on InfoGAN paper. 7 | """ 8 | 9 | class Generator(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | self.tconv1 = nn.ConvTranspose2d(74, 1024, 1, 1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(1024) 15 | 16 | self.tconv2 = nn.ConvTranspose2d(1024, 128, 7, 1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(128) 18 | 19 | self.tconv3 = nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False) 20 | self.bn3 = nn.BatchNorm2d(64) 21 | 22 | self.tconv4 = nn.ConvTranspose2d(64, 1, 4, 2, padding=1, bias=False) 23 | 24 | def forward(self, x): 25 | x = F.relu(self.bn1(self.tconv1(x))) 26 | x = F.relu(self.bn2(self.tconv2(x))) 27 | x = F.relu(self.bn3(self.tconv3(x))) 28 | 29 | img = torch.sigmoid(self.tconv4(x)) 30 | 31 | return img 32 | 33 | class Discriminator(nn.Module): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | self.conv1 = nn.Conv2d(1, 64, 4, 2, 1) 38 | 39 | self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False) 40 | self.bn2 = nn.BatchNorm2d(128) 41 | 42 | self.conv3 = nn.Conv2d(128, 1024, 7, bias=False) 43 | self.bn3 = nn.BatchNorm2d(1024) 44 | 45 | def forward(self, x): 46 | x = F.leaky_relu(self.conv1(x), 0.1, inplace=True) 47 | x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1, inplace=True) 48 | x = F.leaky_relu(self.bn3(self.conv3(x)), 0.1, inplace=True) 49 | 50 | return x 51 | 52 | class DHead(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | 56 | self.conv = nn.Conv2d(1024, 1, 1) 57 | 58 | def forward(self, x): 59 | output = torch.sigmoid(self.conv(x)) 60 | 61 | return output 62 | 63 | class QHead(nn.Module): 64 | def __init__(self): 65 | super().__init__() 66 | 67 | self.conv1 = nn.Conv2d(1024, 128, 1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(128) 69 | 70 | self.conv_disc = nn.Conv2d(128, 10, 1) 71 | self.conv_mu = nn.Conv2d(128, 2, 1) 72 | self.conv_var = nn.Conv2d(128, 2, 1) 73 | 74 | def forward(self, x): 75 | x = F.leaky_relu(self.bn1(self.conv1(x)), 0.1, inplace=True) 76 | 77 | disc_logits = self.conv_disc(x).squeeze() 78 | 79 | mu = self.conv_mu(x).squeeze() 80 | var = torch.exp(self.conv_var(x).squeeze()) 81 | 82 | return disc_logits, mu, var 83 | -------------------------------------------------------------------------------- /models/svhn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Architecture based on InfoGAN paper. 7 | """ 8 | 9 | class Generator(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | self.tconv1 = nn.ConvTranspose2d(138, 448, 2, 1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(448) 15 | 16 | self.tconv2 = nn.ConvTranspose2d(448, 256, 4, 2, padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(256) 18 | 19 | self.tconv3 = nn.ConvTranspose2d(256, 128, 4, 2, padding=1, bias=False) 20 | 21 | self.tconv4 = nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False) 22 | 23 | self.tconv5 = nn.ConvTranspose2d(64, 3, 4, 2, padding=1, bias=False) 24 | 25 | def forward(self, x): 26 | x = F.relu(self.bn1(self.tconv1(x))) 27 | x = F.relu(self.bn2(self.tconv2(x))) 28 | x = F.relu(self.tconv3(x)) 29 | x = F.relu(self.tconv4(x)) 30 | 31 | img = torch.tanh(self.tconv5(x)) 32 | 33 | return img 34 | 35 | class Discriminator(nn.Module): 36 | def __init__(self): 37 | super().__init__() 38 | 39 | self.conv1 = nn.Conv2d(3, 64, 4, 2, 1) 40 | 41 | self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False) 42 | self.bn2 = nn.BatchNorm2d(128) 43 | 44 | self.conv3 = nn.Conv2d(128, 256, 4, 2, 1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(256) 46 | 47 | def forward(self, x): 48 | x = F.leaky_relu(self.conv1(x), 0.1, inplace=True) 49 | x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1, inplace=True) 50 | x = F.leaky_relu(self.bn3(self.conv3(x)), 0.1, inplace=True) 51 | 52 | return x 53 | 54 | class DHead(nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | self.conv = nn.Conv2d(256, 1, 4) 59 | 60 | def forward(self, x): 61 | output = torch.sigmoid(self.conv(x)) 62 | 63 | return output 64 | 65 | class QHead(nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | 69 | self.conv1 = nn.Conv2d(256, 128, 4, bias=False) 70 | self.bn1 = nn.BatchNorm2d(128) 71 | 72 | self.conv_disc = nn.Conv2d(128, 40, 1) 73 | self.conv_mu = nn.Conv2d(128, 4, 1) 74 | self.conv_var = nn.Conv2d(128, 4, 1) 75 | 76 | def forward(self, x): 77 | x = F.leaky_relu(self.bn1(self.conv1(x)), 0.1, inplace=True) 78 | 79 | disc_logits = self.conv_disc(x).squeeze() 80 | 81 | mu = self.conv_mu(x).squeeze() 82 | var = torch.exp(self.conv_var(x).squeeze()) 83 | 84 | return disc_logits, mu, var 85 | -------------------------------------------------------------------------------- /partition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import random 5 | import argparse 6 | import csv 7 | 8 | from utils import mkdirs 9 | def partition_data(dataset, class_id, K, partition, n_parties, beta, seed): 10 | np.random.seed(seed) 11 | random.seed(seed) 12 | 13 | n_train = dataset.shape[0] 14 | y_train = dataset[:,class_id] 15 | 16 | if partition == "homo": 17 | idxs = np.random.permutation(n_train) 18 | batch_idxs = np.array_split(idxs, n_parties) 19 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)} 20 | 21 | elif partition == "noniid-labeldir": 22 | min_size = 0 23 | min_require_size = 10 24 | 25 | N = dataset.shape[0] 26 | net_dataidx_map = {} 27 | 28 | while min_size < min_require_size: 29 | idx_batch = [[] for _ in range(n_parties)] 30 | for k in range(K): 31 | idx_k = np.where(y_train == k)[0] 32 | np.random.shuffle(idx_k) 33 | proportions = np.random.dirichlet(np.repeat(beta, n_parties)) 34 | # logger.info("proportions1: ", proportions) 35 | # logger.info("sum pro1:", np.sum(proportions)) 36 | ## Balance 37 | proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)]) 38 | # logger.info("proportions2: ", proportions) 39 | proportions = proportions / proportions.sum() 40 | # logger.info("proportions3: ", proportions) 41 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 42 | # logger.info("proportions4: ", proportions) 43 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 44 | min_size = min([len(idx_j) for idx_j in idx_batch]) 45 | # if K == 2 and n_parties <= 10: 46 | # if np.min(proportions) < 200: 47 | # min_size = 0 48 | # break 49 | 50 | 51 | for j in range(n_parties): 52 | np.random.shuffle(idx_batch[j]) 53 | net_dataidx_map[j] = idx_batch[j] 54 | 55 | elif partition > "noniid-#label0" and partition <= "noniid-#label9": 56 | num = eval(partition[13:]) 57 | 58 | times=[0 for i in range(K)] 59 | contain=[] 60 | for i in range(n_parties): 61 | current=[i%K] 62 | times[i%K]+=1 63 | j=1 64 | while (j 1: 152 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 153 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 154 | self.conv1 = conv3x3(inplanes, planes, stride) 155 | self.bn1 = norm_layer(planes) 156 | self.relu = nn.ReLU(inplace=True) 157 | self.conv2 = conv3x3(planes, planes) 158 | self.bn2 = norm_layer(planes) 159 | self.downsample = downsample 160 | self.stride = stride 161 | 162 | def forward(self, x): 163 | identity = x 164 | 165 | out = self.conv1(x) 166 | out = self.bn1(out) 167 | out = self.relu(out) 168 | 169 | out = self.conv2(out) 170 | out = self.bn2(out) 171 | 172 | if self.downsample is not None: 173 | identity = self.downsample(x) 174 | 175 | out += identity 176 | out = self.relu(out) 177 | 178 | return out 179 | 180 | 181 | class Bottleneck(nn.Module): 182 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 183 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 184 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 185 | # This variant is also known as ResNet V1.5 and improves accuracy according to 186 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 187 | 188 | expansion = 4 189 | 190 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 191 | base_width=64, dilation=1, norm_layer=None): 192 | super(Bottleneck, self).__init__() 193 | if norm_layer is None: 194 | norm_layer = nn.BatchNorm2d 195 | width = int(planes * (base_width / 64.)) * groups 196 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 197 | self.conv1 = conv1x1(inplanes, width) 198 | self.bn1 = norm_layer(width) 199 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 200 | self.bn2 = norm_layer(width) 201 | self.conv3 = conv1x1(width, planes * self.expansion) 202 | self.bn3 = norm_layer(planes * self.expansion) 203 | self.relu = nn.ReLU(inplace=True) 204 | self.downsample = downsample 205 | self.stride = stride 206 | 207 | def forward(self, x): 208 | identity = x 209 | 210 | out = self.conv1(x) 211 | out = self.bn1(out) 212 | out = self.relu(out) 213 | 214 | out = self.conv2(out) 215 | out = self.bn2(out) 216 | out = self.relu(out) 217 | 218 | out = self.conv3(out) 219 | out = self.bn3(out) 220 | 221 | if self.downsample is not None: 222 | identity = self.downsample(x) 223 | 224 | out += identity 225 | out = self.relu(out) 226 | 227 | return out 228 | 229 | 230 | class ResNetCifar10(nn.Module): 231 | 232 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 233 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 234 | norm_layer=None): 235 | super(ResNetCifar10, self).__init__() 236 | if norm_layer is None: 237 | norm_layer = nn.BatchNorm2d 238 | self._norm_layer = norm_layer 239 | 240 | self.inplanes = 64 241 | self.dilation = 1 242 | if replace_stride_with_dilation is None: 243 | # each element in the tuple indicates if we should replace 244 | # the 2x2 stride with a dilated convolution instead 245 | replace_stride_with_dilation = [False, False, False] 246 | if len(replace_stride_with_dilation) != 3: 247 | raise ValueError("replace_stride_with_dilation should be None " 248 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 249 | self.groups = groups 250 | self.base_width = width_per_group 251 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 252 | bias=False) 253 | self.bn1 = norm_layer(self.inplanes) 254 | self.relu = nn.ReLU(inplace=True) 255 | self.layer1 = self._make_layer(block, 64, layers[0]) 256 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 257 | dilate=replace_stride_with_dilation[0]) 258 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 259 | dilate=replace_stride_with_dilation[1]) 260 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 261 | dilate=replace_stride_with_dilation[2]) 262 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 263 | self.fc = nn.Linear(512 * block.expansion, num_classes) 264 | 265 | for m in self.modules(): 266 | if isinstance(m, nn.Conv2d): 267 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 268 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 269 | nn.init.constant_(m.weight, 1) 270 | nn.init.constant_(m.bias, 0) 271 | 272 | # Zero-initialize the last BN in each residual branch, 273 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 274 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 275 | if zero_init_residual: 276 | for m in self.modules(): 277 | if isinstance(m, Bottleneck): 278 | nn.init.constant_(m.bn3.weight, 0) 279 | elif isinstance(m, BasicBlock): 280 | nn.init.constant_(m.bn2.weight, 0) 281 | 282 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 283 | norm_layer = self._norm_layer 284 | downsample = None 285 | previous_dilation = self.dilation 286 | if dilate: 287 | self.dilation *= stride 288 | stride = 1 289 | if stride != 1 or self.inplanes != planes * block.expansion: 290 | downsample = nn.Sequential( 291 | conv1x1(self.inplanes, planes * block.expansion, stride), 292 | norm_layer(planes * block.expansion), 293 | ) 294 | 295 | layers = [] 296 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 297 | self.base_width, previous_dilation, norm_layer)) 298 | self.inplanes = planes * block.expansion 299 | for _ in range(1, blocks): 300 | layers.append(block(self.inplanes, planes, groups=self.groups, 301 | base_width=self.base_width, dilation=self.dilation, 302 | norm_layer=norm_layer)) 303 | 304 | return nn.Sequential(*layers) 305 | 306 | def _forward_impl(self, x): 307 | # See note [TorchScript super()] 308 | x = self.conv1(x) 309 | x = self.bn1(x) 310 | x = self.relu(x) 311 | 312 | x = self.layer1(x) 313 | x = self.layer2(x) 314 | x = self.layer3(x) 315 | x = self.layer4(x) 316 | 317 | x = self.avgpool(x) 318 | x = torch.flatten(x, 1) 319 | x = self.fc(x) 320 | 321 | return x 322 | 323 | def forward(self, x): 324 | return self._forward_impl(x) 325 | 326 | 327 | def ResNet18_cifar10(**kwargs): 328 | r"""ResNet-18 model from 329 | `"Deep Residual Learning for Image Recognition" `_ 330 | 331 | Args: 332 | pretrained (bool): If True, returns a model pre-trained on ImageNet 333 | progress (bool): If True, displays a progress bar of the download to stderr 334 | """ 335 | return ResNetCifar10(BasicBlock, [2, 2, 2, 2], **kwargs) 336 | 337 | 338 | 339 | def ResNet50_cifar10(**kwargs): 340 | r"""ResNet-50 model from 341 | `"Deep Residual Learning for Image Recognition" `_ 342 | 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | """ 347 | return ResNetCifar10(Bottleneck, [3, 4, 6, 3], **kwargs) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python experiments.py --model=simple-cnn \ 2 | --dataset=cifar10 \ 3 | --alg=fednova \ 4 | --lr=0.01 \ 5 | --batch-size=64 \ 6 | --epochs=10 \ 7 | --n_parties=10 \ 8 | --rho=0.9 \ 9 | --comm_round=50 \ 10 | --partition=noniid-labeldir \ 11 | --beta=0.5\ 12 | --device='cpu'\ 13 | --datadir='./data/' \ 14 | --logdir='./logs/' \ 15 | --noise=0\ 16 | --init_seed=0 -------------------------------------------------------------------------------- /scripts/100-parties.sh: -------------------------------------------------------------------------------- 1 | for partition in noniid-labeldir noniid-#label1 noniid-#label2 noniid-#label3 iid-diff-quantity homo 2 | do 3 | for alg in fedavg scaffold fednova 4 | do 5 | python experiments.py --model=simple-cnn \ 6 | --dataset=cifar10 \ 7 | --alg=$alg \ 8 | --lr=0.01 \ 9 | --batch-size=64 \ 10 | --epochs=10 \ 11 | --n_parties=100 \ 12 | --rho=0.9 \ 13 | --comm_round=500 \ 14 | --partition=$partition \ 15 | --beta=0.5\ 16 | --device='cuda:0'\ 17 | --datadir='./data/' \ 18 | --logdir='./logs/' \ 19 | --noise=0\ 20 | --sample=0.1\ 21 | --init_seed=0 22 | done 23 | 24 | for mu in 0.001 0.01 0.1 1 25 | do 26 | python experiments.py --model=simple-cnn \ 27 | --dataset=cifar10 \ 28 | --alg=fedprox \ 29 | --lr=0.01 \ 30 | --batch-size=64 \ 31 | --epochs=10 \ 32 | --n_parties=100 \ 33 | --rho=0.9 \ 34 | --mu=$mu 35 | --comm_round=500 \ 36 | --partition=$partition \ 37 | --beta=0.5\ 38 | --device='cuda:0'\ 39 | --datadir='./data/' \ 40 | --logdir='./logs/' \ 41 | --noise=0\ 42 | --sample=0.1\ 43 | --init_seed=0 44 | done 45 | done 46 | -------------------------------------------------------------------------------- /scripts/adult&covtype.sh: -------------------------------------------------------------------------------- 1 | for init_seed in 0 1 2 2 | do 3 | for partition in noniid-labeldir noniid-#label1 iid-diff-quantity homo 4 | do 5 | for dataset in a9a covtype 6 | do 7 | for alg in fedavg scaffold fednova 8 | do 9 | python experiments.py --model=mlp \ 10 | --dataset=$dataset \ 11 | --alg=$alg \ 12 | --lr=0.01 \ 13 | --batch-size=64 \ 14 | --epochs=10 \ 15 | --n_parties=10 \ 16 | --rho=0.9 \ 17 | --comm_round=50 \ 18 | --partition=$partition \ 19 | --beta=0.5\ 20 | --device='cuda:0'\ 21 | --datadir='./data/' \ 22 | --logdir='./logs/' \ 23 | --noise=0\ 24 | --init_seed=$init_seed 25 | done 26 | 27 | for mu in 0.001 0.01 0.1 1 28 | do 29 | python experiments.py --model=mlp \ 30 | --dataset=$dataset \ 31 | --alg=fedprox \ 32 | --lr=0.01 \ 33 | --batch-size=64 \ 34 | --epochs=10 \ 35 | --n_parties=10 \ 36 | --rho=0.9 \ 37 | --mu=$mu 38 | --comm_round=50 \ 39 | --partition=$partition \ 40 | --beta=0.5\ 41 | --device='cuda:0'\ 42 | --datadir='./data/' \ 43 | --logdir='./logs/' \ 44 | --noise=0\ 45 | --init_seed=$init_seed 46 | done 47 | 48 | done 49 | done 50 | done 51 | -------------------------------------------------------------------------------- /scripts/batch-size.sh: -------------------------------------------------------------------------------- 1 | for size in 16 32 64 128 256 2 | do 3 | for alg in fedavg fedprox scaffold fednova 4 | do 5 | python experiments.py --model=simple-cnn \ 6 | --dataset=cifar10 \ 7 | --alg=$alg \ 8 | --lr=0.01 \ 9 | --batch-size=$size \ 10 | --epochs=10 \ 11 | --n_parties=10 \ 12 | --rho=0.9 \ 13 | --mu=0.01 \ 14 | --comm_round=50 \ 15 | --partition=noniid-labeldir \ 16 | --beta=0.5\ 17 | --device='cuda:0'\ 18 | --datadir='./data/' \ 19 | --logdir='./logs/' \ 20 | --noise=0\ 21 | --sample=0\ 22 | --init_seed=0 23 | done 24 | done 25 | 26 | -------------------------------------------------------------------------------- /scripts/fcube.sh: -------------------------------------------------------------------------------- 1 | for init_seed in 0 1 2 2 | do 3 | for partition in real homo 4 | do 5 | for alg in fedavg scaffold fednova 6 | do 7 | python experiments.py --model=mlp \ 8 | --dataset=generated \ 9 | --alg=$alg \ 10 | --lr=0.01 \ 11 | --batch-size=64 \ 12 | --epochs=10 \ 13 | --n_parties=4 \ 14 | --rho=0.9 \ 15 | --comm_round=50 \ 16 | --partition=$partition \ 17 | --beta=0.5\ 18 | --device='cuda:0'\ 19 | --datadir='./data/' \ 20 | --logdir='./logs/' \ 21 | --noise=0\ 22 | --init_seed=$init_seed 23 | done 24 | 25 | for mu in 0.001 0.01 0.1 1 26 | do 27 | python experiments.py --model=mlp \ 28 | --dataset=generated \ 29 | --alg=fedprox \ 30 | --lr=0.01 \ 31 | --batch-size=64 \ 32 | --epochs=10 \ 33 | --n_parties=4 \ 34 | --rho=0.9 \ 35 | --mu=$mu 36 | --comm_round=50 \ 37 | --partition=$partition \ 38 | --beta=0.5\ 39 | --device='cuda:0'\ 40 | --datadir='./data/' \ 41 | --logdir='./logs/' \ 42 | --noise=0\ 43 | --init_seed=$init_seed 44 | done 45 | 46 | done 47 | done 48 | -------------------------------------------------------------------------------- /scripts/femnist.sh: -------------------------------------------------------------------------------- 1 | for init_seed in 0 1 2 2 | do 3 | for partition in real homo 4 | do 5 | for alg in fedavg scaffold fednova 6 | do 7 | python experiments.py --model=mlp \ 8 | --dataset=femnist \ 9 | --alg=$alg \ 10 | --lr=0.01 \ 11 | --batch-size=64 \ 12 | --epochs=10 \ 13 | --n_parties=10 \ 14 | --rho=0.9 \ 15 | --comm_round=50 \ 16 | --partition=$partition \ 17 | --beta=0.5\ 18 | --device='cuda:0'\ 19 | --datadir='./data/' \ 20 | --logdir='./logs/' \ 21 | --noise=0\ 22 | --init_seed=$init_seed 23 | done 24 | 25 | for mu in 0.001 0.01 0.1 1 26 | do 27 | python experiments.py --model=mlp \ 28 | --dataset=femnist \ 29 | --alg=fedprox \ 30 | --lr=0.01 \ 31 | --batch-size=64 \ 32 | --epochs=10 \ 33 | --n_parties=10 \ 34 | --rho=0.9 \ 35 | --mu=$mu 36 | --comm_round=50 \ 37 | --partition=$partition \ 38 | --beta=0.5\ 39 | --device='cuda:0'\ 40 | --datadir='./data/' \ 41 | --logdir='./logs/' \ 42 | --noise=0\ 43 | --init_seed=$init_seed 44 | done 45 | 46 | done 47 | done 48 | -------------------------------------------------------------------------------- /scripts/image-data-with-noise.sh: -------------------------------------------------------------------------------- 1 | for init_seed in 0 1 2 2 | do 3 | for dataset in cifar10 mnist fmnist svhn 4 | do 5 | for alg in fedavg scaffold fednova 6 | do 7 | python experiments.py --model=simple-cnn \ 8 | --dataset=$dataset \ 9 | --alg=$alg \ 10 | --lr=0.01 \ 11 | --batch-size=64 \ 12 | --epochs=10 \ 13 | --n_parties=10 \ 14 | --rho=0.9 \ 15 | --comm_round=50 \ 16 | --partition=homo \ 17 | --beta=0.5\ 18 | --device='cuda:0'\ 19 | --datadir='./data/' \ 20 | --logdir='./logs/' \ 21 | --noise=0.1\ 22 | --init_seed=$init_seed 23 | done 24 | 25 | for mu in 0.001 0.01 0.1 1 26 | do 27 | python experiments.py --model=simple-cnn \ 28 | --dataset=$dataset \ 29 | --alg=fedprox \ 30 | --lr=0.01 \ 31 | --batch-size=64 \ 32 | --epochs=10 \ 33 | --n_parties=10 \ 34 | --rho=0.9 \ 35 | --mu=$mu 36 | --comm_round=50 \ 37 | --partition=homo \ 38 | --beta=0.5\ 39 | --device='cuda:0'\ 40 | --datadir='./data/' \ 41 | --logdir='./logs/' \ 42 | --noise=0.1\ 43 | --init_seed=$init_seed 44 | done 45 | 46 | done 47 | done 48 | 49 | 50 | -------------------------------------------------------------------------------- /scripts/image-data-without-noise.sh: -------------------------------------------------------------------------------- 1 | for init_seed in 0 1 2 2 | do 3 | for partition in noniid-labeldir noniid-#label1 noniid-#label2 noniid-#label3 iid-diff-quantity homo 4 | do 5 | for dataset in cifar10 mnist fmnist svhn 6 | do 7 | for alg in fedavg scaffold fednova 8 | do 9 | python experiments.py --model=simple-cnn \ 10 | --dataset=$dataset \ 11 | --alg=$alg \ 12 | --lr=0.01 \ 13 | --batch-size=64 \ 14 | --epochs=10 \ 15 | --n_parties=10 \ 16 | --rho=0.9 \ 17 | --comm_round=50 \ 18 | --partition=$partition \ 19 | --beta=0.5\ 20 | --device='cuda:0'\ 21 | --datadir='./data/' \ 22 | --logdir='./logs/' \ 23 | --noise=0\ 24 | --init_seed=$init_seed 25 | done 26 | 27 | for mu in 0.001 0.01 0.1 1 28 | do 29 | python experiments.py --model=simple-cnn \ 30 | --dataset=$dataset \ 31 | --alg=fedprox \ 32 | --lr=0.01 \ 33 | --batch-size=64 \ 34 | --epochs=10 \ 35 | --n_parties=10 \ 36 | --rho=0.9 \ 37 | --mu=$mu 38 | --comm_round=50 \ 39 | --partition=$partition \ 40 | --beta=0.5\ 41 | --device='cuda:0'\ 42 | --datadir='./data/' \ 43 | --logdir='./logs/' \ 44 | --noise=0\ 45 | --init_seed=$init_seed 46 | done 47 | 48 | done 49 | done 50 | done 51 | 52 | 53 | -------------------------------------------------------------------------------- /scripts/rcv1.sh: -------------------------------------------------------------------------------- 1 | for init_seed in 0 1 2 2 | do 3 | for partition in noniid-labeldir noniid-#label1 iid-diff-quantity homo 4 | do 5 | for dataset in rcv1 6 | do 7 | for alg in fedavg scaffold fednova 8 | do 9 | python experiments.py --model=mlp \ 10 | --dataset=$dataset \ 11 | --alg=$alg \ 12 | --lr=0.1 \ 13 | --batch-size=64 \ 14 | --epochs=10 \ 15 | --n_parties=10 \ 16 | --rho=0.9 \ 17 | --comm_round=50 \ 18 | --partition=$partition \ 19 | --beta=0.5\ 20 | --device='cuda:0'\ 21 | --datadir='./data/' \ 22 | --logdir='./logs/' \ 23 | --noise=0\ 24 | --init_seed=$init_seed 25 | done 26 | 27 | for mu in 0.001 0.01 0.1 1 28 | do 29 | python experiments.py --model=mlp \ 30 | --dataset=$dataset \ 31 | --alg=fedprox \ 32 | --lr=0.1 \ 33 | --batch-size=64 \ 34 | --epochs=10 \ 35 | --n_parties=10 \ 36 | --rho=0.9 \ 37 | --mu=$mu 38 | --comm_round=50 \ 39 | --partition=$partition \ 40 | --beta=0.5\ 41 | --device='cuda:0'\ 42 | --datadir='./data/' \ 43 | --logdir='./logs/' \ 44 | --noise=0\ 45 | --init_seed=$init_seed 46 | done 47 | 48 | done 49 | done 50 | done 51 | -------------------------------------------------------------------------------- /scripts/vgg&resnet.sh: -------------------------------------------------------------------------------- 1 | for model in vgg resnet 2 | do 3 | for partition in noniid-labeldir iid-diff-quantity 4 | do 5 | for alg in fedavg fedprox scaffold fednova 6 | do 7 | python experiments.py --model=$model \ 8 | --dataset=cifar10 \ 9 | --alg=$alg \ 10 | --lr=0.01 \ 11 | --batch-size=64 \ 12 | --epochs=10 \ 13 | --n_parties=10 \ 14 | --rho=0.9 \ 15 | --mu=0.01 \ 16 | --comm_round=100 \ 17 | --partition=$partition \ 18 | --beta=0.1\ 19 | --device='cuda:0'\ 20 | --datadir='./data/' \ 21 | --logdir='./logs/' \ 22 | --noise=0\ 23 | --init_seed=0 24 | done 25 | done 26 | 27 | for alg in fedavg fedprox scaffold fednova 28 | do 29 | python experiments.py --model=$model \ 30 | --dataset=cifar10 \ 31 | --alg=$alg \ 32 | --lr=0.01 \ 33 | --batch-size=64 \ 34 | --epochs=10 \ 35 | --n_parties=10 \ 36 | --rho=0.9 \ 37 | --mu=0.01 \ 38 | --comm_round=100 \ 39 | --partition=homo \ 40 | --beta=0.1\ 41 | --device='cuda:0'\ 42 | --datadir='./data/' \ 43 | --logdir='./logs/' \ 44 | --noise=0.1\ 45 | --init_seed=0 46 | done 47 | done 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torch.utils.data as data 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import random 10 | from sklearn.metrics import confusion_matrix 11 | from torch.utils.data import DataLoader 12 | import copy 13 | 14 | from model import * 15 | from datasets import MNIST_truncated, CIFAR10_truncated, CIFAR100_truncated, ImageFolder_custom, SVHN_custom, FashionMNIST_truncated, CustomTensorDataset, CelebA_custom, FEMNIST, Generated, genData 16 | from math import sqrt 17 | 18 | import torch.nn as nn 19 | 20 | import torch.optim as optim 21 | import torchvision.utils as vutils 22 | import time 23 | import random 24 | 25 | from models.mnist_model import Generator, Discriminator, DHead, QHead 26 | from config import params 27 | import sklearn.datasets as sk 28 | from sklearn.datasets import load_svmlight_file 29 | 30 | logging.basicConfig() 31 | logger = logging.getLogger() 32 | logger.setLevel(logging.INFO) 33 | 34 | def mkdirs(dirpath): 35 | try: 36 | os.makedirs(dirpath) 37 | except Exception as _: 38 | pass 39 | 40 | def load_mnist_data(datadir): 41 | 42 | transform = transforms.Compose([transforms.ToTensor()]) 43 | 44 | mnist_train_ds = MNIST_truncated(datadir, train=True, download=True, transform=transform) 45 | mnist_test_ds = MNIST_truncated(datadir, train=False, download=True, transform=transform) 46 | 47 | X_train, y_train = mnist_train_ds.data, mnist_train_ds.target 48 | X_test, y_test = mnist_test_ds.data, mnist_test_ds.target 49 | 50 | X_train = X_train.data.numpy() 51 | y_train = y_train.data.numpy() 52 | X_test = X_test.data.numpy() 53 | y_test = y_test.data.numpy() 54 | 55 | return (X_train, y_train, X_test, y_test) 56 | 57 | def load_fmnist_data(datadir): 58 | 59 | transform = transforms.Compose([transforms.ToTensor()]) 60 | 61 | mnist_train_ds = FashionMNIST_truncated(datadir, train=True, download=True, transform=transform) 62 | mnist_test_ds = FashionMNIST_truncated(datadir, train=False, download=True, transform=transform) 63 | 64 | X_train, y_train = mnist_train_ds.data, mnist_train_ds.target 65 | X_test, y_test = mnist_test_ds.data, mnist_test_ds.target 66 | 67 | X_train = X_train.data.numpy() 68 | y_train = y_train.data.numpy() 69 | X_test = X_test.data.numpy() 70 | y_test = y_test.data.numpy() 71 | 72 | return (X_train, y_train, X_test, y_test) 73 | 74 | def load_svhn_data(datadir): 75 | 76 | transform = transforms.Compose([transforms.ToTensor()]) 77 | 78 | svhn_train_ds = SVHN_custom(datadir, train=True, download=True, transform=transform) 79 | svhn_test_ds = SVHN_custom(datadir, train=False, download=True, transform=transform) 80 | 81 | X_train, y_train = svhn_train_ds.data, svhn_train_ds.target 82 | X_test, y_test = svhn_test_ds.data, svhn_test_ds.target 83 | 84 | # X_train = X_train.data.numpy() 85 | # y_train = y_train.data.numpy() 86 | # X_test = X_test.data.numpy() 87 | # y_test = y_test.data.numpy() 88 | 89 | return (X_train, y_train, X_test, y_test) 90 | 91 | 92 | def load_cifar10_data(datadir): 93 | 94 | transform = transforms.Compose([transforms.ToTensor()]) 95 | 96 | cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=transform) 97 | cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=transform) 98 | 99 | X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target 100 | X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target 101 | 102 | # y_train = y_train.numpy() 103 | # y_test = y_test.numpy() 104 | 105 | return (X_train, y_train, X_test, y_test) 106 | 107 | def load_celeba_data(datadir): 108 | 109 | transform = transforms.Compose([transforms.ToTensor()]) 110 | 111 | celeba_train_ds = CelebA_custom(datadir, split='train', target_type="attr", download=True, transform=transform) 112 | celeba_test_ds = CelebA_custom(datadir, split='test', target_type="attr", download=True, transform=transform) 113 | 114 | gender_index = celeba_train_ds.attr_names.index('Male') 115 | y_train = celeba_train_ds.attr[:,gender_index:gender_index+1].reshape(-1) 116 | y_test = celeba_test_ds.attr[:,gender_index:gender_index+1].reshape(-1) 117 | 118 | # y_train = y_train.numpy() 119 | # y_test = y_test.numpy() 120 | 121 | return (None, y_train, None, y_test) 122 | 123 | def load_femnist_data(datadir): 124 | transform = transforms.Compose([transforms.ToTensor()]) 125 | 126 | mnist_train_ds = FEMNIST(datadir, train=True, transform=transform, download=True) 127 | mnist_test_ds = FEMNIST(datadir, train=False, transform=transform, download=True) 128 | 129 | X_train, y_train, u_train = mnist_train_ds.data, mnist_train_ds.targets, mnist_train_ds.users_index 130 | X_test, y_test, u_test = mnist_test_ds.data, mnist_test_ds.targets, mnist_test_ds.users_index 131 | 132 | X_train = X_train.data.numpy() 133 | y_train = y_train.data.numpy() 134 | u_train = np.array(u_train) 135 | X_test = X_test.data.numpy() 136 | y_test = y_test.data.numpy() 137 | u_test = np.array(u_test) 138 | 139 | return (X_train, y_train, u_train, X_test, y_test, u_test) 140 | 141 | def load_cifar100_data(datadir): 142 | transform = transforms.Compose([transforms.ToTensor()]) 143 | 144 | cifar100_train_ds = CIFAR100_truncated(datadir, train=True, download=True, transform=transform) 145 | cifar100_test_ds = CIFAR100_truncated(datadir, train=False, download=True, transform=transform) 146 | 147 | X_train, y_train = cifar100_train_ds.data, cifar100_train_ds.target 148 | X_test, y_test = cifar100_test_ds.data, cifar100_test_ds.target 149 | 150 | # y_train = y_train.numpy() 151 | # y_test = y_test.numpy() 152 | 153 | return (X_train, y_train, X_test, y_test) 154 | 155 | 156 | def load_tinyimagenet_data(datadir): 157 | transform = transforms.Compose([transforms.ToTensor()]) 158 | xray_train_ds = ImageFolder_custom(datadir+'./train/', transform=transform) 159 | xray_test_ds = ImageFolder_custom(datadir+'./val/', transform=transform) 160 | 161 | X_train, y_train = np.array([s[0] for s in xray_train_ds.samples]), np.array([int(s[1]) for s in xray_train_ds.samples]) 162 | X_test, y_test = np.array([s[0] for s in xray_test_ds.samples]), np.array([int(s[1]) for s in xray_test_ds.samples]) 163 | 164 | return (X_train, y_train, X_test, y_test) 165 | 166 | def record_net_data_stats(y_train, net_dataidx_map, logdir): 167 | 168 | net_cls_counts = {} 169 | 170 | for net_i, dataidx in net_dataidx_map.items(): 171 | unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True) 172 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 173 | net_cls_counts[net_i] = tmp 174 | 175 | logger.info('Data statistics: %s' % str(net_cls_counts)) 176 | 177 | return net_cls_counts 178 | 179 | def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4): 180 | #np.random.seed(2020) 181 | #torch.manual_seed(2020) 182 | 183 | if dataset == 'mnist': 184 | X_train, y_train, X_test, y_test = load_mnist_data(datadir) 185 | elif dataset == 'fmnist': 186 | X_train, y_train, X_test, y_test = load_fmnist_data(datadir) 187 | elif dataset == 'cifar10': 188 | X_train, y_train, X_test, y_test = load_cifar10_data(datadir) 189 | elif dataset == 'svhn': 190 | X_train, y_train, X_test, y_test = load_svhn_data(datadir) 191 | elif dataset == 'celeba': 192 | X_train, y_train, X_test, y_test = load_celeba_data(datadir) 193 | elif dataset == 'femnist': 194 | X_train, y_train, u_train, X_test, y_test, u_test = load_femnist_data(datadir) 195 | elif dataset == 'cifar100': 196 | X_train, y_train, X_test, y_test = load_cifar100_data(datadir) 197 | elif dataset == 'tinyimagenet': 198 | X_train, y_train, X_test, y_test = load_tinyimagenet_data(datadir) 199 | elif dataset == 'generated': 200 | X_train, y_train = [], [] 201 | for loc in range(4): 202 | for i in range(1000): 203 | p1 = random.random() 204 | p2 = random.random() 205 | p3 = random.random() 206 | if loc > 1: 207 | p2 = -p2 208 | if loc % 2 ==1: 209 | p3 = -p3 210 | if i % 2 == 0: 211 | X_train.append([p1, p2, p3]) 212 | y_train.append(0) 213 | else: 214 | X_train.append([-p1, -p2, -p3]) 215 | y_train.append(1) 216 | X_test, y_test = [], [] 217 | for i in range(1000): 218 | p1 = random.random() * 2 - 1 219 | p2 = random.random() * 2 - 1 220 | p3 = random.random() * 2 - 1 221 | X_test.append([p1, p2, p3]) 222 | if p1>0: 223 | y_test.append(0) 224 | else: 225 | y_test.append(1) 226 | X_train = np.array(X_train, dtype=np.float32) 227 | X_test = np.array(X_test, dtype=np.float32) 228 | y_train = np.array(y_train, dtype=np.int32) 229 | y_test = np.array(y_test, dtype=np.int64) 230 | idxs = np.linspace(0,3999,4000,dtype=np.int64) 231 | batch_idxs = np.array_split(idxs, n_parties) 232 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)} 233 | mkdirs("data/generated/") 234 | np.save("data/generated/X_train.npy",X_train) 235 | np.save("data/generated/X_test.npy",X_test) 236 | np.save("data/generated/y_train.npy",y_train) 237 | np.save("data/generated/y_test.npy",y_test) 238 | 239 | #elif dataset == 'covtype': 240 | # cov_type = sk.fetch_covtype('./data') 241 | # num_train = int(581012 * 0.75) 242 | # idxs = np.random.permutation(581012) 243 | # X_train = np.array(cov_type['data'][idxs[:num_train]], dtype=np.float32) 244 | # y_train = np.array(cov_type['target'][idxs[:num_train]], dtype=np.int32) - 1 245 | # X_test = np.array(cov_type['data'][idxs[num_train:]], dtype=np.float32) 246 | # y_test = np.array(cov_type['target'][idxs[num_train:]], dtype=np.int32) - 1 247 | # mkdirs("data/generated/") 248 | # np.save("data/generated/X_train.npy",X_train) 249 | # np.save("data/generated/X_test.npy",X_test) 250 | # np.save("data/generated/y_train.npy",y_train) 251 | # np.save("data/generated/y_test.npy",y_test) 252 | 253 | elif dataset in ('rcv1', 'SUSY', 'covtype'): 254 | X_train, y_train = load_svmlight_file(datadir+dataset) 255 | X_train = X_train.todense() 256 | num_train = int(X_train.shape[0] * 0.75) 257 | if dataset == 'covtype': 258 | y_train = y_train-1 259 | else: 260 | y_train = (y_train+1)/2 261 | idxs = np.random.permutation(X_train.shape[0]) 262 | 263 | X_test = np.array(X_train[idxs[num_train:]], dtype=np.float32) 264 | y_test = np.array(y_train[idxs[num_train:]], dtype=np.int32) 265 | X_train = np.array(X_train[idxs[:num_train]], dtype=np.float32) 266 | y_train = np.array(y_train[idxs[:num_train]], dtype=np.int32) 267 | 268 | mkdirs("data/generated/") 269 | np.save("data/generated/X_train.npy",X_train) 270 | np.save("data/generated/X_test.npy",X_test) 271 | np.save("data/generated/y_train.npy",y_train) 272 | np.save("data/generated/y_test.npy",y_test) 273 | 274 | elif dataset in ('a9a'): 275 | X_train, y_train = load_svmlight_file(datadir+"a9a") 276 | X_test, y_test = load_svmlight_file(datadir+"a9a.t") 277 | X_train = X_train.todense() 278 | X_test = X_test.todense() 279 | X_test = np.c_[X_test, np.zeros((len(y_test), X_train.shape[1] - np.size(X_test[0, :])))] 280 | 281 | X_train = np.array(X_train, dtype=np.float32) 282 | X_test = np.array(X_test, dtype=np.float32) 283 | y_train = (y_train+1)/2 284 | y_test = (y_test+1)/2 285 | y_train = np.array(y_train, dtype=np.int32) 286 | y_test = np.array(y_test, dtype=np.int32) 287 | 288 | mkdirs("data/generated/") 289 | np.save("data/generated/X_train.npy",X_train) 290 | np.save("data/generated/X_test.npy",X_test) 291 | np.save("data/generated/y_train.npy",y_train) 292 | np.save("data/generated/y_test.npy",y_test) 293 | 294 | 295 | n_train = y_train.shape[0] 296 | 297 | if partition == "homo": 298 | idxs = np.random.permutation(n_train) 299 | batch_idxs = np.array_split(idxs, n_parties) 300 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)} 301 | 302 | 303 | elif partition == "noniid-labeldir": 304 | min_size = 0 305 | min_require_size = 10 306 | K = 10 307 | if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'): 308 | K = 2 309 | # min_require_size = 100 310 | if dataset == 'cifar100': 311 | K = 100 312 | elif dataset == 'tinyimagenet': 313 | K = 200 314 | 315 | N = y_train.shape[0] 316 | #np.random.seed(2020) 317 | net_dataidx_map = {} 318 | 319 | while min_size < min_require_size: 320 | idx_batch = [[] for _ in range(n_parties)] 321 | for k in range(K): 322 | idx_k = np.where(y_train == k)[0] 323 | np.random.shuffle(idx_k) 324 | proportions = np.random.dirichlet(np.repeat(beta, n_parties)) 325 | # logger.info("proportions1: ", proportions) 326 | # logger.info("sum pro1:", np.sum(proportions)) 327 | ## Balance 328 | proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)]) 329 | # logger.info("proportions2: ", proportions) 330 | proportions = proportions / proportions.sum() 331 | # logger.info("proportions3: ", proportions) 332 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 333 | # logger.info("proportions4: ", proportions) 334 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 335 | min_size = min([len(idx_j) for idx_j in idx_batch]) 336 | # if K == 2 and n_parties <= 10: 337 | # if np.min(proportions) < 200: 338 | # min_size = 0 339 | # break 340 | 341 | 342 | for j in range(n_parties): 343 | np.random.shuffle(idx_batch[j]) 344 | net_dataidx_map[j] = idx_batch[j] 345 | 346 | elif partition > "noniid-#label0" and partition <= "noniid-#label9": 347 | num = eval(partition[13:]) 348 | if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'): 349 | num = 1 350 | K = 2 351 | else: 352 | K = 10 353 | if dataset == "cifar100": 354 | K = 100 355 | elif dataset == "tinyimagenet": 356 | K = 200 357 | if num == 10: 358 | net_dataidx_map ={i:np.ndarray(0,dtype=np.int64) for i in range(n_parties)} 359 | for i in range(10): 360 | idx_k = np.where(y_train==i)[0] 361 | np.random.shuffle(idx_k) 362 | split = np.array_split(idx_k,n_parties) 363 | for j in range(n_parties): 364 | net_dataidx_map[j]=np.append(net_dataidx_map[j],split[j]) 365 | else: 366 | times=[0 for i in range(K)] 367 | contain=[] 368 | for i in range(n_parties): 369 | current=[i%K] 370 | times[i%K]+=1 371 | j=1 372 | while (j0: 508 | check[j]=1 509 | flag=False 510 | for i in range(10): 511 | if check[i]==0: 512 | flag=True 513 | break 514 | 515 | 516 | if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'): 517 | K = 2 518 | stat[:,0]=np.sum(stat[:,:5],axis=1) 519 | stat[:,1]=np.sum(stat[:,5:],axis=1) 520 | else: 521 | K = 10 522 | 523 | N = y_train.shape[0] 524 | #np.random.seed(2020) 525 | net_dataidx_map = {} 526 | 527 | idx_batch = [[] for _ in range(n_parties)] 528 | for k in range(K): 529 | idx_k = np.where(y_train == k)[0] 530 | np.random.shuffle(idx_k) 531 | proportions = stat[:,k] 532 | # logger.info("proportions2: ", proportions) 533 | proportions = proportions / proportions.sum() 534 | # logger.info("proportions3: ", proportions) 535 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 536 | # logger.info("proportions4: ", proportions) 537 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 538 | 539 | 540 | for j in range(n_parties): 541 | np.random.shuffle(idx_batch[j]) 542 | net_dataidx_map[j] = idx_batch[j] 543 | 544 | traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map, logdir) 545 | return (X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts) 546 | 547 | 548 | def get_trainable_parameters(net): 549 | 'return trainable parameter values as a vector (only the first parameter set)' 550 | trainable=filter(lambda p: p.requires_grad, net.parameters()) 551 | # logger.info("net.parameter.data:", list(net.parameters())) 552 | paramlist=list(trainable) 553 | N=0 554 | for params in paramlist: 555 | N+=params.numel() 556 | # logger.info("params.data:", params.data) 557 | X=torch.empty(N,dtype=torch.float64) 558 | X.fill_(0.0) 559 | offset=0 560 | for params in paramlist: 561 | numel=params.numel() 562 | with torch.no_grad(): 563 | X[offset:offset+numel].copy_(params.data.view_as(X[offset:offset+numel].data)) 564 | offset+=numel 565 | # logger.info("get trainable x:", X) 566 | return X 567 | 568 | 569 | def put_trainable_parameters(net,X): 570 | 'replace trainable parameter values by the given vector (only the first parameter set)' 571 | trainable=filter(lambda p: p.requires_grad, net.parameters()) 572 | paramlist=list(trainable) 573 | offset=0 574 | for params in paramlist: 575 | numel=params.numel() 576 | with torch.no_grad(): 577 | params.data.copy_(X[offset:offset+numel].data.view_as(params.data)) 578 | offset+=numel 579 | 580 | def compute_accuracy(model, dataloader, get_confusion_matrix=False, moon_model=False, device="cpu"): 581 | 582 | was_training = False 583 | if model.training: 584 | model.eval() 585 | was_training = True 586 | 587 | true_labels_list, pred_labels_list = np.array([]), np.array([]) 588 | 589 | if type(dataloader) == type([1]): 590 | pass 591 | else: 592 | dataloader = [dataloader] 593 | 594 | correct, total = 0, 0 595 | with torch.no_grad(): 596 | for tmp in dataloader: 597 | for batch_idx, (x, target) in enumerate(tmp): 598 | x, target = x.to(device), target.to(device,dtype=torch.int64) 599 | if moon_model: 600 | _, _, out = model(x) 601 | else: 602 | out = model(x) 603 | _, pred_label = torch.max(out.data, 1) 604 | 605 | total += x.data.size()[0] 606 | correct += (pred_label == target.data).sum().item() 607 | 608 | if device == "cpu": 609 | pred_labels_list = np.append(pred_labels_list, pred_label.numpy()) 610 | true_labels_list = np.append(true_labels_list, target.data.numpy()) 611 | else: 612 | pred_labels_list = np.append(pred_labels_list, pred_label.cpu().numpy()) 613 | true_labels_list = np.append(true_labels_list, target.data.cpu().numpy()) 614 | 615 | if get_confusion_matrix: 616 | conf_matrix = confusion_matrix(true_labels_list, pred_labels_list) 617 | 618 | if was_training: 619 | model.train() 620 | 621 | if get_confusion_matrix: 622 | return correct/float(total), conf_matrix 623 | 624 | return correct/float(total) 625 | 626 | 627 | def save_model(model, model_index, args): 628 | logger.info("saving local model-{}".format(model_index)) 629 | with open(args.modeldir+"trained_local_model"+str(model_index), "wb") as f_: 630 | torch.save(model.state_dict(), f_) 631 | return 632 | 633 | def load_model(model, model_index, device="cpu"): 634 | # 635 | with open("trained_local_model"+str(model_index), "rb") as f_: 636 | model.load_state_dict(torch.load(f_)) 637 | model.to(device) 638 | return model 639 | 640 | class AddGaussianNoise(object): 641 | def __init__(self, mean=0., std=1., net_id=None, total=0): 642 | self.std = std 643 | self.mean = mean 644 | self.net_id = net_id 645 | self.num = int(sqrt(total)) 646 | if self.num * self.num < total: 647 | self.num = self.num + 1 648 | 649 | def __call__(self, tensor): 650 | if self.net_id is None: 651 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 652 | else: 653 | tmp = torch.randn(tensor.size()) 654 | filt = torch.zeros(tensor.size()) 655 | size = int(28 / self.num) 656 | row = int(self.net_id / size) 657 | col = self.net_id % size 658 | for i in range(size): 659 | for j in range(size): 660 | filt[:,row*size+i,col*size+j] = 1 661 | tmp = tmp * filt 662 | return tensor + tmp * self.std + self.mean 663 | 664 | def __repr__(self): 665 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 666 | 667 | def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, noise_level=0, net_id=None, total=0): 668 | if dataset in ('mnist', 'femnist', 'fmnist', 'cifar10', 'svhn', 'generated', 'covtype', 'a9a', 'rcv1', 'SUSY', 'cifar100', 'tinyimagenet'): 669 | if dataset == 'mnist': 670 | dl_obj = MNIST_truncated 671 | 672 | transform_train = transforms.Compose([ 673 | transforms.ToTensor(), 674 | AddGaussianNoise(0., noise_level, net_id, total)]) 675 | 676 | transform_test = transforms.Compose([ 677 | transforms.ToTensor(), 678 | AddGaussianNoise(0., noise_level, net_id, total)]) 679 | 680 | elif dataset == 'femnist': 681 | dl_obj = FEMNIST 682 | transform_train = transforms.Compose([ 683 | transforms.ToTensor(), 684 | AddGaussianNoise(0., noise_level, net_id, total)]) 685 | transform_test = transforms.Compose([ 686 | transforms.ToTensor(), 687 | AddGaussianNoise(0., noise_level, net_id, total)]) 688 | 689 | elif dataset == 'fmnist': 690 | dl_obj = FashionMNIST_truncated 691 | transform_train = transforms.Compose([ 692 | transforms.ToTensor(), 693 | AddGaussianNoise(0., noise_level, net_id, total)]) 694 | transform_test = transforms.Compose([ 695 | transforms.ToTensor(), 696 | AddGaussianNoise(0., noise_level, net_id, total)]) 697 | 698 | elif dataset == 'svhn': 699 | dl_obj = SVHN_custom 700 | transform_train = transforms.Compose([ 701 | transforms.ToTensor(), 702 | AddGaussianNoise(0., noise_level, net_id, total)]) 703 | transform_test = transforms.Compose([ 704 | transforms.ToTensor(), 705 | AddGaussianNoise(0., noise_level, net_id, total)]) 706 | 707 | 708 | elif dataset == 'cifar10': 709 | dl_obj = CIFAR10_truncated 710 | 711 | transform_train = transforms.Compose([ 712 | transforms.ToTensor(), 713 | transforms.Lambda(lambda x: F.pad( 714 | Variable(x.unsqueeze(0), requires_grad=False), 715 | (4, 4, 4, 4), mode='reflect').data.squeeze()), 716 | transforms.ToPILImage(), 717 | transforms.RandomCrop(32), 718 | transforms.RandomHorizontalFlip(), 719 | transforms.ToTensor(), 720 | AddGaussianNoise(0., noise_level, net_id, total) 721 | ]) 722 | # data prep for test set 723 | transform_test = transforms.Compose([ 724 | transforms.ToTensor(), 725 | AddGaussianNoise(0., noise_level, net_id, total)]) 726 | 727 | elif dataset == 'cifar100': 728 | dl_obj = CIFAR100_truncated 729 | 730 | normalize = transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343], 731 | std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]) 732 | # transform_train = transforms.Compose([ 733 | # transforms.RandomCrop(32), 734 | # transforms.RandomHorizontalFlip(), 735 | # transforms.ToTensor(), 736 | # normalize 737 | # ]) 738 | transform_train = transforms.Compose([ 739 | # transforms.ToPILImage(), 740 | transforms.RandomCrop(32, padding=4), 741 | transforms.RandomHorizontalFlip(), 742 | transforms.RandomRotation(15), 743 | transforms.ToTensor(), 744 | normalize 745 | ]) 746 | # data prep for test set 747 | transform_test = transforms.Compose([ 748 | transforms.ToTensor(), 749 | normalize]) 750 | elif dataset == 'tinyimagenet': 751 | dl_obj = ImageFolder_custom 752 | transform_train = transforms.Compose([ 753 | transforms.Resize(32), 754 | transforms.RandomCrop(32, padding=4), 755 | transforms.RandomHorizontalFlip(), 756 | transforms.RandomRotation(15), 757 | transforms.ToTensor(), 758 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 759 | ]) 760 | transform_test = transforms.Compose([ 761 | transforms.Resize(32), 762 | transforms.ToTensor(), 763 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 764 | ]) 765 | 766 | else: 767 | dl_obj = Generated 768 | transform_train = None 769 | transform_test = None 770 | 771 | 772 | if dataset == "tinyimagenet": 773 | train_ds = dl_obj(datadir+'./train/', dataidxs=dataidxs, transform=transform_train) 774 | test_ds = dl_obj(datadir+'./val/', transform=transform_test) 775 | else: 776 | train_ds = dl_obj(datadir, dataidxs=dataidxs, train=True, transform=transform_train, download=True) 777 | test_ds = dl_obj(datadir, train=False, transform=transform_test, download=True) 778 | 779 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=False) 780 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=False) 781 | 782 | return train_dl, test_dl, train_ds, test_ds 783 | 784 | 785 | def weights_init(m): 786 | """ 787 | Initialise weights of the model. 788 | """ 789 | if(type(m) == nn.ConvTranspose2d or type(m) == nn.Conv2d): 790 | nn.init.normal_(m.weight.data, 0.0, 0.02) 791 | elif(type(m) == nn.BatchNorm2d): 792 | nn.init.normal_(m.weight.data, 1.0, 0.02) 793 | nn.init.constant_(m.bias.data, 0) 794 | 795 | class NormalNLLLoss: 796 | """ 797 | Calculate the negative log likelihood 798 | of normal distribution. 799 | This needs to be minimised. 800 | 801 | Treating Q(cj | x) as a factored Gaussian. 802 | """ 803 | def __call__(self, x, mu, var): 804 | 805 | logli = -0.5 * (var.mul(2 * np.pi) + 1e-6).log() - (x - mu).pow(2).div(var.mul(2.0) + 1e-6) 806 | nll = -(logli.sum(1).mean()) 807 | 808 | return nll 809 | 810 | 811 | def noise_sample(choice, n_dis_c, dis_c_dim, n_con_c, n_z, batch_size, device): 812 | """ 813 | Sample random noise vector for training. 814 | 815 | INPUT 816 | -------- 817 | n_dis_c : Number of discrete latent code. 818 | dis_c_dim : Dimension of discrete latent code. 819 | n_con_c : Number of continuous latent code. 820 | n_z : Dimension of iicompressible noise. 821 | batch_size : Batch Size 822 | device : GPU/CPU 823 | """ 824 | 825 | z = torch.randn(batch_size, n_z, 1, 1, device=device) 826 | idx = np.zeros((n_dis_c, batch_size)) 827 | if(n_dis_c != 0): 828 | dis_c = torch.zeros(batch_size, n_dis_c, dis_c_dim, device=device) 829 | 830 | c_tmp = np.array(choice) 831 | 832 | for i in range(n_dis_c): 833 | idx[i] = np.random.randint(len(choice), size=batch_size) 834 | for j in range(batch_size): 835 | idx[i][j] = c_tmp[int(idx[i][j])] 836 | 837 | dis_c[torch.arange(0, batch_size), i, idx[i]] = 1.0 838 | 839 | dis_c = dis_c.view(batch_size, -1, 1, 1) 840 | 841 | if(n_con_c != 0): 842 | # Random uniform between -1 and 1. 843 | con_c = torch.rand(batch_size, n_con_c, 1, 1, device=device) * 2 - 1 844 | 845 | noise = z 846 | if(n_dis_c != 0): 847 | noise = torch.cat((z, dis_c), dim=1) 848 | if(n_con_c != 0): 849 | noise = torch.cat((noise, con_c), dim=1) 850 | 851 | return noise, idx 852 | 853 | 854 | -------------------------------------------------------------------------------- /vggmodel.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | class VGG(nn.Module): 13 | ''' 14 | VGG model 15 | ''' 16 | def __init__(self, features): 17 | super(VGG, self).__init__() 18 | self.features = features 19 | self.classifier = nn.Sequential( 20 | nn.Dropout(), 21 | nn.Linear(512, 512), 22 | nn.ReLU(True), 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Linear(512, 10), 27 | ) 28 | # Initialize weights 29 | for m in self.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 32 | m.weight.data.normal_(0, math.sqrt(2. / n)) 33 | m.bias.data.zero_() 34 | 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | x = x.view(x.size(0), -1) 39 | x = self.classifier(x) 40 | return x 41 | 42 | 43 | def make_layers(cfg, batch_norm=False): 44 | layers = [] 45 | in_channels = 3 46 | for v in cfg: 47 | if v == 'M': 48 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 49 | else: 50 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 51 | if batch_norm: 52 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 53 | else: 54 | layers += [conv2d, nn.ReLU(inplace=True)] 55 | in_channels = v 56 | return nn.Sequential(*layers) 57 | 58 | 59 | cfg = { 60 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 61 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 62 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 63 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 64 | 512, 512, 512, 512, 'M'], 65 | } 66 | 67 | 68 | def vgg11(): 69 | """VGG 11-layer model (configuration "A")""" 70 | return VGG(make_layers(cfg['A'])) 71 | 72 | 73 | def vgg11_bn(): 74 | """VGG 11-layer model (configuration "A") with batch normalization""" 75 | return VGG(make_layers(cfg['A'], batch_norm=True)) 76 | 77 | 78 | def vgg13(): 79 | """VGG 13-layer model (configuration "B")""" 80 | return VGG(make_layers(cfg['B'])) 81 | 82 | 83 | def vgg13_bn(): 84 | """VGG 13-layer model (configuration "B") with batch normalization""" 85 | return VGG(make_layers(cfg['B'], batch_norm=True)) 86 | 87 | 88 | def vgg16(): 89 | """VGG 16-layer model (configuration "D")""" 90 | return VGG(make_layers(cfg['D'])) 91 | 92 | 93 | def vgg16_bn(): 94 | """VGG 16-layer model (configuration "D") with batch normalization""" 95 | return VGG(make_layers(cfg['D'], batch_norm=True)) 96 | 97 | 98 | def vgg19(): 99 | """VGG 19-layer model (configuration "E")""" 100 | return VGG(make_layers(cfg['E'])) 101 | 102 | 103 | def vgg19_bn(): 104 | """VGG 19-layer model (configuration 'E') with batch normalization""" 105 | return VGG(make_layers(cfg['E'], batch_norm=True)) 106 | --------------------------------------------------------------------------------