├── Grad_Diff.py ├── Grad_Diff.sh ├── LICENSE ├── LocalSGD.py ├── Plot_grad_diversity.py ├── README.md ├── Run_Exper.sh ├── comm_helpers.py ├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── cifar.cpython-36.pyc │ ├── cifar.cpython-37.pyc │ ├── cifar.cpython-38.pyc │ ├── randaugment.cpython-36.pyc │ ├── randaugment.cpython-37.pyc │ └── randaugment.cpython-38.pyc ├── cifar.py ├── randaugment.py └── stl10_input.py ├── environment.yaml ├── logs └── readme_logs.txt ├── models ├── AlexNet.py ├── EMNIST_model.py ├── EMNIST_test.py ├── MLP.py ├── Semi_net.py ├── __init__.py ├── __pycache__ │ ├── EMNIST_model.cpython-36.pyc │ ├── EMNIST_model.cpython-37.pyc │ ├── EMNIST_model.cpython-38.pyc │ ├── MLP.cpython-36.pyc │ ├── MLP.cpython-37.pyc │ ├── MLP.cpython-38.pyc │ ├── Semi_net.cpython-36.pyc │ ├── Semi_net.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── alexnet.cpython-36.pyc │ ├── alexnet.cpython-37.pyc │ ├── cifar.cpython-36.pyc │ ├── cifar.cpython-37.pyc │ ├── resnet.cpython-36.pyc │ ├── resnet.cpython-37.pyc │ ├── resnet.cpython-38.pyc │ ├── resnet9.cpython-36.pyc │ ├── resnet9.cpython-38.pyc │ ├── resnet_gn.cpython-36.pyc │ ├── resnet_gn.cpython-37.pyc │ ├── resnet_gn.cpython-38.pyc │ ├── resnet_ln.cpython-36.pyc │ ├── resnet_ln.cpython-37.pyc │ ├── resnet_ln.cpython-38.pyc │ ├── vgg.cpython-37.pyc │ ├── vggnet.cpython-36.pyc │ ├── vggnet.cpython-37.pyc │ ├── vggnet.cpython-38.pyc │ ├── wrn.cpython-36.pyc │ ├── wrn.cpython-37.pyc │ └── wrn.cpython-38.pyc ├── base.py ├── cifar.py ├── resnet.py ├── resnet9.py ├── resnet_gn.py ├── resnet_ln.py ├── vgg.py ├── vggnet.py └── wrn.py ├── readme.md ├── run_cifar10.sh ├── run_cifar10_res9.sh ├── run_emnist.sh ├── run_svhn.sh ├── train_LocalSGD.py ├── train_parallel.py ├── transform.py ├── util_v4.py ├── utils_v1.py └── utils_v2.py /Grad_Diff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import argparse 5 | import sys 6 | from math import ceil 7 | from random import Random 8 | import time 9 | import random 10 | import torch 11 | import torch.distributed as dist 12 | import torch.utils.data.distributed 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim_tr 16 | from torch.multiprocessing import Process 17 | import torchvision 18 | from torchvision import datasets, transforms 19 | import torch.backends.cudnn as cudnn 20 | import datetime 21 | from scipy.io import loadmat 22 | import json 23 | from scipy import io 24 | import utils_v2 as util_1 25 | import math 26 | import copy 27 | import torch.distributed as dist 28 | 29 | from utils_v2 import * 30 | 31 | 32 | parser = argparse.ArgumentParser(description='Calculating gradient diversity') 33 | parser.add_argument('--experiment_name', default='EMNIST_size47_comUE10_H1_R0.4_SSFL', type=str, 34 | help='path to checkpoint') 35 | parser.add_argument('--experiment_folder', default='.', type=str, 36 | help='the path of the experiment') 37 | 38 | parser.add_argument('--vector_type', default='grad', type=str, 39 | help='vector_type = grad or weight_variation') 40 | 41 | parser.add_argument('--l2', 42 | default=1, 43 | type=int, 44 | help='l2 norm or l1 norm') 45 | 46 | parser.add_argument('--only_user', 47 | default=0, 48 | type=int, 49 | help='only user or not') 50 | 51 | parser.add_argument('--ord', 52 | default=2, 53 | type=int, 54 | help='norm ord = 1 or 2') 55 | 56 | args = parser.parse_args() 57 | 58 | def get_groups(args): 59 | if 'EMNIST' in args.experiment_name: 60 | if 'comUE10' in args.experiment_name: 61 | server_list = [0,48] 62 | group1 = [0] + np.arange(1, 6).tolist() 63 | group2 = [11] + np.arange(6, 11).tolist() 64 | group3 = [0, 11] 65 | groups = [group1, group2, group3] 66 | 67 | if 'comUE30' in args.experiment_name: 68 | server_list = [0,48,49] 69 | group1 = [0] + np.arange(1, 11).tolist() 70 | group2 = [31] + np.arange(11, 21).tolist() 71 | group3 = [32] + np.arange(21, 31).tolist() 72 | group4 = [0,31,32] 73 | groups = [group1, group2, group3, group4] 74 | 75 | if 'comUE47' in args.experiment_name: 76 | server_list = [0,48,49,50,51] 77 | group1 = [0] + np.arange(1, 11).tolist() 78 | group2 = [48] + np.arange(11, 21).tolist() 79 | group3 = [49] + np.arange(21, 31).tolist() 80 | group4 = [50] + np.arange(31, 41).tolist() 81 | group5 = [51] + np.arange(41, 48).tolist() 82 | group6 = [0, 48, 49, 50, 51] 83 | groups = [group1, group2, group3, group4, group5, group6] 84 | 85 | if 'SVHN' in args.experiment_name: 86 | if 'commUE10' in args.experiment_name: 87 | server_list = [0,31] 88 | group1 = [0] + np.arange(1, 6).tolist() 89 | group2 = [11] + np.arange(6, 11).tolist() 90 | group3 = [0, 11] 91 | groups = [group1, group2, group3] 92 | 93 | if 'commUE20' in args.experiment_name: 94 | server_list = [0,31] 95 | group1 = [0] + np.arange(1, 11).tolist() 96 | group2 = [21] + np.arange(11, 21).tolist() 97 | group3 = [0,21] 98 | groups = [group1, group2, group3] 99 | 100 | if 'commUE30' in args.experiment_name: 101 | server_list = [0,48,49,50,51] 102 | group1 = [0] + np.arange(1, 11).tolist() 103 | group2 = [31] + np.arange(11, 21).tolist() 104 | group3 = [32] + np.arange(21, 31).tolist() 105 | group4 = [0, 31, 32] 106 | groups = [group1, group2, group3, group4] 107 | 108 | 109 | return groups, server_list 110 | 111 | def Load_model_grad_checkpoint(experiment_folder='.', experiment_name=None, rank=0, epoch=10, tao=0.95): 112 | try: 113 | path_checkpoint = '%s/grad_checkpoint/%s/' %(experiment_folder, experiment_name) 114 | pthfile = path_checkpoint+'Rank%s_Epoch_%s_model_grads_tao_%s.pth' %(rank, epoch, tao) 115 | checkpoint_grads = torch.load(pthfile, map_location=lambda storage, loc: storage) 116 | except: 117 | path_checkpoint = '%s/checkpoint/%s/' %(experiment_folder, experiment_name) 118 | pthfile = path_checkpoint+'Rank%s_Epoch_%s_model_grads_tao_%s.pth' %(rank, epoch, tao) 119 | checkpoint_grads = torch.load(pthfile, map_location=lambda storage, loc: storage) 120 | 121 | return checkpoint_grads 122 | 123 | def Load_Avgmodel_weights(experiment_folder, experiment_name, epoch): 124 | path_checkpoint = '%s/checkpoint/%s/' %(experiment_folder, experiment_name) 125 | pthfile = path_checkpoint+'Avg_before_Epoch_%s_weights.pth' %(epoch) 126 | checkpoint_weights = torch.load(pthfile, map_location=lambda storage, loc: storage) 127 | try: 128 | checkpoint_weights = checkpoint_weights['state_dict'] 129 | except: 130 | pass 131 | return checkpoint_weights 132 | 133 | def Load_Avgmodel_weights_with_middle(experiment_folder, experiment_name, epoch, middle): 134 | path_checkpoint = '%s/checkpoint/%s/' %(experiment_folder, experiment_name) 135 | pthfile = path_checkpoint+'Avg_%s_Epoch_%s_weights.pth' %(middle, epoch) 136 | checkpoint_weights = torch.load(pthfile, map_location=lambda storage, loc: storage) 137 | try: 138 | checkpoint_weights = checkpoint_weights['state_dict'] 139 | except: 140 | pass 141 | return checkpoint_weights 142 | 143 | def Load_model_weight_checkpoint(experiment_folder='.', experiment_name=None, rank=0, epoch=10): 144 | path_checkpoint = '%s/checkpoint/%s/' %(experiment_folder, experiment_name) 145 | pthfile = path_checkpoint+'Rank%s_Epoch_%s_weights.pth' %(rank, epoch) 146 | checkpoint_weights = torch.load(pthfile, map_location=lambda storage, loc: storage) 147 | try: 148 | checkpoint_weights = checkpoint_weights['state_dict'] 149 | except: 150 | pass 151 | 152 | return checkpoint_weights 153 | 154 | if 'EMNIST' in args.experiment_name: 155 | E_list = [0,10,20,30,40,50,60,70,80,90] 156 | 157 | 158 | if 'cifar' in args.experiment_name: 159 | E_list = [0,30,60,90,120,150,180,210,240,270] 160 | 161 | if 'SVHN' in args.experiment_name: 162 | E_list = [0, 4, 8, 12, 16, 20, 24, 28, 32, 36] 163 | 164 | if 'H1' in args.experiment_name: 165 | Grouping_method = True 166 | if 'comUE10' in args.experiment_name or 'commUE10' in args.experiment_name: 167 | num_rank = 12 168 | if 'comUE20' in args.experiment_name or 'commUE20' in args.experiment_name: 169 | num_rank = 22 170 | if 'comUE30' in args.experiment_name or 'commUE30' in args.experiment_name: 171 | num_rank = 33 172 | if 'comUE47' in args.experiment_name or 'commUE47' in args.experiment_name: 173 | num_rank = 52 174 | 175 | else: 176 | Grouping_method = False 177 | if 'comUE10' in args.experiment_name or 'commUE10' in args.experiment_name: 178 | num_rank = 11 179 | if 'comUE20' in args.experiment_name or 'commUE20' in args.experiment_name: 180 | num_rank = 21 181 | if 'comUE30' in args.experiment_name or 'commUE30' in args.experiment_name: 182 | num_rank = 31 183 | if 'comUE47' in args.experiment_name or 'commUE47' in args.experiment_name: 184 | num_rank = 48 185 | 186 | 187 | 188 | if not Grouping_method: 189 | groups, server_list = get_groups(args) 190 | W_e_list_norm_mean_list = [] 191 | state_train_loss_epoch_list = [] 192 | state_train_masks_epoch_list = [] 193 | for e in E_list: 194 | W_e_list = [] 195 | state_train_loss_list = [] 196 | state_train_masks_list = [] 197 | 198 | for rank in range(num_rank): 199 | if rank in set(groups[-1]) and args.only_user: 200 | pass 201 | else: 202 | if args.vector_type == 'grad': 203 | W_rank_e = Load_model_grad_checkpoint(args.experiment_folder, args.experiment_name, rank=rank, epoch=e) 204 | 205 | elif args.vector_type == 'weight_variation': 206 | W = Load_model_weight_checkpoint(args.experiment_folder, args.experiment_name, rank=rank, epoch=e) 207 | W_i = get_params(W, WBN=False) 208 | W_avg = Load_Avgmodel_weights(args.experiment_folder, args.experiment_name, epoch=e) 209 | W_avg = get_params(W_avg, WBN=False) 210 | W_rank_e = W_i - W_avg 211 | 212 | 213 | W_rank_e_norm = torch.norm(W_rank_e) 214 | print(f'epoch={e},rank={rank},grad_norm={W_rank_e_norm}') 215 | W_e_list.append(W_rank_e.numpy()) 216 | try: 217 | state_values = Load_train_state(args.experiment_folder, args.experiment_name, rank=rank, epoch=e) 218 | 219 | state_train_loss = state_values.item()['train_loss'] 220 | state_train_masks = state_values.item()['train_mask'] 221 | except: 222 | state_train_loss = 0 223 | state_train_masks = 0 224 | 225 | state_train_loss_list.append(state_train_loss) 226 | state_train_masks_list.append(state_train_masks) 227 | 228 | W_e_list = np.array(W_e_list) 229 | W_e_list_mean = np.mean(W_e_list, axis=0) 230 | 231 | 232 | state_train_loss_epoch_list.append(round(np.mean(state_train_loss_list),2)) 233 | state_train_masks_epoch_list.append(np.mean(state_train_masks_list)) 234 | 235 | W_e_list_norm_list = [] 236 | for i in range(len(W_e_list)): 237 | W_Wavg_norm = 0 238 | 239 | W_Wavg_norm = np.linalg.norm(W_e_list[i] -W_e_list_mean) 240 | W_e_list_norm = np.linalg.norm(W_e_list[i],ord=args.ord)/np.linalg.norm(W_e_list_mean,ord=args.ord) 241 | if args.l2: 242 | W_e_list_norm = W_e_list_norm**2 243 | 244 | print(f'epoch={e}, i={i}, norm of (g-E[g]) = {W_Wavg_norm}') 245 | W_e_list_norm_list.append(W_e_list_norm) 246 | 247 | W_e_list_norm_mean = np.mean(W_e_list_norm_list) 248 | W_e_list_norm_mean_list.append(W_e_list_norm_mean) 249 | 250 | print(f'epoch={e}, grad_var={W_e_list_norm_mean}') 251 | print('Grouping_method = False',W_e_list_norm_mean_list) 252 | print(state_train_loss_epoch_list) 253 | print(state_train_masks_epoch_list) 254 | print('############') 255 | 256 | save_var_path = f'./diversity_ord_{args.ord}/' 257 | if args.vector_type == 'grad': 258 | save_var_path = save_var_path + 'grad_FedAvg_post/' 259 | elif args.vector_type == 'weight_variation': 260 | save_var_path = save_var_path + 'weight_variation_FedAvg_post/' 261 | save_var_path = save_var_path + f'{args.experiment_name}/' 262 | if not os.path.exists(save_var_path): 263 | os.makedirs(save_var_path) 264 | np.save(save_var_path+f'{args.vector_type}_diversity_list_l2_{args.l2}_only_user_{args.only_user}.npy',W_e_list_norm_mean_list) 265 | 266 | 267 | else: 268 | groups, server_list = get_groups(args) 269 | W_e_list_norm_mean_list = [] 270 | state_train_loss_epoch_list = [] 271 | state_train_masks_epoch_list = [] 272 | 273 | num_group = len(groups[0:-1]) 274 | for e in E_list: 275 | W_e_list = [[] for i in range(num_group)] 276 | 277 | state_train_loss_list = [] 278 | state_train_masks_list = [] 279 | 280 | for rank in range(num_rank): 281 | if rank in set(groups[-1]) and args.only_user: 282 | pass 283 | else: 284 | group_id = Get_group_num(args, groups, rank) 285 | if args.vector_type == 'grad': 286 | W_rank_e = Load_model_grad_checkpoint(args.experiment_folder, args.experiment_name, rank=rank, epoch=e) 287 | 288 | elif args.vector_type == 'weight_variation': 289 | W = Load_model_weight_checkpoint(args.experiment_folder, args.experiment_name, rank=rank, epoch=e) 290 | W_i = get_params(W, WBN=False) 291 | W_avg = Load_Avgmodel_weights_with_middle(args.experiment_folder, args.experiment_name, epoch=e, middle=f'before_g{group_id+1}') 292 | W_avg = get_params(W_avg, WBN=False) 293 | W_rank_e = W_i - W_avg 294 | 295 | W_rank_e_norm = torch.norm(W_rank_e) 296 | print(f'epoch={e},rank={rank},grad_norm={W_rank_e_norm}') 297 | 298 | W_e_list[group_id].append(W_rank_e.numpy()) 299 | 300 | try: 301 | state_values = Load_train_state(args.experiment_folder, args.experiment_name, rank=rank, epoch=e) 302 | 303 | state_train_loss = state_values.item()['train_loss'] 304 | state_train_masks = state_values.item()['train_mask'] 305 | except: 306 | state_train_loss = 0 307 | state_train_masks = 0 308 | 309 | state_train_loss_list.append(state_train_loss) 310 | state_train_masks_list.append(state_train_masks) 311 | 312 | W_e_var_list = [] 313 | W_e_list_norm_list = [] 314 | for i in range(num_group): 315 | W_e = np.array(W_e_list[i]) 316 | W_e_mean = np.mean(W_e, axis=0) 317 | 318 | for j in range(len(W_e_list[i])): 319 | W_Wavg_norm = np.linalg.norm(W_e[j]-W_e_mean) 320 | W_e_list_norm = np.linalg.norm(W_e[j],ord=args.ord)/np.linalg.norm(W_e_mean,ord=args.ord) 321 | if args.l2: 322 | W_e_list_norm = W_e_list_norm**2 323 | 324 | print(f'epoch={e}, j={j}, norm of (g-E[g]) = {W_Wavg_norm}') 325 | W_e_list_norm_list.append(W_e_list_norm) 326 | 327 | W_e_list_norm_mean = np.mean(W_e_list_norm_list) 328 | W_e_list_norm_mean_list.append(W_e_list_norm_mean) 329 | 330 | state_train_loss_epoch_list.append(round(np.mean(state_train_loss_list),2)) 331 | state_train_masks_epoch_list.append(np.mean(state_train_masks_list)) 332 | 333 | print(f'epoch={e}, grad_var={W_e_list_norm_mean}') 334 | print('Grouping_method = True',W_e_list_norm_mean_list) 335 | print(state_train_loss_epoch_list) 336 | print(state_train_masks_epoch_list) 337 | print('############') 338 | 339 | save_var_path = f'./diversity_ord_{args.ord}/' 340 | if args.vector_type == 'grad': 341 | save_var_path = save_var_path + 'grad_Group_post/' 342 | elif args.vector_type == 'weight_variation': 343 | save_var_path = save_var_path + 'weight_variation_Group_post/' 344 | 345 | save_var_path = save_var_path + f'{args.experiment_name}/' 346 | if not os.path.exists(save_var_path): 347 | os.makedirs(save_var_path) 348 | np.save(save_var_path+f'{args.vector_type}_diversity_list_l2_{args.l2}_only_user_{args.only_user}.npy',W_e_list_norm_mean_list) 349 | -------------------------------------------------------------------------------- /Grad_Diff.sh: -------------------------------------------------------------------------------- 1 | 2 | declare -a experiment_names=("Cifar10_res_H0_comUE10_R0.4_SSFL" "Cifar10_res_gn_H0_comUE10_R0.4_SSFL" "Cifar10_res_gn_H1_comUE10_R0.4_SSFL"\ 3 | "Cifar10_res_H0_comUE10_R0.0_SSFL" "Cifar10_res_gn_H0_comUE10_R0.0_SSFL" "Cifar10_res_gn_H1_comUE10_R0.0_SSFL"\ 4 | "Cifar10_res_H0_comUE10_R0.4_SFL" "Cifar10_res_H0_comUE10_R0.0_SFL" \ 5 | "EMNIST_size47_comUE10_H1_R0.4_SSFL" "EMNIST_size47_comUE30_H1_R0.4_SSFL" "EMNIST_size47_comUE47_H1_R0.4_SSFL"\ 6 | "EMNIST_size47_comUE10_H0_R0.4_SSFL" "EMNIST_size47_comUE30_H0_R0.4_SSFL" "EMNIST_size47_comUE47_H0_R0.4_SSFL"\ 7 | ) 8 | 9 | declare -a tao=("0.95") 10 | declare -a l2=("1" "0") 11 | declare -a method_list=("diversity1") 12 | declare -a vector_type_list=("weight_variation" "grad") 13 | declare -a only_user_list=("0" "1") 14 | declare -a ord_list=("2" "1") 15 | 16 | for i in ${!experiment_names[@]}; 17 | do 18 | for j in ${!tao[@]}; 19 | do 20 | for method in ${method_list[@]}; 21 | do 22 | 23 | for k in ${!l2[@]}; 24 | do 25 | 26 | for vector_type in ${vector_type_list[@]}; 27 | do 28 | 29 | for p in ${!only_user_list[@]}; 30 | do 31 | for q in ${!ord_list[@]}; 32 | do 33 | 34 | python Grad_Diff.py --experiment_name ${experiment_names[$i]} --only_user ${only_user_list[$p]} --ord ${ord_list[$q]} --vector_type $vector_type --experiment_folder ${experiment_folders[$i]} --l2 ${l2[$k]} --tao ${tao[$j]} --method $method 1>./logs/"${experiment_names[$i]}"_"${tao[$j]}"_"${l2[$k]}"_"$method"_"$vector_type"_"only_user"_"${only_user_list[$p]}"_"Norm_ord"_"${order[$q]}".log 2>./logs/"${experiment_names[$i]}"_"${tao[$j]}"_"${l2[$k]}"_"$method"_"$vector_type"_"only_user"_"${only_user_list[$p]}"_"Norm_ord"_"${order[$q]}".err & 35 | 36 | /bin/sleep 5 37 | done 38 | 39 | done 40 | 41 | done 42 | 43 | done 44 | done 45 | done 46 | 47 | done 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zhengming Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LocalSGD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.optim.optimizer import Optimizer, required 4 | from comm_helpers import communicate, flatten_tensors, unflatten_tensors 5 | import threading 6 | 7 | 8 | class SGD(Optimizer): 9 | r"""Implements stochastic gradient descent (optionally with momentum). 10 | 11 | Nesterov momentum is based on the formula from 12 | `On the importance of initialization and momentum in deep learning`__. 13 | 14 | Args: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float): learning rate 18 | momentum (float, optional): momentum factor (default: 0) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | dampening (float, optional): dampening for momentum (default: 0) 21 | nesterov (bool, optional): enables Nesterov momentum (default: False) 22 | 23 | Example: 24 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 25 | >>> optimizer.zero_grad() 26 | >>> loss_fn(model(input), target).backward() 27 | >>> optimizer.step() 28 | 29 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 30 | 31 | .. note:: 32 | The implementation of SGD with Momentum/Nesterov subtly differs from 33 | Sutskever et. al. and implementations in some other frameworks. 34 | 35 | Considering the specific case of Momentum, the update can be written as 36 | 37 | .. math:: 38 | v = \rho * v + g \\ 39 | p = p - lr * v 40 | 41 | where p, g, v and :math:`\rho` denote the parameters, gradient, 42 | velocity, and momentum respectively. 43 | 44 | This is in contrast to Sutskever et. al. and 45 | other frameworks which employ an update of the form 46 | 47 | .. math:: 48 | v = \rho * v + lr * g \\ 49 | p = p - v 50 | 51 | The Nesterov version is analogously modified. 52 | """ 53 | 54 | def __init__(self, params, alpha, gmf, size, lr=required, momentum=0, dampening=0, 55 | weight_decay=0, nesterov=False, variance=0): 56 | 57 | self.alpha = alpha 58 | self.gmf = gmf 59 | self.size = size 60 | self.comm_buf = [] 61 | 62 | 63 | 64 | if lr is not required and lr < 0.0: 65 | raise ValueError("Invalid learning rate: {}".format(lr)) 66 | if momentum < 0.0: 67 | raise ValueError("Invalid momentum value: {}".format(momentum)) 68 | if weight_decay < 0.0: 69 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 70 | 71 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 72 | weight_decay=weight_decay, nesterov=nesterov, variance=variance) 73 | if nesterov and (momentum <= 0 or dampening != 0): 74 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 75 | super(SGD, self).__init__(params, defaults) 76 | 77 | 78 | for group in self.param_groups: 79 | for p in group['params']: 80 | param_state = self.state[p] 81 | buf = param_state['anchor_model'] = torch.clone(p.data).detach() 82 | self.comm_buf.append(buf) 83 | 84 | self.first_flag = True 85 | self.comm_finish = threading.Event() 86 | self.buf_ready = threading.Event() 87 | self.comm_finish.set() 88 | self.buf_ready.clear() 89 | 90 | self.comm_thread = threading.Thread( 91 | target=SGD._async_all_reduce_, 92 | args=(self.comm_buf, self.buf_ready, self.comm_finish)) 93 | self.comm_thread.daemon = True 94 | self.comm_thread.name = 'Communication-Thread' 95 | self.comm_thread.start() 96 | 97 | def __setstate__(self, state): 98 | super(SGD, self).__setstate__(state) 99 | for group in self.param_groups: 100 | group.setdefault('nesterov', False) 101 | 102 | def step(self, closure=None): 103 | """Performs a single optimization step. 104 | 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | device = "cuda" if torch.cuda.is_available() else "cpu" 110 | 111 | loss = None 112 | if closure is not None: 113 | loss = closure() 114 | 115 | for group in self.param_groups: 116 | weight_decay = group['weight_decay'] 117 | momentum = group['momentum'] 118 | dampening = group['dampening'] 119 | nesterov = group['nesterov'] 120 | 121 | 122 | for p in group['params']: 123 | if p.grad is None: 124 | continue 125 | d_p = p.grad.data 126 | 127 | if weight_decay != 0: 128 | d_p.add_(weight_decay, p.data) 129 | if momentum != 0: 130 | param_state = self.state[p] 131 | 132 | if 'momentum_buffer' not in param_state: 133 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 134 | else: 135 | buf = param_state['momentum_buffer'] 136 | buf.mul_(momentum).add_(1 - dampening, d_p) 137 | if nesterov: 138 | d_p = d_p.add(momentum, buf) 139 | else: 140 | d_p = buf 141 | 142 | p.data.add_(-group['lr'], d_p) 143 | 144 | return loss 145 | 146 | def elastic_average(self, itr, cp): 147 | step_flag = (itr != 0 and itr % cp == 0) 148 | if step_flag: 149 | beta = 1/self.size - self.alpha - self.alpha**2/(1-self.alpha) 150 | for group in self.param_groups: 151 | for p in group['params']: 152 | param_state = self.state[p] 153 | buf = param_state['anchor_model'] 154 | 155 | p.data.mul_(1-self.alpha).add_(self.alpha, buf) 156 | buf.mul_(beta).add_(self.alpha/(1-self.alpha), p.data) 157 | 158 | communicate(self.comm_buf, dist.all_reduce) 159 | 160 | 161 | def overlap_elastic_average(self, itr, cp, req): 162 | step_flag = (itr != 0 and itr % cp == 0) 163 | if step_flag: 164 | beta = 1/self.size - self.alpha - self.alpha**2/(1-self.alpha) 165 | gamma = self.alpha/(1-self.alpha) 166 | if req: 167 | req.wait() 168 | for f, t in zip(unflatten_tensors(self.flat_tensor, self.comm_buf), self.comm_buf): 169 | t.set_(f) 170 | 171 | for group in self.param_groups: 172 | for p in group['params']: 173 | param_state = self.state[p] 174 | buf = param_state['anchor_model'] 175 | 176 | p.data.mul_(1-self.alpha).add_(self.alpha, buf) 177 | buf.mul_(beta).add_(gamma, p.data) 178 | 179 | self.flat_tensor = flatten_tensors(self.comm_buf) 180 | req = dist.all_reduce(tensor=self.flat_tensor, async_op=True) 181 | 182 | return req 183 | 184 | 185 | def BMUF(self, itr, cp): 186 | step_flag = (itr != 0 and itr % cp == 0) 187 | if step_flag: 188 | 189 | for group in self.param_groups: 190 | lr = group['lr'] 191 | for p in group['params']: 192 | param_state = self.state[p] 193 | old_data = param_state['anchor_model'] 194 | 195 | if 'global_momentum_buffer' not in param_state: 196 | buf = param_state['global_momentum_buffer'] = torch.clone(p.data).detach() 197 | buf.sub_(old_data) 198 | buf.div_(-lr) 199 | else: 200 | buf = param_state['global_momentum_buffer'] 201 | buf.mul_(self.gmf).sub_(1/lr, p.data).add_(1/lr, old_data) 202 | 203 | old_data.add_(-lr, buf) 204 | old_data.div_(self.size) 205 | 206 | communicate(self.comm_buf, dist.all_reduce) 207 | for group in self.param_groups: 208 | for p in group['params']: 209 | param_state = self.state[p] 210 | old_data = param_state['anchor_model'] 211 | p.data.copy_(old_data) 212 | 213 | 214 | def OverlapLocalSGD_step(self, itr, cp, req): 215 | # Olocal SGD 216 | step_flag = (itr != 0 and itr % cp == 0) 217 | if step_flag: 218 | 219 | self.comm_finish.wait() 220 | 221 | for group in self.param_groups: 222 | lr = group['lr'] 223 | for p in group['params']: 224 | param_state = self.state[p] 225 | old_data = param_state['anchor_model'] 226 | 227 | p.data.mul_(1-self.alpha).add_(self.alpha, old_data) 228 | 229 | #param_state['momentum_buffer'].zero_() 230 | if 'global_momentum_buffer' not in param_state: 231 | buf = param_state['global_momentum_buffer'] = torch.clone(p.data).detach() 232 | buf.sub_(old_data) 233 | buf.div_(-lr) 234 | else: 235 | buf = param_state['global_momentum_buffer'] 236 | buf.mul_(self.gmf).sub_(1/lr, p.data).add_(1/lr, old_data) 237 | 238 | old_data.add_(-lr, buf) 239 | old_data.div_(self.size) 240 | #param_state['momentum_buffer'].zero_() 241 | 242 | self.comm_finish.clear() 243 | self.buf_ready.set() 244 | 245 | def async_CoCoD_SGD_step(self, itr, cp, req): 246 | step_flag = (itr != 0 and itr % cp == 0) 247 | if step_flag: 248 | 249 | self.comm_finish.wait() 250 | 251 | for group in self.param_groups: 252 | for p in group['params']: 253 | param_state = self.state[p] 254 | old_data = param_state['anchor_model'] 255 | 256 | if 'local_anchor_model' not in param_state: 257 | param_state['local_anchor_model'] = torch.clone(old_data).detach() 258 | buf = param_state['local_anchor_model'] 259 | 260 | # update(anchor) 261 | old_data.add_(p.data).sub_(buf) 262 | 263 | # update training params 264 | p.data.copy_(old_data) 265 | 266 | # update local_anchor_model 267 | buf.copy_(old_data) 268 | 269 | old_data.div_(self.size) 270 | 271 | self.comm_finish.clear() 272 | self.buf_ready.set() 273 | 274 | 275 | @staticmethod 276 | def _async_all_reduce_(buff, buf_ready, comm_finish): 277 | while True: 278 | buf_ready.wait() 279 | communicate(buff, dist.all_reduce) 280 | buf_ready.clear() 281 | comm_finish.set() 282 | -------------------------------------------------------------------------------- /Plot_grad_diversity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import copy 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import random 7 | import matplotlib.ticker as ticker 8 | 9 | age_list = ['10','20','30','40','50','60','70','80','90','100'] 10 | 11 | stat_point = 0 12 | end_point = 10 13 | interval = 1 14 | age_list = age_list[stat_point:end_point:interval] 15 | 16 | color = ['#696969','coral','steelblue','maroon','deeppink','limegreen','firebrick','khaki','yellowgreen','navy'] 17 | 18 | hatch_list = ["/","X", "\\", "." , "+", "*", "o", "O", "x", "-"] 19 | 20 | name_list=age_list 21 | 22 | dir_root_list = ['./diversity_ord_1/', './diversity_ord_2/']' 23 | 24 | methods_list = ['diversity'] 25 | 26 | #### If you want to plot the results of EMNIST: 27 | Name_list = ["weight_variation_FedAvg_post/EMNIST_size47_comUE10_H0_R0.4_SSFL", 28 | "weight_variation_Group_post/EMNIST_size47_comUE10_H1_R0.4_SSFL", 29 | "weight_variation_FedAvg_post/EMNIST_size47_comUE30_H0_R0.4_SSFL", 30 | "weight_variation_Group_post/EMNIST_size47_comUE30_H1_R0.4_SSFL", 31 | "weight_variation_FedAvg_post/EMNIST_size47_comUE47_H0_R0.4_SSFL", 32 | "weight_variation_Group_post/EMNIST_size47_comUE47_H1_R0.4_SSFL"] 33 | #### If you want to plot the results of Cifar10 34 | # Name_list = ["weight_variation_FedAvg_post/Cifar10_res_H0_comUE10_R0.4_SSFL", 35 | # "weight_variation_Group_post/Cifar10_res_gn_H1_comUE10_R0.4_SSFL", 36 | # "weight_variation_FedAvg_post/Cifar10_res_gn_H0_comUE10_R0.4_SSFL", 37 | # "weight_variation_FedAvg_post/Cifar10_res_H0_comUE10_R0.0_SSFL", 38 | # "weight_variation_Group_post/Cifar10_res_gn_H1_comUE10_R0.0_SSFL", 39 | # "weight_variation_FedAvg_post/Cifar10_res_gn_H0_comUE10_R0.0_SSFL",] 40 | 41 | l2_list = ['1','0'] 42 | all_user_list = ['1','0'] 43 | type_list = ['grad', 'weight_variation'] 44 | 45 | for type in type_list: 46 | for dir_root in dir_root_list: 47 | for l2_value in l2_list: 48 | Var_list = [] 49 | for only_user in all_user_list: 50 | for name in Name_list: 51 | dir = dir_root + f'{name}/' 52 | WD0 = np.load(dir+f'{type}_diversity_list_l2_{l2_value}_only_user_{only_user}.npy') 53 | 54 | WD0 = np.array(WD0) 55 | 56 | Var = np.zeros((10,)) 57 | Var[0:len(WD0[0:])] = WD0[0:] 58 | Var_list.append(Var[stat_point:end_point:interval]) 59 | 60 | x = list(range(len(name_list))) 61 | width=0.3/1.5 62 | index=np.arange(len(name_list))+1 63 | 64 | 65 | plt.bar(index,Var_list[0],width,color='k',tick_label = name_list, hatch=hatch_list[0],alpha=0.6) 66 | plt.bar(index+width,Var_list[1],width,color='#d95f0e',hatch=hatch_list[1]) 67 | 68 | Legend_name = ['FedAvg','Grouping-based'] 69 | 70 | font_size = 29 71 | plt.yscale('log') 72 | plt.yticks(fontproperties = 'Times New Roman', size = font_size-10) 73 | plt.xticks(fontproperties = 'Times New Roman', size = font_size-10) 74 | plt.ylabel('Gradient diversity', fontdict={'family' : 'Times New Roman', 'size' : font_size}) 75 | plt.xlabel('epoch', fontdict={'family' : 'Times New Roman', 'size' : font_size}) 76 | if l2_value == '1': 77 | plt.ylim([1.0,49.0]) 78 | else: 79 | if only_user == '1': 80 | plt.ylim([1.0000,7.20025]) 81 | else: 82 | plt.ylim([1.00001,7.20025]) 83 | plt.grid(True, linestyle = "-.", linewidth = "0.15") 84 | 85 | if l2_value == '1': 86 | plt.legend(Legend_name,labelspacing=0.2, loc=4,fontsize=18.2, ncol=1) 87 | else: 88 | plt.legend(Legend_name,labelspacing=0.2, loc=4,fontsize=18.2, ncol=1) 89 | plt.tight_layout() 90 | if 'ord_1' in dir_root: 91 | Norm_ord_L = 1 92 | else: 93 | Norm_ord_L = 2 94 | if only_user == '0': 95 | plt.savefig(f'{type}_diversity_EMNIST_C10_L2_{l2_value}_all_ord{Norm_ord_L}.pdf') 96 | else: 97 | plt.savefig(f'{type}_diversity_EMNIST_C10_L2_{l2_value}_only_user_ord{Norm_ord_L}.pdf') 98 | plt.show() 99 | 100 | #### C=30 101 | plt.bar(index,Var_list[2],width,color='peru',tick_label = name_list, hatch=hatch_list[0],alpha=0.6) 102 | plt.bar(index+width,Var_list[3],width,color='#2c7fb8',hatch=hatch_list[1]) 103 | 104 | Legend_name = ['FedAvg','Grouping-based'] 105 | 106 | font_size = 29 107 | plt.yscale('log') 108 | plt.yticks(fontproperties = 'Times New Roman', size = font_size-10) 109 | plt.xticks(fontproperties = 'Times New Roman', size = font_size-10) 110 | plt.ylabel('Gradient diversity', fontdict={'family' : 'Times New Roman', 'size' : font_size}) 111 | plt.xlabel('epoch', fontdict={'family' : 'Times New Roman', 'size' : font_size}) 112 | if l2_value == '1': 113 | plt.ylim([1.0,49.0]) 114 | else: 115 | if only_user == '1': 116 | plt.ylim([1.0000,7.20025]) 117 | else: 118 | plt.ylim([1.00001,7.20025]) 119 | plt.grid(True, linestyle = "-.", linewidth = "0.15") 120 | if l2_value == '1': 121 | plt.legend(Legend_name,labelspacing=0.2, loc=4,fontsize=18.2, ncol=1) 122 | else: 123 | plt.legend(Legend_name,labelspacing=0.2, loc=4,fontsize=18.2, ncol=1) 124 | plt.tight_layout() 125 | if 'ord_1' in dir_root: 126 | Norm_ord_L = 1 127 | else: 128 | Norm_ord_L = 2 129 | if only_user == '0': 130 | plt.savefig(f'{type}_diversity_EMNIST_C30_L2_{l2_value}_all_ord{Norm_ord_L}.pdf') 131 | else: 132 | plt.savefig(f'{type}_diversity_EMNIST_C30_L2_{l2_value}_only_user_ord{Norm_ord_L}.pdf') 133 | plt.show() 134 | 135 | ###### C=47 136 | plt.bar(index,Var_list[4],width,color='#756bb1',tick_label = name_list, hatch=hatch_list[0],alpha=0.6) 137 | plt.bar(index+width,Var_list[5],width,color='#c51b8a',hatch=hatch_list[1]) 138 | 139 | Legend_name = ['FedAvg','Grouping-based'] 140 | 141 | font_size = 29 142 | plt.yscale('log') 143 | plt.yticks(fontproperties = 'Times New Roman', size = font_size-10) 144 | plt.xticks(fontproperties = 'Times New Roman', size = font_size-10) 145 | plt.ylabel('Gradient diversity', fontdict={'family' : 'Times New Roman', 'size' : font_size}) 146 | plt.xlabel('epoch', fontdict={'family' : 'Times New Roman', 'size' : font_size}) 147 | if l2_value == '1': 148 | plt.ylim([1.0,49.0]) 149 | else: 150 | if only_user == '1': 151 | plt.ylim([1.0000,7.20025]) 152 | else: 153 | plt.ylim([1.00001,7.20025]) 154 | plt.grid(True, linestyle = "-.", linewidth = "0.15") 155 | if l2_value == '1': 156 | plt.legend(Legend_name,labelspacing=0.2, loc=4,fontsize=18.2, ncol=1) 157 | else: 158 | plt.legend(Legend_name,labelspacing=0.2, loc=4,fontsize=18.2, ncol=1) 159 | plt.tight_layout() 160 | if 'ord_1' in dir_root: 161 | Norm_ord_L = 1 162 | else: 163 | Norm_ord_L = 2 164 | if only_user == '0': 165 | plt.savefig(f'{type}_diversity_EMNIST_C47_L2_{l2_value}_all_ord{Norm_ord_L}.pdf') 166 | else: 167 | plt.savefig(f'{type}_diversity_EMNIST_C47_L2_{l2_value}_only_user_ord{Norm_ord_L}.pdf') 168 | 169 | plt.show() 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSFL-Semi-supervised-Federated-Learning: Improving Semi-supervised Federated Learning by Reducing the Gradient Diversity of Models 2 | Improving Semi-supervised Federated Learning by Reducing the Gradient Diversity of Models 3 | ## Introduction 4 | This repository includes all necessary programs to implement Semi-supervised Federated Learning of [the following paper](https://arxiv.org/abs/2008.11364). The code runs on Python 3.7.6 with PyTorch 1.0.0 and torchvision 0.2.2. We appreciate it if you would please cite the following paper if you found the repository useful for your work: 5 | 6 | 7 | ``` 8 | @article{SSFL, 9 | title={Improving Semi-supervised Federated Learning by Reducing the Gradient Diversity of Models}, 10 | author={Zhang, Zhengming and Yang, Yaoqing and Yao, Zhewei and Yan, Yujun and Gonzalez, Joseph E and Ramchandran, Kannan and Mahoney, Michael W}, 11 | journal={IEEE International Conference on Big Data (Big Data)}, 12 | year={2021} 13 | } 14 | ``` 15 | 16 | 17 | ## Usage 18 | Please first clone the this library to your local system: 19 | 20 | ``` 21 | git clone https://github.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning.git 22 | ``` 23 | 24 | After cloning, please use Anaconda to install all the dependencies: 25 | 26 | ``` 27 | conda env create -f environment.yml 28 | ``` 29 | 30 | To run the main scripte "train_parallel.py", one needs to determine number of available GPUs. 31 | For example, the number of GPUs on your machine is 4, you need to specify GPU_list as 0123 in run_cifar10.sh, run_svhn.sh and run_emnist.sh. 32 | 33 | Then, you can train a Semi-supervised Federated Learning experiment using the following command: 34 | 35 | ``` 36 | python train_parallel.py [--GPU_list] [--datasetid] [--size] [--basicLabelRatio] [--labeled] [--num_comm_ue] [--H] [--cp] [--eval_grad] [--experiment_folder] [--experiment_name] [--tao] [--model] --[ue_loss] [--user_semi] 37 | [--epoch] [--batch_size] [--fast] [--Ns] 38 | optional arguments: 39 | --GPU_list: GPUs used for training, e.g., --GPU_list 0123456789 40 | --datasetid: the id of the datasets (default: 0), datasetid = 0/1/2 means the Cifar-10/SVHN/EMNIST dataset is used in the experiment. 41 | --size: size = K (users) + 1 (server); 42 | --cp: cp in {2, 4, 8, 16} is frequency of communication; cp = 2 means UEs and server communicates every 2 iterations; 43 | --basicLabelRatio: basicLabelRatio in {0.1, 0.2, 0.4, ..., 1.0}, is the degree of data dispersion for each UE, 44 | basicLabelRatio = 0.0 means UE has the same amount of samples in each class; basicLabelRatio = 1.0 means all samples owned 45 | by UE belong to the same class; 46 | --model: model in {'res', 'res_gn', 'EMNIST_model'}; model = 'res' means we use ResNet18 + BN; model = 'res_gn' means we use ResNet18 + GN, EMNIST_model are used to train SSFL models on EMNIST dataset; 47 | --num_comm_ue: num_comm_ue in {1, 2, ..., K}; communication user number per iteration; 48 | --H: H in {0, 1}; use grouping-based method or not; H = 1 means we use grouping-based method; H = 0 means we use FedAvg method; 49 | --Ns: num_data_server in {1000, 4000}, number of labeled samples in server; 50 | --labeled: labeled in {0, 1}, labeled=1 means supervised FL, labeled=0 means semi-supervised FL; 51 | --cp: cp in {2,4,8,16,32,64} is communication period between the users and the server 52 | --eval_grad: eval_grad in {0, 1}, eval_grad=1 means that we load the model stored during training to calculate the gradient; 53 | --experiment_folder: storage directory of experiment checkpoints; 54 | --experiment_name: the name of current experiment; 55 | --tao: hyperparameters used to calculate CRL; 56 | --model: neural network model for training; 57 | --ue_loss: ue_loss=CRL means we use CRL as the loss for local training; ue_loss=SF means we use self-training method for local training; 58 | --epoch: training epoches of SSFL; 59 | --batch_size: batch size used for training; 60 | --fast: hyperparameters used for learning rate update; 61 | --Ns: the number of labeled data in the server. 62 | ``` 63 | For example, you can also run the following script to reproduce the results of SSFL on Cifar10 in the setting of K=C=10, Ns=1000, model=res_gn with the non-iidness 0.4 and the grouping-based method. 64 | ``` 65 | python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 0.4 --experiment_name Cifar10_res_gn_H1_comUE10_R0.4_SSFL 66 | ``` 67 | 68 | In the above experiment, the default model is ResNet18. One can use the following command to change the model: 69 | ``` 70 | python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --model res9 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 0.4 --experiment_name Cifar10_res_gn_H1_comUE10_R0.4_SSFL 71 | ``` 72 | One can also perform the experiments on ResNet9 to compare with another paper on federated semi-supervised learning. See Table 3 in section 4.3 in our paper for more details. 73 | ``` 74 | nohup bash run_cifar10_res9.sh 75 | ``` 76 | 77 | 78 | The results will be saved in the folder results_v0, and the checkpoints will be save in "/checkpoints/Cifar10_res_gn_H1_comUE10_R0.4_SSFL" 79 | 80 | When all the checkpoints are saved, you can run the following script to calculate gradient diversity to reproduce the results on gradient diversity of our paper: 81 | ``` 82 | nohup bash Grad_Diff.sh 83 | ``` 84 | When all gradient diversities are calculated, you can run the following script to plot the results of gradient diversity. 85 | ``` 86 | python Plot_grad_diversity.py 87 | ``` 88 | 89 | You can also run the following scripts to reproduce the results reported in [the paper](https://arxiv.org/abs/2008.11364). 90 | 91 | ``` 92 | nohup bash Run_Exper.sh 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /Run_Exper.sh: -------------------------------------------------------------------------------- 1 | nohup bash run_cifar10.sh 2 | nohup bash run_svhn.sh 3 | nohup bash run_emnist.sh 4 | -------------------------------------------------------------------------------- /comm_helpers.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import math 4 | import sys 5 | import copy 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import functools 10 | import copy 11 | 12 | def flatten_tensors(tensors): 13 | """ 14 | Reference: https://github.com/facebookresearch/stochastic_gradient_push 15 | 16 | Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 17 | same dense type. 18 | Since inputs are dense, the resulting tensor will be a concatenated 1D 19 | buffer. Element-wise operation on this buffer will be equivalent to 20 | operating individually. 21 | Arguments: 22 | tensors (Iterable[Tensor]): dense tensors to flatten. 23 | Returns: 24 | A 1D buffer containing input tensors. 25 | """ 26 | if len(tensors) == 1: 27 | return tensors[0].view(-1).clone() 28 | flat = torch.cat([t.view(-1) for t in tensors], dim=0) 29 | return flat 30 | 31 | 32 | def unflatten_tensors(flat, tensors): 33 | """ 34 | Reference: https://github.com/facebookresearch/stochastic_gradient_push 35 | 36 | View a flat buffer using the sizes of tensors. Assume that tensors are of 37 | same dense type, and that flat is given by flatten_dense_tensors. 38 | Arguments: 39 | flat (Tensor): flattened dense tensors to unflatten. 40 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 41 | unflatten flat. 42 | Returns: 43 | Unflattened dense tensors with sizes same as tensors and values from 44 | flat. 45 | """ 46 | outputs = [] 47 | offset = 0 48 | for tensor in tensors: 49 | numel = tensor.numel() 50 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 51 | offset += numel 52 | return tuple(outputs) 53 | 54 | def communicate(tensors, communication_op, attention=False): 55 | """ 56 | Reference: https://github.com/facebookresearch/stochastic_gradient_push 57 | 58 | Communicate a list of tensors. 59 | Arguments: 60 | tensors (Iterable[Tensor]): list of tensors. 61 | communication_op: a method or partial object which takes a tensor as 62 | input and communicates it. It can be a partial object around 63 | something like torch.distributed.all_reduce. 64 | """ 65 | flat_tensor = flatten_tensors(tensors) 66 | communication_op(tensor=flat_tensor) 67 | if attention: 68 | return tensors/flat_tensor 69 | 70 | for f, t in zip(unflatten_tensors(flat_tensor, tensors), tensors): 71 | with torch.no_grad(): 72 | t.set_(f) 73 | 74 | # def group_by_dtype(tensors): 75 | # """ 76 | # Returns a dict mapping from the tensor dtype to a list containing all 77 | # tensors of that dtype. 78 | # Arguments: 79 | # tensors (Iterable[Tensor]): list of tensors. 80 | # """ 81 | # tensors_by_dtype = collections.defaultdict(list) 82 | # for tensor in tensors: 83 | # tensors_by_dtype[tensor.dtype].append(tensor) 84 | # return tensors_by_dtype 85 | # 86 | # def communicate(tensors, communication_op): 87 | # """ 88 | # Communicate a list of tensors. 89 | # Arguments: 90 | # tensors (Iterable[Tensor]): list of tensors. 91 | # communication_op: a method or partial object which takes a tensor as 92 | # input and communicates it. It can be a partial object around 93 | # something like torch.distributed.all_reduce. 94 | # """ 95 | # with torch.no_grad(): 96 | # tensors_by_dtype = group_by_dtype(tensors) 97 | # for dtype in tensors_by_dtype: 98 | # flat_tensor = flatten_tensors(tensors_by_dtype[dtype]) 99 | # communication_op(tensor=flat_tensor) 100 | # for f, t in zip(unflatten_tensors(flat_tensor, tensors_by_dtype[dtype]), 101 | # tensors_by_dtype[dtype]): 102 | # t.set_(f) 103 | 104 | def SyncEAvg(model, anchor_model, rank, size, group, alpha): 105 | ''' 106 | Inputs: 107 | model: (x^i) local neural net model at i-th worker node 108 | anchor_model: (z^1=z^2=...=z^m=z) local copy of auxiliary variable 109 | rank: (i) worker index 110 | size: (m) total number of workers 111 | group: worker group 112 | alpha: (a) elasticity parameter 113 | Output: 114 | return void, change in-place 115 | Formula: 116 | x_new = (1-a)*x^i + a*z 117 | z_new = z + a*(sum_i x^i - m*z) 118 | ''' 119 | 120 | for param1, param2 in zip(anchor_model.parameters(), model.parameters()): 121 | diff = (param2.data - param1.data) 122 | param2.data = (1-alpha)*param2.data + alpha*param1.data 123 | param1.data = param1.data/float(size) + alpha*diff 124 | 125 | for param in anchor_model.parameters(): 126 | dist.all_reduce(param.data, op=dist.ReduceOp.SUM, group=group) 127 | 128 | 129 | def AsyncEAvg(model, anchor_model, rank, size, group, req, alpha): 130 | ''' 131 | Inputs: 132 | model: (x^i) local neural net model at i-th worker node 133 | anchor_model: (z^1=z^2=...=z^m=z) local copy of auxiliary variable 134 | rank: (i) worker index 135 | size: (m) total number of workers 136 | group: worker group 137 | alpha: (a) elasticity parameter 138 | req: handle of last iteration's communication 139 | Output: 140 | return a handle of asynchronous fuction 141 | Formula: 142 | x_new = (1-a)*x^i + a*z 143 | z_new = z + a*(sum_i x^i - m*z) 144 | * the computation of z_new isn't finished when the function returns 145 | ''' 146 | if req: 147 | for param1, param2 in zip(anchor_model.parameters(), model.parameters()): 148 | req[param1].wait() # wait the last iteration's update of z to finish 149 | 150 | diff = (param2.data - param1.data) 151 | param2.data = (1-alpha)*param2.data + alpha*param1.data 152 | param1.data = param1.data/float(size) + alpha*diff 153 | else: 154 | for param1, param2 in zip(anchor_model.parameters(), model.parameters()): 155 | diff = (param2.data - param1.data) 156 | param2.data = (1-alpha)*param2.data + alpha*param1.data 157 | param1.data = param1.data/float(size) + alpha*diff 158 | 159 | for param in anchor_model.parameters(): 160 | req[param] = dist.all_reduce(param.data, op=dist.ReduceOp.SUM, group=group, async_op=True) 161 | 162 | return req 163 | 164 | 165 | def SyncAllreduce(model, rank, size): 166 | ''' 167 | Inputs: 168 | model: (x^i) local neural net model at i-th worker node 169 | anchor_model: (z^1=z^2=...=z^m=z) local copy of auxiliary variable 170 | rank: (i) worker index 171 | size: (m) total number of workers 172 | group: worker group 173 | Output: 174 | return void, change in-place 175 | Formula: 176 | x_new = sum_i x_i / size 177 | ''' 178 | communication_op = functools.partial(dist.all_reduce) 179 | params_list = [] 180 | for param in model.parameters(): 181 | param.data.div_(float(size)) 182 | 183 | # params_list.append(param.data) 184 | params_list.append(param) 185 | 186 | communicate(params_list, communication_op) 187 | 188 | def communicate_gather(tensors, rank, gsize, communication_op, group, dst=0, attention=False): 189 | """ 190 | Reference: https://github.com/facebookresearch/stochastic_gradient_push 191 | 192 | Communicate a list of tensors. 193 | Arguments: 194 | tensors (Iterable[Tensor]): list of tensors. 195 | communication_op: a method or partial object which takes a tensor as 196 | input and communicates it. It can be a partial object around 197 | something like torch.distributed.all_reduce. 198 | """ 199 | flat_tensor = flatten_tensors(tensors) 200 | if rank == 0: 201 | gather_list = [flat_tensor.clone() for _ in range(gsize)] 202 | else: 203 | gather_list = [] 204 | communication_op(tensor=flat_tensor, gather_list=gather_list, group=group, dst=dst) 205 | if attention: 206 | return tensors/flat_tensor 207 | gather_parameters_list = [] 208 | if rank == 0: 209 | for i in range(gsize): 210 | # tensors_clone = tensors.clone() 211 | tensors_clone = copy.deepcopy(tensors)#[ten.clone() for ten in tensors] 212 | for f, t in zip(unflatten_tensors(gather_list[i], tensors_clone), tensors_clone): 213 | with torch.no_grad(): 214 | t.set_(f) 215 | 216 | gather_parameters_list.append(tensors_clone) 217 | 218 | return gather_parameters_list 219 | else: 220 | return gather_parameters_list 221 | 222 | def SyncAllGather(model, rank, gsize, group): 223 | ''' 224 | Inputs: 225 | model: (x^i) local neural net model at i-th worker node 226 | anchor_model: (z^1=z^2=...=z^m=z) local copy of auxiliary variable 227 | rank: (i) worker index 228 | size: (m) total number of workers 229 | group: worker group 230 | Output: 231 | return void, change in-place 232 | Formula: 233 | x_new = sum_i x_i / size 234 | ''' 235 | communication_op = functools.partial(dist.gather) 236 | params_list = [] 237 | for param in model.parameters(): 238 | params_list.append(param.data.cpu().clone()) 239 | 240 | gather_parameters_list = communicate_gather(params_list, rank, gsize, communication_op, group, dst=0) 241 | return gather_parameters_list 242 | 243 | 244 | def communicate_1(tensors, communication_op, group, attention=False): 245 | """ 246 | Reference: https://github.com/facebookresearch/stochastic_gradient_push 247 | 248 | Communicate a list of tensors. 249 | Arguments: 250 | tensors (Iterable[Tensor]): list of tensors. 251 | communication_op: a method or partial object which takes a tensor as 252 | input and communicates it. It can be a partial object around 253 | something like torch.distributed.all_reduce. 254 | """ 255 | flat_tensor = flatten_tensors(tensors) 256 | communication_op(tensor=flat_tensor, group=group) 257 | if attention: 258 | return tensors/flat_tensor 259 | for f, t in zip(unflatten_tensors(flat_tensor, tensors), tensors): 260 | with torch.no_grad(): 261 | t.set_(f) 262 | 263 | def SyncAllreduce_1(model, rank, size,group): 264 | ''' 265 | Inputs: 266 | model: (x^i) local neural net model at i-th worker node 267 | anchor_model: (z^1=z^2=...=z^m=z) local copy of auxiliary variable 268 | rank: (i) worker index 269 | size: (m) total number of workers 270 | group: worker group 271 | Output: 272 | return void, change in-place 273 | Formula: 274 | x_new = sum_i x_i / size 275 | ''' 276 | communication_op = functools.partial(dist.all_reduce) 277 | params_list = [] 278 | for param in model.parameters(): 279 | param.data.div_(float(size)) 280 | # params_list.append(param.data) 281 | params_list.append(param) 282 | 283 | communicate_1(params_list, communication_op, group=group) 284 | 285 | def SyncAllreduce_2(model, rank, size, ue_list): 286 | ''' 287 | Inputs: 288 | model: (x^i) local neural net model at i-th worker node 289 | anchor_model: (z^1=z^2=...=z^m=z) local copy of auxiliary variable 290 | rank: (i) worker index 291 | size: (m) total number of workers 292 | group: worker group 293 | Output: 294 | return void, change in-place 295 | Formula: 296 | x_new = sum_i x_i / size 297 | ''' 298 | communication_op = functools.partial(dist.all_reduce) 299 | params_list = [] 300 | ue_list_set = set(ue_list) 301 | if rank in ue_list_set: 302 | for param in model.parameters(): 303 | param.data.div_(float(len(ue_list))) 304 | # params_list.append(param.data) 305 | params_list.append(param) 306 | else: 307 | for param in model.parameters(): 308 | param.data.mul_(0.0) 309 | # params_list.append(param.data) 310 | params_list.append(param) 311 | 312 | 313 | communicate(params_list, communication_op) 314 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import get_cifar10 -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/dataset/__pycache__/cifar.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cifar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/dataset/__pycache__/cifar.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cifar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/dataset/__pycache__/cifar.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/randaugment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/dataset/__pycache__/randaugment.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/randaugment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/dataset/__pycache__/randaugment.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/randaugment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/dataset/__pycache__/randaugment.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/cifar.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | import copy 8 | from .randaugment import RandAugmentMC 9 | import random 10 | import numpy as np 11 | 12 | seed_value = 1 13 | 14 | random.seed(seed_value) 15 | np.random.seed(seed_value) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | cifar10_mean = (0.4914, 0.4822, 0.4465) 20 | cifar10_std = (0.2471, 0.2435, 0.2616) 21 | cifar100_mean = (0.5071, 0.4867, 0.4408) 22 | cifar100_std = (0.2675, 0.2565, 0.2761) 23 | normal_mean = (0.5, 0.5, 0.5) 24 | normal_std = (0.5, 0.5, 0.5) 25 | 26 | 27 | def get_cifar10(root, num_expand_x, num_expand_u,device_ids, server_idxs): 28 | root='./data' 29 | transform_labeled = transforms.Compose([ 30 | transforms.RandomHorizontalFlip(), 31 | transforms.RandomCrop(size=32, 32 | padding=int(32*0.125), 33 | padding_mode='reflect'), 34 | transforms.ToTensor(), 35 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 36 | ]) 37 | transform_val = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 40 | ]) 41 | base_dataset = datasets.CIFAR10(root, train=True, download=False) 42 | 43 | # train_labeled_idxs, train_unlabeled_idxs = x_u_split( 44 | # base_dataset.targets, num_expand_x, num_expand_u, device_ids,server_idxs) 45 | train_labeled_idxs, train_unlabeled_idxs = server_idxs, device_ids 46 | 47 | train_labeled_dataset = CIFAR10SSL( 48 | root, train_labeled_idxs, train=True, 49 | transform=transform_labeled) 50 | 51 | train_unlabeled_dataset_list = [] 52 | for id in range(len(train_unlabeled_idxs)): 53 | train_unlabeled_dataset = CIFAR10SSL( 54 | root, train_unlabeled_idxs[id], train=True, 55 | transform=TransformFix(mean=cifar10_mean, std=cifar10_std)) 56 | 57 | train_unlabeled_dataset_list.append(train_unlabeled_dataset) 58 | 59 | test_dataset = datasets.CIFAR10( 60 | root, train=False, transform=transform_val, download=False) 61 | logger.info("Dataset: CIFAR10") 62 | 63 | return train_labeled_dataset, train_unlabeled_dataset_list, test_dataset, base_dataset 64 | 65 | 66 | def get_cifar10_semi(root, num_expand_x, num_expand_u,device_ids, server_idxs): 67 | root='./data' 68 | transform_labeled = transforms.Compose([ 69 | transforms.RandomHorizontalFlip(), 70 | transforms.RandomCrop(size=32, 71 | padding=int(32*0.125), 72 | padding_mode='reflect'), 73 | transforms.ToTensor(), 74 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 75 | ]) 76 | transform_val = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 79 | ]) 80 | base_dataset = datasets.CIFAR10(root, train=True, download=False) 81 | 82 | train_labeled_idxs, train_unlabeled_idxs = x_u_split_semi_cifar( 83 | base_dataset.targets, num_expand_x, num_expand_u, device_ids, server_idxs) 84 | # train_labeled_idxs, train_unlabeled_idxs = server_idxs, device_ids 85 | 86 | 87 | train_unlabeled_dataset_list = [] 88 | train_labeled_dataset_list = [] 89 | for id in range(len(train_unlabeled_idxs)): 90 | print(id) 91 | train_unlabeled_dataset = CIFAR10SSL( 92 | root, train_unlabeled_idxs[id], train=True, 93 | transform=TransformFix(mean=cifar10_mean, std=cifar10_std)) 94 | 95 | train_labeled_dataset = CIFAR10SSL( 96 | root, train_labeled_idxs[id], train=True, 97 | transform=transform_labeled) 98 | 99 | train_unlabeled_dataset_list.append(train_unlabeled_dataset) 100 | train_labeled_dataset_list.append(train_labeled_dataset) 101 | 102 | test_dataset = datasets.CIFAR10( 103 | root, train=False, transform=transform_val, download=False) 104 | logger.info("Dataset: CIFAR10") 105 | 106 | 107 | return train_labeled_dataset_list, train_unlabeled_dataset_list, test_dataset, base_dataset 108 | 109 | def get_svhn(root, num_expand_x, num_expand_u,device_ids, server_idxs): 110 | root='./data' 111 | 112 | transform_labeled = transforms.Compose([ 113 | transforms.RandomHorizontalFlip(), 114 | transforms.RandomCrop(size=32, 115 | padding=int(32*0.125), 116 | padding_mode='reflect'), 117 | transforms.ToTensor(), 118 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 119 | ]) 120 | transform_val = transforms.Compose([ 121 | transforms.ToTensor(), 122 | transforms.Normalize(mean=cifar10_mean, std=cifar10_std) 123 | ]) 124 | base_dataset = datasets.SVHN(root, split='train', download=False) 125 | 126 | # train_labeled_idxs, train_unlabeled_idxs = x_u_split( 127 | # base_dataset.labels, num_expand_x, num_expand_u, device_ids,server_idxs) 128 | train_labeled_idxs, train_unlabeled_idxs = server_idxs, device_ids 129 | 130 | train_labeled_dataset = SVHNSSL( 131 | root, train_labeled_idxs, split='train', 132 | transform=transform_labeled) 133 | 134 | train_unlabeled_dataset_list = [] 135 | train_unlabeled_idxs_tmp = copy.deepcopy(train_unlabeled_idxs[0]) 136 | 137 | import functools 138 | import operator 139 | 140 | for id in range(len(train_unlabeled_idxs)): 141 | train_unlabeled_dataset = SVHNSSL( 142 | root, train_unlabeled_idxs[id], split='train', 143 | transform=TransformFix(mean=cifar10_mean, std=cifar10_std)) 144 | train_unlabeled_dataset_list.append(train_unlabeled_dataset) 145 | 146 | test_dataset = datasets.SVHN( 147 | root, split='test', transform=transform_val, download=False) 148 | logger.info("Dataset: SVHN") 149 | 150 | 151 | return train_labeled_dataset, train_unlabeled_dataset_list, test_dataset, base_dataset 152 | 153 | def get_cifar100(root, num_labeled, num_expand_x, num_expand_u): 154 | 155 | transform_labeled = transforms.Compose([ 156 | transforms.RandomHorizontalFlip(), 157 | transforms.RandomCrop(size=32, 158 | padding=int(32*0.125), 159 | padding_mode='reflect'), 160 | transforms.ToTensor(), 161 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 162 | 163 | transform_val = transforms.Compose([ 164 | transforms.ToTensor(), 165 | transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) 166 | 167 | base_dataset = datasets.CIFAR100( 168 | root, train=True, download=True) 169 | 170 | train_labeled_idxs, train_unlabeled_idxs = x_u_split( 171 | base_dataset.targets, num_classes=100) 172 | 173 | train_labeled_dataset = CIFAR100SSL( 174 | root, train_labeled_idxs, train=True, 175 | transform=transform_labeled) 176 | 177 | train_unlabeled_dataset = CIFAR100SSL( 178 | root, train_unlabeled_idxs, train=True, 179 | transform=TransformFix(mean=cifar100_mean, std=cifar100_std)) 180 | 181 | test_dataset = datasets.CIFAR100( 182 | root, train=False, transform=transform_val, download=False) 183 | 184 | logger.info("Dataset: CIFAR100") 185 | logger.info(f"Labeled examples: {len(train_labeled_idxs)}" 186 | f" Unlabeled examples: {len(train_unlabeled_idxs)}") 187 | 188 | return train_labeled_dataset, train_unlabeled_dataset, test_dataset 189 | 190 | def get_emnist(root, num_expand_x, num_expand_u,device_ids, server_idxs, attack_idxs=None): 191 | root='./data' 192 | transform_labeled = transforms.Compose([ 193 | transforms.RandomHorizontalFlip(), 194 | transforms.RandomCrop(size=28, 195 | padding=int(28*0.125), 196 | padding_mode='reflect'), 197 | transforms.ToTensor(), 198 | transforms.Normalize((0.1307,), (0.3081,)) 199 | ]) 200 | transform_val = transforms.Compose([ 201 | transforms.ToTensor(), 202 | transforms.Normalize((0.1307,), (0.3081,)) 203 | ]) 204 | base_dataset = datasets.EMNIST(root, train=True,split='balanced', download=True) 205 | 206 | 207 | # train_labeled_idxs, train_unlabeled_idxs = x_u_split( 208 | # base_dataset.targets, num_expand_x, num_expand_u, device_ids,server_idxs) 209 | train_labeled_idxs, train_unlabeled_idxs = server_idxs, device_ids 210 | 211 | train_labeled_dataset = EMNIST( 212 | root, train_labeled_idxs, train=True, 213 | transform=transform_labeled) 214 | 215 | if attack_idxs is not None: 216 | train_attack_dataset = EMNIST( 217 | root, attack_idxs, train=True, 218 | transform=transform_labeled) 219 | 220 | 221 | train_unlabeled_dataset_list = [] 222 | print('len(train_unlabeled_idxs):',len(train_unlabeled_idxs)) 223 | 224 | 225 | for id in range(len(train_unlabeled_idxs)): 226 | train_unlabeled_dataset = EMNIST( 227 | root, train_unlabeled_idxs[id], train=True, 228 | transform=TransformFix(size = 28, mean=(0.1307,), std=(0.3081,))) 229 | train_unlabeled_dataset_list.append(train_unlabeled_dataset) 230 | 231 | test_dataset = datasets.EMNIST( 232 | root, train=False,split='balanced', transform=transform_val, download=True) 233 | 234 | if attack_idxs is not None: 235 | return train_labeled_dataset, train_unlabeled_dataset_list, test_dataset, train_attack_dataset, base_dataset 236 | else: 237 | return train_labeled_dataset, train_unlabeled_dataset_list, test_dataset, base_dataset 238 | 239 | def get_emnist_semi(root, num_expand_x, num_expand_u,device_ids, server_idxs): 240 | root='./data' 241 | transform_labeled = transforms.Compose([ 242 | transforms.RandomHorizontalFlip(), 243 | transforms.RandomCrop(size=28, 244 | padding=int(28*0.125), 245 | padding_mode='reflect'), 246 | transforms.ToTensor(), 247 | transforms.Normalize((0.1307,), (0.3081,)) 248 | ]) 249 | transform_val = transforms.Compose([ 250 | transforms.ToTensor(), 251 | transforms.Normalize((0.1307,), (0.3081,)) 252 | ]) 253 | base_dataset = datasets.EMNIST(root, train=True,split='balanced', download=True) 254 | 255 | 256 | train_labeled_idxs, train_unlabeled_idxs = x_u_split_semi( 257 | base_dataset.targets, num_expand_x, num_expand_u, device_ids, server_idxs) 258 | 259 | 260 | train_unlabeled_dataset_list = [] 261 | train_labeled_dataset_list = [] 262 | train_unlabeled_idxs_tmp = copy.deepcopy(train_unlabeled_idxs[0]) 263 | 264 | 265 | 266 | for id in range(len(train_unlabeled_idxs)): 267 | train_unlabeled_dataset = EMNIST( 268 | root, train_unlabeled_idxs[id], train=True, 269 | transform=TransformFix(size = 28, mean=(0.1307,), std=(0.3081,))) 270 | train_unlabeled_dataset_list.append(train_unlabeled_dataset) 271 | 272 | train_labeled_dataset = EMNIST( 273 | root, train_labeled_idxs[id], train=True, 274 | transform=transform_labeled) 275 | train_labeled_dataset_list.append(train_labeled_dataset) 276 | 277 | test_dataset = datasets.EMNIST( 278 | root, train=False,split='balanced', transform=transform_val, download=True) 279 | 280 | 281 | return train_labeled_dataset_list, train_unlabeled_dataset_list, test_dataset 282 | 283 | def x_u_split(labels, 284 | num_expand_x, 285 | num_expand_u, 286 | device_ids, 287 | server_idxs): 288 | labels = np.array(labels) 289 | labeled_idx = copy.deepcopy(server_idxs) 290 | unlabeled_idx = [] 291 | 292 | unlabeled_idx_list = [] 293 | for id in range(len(device_ids)): 294 | unlabeled_idx = device_ids[id] 295 | 296 | exapand_unlabeled = num_expand_u // len(device_ids[id]) // len(device_ids) 297 | 298 | unlabeled_idx = np.hstack( 299 | [unlabeled_idx for _ in range(exapand_unlabeled)]) 300 | 301 | if len(unlabeled_idx) < num_expand_u // len(device_ids): 302 | diff = num_expand_u // len(device_ids) - len(unlabeled_idx) 303 | unlabeled_idx = np.hstack( 304 | (unlabeled_idx, np.random.choice(unlabeled_idx, diff))) 305 | else: 306 | assert len(unlabeled_idx) == num_expand_u // len(device_ids) 307 | 308 | unlabeled_idx_list.append(unlabeled_idx) 309 | 310 | 311 | exapand_labeled = num_expand_x // len(labeled_idx) 312 | labeled_idx = np.hstack( 313 | [labeled_idx for _ in range(exapand_labeled)]) 314 | if len(labeled_idx) < num_expand_x: 315 | diff = num_expand_x - len(labeled_idx) 316 | labeled_idx = np.hstack( 317 | (labeled_idx, np.random.choice(labeled_idx, diff))) 318 | else: 319 | assert len(labeled_idx) == num_expand_x 320 | 321 | return labeled_idx, unlabeled_idx_list 322 | 323 | def x_u_split_semi(labels, 324 | num_expand_x, 325 | num_expand_u, 326 | device_ids, 327 | server_idxs): 328 | 329 | 330 | server_semi_idxs = [] 331 | for i in range(len(device_ids)): 332 | server_semi_idxs.append([]) 333 | 334 | num = len(server_idxs)//len(device_ids) 335 | for id in range(len(device_ids)-1): 336 | idx_tmp = server_idxs[id*num:(id+1)*num] 337 | server_semi_idxs[id] = idx_tmp 338 | server_semi_idxs[len(device_ids)-1] = server_idxs[(id+1)*num:] 339 | labels = np.array(labels) 340 | labeled_idx = copy.deepcopy(server_idxs) 341 | unlabeled_idx = [] 342 | 343 | unlabeled_idx_list = [] 344 | for id in range(len(device_ids)): 345 | unlabeled_idx = device_ids[id] 346 | exapand_unlabeled = num_expand_u // len(device_ids[id]) // len(device_ids) 347 | 348 | unlabeled_idx = np.hstack( 349 | [unlabeled_idx for _ in range(exapand_unlabeled)]) 350 | 351 | if len(unlabeled_idx) < num_expand_u // len(device_ids): 352 | diff = num_expand_u // len(device_ids) - len(unlabeled_idx) 353 | unlabeled_idx = np.hstack( 354 | (unlabeled_idx, np.random.choice(unlabeled_idx, diff))) 355 | else: 356 | assert len(unlabeled_idx) == num_expand_u // len(device_ids) 357 | 358 | unlabeled_idx_list.append(unlabeled_idx) 359 | 360 | labeled_idx_list = [] 361 | for id in range(len(device_ids)): 362 | labeled_idx = server_semi_idxs[id] 363 | exapand_unlabeled = num_expand_u // len(server_semi_idxs[id]) // len(server_semi_idxs) 364 | 365 | labeled_idx = np.hstack( 366 | [labeled_idx for _ in range(exapand_unlabeled)]) 367 | 368 | if len(labeled_idx) < num_expand_u // len(device_ids): 369 | diff = num_expand_u // len(device_ids) - len(labeled_idx) 370 | labeled_idx = np.hstack( 371 | (labeled_idx, np.random.choice(labeled_idx, diff))) 372 | else: 373 | assert len(labeled_idx) == num_expand_u // len(device_ids) 374 | 375 | labeled_idx_list.append(labeled_idx) 376 | 377 | return labeled_idx_list, unlabeled_idx_list 378 | 379 | 380 | def x_u_split_semi_cifar(labels, 381 | num_expand_x, 382 | num_expand_u, 383 | device_ids, 384 | server_idxs): 385 | 386 | unlabeled_idx = [] 387 | unlabeled_idx_list = [] 388 | for id in range(len(device_ids)): 389 | unlabeled_idx = device_ids[id] 390 | exapand_unlabeled = num_expand_u // len(device_ids[id]) // len(device_ids) 391 | 392 | unlabeled_idx = np.hstack( 393 | [unlabeled_idx for _ in range(exapand_unlabeled)]) 394 | 395 | if len(unlabeled_idx) < num_expand_u // len(device_ids): 396 | diff = num_expand_u // len(device_ids) - len(unlabeled_idx) 397 | unlabeled_idx = np.hstack( 398 | (unlabeled_idx, np.random.choice(unlabeled_idx, diff))) 399 | else: 400 | assert len(unlabeled_idx) == num_expand_u // len(device_ids) 401 | 402 | unlabeled_idx_list.append(unlabeled_idx) 403 | 404 | labeled_idx_list = [] 405 | for id in range(len(device_ids)): 406 | labeled_idx = server_idxs[id] 407 | 408 | exapand_unlabeled = num_expand_u // len(device_ids[id]) // len(device_ids) 409 | 410 | labeled_idx = np.hstack( 411 | [labeled_idx for _ in range(exapand_unlabeled)]) 412 | 413 | if len(labeled_idx) < num_expand_u // len(device_ids): 414 | diff = num_expand_u // len(device_ids) - len(labeled_idx) 415 | labeled_idx = np.hstack( 416 | (labeled_idx, np.random.choice(labeled_idx, diff))) 417 | else: 418 | assert len(labeled_idx) == num_expand_u // len(device_ids) 419 | 420 | labeled_idx_list.append(labeled_idx) 421 | 422 | return labeled_idx_list, unlabeled_idx_list 423 | 424 | 425 | class TransformFix(object): 426 | def __init__(self, mean, std,size=32): 427 | self.weak = transforms.Compose([ 428 | transforms.RandomHorizontalFlip(), 429 | transforms.RandomCrop(size=size, 430 | padding=int(size*0.125), 431 | padding_mode='reflect')]) 432 | self.strong = transforms.Compose([ 433 | transforms.RandomHorizontalFlip(), 434 | transforms.RandomCrop(size=size, 435 | padding=int(size*0.125), 436 | padding_mode='reflect'), 437 | RandAugmentMC(n=2, m=10)]) 438 | self.normalize = transforms.Compose([ 439 | transforms.ToTensor(), 440 | transforms.Normalize(mean=mean, std=std)]) 441 | 442 | def __call__(self, x): 443 | weak = self.weak(x) 444 | strong = self.strong(x) 445 | return self.normalize(weak), self.normalize(strong) 446 | 447 | 448 | class CIFAR10SSL(datasets.CIFAR10): 449 | def __init__(self, root, indexs, train=True, 450 | transform=None, target_transform=None, 451 | download=False): 452 | super().__init__(root, train=train, 453 | transform=transform, 454 | target_transform=target_transform, 455 | download=download) 456 | if indexs is not None: 457 | self.data = self.data[indexs] 458 | self.targets = np.array(self.targets)[indexs] 459 | 460 | def __getitem__(self, index): 461 | img, target = self.data[index], self.targets[index] 462 | img = Image.fromarray(img) 463 | 464 | if self.transform is not None: 465 | img = self.transform(img) 466 | 467 | if self.target_transform is not None: 468 | target = self.target_transform(target) 469 | 470 | return img, target 471 | 472 | class EMNIST(datasets.EMNIST): 473 | def __init__(self, root, indexs, train=True, 474 | transform=None, target_transform=None, 475 | download=True,split='balanced'): 476 | super().__init__(root, train=train, 477 | transform=transform, 478 | target_transform=target_transform,split='balanced', 479 | download=download) 480 | if indexs is not None: 481 | self.data = self.data[indexs] 482 | self.targets = np.array(self.targets)[indexs] 483 | 484 | def __getitem__(self, index): 485 | img, target = self.data[index], self.targets[index] 486 | img = img.cpu().numpy() 487 | img = Image.fromarray(img) 488 | 489 | if self.transform is not None: 490 | img = self.transform(img) 491 | 492 | if self.target_transform is not None: 493 | target = target.cpu().numpy() 494 | target = self.target_transform(target) 495 | return img, target 496 | 497 | class CIFAR100SSL(datasets.CIFAR100): 498 | def __init__(self, root, indexs, train=True, 499 | transform=None, target_transform=None, 500 | download=False): 501 | super().__init__(root, train=train, 502 | transform=transform, 503 | target_transform=target_transform, 504 | download=download) 505 | if indexs is not None: 506 | self.data = self.data[indexs] 507 | self.targets = np.array(self.targets)[indexs] 508 | 509 | def __getitem__(self, index): 510 | img, target = self.data[index], self.targets[index] 511 | img = Image.fromarray(img) 512 | 513 | if self.transform is not None: 514 | img = self.transform(img) 515 | 516 | if self.target_transform is not None: 517 | target = self.target_transform(target) 518 | 519 | return img, target 520 | 521 | class SVHNSSL(datasets.SVHN): 522 | def __init__(self, root, indexs, split='train', 523 | transform=None, target_transform=None, 524 | download=False): 525 | super().__init__(root, split='train', 526 | transform=transform, 527 | target_transform=target_transform, 528 | download=download) 529 | if indexs is not None: 530 | self.data = self.data[indexs] 531 | self.labels = np.array(self.labels)[indexs] 532 | 533 | def __getitem__(self, index): 534 | img, target = self.data[index], int(self.labels[index]) 535 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 536 | if self.transform is not None: 537 | img = self.transform(img) 538 | 539 | if self.target_transform is not None: 540 | target = self.target_transform(target) 541 | 542 | return img, target 543 | -------------------------------------------------------------------------------- /dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | # gray 57 | color = (127, 127, 127) 58 | img = img.copy() 59 | # print(img.size) 60 | # print(type(img)) 61 | if w == 32: 62 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 63 | else: 64 | color = (np.max(img)//2,) 65 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 66 | 67 | return img 68 | 69 | 70 | def Equalize(img, **kwarg): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Identity(img, **kwarg): 75 | return img 76 | 77 | 78 | def Invert(img, **kwarg): 79 | return PIL.ImageOps.invert(img) 80 | 81 | 82 | def Posterize(img, v, max_v, bias=0): 83 | v = _int_parameter(v, max_v) + bias 84 | return PIL.ImageOps.posterize(img, v) 85 | 86 | 87 | def Rotate(img, v, max_v, bias=0): 88 | v = _int_parameter(v, max_v) + bias 89 | if random.random() < 0.5: 90 | v = -v 91 | return img.rotate(v) 92 | 93 | 94 | def Sharpness(img, v, max_v, bias=0): 95 | v = _float_parameter(v, max_v) + bias 96 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 97 | 98 | 99 | def ShearX(img, v, max_v, bias=0): 100 | v = _float_parameter(v, max_v) + bias 101 | if random.random() < 0.5: 102 | v = -v 103 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 104 | 105 | 106 | def ShearY(img, v, max_v, bias=0): 107 | v = _float_parameter(v, max_v) + bias 108 | if random.random() < 0.5: 109 | v = -v 110 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 111 | 112 | 113 | def Solarize(img, v, max_v, bias=0): 114 | v = _int_parameter(v, max_v) + bias 115 | return PIL.ImageOps.solarize(img, 256 - v) 116 | 117 | 118 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 119 | v = _int_parameter(v, max_v) + bias 120 | if random.random() < 0.5: 121 | v = -v 122 | img_np = np.array(img).astype(np.int) 123 | img_np = img_np + v 124 | img_np = np.clip(img_np, 0, 255) 125 | img_np = img_np.astype(np.uint8) 126 | img = Image.fromarray(img_np) 127 | return PIL.ImageOps.solarize(img, threshold) 128 | 129 | 130 | def TranslateX(img, v, max_v, bias=0): 131 | v = _float_parameter(v, max_v) + bias 132 | if random.random() < 0.5: 133 | v = -v 134 | v = int(v * img.size[0]) 135 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 136 | 137 | 138 | def TranslateY(img, v, max_v, bias=0): 139 | v = _float_parameter(v, max_v) + bias 140 | if random.random() < 0.5: 141 | v = -v 142 | v = int(v * img.size[1]) 143 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 144 | 145 | 146 | def _float_parameter(v, max_v): 147 | return float(v) * max_v / PARAMETER_MAX 148 | 149 | 150 | def _int_parameter(v, max_v): 151 | return int(v * max_v / PARAMETER_MAX) 152 | 153 | 154 | def fixmatch_augment_pool(): 155 | # FixMatch paper 156 | augs = [(AutoContrast, None, None), 157 | (Brightness, 0.9, 0.05), 158 | (Color, 0.9, 0.05), 159 | (Contrast, 0.9, 0.05), 160 | (Equalize, None, None), 161 | (Identity, None, None), 162 | (Posterize, 4, 4), 163 | (Rotate, 30, 0), 164 | (Sharpness, 0.9, 0.05), 165 | (ShearX, 0.3, 0), 166 | (ShearY, 0.3, 0), 167 | (Solarize, 256, 0), 168 | (TranslateX, 0.3, 0), 169 | (TranslateY, 0.3, 0)] 170 | return augs 171 | 172 | 173 | def my_augment_pool(): 174 | # Test 175 | augs = [(AutoContrast, None, None), 176 | (Brightness, 1.8, 0.1), 177 | (Color, 1.8, 0.1), 178 | (Contrast, 1.8, 0.1), 179 | (Cutout, 0.2, 0), 180 | (Equalize, None, None), 181 | (Invert, None, None), 182 | (Posterize, 4, 4), 183 | (Rotate, 30, 0), 184 | (Sharpness, 1.8, 0.1), 185 | (ShearX, 0.3, 0), 186 | (ShearY, 0.3, 0), 187 | (Solarize, 256, 0), 188 | (SolarizeAdd, 110, 0), 189 | (TranslateX, 0.45, 0), 190 | (TranslateY, 0.45, 0)] 191 | return augs 192 | 193 | 194 | class RandAugmentPC(object): 195 | def __init__(self, n, m): 196 | assert n >= 1 197 | assert 1 <= m <= 10 198 | self.n = n 199 | self.m = m 200 | self.augment_pool = my_augment_pool() 201 | 202 | def __call__(self, img): 203 | ops = random.choices(self.augment_pool, k=self.n) 204 | for op, max_v, bias in ops: 205 | prob = np.random.uniform(0.2, 0.8) 206 | if random.random() + prob >= 1: 207 | img = op(img, v=self.m, max_v=max_v, bias=bias) 208 | w, h = img.size 209 | if w == 32: 210 | img = CutoutAbs(img, 16) 211 | else: 212 | img = CutoutAbs(img, 14) 213 | return img 214 | 215 | 216 | class RandAugmentMC(object): 217 | def __init__(self, n, m): 218 | assert n >= 1 219 | assert 1 <= m <= 10 220 | self.n = n 221 | self.m = m 222 | self.augment_pool = fixmatch_augment_pool() 223 | 224 | def __call__(self, img): 225 | ops = random.choices(self.augment_pool, k=self.n) 226 | for op, max_v, bias in ops: 227 | v = np.random.randint(1, self.m) 228 | if random.random() < 0.5: 229 | img = op(img, v=v, max_v=max_v, bias=bias) 230 | w, h = img.size 231 | if w == 32: 232 | img = CutoutAbs(img, 16) 233 | else: 234 | img = CutoutAbs(img, 14) 235 | return img 236 | -------------------------------------------------------------------------------- /dataset/stl10_input.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os, sys, tarfile, errno 5 | import numpy as np 6 | # import matplotlib.pyplot as plt 7 | 8 | if sys.version_info >= (3, 0, 0): 9 | import urllib.request as urllib # ugly but works 10 | else: 11 | import urllib 12 | 13 | try: 14 | from imageio import imsave 15 | except: 16 | from scipy.misc import imsave 17 | 18 | print(sys.version_info) 19 | 20 | # image shape 21 | HEIGHT = 96 22 | WIDTH = 96 23 | DEPTH = 3 24 | 25 | # size of a single image in bytes 26 | SIZE = HEIGHT * WIDTH * DEPTH 27 | 28 | # path to the directory with the data 29 | DATA_DIR = './data' 30 | 31 | # url of the binary data 32 | DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' 33 | 34 | # path to the binary train file with image data 35 | DATA_PATH = './data/stl10_binary/train_X.bin' 36 | 37 | DATA_PATH_unlabeled = './data/stl10_binary/unlabeled_X.bin' 38 | 39 | # path to the binary train file with labels 40 | LABEL_PATH = './data/stl10_binary/train_y.bin' 41 | 42 | DATA_PATH_test = './data/stl10_binary/test_X.bin' 43 | LABEL_PATH_test = './data/stl10_binary/test_y.bin' 44 | 45 | def read_labels(path_to_labels): 46 | """ 47 | :param path_to_labels: path to the binary file containing labels from the STL-10 dataset 48 | :return: an array containing the labels 49 | """ 50 | with open(path_to_labels, 'rb') as f: 51 | labels = np.fromfile(f, dtype=np.uint8) 52 | return labels 53 | 54 | 55 | def read_all_images(path_to_data): 56 | """ 57 | :param path_to_data: the file containing the binary images from the STL-10 dataset 58 | :return: an array containing all the images 59 | """ 60 | 61 | with open(path_to_data, 'rb') as f: 62 | # read whole file in uint8 chunks 63 | everything = np.fromfile(f, dtype=np.uint8) 64 | 65 | # We force the data into 3x96x96 chunks, since the 66 | # images are stored in "column-major order", meaning 67 | # that "the first 96*96 values are the red channel, 68 | # the next 96*96 are green, and the last are blue." 69 | # The -1 is since the size of the pictures depends 70 | # on the input file, and this way numpy determines 71 | # the size on its own. 72 | 73 | images = np.reshape(everything, (-1, 3, 96, 96)) 74 | 75 | # Now transpose the images into a standard image format 76 | # readable by, for example, matplotlib.imshow 77 | # You might want to comment this line or reverse the shuffle 78 | # if you will use a learning algorithm like CNN, since they like 79 | # their channels separated. 80 | images = np.transpose(images, (0, 3, 2, 1)) 81 | return images 82 | 83 | 84 | def read_single_image(image_file): 85 | """ 86 | CAREFUL! - this method uses a file as input instead of the path - so the 87 | position of the reader will be remembered outside of context of this method. 88 | :param image_file: the open file containing the images 89 | :return: a single image 90 | """ 91 | # read a single image, count determines the number of uint8's to read 92 | image = np.fromfile(image_file, dtype=np.uint8, count=SIZE) 93 | # force into image matrix 94 | image = np.reshape(image, (3, 96, 96)) 95 | # transpose to standard format 96 | # You might want to comment this line or reverse the shuffle 97 | # if you will use a learning algorithm like CNN, since they like 98 | # their channels separated. 99 | image = np.transpose(image, (2, 1, 0)) 100 | return image 101 | 102 | 103 | # def plot_image(image): 104 | # """ 105 | # :param image: the image to be plotted in a 3-D matrix format 106 | # :return: None 107 | # """ 108 | # plt.imshow(image) 109 | # plt.show() 110 | 111 | def save_image(image, name): 112 | imsave("%s.png" % name, image, format="png") 113 | 114 | def download_and_extract(): 115 | """ 116 | Download and extract the STL-10 dataset 117 | :return: None 118 | """ 119 | dest_directory = DATA_DIR 120 | if not os.path.exists(dest_directory): 121 | os.makedirs(dest_directory) 122 | filename = DATA_URL.split('/')[-1] 123 | filepath = os.path.join(dest_directory, filename) 124 | if not os.path.exists(filepath): 125 | def _progress(count, block_size, total_size): 126 | sys.stdout.write('\rDownloading %s %.2f%%' % (filename, 127 | float(count * block_size) / float(total_size) * 100.0)) 128 | sys.stdout.flush() 129 | filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) 130 | print('Downloaded', filename) 131 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 132 | 133 | def save_images(images, labels): 134 | print("Saving images to disk") 135 | i = 0 136 | for image in images: 137 | label = labels[i] 138 | directory = './img/' + str(label) + '/' 139 | try: 140 | os.makedirs(directory, exist_ok=True) 141 | except OSError as exc: 142 | if exc.errno == errno.EEXIST: 143 | pass 144 | filename = directory + str(i) 145 | print(filename) 146 | save_image(image, filename) 147 | i = i+1 148 | 149 | def get_stl10_data(): 150 | # download data if needed 151 | download_and_extract() 152 | 153 | # test to check if the whole dataset is read correctly 154 | images_train_labeled = read_all_images(DATA_PATH) 155 | labels_train_labeled = read_labels(LABEL_PATH) 156 | 157 | images_train_unlabeled = read_all_images(DATA_PATH_unlabeled) 158 | 159 | images_test = read_all_images(DATA_PATH_test) 160 | label_test = read_labels(LABEL_PATH_test) 161 | 162 | return images_train_labeled, labels_train_labeled, images_train_unlabeled, images_test, label_test 163 | 164 | 165 | 166 | # if __name__ == "__main__": 167 | # # download data if needed 168 | # download_and_extract() 169 | # 170 | # # test to check if the image is read correctly 171 | # # with open(DATA_PATH) as f: 172 | # # image = read_single_image(f) 173 | # # plot_image(image) 174 | # 175 | # # test to check if the whole dataset is read correctly 176 | # images = read_all_images(DATA_PATH) 177 | # print(images.shape) 178 | # 179 | # labels = read_labels(LABEL_PATH) 180 | # print(labels.shape) 181 | # 182 | # images_unlabeled = read_all_images(DATA_PATH_unlabeled) 183 | # print(images_unlabeled.shape) 184 | # 185 | # images_test = read_all_images(DATA_PATH_test) 186 | # print(images_test.shape) 187 | # 188 | # label_test = read_labels(LABEL_PATH_test) 189 | # print(label_test.shape) 190 | # 191 | # # save images to disk 192 | # # save_images(images, labels) 193 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0 7 | - _libgcc_mutex=0.1=main 8 | - alabaster=0.7.12=py37_0 9 | - anaconda=2020.02=py37_0 10 | - anaconda-client=1.7.2=py37_0 11 | - anaconda-navigator=1.9.12=py37_0 12 | - anaconda-project=0.8.4=py_0 13 | - argh=0.26.2=py37_0 14 | - asn1crypto=1.3.0=py37_0 15 | - astroid=2.3.3=py37_0 16 | - astropy=4.0=py37h7b6447c_0 17 | - atomicwrites=1.3.0=py37_1 18 | - attrs=19.3.0=py_0 19 | - babel=2.8.0=py_0 20 | - backcall=0.1.0=py37_0 21 | - backports=1.0=py_2 22 | - backports.functools_lru_cache=1.6.1=py_0 23 | - backports.shutil_get_terminal_size=1.0.0=py37_2 24 | - backports.tempfile=1.0=py_1 25 | - backports.weakref=1.0.post1=py_1 26 | - beautifulsoup4=4.8.2=py37_0 27 | - bitarray=1.2.1=py37h7b6447c_0 28 | - bkcharts=0.2=py37_0 29 | - blas=1.0=mkl 30 | - bleach=3.1.0=py37_0 31 | - blosc=1.16.3=hd408876_0 32 | - bokeh=1.4.0=py37_0 33 | - boto=2.49.0=py37_0 34 | - bottleneck=1.3.2=py37heb32a55_0 35 | - bzip2=1.0.8=h7b6447c_0 36 | - ca-certificates=2020.1.1=0 37 | - cairo=1.14.12=h8948797_3 38 | - certifi=2019.11.28=py37_0 39 | - cffi=1.14.0=py37h2e261b9_0 40 | - chardet=3.0.4=py37_1003 41 | - click=7.0=py37_0 42 | - cloudpickle=1.3.0=py_0 43 | - clyent=1.2.2=py37_1 44 | - colorama=0.4.3=py_0 45 | - conda=4.8.3=py37_0 46 | - conda-build=3.18.11=py37_0 47 | - conda-env=2.6.0=1 48 | - conda-package-handling=1.6.0=py37h7b6447c_0 49 | - conda-verify=3.4.2=py_1 50 | - contextlib2=0.6.0.post1=py_0 51 | - cryptography=2.8=py37h1ba5d50_0 52 | - cudatoolkit=10.2.89=hfd86e86_1 53 | - curl=7.68.0=hbc83047_0 54 | - cycler=0.10.0=py37_0 55 | - cython=0.29.15=py37he6710b0_0 56 | - cytoolz=0.10.1=py37h7b6447c_0 57 | - dask=2.11.0=py_0 58 | - dask-core=2.11.0=py_0 59 | - dbus=1.13.12=h746ee38_0 60 | - decorator=4.4.1=py_0 61 | - defusedxml=0.6.0=py_0 62 | - diff-match-patch=20181111=py_0 63 | - distributed=2.11.0=py37_0 64 | - docutils=0.16=py37_0 65 | - entrypoints=0.3=py37_0 66 | - et_xmlfile=1.0.1=py37_0 67 | - expat=2.2.6=he6710b0_0 68 | - fastcache=1.1.0=py37h7b6447c_0 69 | - filelock=3.0.12=py_0 70 | - flake8=3.7.9=py37_0 71 | - flask=1.1.1=py_0 72 | - fontconfig=2.13.0=h9420a91_0 73 | - freetype=2.9.1=h8a8886c_1 74 | - fribidi=1.0.5=h7b6447c_0 75 | - fsspec=0.6.2=py_0 76 | - future=0.18.2=py37_0 77 | - get_terminal_size=1.0.0=haa9412d_0 78 | - gevent=1.4.0=py37h7b6447c_0 79 | - glib=2.63.1=h5a9c865_0 80 | - glob2=0.7=py_0 81 | - gmp=6.1.2=h6c8ec71_1 82 | - gmpy2=2.0.8=py37h10f8cd9_2 83 | - graphite2=1.3.13=h23475e2_0 84 | - greenlet=0.4.15=py37h7b6447c_0 85 | - gst-plugins-base=1.14.0=hbbd80ab_1 86 | - gstreamer=1.14.0=hb453b48_1 87 | - h5py=2.10.0=py37h7918eee_0 88 | - harfbuzz=1.8.8=hffaf4a1_0 89 | - hdf5=1.10.4=hb1b8bf9_0 90 | - heapdict=1.0.1=py_0 91 | - html5lib=1.0.1=py37_0 92 | - hypothesis=5.5.4=py_0 93 | - icu=58.2=h9c2bf20_1 94 | - idna=2.8=py37_0 95 | - imageio=2.6.1=py37_0 96 | - imagesize=1.2.0=py_0 97 | - importlib_metadata=1.5.0=py37_0 98 | - intel-openmp=2020.0=166 99 | - intervaltree=3.0.2=py_0 100 | - ipykernel=5.1.4=py37h39e3cac_0 101 | - ipython=7.12.0=py37h5ca1d4c_0 102 | - ipython_genutils=0.2.0=py37_0 103 | - ipywidgets=7.5.1=py_0 104 | - isort=4.3.21=py37_0 105 | - itsdangerous=1.1.0=py37_0 106 | - jbig=2.1=hdba287a_0 107 | - jdcal=1.4.1=py_0 108 | - jedi=0.14.1=py37_0 109 | - jeepney=0.4.2=py_0 110 | - jinja2=2.11.1=py_0 111 | - joblib=0.14.1=py_0 112 | - jpeg=9b=h024ee3a_2 113 | - json5=0.9.1=py_0 114 | - jsonschema=3.2.0=py37_0 115 | - jupyter=1.0.0=py37_7 116 | - jupyter_client=5.3.4=py37_0 117 | - jupyter_console=6.1.0=py_0 118 | - jupyter_core=4.6.1=py37_0 119 | - jupyterlab=1.2.6=pyhf63ae98_0 120 | - jupyterlab_server=1.0.6=py_0 121 | - keyring=21.1.0=py37_0 122 | - kiwisolver=1.1.0=py37he6710b0_0 123 | - krb5=1.17.1=h173b8e3_0 124 | - lazy-object-proxy=1.4.3=py37h7b6447c_0 125 | - ld_impl_linux-64=2.33.1=h53a641e_7 126 | - libarchive=3.3.3=h5d8350f_5 127 | - libcurl=7.68.0=h20c2e04_0 128 | - libedit=3.1.20181209=hc058e9b_0 129 | - libffi=3.2.1=hd88cf55_4 130 | - libgcc-ng=9.1.0=hdf63c60_0 131 | - libgfortran-ng=7.3.0=hdf63c60_0 132 | - liblief=0.9.0=h7725739_2 133 | - libpng=1.6.37=hbc83047_0 134 | - libsodium=1.0.16=h1bed415_0 135 | - libspatialindex=1.9.3=he6710b0_0 136 | - libssh2=1.8.2=h1ba5d50_0 137 | - libstdcxx-ng=9.1.0=hdf63c60_0 138 | - libtiff=4.1.0=h2733197_0 139 | - libtool=2.4.6=h7b6447c_5 140 | - libuuid=1.0.3=h1bed415_2 141 | - libxcb=1.13=h1bed415_1 142 | - libxml2=2.9.9=hea5a465_1 143 | - libxslt=1.1.33=h7d1a2b0_0 144 | - llvmlite=0.31.0=py37hd408876_0 145 | - locket=0.2.0=py37_1 146 | - lxml=4.5.0=py37hefd8a0e_0 147 | - lz4-c=1.8.1.2=h14c3975_0 148 | - lzo=2.10=h49e0be7_2 149 | - markupsafe=1.1.1=py37h7b6447c_0 150 | - matplotlib=3.1.3=py37_0 151 | - matplotlib-base=3.1.3=py37hef1b27d_0 152 | - mccabe=0.6.1=py37_1 153 | - mistune=0.8.4=py37h7b6447c_0 154 | - mkl=2020.0=166 155 | - mkl-service=2.3.0=py37he904b0f_0 156 | - mkl_fft=1.0.15=py37ha843d7b_0 157 | - mkl_random=1.1.0=py37hd6b4f25_0 158 | - mock=4.0.1=py_0 159 | - more-itertools=8.2.0=py_0 160 | - mpc=1.1.0=h10f8cd9_1 161 | - mpfr=4.0.1=hdf1c602_3 162 | - mpmath=1.1.0=py37_0 163 | - msgpack-python=0.6.1=py37hfd86e86_1 164 | - multipledispatch=0.6.0=py37_0 165 | - navigator-updater=0.2.1=py37_0 166 | - nbconvert=5.6.1=py37_0 167 | - nbformat=5.0.4=py_0 168 | - ncurses=6.2=he6710b0_0 169 | - networkx=2.4=py_0 170 | - ninja=1.9.0=py37hfd86e86_0 171 | - nltk=3.4.5=py37_0 172 | - nose=1.3.7=py37_2 173 | - notebook=6.0.3=py37_0 174 | - numba=0.48.0=py37h0573a6f_0 175 | - numexpr=2.7.1=py37h423224d_0 176 | - numpy=1.18.1=py37h4f9e942_0 177 | - numpy-base=1.18.1=py37hde5b4d6_1 178 | - numpydoc=0.9.2=py_0 179 | - olefile=0.46=py37_0 180 | - openpyxl=3.0.3=py_0 181 | - openssl=1.1.1d=h7b6447c_4 182 | - packaging=20.1=py_0 183 | - pandas=1.0.1=py37h0573a6f_0 184 | - pandoc=2.2.3.2=0 185 | - pandocfilters=1.4.2=py37_1 186 | - pango=1.42.4=h049681c_0 187 | - parso=0.5.2=py_0 188 | - partd=1.1.0=py_0 189 | - patchelf=0.10=he6710b0_0 190 | - path=13.1.0=py37_0 191 | - path.py=12.4.0=0 192 | - pathlib2=2.3.5=py37_0 193 | - pathtools=0.1.2=py_1 194 | - patsy=0.5.1=py37_0 195 | - pcre=8.43=he6710b0_0 196 | - pep8=1.7.1=py37_0 197 | - pexpect=4.8.0=py37_0 198 | - pickleshare=0.7.5=py37_0 199 | - pillow=7.0.0=py37hb39fc2d_0 200 | - pip=20.0.2=py37_1 201 | - pixman=0.38.0=h7b6447c_0 202 | - pkginfo=1.5.0.1=py37_0 203 | - pluggy=0.13.1=py37_0 204 | - ply=3.11=py37_0 205 | - prometheus_client=0.7.1=py_0 206 | - prompt_toolkit=3.0.3=py_0 207 | - psutil=5.6.7=py37h7b6447c_0 208 | - ptyprocess=0.6.0=py37_0 209 | - py=1.8.1=py_0 210 | - py-lief=0.9.0=py37h7725739_2 211 | - pycodestyle=2.5.0=py37_0 212 | - pycosat=0.6.3=py37h7b6447c_0 213 | - pycparser=2.19=py37_0 214 | - pycrypto=2.6.1=py37h14c3975_9 215 | - pycurl=7.43.0.5=py37h1ba5d50_0 216 | - pydocstyle=4.0.1=py_0 217 | - pyflakes=2.1.1=py37_0 218 | - pygments=2.5.2=py_0 219 | - pylint=2.4.4=py37_0 220 | - pyodbc=4.0.30=py37he6710b0_0 221 | - pyopenssl=19.1.0=py37_0 222 | - pyparsing=2.4.6=py_0 223 | - pyqt=5.9.2=py37h05f1152_2 224 | - pyrsistent=0.15.7=py37h7b6447c_0 225 | - pysocks=1.7.1=py37_0 226 | - pytables=3.6.1=py37h71ec239_0 227 | - pytest=5.3.5=py37_0 228 | - pytest-arraydiff=0.3=py37h39e3cac_0 229 | - pytest-astropy=0.8.0=py_0 230 | - pytest-astropy-header=0.1.2=py_0 231 | - pytest-doctestplus=0.5.0=py_0 232 | - pytest-openfiles=0.4.0=py_0 233 | - pytest-remotedata=0.3.2=py37_0 234 | - python=3.7.6=h0371630_2 235 | - python-dateutil=2.8.1=py_0 236 | - python-jsonrpc-server=0.3.4=py_0 237 | - python-language-server=0.31.7=py37_0 238 | - python-libarchive-c=2.8=py37_13 239 | - pytorch=1.0.0=py3.7_cuda9.0.176_cudnn7.4.1_1 240 | - pytz=2019.3=py_0 241 | - pywavelets=1.1.1=py37h7b6447c_0 242 | - pyxdg=0.26=py_0 243 | - pyyaml=5.3=py37h7b6447c_0 244 | - pyzmq=18.1.1=py37he6710b0_0 245 | - qdarkstyle=2.8=py_0 246 | - qt=5.9.7=h5867ecd_1 247 | - qtawesome=0.6.1=py_0 248 | - qtconsole=4.6.0=py_1 249 | - qtpy=1.9.0=py_0 250 | - readline=7.0=h7b6447c_5 251 | - requests=2.22.0=py37_1 252 | - ripgrep=11.0.2=he32d670_0 253 | - rope=0.16.0=py_0 254 | - rtree=0.9.3=py37_0 255 | - ruamel_yaml=0.15.87=py37h7b6447c_0 256 | - scikit-image=0.16.2=py37h0573a6f_0 257 | - scikit-learn=0.22.1=py37hd81dba3_0 258 | - scipy=1.4.1=py37h0b6359f_0 259 | - seaborn=0.10.0=py_0 260 | - secretstorage=3.1.2=py37_0 261 | - send2trash=1.5.0=py37_0 262 | - setuptools=45.2.0=py37_0 263 | - simplegeneric=0.8.1=py37_2 264 | - singledispatch=3.4.0.3=py37_0 265 | - sip=4.19.8=py37hf484d3e_0 266 | - six=1.14.0=py37_0 267 | - snappy=1.1.7=hbae5bb6_3 268 | - snowballstemmer=2.0.0=py_0 269 | - sortedcollections=1.1.2=py37_0 270 | - sortedcontainers=2.1.0=py37_0 271 | - soupsieve=1.9.5=py37_0 272 | - sphinx=2.4.0=py_0 273 | - sphinxcontrib=1.0=py37_1 274 | - sphinxcontrib-applehelp=1.0.1=py_0 275 | - sphinxcontrib-devhelp=1.0.1=py_0 276 | - sphinxcontrib-htmlhelp=1.0.2=py_0 277 | - sphinxcontrib-jsmath=1.0.1=py_0 278 | - sphinxcontrib-qthelp=1.0.2=py_0 279 | - sphinxcontrib-serializinghtml=1.1.3=py_0 280 | - sphinxcontrib-websupport=1.2.0=py_0 281 | - spyder=4.0.1=py37_0 282 | - spyder-kernels=1.8.1=py37_0 283 | - sqlalchemy=1.3.13=py37h7b6447c_0 284 | - sqlite=3.31.1=h7b6447c_0 285 | - statsmodels=0.11.0=py37h7b6447c_0 286 | - sympy=1.5.1=py37_0 287 | - tbb=2020.0=hfd86e86_0 288 | - tblib=1.6.0=py_0 289 | - terminado=0.8.3=py37_0 290 | - testpath=0.4.4=py_0 291 | - tk=8.6.8=hbc83047_0 292 | - toolz=0.10.0=py_0 293 | - torchvision=0.2.2=py_3 294 | - tornado=6.0.3=py37h7b6447c_3 295 | - tqdm=4.42.1=py_0 296 | - traitlets=4.3.3=py37_0 297 | - ujson=1.35=py37h14c3975_0 298 | - unicodecsv=0.14.1=py37_0 299 | - unixodbc=2.3.7=h14c3975_0 300 | - urllib3=1.25.8=py37_0 301 | - watchdog=0.10.2=py37_0 302 | - wcwidth=0.1.8=py_0 303 | - webencodings=0.5.1=py37_1 304 | - werkzeug=1.0.0=py_0 305 | - wheel=0.34.2=py37_0 306 | - widgetsnbextension=3.5.1=py37_0 307 | - wrapt=1.11.2=py37h7b6447c_0 308 | - wurlitzer=2.0.0=py37_0 309 | - xlrd=1.2.0=py37_0 310 | - xlsxwriter=1.2.7=py_0 311 | - xlwt=1.3.0=py37_0 312 | - xmltodict=0.12.0=py_0 313 | - xz=5.2.4=h14c3975_4 314 | - yaml=0.1.7=had09818_2 315 | - yapf=0.28.0=py_0 316 | - zeromq=4.3.1=he6710b0_3 317 | - zict=1.0.0=py_0 318 | - zipp=2.2.0=py_0 319 | - zlib=1.2.11=h7b6447c_3 320 | - zstd=1.3.7=h0b5b093_0 321 | - pip: 322 | - autopep8==1.5.2 323 | - blessings==1.7 324 | - gpustat==0.6.0 325 | - nvidia-ml-py3==7.352.0 326 | - opencv-contrib-python-headless==4.2.0.34 327 | - opencv-python-headless==4.2.0.34 328 | 329 | -------------------------------------------------------------------------------- /logs/readme_logs.txt: -------------------------------------------------------------------------------- 1 | logs will be saved in this folder -------------------------------------------------------------------------------- /models/AlexNet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | # from .cifar import get 5 | 6 | def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1): 7 | return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1)) 8 | 9 | class Net(torch.nn.Module): 10 | def __init__(self,inputsize,taskcla=None): 11 | super(Net,self).__init__() 12 | 13 | ncha,size,_=inputsize 14 | self.taskcla=taskcla 15 | 16 | self.conv1=torch.nn.Conv2d(ncha,64,kernel_size=size//8) 17 | s=compute_conv_output_size(size,size//8) 18 | s=s//2 19 | self.conv2=torch.nn.Conv2d(64,128,kernel_size=size//10) 20 | s=compute_conv_output_size(s,size//10) 21 | s=s//2 22 | self.conv3=torch.nn.Conv2d(128,256,kernel_size=2) 23 | s=compute_conv_output_size(s,2) 24 | s=s//2 25 | self.maxpool=torch.nn.MaxPool2d(2) 26 | self.relu=torch.nn.ReLU() 27 | 28 | self.drop1=torch.nn.Dropout(0.2) 29 | self.drop2=torch.nn.Dropout(0.5) 30 | # self.fc1=torch.nn.Linear(256*s*s,2048) 31 | self.fc1=torch.nn.Linear(2304,2048) 32 | self.fc2=torch.nn.Linear(2048,2048) 33 | self.last=torch.nn.ModuleList() 34 | 35 | self.last = torch.nn.Linear(2048,10) 36 | # data,taskcla,size = get() 37 | # for t,n in taskcla: 38 | # print('t n ',t,n) 39 | # self.last.append(torch.nn.Linear(2048,n)) 40 | 41 | return 42 | 43 | # def forward(self,x): 44 | # h=self.maxpool(self.drop1(self.relu(self.conv1(x)))) 45 | # h=self.maxpool(self.drop1(self.relu(self.conv2(h)))) 46 | # h=self.maxpool(self.drop2(self.relu(self.conv3(h)))) 47 | # h=h.view(x.size(0),-1) 48 | # h=self.drop2(self.relu(self.fc1(h))) 49 | # h=self.drop2(self.relu(self.fc2(h))) 50 | # y=self.last(h) 51 | # # for i in range(2): 52 | # # y.append(self.last[i](h)) 53 | # return y 54 | 55 | def forward(self,x): 56 | h=self.maxpool(self.drop1(self.relu(self.conv1(x)))) 57 | h=self.maxpool(self.drop1(self.relu(self.conv2(h)))) 58 | h=self.maxpool(self.drop2(self.relu(self.conv3(h)))) 59 | h=h.view(x.size(0),-1) 60 | h=self.drop2(self.relu(self.fc1(h))) 61 | h=self.drop2(self.relu(self.fc2(h))) 62 | y=self.last(h) 63 | # for i in range(2): 64 | # y.append(self.last[i](h)) 65 | return y 66 | -------------------------------------------------------------------------------- /models/EMNIST_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.normalization import GroupNorm 5 | 6 | class Net(nn.Module): 7 | def __init__(self): 8 | super(Net, self).__init__() 9 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 10 | self.bn1 = torch.nn.GroupNorm(32,32) 11 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 12 | self.bn2 = torch.nn.GroupNorm(32,64) 13 | self.conv2_drop = nn.Dropout2d(p=0.25) 14 | self.fc1 = nn.Linear(9216, 128) 15 | self.fc2 = nn.Linear(128, 47) 16 | 17 | def forward(self, x): 18 | x = F.relu(self.conv1(x)) 19 | x = F.relu(self.conv2_drop(F.max_pool2d(self.conv2(x), 2))) 20 | 21 | x = x.view(-1, 9216) 22 | x = F.relu(self.fc1(x)) 23 | 24 | x = self.fc2(x) 25 | return x 26 | -------------------------------------------------------------------------------- /models/EMNIST_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | from torch.autograd import Variable 7 | from EMNIST_model import * 8 | # Training settings 9 | batch_size = 64 10 | 11 | 12 | train_dataset = datasets.EMNIST(root='./data/', 13 | train=True, 14 | transform=transforms.ToTensor(),split='byclass', 15 | download=True) 16 | 17 | test_dataset = datasets.EMNIST(root='./data/', 18 | train=False,split='byclass', 19 | transform=transforms.ToTensor()) 20 | 21 | # Data Loader (Input Pipeline) 22 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 23 | batch_size=batch_size, 24 | shuffle=True) 25 | 26 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 27 | batch_size=batch_size, 28 | shuffle=False) 29 | 30 | model = Net() 31 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 32 | 33 | def train(epoch): 34 | for batch_idx, (data, target) in enumerate(train_loader): 35 | 36 | 37 | output = model(data) 38 | #output:64*10 39 | 40 | loss = F.nll_loss(output, target) 41 | 42 | if batch_idx % 200 == 0: 43 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 44 | epoch, batch_idx * len(data), len(train_loader.dataset), 45 | 100. * batch_idx / len(train_loader), loss.data[0])) 46 | 47 | optimizer.zero_grad() 48 | loss.backward() 49 | optimizer.step() 50 | -------------------------------------------------------------------------------- /models/MLP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MNIST_MLP(nn.Module): 6 | """ 7 | global batch_size = 100 8 | """ 9 | def __init__(self, num_classes): 10 | super(MNIST_MLP, self).__init__() 11 | self.layers = nn.ModuleList() 12 | self.layers.append(nn.Linear(28*28, 500)) 13 | self.layers.append(nn.Linear(500, 500)) 14 | self.layers.append(nn.Linear(500, num_classes)) 15 | # self.fc1 = nn.Linear(28*28, 500) 16 | # self.fc2 = nn.Linear(500, 500) 17 | # self.fc3 = nn.Linear(500, 10) 18 | 19 | def forward(self, x): # x: (batch, ) 20 | # x = x.view(-1, 28*28) 21 | # x = F.relu(self.fc1(x)) 22 | # x = F.relu(self.fc2(x)) 23 | # x = self.fc3(x) 24 | # return x 25 | x = x.view(-1, 28 * 28) 26 | x = F.relu(self.layers[0](x)) 27 | x = F.relu(self.layers[1](x)) 28 | x = self.layers[2](x) 29 | return x 30 | 31 | def get_weights(self): 32 | weights = [] 33 | for layer in self.layers: 34 | weights.append(layer.weight) 35 | return weights 36 | 37 | def get_gradients(self): 38 | gradients = [] 39 | for layer in self.layers: 40 | gradients.append(layer.weight.grad) 41 | 42 | return gradients 43 | 44 | def assign_gradients(self, gradients): 45 | for idx, layer in enumerate(self.layers): 46 | layer.weight.grad.data = gradients[idx] 47 | 48 | def update_weights(self, gradients, lr): 49 | for idx, layer in enumerate(self.layers): 50 | layer.weight.data -= lr * gradients[idx].data 51 | 52 | def initialize_new_grads(self): 53 | init_grads = [] 54 | for layer in self.layers: 55 | init_grads.append(torch.zeros_like(layer.weight)) 56 | return init_grads 57 | 58 | -------------------------------------------------------------------------------- /models/Semi_net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | # from .cifar import get 5 | 6 | def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1): 7 | return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1)) 8 | 9 | class SemiNet(torch.nn.Module): 10 | def __init__(self,in_channels=3,taskcla=None): 11 | super(SemiNet,self).__init__() 12 | 13 | self.conv1 = torch.nn.Conv2d(in_channels,32,kernel_size=3,stride=1, padding=1) 14 | self.GN1 = torch.nn.GroupNorm(32,32) 15 | self.BN1 = torch.nn.BatchNorm2d(32) 16 | 17 | self.conv2=torch.nn.Conv2d(32,64,kernel_size=3,stride=1, padding=1) 18 | 19 | self.conv3=torch.nn.Conv2d(64,128,kernel_size=3,stride=1, padding=1) 20 | self.GN2 = torch.nn.GroupNorm(32,128) 21 | self.BN2 = torch.nn.BatchNorm2d(128) 22 | 23 | self.conv4=torch.nn.Conv2d(128,128,kernel_size=3,stride=1, padding=1) 24 | 25 | self.conv5=torch.nn.Conv2d(128,256,kernel_size=3,stride=1, padding=1) 26 | self.GN3 = torch.nn.GroupNorm(32,256) 27 | self.BN3 = torch.nn.BatchNorm2d(256) 28 | 29 | self.conv6=torch.nn.Conv2d(256,256,kernel_size=3,stride=1, padding=1) 30 | 31 | self.maxpool=torch.nn.MaxPool2d(2) 32 | self.relu=torch.nn.ReLU() 33 | 34 | self.drop1=torch.nn.Dropout(0.05) 35 | self.drop2=torch.nn.Dropout(0.1) 36 | self.drop3=torch.nn.Dropout(0.1) 37 | 38 | self.fc1=torch.nn.Linear(4096,1024) 39 | self.fc2=torch.nn.Linear(1024,512) 40 | self.fc3 = torch.nn.Linear(512,10) 41 | 42 | return 43 | 44 | 45 | def forward(self,x): 46 | 47 | h = self.relu(self.BN1((self.conv1(x)))) 48 | 49 | h = self.maxpool(self.relu(self.conv2(h))) 50 | 51 | h = self.relu(self.BN2((self.conv3(h)))) 52 | 53 | h= self.drop1(self.maxpool(self.relu(self.conv4(h)))) 54 | 55 | h = self.relu(self.BN3((self.conv5(h)))) 56 | 57 | h = self.maxpool(self.relu(self.conv6(h))) 58 | 59 | h=h.view(x.size(0),-1) 60 | 61 | h=self.drop2(self.relu(self.fc1(h))) 62 | h=self.drop3(self.relu(self.fc2(h))) 63 | 64 | y=self.fc3(h) 65 | 66 | return y 67 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vggnet import * 2 | from .resnet import * 3 | from .wrn import * 4 | from .MLP import * 5 | from .resnet_gn import * 6 | from .resnet_ln import * 7 | from .EMNIST_model import * -------------------------------------------------------------------------------- /models/__pycache__/EMNIST_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/EMNIST_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/EMNIST_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/EMNIST_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/EMNIST_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/EMNIST_model.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/MLP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/MLP.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/MLP.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/MLP.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/MLP.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/MLP.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/Semi_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/Semi_net.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/Semi_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/Semi_net.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/alexnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/alexnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/cifar.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/cifar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/cifar.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet9.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet9.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet9.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet9.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_gn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet_gn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_gn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet_gn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_gn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet_gn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_ln.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet_ln.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_ln.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet_ln.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_ln.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/resnet_ln.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/vgg.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/vggnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/vggnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vggnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/vggnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/vggnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/vggnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/wrn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/wrn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/wrn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/wrn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/wrn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhcknzzm/SSFL-Benchmarking-Semi-supervised-Federated-Learning/9a18e895da73a3d3d14c239c6fa10de0d1d2fef5/models/__pycache__/wrn.cpython-38.pyc -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def accuracy(outputs, labels): 6 | _, preds = torch.max(outputs, dim=1) 7 | return torch.tensor(torch.sum(preds == labels).item() / len(preds)) 8 | 9 | 10 | class Base(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def train_step(self, batch, model, device): 15 | images, labels = batch 16 | images = images.to(device) 17 | output = model(images) 18 | loss_fn = nn.CrossEntropyLoss() 19 | loss = loss_fn(output, labels) 20 | return loss 21 | 22 | def validation_step(self, batch, model, device): 23 | images, labels = batch 24 | images = images.to(device) 25 | labels = labels.to(device) 26 | output = model(images) 27 | acc = accuracy(output, labels) 28 | return acc 29 | -------------------------------------------------------------------------------- /models/cifar.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import torch 4 | # import utils 5 | from torchvision import datasets,transforms 6 | from sklearn.utils import shuffle 7 | 8 | def get(seed=0, pc_valid=0.10): 9 | data={} 10 | taskcla=[] 11 | size=[3,32,32] 12 | 13 | if not os.path.isdir('../data/binary_cifar/'): 14 | os.makedirs('../data/binary_cifar') 15 | 16 | mean=[x/255 for x in [125.3,123.0,113.9]] 17 | std=[x/255 for x in [63.0,62.1,66.7]] 18 | 19 | # CIFAR10 20 | dat={} 21 | dat['train']=datasets.CIFAR10('../data/',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 22 | dat['test']=datasets.CIFAR10('../data/',train=False,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 23 | for n in range(5): 24 | data[n]={} 25 | data[n]['name']='cifar10' 26 | data[n]['ncla']=2 27 | data[n]['train']={'x': [],'y': []} 28 | data[n]['test']={'x': [],'y': []} 29 | for s in ['train','test']: 30 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 31 | for image,target in loader: 32 | n=target.numpy()[0] 33 | nn=n//2 34 | data[nn][s]['x'].append(image) 35 | data[nn][s]['y'].append(n%2) 36 | 37 | # CIFAR100 38 | dat={} 39 | dat['train']=datasets.CIFAR100('../data/',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 40 | dat['test']=datasets.CIFAR100('../data/',train=False,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])) 41 | for n in range(5,10): 42 | data[n]={} 43 | data[n]['name']='cifar100' 44 | data[n]['ncla']=20 45 | data[n]['train']={'x': [],'y': []} 46 | data[n]['test']={'x': [],'y': []} 47 | for s in ['train','test']: 48 | loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False) 49 | for image,target in loader: 50 | n=target.numpy()[0] 51 | nn=(n//20)+5 52 | data[nn][s]['x'].append(image) 53 | data[nn][s]['y'].append(n%20) 54 | 55 | # "Unify" and save 56 | for t in data.keys(): 57 | for s in ['train','test']: 58 | data[t][s]['x']=torch.stack(data[t][s]['x']).view(-1,size[0],size[1],size[2]) 59 | data[t][s]['y']=torch.LongTensor(np.array(data[t][s]['y'],dtype=int)).view(-1) 60 | torch.save(data[t][s]['x'], os.path.join(os.path.expanduser('../data/binary_cifar'),'data'+str(t)+s+'x.bin')) 61 | torch.save(data[t][s]['y'], os.path.join(os.path.expanduser('../data/binary_cifar'),'data'+str(t)+s+'y.bin')) 62 | 63 | # Load binary files 64 | data={} 65 | ids=list(shuffle(np.arange(10),random_state=seed)) 66 | print('Task order =',ids) 67 | for i in range(10): 68 | data[i] = dict.fromkeys(['name','ncla','train','test']) 69 | for s in ['train','test']: 70 | data[i][s]={'x':[],'y':[]} 71 | data[i][s]['x']=torch.load(os.path.join(os.path.expanduser('../data/binary_cifar'),'data'+str(ids[i])+s+'x.bin')) 72 | data[i][s]['y']=torch.load(os.path.join(os.path.expanduser('../data/binary_cifar'),'data'+str(ids[i])+s+'y.bin')) 73 | data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy())) 74 | if data[i]['ncla']==2: 75 | data[i]['name']='cifar10-'+str(ids[i]) 76 | else: 77 | data[i]['name']='cifar100-'+str(ids[i]-5) 78 | 79 | # Validation 80 | for t in data.keys(): 81 | r=np.arange(data[t]['train']['x'].size(0)) 82 | r=np.array(shuffle(r,random_state=seed),dtype=int) 83 | nvalid=int(pc_valid*len(r)) 84 | ivalid=torch.LongTensor(r[:nvalid]) 85 | itrain=torch.LongTensor(r[nvalid:]) 86 | data[t]['valid']={} 87 | data[t]['valid']['x']=data[t]['train']['x'][ivalid].clone() 88 | data[t]['valid']['y']=data[t]['train']['y'][ivalid].clone() 89 | data[t]['train']['x']=data[t]['train']['x'][itrain].clone() 90 | data[t]['train']['y']=data[t]['train']['y'][itrain].clone() 91 | 92 | # Others 93 | n=0 94 | for t in data.keys(): 95 | taskcla.append((t,data[t]['ncla'])) 96 | n+=data[t]['ncla'] 97 | data['ncla']=n 98 | 99 | return data,taskcla,size 100 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | 78 | def _make_layer(self, block, planes, num_blocks, stride): 79 | strides = [stride] + [1]*(num_blocks-1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_planes, planes, stride)) 83 | self.in_planes = planes * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = F.relu(self.bn1(self.conv1(x))) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | out = F.avg_pool2d(out, 4) 93 | out = out.view(out.size(0), -1) 94 | out = self.linear(out) 95 | return out 96 | 97 | 98 | def ResNet18(): 99 | return ResNet(BasicBlock, [2,2,2,2]) 100 | 101 | def ResNet34(): 102 | return ResNet(BasicBlock, [3,4,6,3]) 103 | 104 | def ResNet50(): 105 | return ResNet(Bottleneck, [3,4,6,3]) 106 | 107 | def ResNet101(): 108 | return ResNet(Bottleneck, [3,4,23,3]) 109 | 110 | def ResNet152(): 111 | return ResNet(Bottleneck, [3,8,36,3]) 112 | 113 | 114 | def test(): 115 | net = ResNet18() 116 | y = net(torch.randn(1,3,32,32)) 117 | print(y.size()) 118 | 119 | # test() -------------------------------------------------------------------------------- /models/resnet9.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # from base import Base 3 | 4 | def accuracy(outputs, labels): 5 | _, preds = torch.max(outputs, dim=1) 6 | return torch.tensor(torch.sum(preds == labels).item() / len(preds)) 7 | 8 | 9 | class Base(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def train_step(self, batch, model, device): 14 | images, labels = batch 15 | images = images.to(device) 16 | output = model(images) 17 | loss_fn = nn.CrossEntropyLoss() 18 | loss = loss_fn(output, labels) 19 | return loss 20 | 21 | def validation_step(self, batch, model, device): 22 | images, labels = batch 23 | images = images.to(device) 24 | labels = labels.to(device) 25 | output = model(images) 26 | acc = accuracy(output, labels) 27 | return acc 28 | 29 | def conv_bn_relu_pool(in_channels, out_channels, pool=False): 30 | layers = [ 31 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), 32 | # nn.BatchNorm2d(out_channels), 33 | nn.GroupNorm(32,out_channels), 34 | nn.ReLU(inplace=True) 35 | ] 36 | if pool: 37 | layers.append(nn.MaxPool2d(2)) 38 | return nn.Sequential(*layers) 39 | 40 | class ResNet9(Base): 41 | def __init__(self, in_channels, num_classes): 42 | super().__init__() 43 | self.prep = conv_bn_relu_pool(in_channels, 64) 44 | self.layer1_head = conv_bn_relu_pool(64, 128, pool=True) 45 | self.layer1_residual = nn.Sequential(conv_bn_relu_pool(128, 128), conv_bn_relu_pool(128, 128)) 46 | self.layer2 = conv_bn_relu_pool(128, 256, pool=True) 47 | self.layer3_head = conv_bn_relu_pool(256, 512, pool=True) 48 | self.layer3_residual = nn.Sequential(conv_bn_relu_pool(512, 512), conv_bn_relu_pool(512, 512)) 49 | self.MaxPool2d = nn.Sequential( 50 | nn.MaxPool2d(4)) 51 | self.linear = nn.Linear(512, num_classes) 52 | # self.classifier = nn.Sequential( 53 | # nn.MaxPool2d(4), 54 | # nn.Flatten(), 55 | # nn.Linear(512, num_classes)) 56 | 57 | 58 | def forward(self, x): 59 | x = self.prep(x) 60 | x = self.layer1_head(x) 61 | x = self.layer1_residual(x) + x 62 | x = self.layer2(x) 63 | x = self.layer3_head(x) 64 | x = self.layer3_residual(x) + x 65 | x = self.MaxPool2d(x) 66 | x = x.view(x.size(0), -1) 67 | x = self.linear(x) 68 | return x 69 | -------------------------------------------------------------------------------- /models/resnet_gn.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.modules.normalization import GroupNorm 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | 20 | self.bn1 = torch.nn.GroupNorm(32,planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | self.bn2 = torch.nn.GroupNorm(32,planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion*planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 29 | torch.nn.GroupNorm(32,self.expansion*planes) 30 | 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | 48 | self.bn1 = torch.nn.GroupNorm(32,planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | 51 | self.bn2 = torch.nn.GroupNorm(32,planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 53 | 54 | self.bn3 = torch.nn.GroupNorm(32,self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 60 | torch.nn.GroupNorm(32,self.expansion*planes) 61 | ) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = F.relu(self.bn2(self.conv2(out))) 66 | out = self.bn3(self.conv3(out)) 67 | out += self.shortcut(x) 68 | out = F.relu(out) 69 | return out 70 | 71 | 72 | class ResNet_gn(nn.Module): 73 | def __init__(self, block, num_blocks, num_classes=10): 74 | super(ResNet_gn, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 78 | 79 | self.bn1 = torch.nn.GroupNorm(32,64) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.linear = nn.Linear(512*block.expansion, num_classes) 85 | 86 | def _make_layer(self, block, planes, num_blocks, stride): 87 | strides = [stride] + [1]*(num_blocks-1) 88 | layers = [] 89 | for stride in strides: 90 | layers.append(block(self.in_planes, planes, stride)) 91 | self.in_planes = planes * block.expansion 92 | return nn.Sequential(*layers) 93 | 94 | def forward(self, x): 95 | out = F.relu(self.bn1(self.conv1(x))) 96 | out = self.layer1(out) 97 | out = self.layer2(out) 98 | out = self.layer3(out) 99 | out = self.layer4(out) 100 | out = F.avg_pool2d(out, 4) 101 | out = out.view(out.size(0), -1) 102 | out = self.linear(out) 103 | return out 104 | 105 | 106 | def ResNet18_gn(): 107 | return ResNet_gn(BasicBlock, [2,2,2,2]) 108 | 109 | def ResNet34_gn(): 110 | return ResNet_gn(BasicBlock, [3,4,6,3]) 111 | 112 | def ResNet50_gn(): 113 | return ResNet_gn(Bottleneck, [3,4,6,3]) 114 | 115 | def ResNet101_gn(): 116 | return ResNet_gn(Bottleneck, [3,4,23,3]) 117 | 118 | def ResNet152_gn(): 119 | return ResNet_gn(Bottleneck, [3,8,36,3]) 120 | 121 | 122 | def test(): 123 | net = ResNet18_gn() 124 | y = net(torch.randn(1,3,32,32)) 125 | print(y.size()) 126 | 127 | # test() 128 | -------------------------------------------------------------------------------- /models/resnet_ln.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | 13 | class FilterResponseNormNd(nn.Module): 14 | def __init__(self, ndim, num_features, eps=1e-6, 15 | learnable_eps=False): 16 | """ 17 | Input Variables: 18 | ---------------- 19 | ndim: An integer indicating the number of dimensions of the expected input tensor. 20 | num_features: An integer indicating the number of input feature dimensions. 21 | eps: A scalar constant or learnable variable. 22 | learnable_eps: A bool value indicating whether the eps is learnable. 23 | """ 24 | assert ndim in [3, 4, 5], \ 25 | 'FilterResponseNorm only supports 3d, 4d or 5d inputs.' 26 | super(FilterResponseNormNd, self).__init__() 27 | shape = (1, num_features) + (1, ) * (ndim - 2) 28 | self.eps = nn.Parameter(torch.ones(*shape) * eps) 29 | if not learnable_eps: 30 | self.eps.requires_grad_(False) 31 | self.gamma = nn.Parameter(torch.Tensor(*shape)) 32 | self.beta = nn.Parameter(torch.Tensor(*shape)) 33 | self.tau = nn.Parameter(torch.Tensor(*shape)) 34 | self.reset_parameters() 35 | 36 | 37 | 38 | def forward(self, x): 39 | avg_dims = tuple(range(2, x.dim())) 40 | nu2 = torch.pow(x, 2).mean(dim=avg_dims, keepdim=True) 41 | x = x * torch.rsqrt(nu2 + torch.abs(self.eps)) 42 | return torch.max(self.gamma * x + self.beta, self.tau) 43 | 44 | def reset_parameters(self): 45 | nn.init.ones_(self.gamma) 46 | nn.init.zeros_(self.beta) 47 | nn.init.zeros_(self.tau) 48 | 49 | class FilterResponseNorm1d(FilterResponseNormNd): 50 | def __init__(self, num_features, eps=1e-6, learnable_eps=False): 51 | super(FilterResponseNorm1d, self).__init__( 52 | 3, num_features, eps=eps, learnable_eps=learnable_eps) 53 | 54 | class FilterResponseNorm2d(FilterResponseNormNd): 55 | def __init__(self, num_features, eps=1e-6, learnable_eps=False): 56 | super(FilterResponseNorm2d, self).__init__( 57 | 4, num_features, eps=eps, learnable_eps=learnable_eps) 58 | 59 | class FilterResponseNorm3d(FilterResponseNormNd): 60 | def __init__(self, num_features, eps=1e-6, learnable_eps=False): 61 | super(FilterResponseNorm3d, self).__init__( 62 | 5, num_features, eps=eps, learnable_eps=learnable_eps) 63 | 64 | 65 | class BasicBlock(nn.Module): 66 | expansion = 1 67 | 68 | def __init__(self, in_planes, planes, stride=1): 69 | super(BasicBlock, self).__init__() 70 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 71 | self.bn1 = FilterResponseNorm2d(planes) 72 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn2 = FilterResponseNorm2d(planes) 74 | 75 | self.shortcut = nn.Sequential() 76 | if stride != 1 or in_planes != self.expansion*planes: 77 | self.shortcut = nn.Sequential( 78 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 79 | FilterResponseNorm2d(self.expansion*planes) 80 | ) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = self.bn2(self.conv2(out)) 85 | out += self.shortcut(x) 86 | out = F.relu(out) 87 | return out 88 | 89 | 90 | class Bottleneck(nn.Module): 91 | expansion = 4 92 | 93 | def __init__(self, in_planes, planes, stride=1): 94 | super(Bottleneck, self).__init__() 95 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 96 | self.bn1 = FilterResponseNorm2d(planes) 97 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 98 | self.bn2 = FilterResponseNorm2d(planes) 99 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 100 | self.bn3 = FilterResponseNorm2d(self.expansion*planes) 101 | 102 | self.shortcut = nn.Sequential() 103 | if stride != 1 or in_planes != self.expansion*planes: 104 | self.shortcut = nn.Sequential( 105 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 106 | FilterResponseNorm2d(self.expansion*planes) 107 | ) 108 | 109 | def forward(self, x): 110 | out = F.relu(self.bn1(self.conv1(x))) 111 | out = F.relu(self.bn2(self.conv2(out))) 112 | out = self.bn3(self.conv3(out)) 113 | out += self.shortcut(x) 114 | out = F.relu(out) 115 | return out 116 | 117 | 118 | class ResNet_LN(nn.Module): 119 | def __init__(self, block, num_blocks, num_classes=10): 120 | super(ResNet_LN, self).__init__() 121 | self.in_planes = 64 122 | 123 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 124 | self.bn1 = FilterResponseNorm2d(64) 125 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 126 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 127 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 128 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 129 | self.linear = nn.Linear(512*block.expansion, num_classes) 130 | 131 | def _make_layer(self, block, planes, num_blocks, stride): 132 | strides = [stride] + [1]*(num_blocks-1) 133 | layers = [] 134 | for stride in strides: 135 | layers.append(block(self.in_planes, planes, stride)) 136 | self.in_planes = planes * block.expansion 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | out = F.relu(self.bn1(self.conv1(x))) 141 | out = self.layer1(out) 142 | out = self.layer2(out) 143 | out = self.layer3(out) 144 | out = self.layer4(out) 145 | out = F.avg_pool2d(out, 4) 146 | out = out.view(out.size(0), -1) 147 | out = self.linear(out) 148 | return out 149 | 150 | 151 | def ResNet18_LN(): 152 | return ResNet_LN(BasicBlock, [2,2,2,2]) 153 | 154 | def ResNet34_LN(): 155 | return ResNet_LN(BasicBlock, [3,4,6,3]) 156 | 157 | def ResNet50_LN(): 158 | return ResNet_LN(Bottleneck, [3,4,6,3]) 159 | 160 | def ResNet101_LN(): 161 | return ResNet_LN(Bottleneck, [3,4,23,3]) 162 | 163 | def ResNet152_LN(): 164 | return ResNet_LN(Bottleneck, [3,8,36,3]) 165 | 166 | 167 | def test(): 168 | net = ResNet18_LN() 169 | y = net(torch.randn(1,3,32,32)) 170 | print(y.size()) 171 | 172 | # test() 173 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | 5 | n_channel=32 6 | 7 | cfg = { 8 | 'VGG9': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M'], 9 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, vgg_name, input_size=96, num_class=10): 17 | super(VGG, self).__init__() 18 | self.input_size = input_size 19 | self.features = self._make_layers(cfg[vgg_name]) 20 | self.n_maps = cfg[vgg_name][-2] 21 | self.fc = self._make_fc_layers() 22 | self.classifier1 = nn.Linear(self.n_maps, 128) 23 | self.classifier2 = nn.Linear(128, num_class) 24 | 25 | def forward(self, x): 26 | out = self.features(x) 27 | out = out.view(out.size(0), -1) 28 | # print('out.size',out.size(),self.n_maps) 29 | out = self.fc(out) 30 | out = self.classifier1(out) 31 | out = self.classifier2(out) 32 | return out 33 | 34 | def _make_fc_layers(self): 35 | layers = [] 36 | # print(self.n_maps*self.input_size*self.input_size, self.n_maps) 37 | # print('self.input_size',self.input_size) 38 | # layers += [nn.Linear(self.n_maps*self.input_size*self.input_size, self.n_maps), 39 | # nn.BatchNorm1d(self.n_maps), 40 | # nn.ReLU(inplace=True)] 41 | layers += [nn.Linear(4608, self.n_maps), 42 | nn.BatchNorm1d(self.n_maps), 43 | nn.ReLU(inplace=True)] 44 | return nn.Sequential(*layers) 45 | 46 | def _make_layers(self, cfg): 47 | layers = [] 48 | in_channels = 3 49 | for x in cfg: 50 | if x == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 52 | self.input_size = self.input_size // 2 53 | else: 54 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(x), 56 | nn.ReLU(inplace=True)] 57 | in_channels = x 58 | return nn.Sequential(*layers) 59 | 60 | def VGG9(input_size, num_class): 61 | return VGG('VGG9', input_size, num_class) 62 | 63 | def VGG16(input_size, num_class): 64 | return VGG('VGG16', input_size, num_class) 65 | 66 | def VGG19(input_size, num_class): 67 | return VGG('VGG19', input_size, num_class) 68 | 69 | def VGG11(input_size, num_class): 70 | return VGG('VGG11', input_size, num_class) 71 | 72 | def test(): 73 | #net = VGG('VGG11', input_size=96,num_class=10) 74 | net = VGG('VGG11', input_size=32,num_class=10) 75 | print(net) 76 | x = torch.randn(128, 3, 96, 96) 77 | # x = torch.randn(2, 3, 32, 32) 78 | y = net(x) 79 | print(y.size()) 80 | # test() 81 | -------------------------------------------------------------------------------- /models/vggnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | def conv_init(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 9 | init.constant(m.bias, 0) 10 | 11 | def cfg(depth): 12 | depth_lst = [11, 13, 16, 19] 13 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19" 14 | cf_dict = { 15 | '11': [ 16 | 64, 'mp', 17 | 128, 'mp', 18 | 256, 256, 'mp', 19 | 512, 512, 'mp', 20 | 512, 512, 'mp'], 21 | '13': [ 22 | 64, 64, 'mp', 23 | 128, 128, 'mp', 24 | 256, 256, 'mp', 25 | 512, 512, 'mp', 26 | 512, 512, 'mp' 27 | ], 28 | '16': [ 29 | 64, 64, 'mp', 30 | 128, 128, 'mp', 31 | 256, 256, 256, 'mp', 32 | 512, 512, 512, 'mp', 33 | 512, 512, 512, 'mp' 34 | ], 35 | '19': [ 36 | 64, 64, 'mp', 37 | 128, 128, 'mp', 38 | 256, 256, 256, 256, 'mp', 39 | 512, 512, 512, 512, 'mp', 40 | 512, 512, 512, 512, 'mp' 41 | ], 42 | } 43 | 44 | return cf_dict[str(depth)] 45 | 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 48 | 49 | class VGG(nn.Module): 50 | def __init__(self, depth, num_classes): 51 | super(VGG, self).__init__() 52 | self.features = self._make_layers(cfg(depth)) 53 | self.linear = nn.Linear(512, num_classes) 54 | 55 | def forward(self, x): 56 | out = self.features(x) 57 | out = out.view(out.size(0), -1) 58 | out = self.linear(out) 59 | 60 | return out 61 | 62 | def _make_layers(self, cfg): 63 | layers = [] 64 | in_planes = 3 65 | 66 | for x in cfg: 67 | if x == 'mp': 68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 69 | else: 70 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] 71 | in_planes = x 72 | 73 | # After cfg convolution 74 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 75 | return nn.Sequential(*layers) 76 | 77 | if __name__ == "__main__": 78 | net = VGG(16, 10) 79 | y = net(Variable(torch.randn(1,3,32,32))) 80 | print(y.size()) 81 | -------------------------------------------------------------------------------- /models/wrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 17 | init.constant(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant(m.weight, 1) 20 | init.constant(m.bias, 0) 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth-4)//6 51 | k = widen_factor 52 | 53 | #print('| Wide-Resnet %dx%d' %(depth, k)) 54 | nStages = [16, 16*k, 32*k, 64*k] 55 | 56 | self.conv1 = conv3x3(3,nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3], num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1]*(num_blocks-1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out 84 | 85 | if __name__ == '__main__': 86 | net=Wide_ResNet(28, 10, 0.3, 10) 87 | y = net(Variable(torch.randn(1,3,32,32))) 88 | 89 | print(y.size()) 90 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Improving Semi-supervised Federated Learning by Reducing the Gradient Diversity of Models 2 | 3 | To run the main scripte "train_parallel.py", one needs to determine number of available GPUs. 4 | If the number of GPUs on your machine is less than 8, for example, the number of GPUs is 4, you need to specify GPU_list as 0123 in run_cifar10.sh, run_svhn.sh and run_emnist.sh. 5 | 6 | Please use Anaconda to install all the dependencies: 7 | 8 | ``` 9 | conda env create -f environment.yml 10 | ``` 11 | If you have 8 GPUs on your machine, you can do one experiment using the following command: 12 | 13 | ``` 14 | python train_parallel.py [--GPU_list] [--datasetid] [--size] [--basicLabelRatio] [--labeled] [--num_comm_ue] [--H] [--cp] [--eval_grad] [--experiment_folder] [--experiment_name] [--tao] [--model] --[ue_loss] [--user_semi] 15 | [--epoch] [--batch_size] [--fast] [--Ns] 16 | optional arguments: 17 | --GPU_list: GPUs used for training, e.g., --GPU_list 0123456789 18 | --datasetid: the id of the datasets (default: 0), datasetid = 0/1/2 means the Cifar-10/SVHN/EMNIST dataset is used in the experiment. 19 | --size: size = K (users) + 1 (server); 20 | --cp: cp in {2, 4, 8, 16} is frequency of communication; cp = 2 means UEs and server communicates every 2 iterations; 21 | --basicLabelRatio: basicLabelRatio in {0.1, 0.2, 0.4, ..., 1.0}, is the degree of data dispersion for each UE, 22 | basicLabelRatio = 0.0 means UE has the same amount of samples in each class; basicLabelRatio = 1.0 means all samples owned 23 | by UE belong to the same class; 24 | --model: model in {'res', 'res_gn', 'EMNIST_model'}; model = 'res' means we use ResNet18 + BN; model = 'res_gn' means we use ResNet18 + GN, EMNIST_model are used to train SSFL models on EMNIST dataset; 25 | --num_comm_ue: num_comm_ue in {1, 2, ..., K}; communication user number per iteration; 26 | --H: H in {0, 1}; use grouping-based method or not; H = 1 means we use grouping-based method; H = 0 means we use FedAvg method; 27 | --Ns: num_data_server in {1000, 4000}, number of labeled samples in server; 28 | --labeled: labeled in {0, 1}, labeled=1 means supervised FL, labeled=0 means semi-supervised FL; 29 | --cp: cp in {2,4,8,16,32,64} is communication period between the users and the server 30 | --eval_grad: eval_grad in {0, 1}, eval_grad=1 means that we load the model stored during training to calculate the gradient; 31 | --experiment_folder: storage directory of experiment checkpoints; 32 | --experiment_name: the name of current experiment; 33 | --tao: Hyperparameters used to calculate CRL; 34 | --model: Neural network model for training; 35 | --ue_loss: ue_loss=CRL means we use CRL as the loss for local training; ue_loss=SF means we use self-training method for local training; 36 | --epoch: training epoches of SSFL; 37 | --batch_size: batch size used for training; 38 | --fast: Hyperparameters used for learning rate update; 39 | --Ns: The number of labeled data in the server. 40 | ``` 41 | For example, you can also run the following script to reproduce the results of SSFL on Cifar10 in the setting of K=C=10, Ns=1000, model=res_gn with the non-iidness 0.4 and the grouping-based method. 42 | ``` 43 | python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 0.4 --experiment_name Cifar10_res_gn_H1_comUE10_R0.4_SSFL 44 | ``` 45 | 46 | The results will be saved in the folder results_v0, and the checkpoints will be save in "/checkpoints/Cifar10_res_gn_H1_comUE10_R0.4_SSFL" 47 | 48 | You can also run the following script to reproduce all the results of our paper. 49 | ``` 50 | nohup bash Run_Exper.sh 51 | ``` 52 | 53 | When all the checkpoints are saved, you can run the following script to calculate gradient diversity to reproduce the results on gradient diversity of our paper: 54 | ``` 55 | nohup bash Grad_Diff.sh 56 | ``` 57 | When all gradient diversities are calculated, you can run the following script to plot the results of gradient diversity. 58 | ``` 59 | python Plot_grad_diversity.py 60 | ``` 61 | -------------------------------------------------------------------------------- /run_cifar10.sh: -------------------------------------------------------------------------------- 1 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res --experiment_name Cifar10_res_H0_comUE10_R0.4_SSFL 2 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res --experiment_name Cifar10_res_H0_comUE10_R0.4_SFL 3 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res --experiment_name Cifar10_res_H0_comUE10_R0.0_SFL --basicLabelRatio 0.0 4 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --experiment_name Cifar10_res_gn_H0_comUE10_R0.4_SSFL 5 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --experiment_name Cifar10_res_gn_H1_comUE10_R0.4_SSFL 6 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res --experiment_name Cifar10_res_H0_comUE10_R0.0_SSFL --basicLabelRatio 0.0 7 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL --basicLabelRatio 0.0 8 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --Ns 2000 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_2000_eval_grad_0 9 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --Ns 3000 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_3000_eval_grad_0 10 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --Ns 4000 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_4000_eval_grad_0 11 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --Ns 5000 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_5000_eval_grad_0 12 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --cp 2 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_2000_eval_grad_0_cp2 13 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --cp 4 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_2000_eval_grad_0_cp4 14 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --cp 8 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_2000_eval_grad_0_cp8 15 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --cp 32 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_2000_eval_grad_0_cp32 16 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --cp 64 --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL_Ns_2000_eval_grad_0_cp64 17 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 0.0 --experiment_name Cifar10_res_gn_H1_comUE10_R0.0_SSFL 18 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 0.2 --experiment_name Cifar10_res_gn_H1_comUE10_R0.2_SSFL 19 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 0.6 --experiment_name Cifar10_res_gn_H1_comUE10_R0.6_SSFL 20 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 0.8 --experiment_name Cifar10_res_gn_H1_comUE10_R0.8_SSFL 21 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 0 --model res_gn --basicLabelRatio 1.0 --experiment_name Cifar10_res_gn_H1_comUE10_R1.0_SSFL 22 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 21 --epoch 300 --eval_grad 0 --model res_gn 23 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 20 --size 21 --epoch 300 --eval_grad 0 --model res_gn 24 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 31 --epoch 300 --eval_grad 0 --model res_gn 25 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 30 --size 31 --epoch 300 --eval_grad 0 --model res_gn 26 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 1 --model res --experiment_name Cifar10_res_H0_comUE10_R0.4_SSFL 27 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 1 --model res_gn --experiment_name Cifar10_res_gn_H0_comUE10_R0.4_SSFL 28 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 1 --model res_gn --experiment_name Cifar10_res_gn_H1_comUE10_R0.4_SSFL 29 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 1 --model res --experiment_name Cifar10_res_H0_comUE10_R0.0_SSFL --basicLabelRatio 0.0 30 | nohup python train_parallel.py --GPU_list 01234 --H 0 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 1 --model res_gn --experiment_name Cifar10_res_gn_H0_comUE10_R0.0_SSFL --basicLabelRatio 0.0 31 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 300 --eval_grad 1 --model res_gn --basicLabelRatio 0.0 --experiment_name Cifar10_res_gn_H1_comUE10_R0.0_SSFL 32 | -------------------------------------------------------------------------------- /run_cifar10_res9.sh: -------------------------------------------------------------------------------- 1 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 5 --size 101 --epoch 200 --basicLabelRatio 0.0 --cp 5 --k_img 1000 --Ns 5000 --eval_grad 0 --model res9 --experiment_name Cifar10_res9_H1_comUE5_R0.0_SSFL 2 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 5 --size 100 --user_semi 1 --epoch 200 --basicLabelRatio 0.0 --cp 5 --k_img 1000 --Ns 5000 --eval_grad 0 --model res9 --experiment_name Cifar10_res9_H1_comUE5_R0.0_SSFL_user_semi 3 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 5 --size 101 --epoch 200 --basicLabelRatio 1.0 --cp 5 --k_img 1000 --Ns 5000 --eval_grad 0 --model res9 --experiment_name Cifar10_res9_H1_comUE5_R1.0_SSFL 4 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 5 --size 100 --user_semi 1 --epoch 200 --basicLabelRatio 1.0 --cp 5 --k_img 1000 --Ns 5000 --eval_grad 0 --model res9 --experiment_name Cifar10_res9_H1_comUE5_R1.0_SSFL_user_semi 5 | -------------------------------------------------------------------------------- /run_emnist.sh: -------------------------------------------------------------------------------- 1 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --Ns 1000 --datasetid 2 --fast 0 2 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --Ns 2000 --datasetid 2 --fast 0 3 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --Ns 3000 --datasetid 2 --fast 0 4 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --Ns 4000 --datasetid 2 --fast 0 5 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --Ns 5000 --datasetid 2 --fast 0 6 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --cp 2 --datasetid 2 --fast 0 --Ns 4700 7 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --cp 4 --datasetid 2 --fast 0 --Ns 4700 8 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --cp 8 --datasetid 2 --fast 0 --Ns 4700 9 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --cp 32 --datasetid 2 --fast 0 --Ns 4700 10 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --cp 64 --datasetid 2 --fast 0 --Ns 4700 11 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --basicLabelRatio 0.0 --datasetid 2 --fast 0 --Ns 4700 12 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --basicLabelRatio 0.2 --datasetid 2 --fast 0 --Ns 4700 13 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --basicLabelRatio 0.6 --datasetid 2 --fast 0 --Ns 4700 14 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --basicLabelRatio 0.8 --datasetid 2 --fast 0 --Ns 4700 15 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --basicLabelRatio 1.0 --datasetid 2 --fast 0 --Ns 4700 16 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE10_H1_R0.4_SSFL 17 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 30 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE20_H1_R0.4_SSFL 18 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 47 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE47_H1_R0.4_SSFL 19 | nohup python train_parallel.py --GPU_list 01234567 --H 0 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE10_H0_R0.4_SSFL 20 | nohup python train_parallel.py --GPU_list 01234567 --H 0 --num_comm_ue 30 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE30_H0_R0.4_SSFL 21 | nohup python train_parallel.py --GPU_list 01234567 --H 0 --num_comm_ue 47 --size 48 --epoch 100 --eval_grad 0 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE47_H0_R0.4_SSFL 22 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 1 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE10_H1_R0.4_SSFL 23 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 30 --size 48 --epoch 100 --eval_grad 1 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE20_H1_R0.4_SSFL 24 | nohup python train_parallel.py --GPU_list 01234567 --H 1 --num_comm_ue 47 --size 48 --epoch 100 --eval_grad 1 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE47_H1_R0.4_SSFL 25 | nohup python train_parallel.py --GPU_list 01234567 --H 0 --num_comm_ue 10 --size 48 --epoch 100 --eval_grad 1 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE10_H0_R0.4_SSFL 26 | nohup python train_parallel.py --GPU_list 01234567 --H 0 --num_comm_ue 30 --size 48 --epoch 100 --eval_grad 1 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE30_H0_R0.4_SSFL 27 | nohup python train_parallel.py --GPU_list 01234567 --H 0 --num_comm_ue 47 --size 48 --epoch 100 --eval_grad 1 --model EMNIST_model --datasetid 2 --Ns 4700 --fast 0 --experiment_name EMNIST_size47_comUE47_H0_R0.4_SSFL 28 | -------------------------------------------------------------------------------- /run_svhn.sh: -------------------------------------------------------------------------------- 1 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --datasetid 1 2 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --Ns 2000 --datasetid 1 3 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --Ns 3000 --datasetid 1 4 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --Ns 4000 --datasetid 1 5 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --Ns 5000 --datasetid 1 6 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --cp 2 --datasetid 1 7 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --cp 4 --datasetid 1 8 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --cp 8 --datasetid 1 9 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --cp 32 --datasetid 1 10 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --cp 64 --datasetid 1 11 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --basicLabelRatio 0.0 --datasetid 1 12 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --basicLabelRatio 0.2 --datasetid 1 13 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --basicLabelRatio 0.6 --datasetid 1 14 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --basicLabelRatio 0.8 --datasetid 1 15 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 11 --epoch 40 --eval_grad 0 --model res_gn --basicLabelRatio 1.0 --datasetid 1 16 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 21 --epoch 40 --eval_grad 0 --model res_gn --datasetid 1 17 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 20 --size 21 --epoch 40 --eval_grad 0 --model res_gn --datasetid 1 18 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 10 --size 31 --epoch 40 --eval_grad 0 --model res_gn --datasetid 1 19 | nohup python train_parallel.py --GPU_list 01234 --H 1 --num_comm_ue 30 --size 31 --epoch 40 --eval_grad 0 --model res_gn --datasetid 1 20 | -------------------------------------------------------------------------------- /train_parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #!/usr/bin/python3 3 | 4 | import threading 5 | import time 6 | import os 7 | import numpy as np 8 | import random 9 | 10 | import gpustat 11 | import logging 12 | 13 | import itertools 14 | import torch 15 | import torch.optim as optim 16 | import argparse 17 | import sys 18 | from scipy import io 19 | import datetime 20 | from utils_v2 import Get_num_ranks_all_size_num_devices 21 | 22 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 23 | os.environ['MKL_NUM_THREADS'] = '1' 24 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 25 | 26 | parser = argparse.ArgumentParser(description='SSFL training') 27 | 28 | parser.add_argument('--GPU_list', 29 | default='01', 30 | type=str, 31 | help='gpu list') 32 | parser.add_argument('--datasetid', 33 | default=0, 34 | type=int, 35 | help='dataset') 36 | parser.add_argument('--basicLabelRatio', 37 | default=0.4, 38 | type=float, 39 | help='basicLabelRatio') 40 | parser.add_argument('--labeled', 41 | default=0, 42 | type=int, 43 | help='supervised or not') 44 | parser.add_argument('--num_comm_ue', 45 | default=10, 46 | type=int, 47 | help='supervised or not') 48 | parser.add_argument('--H', 49 | default=1, 50 | type=int, 51 | help='Group or not') 52 | parser.add_argument('--cp', 53 | default=16, 54 | type=int, 55 | help='cp') 56 | 57 | parser.add_argument('--eval_grad', 58 | default=1, 59 | type=int, 60 | help='eval_grad or training') 61 | 62 | parser.add_argument('--experiment_folder', default='.', type=str, 63 | help='the path of the experiment') 64 | 65 | parser.add_argument('--experiment_name', default=None, type=str, 66 | help='experiment_name') 67 | 68 | parser.add_argument('--tao', 69 | default=0.95, 70 | type=float, 71 | help='tao for cal. mask') 72 | 73 | 74 | parser.add_argument('--model', default='res_gn', type=str, 75 | help='model') 76 | 77 | parser.add_argument('--ue_loss', default='CRL', type=str, 78 | help='user loss for training') 79 | 80 | parser.add_argument('--user_semi', 81 | default=0, 82 | type=int, 83 | help='user side semi') 84 | 85 | parser.add_argument('--size', 86 | default=11, 87 | type=int, 88 | help='user number + one server') 89 | parser.add_argument('--epoch', 90 | default=300, 91 | type=int, 92 | help='training epoch') 93 | 94 | parser.add_argument('--batch_size', 95 | default=64, 96 | type=int, 97 | help='batch_size for training') 98 | 99 | parser.add_argument('--k_img', 100 | default=65536, 101 | type=int, 102 | help='k_img') 103 | 104 | parser.add_argument('--fast', 105 | default=1, 106 | type=int, 107 | help='use fast model for lr scheduler or not') 108 | 109 | parser.add_argument('--Ns', 110 | default=1000, 111 | type=int, 112 | help='number of labeled data in server') 113 | 114 | args = parser.parse_args() 115 | 116 | FORMAT = '[%(asctime)-15s %(filename)s:%(lineno)s] %(message)s' 117 | FORMAT_MINIMAL = '%(message)s' 118 | 119 | logger = logging.getLogger('runner') 120 | logging.basicConfig(format=FORMAT) 121 | logger.setLevel(logging.DEBUG) 122 | 123 | 124 | exitFlag = 0 125 | GPU_MEMORY_THRESHOLD = 24000 # MB? 126 | 127 | def get_free_gpu_indices(): 128 | ''' 129 | Return an available GPU index. 130 | ''' 131 | while True: 132 | stats = gpustat.GPUStatCollection.new_query() 133 | return_list = [] 134 | for i, stat in enumerate(stats.gpus): 135 | memory_used = stat['memory.used'] 136 | if memory_used < GPU_MEMORY_THRESHOLD: 137 | return i 138 | 139 | logger.info("Waiting on GPUs") 140 | time.sleep(5) 141 | 142 | 143 | class DispatchThread(threading.Thread): 144 | def __init__(self, threadID, name, counter, bash_command_list): 145 | threading.Thread.__init__(self) 146 | self.threadID = threadID 147 | self.name = name 148 | self.counter = counter 149 | self.bash_command_list = bash_command_list 150 | 151 | def run(self): 152 | # logger.info("Starting " + self.name) 153 | threads = [] 154 | for i, bash_command in enumerate(self.bash_command_list): 155 | 156 | cuda_device = get_free_gpu_indices() 157 | thread1 = ChildThread(1, f"{i}th + {bash_command}", 1, cuda_device, bash_command) 158 | thread1.start() 159 | import time 160 | time.sleep(5) 161 | threads.append(thread1) 162 | 163 | # join all. 164 | for t in threads: 165 | t.join() 166 | logger.info("Exiting " + self.name) 167 | 168 | 169 | 170 | class ChildThread(threading.Thread): 171 | def __init__(self, threadID, name, counter, cuda_device, bash_command): 172 | threading.Thread.__init__(self) 173 | self.threadID = threadID 174 | self.name = name 175 | self.counter = counter 176 | self.cuda_device = cuda_device 177 | self.bash_command = bash_command 178 | 179 | def run(self): 180 | bash_command = self.bash_command 181 | 182 | # ACTIVATE 183 | os.system(bash_command) 184 | import time 185 | import random 186 | time.sleep(random.random() % 5) 187 | 188 | logger.info("Finishing " + self.name) 189 | 190 | if args.datasetid == 0: 191 | dataset = 'cifar10' 192 | if args.datasetid == 1: 193 | dataset = 'svhn' 194 | if args.datasetid == 2: 195 | dataset = 'emnist' 196 | 197 | """ 198 | ########## 199 | Assume the number of UEs is K 200 | *************************************************************************************************************************************** 201 | parameters Value/meaning 202 | size: size = K + 1 (server); 203 | cp: cp in {2, 4, 8, 16} is frequency of communication; cp = 2 means UEs and server communicates every 2 iterations; 204 | basicLabelRatio: basicLabelRatio in {0.1, 0.2, 0.4, ..., 1.0}, is the degree of data dispersion for each UE, 205 | basicLabelRatio = 0.0 means UE has the same amount of samples in each class; basicLabelRatio = 1.0 means all samples owned 206 | by UE belong to the same class; 207 | model: model in {'res', 'res_gn'}; model = 'res' means we use ResNet18 + BN; model = 'res_gn' means we use ResNet18 + GN; 208 | iid: iid in {0, 1}; iid = 1 is the IID case; iid = 0 is the Non-IID case; 209 | num_comm_ue: num_comm_ue in {1, 2, ..., K}; communication user number per iteration; 210 | k_img: Total data volume after data augmentation for each UE and server; 211 | H: H in {0, 1}; use grouping-based method or not; H = 1 means we use grouping-based method; 212 | GPU_list: GPU_list is a string; GPU_list = '01' means we use GPU0 and GPU1 for training; 213 | num_data_server: num_data_server in {1000, 4000}, number of labeled samples in server 214 | master_port: a random string; MASTER_PORT 215 | ip_address: a string; MASTER_ADDR 216 | 217 | For examples: 218 | size = 5 + 1 219 | batch_size = 64 220 | cp = 4 221 | basicLabelRatio = 0.4 222 | model = 'res_gn' 223 | iid = 0 224 | num_comm_ue = 2 225 | k_img = 65536 226 | epoches = 300 227 | H = 1 228 | num_data_server = 1000 229 | *************************************************************************************************************************************** 230 | """ 231 | 232 | 233 | size = args.size 234 | batch_size = args.batch_size 235 | cp_list = [args.cp] 236 | 237 | basicLabelRatio = args.basicLabelRatio 238 | model_list = [args.model] 239 | 240 | if basicLabelRatio == 0.0: 241 | iid = 1 242 | else: 243 | iid = 0 244 | 245 | num_comm_ue = args.num_comm_ue 246 | k_img = args.k_img 247 | epoches = args.epoch 248 | warmup_epoch = 5 249 | num_data_server = args.Ns 250 | 251 | labeled = args.labeled 252 | fast = args.fast 253 | H = args.H 254 | epoch_interval = epoches//10 255 | 256 | GPU_list = args.GPU_list 257 | 258 | import socket 259 | myname = socket.getfqdn(socket.gethostname( )) 260 | myaddr = socket.gethostbyname(myname) 261 | print('The ip address:',myaddr) 262 | ip_address = myaddr 263 | 264 | class_per_device = 1 265 | Start_Epoch = [0] 266 | num_rank, size_all, num_devices = Get_num_ranks_all_size_num_devices(args) 267 | 268 | now_time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 269 | 270 | if args.experiment_name is None: 271 | experiment_name = f'{dataset}_size_all_{size_all}_UE{num_devices}_comUE{num_comm_ue}_cp{cp_list[0]}_Model{model_list[0]}_H{H}_labeled_{labeled}_Ns_{num_data_server}_eval_grad_{args.eval_grad}_Time_{now_time}' 272 | else: 273 | experiment_name = args.experiment_name 274 | 275 | ### submitte models to uers' device for training 276 | for model in model_list: 277 | for cp in cp_list: 278 | for epoch_resume in Start_Epoch: 279 | master_port = random.sample(range(10000,30000),1) 280 | master_port = str(master_port[0]) 281 | BASH_COMMAND_LIST = [] 282 | for rank in range(num_rank): 283 | lr = 0.03*(10.0+1.0)*batch_size/128.0 284 | if args.model == 'res9': 285 | lr = 0.003*num_comm_ue*batch_size/128.0 286 | if dataset == 'emnist' or args.ue_loss == 'SF': 287 | lr = 0.03 288 | warmup_epoch = 0 289 | if args.user_semi and args.model != 'res9': 290 | lr = 0.01 291 | warmup_epoch = 0 292 | 293 | comm = f"setsid python train_LocalSGD.py --dataset {dataset} --model {model} --eval_grad {args.eval_grad} --epoch_resume {epoch_resume} --epoch_interval {epoch_interval}\ 294 | --lr {lr} --bs {batch_size} --cp {cp} --alpha 0.6 --gmf 0.7 --basicLabelRatio {basicLabelRatio} --master_port {master_port}\ 295 | --name revised_results_e300 --ip_address {ip_address} --num_comm_ue {num_comm_ue} --num_data_server {num_data_server} --k-img {k_img}\ 296 | --iid {iid} --rank {rank} --size {size_all} --backend gloo --warmup_epoch {warmup_epoch} --GPU_list {GPU_list} --labeled {labeled}\ 297 | --class_per_device {1} --num-devices {num_devices} --num_rank {num_rank} --epoch {epoches} --experiment_name {experiment_name} --fast {fast} --H {H} \ 298 | --experiment_folder {args.experiment_folder} --tao {args.tao} --ue_loss {args.ue_loss} --user_semi {args.user_semi}" 299 | 300 | 301 | BASH_COMMAND_LIST.append(comm) 302 | 303 | 304 | dispatch_thread = DispatchThread(2, "Thread-2", 4, BASH_COMMAND_LIST) 305 | # # Start new Threads 306 | dispatch_thread.start() 307 | dispatch_thread.join() 308 | 309 | import time 310 | time.sleep(5) 311 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | 6 | class PadandRandomCrop(object): 7 | ''' 8 | Input tensor is expected to have shape of (H, W, 3) 9 | ''' 10 | def __init__(self, border=4, cropsize=(32, 32)): 11 | self.border = border 12 | self.cropsize = cropsize 13 | 14 | def __call__(self, im): 15 | borders = [(self.border, self.border), (self.border, self.border), (0, 0)] 16 | convas = np.pad(im, borders, mode='reflect') 17 | H, W, C = convas.shape 18 | h, w = self.cropsize 19 | dh, dw = max(0, H-h), max(0, W-w) 20 | sh, sw = np.random.randint(0, dh), np.random.randint(0, dw) 21 | out = convas[sh:sh+h, sw:sw+w, :] 22 | return out 23 | 24 | 25 | class RandomHorizontalFlip(object): 26 | def __init__(self, p=0.5): 27 | self.p = p 28 | 29 | def __call__(self, im): 30 | if np.random.rand() < self.p: 31 | im = im[:, ::-1, :] 32 | return im 33 | 34 | 35 | class Resize(object): 36 | def __init__(self, size): 37 | self.size = size 38 | 39 | def __call__(self, im): 40 | im = cv2.resize(im, self.size) 41 | return im 42 | 43 | 44 | class Normalize(object): 45 | ''' 46 | Inputs are pixel values in range of [0, 255], channel order is 'rgb' 47 | ''' 48 | def __init__(self, mean, std): 49 | self.mean = np.array(mean, np.float32).reshape(1, 1, -1) 50 | self.std = np.array(std, np.float32).reshape(1, 1, -1) 51 | 52 | def __call__(self, im): 53 | if len(im.shape) == 4: 54 | mean, std = self.mean[None, ...], self.std[None, ...] 55 | elif len(im.shape) == 3: 56 | mean, std = self.mean, self.std 57 | im = im.astype(np.float32) / 255. 58 | im -= mean 59 | im /= std 60 | return im 61 | 62 | 63 | class ToTensor(object): 64 | def __init__(self): 65 | pass 66 | 67 | def __call__(self, im): 68 | if len(im.shape) == 4: 69 | return torch.from_numpy(im.transpose(0, 3, 1, 2)) 70 | elif len(im.shape) == 3: 71 | return torch.from_numpy(im.transpose(2, 0, 1)) 72 | 73 | 74 | class Compose(object): 75 | def __init__(self, ops): 76 | self.ops = ops 77 | 78 | def __call__(self, im): 79 | for op in self.ops: 80 | im = op(im) 81 | return im 82 | -------------------------------------------------------------------------------- /util_v4.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import argparse 5 | # import logging 6 | 7 | #from mpi4py import MPI 8 | from math import ceil 9 | from random import Random 10 | #import networkx as nx 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torch.utils.data.distributed 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from torch.multiprocessing import Process 19 | from torch.autograd import Variable 20 | import torchvision 21 | from torchvision import datasets, transforms 22 | import torch.backends.cudnn as cudnn 23 | import torchvision.models as IMG_models 24 | 25 | from models import * 26 | from models.Semi_net import SemiNet 27 | 28 | # logging.basicConfig(level=logging.INFO) 29 | 30 | 31 | class Partition(object): 32 | """ Dataset-like object, but only access a subset of it. """ 33 | 34 | def __init__(self, data, index): 35 | self.data = data 36 | self.index = index 37 | 38 | def __len__(self): 39 | return len(self.index) 40 | 41 | def __getitem__(self, index): 42 | data_idx = self.index[index] 43 | return self.data[data_idx] 44 | 45 | class DataPartitioner(object): 46 | """ Partitions a dataset into different chuncks. """ 47 | def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234, isNonIID=False, alpha=0): 48 | self.data = data 49 | self.partitions = [] 50 | self.ratio = [0] * len(sizes) 51 | rng = Random() 52 | rng.seed(seed) 53 | data_len = len(data) 54 | indexes = [x for x in range(0, data_len)] 55 | rng.shuffle(indexes) 56 | 57 | 58 | for frac in sizes: 59 | part_len = int(frac * data_len) 60 | self.partitions.append(indexes[0:part_len]) 61 | indexes = indexes[part_len:] 62 | 63 | if isNonIID: 64 | self.partitions, self.ratio = self.__getDirichletData__(data, sizes, seed, alpha) 65 | 66 | def use(self, partition): 67 | return Partition(self.data, self.partitions[partition]) 68 | 69 | def __getNonIIDdata__(self, data, sizes, seed): 70 | labelList = data.train_labels 71 | rng = Random() 72 | rng.seed(seed) 73 | a = [(label, idx) for idx, label in enumerate(labelList)] 74 | # Same Part 75 | labelIdxDict = dict() 76 | for label, idx in a: 77 | labelIdxDict.setdefault(label,[]) 78 | labelIdxDict[label].append(idx) 79 | labelNum = len(labelIdxDict) 80 | labelNameList = [key for key in labelIdxDict] 81 | labelIdxPointer = [0] * labelNum 82 | # sizes = number of nodes 83 | partitions = [list() for i in range(len(sizes))] 84 | eachPartitionLen= int(len(labelList)/len(sizes)) 85 | majorLabelNumPerPartition = ceil(labelNum/len(partitions)) 86 | basicLabelRatio = 0.4 87 | 88 | interval = 1 89 | labelPointer = 0 90 | 91 | #basic part 92 | for partPointer in range(len(partitions)): 93 | requiredLabelList = list() 94 | for _ in range(majorLabelNumPerPartition): 95 | requiredLabelList.append(labelPointer) 96 | labelPointer += interval 97 | if labelPointer > labelNum - 1: 98 | labelPointer = interval 99 | interval += 1 100 | for labelIdx in requiredLabelList: 101 | start = labelIdxPointer[labelIdx] 102 | idxIncrement = int(basicLabelRatio*len(labelIdxDict[labelNameList[labelIdx]])) 103 | partitions[partPointer].extend(labelIdxDict[labelNameList[labelIdx]][start:start+ idxIncrement]) 104 | labelIdxPointer[labelIdx] += idxIncrement 105 | 106 | #random part 107 | remainLabels = list() 108 | for labelIdx in range(labelNum): 109 | remainLabels.extend(labelIdxDict[labelNameList[labelIdx]][labelIdxPointer[labelIdx]:]) 110 | rng.shuffle(remainLabels) 111 | for partPointer in range(len(partitions)): 112 | idxIncrement = eachPartitionLen - len(partitions[partPointer]) 113 | partitions[partPointer].extend(remainLabels[:idxIncrement]) 114 | rng.shuffle(partitions[partPointer]) 115 | remainLabels = remainLabels[idxIncrement:] 116 | return partitions 117 | 118 | def __getDirichletData__(self, data, psizes, seed, alpha): 119 | sizes = len(psizes) 120 | labelList = data.targets 121 | rng = Random() 122 | rng.seed(seed) 123 | a = [(label, idx) for idx, label in enumerate(labelList)] 124 | # Same Part 125 | labelIdxDict = dict() 126 | for label, idx in a: 127 | labelIdxDict.setdefault(label,[]) 128 | labelIdxDict[label].append(idx) 129 | labelNum = len(labelIdxDict) #10 130 | labelNameList = [key for key in labelIdxDict] 131 | # rng.shuffle(labelNameList) 132 | labelIdxPointer = [0] * labelNum 133 | # sizes = number of nodes 134 | partitions = [list() for i in range(sizes)] # of size (m) 135 | np.random.seed(seed) 136 | distribution = np.random.dirichlet([alpha] * sizes, labelNum).tolist() # of size (10, m) 137 | 138 | # basic part 139 | for row_id, dist in enumerate(distribution): 140 | subDictList = labelIdxDict[labelNameList[row_id]] 141 | rng.shuffle(subDictList) 142 | totalNum = len(subDictList) 143 | dist = self.handlePartition(dist, totalNum) 144 | for i in range(len(dist)-1): 145 | partitions[i].extend(subDictList[dist[i]:dist[i+1]+1]) 146 | 147 | #random part 148 | a = [len(partitions[i]) for i in range(len(partitions))] 149 | ratio = [a[i]/sum(a) for i in range(len(a))] 150 | return partitions, ratio 151 | 152 | def handlePartition(self, plist, length): 153 | newList = [0] 154 | canary = 0 155 | for i in range(len(plist)): 156 | canary = int(canary + length*plist[i]) 157 | newList.append(canary) 158 | return newList 159 | 160 | def partition_dataset(rank, size, args): 161 | print('==> load train data') 162 | if args.dataset == 'cifar10': 163 | transform_train = transforms.Compose([ 164 | transforms.RandomCrop(32, padding=4), 165 | transforms.RandomHorizontalFlip(), 166 | transforms.ToTensor(), 167 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 168 | ]) 169 | trainset = torchvision.datasets.CIFAR10(root='./data', 170 | train=True, 171 | download=True, 172 | transform=transform_train) 173 | 174 | partition_sizes = [1.0 / size for _ in range(size)] 175 | partition = DataPartitioner(trainset, partition_sizes, isNonIID=args.NIID, alpha=args.alpha) 176 | ratio = partition.ratio 177 | partition = partition.use(rank) 178 | train_loader = torch.utils.data.DataLoader(partition, 179 | batch_size=args.bs, 180 | shuffle=True, 181 | pin_memory=True) 182 | 183 | print('==> load test data') 184 | transform_test = transforms.Compose([ 185 | transforms.ToTensor(), 186 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 187 | ]) 188 | testset = torchvision.datasets.CIFAR10(root='./data', 189 | train=False, 190 | download=True, 191 | transform=transform_test) 192 | test_loader = torch.utils.data.DataLoader(testset, 193 | batch_size=64, 194 | shuffle=False, 195 | num_workers=size) 196 | 197 | if args.dataset == 'cifar100': 198 | transform_train = transforms.Compose([ 199 | transforms.RandomCrop(32, padding=4), 200 | transforms.RandomHorizontalFlip(), 201 | transforms.ToTensor(), 202 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 203 | ]) 204 | trainset = torchvision.datasets.CIFAR100(root='/users/jianyuw1/AdaDSGD/data', 205 | train=True, 206 | download=True, 207 | transform=transform_train) 208 | 209 | partition_sizes = [1.0 / size for _ in range(size)] 210 | partition = DataPartitioner(trainset, partition_sizes, isNonIID=False) 211 | ratio = partition.ratio 212 | partition = partition.use(rank) 213 | train_loader = torch.utils.data.DataLoader(partition, 214 | batch_size=args.bs, 215 | shuffle=True, 216 | pin_memory=True) 217 | 218 | print('==> load test data') 219 | transform_test = transforms.Compose([ 220 | transforms.ToTensor(), 221 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 222 | ]) 223 | testset = torchvision.datasets.CIFAR100(root='/users/jianyuw1/AdaDSGD/data', 224 | train=False, 225 | download=True, 226 | transform=transform_test) 227 | test_loader = torch.utils.data.DataLoader(testset, 228 | batch_size=64, 229 | shuffle=False, 230 | num_workers=size) 231 | 232 | 233 | elif args.dataset == 'imagenet': 234 | datadir = '/datasets/shared/imagenet/ILSVRC2015/Data/' 235 | traindir = os.path.join(datadir, 'CLS-LOC/train/') 236 | #valdir = os.path.join(datadir, 'CLS-LOC/') 237 | #testdir = os.path.join(datadir, 'CLS-LOC/') 238 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 239 | std=[0.229, 0.224, 0.225]) 240 | train_dataset = datasets.ImageFolder( 241 | traindir, 242 | transforms.Compose([ 243 | transforms.RandomResizedCrop(224), 244 | transforms.RandomHorizontalFlip(), 245 | transforms.ToTensor(), 246 | normalize, 247 | ])) 248 | 249 | partition_sizes = [1.0 / size for _ in range(size)] 250 | partition = DataPartitioner(train_dataset, partition_sizes, isNonIID=False) 251 | ratio = partition.ratio 252 | partition = partition.use(rank) 253 | 254 | train_loader = torch.utils.data.DataLoader( 255 | partition, batch_size=args.bs, shuffle=True, 256 | pin_memory=True) 257 | ''' 258 | val_loader = torch.utils.data.DataLoader( 259 | datasets.ImageFolder(valdir, transforms.Compose([ 260 | transforms.Resize(256), 261 | transforms.CenterCrop(224), 262 | transforms.ToTensor(), 263 | normalize, 264 | ])), 265 | batch_size=args.bs, shuffle=False, 266 | pin_memory=True) 267 | val_loader = None 268 | ''' 269 | test_loader = None 270 | 271 | if args.dataset == 'emnist': 272 | transform_train = transforms.Compose([ 273 | transforms.ToTensor(), 274 | ]) 275 | train_dataset = torchvision.datasets.EMNIST(root='/users/jianyuw1/AdaDSGD/data', 276 | split = 'balanced', 277 | train=True, 278 | download=True, 279 | transform=transform_train) 280 | partition_sizes = [1.0 / size for _ in range(size)] 281 | partition = DataPartitioner(train_dataset, partition_sizes, isNonIID=False) 282 | ratio = partition.ratio 283 | partition = partition.use(rank) 284 | 285 | train_loader = torch.utils.data.DataLoader( 286 | partition, batch_size=args.bs, shuffle=True, 287 | pin_memory=True) 288 | 289 | transform_test = transforms.Compose([ 290 | transforms.ToTensor(), 291 | ]) 292 | testset = torchvision.datasets.EMNIST(root='/users/jianyuw1/AdaDSGD/data', 293 | split = 'balanced', 294 | train=False, 295 | download=True, 296 | transform=transform_test) 297 | test_loader = torch.utils.data.DataLoader(testset, 298 | batch_size=64, 299 | shuffle=False, 300 | num_workers=size) 301 | 302 | 303 | 304 | return train_loader, test_loader 305 | 306 | def select_model(num_class, args): 307 | if args.model == 'VGG': 308 | model = VGG(16, num_class) 309 | elif args.model == 'res': 310 | if args.dataset == 'cifar10': 311 | # model = resnet.ResNet(34, num_class) 312 | model = ResNet18() 313 | elif args.dataset == 'imagenet': 314 | model = IMG_models.resnet18() 315 | elif args.model == 'wrn': 316 | num_class = 10 317 | model = Wide_ResNet(28,10,0,num_class) 318 | elif args.model == 'mlp': 319 | if args.dataset == 'emnist': 320 | model = MNIST_MLP(47) 321 | elif args.model == 'res_gn': 322 | # model = resnet.ResNet(34, num_class) 323 | model = ResNet18_gn() 324 | elif args.model == 'res_ln': 325 | if args.dataset == 'cifar10': 326 | # model = resnet.ResNet(34, num_class) 327 | model = ResNet18_LN() 328 | elif args.model == 'EMNIST_model': 329 | if args.dataset == 'emnist': 330 | # model = resnet.ResNet(34, num_class) 331 | model = Net() 332 | elif args.model == 'UserSemi_model': 333 | 334 | model = SemiNet() 335 | 336 | return model 337 | 338 | def comp_accuracy(output, target, topk=(1,)): 339 | """Computes the accuracy over the k top predictions for the specified values of k""" 340 | with torch.no_grad(): 341 | maxk = max(topk) 342 | batch_size = target.size(0) 343 | 344 | _, pred = output.topk(maxk, 1, True, True) 345 | pred = pred.t() 346 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 347 | 348 | res = [] 349 | for k in topk: 350 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 351 | res.append(correct_k.mul_(100.0 / batch_size)) 352 | return res 353 | 354 | class AverageMeter(object): 355 | """Computes and stores the average and current value""" 356 | def __init__(self): 357 | self.reset() 358 | 359 | def reset(self): 360 | self.val = 0 361 | self.avg = 0 362 | self.sum = 0 363 | self.count = 0 364 | 365 | def update(self, val, n=1): 366 | self.val = val 367 | self.sum += val * n 368 | self.count += n 369 | self.avg = self.sum / self.count 370 | 371 | class Meter(object): 372 | """ Computes and stores the average, variance, and current value """ 373 | 374 | def __init__(self, init_dict=None, ptag='Time', stateful=False, 375 | csv_format=True): 376 | """ 377 | :param init_dict: Dictionary to initialize meter values 378 | :param ptag: Print tag used in __str__() to identify meter 379 | :param stateful: Whether to store value history and compute MAD 380 | """ 381 | self.reset() 382 | self.ptag = ptag 383 | self.value_history = None 384 | self.stateful = stateful 385 | if self.stateful: 386 | self.value_history = [] 387 | self.csv_format = csv_format 388 | if init_dict is not None: 389 | for key in init_dict: 390 | try: 391 | # TODO: add type checking to init_dict values 392 | self.__dict__[key] = init_dict[key] 393 | except Exception: 394 | print('(Warning) Invalid key {} in init_dict'.format(key)) 395 | 396 | def reset(self): 397 | self.val = 0 398 | self.avg = 0 399 | self.sum = 0 400 | self.count = 0 401 | self.std = 0 402 | self.sqsum = 0 403 | self.mad = 0 404 | 405 | def update(self, val, n=1): 406 | self.val = val 407 | self.sum += val * n 408 | self.count += n 409 | self.avg = self.sum / self.count 410 | self.sqsum += (val ** 2) * n 411 | if self.count > 1: 412 | self.std = ((self.sqsum - (self.sum ** 2) / self.count) 413 | / (self.count - 1) 414 | ) ** 0.5 415 | if self.stateful: 416 | self.value_history.append(val) 417 | mad = 0 418 | for v in self.value_history: 419 | mad += abs(v - self.avg) 420 | self.mad = mad / len(self.value_history) 421 | 422 | def __str__(self): 423 | if self.csv_format: 424 | if self.stateful: 425 | return str('{dm.val:.3f},{dm.avg:.3f},{dm.mad:.3f}' 426 | .format(dm=self)) 427 | else: 428 | return str('{dm.val:.3f},{dm.avg:.3f},{dm.std:.3f}' 429 | .format(dm=self)) 430 | else: 431 | if self.stateful: 432 | return str(self.ptag) + \ 433 | str(': {dm.val:.3f} ({dm.avg:.3f} +- {dm.mad:.3f})' 434 | .format(dm=self)) 435 | else: 436 | return str(self.ptag) + \ 437 | str(': {dm.val:.3f} ({dm.avg:.3f} +- {dm.std:.3f})' 438 | .format(dm=self)) 439 | --------------------------------------------------------------------------------