├── README.md ├── RGA.py ├── datasets.py ├── framework.png ├── kmeans.py ├── loss.py ├── main_CSPC.py ├── main_FedAvg.py ├── model.py ├── re_training.py ├── resnet.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # FedCSPC 2 | (ACM MM'23) Cross-Silo Prototypical Calibration for Federated Learning with Non-IID Data 3 | 4 | ## Overview 5 | This paper presents a novel Cross-Silo Prototypical Calibration method, termed FedCSPC. As illustrated in Figure, compared with conventional federated learning method, the proposed FedCSPC performs prototypical calibration, which can map representations from different feature spaces to a unified space while maintaining clear decision boundaries. Specifically, FedCSPC has two main modules: the Data Prototypical Modeling (DPM) module and the Cross-Silo Prototypical Calibration (CSPC) module. To promote the alignment of features across different spaces, the DPM module employs clustering to model the data patterns and provides prototypical information to the server to assist with model calibration. Subsequently, to enhance the robustness of calibration, FedCSPC develops an augmented contrastive learning method in the CSPC module, which increases sample diversity by positive mixing and hard negative mining, and implements contrastive learning to achieve effective alignment of cross-source features. Meanwhile, the calibrated prototypes form a knowledge base in a unified space and generate knowledge-based class predictions to reduce errors. Notably, the CSPC module is a highly adaptable tool that easily integrates into various algorithms. As observed, FedCSPC is capable of alleviating the feature gap between data sources, thus significantly improving the generalization ability. 6 | ![_](./framework.png) 7 | 8 | 9 | 10 | ## Dependencies 11 | * PyTorch >= 1.0.0 12 | * torchvision >= 0.2.1 13 | * scikit-learn >= 0.23.1 14 | 15 | 16 | 17 | ## Parameters 18 | 19 | | Parameter | Description | 20 | | ----------------------------- | ---------------------------------------- | 21 | | `model` | The model architecture. Options: `simple-cnn`, `resnet18` .| 22 | | `alg` | The training algorithm. Options: `CSPC` | 23 | | `dataset` | Dataset to use. Options: `cifar10`. `cifar100`, `tinyimagenet`| 24 | | `lr` | Learning rate. | 25 | | `batch-size` | Batch size. | 26 | | `epochs` | Number of local epochs. | 27 | | `n_parties` | Number of parties. | 28 | | `sample_fraction` | the fraction of parties to be sampled in each round. | 29 | | `comm_round` | Number of communication rounds. | 30 | | `partition` | The partition approach. Options: `noniid`, `iid`. | 31 | | `beta` | The concentration parameter of the Dirichlet distribution for non-IID partition. | 32 | | `out_dim` | The output dimension of the projection head. | 33 | | `datadir` | The path of the dataset. | 34 | | `logdir` | The path to store the logs. | 35 | | `device` | Specify the device to run the program. | 36 | | `seed` | The initial seed. | 37 | 38 | 39 | ## Usage 40 | 41 | Here is an example to run FedCSPC on CIFAR-10 with a simple CNN: 42 | ``` 43 | python main_CSPC.py --dataset=cifar10 --model=simple-cnn --alg=CSPC --lr=0.01 --epochs=10 --comm_round=100 --n_parties=10 --partition=noniid 44 | --beta=0.5 --logdir='./logs/' --datadir='./data/' 45 | ``` 46 | 47 | ## Citation 48 | Please cite our paper if you find this code useful for your research. 49 | ``` 50 | @article{qi2023cross, 51 | title={Cross-Silo Prototypical Calibration for Federated Learning with Non-IID Data}, 52 | author={Qi, Zhuang and Meng, Lei and Chen, Zitan and Hu, Han and Lin, Hui and Meng, Xiangxu}, 53 | journal={arXiv preprint arXiv:2308.03457}, 54 | year={2023} 55 | } 56 | 57 | 58 | -------------------------------------------------------------------------------- /RGA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | # import torch.tensor as tensor 5 | "Embedding Graph Alignment Loss" 6 | import ipdb 7 | def PCC(m): 8 | '''Compute the Pearson’s correlation coefficients.''' 9 | fact = 1.0 / (m.size(1) - 1) 10 | m = m - torch.mean(m, dim=1, keepdim=True) 11 | mt = m.t() 12 | c = fact * m.matmul(mt).squeeze() 13 | d = torch.diag(c, 0) 14 | std = torch.sqrt(d) 15 | c /= std[:, None] 16 | c /= std[None, :] 17 | return c 18 | 19 | 20 | 21 | # def pdist(a,dim=2, p=2): 22 | # dist_matrix = torch.norm(a[:, None]-a, dim, p) / a.shape[1] 23 | # return dist_matrix 24 | 25 | def cosinematrix(A): 26 | prod = torch.mm(A, A.t())#分子 27 | norm = torch.norm(A,p=2,dim=1).unsqueeze(0)#分母 28 | cos = prod.div(torch.mm(norm.t(),norm)) 29 | return cos 30 | 31 | 32 | def RKdNode(features, f_labels, prototypes, p_labels, t=0.5): 33 | 34 | a_norm = features / features.norm(dim=1)[:, None] 35 | b_norm = prototypes / prototypes.norm(dim=1)[:, None] 36 | sim_matrix = torch.exp(torch.mm(a_norm, b_norm.transpose(0,1)) / t) 37 | c_norm = prototypes[f_labels] / prototypes[f_labels].norm(dim=1)[:, None] 38 | pos_sim = torch.exp(torch.diag(torch.mm(a_norm, c_norm.transpose(0,1))) / t) 39 | 40 | loss = (-torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() 41 | 42 | return loss 43 | 44 | 45 | 46 | 47 | def pdist(e, squared=False, eps=1e-12): 48 | e_square = e.pow(2).sum(dim=1) 49 | prod = e @ e.t() 50 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 51 | 52 | if not squared: 53 | res = res.sqrt() 54 | 55 | res = res.clone() 56 | res[range(len(e)), range(len(e))] = 0 57 | return res 58 | 59 | 60 | def RKdAngle(student, teacher): 61 | # N x C 62 | # N x N x C 63 | 64 | with torch.no_grad(): 65 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) 66 | norm_td = F.normalize(td, p=2, dim=2) 67 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 68 | 69 | sd = (student.unsqueeze(0) - student.unsqueeze(1)) 70 | norm_sd = F.normalize(sd, p=2, dim=2) 71 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 72 | 73 | loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean') 74 | return loss 75 | 76 | 77 | 78 | def RkdEdge(student, teacher): 79 | with torch.no_grad(): 80 | t_d = pdist(teacher, squared=False) 81 | mean_td = t_d[t_d>0].mean() 82 | t_d = t_d / mean_td 83 | 84 | d = pdist(student, squared=False) 85 | mean_d = d[d>0].mean() 86 | d = d / mean_d 87 | 88 | loss = F.smooth_l1_loss(d, t_d, reduction='mean') 89 | return loss 90 | 91 | class RGA_loss(torch.nn.Module): 92 | def __init__(self, node_weight=1, edge_weight=0.3, angle_weight=0.1, t=0.5): 93 | super(RGA_loss, self).__init__() 94 | self.node_weight = node_weight 95 | self.edge_weight = edge_weight 96 | self.angle_weight = angle_weight 97 | self.t = t 98 | 99 | def forward(self, student, student_labels, teacher, teacher_label, mode='N'): 100 | REloss = RkdEdge(student, teacher[student_labels]) 101 | RAloss = RKdAngle(student, teacher[student_labels]) 102 | RNloss = RKdNode(student, student_labels, teacher, teacher_label, self.t) 103 | if mode == 'N' or mode == 'N*': 104 | RGAloss = RNloss 105 | elif mode == 'E': 106 | RGAloss = REloss 107 | elif mode == 'A': 108 | RGAloss = RAloss 109 | elif mode == 'N+E': 110 | RGAloss = self.node_weight * RNloss + self.edge_weight * REloss# ipdb.set_trace() 111 | elif mode == 'N+A': 112 | RGAloss = self.node_weight * RNloss + self.angle_weight * RAloss# 113 | elif mode == 'A+E': 114 | RGAloss = self.angle_weight * RAloss + self.edge_weight * REloss# 115 | elif mode == 'N+E+A': 116 | RGAloss = self.node_weight * RNloss + self.angle_weight * RAloss + self.edge_weight * REloss# 117 | 118 | return RGAloss 119 | 120 | 121 | # class RGA(torch.nn.Module): 122 | 123 | # def __init__(self, node_weight=1, edge_weight=0.3, t=0.5): 124 | # 125 | # super(RGA, self).__init__() 126 | 127 | # self.node_weight = node_weight 128 | # self.edge_weight = edge_weight 129 | # self.t = t 130 | 131 | # def forward(self, feats, feats_label, prototype, proto_label): 132 | 133 | # X = torch.cat((feats, prototype[feats_label]), 0) 134 | # # C = PCC(X) 135 | # C = pdist(X) 136 | # n = C.shape[0]//2 137 | 138 | # Et = C[0:n, 0:n] # compute teacher edge matrix 139 | # Es = C[n:, n:] # compute student edge matrix 140 | # Nts= C[0:n, n:] # compute node matrix 141 | 142 | 143 | # loss_edge = torch.norm((Et-Es), 2) 144 | # loss_node = PCLoss(feats, feats_label, prototype, proto_label, self.t) 145 | 146 | # RGA_loss = self.node_weight * loss_node + self.edge_weight * loss_edge 147 | 148 | # return RGA_loss 149 | 150 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import numpy as np 4 | import torchvision 5 | from torchvision.datasets import MNIST, EMNIST, CIFAR10, CIFAR100, SVHN, FashionMNIST, ImageFolder, DatasetFolder, utils 6 | from torch.utils.data import Dataset 7 | import os 8 | import os.path 9 | import logging 10 | import sys 11 | import torch 12 | import io 13 | import scipy.io as matio 14 | import ipdb 15 | 16 | logging.basicConfig() 17 | logger = logging.getLogger() 18 | logger.setLevel(logging.INFO) 19 | 20 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 21 | 22 | 23 | def mkdirs(dirpath): 24 | try: 25 | os.makedirs(dirpath) 26 | except Exception as _: 27 | pass 28 | 29 | def default_loader(image_path): 30 | return Image.open(image_path).convert('RGB') 31 | 32 | 33 | class MNIST_truncated(data.Dataset): 34 | 35 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 36 | 37 | self.root = root 38 | self.dataidxs = dataidxs 39 | self.train = train 40 | self.transform = transform 41 | self.target_transform = target_transform 42 | self.download = download 43 | 44 | 45 | self.data, self.target = self.__build_truncated_dataset__() 46 | 47 | def __build_truncated_dataset__(self): 48 | 49 | mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download) 50 | 51 | if torchvision.__version__ == '0.2.1': 52 | if self.train: 53 | data, target = mnist_dataobj.train_data, np.array(mnist_dataobj.train_labels) 54 | else: 55 | data, target = mnist_dataobj.test_data, np.array(mnist_dataobj.test_labels) 56 | else: 57 | data = mnist_dataobj.data 58 | target = np.array(mnist_dataobj.targets) 59 | 60 | 61 | if self.dataidxs is not None: 62 | data = data[self.dataidxs] 63 | target = target[self.dataidxs] 64 | 65 | return data, target 66 | 67 | def truncate_channel(self, index): 68 | for i in range(index.shape[0]): 69 | gs_index = index[i] 70 | self.data[gs_index, :, :, 1] = 0.0 71 | self.data[gs_index, :, :, 2] = 0.0 72 | 73 | 74 | 75 | def __getitem__(self, index): 76 | """ 77 | Args: 78 | index (int): Index 79 | 80 | Returns: 81 | tuple: (image, target) where target is index of the target class. 82 | """ 83 | img, target = self.data[index], self.target[index] 84 | 85 | img = img.reshape(28,28,1).cpu().numpy() 86 | 87 | 88 | if self.transform is not None: 89 | img = self.transform(img) 90 | 91 | 92 | return img, target 93 | 94 | def __len__(self): 95 | return len(self.data) 96 | 97 | 98 | 99 | 100 | class FashionMNIST_truncated(data.Dataset): 101 | 102 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 103 | 104 | self.root = root 105 | self.dataidxs = dataidxs 106 | self.train = train 107 | self.transform = transform 108 | self.target_transform = target_transform 109 | self.download = download 110 | 111 | self.data, self.target = self.__build_truncated_dataset__() 112 | 113 | def __build_truncated_dataset__(self): 114 | 115 | fmnist_dataobj = FashionMNIST(self.root, self.train, self.transform, self.target_transform, self.download) 116 | 117 | if torchvision.__version__ == '0.2.1': 118 | if self.train: 119 | data, target = fmnist_dataobj.train_data, np.array(fmnist_dataobj.train_labels) 120 | else: 121 | data, target = fmnist_dataobj.test_data, np.array(fmnist_dataobj.test_labels) 122 | else: 123 | data = fmnist_dataobj.data 124 | target = np.array(fmnist_dataobj.targets) 125 | 126 | if self.dataidxs is not None: 127 | data = data[self.dataidxs] 128 | target = target[self.dataidxs] 129 | 130 | return data, target 131 | 132 | def truncate_channel(self, index): 133 | for i in range(index.shape[0]): 134 | gs_index = index[i] 135 | self.data[gs_index, :, :, 1] = 0.0 136 | self.data[gs_index, :, :, 2] = 0.0 137 | 138 | def __getitem__(self, index): 139 | """ 140 | Args: 141 | index (int): Index 142 | 143 | Returns: 144 | tuple: (image, target) where target is index of the target class. 145 | """ 146 | img, target = self.data[index], self.target[index] 147 | 148 | if self.transform is not None: 149 | img = self.transform(img) 150 | if self.target_transform is not None: 151 | target = self.target_transform(target) 152 | 153 | return img, target 154 | 155 | def __len__(self): 156 | return len(self.data) 157 | 158 | 159 | class CIFAR10_truncated(data.Dataset): 160 | 161 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 162 | 163 | self.root = root 164 | self.dataidxs = dataidxs 165 | self.train = train 166 | self.transform = transform 167 | self.target_transform = target_transform 168 | self.download = download 169 | 170 | self.data, self.target = self.__build_truncated_dataset__() 171 | 172 | def __build_truncated_dataset__(self): 173 | 174 | cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) 175 | 176 | if torchvision.__version__ == '0.2.1': 177 | if self.train: 178 | data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels) 179 | else: 180 | data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels) 181 | else: 182 | data = cifar_dataobj.data 183 | target = np.array(cifar_dataobj.targets) 184 | # ipdb.set_trace() 185 | if self.dataidxs is not None: 186 | data = data[self.dataidxs] 187 | target = target[self.dataidxs] 188 | 189 | return data, target 190 | 191 | def truncate_channel(self, index): 192 | for i in range(index.shape[0]): 193 | gs_index = index[i] 194 | self.data[gs_index, :, :, 1] = 0.0 195 | self.data[gs_index, :, :, 2] = 0.0 196 | 197 | def __getitem__(self, index): 198 | """ 199 | Args: 200 | index (int): Index 201 | 202 | Returns: 203 | tuple: (image, target) where target is index of the target class. 204 | """ 205 | img, target = self.data[index], self.target[index] 206 | # img = Image.fromarray(img) 207 | # print("cifar10 img:", img) 208 | # print("cifar10 target:", target) 209 | 210 | if self.transform is not None: 211 | img = self.transform(img) 212 | 213 | if self.target_transform is not None: 214 | target = self.target_transform(target) 215 | 216 | return img, target 217 | 218 | def __len__(self): 219 | return len(self.data) 220 | 221 | 222 | class CIFAR100_truncated(data.Dataset): 223 | 224 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): 225 | 226 | self.root = root 227 | self.dataidxs = dataidxs 228 | self.train = train 229 | self.transform = transform 230 | self.target_transform = target_transform 231 | self.download = download 232 | 233 | self.data, self.target = self.__build_truncated_dataset__() 234 | 235 | def __build_truncated_dataset__(self): 236 | 237 | cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download) 238 | 239 | if torchvision.__version__ == '0.2.1': 240 | if self.train: 241 | data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels) 242 | else: 243 | data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels) 244 | else: 245 | data = cifar_dataobj.data 246 | target = np.array(cifar_dataobj.targets) 247 | 248 | if self.dataidxs is not None: 249 | data = data[self.dataidxs] 250 | target = target[self.dataidxs] 251 | 252 | return data, target 253 | 254 | def __getitem__(self, index): 255 | """ 256 | Args: 257 | index (int): Index 258 | 259 | Returns: 260 | tuple: (image, target) where target is index of the target class. 261 | """ 262 | img, target = self.data[index], self.target[index] 263 | img = Image.fromarray(img) 264 | # print("cifar10 img:", img) 265 | # print("cifar10 target:", target) 266 | 267 | if self.transform is not None: 268 | img = self.transform(img) 269 | 270 | if self.target_transform is not None: 271 | target = self.target_transform(target) 272 | 273 | return img, target 274 | 275 | def __len__(self): 276 | return len(self.data) 277 | 278 | class TinyImageNet_load(Dataset): 279 | def __init__(self, root, dataidxs=None, train=True, transform=None): 280 | self.Train = train 281 | self.root_dir = root 282 | self.transform = transform 283 | self.train_dir = os.path.join(self.root_dir, "train") 284 | self.val_dir = os.path.join(self.root_dir, "val") 285 | self.dataidxs = dataidxs 286 | if (self.Train): 287 | self._create_class_idx_dict_train() 288 | else: 289 | self._create_class_idx_dict_val() 290 | 291 | self._make_dataset(self.Train) 292 | 293 | words_file = os.path.join(self.root_dir, "words.txt") 294 | wnids_file = os.path.join(self.root_dir, "wnids.txt") 295 | 296 | self.set_nids = set() 297 | 298 | 299 | if self.dataidxs is not None: 300 | self.samples = self.images[dataidxs] 301 | else: 302 | self.samples = self.images 303 | 304 | # print('samples.shape', self.samples.shape) 305 | with open(wnids_file, 'r') as fo: 306 | data = fo.readlines() 307 | for entry in data: 308 | self.set_nids.add(entry.strip("\n")) 309 | 310 | self.class_to_label = {} 311 | with open(words_file, 'r') as fo: 312 | data = fo.readlines() 313 | for entry in data: 314 | words = entry.split("\t") 315 | if words[0] in self.set_nids: 316 | self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0] 317 | 318 | def _create_class_idx_dict_train(self): 319 | if sys.version_info >= (3, 5): 320 | classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()] 321 | else: 322 | classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(self.train_dir, d))] 323 | classes = sorted(classes) 324 | num_images = 0 325 | for root, dirs, files in os.walk(self.train_dir): 326 | for f in files: 327 | if f.endswith(".JPEG"): 328 | num_images = num_images + 1 329 | 330 | self.len_dataset = num_images; 331 | 332 | self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))} 333 | self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))} 334 | # print(self.tgt_idx_to_class) 335 | def _create_class_idx_dict_val(self): 336 | val_image_dir = os.path.join(self.val_dir, "images") 337 | if sys.version_info >= (3, 5): 338 | images = [d.name for d in os.scandir(val_image_dir) if d.is_file()] 339 | else: 340 | images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(val_image_dir, d))] 341 | val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt") 342 | self.val_img_to_class = {} 343 | set_of_classes = set() 344 | with open(val_annotations_file, 'r') as fo: 345 | entry = fo.readlines() 346 | for data in entry: 347 | words = data.split("\t") 348 | self.val_img_to_class[words[0]] = words[1] 349 | set_of_classes.add(words[1]) 350 | 351 | self.len_dataset = len(list(self.val_img_to_class.keys())) 352 | classes = sorted(list(set_of_classes)) 353 | # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))} 354 | self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))} 355 | self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))} 356 | 357 | def _make_dataset(self, Train=True): 358 | self.images = [] 359 | if Train: 360 | img_root_dir = self.train_dir 361 | list_of_dirs = [target for target in self.class_to_tgt_idx.keys()] 362 | else: 363 | img_root_dir = self.val_dir 364 | list_of_dirs = ["images"] 365 | 366 | for tgt in list_of_dirs: 367 | dirs = os.path.join(img_root_dir, tgt) 368 | if not os.path.isdir(dirs): 369 | continue 370 | 371 | for root, _, files in sorted(os.walk(dirs)): 372 | for fname in sorted(files): 373 | if (fname.endswith(".JPEG")): 374 | path = os.path.join(root, fname) 375 | if Train: 376 | item = (path, self.class_to_tgt_idx[tgt]) 377 | else: 378 | item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]]) 379 | self.images.append(item) 380 | self.images = np.array(self.images) 381 | # print('dataset.shape', self.images.shape) 382 | def return_label(self, idx): 383 | return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx] 384 | 385 | 386 | def __len__(self): 387 | return self.samples.shape[0] 388 | 389 | def __getitem__(self, idx): 390 | img_path, tgt = self.samples[idx] 391 | with open(img_path, 'rb') as f: 392 | sample = Image.open(img_path) 393 | sample = sample.convert('RGB') 394 | if self.transform is not None: 395 | sample = self.transform(sample) 396 | tgt = int(tgt) 397 | return sample, tgt 398 | 399 | 400 | class ImageFolder_custom(DatasetFolder): 401 | def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None): 402 | self.root = root 403 | self.dataidxs = dataidxs 404 | self.train = train 405 | self.transform = transform 406 | self.target_transform = target_transform 407 | 408 | imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform) 409 | self.loader = imagefolder_obj.loader 410 | if self.dataidxs is not None: 411 | self.samples = np.array(imagefolder_obj.samples)[self.dataidxs] 412 | else: 413 | self.samples = np.array(imagefolder_obj.samples) 414 | 415 | def __getitem__(self, index): 416 | path = self.samples[index][0] 417 | target = self.samples[index][1] 418 | target = int(target) 419 | sample = self.loader(path) 420 | if self.transform is not None: 421 | sample = self.transform(sample) 422 | if self.target_transform is not None: 423 | target = self.target_transform(target) 424 | 425 | return sample, target 426 | 427 | def __len__(self): 428 | if self.dataidxs is None: 429 | return len(self.samples) 430 | else: 431 | return len(self.dataidxs) 432 | 433 | 434 | class Food101_truncated(torch.utils.data.Dataset): 435 | def __init__(self, dataidxs=None, transform=None, loader=default_loader, mode = None): 436 | 437 | image_path = '/Food101_Image/images/' 438 | data_path = '/Food101_Text/' 439 | if mode == 'train': 440 | with io.open(data_path + 'train_images.txt', encoding='utf-8') as file: 441 | path_to_images = file.read().split('\n')[:-1] #list-len:68175 442 | with io.open(data_path + 'train_labels.txt', encoding='utf-8') as file: 443 | labels = file.read().split('\n')[:-1] #list-len:68175 444 | 445 | elif mode == 'test': 446 | 447 | with io.open(data_path + 'test_images.txt', encoding='utf-8') as file: 448 | path_to_images = file.read().split('\n')[:-1] #list-len:25250 449 | with io.open(data_path + 'test_labels.txt', encoding='utf-8') as file: 450 | labels = file.read().split('\n')[:-1] #list-len:25250 451 | 452 | elif mode == 'val': 453 | with io.open(data_path + 'val_images.txt', encoding='utf-8') as file: 454 | path_to_images = file.read().split('\n')[:-1] 455 | with io.open(data_path + 'val_labels.txt', encoding='utf-8') as file: 456 | labels = file.read().split('\n')[:-1] 457 | 458 | else: 459 | assert 1<0, 'Please fill mode with any of train/val/test to facilitate dataset creation' 460 | 461 | #import ipdb; ipdb.set_trace() 462 | if mode == 'train' and dataidxs != None: 463 | # print('xxxxxxxxx', path_to_images) 464 | self.image_path = image_path 465 | # ipdb.set_trace() 466 | self.path_to_images = np.array(path_to_images)[dataidxs] 467 | self.labels = np.array(labels, dtype=int)[dataidxs] 468 | print('mode:', mode, 'len(path_to_images):', len(self.path_to_images)) 469 | 470 | else: 471 | self.image_path = image_path 472 | self.path_to_images = path_to_images 473 | self.labels = np.array(labels, dtype=int) 474 | 475 | self.transform = transform 476 | self.loader = loader 477 | self.mode = mode 478 | 479 | 480 | def __getitem__(self, index): 481 | # get image matrix and transform to tensor 482 | path = self.path_to_images[index] 483 | img = self.loader(self.image_path + path + '.jpg') 484 | 485 | if self.transform is not None: 486 | img = self.transform(img) 487 | 488 | # get label 489 | label = self.labels[index] 490 | 491 | return img, label 492 | 493 | 494 | def __len__(self): 495 | return len(self.path_to_images) 496 | 497 | 498 | class Vireo172_truncated(torch.utils.data.Dataset): 499 | def __init__(self, dataidxs=None, transform=None, loader=default_loader, mode = None): 500 | 501 | image_path = '/Vireo172_Image/ready_chinese_food/' 502 | data_path = '/Vireo172_Text/SplitAndIngreLabel/' 503 | 504 | if mode == 'train': 505 | 506 | with io.open(data_path + 'TR.txt', encoding='utf-8') as file: 507 | path_to_images = file.read().split('\n')[:-1] 508 | labels = matio.loadmat(data_path + 'train_label.mat')['train_label'][0] 509 | 510 | elif mode == 'test': 511 | 512 | with io.open(data_path + 'TE.txt', encoding='utf-8') as file: 513 | path_to_images = file.read().split('\n')[:-1] 514 | labels = matio.loadmat(data_path + 'test_label.mat')['test_label'][0] 515 | 516 | elif mode == 'val': 517 | 518 | with io.open(data_path + 'VAL.txt', encoding='utf-8') as file: 519 | path_to_images = file.read().split('\n')[:-1] 520 | labels = matio.loadmat(data_path + 'val_label.mat')['validation_label'][0] 521 | 522 | else: 523 | assert 1<0, 'Please fill mode with any of train/val/test to facilitate dataset creation' 524 | 525 | #import ipdb; ipdb.set_trace() 526 | 527 | if mode == 'train' and dataidxs != None: 528 | # print('xxxxxxxxx', path_to_images) 529 | self.image_path = image_path 530 | self.path_to_images = np.array(path_to_images)[dataidxs] 531 | self.labels = np.array(labels, dtype=int)[dataidxs]-1 532 | print('mode:', mode, 'len(path_to_images):', len(self.path_to_images)) 533 | 534 | else: 535 | self.image_path = image_path 536 | self.path_to_images = path_to_images 537 | self.labels = np.array(labels, dtype=int)-1 538 | 539 | 540 | self.transform = transform 541 | self.loader = loader 542 | 543 | 544 | def __getitem__(self, index): 545 | # get image matrix and transform to tensor 546 | path = self.path_to_images[index] 547 | 548 | img = self.loader(self.image_path + path) 549 | 550 | if self.transform is not None: 551 | img = self.transform(img) 552 | 553 | # get label 554 | label = self.labels[index] 555 | 556 | #change vireo labels from 1-indexed to 0-indexed values 557 | 558 | 559 | return [img, label] 560 | 561 | def __len__(self): 562 | return len(self.path_to_images) 563 | 564 | class Data_for_label_cluster(torch.utils.data.Dataset): 565 | def __init__(self, data, labels, dataset): 566 | self.vectors = data 567 | self.classIDs = labels 568 | self.dataset = dataset 569 | 570 | def __getitem__(self, index): 571 | # get hidden vector 572 | vector = self.vectors[index] 573 | classID = self.classIDs[index] 574 | 575 | # ipdb.set_trace() 576 | if self.dataset == 'cifar10': 577 | label_indicator = np.zeros([10], dtype=np.float32) 578 | elif self.dataset == 'cifar100': 579 | label_indicator = np.zeros([100], dtype=np.float32) 580 | label_indicator[int(classID)] = 1 581 | # print(label_indicator) 582 | return [vector, label_indicator, index] 583 | 584 | def __len__(self): 585 | return len(self.classIDs) 586 | 587 | 588 | class Data_for_Retraining(torch.utils.data.Dataset): 589 | def __init__(self, data, labels, clientids): 590 | self.vectors = data 591 | self.classIDs = labels 592 | self.clientids = clientids 593 | 594 | def __getitem__(self, index): 595 | # get hidden vector 596 | vector = self.vectors[index] 597 | classID = self.classIDs[index] 598 | in_client_inx = np.where(self.clientids == self.clientids[index])[0] 599 | posi_data_index = self.__posi_sample__(index, in_client_inx) 600 | nega_data_index = self.__nega_sample__(index, in_client_inx) 601 | 602 | posi_vector = self.vectors[posi_data_index] 603 | nega_vector = self.vectors[nega_data_index] 604 | 605 | return vector, posi_vector, nega_vector, classID 606 | 607 | 608 | def __posi_sample__(self, index, in_client_inx, k=1): 609 | target = self.classIDs[index] 610 | posi_idx = np.where(self.classIDs == target)[0] 611 | posi_inx_intersection = np.intersect1d(posi_idx, in_client_inx) 612 | 613 | if len(posi_inx_intersection) >= k: 614 | posi_data_index = np.random.choice(posi_inx_intersection, k) 615 | else: 616 | posi_data_index = [index] 617 | 618 | return posi_data_index 619 | 620 | def __nega_sample__(self, index, in_client_inx, k=1): 621 | target = self.classIDs[index] 622 | nega_idx = np.where(self.classIDs != target)[0] 623 | nega_inx_intersection = np.intersect1d(nega_idx, in_client_inx) 624 | 625 | if len(nega_inx_intersection) >= k: 626 | nega_data_index = np.random.choice(nega_inx_intersection, k) 627 | else: 628 | nega_data_index = [index] 629 | 630 | return nega_data_index 631 | 632 | 633 | def __len__(self): 634 | return len(self.classIDs) 635 | 636 | 637 | 638 | class Data_for_Retraining_final(torch.utils.data.Dataset): 639 | def __init__(self, data, labels, clientids): 640 | self.vectors = data 641 | self.classIDs = labels 642 | self.clientids = clientids 643 | 644 | def __getitem__(self, index): 645 | # get hidden vector 646 | vector = self.vectors[index] 647 | classID = self.classIDs[index] 648 | in_client_inx = np.where(self.clientids == self.clientids[index])[0] 649 | posi_data_index = self.__posi_sample__(index, in_client_inx) 650 | nega_data_index = self.__nega_sample__(index, in_client_inx) 651 | 652 | posi_vector = self.vectors[posi_data_index] 653 | nega_vector = self.vectors[nega_data_index] 654 | 655 | return torch.tensor(vector), torch.tensor(posi_vector), torch.tensor(nega_vector), torch.tensor([classID]) 656 | 657 | 658 | def __posi_sample__(self, index, in_client_inx, k=1): 659 | target = self.classIDs[index] 660 | posi_idx = np.where(self.classIDs == target)[0] 661 | posi_inx_intersection = np.intersect1d(posi_idx, in_client_inx) 662 | 663 | if len(posi_idx) >= k: 664 | posi_data_index = np.random.choice(posi_inx_intersection, k) 665 | else: 666 | posi_data_index = [index] 667 | 668 | return posi_data_index 669 | 670 | def __nega_sample__(self, index, in_client_inx, k=1): 671 | target = self.classIDs[index] 672 | nega_idx = np.where(self.classIDs != target)[0] 673 | nega_inx_intersection = np.intersect1d(nega_idx, in_client_inx) 674 | 675 | if len(nega_idx) >= k: 676 | nega_data_index = np.random.choice(nega_inx_intersection, k) 677 | else: 678 | nega_data_index = [index] 679 | 680 | return nega_data_index 681 | 682 | 683 | def __len__(self): 684 | return len(self.classIDs) 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qizhuang-qz/FedCSPC/e1630bf8fb4773be71ac5a471b57dd885ea29919/framework.png -------------------------------------------------------------------------------- /kmeans.py: -------------------------------------------------------------------------------- 1 | # kmeans clustering and assigning sample weight based on cluster information 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from sklearn.cluster import KMeans 6 | import logging 7 | import os 8 | import random 9 | import torch 10 | import time 11 | from tqdm import tqdm 12 | from sklearn.manifold import TSNE 13 | 14 | class KMEANS: 15 | def __init__(self, n_clusters, max_iter, device=torch.device("cpu")): 16 | 17 | self.n_clusters = n_clusters 18 | self.labels = None 19 | self.dists = None # shape: [x.shape[0],n_cluster] 20 | self.centers = None 21 | self.max_iter = max_iter 22 | self.count = 0 23 | self.device = device 24 | 25 | def fit(self, x): 26 | # 随机选择初始中心点,想更快的收敛速度可以借鉴sklearn中的kmeans++初始化方法 27 | init_row = torch.randint(0, x.shape[0], (self.n_clusters,)).to(self.device) 28 | init_points = torch.tensor(x[init_row.cpu().numpy().astype(int)]) 29 | self.centers = init_points 30 | while True: 31 | # print(self.count) 32 | # 聚类标记 33 | self.nearest_center(x) 34 | # 更新中心点 35 | self.update_center(x) 36 | 37 | if self.count == self.max_iter: 38 | break 39 | 40 | self.count += 1 41 | return self.labels 42 | 43 | def nearest_center(self, x): 44 | labels = torch.empty((x.shape[0],)).long().to(self.device) 45 | dists = torch.empty((0, self.n_clusters)).to(self.device) 46 | x = torch.tensor(x) 47 | for i, sample in enumerate(x): 48 | dist = torch.sum(torch.mul(sample - self.centers, sample - self.centers), (1)) 49 | labels[i] = torch.argmin(dist) 50 | dists = torch.cat([dists, dist.unsqueeze(0)], (0)) 51 | self.labels = labels 52 | self.dists = dists 53 | 54 | def update_center(self, x): 55 | centers = torch.empty((0, x.shape[1])).to(self.device) 56 | x = torch.tensor(x) 57 | for i in range(self.n_clusters): 58 | mask = self.labels == i 59 | cluster_samples = x[mask] 60 | 61 | # print('cluster_samples', cluster_samples.shape) 62 | # print('centers', centers.shape) 63 | 64 | if len(cluster_samples.shape) == 1: 65 | if cluster_samples.shape[0] == 0: 66 | centers = torch.cat([centers, self.centers[i].unsqueeze(0)], (0)) 67 | else: 68 | cluster_samples.reshape((-1, cluster_samples.shape[0])) 69 | else: 70 | centers = torch.cat([centers, torch.mean(cluster_samples, (0)).unsqueeze(0)], (0)) 71 | self.centers = centers 72 | 73 | def normalization(data): 74 | _range = np.max(data) - np.min(data) 75 | return (data - np.min(data)) / _range 76 | 77 | 78 | def standardization(data): 79 | mu = np.mean(data, axis=0) 80 | sigma = np.std(data, axis=0) 81 | return (data - mu) / sigma 82 | 83 | 84 | if __name__ == "__main__": 85 | seed = 1 86 | i = 2 87 | round = 24 88 | # np.random.seed(seed) 89 | # torch.manual_seed(seed) 90 | # if torch.cuda.is_available(): 91 | # torch.cuda.manual_seed(seed) 92 | # random.seed(seed) 93 | cluster_mode = 1 94 | c_list_7 = ['skyblue', 'lightpink', 'chocolate', 'silver', 'violet'] 95 | c_list_2 = ['cornflowerblue', 'brown', 'orange', 'forestgreen', 'purple'] 96 | 97 | c_list_1 = [] 98 | 99 | # 黑色、红色、橘色、巧克力色、绿色、粉色、灰色、蓝色、黄色、黄绿色 100 | ts = TSNE(n_components=2, init='pca', random_state=50, perplexity=100) # , metric='cosine' 101 | font1 = { 102 | 'weight': 'normal', 103 | 'size': 30, 104 | } 105 | if cluster_mode: 106 | N = 4 107 | M = 50 108 | number = 3 109 | beforepath = './feats/cnn/cifar10/' + str(round) + '/' + str(i) + '/case_feats.npy' 110 | labelpath = './feats/cnn/cifar10/' + str(round) + '/' + str(i) + '/case_labels.npy' 111 | before = np.load(beforepath) 112 | label = np.load(labelpath) 113 | print(before.shape) 114 | 115 | class_idx_7 = np.where(label == 3)[0] 116 | class_idx_2 = np.where(label == 7)[0] 117 | # class_idx_1 = np.where(label == 9)[0] 118 | 119 | kmeans_7 = KMEANS(n_clusters=N, max_iter=M) 120 | predict_labels_7 = kmeans_7.fit(before[class_idx_7]) 121 | 122 | kmeans_2 = KMEANS(n_clusters=N, max_iter=M) 123 | predict_labels_2 = kmeans_2.fit(before[class_idx_2]) 124 | 125 | 126 | data = np.concatenate([before[class_idx_7], before[class_idx_2]]) 127 | print(data.shape) 128 | cluster_7_set, unq_cluster_7_size = np.unique(predict_labels_7, return_counts=True) 129 | cluster_2_set, unq_cluster_2_size = np.unique(predict_labels_2, return_counts=True) 130 | # cluster_1_set, unq_cluster_1_size = np.unique(predict_labels_1, return_counts=True) 131 | print(cluster_7_set, cluster_2_set) 132 | cluster_2_set = cluster_2_set + N 133 | 134 | # data_tsne = ts.fit_transform(data) 135 | # data_tsne = normalization(data_tsne) 136 | data_tsne = np.load('./feats/cnn/cifar10/' + str(round) + '/' + str(i) + '/tsne_feats.npy') 137 | 138 | 139 | 140 | t = 0 141 | for i in range(len(class_idx_7)): 142 | k = np.where(cluster_7_set == int(predict_labels_7[i]))[0][0] 143 | plt.scatter(data_tsne[i][0], data_tsne[i][1], marker=',', c=c_list_7[k], s=20) 144 | t += 1 145 | 146 | 147 | for i in range(len(class_idx_2)): 148 | k = np.where(cluster_2_set == int(predict_labels_2[i])+N)[0][0] 149 | plt.scatter(data_tsne[t + i][0], data_tsne[t + i][1], marker='^', c=c_list_2[k], s=20) 150 | 151 | assign_7 = predict_labels_7 152 | assign_2 = predict_labels_2 + N 153 | assign = np.concatenate([assign_7, assign_2]) 154 | print(cluster_2_set) 155 | 156 | for j in cluster_7_set: 157 | idx_j = np.where(assign == j)[0] 158 | print(idx_j) 159 | for A in range(number): # len(class_idx[i]) 160 | idx = np.random.choice(np.arange(len(idx_j)), int(len(idx_j)*0.05)) #int(len(idx_j) * 0.2 161 | feature_classwise = np.mean(data_tsne[idx_j[idx]], axis=0) 162 | print(feature_classwise.shape) 163 | plt.scatter(feature_classwise[0], feature_classwise[1], marker=',', c='red', s=120) 164 | 165 | 166 | 167 | for j in cluster_2_set: 168 | idx_j = np.where(assign == j)[0] 169 | print(idx_j) 170 | for A in range(number): # len(class_idx[i]) 171 | idx = np.random.choice(np.arange(len(idx_j)), int(len(idx_j)*0.05)) 172 | feature_classwise = np.mean(data_tsne[idx_j[idx]], axis=0) 173 | plt.scatter(feature_classwise[0], feature_classwise[1], marker='^', c='red', s=120) 174 | 175 | # print(data.shape) 176 | # 177 | # 178 | # m = data.shape[0]-len(class_idx_2)-len(class_idx_7) 179 | # for i in range(m): 180 | # if i 0.2, weight_vector < 0.5)).reshape(-1) 241 | 242 | weight_index_hard = torch.nonzero(weight_vector < 0.2).reshape(-1) 243 | 244 | label_easy = label[np.array(weight_index_easy)] 245 | label_med = label[np.array(weight_index_med)] 246 | label_hard = label[np.array(weight_index_hard)] 247 | 248 | sta_easy = [] 249 | sta_med = [] 250 | sta_hard = [] 251 | for i in range(101): 252 | easy_id = np.where(label_easy == i)[0] 253 | med_id = np.where(label_med == i)[0] 254 | hard_id = np.where(label_hard == i)[0] 255 | sta_easy.append(len(easy_id)) 256 | sta_med.append(len(med_id)) 257 | sta_hard.append(len(hard_id)) 258 | 259 | import ipdb; 260 | 261 | ipdb.set_trace() 262 | 263 | np.save('./weights/food101/weight_index_easy.npy', weight_index_easy) 264 | np.save('./weights/food101/weight_index_med.npy', weight_index_med) 265 | np.save('./weights/food101/weight_index_hard.npy', weight_index_hard) 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torch.utils.data as data 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import random 11 | from sklearn.metrics import confusion_matrix 12 | from torchvision import datasets 13 | 14 | import ipdb 15 | 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | def supervised_contrastive_loss(features, labels, temperature=0.07): 21 | """ 22 | 带标签的对比损失函数的实现。 23 | 24 | 参数: 25 | - features:形状为 (batch_size, embedding_size) 的张量,表示输入的特征向量。 26 | - labels:形状为 (batch_size,) 的张量,表示输入的样本标签。 27 | - temperature:温度参数。 28 | 29 | 返回值: 30 | - loss:对比损失。 31 | """ 32 | 33 | # 将特征向量 L2 归一化 34 | features = F.normalize(features, dim=1) 35 | 36 | # 对所有样本计算相似度矩阵 37 | similarity_matrix = torch.matmul(features, features.T) / temperature 38 | 39 | # 将对角线的值排除在外,避免同一样本与自身比较 40 | mask = torch.eye(labels.size(0), dtype=torch.bool).cuda() 41 | similarity_matrix = similarity_matrix.masked_fill(mask, 1) 42 | ipdb.set_trace() 43 | # 计算每个样本的正样本对的对比损失和负样本对的对比损失 44 | pos_pairs_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)).bool() 45 | neg_pairs_mask = ~pos_pairs_mask 46 | 47 | # 计算正样本对的对比损失 48 | pos_pairs_similarity = similarity_matrix[pos_pairs_mask] 49 | pos_pairs_loss = -torch.log(pos_pairs_similarity / torch.sum(similarity_matrix)) 50 | 51 | # 计算负样本对的对比损失 52 | neg_pairs_similarity = similarity_matrix[neg_pairs_mask].view(labels.size(0), -1) 53 | neg_pairs_loss = -torch.log(torch.sum(torch.exp(neg_pairs_similarity), dim=1) / torch.sum(neg_pairs_mask, dim=1)) 54 | 55 | # 对所有样本的对比损失取平均 56 | loss = torch.mean(torch.cat([pos_pairs_loss, neg_pairs_loss])) 57 | 58 | return loss 59 | 60 | 61 | 62 | 63 | 64 | class SupervisedContrastiveLoss(torch.nn.Module): 65 | def __init__(self, temperature=0.07): 66 | super(SupervisedContrastiveLoss, self).__init__() 67 | self.temperature = temperature 68 | 69 | def forward(self, x, y): 70 | # x: the feature representations of the samples 71 | # y: the ground truth labels 72 | 73 | # normalize the feature vectors 74 | x = F.normalize(x, dim=1) 75 | 76 | # compute the similarity matrix 77 | sim_matrix = torch.matmul(x, x.t()) / self.temperature 78 | 79 | # generate the mask for positive and negative pairs 80 | mask = torch.eq(y.unsqueeze(0), y.unsqueeze(1)).float() 81 | mask = mask / mask.sum(dim=1, keepdim=True) 82 | 83 | # calculate the contrastive loss 84 | loss = (-torch.log_softmax(sim_matrix, dim=1) * mask).sum(dim=1).mean() 85 | 86 | return loss 87 | 88 | 89 | def nt_xent(x1, x2, t=0.07): 90 | """Contrastive loss objective function""" 91 | x1 = F.normalize(x1, dim=1) 92 | x2 = F.normalize(x2, dim=1) 93 | batch_size = x1.size(0) 94 | out = torch.cat([x1, x2], dim=0) 95 | sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / t) 96 | mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool() 97 | sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1) 98 | pos_sim = torch.exp(torch.sum(x1 * x2, dim=-1) / t) 99 | pos_sim = torch.cat([pos_sim, pos_sim], dim=0) 100 | loss = (-torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() 101 | return loss 102 | 103 | 104 | def PCLoss(features, f_labels, prototypes, p_labels, t=0.5): 105 | 106 | a_norm = features / features.norm(dim=1)[:, None] 107 | b_norm = prototypes / prototypes.norm(dim=1)[:, None] 108 | sim_matrix = torch.exp(torch.mm(a_norm, b_norm.transpose(0,1)) / t) 109 | 110 | pos_sim = torch.exp(torch.diag(torch.mm(a_norm, b_norm[f_labels].transpose(0,1))) / t) 111 | 112 | loss = (-torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() 113 | 114 | return loss 115 | 116 | def refine_as_not_true(logits, targets, num_classes): 117 | nt_positions = torch.arange(0, num_classes).to(logits.device) 118 | nt_positions = nt_positions.repeat(logits.size(0), 1) 119 | nt_positions = nt_positions[nt_positions[:, :] != targets.view(-1, 1)] 120 | nt_positions = nt_positions.view(-1, num_classes - 1) 121 | 122 | logits = torch.gather(logits, 1, nt_positions) 123 | 124 | return logits 125 | 126 | class NTD_Loss(nn.Module): 127 | """Not-true Distillation Loss""" 128 | 129 | def __init__(self, num_classes=10, tau=3, lamb=1): 130 | super(NTD_Loss, self).__init__() 131 | self.CE = nn.CrossEntropyLoss() 132 | self.MSE = nn.MSELoss() 133 | self.KLDiv = nn.KLDivLoss(reduction="batchmean") 134 | self.num_classes = num_classes 135 | self.tau = tau 136 | self.beta = lamb 137 | 138 | def forward(self, logits, targets, dg_logits): 139 | ce_loss = self.CE(logits, targets) 140 | ntd_loss = self._ntd_loss(logits, dg_logits, targets) 141 | 142 | loss = ce_loss + self.beta * ntd_loss 143 | 144 | return loss 145 | 146 | def _ntd_loss(self, logits, dg_logits, targets): 147 | """Not-tue Distillation Loss""" 148 | 149 | # Get smoothed local model prediction 150 | logits = refine_as_not_true(logits, targets, self.num_classes) 151 | pred_probs = F.log_softmax(logits / self.tau, dim=1) 152 | 153 | # Get smoothed global model prediction 154 | with torch.no_grad(): 155 | dg_logits = refine_as_not_true(dg_logits, targets, self.num_classes) 156 | dg_probs = torch.softmax(dg_logits / self.tau, dim=1) 157 | 158 | loss = (self.tau ** 2) * self.KLDiv(pred_probs, dg_probs) 159 | 160 | return loss 161 | 162 | class CrossEntropyLoss(torch.nn.Module): 163 | def __init__(self, reduction='mean'): 164 | super(CrossEntropyLoss, self).__init__() 165 | self.reduction = reduction 166 | 167 | def forward(self, logits, target, weights): 168 | # logits: [N, C, H, W], target: [N, H, W] 169 | # loss = sum(-y_i * log(c_i)) 170 | 171 | if logits.dim() > 2: 172 | logits = logits.view(logits.size(0), logits.size(1), -1) # [N, C, HW] 173 | logits = logits.transpose(1, 2) # [N, HW, C] 174 | logits = logits.contiguous().view(-1, logits.size(2)) # [NHW, C] 175 | target = target.view(-1, 1) # [NHW,1] 176 | 177 | logits = F.log_softmax(logits, 1) 178 | 179 | # import ipdb; ipdb.set_trace() 180 | 181 | logits = logits.gather(1, target).reshape(-1) # [NHW, 1] 182 | 183 | loss = -1 * (torch.mul(logits, weights)) 184 | 185 | if self.reduction == 'mean': 186 | loss = loss.mean() 187 | elif self.reduction == 'sum': 188 | loss = loss.sum() 189 | return torch.sum(loss, 0) 190 | 191 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 192 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 193 | 194 | 195 | 196 | class TripletLoss(nn.Module): 197 | def __init__(self, margin=1.0): 198 | super(TripletLoss, self).__init__() 199 | self.margin = margin 200 | 201 | def forward(self, anchor, positive, negative): 202 | dist_pos = torch.norm(anchor - positive, 2, dim=1) 203 | dist_neg = torch.norm(anchor - negative, 2, dim=1) 204 | loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0)) 205 | return loss 206 | -------------------------------------------------------------------------------- /main_CSPC.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import argparse 7 | import logging 8 | import os 9 | import copy 10 | import datetime 11 | import random 12 | # os.environ['CUDA_VISIBLE_DEVICES']='0' 13 | from resnet import resnet18, DNN 14 | from model import ModelFedCon 15 | from utils import * 16 | from loss import * 17 | from re_training import * 18 | import torch.nn.functional as F 19 | from RGA import * 20 | import ipdb 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--model', type=str, default='resnet18', help='neural network used in training') 25 | parser.add_argument('--dataset', type=str, default='cifar10', help='dataset used for training') 26 | parser.add_argument('--net_config', type=lambda x: list(map(int, x.split(', ')))) 27 | parser.add_argument('--partition', type=str, default='noniid', help='the data partitioning strategy') 28 | parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)') 29 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 0.1)') 30 | parser.add_argument('--epochs', type=int, default=10, help='number of local epochs') 31 | parser.add_argument('--n_parties', type=int, default=10, help='number of workers in a distributed cluster') 32 | parser.add_argument('--alg', type=str, default='fedavg', 33 | help='communication strategy: fedavg/fedprox') 34 | parser.add_argument('--comm_round', type=int, default=100, help='number of maximum communication roun') 35 | parser.add_argument('--init_seed', type=int, default=0, help="Random seed") 36 | parser.add_argument('--dropout_p', type=float, required=False, default=0.0, help="Dropout probability. Default=0.0") 37 | parser.add_argument('--datadir', type=str, required=False, default="./data/", help="Data directory") 38 | parser.add_argument('--reg', type=float, default=1e-5, help="L2 regularization strength") 39 | parser.add_argument('--logdir', type=str, required=False, default="./logs/N+E/cifar10/", help='Log directory path') 40 | parser.add_argument('--modeldir', type=str, required=False, default="./models/", help='Model directory path') 41 | parser.add_argument('--beta', type=float, default=0.5, 42 | help='The parameter for the dirichlet distribution for data partitioning') 43 | 44 | parser.add_argument('--mu', type=float, default=0.5, 45 | help='The parameter for the weight of RGA loss') 46 | parser.add_argument('--node_weight', default=1.0, type=float, help='The weight of node-based loss') 47 | parser.add_argument('--edge_weight', default=0.2, type=float, help='The weight of edge-based loss') 48 | parser.add_argument('--angle_weight', default=0.2, type=float, help='The weight of angle-based loss') 49 | parser.add_argument('--mode', default='N', type=str, help='The loss mode') 50 | parser.add_argument('--device', type=str, default='cuda:0', help='The device to run the program') 51 | parser.add_argument('--log_file_name', type=str, default=None, help='The log file name') 52 | parser.add_argument('--optimizer', type=str, default='sgd', help='the optimizer') 53 | parser.add_argument('--out_dim', type=int, default=256, help='the output dimension for the projection layer') 54 | parser.add_argument('--temperature', type=float, default=0.5, help='the temperature parameter for contrastive loss') 55 | parser.add_argument('--local_max_epoch', type=int, default=100, 56 | help='the number of epoch for local optimal training') 57 | parser.add_argument('--model_buffer_size', type=int, default=1, 58 | help='store how many previous models for contrastive loss') 59 | parser.add_argument('--pool_option', type=str, default='FIFO', help='FIFO or BOX') 60 | parser.add_argument('--sample_fraction', type=float, default=1, help='how many clients are sampled in each round') 61 | parser.add_argument('--load_model_file', type=str, default=None, help='the model to load as global model') 62 | parser.add_argument('--load_pool_file', type=str, default=None, help='the old model pool path to load') 63 | parser.add_argument('--load_model_round', type=int, default=None, help='how many rounds have executed for the loaded model') 64 | parser.add_argument('--load_first_net', type=int, default=1, help='whether load the first net as old net or not') 65 | parser.add_argument('--normal_model', type=int, default=0, help='use normal model or aggregate model') 66 | parser.add_argument('--loss', type=str, default='contrastive') 67 | parser.add_argument('--save_model', type=int, default=0) 68 | parser.add_argument('--use_project_head', type=int, default=1) 69 | parser.add_argument('--ratio', type=float, default=0.8) 70 | parser.add_argument('--number', type=int, default=5) 71 | parser.add_argument('--re_mu', type=float, default=0.1) 72 | parser.add_argument('--re_beta', type=float, default=0.1) 73 | parser.add_argument('--temp_final', type=float, default=0.5) 74 | parser.add_argument('--re_version', type=str, default='v1') 75 | parser.add_argument('--final_weights', type=float, default=0.1) 76 | parser.add_argument('--posi_lambda', type=float, default=0.5) 77 | parser.add_argument('--nega_lambda', type=float, default=0.5) 78 | parser.add_argument('--server_momentum', type=float, default=0, help='the server momentum (FedAvgM)') 79 | args = parser.parse_args() 80 | return args 81 | 82 | 83 | 84 | 85 | def init_nets(net_configs, n_parties, args, n_classes, device='cuda:0'): 86 | nets = {net_i: None for net_i in range(n_parties)} 87 | 88 | for net_i in range(n_parties): 89 | if 'cifar' in args.dataset: 90 | net = ModelFedCon(args.model, args.out_dim, n_classes, net_configs) 91 | net.to(device) 92 | nets[net_i] = net 93 | if args.model == 'simple-cnn': 94 | dnn = DNN(input_dim=84, hidden_dims=[84, 256], n_classes=n_classes).to(device) 95 | elif args.model == 'resnet18': 96 | dnn = DNN(input_dim=512, hidden_dims=[512, 256], n_classes=n_classes).to(device) 97 | else: 98 | # ipdb.set_trace() 99 | if args.dataset == 'vireo172' or args.dataset == 'food101': 100 | net = resnet18(args.dataset, kernel_size=7, pretrained=False) 101 | nets[net_i] = net 102 | else: 103 | net = resnet18(args.dataset, kernel_size=3, pretrained=False) 104 | nets[net_i] = net 105 | dnn = DNN(input_dim=512, hidden_dims=[512, 512], n_classes=n_classes).to(device) 106 | model_meta_data = [] 107 | layer_type = [] 108 | for (k, v) in nets[0].state_dict().items(): 109 | model_meta_data.append(v.shape) 110 | layer_type.append(k) 111 | 112 | return nets, dnn, model_meta_data, layer_type 113 | 114 | 115 | 116 | 117 | 118 | def train_net_CSPC(net_id, net, train_dataloader, test_dataloader, anchors, epochs, lr, args_optimizer, args, n_classes=200, device="cuda:0"): 119 | 120 | net.to(device) 121 | 122 | logger.info('Training network %s' % str(net_id)) 123 | logger.info('n_training: %d' % len(train_dataloader)) 124 | logger.info('n_test: %d' % len(test_dataloader)) 125 | 126 | if args_optimizer == 'adam': 127 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg) 128 | elif args_optimizer == 'amsgrad': 129 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg, 130 | amsgrad=True) 131 | elif args_optimizer == 'sgd': 132 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=0.9, 133 | weight_decay=args.reg) 134 | 135 | criterion = nn.CrossEntropyLoss().to(device) 136 | 137 | cnt = 0 138 | net.train() 139 | 140 | anchors_label = torch.tensor(list(range(n_classes))).to(device) 141 | RGAloss = RGA_loss(node_weight=args.node_weight, edge_weight=args.edge_weight, angle_weight=args.angle_weight, 142 | t=args.temperature) 143 | 144 | for epoch in range(epochs): 145 | epoch_loss_collector = [] 146 | for batch_idx, (x, target) in enumerate(train_dataloader): 147 | # ipdb.set_trace() 148 | x, target = x.to(device), target.to(device) 149 | if args.dataset == 'pmnist': 150 | target = target.reshape(-1) 151 | optimizer.zero_grad() 152 | x.requires_grad = True 153 | target.requires_grad = False 154 | target = target.long() 155 | h, h_out, _, out = net(x) 156 | if round == 0: 157 | loss = criterion(out, target) + criterion(h_out, target) 158 | else: 159 | # ipdb.set_trace() 160 | loss = criterion(out, target) + criterion(h_out, target) + args.mu * RGAloss(h, target, anchors, anchors_label, 'N') 161 | loss.backward() 162 | optimizer.step() 163 | 164 | cnt += 1 165 | epoch_loss_collector.append(loss.item()) 166 | 167 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) 168 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss)) 169 | logger.info(' ** Training complete **') 170 | 171 | net.to('cpu') 172 | 173 | return 0, 0 174 | 175 | 176 | def local_train_net(nets, args, net_dataidx_map, local_proto_list, local_proto_label_list, net_id_list, train_dl=None, test_dl=None, round=None, global_model=None, prev_model_pool=None, anchors=None, n_classes=None, device="cuda:0"): 177 | avg_acc = 0.0 178 | nets_global = copy.deepcopy(nets) 179 | k = 0 180 | for net_id, net in nets.items(): 181 | print(net_id) 182 | 183 | dataidxs = net_dataidx_map[net_id] 184 | 185 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 186 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs) 187 | 188 | n_epoch = args.epochs 189 | 190 | if args.alg == 'CSPC': 191 | _, _ = train_net_CSPC(net_id, net, train_dl_local, test_dl, anchors, n_epoch, args.lr, args.optimizer, args, n_classes=n_classes, device=device) 192 | 193 | local_protos, local_labels = dropout_proto_local_v2(net, train_dl_local, args, n_class=n_classes) 194 | 195 | if k == 0: 196 | # ipdb.set_trace() 197 | clients_ids = torch.tensor([net_id] * local_labels.shape[0]) 198 | else: 199 | clients_ids = torch.cat([clients_ids, torch.tensor([net_id] * local_labels.shape[0])]) 200 | # print(local_proto_list) 201 | local_proto_list[net_id] = local_protos 202 | local_proto_label_list[net_id] = local_labels 203 | net_id_list[net_id] = 1 204 | k+=1 205 | return nets, local_proto_list, local_proto_label_list, net_id_list, clients_ids 206 | 207 | 208 | if __name__ == '__main__': 209 | args = get_args() 210 | mkdirs(args.logdir) 211 | mkdirs(args.modeldir) 212 | if args.log_file_name is None: 213 | argument_path = 'experiment_arguments-%s.json' % datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") 214 | else: 215 | argument_path = args.log_file_name + '.json' 216 | with open(os.path.join(args.logdir, argument_path), 'w') as f: 217 | json.dump(str(args), f) 218 | device = torch.device(args.device) 219 | for handler in logging.root.handlers[:]: 220 | logging.root.removeHandler(handler) 221 | 222 | if args.log_file_name is None: 223 | args.log_file_name = 'experiment_log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")) 224 | log_path = args.log_file_name + '.log' 225 | logging.basicConfig( 226 | filename=os.path.join(args.logdir, log_path), 227 | format='%(asctime)s %(levelname)-8s %(message)s', 228 | datefmt='%m-%d %H:%M', level=logging.DEBUG, filemode='w') 229 | 230 | logger = logging.getLogger() 231 | logger.setLevel(logging.DEBUG) 232 | logger.info(device) 233 | 234 | seed = args.init_seed 235 | logger.info("#" * 100) 236 | np.random.seed(seed) 237 | torch.manual_seed(seed) 238 | if torch.cuda.is_available(): 239 | torch.cuda.manual_seed(seed) 240 | random.seed(seed) 241 | 242 | 243 | logger.info("Partitioning data") 244 | X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data( 245 | args.dataset, args.datadir, args.logdir, args.partition, args.n_parties, beta=args.beta) 246 | 247 | n_party_per_round = int(args.n_parties * args.sample_fraction) 248 | party_list = [i for i in range(args.n_parties)] 249 | global_party_list = [i for i in range(args.n_parties)] 250 | party_list_rounds = [] 251 | if n_party_per_round != args.n_parties: 252 | for i in range(args.comm_round): 253 | party_list_rounds.append(random.sample(party_list, n_party_per_round)) 254 | else: 255 | for i in range(args.comm_round): 256 | party_list_rounds.append(party_list) 257 | 258 | n_classes = len(np.unique(y_train)) 259 | 260 | train_dl_global, test_dl, train_ds_global, test_ds_global = get_dataloader(args.dataset, args.datadir, 261 | args.batch_size, 32) 262 | 263 | print("len train_dl_global:", len(train_ds_global)) 264 | train_dl = None 265 | data_size = len(test_ds_global) 266 | 267 | logger.info("Initializing nets") 268 | nets, _, local_model_meta_data, layer_type = init_nets(args.net_config, args.n_parties, args, n_classes, device=device) 269 | 270 | global_models, global_dnn, global_model_meta_data, global_layer_type = init_nets(args.net_config, 1, args, n_classes, device=device) 271 | global_model = global_models[0] 272 | 273 | n_comm_rounds = args.comm_round 274 | 275 | 276 | if args.load_model_file and args.alg != 'plot_visual': 277 | global_model.load_state_dict(torch.load(args.load_model_file)) 278 | n_comm_rounds -= args.load_model_round 279 | 280 | if args.server_momentum: 281 | moment_v = copy.deepcopy(global_model.state_dict()) 282 | for key in moment_v: 283 | moment_v[key] = 0 284 | 285 | 286 | local_proto_list = [[]]*args.n_parties 287 | local_proto_label_list = [[]]*args.n_parties 288 | net_id_list = torch.tensor([0]*args.n_parties) 289 | if args.alg == 'CSPC': 290 | anchors=0 291 | for round in range(n_comm_rounds): 292 | logger.info("in comm round:" + str(round)) 293 | party_list_this_round = party_list_rounds[round] 294 | global_w = global_model.state_dict() 295 | 296 | nets_this_round = {k: nets[k] for k in party_list_this_round} 297 | 298 | 299 | for net in nets_this_round.values(): 300 | net.load_state_dict(global_w) 301 | 302 | nets_this_round, local_proto_list, local_proto_label_list, net_id_list, clients_ids = local_train_net(nets_this_round, args, net_dataidx_map, local_proto_list, local_proto_label_list, net_id_list, round=round, train_dl=train_dl, anchors=anchors, test_dl=test_dl, n_classes=n_classes, device=device) 303 | global_model.to('cpu') 304 | 305 | # update global model 306 | 307 | global_nets_this_round = {k: nets[k] for k in global_party_list} 308 | 309 | total_data_points = sum([len(net_dataidx_map[r]) for r in global_party_list]) 310 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in global_party_list] 311 | 312 | for net_id, net in enumerate(global_nets_this_round.values()): 313 | net.to('cpu') 314 | net_para = net.state_dict() 315 | if net_id == 0: 316 | for key in net_para: 317 | global_w[key] = net_para[key] * fed_avg_freqs[net_id] 318 | else: 319 | for key in net_para: 320 | global_w[key] += net_para[key] * fed_avg_freqs[net_id] 321 | 322 | global_model.load_state_dict(global_w) 323 | 324 | logger.info('global n_training: %d' % len(train_dl_global)) 325 | logger.info('global n_test: %d' % len(test_dl)) 326 | 327 | 328 | # re_training 329 | id_list = torch.nonzero(net_id_list == 1).reshape(-1) 330 | # ipdb.set_trace() 331 | for net_id, idx in enumerate(id_list): 332 | # ipdb.set_trace() 333 | if net_id == 0: 334 | local_protos = local_proto_list[idx] 335 | local_labels = local_proto_label_list[idx] 336 | else: 337 | local_protos = torch.cat([local_protos, local_proto_list[idx]]) 338 | local_labels = torch.cat([local_labels, local_proto_label_list[idx]]) 339 | 340 | anchors, labels = gen_proto_global(local_protos, local_labels, n_classes) 341 | 342 | global_dnn = get_updateModel_before(global_dnn, global_model) 343 | # clients_ids = clients_ids.cpu().numpy() 344 | global_dnn, glo_proto, glo_proto_label = retrain_cls_final(global_dnn, local_protos, local_labels, clients_ids, n_classes, args, round, device) 345 | global_model.to(device) 346 | global_model = get_updateModel_after(global_model, global_dnn) 347 | 348 | 349 | # train_acc, train_loss = compute_accuracy_v6(global_model, train_dl_global, device=device) 350 | acc_out, acc_sim, acc_final, conf_matrix, _ = compute_accuracy_v6(global_model, glo_proto, glo_proto_label, test_dl, args, get_confusion_matrix=True, device=device) 351 | 352 | 353 | logger.info('>> Global Model Test_out accuracy: %f' % acc_out) 354 | logger.info('>> Global Model Test_sim accuracy: %f' % acc_sim) 355 | logger.info('>> Global Model Test_final accuracy: %f' % acc_final) 356 | 357 | 358 | 359 | 360 | -------------------------------------------------------------------------------- /main_FedAvg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import argparse 7 | import logging 8 | import os 9 | import copy 10 | import datetime 11 | import random 12 | # os.environ['CUDA_VISIBLE_DEVICES']='0' 13 | from resnet import resnet18 14 | from model_v5 import * 15 | from utils import * 16 | from loss import * 17 | import torch.nn.functional as F 18 | import ipdb 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model', type=str, default='simple-cnn', help='neural network used in training') 23 | parser.add_argument('--dataset', type=str, default='cifar10', help='dataset used for training') 24 | parser.add_argument('--net_config', type=lambda x: list(map(int, x.split(', ')))) 25 | parser.add_argument('--partition', type=str, default='noniid', help='the data partitioning strategy') 26 | parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)') 27 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 0.1)') 28 | parser.add_argument('--epochs', type=int, default=10, help='number of local epochs') 29 | parser.add_argument('--n_parties', type=int, default=10, help='number of workers in a distributed cluster') 30 | parser.add_argument('--alg', type=str, default='fedavg', 31 | help='communication strategy: fedavg/fedprox') 32 | parser.add_argument('--comm_round', type=int, default=100, help='number of maximum communication roun') 33 | parser.add_argument('--init_seed', type=int, default=0, help="Random seed") 34 | parser.add_argument('--dropout_p', type=float, required=False, default=0.0, help="Dropout probability. Default=0.0") 35 | parser.add_argument('--datadir', type=str, required=False, default="./data/", help="Data directory") 36 | parser.add_argument('--reg', type=float, default=1e-5, help="L2 regularization strength") 37 | parser.add_argument('--logdir', type=str, required=False, default="./logs/N+E/cifar10/", help='Log directory path') 38 | parser.add_argument('--modeldir', type=str, required=False, default="./models/", help='Model directory path') 39 | parser.add_argument('--beta', type=float, default=0.5, 40 | help='The parameter for the dirichlet distribution for data partitioning') 41 | 42 | parser.add_argument('--device', type=str, default='cuda:0', help='The device to run the program') 43 | parser.add_argument('--log_file_name', type=str, default=None, help='The log file name') 44 | parser.add_argument('--optimizer', type=str, default='sgd', help='the optimizer') 45 | parser.add_argument('--out_dim', type=int, default=84, help='the output dimension for the projection layer') 46 | parser.add_argument('--temperature', type=float, default=0.5, help='the temperature parameter for contrastive loss') 47 | parser.add_argument('--local_max_epoch', type=int, default=100, 48 | help='the number of epoch for local optimal training') 49 | parser.add_argument('--model_buffer_size', type=int, default=1, 50 | help='store how many previous models for contrastive loss') 51 | parser.add_argument('--pool_option', type=str, default='FIFO', help='FIFO or BOX') 52 | parser.add_argument('--sample_fraction', type=float, default=1, help='how many clients are sampled in each round') 53 | parser.add_argument('--load_model_file', type=str, default=None, help='the model to load as global model') 54 | parser.add_argument('--load_pool_file', type=str, default=None, help='the old model pool path to load') 55 | parser.add_argument('--load_model_round', type=int, default=None, help='how many rounds have executed for the loaded model') 56 | parser.add_argument('--load_first_net', type=int, default=1, help='whether load the first net as old net or not') 57 | parser.add_argument('--normal_model', type=int, default=0, help='use normal model or aggregate model') 58 | parser.add_argument('--loss', type=str, default='contrastive') 59 | parser.add_argument('--save_model', type=int, default=0) 60 | parser.add_argument('--use_project_head', type=int, default=1) 61 | parser.add_argument('--server_momentum', type=float, default=0, help='the server momentum (FedAvgM)') 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def init_nets(net_configs, n_parties, args, n_classes, device='cuda:0'): 67 | nets = {net_i: None for net_i in range(n_parties)} 68 | 69 | 70 | for net_i in range(n_parties): 71 | if 'cifar' in args.dataset: 72 | net = ModelFedCon(args.model, args.out_dim, n_classes, net_configs) 73 | net.to(device) 74 | nets[net_i] = net 75 | else: 76 | if args.dataset == 'vireo172' or args.dataset == 'food101': 77 | net = resnet18(args.dataset, kernel_size=7, pretrained=False) 78 | else: 79 | net = resnet18(args.dataset, kernel_size=3, pretrained=False) 80 | nets[net_i] = net 81 | model_meta_data = [] 82 | layer_type = [] 83 | for (k, v) in nets[0].state_dict().items(): 84 | model_meta_data.append(v.shape) 85 | layer_type.append(k) 86 | 87 | return nets, model_meta_data, layer_type 88 | 89 | 90 | 91 | def train_net_fedavg(net_id, net, global_net, train_dataloader, test_dataloader, epochs, lr, args_optimizer, args, 92 | device="cuda:0"): 93 | global_net.to(device) 94 | # net = nn.DataParallel(net) 95 | net.to(device) 96 | 97 | logger.info('Training network %s' % str(net_id)) 98 | logger.info('n_training: %d' % len(train_dataloader)) 99 | logger.info('n_test: %d' % len(test_dataloader)) 100 | 101 | if args_optimizer == 'adam': 102 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg) 103 | elif args_optimizer == 'amsgrad': 104 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=args.reg, 105 | amsgrad=True) 106 | elif args_optimizer == 'sgd': 107 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=0.9, 108 | weight_decay=args.reg) 109 | 110 | criterion = nn.CrossEntropyLoss().to(device) 111 | 112 | cnt = 0 113 | 114 | 115 | for epoch in range(epochs): 116 | epoch_loss_collector = [] 117 | for batch_idx, (x, target) in enumerate(train_dataloader): 118 | # ipdb.set_trace() 119 | x, target = x.to(device), target.to(device) 120 | if args.dataset == 'pmnist': 121 | target = target.reshape(-1) 122 | optimizer.zero_grad() 123 | x.requires_grad = True 124 | target.requires_grad = False 125 | target = target.long() 126 | 127 | _,_,_,out = net(x) 128 | loss = criterion(out, target) 129 | 130 | loss.backward() 131 | optimizer.step() 132 | 133 | cnt += 1 134 | epoch_loss_collector.append(loss.item()) 135 | 136 | epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector) 137 | logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss)) 138 | 139 | net.to('cpu') 140 | logger.info(' ** Training complete **') 141 | return 0, 0 142 | 143 | 144 | def local_train_net(nets, args, net_dataidx_map, train_dl=None, test_dl=None, global_model = None, prev_model_pool = None, prev_protos_pool=None, prev_protos_label_pool=None, server_c = None, clients_c = None, round=None, device="cuda:0"): 145 | avg_acc = 0.0 146 | acc_list = [] 147 | if global_model: 148 | global_model.cuda() 149 | if server_c: 150 | server_c.cuda() 151 | server_c_collector = list(server_c.cuda().parameters()) 152 | new_server_c_collector = copy.deepcopy(server_c_collector) 153 | 154 | 155 | for net_id, net in nets.items(): 156 | dataidxs = net_dataidx_map[net_id] 157 | 158 | logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs))) 159 | train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs) 160 | 161 | n_epoch = args.epochs 162 | 163 | trainacc, testacc = train_net_fedavg(net_id, net, global_model, train_dl_local, test_dl, n_epoch, args.lr, 164 | args.optimizer, args, device=device) 165 | if global_model: 166 | global_model.to('cpu') 167 | if server_c: 168 | for param_index, param in enumerate(server_c.parameters()): 169 | server_c_collector[param_index] = new_server_c_collector[param_index] 170 | server_c.to('cpu') 171 | return nets 172 | 173 | 174 | if __name__ == '__main__': 175 | args = get_args() 176 | mkdirs(args.logdir) 177 | mkdirs(args.modeldir) 178 | if args.log_file_name is None: 179 | argument_path = 'experiment_arguments-%s.json' % datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") 180 | else: 181 | argument_path = args.log_file_name + '.json' 182 | with open(os.path.join(args.logdir, argument_path), 'w') as f: 183 | json.dump(str(args), f) 184 | device = torch.device(args.device) 185 | for handler in logging.root.handlers[:]: 186 | logging.root.removeHandler(handler) 187 | 188 | if args.log_file_name is None: 189 | args.log_file_name = 'experiment_log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")) 190 | log_path = args.log_file_name + '.log' 191 | logging.basicConfig( 192 | filename=os.path.join(args.logdir, log_path), 193 | format='%(asctime)s %(levelname)-8s %(message)s', 194 | datefmt='%m-%d %H:%M', level=logging.DEBUG, filemode='w') 195 | 196 | logger = logging.getLogger() 197 | logger.setLevel(logging.DEBUG) 198 | logger.info(device) 199 | 200 | seed = args.init_seed 201 | logger.info("#" * 100) 202 | np.random.seed(seed) 203 | torch.manual_seed(seed) 204 | if torch.cuda.is_available(): 205 | torch.cuda.manual_seed(seed) 206 | random.seed(seed) 207 | 208 | 209 | logger.info("Partitioning data") 210 | X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data( 211 | args.dataset, args.datadir, args.logdir, args.partition, args.n_parties, beta=args.beta) 212 | 213 | n_party_per_round = int(args.n_parties * args.sample_fraction) 214 | party_list = [i for i in range(args.n_parties)] 215 | global_party_list = [i for i in range(args.n_parties)] 216 | party_list_rounds = [] 217 | if n_party_per_round != args.n_parties: 218 | for i in range(args.comm_round): 219 | party_list_rounds.append(random.sample(party_list, n_party_per_round)) 220 | else: 221 | for i in range(args.comm_round): 222 | party_list_rounds.append(party_list) 223 | 224 | n_classes = len(np.unique(y_train)) 225 | 226 | train_dl_global, test_dl, train_ds_global, test_ds_global = get_dataloader(args.dataset, args.datadir, 227 | args.batch_size, 32) 228 | 229 | print("len train_dl_global:", len(train_ds_global)) 230 | train_dl = None 231 | data_size = len(test_ds_global) 232 | 233 | logger.info("Initializing nets") 234 | nets, local_model_meta_data, layer_type = init_nets(args.net_config, args.n_parties, args, n_classes,device=device) 235 | 236 | global_models, global_model_meta_data, global_layer_type = init_nets(args.net_config, 1, args, n_classes, device=device) 237 | global_model = global_models[0] 238 | n_comm_rounds = args.comm_round 239 | 240 | 241 | if args.load_model_file and args.alg != 'plot_visual': 242 | global_model.load_state_dict(torch.load(args.load_model_file)) 243 | n_comm_rounds -= args.load_model_round 244 | 245 | if args.server_momentum: 246 | moment_v = copy.deepcopy(global_model.state_dict()) 247 | for key in moment_v: 248 | moment_v[key] = 0 249 | 250 | for round in range(n_comm_rounds): 251 | logger.info("in comm round:" + str(round)) 252 | party_list_this_round = party_list_rounds[round] 253 | global_w = global_model.state_dict() 254 | nets_this_round = {k: nets[k] for k in party_list_this_round} 255 | for net in nets_this_round.values(): 256 | net.load_state_dict(global_w) 257 | 258 | nets_this_round = local_train_net(nets_this_round, args, net_dataidx_map, train_dl=train_dl, test_dl=test_dl, 259 | global_model=global_model, round=round, device=device) 260 | global_model.to('cpu') 261 | 262 | # update global model 263 | global_nets_this_round = {k: nets[k] for k in global_party_list} 264 | 265 | total_data_points = sum([len(net_dataidx_map[r]) for r in global_party_list]) 266 | fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in global_party_list] 267 | 268 | for net_id, net in enumerate(global_nets_this_round.values()): 269 | net.to('cpu') 270 | net_para = net.state_dict() 271 | if net_id == 0: 272 | for key in net_para: 273 | global_w[key] = net_para[key] * fed_avg_freqs[net_id] 274 | else: 275 | for key in net_para: 276 | global_w[key] += net_para[key] * fed_avg_freqs[net_id] 277 | 278 | global_model.load_state_dict(global_w) 279 | 280 | logger.info('global n_training: %d' % len(train_dl_global)) 281 | logger.info('global n_test: %d' % len(test_dl)) 282 | 283 | global_model.cuda() 284 | train_acc, train_loss = compute_accuracy_tset(global_model, train_dl_global, device=device) 285 | test_acc, conf_matrix, _ = compute_accuracy_tset(global_model, test_dl, get_confusion_matrix=True, device=device) 286 | 287 | logger.info('>> Global Model Train accuracy: %f' % train_acc) 288 | logger.info('>> Global Model Test accuracy: %f' % test_acc) 289 | logger.info('>> Global Model Train loss: %f' % train_loss) 290 | if round % 25 == 24: 291 | mkdirs(args.modeldir + 'Fedavg/' + args.dataset + '/' + argument_path + '/' + str(round)) 292 | global_model.to('cpu') 293 | torch.save(global_model.state_dict(), 294 | args.modeldir + 'Fedavg/' + args.dataset + '/' + argument_path + '/' + str(round) + '/global_model.pth') 295 | for i in range(10): 296 | torch.save(nets[i].state_dict(), args.modeldir + 'Fedavg/' + args.dataset + '/' + argument_path + '/' + str(round) + '/local_' + str(i) + '.pth') 297 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torchvision.models as models 6 | from resnetcifar import ResNet18_cifar10, ResNet50_cifar10 7 | 8 | #import pytorch_lightning as pl 9 | 10 | 11 | 12 | 13 | class MLP_header(nn.Module): 14 | def __init__(self,): 15 | super(MLP_header, self).__init__() 16 | self.fc1 = nn.Linear(28*28, 512) 17 | self.fc2 = nn.Linear(512, 512) 18 | self.relu = nn.ReLU() 19 | #projection 20 | # self.fc3 = nn.Linear(512, 10) 21 | 22 | def forward(self, x): 23 | x = x.view(-1, 28*28) 24 | x = self.fc1(x) 25 | x = self.relu(x) 26 | x = self.fc2(x) 27 | x = self.relu(x) 28 | return x 29 | 30 | 31 | class FcNet(nn.Module): 32 | """ 33 | Fully connected network for MNIST classification 34 | """ 35 | 36 | def __init__(self, input_dim, hidden_dims, output_dim, dropout_p=0.0): 37 | 38 | super().__init__() 39 | 40 | self.input_dim = input_dim 41 | self.hidden_dims = hidden_dims 42 | self.output_dim = output_dim 43 | self.dropout_p = dropout_p 44 | 45 | self.dims = [self.input_dim] 46 | self.dims.extend(hidden_dims) 47 | self.dims.append(self.output_dim) 48 | 49 | self.layers = nn.ModuleList([]) 50 | 51 | for i in range(len(self.dims) - 1): 52 | ip_dim = self.dims[i] 53 | op_dim = self.dims[i + 1] 54 | self.layers.append( 55 | nn.Linear(ip_dim, op_dim, bias=True) 56 | ) 57 | 58 | self.__init_net_weights__() 59 | 60 | def __init_net_weights__(self): 61 | 62 | for m in self.layers: 63 | m.weight.data.normal_(0.0, 0.1) 64 | m.bias.data.fill_(0.1) 65 | 66 | def forward(self, x): 67 | 68 | x = x.view(-1, self.input_dim) 69 | 70 | for i, layer in enumerate(self.layers): 71 | x = layer(x) 72 | 73 | # Do not apply ReLU on the final layer 74 | if i < (len(self.layers) - 1): 75 | x = nn.ReLU(x) 76 | 77 | # if i < (len(self.layers) - 1): # No dropout on output layer 78 | # x = F.dropout(x, p=self.dropout_p, training=self.training) 79 | 80 | return x 81 | 82 | 83 | class ConvBlock(nn.Module): 84 | def __init__(self): 85 | super(ConvBlock, self).__init__() 86 | self.conv1 = nn.Conv2d(3, 6, 5) 87 | self.pool = nn.MaxPool2d(2, 2) 88 | self.conv2 = nn.Conv2d(6, 16, 5) 89 | 90 | def forward(self, x): 91 | x = self.pool(F.relu(self.conv1(x))) 92 | x = self.pool(F.relu(self.conv2(x))) 93 | x = x.view(-1, 16 * 5 * 5) 94 | return x 95 | 96 | 97 | class FCBlock(nn.Module): 98 | def __init__(self, input_dim, hidden_dims, output_dim=10): 99 | super(FCBlock, self).__init__() 100 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 101 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 102 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 103 | 104 | def forward(self, x): 105 | x = F.relu(self.fc1(x)) 106 | x = F.relu(self.fc2(x)) 107 | x = self.fc3(x) 108 | return x 109 | 110 | 111 | class VGGConvBlocks(nn.Module): 112 | ''' 113 | VGG model 114 | ''' 115 | 116 | def __init__(self, features, num_classes=10): 117 | super(VGGConvBlocks, self).__init__() 118 | self.features = features 119 | # Initialize weights 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | m.bias.data.zero_() 125 | 126 | def forward(self, x): 127 | x = self.features(x) 128 | x = x.view(x.size(0), -1) 129 | return x 130 | 131 | 132 | class FCBlockVGG(nn.Module): 133 | def __init__(self, input_dim, hidden_dims, output_dim=10): 134 | super(FCBlockVGG, self).__init__() 135 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 136 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 137 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 138 | 139 | def forward(self, x): 140 | x = F.dropout(x) 141 | x = F.relu(self.fc1(x)) 142 | x = F.dropout(x) 143 | x = F.relu(self.fc2(x)) 144 | x = self.fc3(x) 145 | return x 146 | 147 | 148 | class SimpleCNN_header(nn.Module): 149 | def __init__(self, input_dim, hidden_dims, output_dim=10): 150 | super(SimpleCNN_header, self).__init__() 151 | self.conv1 = nn.Conv2d(3, 6, 5) 152 | self.relu = nn.ReLU() 153 | self.pool = nn.MaxPool2d(2, 2) 154 | self.conv2 = nn.Conv2d(6, 16, 5) 155 | 156 | # for now, we hard coded this network 157 | # i.e. we fix the number of hidden layers i.e. 2 layers 158 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 159 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 160 | #self.fc3 = nn.Linear(hidden_dims[1], output_dim) 161 | 162 | def forward(self, x): 163 | 164 | x = self.pool(self.relu(self.conv1(x))) 165 | x = self.pool(self.relu(self.conv2(x))) 166 | x = x.view(-1, 16 * 5 * 5) 167 | 168 | x = self.relu(self.fc1(x)) 169 | x = self.relu(self.fc2(x)) 170 | # x = self.fc3(x) 171 | return x 172 | 173 | 174 | class SimpleCNN(nn.Module): 175 | def __init__(self, input_dim, hidden_dims, output_dim=10): 176 | super(SimpleCNN, self).__init__() 177 | self.conv1 = nn.Conv2d(3, 6, 5) 178 | self.relu = nn.ReLU() 179 | self.pool = nn.MaxPool2d(2, 2) 180 | self.conv2 = nn.Conv2d(6, 16, 5) 181 | 182 | # for now, we hard coded this network 183 | # i.e. we fix the number of hidden layers i.e. 2 layers 184 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 185 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 186 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 187 | 188 | def forward(self, x): 189 | #out = self.conv1(x) 190 | #out = self.relu(out) 191 | #out = self.pool(out) 192 | #out = self.conv2(out) 193 | #out = self.relu(out) 194 | #out = self.pool(out) 195 | #out = out.view(-1, 16 * 5 * 5) 196 | 197 | x = self.pool(self.relu(self.conv1(x))) 198 | x = self.pool(self.relu(self.conv2(x))) 199 | x = x.view(-1, 16 * 5 * 5) 200 | 201 | x = self.relu(self.fc1(x)) 202 | x = self.fc2(x) 203 | x = self.fc3(x) 204 | return x 205 | 206 | 207 | # a simple perceptron model for generated 3D data 208 | class PerceptronModel(nn.Module): 209 | def __init__(self, input_dim=3, output_dim=2): 210 | super(PerceptronModel, self).__init__() 211 | 212 | self.fc1 = nn.Linear(input_dim, output_dim) 213 | 214 | def forward(self, x): 215 | 216 | x = self.fc1(x) 217 | return x 218 | 219 | 220 | class SimpleCNNMNIST_header(nn.Module): 221 | def __init__(self, input_dim, hidden_dims, output_dim=10): 222 | super(SimpleCNNMNIST_header, self).__init__() 223 | self.conv1 = nn.Conv2d(1, 6, 5) 224 | self.relu = nn.ReLU() 225 | self.pool = nn.MaxPool2d(2, 2) 226 | self.conv2 = nn.Conv2d(6, 16, 5) 227 | 228 | # for now, we hard coded this network 229 | # i.e. we fix the number of hidden layers i.e. 2 layers 230 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 231 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 232 | #self.fc3 = nn.Linear(hidden_dims[1], output_dim) 233 | 234 | def forward(self, x): 235 | x = self.pool(self.relu(self.conv1(x))) 236 | x = self.pool(self.relu(self.conv2(x))) 237 | x = x.view(-1, 16 * 4 * 4) 238 | 239 | x = self.relu(self.fc1(x)) 240 | x = self.relu(self.fc2(x)) 241 | # x = self.fc3(x) 242 | return x 243 | 244 | class SimpleCNNMNIST(nn.Module): 245 | def __init__(self, input_dim, hidden_dims, output_dim=10): 246 | super(SimpleCNNMNIST, self).__init__() 247 | self.conv1 = nn.Conv2d(1, 6, 5) 248 | self.pool = nn.MaxPool2d(2, 2) 249 | self.conv2 = nn.Conv2d(6, 16, 5) 250 | 251 | # for now, we hard coded this network 252 | # i.e. we fix the number of hidden layers i.e. 2 layers 253 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 254 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 255 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 256 | 257 | def forward(self, x): 258 | x = self.pool(F.relu(self.conv1(x))) 259 | x = self.pool(F.relu(self.conv2(x))) 260 | x = x.view(-1, 16 * 4 * 4) 261 | 262 | x = F.relu(self.fc1(x)) 263 | x = F.relu(self.fc2(x)) 264 | y = self.fc3(x) 265 | return x, 0, y 266 | 267 | 268 | class SimpleCNNContainer(nn.Module): 269 | def __init__(self, input_channel, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10): 270 | super(SimpleCNNContainer, self).__init__() 271 | ''' 272 | A testing cnn container, which allows initializing a CNN with given dims 273 | 274 | num_filters (list) :: number of convolution filters 275 | hidden_dims (list) :: number of neurons in hidden layers 276 | 277 | Assumptions: 278 | i) we use only two conv layers and three hidden layers (including the output layer) 279 | ii) kernel size in the two conv layers are identical 280 | ''' 281 | self.conv1 = nn.Conv2d(input_channel, num_filters[0], kernel_size) 282 | self.pool = nn.MaxPool2d(2, 2) 283 | self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size) 284 | 285 | # for now, we hard coded this network 286 | # i.e. we fix the number of hidden layers i.e. 2 layers 287 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 288 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 289 | self.fc3 = nn.Linear(hidden_dims[1], output_dim) 290 | 291 | def forward(self, x): 292 | x = self.pool(F.relu(self.conv1(x))) 293 | x = self.pool(F.relu(self.conv2(x))) 294 | x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3]) 295 | x = F.relu(self.fc1(x)) 296 | x = F.relu(self.fc2(x)) 297 | x = self.fc3(x) 298 | return x 299 | 300 | 301 | ############## LeNet for MNIST ################### 302 | class LeNet(nn.Module): 303 | def __init__(self): 304 | super(LeNet, self).__init__() 305 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 306 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 307 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 308 | self.fc2 = nn.Linear(500, 10) 309 | self.ceriation = nn.CrossEntropyLoss() 310 | 311 | def forward(self, x): 312 | x = self.conv1(x) 313 | x = F.max_pool2d(x, 2, 2) 314 | x = F.relu(x) 315 | x = self.conv2(x) 316 | x = F.max_pool2d(x, 2, 2) 317 | x = F.relu(x) 318 | x = x.view(-1, 4 * 4 * 50) 319 | x = self.fc1(x) 320 | x = self.fc2(x) 321 | return x 322 | 323 | 324 | class LeNetContainer(nn.Module): 325 | def __init__(self, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10): 326 | super(LeNetContainer, self).__init__() 327 | self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size, 1) 328 | self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size, 1) 329 | 330 | self.fc1 = nn.Linear(input_dim, hidden_dims[0]) 331 | self.fc2 = nn.Linear(hidden_dims[0], output_dim) 332 | 333 | def forward(self, x): 334 | x = self.conv1(x) 335 | x = F.max_pool2d(x, 2, 2) 336 | x = F.relu(x) 337 | x = self.conv2(x) 338 | x = F.max_pool2d(x, 2, 2) 339 | x = F.relu(x) 340 | x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3]) 341 | x = self.fc1(x) 342 | x = self.fc2(x) 343 | return x 344 | 345 | 346 | 347 | ### Moderate size of CNN for CIFAR-10 dataset 348 | class ModerateCNN(nn.Module): 349 | def __init__(self, output_dim=10): 350 | super(ModerateCNN, self).__init__() 351 | self.conv_layer = nn.Sequential( 352 | # Conv Layer block 1 353 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 354 | nn.ReLU(inplace=True), 355 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 356 | nn.ReLU(inplace=True), 357 | nn.MaxPool2d(kernel_size=2, stride=2), 358 | 359 | # Conv Layer block 2 360 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 361 | nn.ReLU(inplace=True), 362 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 363 | nn.ReLU(inplace=True), 364 | nn.MaxPool2d(kernel_size=2, stride=2), 365 | nn.Dropout2d(p=0.05), 366 | 367 | # Conv Layer block 3 368 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 369 | nn.ReLU(inplace=True), 370 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 371 | nn.ReLU(inplace=True), 372 | nn.MaxPool2d(kernel_size=2, stride=2), 373 | ) 374 | 375 | self.fc_layer = nn.Sequential( 376 | nn.Dropout(p=0.1), 377 | # nn.Linear(4096, 1024), 378 | nn.Linear(4096, 512), 379 | nn.ReLU(inplace=True), 380 | # nn.Linear(1024, 512), 381 | nn.Linear(512, 512), 382 | nn.ReLU(inplace=True), 383 | nn.Dropout(p=0.1), 384 | nn.Linear(512, output_dim) 385 | ) 386 | 387 | def forward(self, x): 388 | x = self.conv_layer(x) 389 | x = x.view(x.size(0), -1) 390 | x = self.fc_layer(x) 391 | return x 392 | 393 | 394 | ### Moderate size of CNN for CIFAR-10 dataset 395 | class ModerateCNNCeleba(nn.Module): 396 | def __init__(self): 397 | super(ModerateCNNCeleba, self).__init__() 398 | self.conv_layer = nn.Sequential( 399 | # Conv Layer block 1 400 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 401 | nn.ReLU(inplace=True), 402 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 403 | nn.ReLU(inplace=True), 404 | nn.MaxPool2d(kernel_size=2, stride=2), 405 | 406 | # Conv Layer block 2 407 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 408 | nn.ReLU(inplace=True), 409 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 410 | nn.ReLU(inplace=True), 411 | nn.MaxPool2d(kernel_size=2, stride=2), 412 | # nn.Dropout2d(p=0.05), 413 | 414 | # Conv Layer block 3 415 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 416 | nn.ReLU(inplace=True), 417 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 418 | nn.ReLU(inplace=True), 419 | nn.MaxPool2d(kernel_size=2, stride=2), 420 | ) 421 | 422 | self.fc_layer = nn.Sequential( 423 | nn.Dropout(p=0.1), 424 | # nn.Linear(4096, 1024), 425 | nn.Linear(4096, 512), 426 | nn.ReLU(inplace=True), 427 | # nn.Linear(1024, 512), 428 | nn.Linear(512, 512), 429 | nn.ReLU(inplace=True), 430 | nn.Dropout(p=0.1), 431 | nn.Linear(512, 2) 432 | ) 433 | 434 | def forward(self, x): 435 | x = self.conv_layer(x) 436 | # x = x.view(x.size(0), -1) 437 | x = x.view(-1, 4096) 438 | x = self.fc_layer(x) 439 | return x 440 | 441 | 442 | class ModerateCNNMNIST(nn.Module): 443 | def __init__(self): 444 | super(ModerateCNNMNIST, self).__init__() 445 | self.conv_layer = nn.Sequential( 446 | # Conv Layer block 1 447 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1), 448 | nn.ReLU(inplace=True), 449 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 450 | nn.ReLU(inplace=True), 451 | nn.MaxPool2d(kernel_size=2, stride=2), 452 | 453 | # Conv Layer block 2 454 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 455 | nn.ReLU(inplace=True), 456 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 457 | nn.ReLU(inplace=True), 458 | nn.MaxPool2d(kernel_size=2, stride=2), 459 | nn.Dropout2d(p=0.05), 460 | 461 | # Conv Layer block 3 462 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 463 | nn.ReLU(inplace=True), 464 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 465 | nn.ReLU(inplace=True), 466 | nn.MaxPool2d(kernel_size=2, stride=2), 467 | ) 468 | 469 | self.fc_layer = nn.Sequential( 470 | nn.Dropout(p=0.1), 471 | nn.Linear(2304, 1024), 472 | nn.ReLU(inplace=True), 473 | nn.Linear(1024, 512), 474 | nn.ReLU(inplace=True), 475 | nn.Dropout(p=0.1), 476 | nn.Linear(512, 10) 477 | ) 478 | 479 | def forward(self, x): 480 | x = self.conv_layer(x) 481 | x = x.view(x.size(0), -1) 482 | x = self.fc_layer(x) 483 | return x 484 | 485 | 486 | class ModerateCNNContainer(nn.Module): 487 | def __init__(self, input_channels, num_filters, kernel_size, input_dim, hidden_dims, output_dim=10): 488 | super(ModerateCNNContainer, self).__init__() 489 | 490 | ## 491 | self.conv_layer = nn.Sequential( 492 | # Conv Layer block 1 493 | nn.Conv2d(in_channels=input_channels, out_channels=num_filters[0], kernel_size=kernel_size, padding=1), 494 | nn.ReLU(inplace=True), 495 | nn.Conv2d(in_channels=num_filters[0], out_channels=num_filters[1], kernel_size=kernel_size, padding=1), 496 | nn.ReLU(inplace=True), 497 | nn.MaxPool2d(kernel_size=2, stride=2), 498 | 499 | # Conv Layer block 2 500 | nn.Conv2d(in_channels=num_filters[1], out_channels=num_filters[2], kernel_size=kernel_size, padding=1), 501 | nn.ReLU(inplace=True), 502 | nn.Conv2d(in_channels=num_filters[2], out_channels=num_filters[3], kernel_size=kernel_size, padding=1), 503 | nn.ReLU(inplace=True), 504 | nn.MaxPool2d(kernel_size=2, stride=2), 505 | nn.Dropout2d(p=0.05), 506 | 507 | # Conv Layer block 3 508 | nn.Conv2d(in_channels=num_filters[3], out_channels=num_filters[4], kernel_size=kernel_size, padding=1), 509 | nn.ReLU(inplace=True), 510 | nn.Conv2d(in_channels=num_filters[4], out_channels=num_filters[5], kernel_size=kernel_size, padding=1), 511 | nn.ReLU(inplace=True), 512 | nn.MaxPool2d(kernel_size=2, stride=2), 513 | ) 514 | 515 | self.fc_layer = nn.Sequential( 516 | nn.Dropout(p=0.1), 517 | nn.Linear(input_dim, hidden_dims[0]), 518 | nn.ReLU(inplace=True), 519 | nn.Linear(hidden_dims[0], hidden_dims[1]), 520 | nn.ReLU(inplace=True), 521 | nn.Dropout(p=0.1), 522 | nn.Linear(hidden_dims[1], output_dim) 523 | ) 524 | 525 | def forward(self, x): 526 | x = self.conv_layer(x) 527 | x = x.view(x.size(0), -1) 528 | x = self.fc_layer(x) 529 | return x 530 | 531 | def forward_conv(self, x): 532 | x = self.conv_layer(x) 533 | x = x.view(x.size(0), -1) 534 | return x 535 | 536 | 537 | class ModelFedCon(nn.Module): 538 | 539 | def __init__(self, base_model, out_dim, n_classes, net_configs=None): 540 | super(ModelFedCon, self).__init__() 541 | 542 | if base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel" or base_model == "resnet50": 543 | basemodel = ResNet50_cifar10() 544 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 545 | num_ftrs = basemodel.fc.in_features 546 | elif base_model == "resnet18-cifar10" or base_model == "resnet18": 547 | basemodel = ResNet18_cifar10() 548 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 549 | num_ftrs = basemodel.fc.in_features 550 | elif base_model == "mlp": 551 | self.features = MLP_header() 552 | num_ftrs = 512 553 | elif base_model == 'simple-cnn': 554 | self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes) 555 | num_ftrs = 84 556 | elif base_model == 'simple-cnn-mnist': 557 | self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes) 558 | num_ftrs = 84 559 | 560 | #summary(self.features.to('cuda:0'), (3,32,32)) 561 | #print("features:", self.features) 562 | # projection MLP 563 | self.fc = nn.Linear(num_ftrs, n_classes) 564 | self.l1 = nn.Linear(num_ftrs, num_ftrs) 565 | self.l2 = nn.Linear(num_ftrs, out_dim) 566 | 567 | # last layer 568 | self.l3 = nn.Linear(out_dim, n_classes) 569 | 570 | def _get_basemodel(self, model_name): 571 | try: 572 | model = self.model_dict[model_name] 573 | #print("Feature extractor:", model_name) 574 | return model 575 | except: 576 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") 577 | 578 | def forward(self, x): 579 | h = self.features(x) 580 | #print("h before:", h) 581 | # print("h size:", h.size()) 582 | h = h.squeeze() 583 | #print("h after:", h) 584 | x1 = self.l1(h) 585 | x1 = F.relu(x1) 586 | x2 = self.l2(x1) 587 | h_out = self.fc(h) 588 | y = self.l3(x2) 589 | return h, h_out, x2, y 590 | 591 | 592 | class ModelFedCon_noheader(nn.Module): 593 | 594 | def __init__(self, base_model, out_dim, n_classes, net_configs=None): 595 | super(ModelFedCon_noheader, self).__init__() 596 | 597 | if base_model == "resnet50": 598 | basemodel = models.resnet50(pretrained=False) 599 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 600 | num_ftrs = basemodel.fc.in_features 601 | # elif base_model == "resnet18": 602 | # basemodel = models.resnet18(pretrained=False) 603 | # self.features = nn.Sequential(*list(basemodel.children())[:-1]) 604 | # num_ftrs = basemodel.fc.in_features 605 | elif base_model == "resnet50-cifar10" or base_model == "resnet50-cifar100" or base_model == "resnet50-smallkernel": 606 | basemodel = ResNet50_cifar10() 607 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 608 | num_ftrs = basemodel.fc.in_features 609 | elif base_model == "resnet18": 610 | basemodel = ResNet18_cifar10() 611 | self.features = nn.Sequential(*list(basemodel.children())[:-1]) 612 | num_ftrs = basemodel.fc.in_features 613 | elif base_model == "mlp": 614 | self.features = MLP_header() 615 | num_ftrs = 512 616 | elif base_model == 'simple-cnn': 617 | self.features = SimpleCNN_header(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=n_classes) 618 | num_ftrs = 84 619 | elif base_model == 'simple-cnn-mnist': 620 | self.features = SimpleCNNMNIST_header(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=n_classes) 621 | num_ftrs = 84 622 | 623 | #summary(self.features.to('cuda:0'), (3,32,32)) 624 | #print("features:", self.features) 625 | # projection MLP 626 | # self.l1 = nn.Linear(num_ftrs, num_ftrs) 627 | # self.l2 = nn.Linear(num_ftrs, out_dim) 628 | 629 | # last layer 630 | self.l3 = nn.Linear(num_ftrs, n_classes) 631 | 632 | def _get_basemodel(self, model_name): 633 | try: 634 | model = self.model_dict[model_name] 635 | #print("Feature extractor:", model_name) 636 | return model 637 | except: 638 | raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") 639 | 640 | def forward(self, x): 641 | h = self.features(x) 642 | #print("h before:", h) 643 | # print("h size:", h.size()) 644 | h = h.squeeze() 645 | #print("h after:", h) 646 | # x = self.l1(h) 647 | # x = F.relu(x) 648 | # x = self.l2(x) 649 | 650 | y = self.l3(h) 651 | return h, h, y 652 | 653 | -------------------------------------------------------------------------------- /re_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torch.utils.data as data 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import random 11 | from sklearn.metrics import confusion_matrix 12 | import torch.optim as optim 13 | # from model import * 14 | from datasets import CIFAR10_truncated, CIFAR100_truncated, ImageFolder_custom, MNIST_truncated, FashionMNIST_truncated, TinyImageNet_load, Vireo172_truncated, Food101_truncated 15 | import ipdb 16 | import copy 17 | from RGA import * 18 | from loss import * 19 | from utils import * 20 | 21 | def exp_lr_scheduler(optimizer, epoch, init_lr, lr_decay, decay_rate): 22 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 23 | # 每四次epoch调整一下lr,将lr减半 24 | lr = init_lr * (decay_rate ** (epoch // lr_decay)) # *是乘法,**是乘方,/是浮点除法,//是整数除法,%是取余数 25 | 26 | if epoch % lr_decay == 0: 27 | print('LR is set to {}'.format(lr)) 28 | 29 | for param_group in optimizer.param_groups: 30 | param_group['lr'] = lr 31 | # 返回改变了学习率的optimizer 32 | return optimizer 33 | 34 | 35 | def retrain_cls_final(global_dnn, prototypes, proto_labels, client_ids, n_classes, args, round, device): 36 | 37 | if round <= 5: 38 | init_lr = 1e-1 39 | elif round <= 40: 40 | init_lr = 1e-2 41 | else: 42 | init_lr = 1e-3 43 | 44 | lr_decay = 30 45 | decay_rate = 0.1 46 | batch_size = 100 47 | global_dnn.to(device) 48 | cuda = 1 49 | prototypes = prototypes.cpu().numpy() 50 | proto_labels = proto_labels.cpu().numpy() 51 | client_ids = client_ids.cpu().numpy() 52 | 53 | kwargs = {'num_workers': 2, 'pin_memory': True} 54 | 55 | 56 | dataset_c = Data_for_Retraining_final(prototypes, proto_labels, client_ids) 57 | data_loader = torch.utils.data.DataLoader(dataset_c, batch_size=batch_size, shuffle=True, **kwargs) 58 | 59 | # prototypes = prototypes.to(device) 60 | # proto_labels = proto_labels.to(device) 61 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, global_dnn.parameters()), lr=init_lr, weight_decay=args.reg) 62 | criterion = nn.CrossEntropyLoss().to(device) 63 | 64 | SCLoss = SupervisedContrastiveLoss() 65 | ACLoss = TripletLoss() 66 | 67 | idx_list = np.array(np.arange(len(proto_labels))) 68 | # ipdb.set_trace() 69 | with torch.no_grad(): 70 | prototypes = torch.tensor(prototypes).to(device) 71 | proto_labels = torch.tensor(proto_labels).to(device) 72 | h2, out = global_dnn(prototypes) 73 | pred_label = torch.argmax(out.data, 1) 74 | total = prototypes.data.size()[0] 75 | 76 | correct = (pred_label == proto_labels.data).sum().item() 77 | print('before', correct) 78 | 79 | print('proto_labels', proto_labels.shape) 80 | 81 | 82 | for epoch in range(100): 83 | optimizer = exp_lr_scheduler(optimizer, epoch, init_lr, lr_decay, decay_rate) 84 | # random.shuffle(idx_list) 85 | # batch_size = 100 86 | epoch_loss_collector=[] 87 | for batch_idx, (x, posi_x, nega_x, target) in enumerate(data_loader): 88 | # for i in range(len(proto_labels)//batch_size): 89 | x, posi_x, nega_x, target = x.to(device), posi_x.to(device), nega_x.to(device), target.reshape(-1).to(device) 90 | 91 | epoch_loss_collector = [] 92 | 93 | optimizer.zero_grad() 94 | x.requires_grad = True 95 | target.requires_grad = False 96 | target = target.long() 97 | 98 | feats, out = global_dnn(x) 99 | 100 | if args.re_version == 'v1': 101 | loss = criterion(out, target) 102 | elif args.re_version == 'v2': 103 | 104 | mix_posi = (posi_x - x) * args.posi_lambda + posi_x 105 | mix_nega = (nega_x - x) * args.nega_lambda + x 106 | 107 | feats_posi, _ = global_dnn(mix_posi) 108 | feats_nega, _ = global_dnn(mix_nega) 109 | 110 | loss1 = criterion(out, target) 111 | loss2 = args.re_mu * SCLoss(feats, target) 112 | loss3 = ACLoss(feats, feats_posi, feats_nega) 113 | 114 | mixed_x, y_a, y_b, lam = mixup_data(x, target, args.posi_lambda) 115 | 116 | _, out_mix = global_dnn(mixed_x) 117 | 118 | # loss_mix = mixup_criterion(criterion, out_mix, y_a, y_b, lam) 119 | 120 | loss = loss1 + loss2 + args.re_beta * loss3 # + 0.1 * loss_mix 121 | 122 | epoch_loss_collector.append(loss.data) 123 | 124 | loss.backward() 125 | optimizer.step() 126 | print(epoch, sum(epoch_loss_collector)/len(epoch_loss_collector)) 127 | 128 | with torch.no_grad(): 129 | feats, out = global_dnn(prototypes) 130 | 131 | pred_label = torch.argmax(out.data, 1) 132 | total = prototypes.data.size()[0] 133 | correct = (pred_label == proto_labels.data).sum().item() 134 | correct_id = torch.nonzero(pred_label == proto_labels.data).reshape(-1) 135 | 136 | protos, labels = gen_proto_global(feats[correct_id], proto_labels[correct_id], n_classes) 137 | print('after', correct) 138 | 139 | return global_dnn, protos, labels 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | 9 | class DNN(nn.Module): 10 | def __init__(self, input_dim, hidden_dims, n_classes): 11 | super(DNN_v5, self).__init__() 12 | self.l1 = nn.Linear(input_dim, hidden_dims[0]) 13 | self.l2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 14 | self.l3 = nn.Linear(hidden_dims[1], n_classes) 15 | 16 | 17 | def forward(self, x): 18 | h1 = self.l1(x) 19 | h1 = F.relu(h1) 20 | h2 = self.l2(h1) 21 | out = self.l3(h2) 22 | return h2, out 23 | 24 | 25 | 26 | 27 | __all__ = ['resnet18', 'resnet34', 'resnet50'] 28 | 29 | 30 | 31 | model_urls = { 32 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 33 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 34 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 35 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 36 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 37 | } 38 | 39 | 40 | 41 | 42 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 43 | """3x3 convolution with padding""" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 45 | padding=1, groups=groups, bias=False) 46 | 47 | 48 | def conv3x3_bn(in_planes, out_planes, stride=1, groups=1): 49 | """3x3 convolution with padding""" 50 | modules = nn.Sequential( 51 | nn.BatchNorm2d(in_planes), 52 | nn.ReLU(), 53 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False), 54 | ) 55 | return modules 56 | 57 | 58 | def conv1x1(in_planes, out_planes, stride=1): 59 | """1x1 convolution""" 60 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 61 | 62 | 63 | def conv1x1_bn(in_planes, out_planes, stride=1, groups=1): 64 | modules = nn.Sequential( 65 | nn.BatchNorm2d(in_planes), 66 | nn.ReLU(), 67 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=groups, bias=False), 68 | ) 69 | return modules 70 | 71 | 72 | class BasicBlock(nn.Module): 73 | expansion = 1 74 | def __init__(self, inplanes, planes, stride=1, groups=1, 75 | base_width=64, norm_layer=None): 76 | super(BasicBlock, self).__init__() 77 | if norm_layer is None: 78 | norm_layer = nn.BatchNorm2d 79 | if groups != 1 or base_width != 64: 80 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 81 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 82 | self.conv1 = conv3x3(inplanes, planes, stride) 83 | self.bn1 = norm_layer(planes) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.conv2 = conv3x3(planes, planes) 86 | self.bn2 = norm_layer(planes) 87 | self.downsample = nn.Sequential() 88 | if stride != 1 or inplanes != self.expansion*planes: 89 | self.downsample = nn.Sequential( 90 | nn.Conv2d(inplanes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 91 | norm_layer(self.expansion*planes) 92 | ) 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | x = F.relu(x) 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv2(out) 102 | out = self.bn2(out) 103 | 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | # out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class Bottleneck(nn.Module): 113 | expansion = 4 114 | def __init__(self, inplanes, planes, stride=1, groups=1, 115 | base_width=64, norm_layer=None): 116 | super(Bottleneck, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | width = int(planes * (base_width / 64.)) * groups 120 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 121 | self.conv1 = conv1x1(inplanes, width) 122 | self.bn1 = norm_layer(width) 123 | self.conv2 = conv3x3(width, width, stride, groups) 124 | self.bn2 = norm_layer(width) 125 | self.conv3 = conv1x1(width, planes * self.expansion) 126 | self.bn3 = norm_layer(planes * self.expansion) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.downsample = nn.Sequential() 129 | if stride != 1 or inplanes != self.expansion*planes: 130 | self.downsample = nn.Sequential( 131 | nn.Conv2d(inplanes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 132 | norm_layer(self.expansion*planes) 133 | ) 134 | 135 | self.stride = stride 136 | 137 | def forward(self, x): 138 | x = F.relu(x) 139 | out = self.conv1(x) 140 | out = self.bn1(out) 141 | out = self.relu(out) 142 | 143 | out = self.conv2(out) 144 | out = self.bn2(out) 145 | out = self.relu(out) 146 | 147 | out = self.conv3(out) 148 | out = self.bn3(out) 149 | 150 | identity = self.downsample(x) 151 | 152 | out += identity 153 | # out = self.relu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNet(nn.Module): 159 | def __init__(self, block, kernel_size, layers, num_classes=1000, zero_init_residual=False, 160 | groups=1, width_per_group=64, norm_layer=None): 161 | super(ResNet, self).__init__() 162 | if norm_layer is None: 163 | norm_layer = nn.BatchNorm2d 164 | 165 | self.inplanes = 64 166 | self.groups = groups 167 | self.base_width = width_per_group 168 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=kernel_size, stride=2, padding=3, bias=False) 169 | self.bn1 = norm_layer(self.inplanes) 170 | self.relu = nn.ReLU(inplace=True) 171 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 172 | self.network_channels = [64 * block.expansion, 128 * block.expansion, 256 * block.expansion, 512 * block.expansion] 173 | 174 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 175 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 176 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 177 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 178 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 179 | self.fc = nn.Linear(512 * block.expansion, num_classes) 180 | self.l1= nn.Linear(512 * block.expansion, 512 * block.expansion) 181 | self.l2= nn.Linear(512 * block.expansion, 512 * block.expansion) 182 | self.l3 = nn.Linear(512 * block.expansion, num_classes) 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 187 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 188 | nn.init.constant_(m.weight, 1) 189 | nn.init.constant_(m.bias, 0) 190 | 191 | # Zero-initialize the last BN in each residual branch, 192 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 193 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 194 | if zero_init_residual: 195 | for m in self.modules(): 196 | if isinstance(m, Bottleneck): 197 | nn.init.constant_(m.bn3.weight, 0) 198 | elif isinstance(m, BasicBlock): 199 | nn.init.constant_(m.bn2.weight, 0) 200 | 201 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 202 | if norm_layer is None: 203 | norm_layer = nn.BatchNorm2d 204 | 205 | layers = [] 206 | layers.append(block(self.inplanes, planes, stride, self.groups, self.base_width, norm_layer)) 207 | self.inplanes = planes * block.expansion 208 | for _ in range(1, blocks): 209 | layers.append(block(self.inplanes, planes, groups=self.groups, 210 | base_width=self.base_width, norm_layer=norm_layer)) 211 | 212 | return nn.Sequential(*layers) 213 | 214 | def forward(self, x, PH=True): 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | # x = self.maxpool(x) 218 | 219 | out1 = self.layer1(x) 220 | out2 = self.layer2(out1) 221 | out3 = self.layer3(out2) 222 | out4 = self.layer4(out3) 223 | 224 | out = self.avgpool(F.relu(out4)) 225 | feats = out.view(out.size(0), -1) 226 | feats_out = self.fc(feats) 227 | feats_lin1 = self.l1(feats) 228 | feats_lin2 = self.l2(feats_lin1) 229 | out = self.l3(feats_lin2) 230 | return feats, feats_out, feats_lin2, out 231 | 232 | 233 | 234 | 235 | def resnet18(dataset, kernel_size=3, pretrained=False, **kwargs): 236 | """Constructs a ResNet-18 model. 237 | Args: 238 | pretrained (bool): If True, returns a model pre-trained on ImageNet 239 | """ 240 | if pretrained: 241 | model = ResNet(BasicBlock, kernel_size, [2, 2, 2, 2]) 242 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 243 | 244 | print('done load model') 245 | else: 246 | model = ResNet(BasicBlock, kernel_size, [2, 2, 2, 2], **kwargs) 247 | if dataset == 'tinyimagenet': 248 | model.l3 = nn.Linear(model.l3.in_features, 200) 249 | elif dataset == 'food101': 250 | model.l3 = nn.Linear(model.l3.in_features, 101) 251 | elif dataset == 'vireo172': 252 | model.l3 = nn.Linear(model.l3.in_features, 172) 253 | elif dataset == 'cifar100': 254 | model.l3 = nn.Linear(model.l3.in_features, 100) 255 | elif dataset == 'cifar10': 256 | model.l3 = nn.Linear(model.l3.in_features, 10) 257 | 258 | return model 259 | 260 | 261 | def resnet34(dataset, pretrained=False, **kwargs): 262 | """Constructs a ResNet-34 model. 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | """ 266 | if pretrained: 267 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 268 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 269 | 270 | print('done load model') 271 | else: 272 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 273 | 274 | if dataset == 'tinyimagenet': 275 | model.fc = nn.Linear(model.fc.in_features, 200) 276 | elif dataset == 'food101': 277 | model.fc = nn.Linear(model.fc.in_features, 101) 278 | elif dataset == 'viero172': 279 | model.fc = nn.Linear(model.fc.in_features, 172) 280 | elif dataset == 'cifar100': 281 | model.fc = nn.Linear(model.fc.in_features, 100) 282 | elif dataset == 'cifar10': 283 | model.fc = nn.Linear(model.fc.in_features, 10) 284 | return model 285 | 286 | 287 | def resnet50(dataset, pretrained=False, **kwargs): 288 | """Constructs a ResNet-50 model. 289 | Args: 290 | pretrained (bool): If True, returns a model pre-trained on ImageNet 291 | """ 292 | if pretrained: 293 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 294 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 295 | 296 | print('done load model') 297 | else: 298 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 299 | if dataset == 'tinyimagenet': 300 | model.fc = nn.Linear(model.fc.in_features, 200) 301 | elif dataset == 'food101': 302 | model.fc = nn.Linear(model.fc.in_features, 101) 303 | elif dataset == 'viero172': 304 | model.fc = nn.Linear(model.fc.in_features, 172) 305 | elif dataset == 'cifar100': 306 | model.fc = nn.Linear(model.fc.in_features, 100) 307 | elif dataset == 'cifar10': 308 | model.fc = nn.Linear(model.fc.in_features, 10) 309 | model_name = 'resnet50' 310 | return model 311 | 312 | 313 | def resnet101(dataset, pretrained=False, **kwargs): 314 | """Constructs a ResNet-101 model. 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | """ 318 | if kwargs['num_classes'] != 1000 and pretrained: 319 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 320 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 321 | 322 | else: 323 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 324 | if dataset == 'tinyimagenet': 325 | model.fc = nn.Linear(model.fc.in_features, 200) 326 | elif dataset == 'food101': 327 | model.fc = nn.Linear(model.fc.in_features, 101) 328 | elif dataset == 'viero172': 329 | model.fc = nn.Linear(model.fc.in_features, 172) 330 | elif dataset == 'cifar100': 331 | model.fc = nn.Linear(model.fc.in_features, 100) 332 | elif dataset == 'cifar10': 333 | model.fc = nn.Linear(model.fc.in_features, 10) 334 | return model 335 | 336 | 337 | def resnet152(dataset, pretrained=False, **kwargs): 338 | """Constructs a ResNet-152 model. 339 | Args: 340 | pretrained (bool): If True, returns a model pre-trained on ImageNet 341 | """ 342 | if kwargs['num_classes'] != 1000 and pretrained: 343 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 344 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 345 | 346 | else: 347 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 348 | if dataset == 'tinyimagenet': 349 | model.fc = nn.Linear(model.fc.in_features, 200) 350 | elif dataset == 'food101': 351 | model.fc = nn.Linear(model.fc.in_features, 101) 352 | elif dataset == 'viero172': 353 | model.fc = nn.Linear(model.fc.in_features, 172) 354 | elif dataset == 'cifar100': 355 | model.fc = nn.Linear(model.fc.in_features, 100) 356 | elif dataset == 'cifar10': 357 | model.fc = nn.Linear(model.fc.in_features, 10) 358 | return model -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torch.utils.data as data 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import random 11 | from sklearn.metrics import confusion_matrix 12 | import torch.optim as optim 13 | # from model import * 14 | from datasets import CIFAR10_truncated, CIFAR100_truncated, ImageFolder_custom, MNIST_truncated, FashionMNIST_truncated, TinyImageNet_load, Vireo172_truncated, Food101_truncated 15 | import ipdb 16 | import copy 17 | from kmeans import * 18 | from loss import * 19 | from clustering import * 20 | from sklearn import metrics 21 | 22 | logging.basicConfig() 23 | logger = logging.getLogger() 24 | logger.setLevel(logging.INFO) 25 | 26 | def get_updateModel_before(model, cal_dnn): 27 | model_dict = model.state_dict() 28 | dnn_dict = cal_dnn.state_dict() 29 | # for k, v in dnn_dict.items(): 30 | # if k in dnn_dict: 31 | # print(k) 32 | # print('********************************') 33 | # import ipdb; ipdb.set_trace() 34 | shared_dict = {k: v for k, v in dnn_dict.items() if (k in model_dict)} 35 | 36 | model_dict.update(shared_dict) 37 | model.load_state_dict(model_dict) 38 | return model 39 | 40 | 41 | def get_updateModel_after(model, cal_dnn): 42 | model_dict = model.state_dict() 43 | dnn_dict = cal_dnn.state_dict() 44 | # for k, v in dnn_dict.items(): 45 | # if k in dnn_dict: 46 | # print(k) 47 | # print('********************************') 48 | # import ipdb; ipdb.set_trace() 49 | shared_dict = {k: v for k, v in dnn_dict.items() if (k in model_dict)} 50 | 51 | model_dict.update(shared_dict) 52 | model.load_state_dict(model_dict) 53 | return model 54 | 55 | 56 | 57 | 58 | def sim_mat(features, prototypes, p_labels, t=0.1): 59 | 60 | a_norm = features / features.norm(dim=1)[:, None] 61 | b_norm = prototypes / prototypes.norm(dim=1)[:, None] 62 | # sim_matrix = torch.mm(a_norm, b_norm.transpose(0,1)) 63 | sim_matrix = torch.exp(torch.mm(a_norm, b_norm.transpose(0,1)) / t) 64 | 65 | return sim_matrix 66 | 67 | 68 | 69 | 70 | def mkdirs(dirpath): 71 | try: 72 | os.makedirs(dirpath) 73 | except Exception as _: 74 | pass 75 | 76 | def load_mnist_data(datadir): 77 | transform = transforms.Compose([transforms.ToTensor()]) 78 | 79 | mnist_train_ds = MNIST_truncated(datadir, train=True, download=True, transform=transform) 80 | mnist_test_ds = MNIST_truncated(datadir, train=False, download=True, transform=transform) 81 | 82 | X_train, y_train = mnist_train_ds.data, mnist_train_ds.target 83 | X_test, y_test = mnist_test_ds.data, mnist_test_ds.target 84 | 85 | # y_train = y_train.numpy() 86 | # y_test = y_test.numpy() 87 | 88 | return (X_train, y_train, X_test, y_test) 89 | 90 | def load_fmnist_data(datadir): 91 | transform = transforms.Compose([transforms.ToTensor()]) 92 | 93 | fmnist_train_ds = FashionMNIST_truncated(datadir, train=True, download=True, transform=transform) 94 | fmnist_test_ds = FashionMNIST_truncated(datadir, train=False, download=True, transform=transform) 95 | 96 | X_train, y_train = fmnist_train_ds.data, fmnist_train_ds.target 97 | X_test, y_test = fmnist_test_ds.data, fmnist_test_ds.target 98 | 99 | # y_train = y_train.numpy() 100 | # y_test = y_test.numpy() 101 | 102 | return (X_train, y_train, X_test, y_test) 103 | 104 | 105 | def load_cifar10_data(datadir): 106 | transform = transforms.Compose([transforms.ToTensor()]) 107 | 108 | cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=transform) 109 | cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=transform) 110 | 111 | X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target 112 | X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target 113 | 114 | # y_train = y_train.numpy() 115 | # y_test = y_test.numpy() 116 | 117 | return (X_train, y_train, X_test, y_test) 118 | 119 | 120 | def load_cifar100_data(datadir): 121 | transform = transforms.Compose([transforms.ToTensor()]) 122 | 123 | cifar100_train_ds = CIFAR100_truncated(datadir, train=True, download=True, transform=transform) 124 | cifar100_test_ds = CIFAR100_truncated(datadir, train=False, download=True, transform=transform) 125 | 126 | X_train, y_train = cifar100_train_ds.data, cifar100_train_ds.target 127 | X_test, y_test = cifar100_test_ds.data, cifar100_test_ds.target 128 | 129 | # y_train = y_train.numpy() 130 | # y_test = y_test.numpy() 131 | 132 | return (X_train, y_train, X_test, y_test) 133 | 134 | 135 | def load_tinyimagenet_data(datadir): 136 | transform = transforms.Compose([transforms.ToTensor()]) 137 | xray_train_ds = TinyImageNet_load('../datasets/tiny-imagenet-200/', train=True, transform=transform) 138 | xray_test_ds = TinyImageNet_load('../datasets/tiny-imagenet-200/', train=False, transform=transform) 139 | 140 | X_train, y_train = np.array([s[0] for s in xray_train_ds.samples]), np.array([int(s[1]) for s in xray_train_ds.samples]) 141 | X_test, y_test = np.array([s[0] for s in xray_test_ds.samples]), np.array([int(s[1]) for s in xray_test_ds.samples]) 142 | 143 | return (X_train, y_train, X_test, y_test) 144 | 145 | def load_vireo_data(): 146 | transform = transforms.Compose([transforms.ToTensor()]) 147 | 148 | vireo_train_ds = Vireo172_truncated(transform=transform, mode='train') 149 | vireo_test_ds = Vireo172_truncated(transform=transform, mode='test') 150 | 151 | X_train, y_train = vireo_train_ds.path_to_images, vireo_train_ds.labels 152 | X_test, y_test = vireo_test_ds.path_to_images, vireo_test_ds.labels 153 | 154 | # y_train = y_train.numpy() 155 | # y_test = y_test.numpy() 156 | 157 | return (X_train, y_train, X_test, y_test) 158 | 159 | def load_food_data(): 160 | transform = transforms.Compose([transforms.ToTensor()]) 161 | 162 | vireo_train_ds = Food101_truncated(transform=transform, mode='train') 163 | vireo_test_ds = Food101_truncated(transform=transform, mode='test') 164 | 165 | X_train, y_train = vireo_train_ds.path_to_images, vireo_train_ds.labels 166 | X_test, y_test = vireo_test_ds.path_to_images, vireo_test_ds.labels 167 | 168 | # y_train = y_train.numpy() 169 | # y_test = y_test.numpy() 170 | 171 | return (X_train, y_train, X_test, y_test) 172 | 173 | def record_net_data_stats(y_train, net_dataidx_map, logdir): 174 | net_cls_counts = {} 175 | 176 | for net_i, dataidx in net_dataidx_map.items(): 177 | unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True) 178 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 179 | net_cls_counts[net_i] = tmp 180 | 181 | data_list=[] 182 | for net_id, data in net_cls_counts.items(): 183 | n_total=0 184 | for class_id, n_data in data.items(): 185 | n_total += n_data 186 | data_list.append(n_total) 187 | print('mean:', np.mean(data_list)) 188 | print('std:', np.std(data_list)) 189 | logger.info('Data statistics: %s' % str(net_cls_counts)) 190 | 191 | return net_cls_counts 192 | 193 | 194 | def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4): 195 | if dataset == 'cifar10': 196 | X_train, y_train, X_test, y_test = load_cifar10_data(datadir) 197 | elif dataset == 'mnist': 198 | X_train, y_train, X_test, y_test = load_mnist_data(datadir) 199 | elif dataset == 'fmnist': 200 | X_train, y_train, X_test, y_test = load_fmnist_data(datadir) 201 | 202 | elif dataset == 'cifar100': 203 | X_train, y_train, X_test, y_test = load_cifar100_data(datadir) 204 | elif dataset == 'tinyimagenet': 205 | X_train, y_train, X_test, y_test = load_tinyimagenet_data(datadir) 206 | elif dataset == 'vireo172': 207 | X_train, y_train, X_test, y_test = load_vireo_data() 208 | elif dataset == 'food101': 209 | X_train, y_train, X_test, y_test = load_food_data() 210 | 211 | n_train = y_train.shape[0] 212 | 213 | if partition == "homo" or partition == "iid": 214 | idxs = np.random.permutation(n_train) 215 | batch_idxs = np.array_split(idxs, n_parties) 216 | net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)} 217 | 218 | 219 | elif partition == "noniid-labeldir" or partition == "noniid": 220 | min_size = 0 221 | min_size_test = 0 222 | min_require_size = 10 223 | K = 10 224 | if dataset == 'cifar100': 225 | K = 100 226 | elif dataset == 'tinyimagenet': 227 | K = 200 228 | # min_require_size = 100 229 | elif dataset == 'vireo172': 230 | K = 172 231 | elif dataset == 'food101': 232 | K = 101 233 | 234 | N_train = y_train.shape[0] 235 | N_test = y_test.shape[0] 236 | print('mmm', np.unique(y_train)) 237 | 238 | net_dataidx_map = {} 239 | net_dataidx_map_test = {} 240 | while min_size < min_require_size and min_size_test < min_require_size: 241 | idx_batch = [[] for _ in range(n_parties)] 242 | idx_batch_test = [[] for _ in range(n_parties)] 243 | for k in range(K): 244 | 245 | idx_k = np.where(y_train == k)[0] 246 | np.random.shuffle(idx_k) 247 | proportions = np.random.dirichlet(np.repeat(beta, n_parties)) 248 | proportions = np.array([p * (len(idx_j) < N_train / n_parties) for p, idx_j in zip(proportions, idx_batch)]) 249 | proportions = proportions / proportions.sum() 250 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 251 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 252 | min_size = min([len(idx_j) for idx_j in idx_batch]) 253 | 254 | 255 | for j in range(n_parties): 256 | np.random.shuffle(idx_batch[j]) 257 | net_dataidx_map[j] = idx_batch[j] 258 | 259 | traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map, logdir) 260 | 261 | return (X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts) 262 | 263 | 264 | 265 | 266 | 267 | 268 | def get_trainable_parameters(net, device='cpu'): 269 | 'return trainable parameter values as a vector (only the first parameter set)' 270 | trainable = filter(lambda p: p.requires_grad, net.parameters()) 271 | # print("net.parameter.data:", list(net.parameters())) 272 | paramlist = list(trainable) 273 | #print("paramlist:", paramlist) 274 | N = 0 275 | for params in paramlist: 276 | N += params.numel() 277 | # print("params.data:", params.data) 278 | X = torch.empty(N, dtype=torch.float64, device=device) 279 | X.fill_(0.0) 280 | offset = 0 281 | for params in paramlist: 282 | numel = params.numel() 283 | with torch.no_grad(): 284 | X[offset:offset + numel].copy_(params.data.view_as(X[offset:offset + numel].data)) 285 | offset += numel 286 | # print("get trainable x:", X) 287 | return X 288 | 289 | 290 | def put_trainable_parameters(net, X): 291 | 'replace trainable parameter values by the given vector (only the first parameter set)' 292 | trainable = filter(lambda p: p.requires_grad, net.parameters()) 293 | paramlist = list(trainable) 294 | offset = 0 295 | for params in paramlist: 296 | numel = params.numel() 297 | with torch.no_grad(): 298 | params.data.copy_(X[offset:offset + numel].data.view_as(params.data)) 299 | offset += numel 300 | 301 | 302 | def compute_accuracy_v6(model, glo_proto, glo_proto_label, dataloader, args, get_confusion_matrix=False, device="cpu", multiloader=False): 303 | was_training = False 304 | if model.training: 305 | model.eval() 306 | was_training = True 307 | 308 | true_labels_list, pred_labels_list = np.array([]), np.array([]) 309 | 310 | correct, total = 0, 0 311 | correct_out, correct_sim = 0, 0 312 | if device == 'cpu': 313 | criterion = nn.CrossEntropyLoss() 314 | elif "cuda" in device.type: 315 | criterion = nn.CrossEntropyLoss().to(device) 316 | loss_collector = [] 317 | 318 | with torch.no_grad(): 319 | for batch_idx, (x, target) in enumerate(dataloader): 320 | if device != 'cpu': 321 | x, target = x.to(device), target.to(dtype=torch.int64).to(device) 322 | _, _, feats, out = model(x) 323 | loss = criterion(out, target) 324 | sim_matrix = sim_mat(feats, glo_proto, glo_proto_label, args.temp_final) 325 | 326 | final_out = args.final_weights * out + (1-args.final_weights) * sim_matrix 327 | # ipdb.set_trace() 328 | 329 | _, pred_out = torch.max(out.data, 1) 330 | _, pred_sim = torch.max(sim_matrix.data, 1) 331 | _, pred_label = torch.max(final_out.data, 1) 332 | 333 | loss_collector.append(loss.item()) 334 | total += x.data.size()[0] 335 | correct_out += (pred_out == target.data).sum().item() 336 | correct_sim += (pred_sim == target.data).sum().item() 337 | correct += (pred_label == target.data).sum().item() 338 | 339 | if device == "cpu": 340 | pred_labels_list = np.append(pred_labels_list, pred_label.numpy()) 341 | true_labels_list = np.append(true_labels_list, target.data.numpy()) 342 | else: 343 | pred_labels_list = np.append(pred_labels_list, pred_label.cpu().numpy()) 344 | true_labels_list = np.append(true_labels_list, target.data.cpu().numpy()) 345 | avg_loss = sum(loss_collector) / len(loss_collector) 346 | 347 | if get_confusion_matrix: 348 | conf_matrix = confusion_matrix(true_labels_list, pred_labels_list) 349 | 350 | if was_training: 351 | model.train() 352 | 353 | if get_confusion_matrix: 354 | return correct_out / float(total), correct_sim / float(total), correct / float(total), conf_matrix, avg_loss 355 | 356 | return correct_out / float(total), correct_sim / float(total), correct / float(total), avg_loss 357 | 358 | 359 | 360 | def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, dataidxs_test=None, noise_level=0): 361 | if dataset in ('cifar10', 'cifar100'): 362 | if dataset == 'cifar10': 363 | dl_obj = CIFAR10_truncated 364 | 365 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 366 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 367 | transform_train = transforms.Compose([ 368 | transforms.ToTensor(), 369 | transforms.Lambda(lambda x: F.pad( 370 | Variable(x.unsqueeze(0), requires_grad=False), 371 | (4, 4, 4, 4), mode='reflect').data.squeeze()), 372 | transforms.ToPILImage(), 373 | transforms.ColorJitter(brightness=noise_level), 374 | transforms.RandomCrop(32), 375 | transforms.RandomHorizontalFlip(), 376 | transforms.ToTensor(), 377 | normalize 378 | ]) 379 | # data prep for test set 380 | transform_test = transforms.Compose([ 381 | transforms.ToTensor(), 382 | normalize]) 383 | 384 | elif dataset == 'cifar100': 385 | dl_obj = CIFAR100_truncated 386 | 387 | normalize = transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343], 388 | std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]) 389 | transform_train = transforms.Compose([ 390 | # transforms.ToPILImage(), 391 | transforms.RandomCrop(32, padding=4), 392 | transforms.RandomHorizontalFlip(), 393 | transforms.RandomRotation(15), 394 | transforms.ToTensor(), 395 | normalize 396 | ]) 397 | 398 | # data prep for test set 399 | transform_test = transforms.Compose([ 400 | transforms.ToTensor(), 401 | normalize]) 402 | 403 | 404 | 405 | train_ds = dl_obj(datadir, dataidxs=dataidxs, train=True, transform=transform_train, download=True) 406 | test_ds = dl_obj(datadir, dataidxs=dataidxs_test, train=False, transform=transform_test, download=True) 407 | 408 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, drop_last=True, shuffle=True, num_workers=4, pin_memory=True) 409 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, num_workers=4, pin_memory=True) 410 | 411 | 412 | elif dataset == 'tinyimagenet': 413 | dl_obj = TinyImageNet_load 414 | transform_train = transforms.Compose([ 415 | transforms.ToTensor(), 416 | transforms.Lambda(lambda x: F.pad( 417 | Variable(x.unsqueeze(0), requires_grad=False), 418 | (4, 4, 4, 4), mode='reflect').data.squeeze()), 419 | transforms.ToPILImage(), 420 | transforms.ColorJitter(brightness=noise_level), 421 | transforms.RandomCrop(64), 422 | transforms.RandomHorizontalFlip(), 423 | transforms.ToTensor(), 424 | transforms.Normalize((.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 425 | ]) 426 | transform_test = transforms.Compose([ 427 | transforms.ToTensor(), 428 | transforms.Normalize((.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 429 | ]) 430 | 431 | train_ds = dl_obj('../datasets/tiny-imagenet-200/', train=True, dataidxs=dataidxs, transform=transform_train) 432 | test_ds = dl_obj('../datasets/tiny-imagenet-200/', train=False, transform=transform_test) 433 | 434 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, drop_last=True, shuffle=True, num_workers=4, pin_memory=True) 435 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, num_workers=4, pin_memory=True) 436 | elif dataset == 'mnist': 437 | dl_obj = MNIST_truncated 438 | 439 | normalize = transforms.Normalize((0.1307,), (0.3081,)) 440 | 441 | transform_train = transforms.Compose([ 442 | transforms.ToPILImage(), 443 | # transforms.RandomCrop(28, padding=4), 444 | transforms.ToTensor(), 445 | ]) 446 | # data prep for test set 447 | transform_test = transforms.Compose([ 448 | transforms.ToPILImage(), 449 | transforms.ToTensor()]) 450 | 451 | train_ds = dl_obj(datadir, dataidxs=dataidxs, train=True, transform=transform_train, download=True) 452 | test_ds = dl_obj(datadir, train=False, transform=transform_test, download=True) 453 | 454 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, drop_last=True, shuffle=True, num_workers=4, pin_memory=True) 455 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, num_workers=4, pin_memory=True) 456 | elif dataset == 'fmnist': 457 | dl_obj = FashionMNIST_truncated 458 | 459 | normalize = transforms.Normalize((0.1307,), (0.3081,)) 460 | 461 | transform_train = transforms.Compose([ 462 | transforms.ToPILImage(), 463 | # transforms.RandomCrop(28, padding=4), 464 | transforms.ToTensor(), 465 | ]) 466 | # data prep for test set 467 | transform_test = transforms.Compose([ 468 | transforms.ToPILImage(), 469 | transforms.ToTensor()]) 470 | train_ds = dl_obj(datadir, dataidxs=dataidxs, train=True, transform=transform_train, download=True) 471 | test_ds = dl_obj(datadir, train=False, transform=transform_test, download=True) 472 | 473 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, drop_last=True, shuffle=True, num_workers=4, pin_memory=True) 474 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, num_workers=4, pin_memory=True) 475 | 476 | 477 | elif dataset == 'vireo172': 478 | dl_obj = Vireo172_truncated 479 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 480 | 481 | transform_train = transforms.Compose([ 482 | transforms.ToTensor(), 483 | transforms.ToPILImage(), 484 | transforms.Resize(224), 485 | transforms.RandomHorizontalFlip(), 486 | transforms.ToTensor(), 487 | normalize 488 | ]) 489 | # transform_train = transforms.Compose([ 490 | # transforms.ToTensor(), 491 | # normalize, 492 | # ]) 493 | transform_test = transforms.Compose([ 494 | transforms.ToTensor(), 495 | normalize, 496 | ]) 497 | 498 | 499 | train_ds = dl_obj(dataidxs, transform_train, mode='train') 500 | test_ds = dl_obj(None, transform_test, mode='test') 501 | 502 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, drop_last=True, shuffle=True, num_workers=4, pin_memory=True) 503 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, num_workers=4, pin_memory=True) 504 | 505 | elif dataset == 'food101': 506 | dl_obj = Food101_truncated 507 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 508 | 509 | transform_train = transforms.Compose([ 510 | transforms.ToTensor(), 511 | transforms.ToPILImage(), 512 | transforms.Resize(224), 513 | transforms.RandomHorizontalFlip(), 514 | transforms.ToTensor(), 515 | normalize 516 | ]) 517 | # transform_train = transforms.Compose([ 518 | # transforms.ToTensor(), 519 | # normalize, 520 | # ]) 521 | transform_test = transforms.Compose([ 522 | transforms.ToTensor(), 523 | normalize, 524 | ]) 525 | 526 | 527 | train_ds = dl_obj(dataidxs, transform_train, mode='train') 528 | test_ds = dl_obj(None, transform_test, mode='test') 529 | 530 | train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, drop_last=True, shuffle=True, num_workers=2, pin_memory=True) 531 | test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, num_workers=2, pin_memory=True) 532 | 533 | return train_dl, test_dl, train_ds, test_ds 534 | 535 | def fix_bn(m): 536 | classname = m.__class__.__name__ 537 | if classname.find('BatchNorm') != -1: 538 | m.eval() 539 | 540 | 541 | 542 | def dropout_proto_local_v2(net, dataloader, args, n_class=10, device='cuda:0'): 543 | feats = [] 544 | labels = [] 545 | net.eval() 546 | net.apply(fix_bn) 547 | net.to(device) 548 | with torch.no_grad(): 549 | for batch_idx, (x, target) in enumerate(dataloader): 550 | x, target = x.to(device), target.to(device) 551 | _, feat, _ = net(x) 552 | 553 | feats.append(feat) 554 | labels.extend(target) 555 | 556 | feats = torch.cat(feats) 557 | labels = torch.tensor(labels) 558 | # ipdb.set_trace() 559 | prototype = [] 560 | proto_label = [] 561 | class_label = [] 562 | class_idx = [] 563 | for i in range(n_class): 564 | index = torch.nonzero(labels == i).reshape(-1) 565 | if len(index) > 0: 566 | class_idx.append(index) 567 | class_label.append(int(i)) 568 | else: 569 | class_idx.append([-1]) 570 | class_label.append(-1) 571 | 572 | for i in range(n_class): 573 | if i in class_label: 574 | if len(class_idx[i])>=5: 575 | for j in range(args.number): #len(class_idx[i]) 576 | idx = np.random.choice(np.arange(len(class_idx[i])), int(len(class_idx[i])*args.ratio)) 577 | feature_classwise = feats[class_idx[i][idx]] 578 | prototype.append(torch.mean(feature_classwise, axis=0).reshape((1, -1))) 579 | proto_label.append(int(i)) 580 | else: 581 | proto_label.append(int(i)) 582 | feature_classwise = feats[class_idx[i]] 583 | prototype.append(torch.mean(feature_classwise, axis=0).reshape((1, -1))) 584 | 585 | return torch.cat(prototype, dim=0), torch.tensor(proto_label) 586 | 587 | 588 | 589 | 590 | def gen_proto_global(feats, labels, n_classes): 591 | local_proto = [] 592 | local_labels = [] 593 | for i in range(n_classes): 594 | # ipdb.set_trace() 595 | c_i = torch.nonzero(labels == i).reshape(-1) 596 | proto_i = torch.sum(feats[c_i], dim=0) / len(c_i) 597 | local_proto.append(proto_i.reshape(1, -1)) 598 | local_labels.append(i) 599 | 600 | return torch.cat(local_proto, dim=0), torch.tensor(local_labels) 601 | 602 | 603 | def aug_protos(local_protos, local_labels, posi_lambda, nega_lambda, n_classes=10): 604 | 605 | aug_protos = [] 606 | aug_labels = [] 607 | for p_id, proto in enumerate(local_protos): 608 | for x_id, x in enumerate(local_protos): 609 | if x_id != p_id: 610 | if local_labels[p_id] == local_labels[x_id]: 611 | aug_protos.append((1+posi_lambda)*proto-posi_lambda*x) 612 | aug_labels.append(local_labels[p_id]) 613 | else: 614 | aug_protos.append((1-posi_lambda)*proto-posi_lambda*x) 615 | aug_labels.append(local_labels[x_id]) 616 | # ipdb.set_trace() 617 | aug_protos = torch.stack(aug_protos).cuda() 618 | aug_labels = torch.tensor(aug_labels).cuda() 619 | 620 | final_protos = torch.cat([local_protos, aug_protos]).cuda() 621 | final_labels = torch.cat([local_labels, aug_labels]).cuda() 622 | 623 | return final_protos, final_labels 624 | 625 | def mixup_data(x, y, alpha): 626 | 627 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 628 | if alpha > 0.: 629 | lam = np.random.beta(alpha, alpha) 630 | else: 631 | lam = 1. 632 | batch_size = x.size()[0] 633 | index = torch.randperm(batch_size).cuda() 634 | mixed_x = lam * x + (1 - lam) * x[index,:] 635 | y_a, y_b = y, y[index] 636 | return mixed_x, y_a, y_b, lam 637 | 638 | --------------------------------------------------------------------------------