├── figure ├── system-overview.png ├── three-stage-framework.pdf └── three-stage-framework.png ├── supervise-fl-node ├── run_supervise_node.sh ├── run_supervise_node_all.sh ├── communication.py ├── tdnn.py ├── util.py ├── data_pre.py ├── model.py └── supervise_main_node.py ├── unsupervise-fl-node ├── run_unsupervise_node.sh ├── run_unsupervise_node_all.sh ├── data_pre.py ├── communication.py ├── tdnn.py ├── util.py ├── cosmo_design.py ├── model.py └── unsupervise_main_node.py ├── LICENSE ├── dataset.md ├── README.md ├── supervise-fl-server └── supervise_main_server.py └── unsupervise-fl-server └── unsupervise_main_server.py /figure/system-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmouyang/ADMarker/HEAD/figure/system-overview.png -------------------------------------------------------------------------------- /figure/three-stage-framework.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmouyang/ADMarker/HEAD/figure/three-stage-framework.pdf -------------------------------------------------------------------------------- /figure/three-stage-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmouyang/ADMarker/HEAD/figure/three-stage-framework.png -------------------------------------------------------------------------------- /supervise-fl-node/run_supervise_node.sh: -------------------------------------------------------------------------------- 1 | python3 ./supervise_main_node.py --node_id 0 --batch_size 8 --epochs 101 --fl_epoch 10 2 | -------------------------------------------------------------------------------- /unsupervise-fl-node/run_unsupervise_node.sh: -------------------------------------------------------------------------------- 1 | python3 ./unsupervise_main_node.py --node_id 0 --num_of_samples 3000 --batch_size 16 --epochs 51 --fl_epoch 10 2 | -------------------------------------------------------------------------------- /supervise-fl-node/run_supervise_node_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 ./supervise_main_node.py --node_id 0 --batch_size 8 --epochs 101 --fl_epoch 10 & 2 | CUDA_VISIBLE_DEVICES=1 python3 ./supervise_main_node.py -node_id 1 --batch_size 8 --epochs 101 --fl_epoch 10 & 3 | CUDA_VISIBLE_DEVICES=2 python3 ./supervise_main_node.py --node_id 2 --batch_size 8 --epochs 101 --fl_epoch 10 & 4 | CUDA_VISIBLE_DEVICES=3 python3 ./supervise_main_node.py --node_id 3 --batch_size 8 --epochs 101 --fl_epoch 10 & 5 | -------------------------------------------------------------------------------- /unsupervise-fl-node/run_unsupervise_node_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 ./unsupervise_main_node.py --node_id 0 --num_of_samples 3000 --batch_size 16 --epochs 101 --fl_epoch 10 & 2 | CUDA_VISIBLE_DEVICES=1 python3 ./unsupervise_main_node.py --node_id 1 --num_of_samples 3000 --batch_size 16 --epochs 101 --fl_epoch 10 & 3 | CUDA_VISIBLE_DEVICES=2 python3 ./unsupervise_main_node.py --node_id 2 --num_of_samples 3000 --batch_size 16 --epochs 101 --fl_epoch 10 & 4 | CUDA_VISIBLE_DEVICES=3 python3 ./unsupervise_main_node.py --node_id 3 --num_of_samples 3000 --batch_size 16 --epochs 101 --fl_epoch 10 & 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xiaomin OUYANG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /unsupervise-fl-node/data_pre.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | 7 | 8 | 9 | class Multimodal_unlabel_dataset(): 10 | """Build dataset from motion sensor data.""" 11 | def __init__(self, node_id, num_of_data): 12 | 13 | self.folder_path = "../AD-example-data/node{}/train_unlabel_data/".format(node_id) 14 | 15 | self.num_of_data = num_of_data 16 | 17 | 18 | def __len__(self): 19 | return self.num_of_data 20 | 21 | def __getitem__(self, idx): 22 | 23 | # print("idx:", idx) 24 | 25 | x1 = np.load(self.folder_path + "audio/" + "{}.npy".format(idx)) 26 | x2 = np.load(self.folder_path + "depth/" + "{}.npy".format(idx)) 27 | x3 = np.load(self.folder_path + "radar/" + "{}.npy".format(idx)) 28 | 29 | self.data1 = x1.tolist() #concate and tolist 30 | self.data2 = x2.tolist() #concate and tolist 31 | self.data3 = x3.tolist() 32 | 33 | sensor_data1 = torch.tensor(self.data1) # to tensor 34 | sensor_data2 = torch.tensor(self.data2).float() # to tensor 35 | sensor_data3 = torch.tensor(self.data3) # to tensor 36 | 37 | sensor_data2 = torch.unsqueeze(sensor_data2, 0) 38 | 39 | return sensor_data1, sensor_data2, sensor_data3 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /supervise-fl-node/communication.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import pickle, struct 3 | import sys 4 | from threading import Lock, Thread 5 | import threading 6 | 7 | 8 | class COMM: 9 | def __init__(self, host, port, user_id): 10 | self.host = host 11 | self.port = port 12 | self.id = user_id 13 | self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | self.client.connect((host,port)) 15 | 16 | # the mess_type defines the content sent to server 17 | # -1 means start request 18 | # 0 means W 19 | # 1 means loss 20 | # 9 means straggler end connection 21 | # 10 means end connection 22 | def send2server(self,content,mess_type): 23 | data = pickle.dumps(content, protocol = 0) 24 | size = sys.getsizeof(data) 25 | 26 | header = struct.pack("i",size) 27 | u_id = struct.pack("i",self.id) 28 | mess_type = struct.pack("i",mess_type) 29 | 30 | self.client.sendall(header) 31 | self.client.sendall(u_id) 32 | self.client.sendall(mess_type) 33 | self.client.sendall(data) 34 | 35 | def recvfserver(self): 36 | header = self.client.recv(4) 37 | size = struct.unpack('i',header) 38 | 39 | recv_data = b"" 40 | while sys.getsizeof(recv_data) np.asarray(args.lr_decay_epochs)) 62 | if steps > 0: 63 | lr = lr * (args.lr_decay_rate ** steps) 64 | 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr 67 | 68 | 69 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 70 | if args.warm and epoch <= args.warm_epochs: 71 | p = (batch_id + (epoch - 1) * total_batches) / \ 72 | (args.warm_epochs * total_batches) 73 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 74 | 75 | for param_group in optimizer.param_groups: 76 | param_group['lr'] = lr 77 | 78 | 79 | def set_optimizer(opt, model): 80 | # optimizer = optim.SGD(model.parameters(), 81 | # lr=opt.learning_rate, 82 | # momentum=opt.momentum, 83 | # weight_decay=opt.weight_decay) 84 | optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) 85 | return optimizer 86 | 87 | def save_model(model, optimizer, opt, epoch, save_file): 88 | print('==> Saving...') 89 | state = { 90 | 'opt': opt, 91 | 'model': model.state_dict(), 92 | 'optimizer': optimizer.state_dict(), 93 | 'epoch': epoch, 94 | } 95 | torch.save(state, save_file) 96 | del state 97 | -------------------------------------------------------------------------------- /unsupervise-fl-node/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | 8 | 9 | class TwoCropTransform: 10 | """Create two crops of the same image""" 11 | def __init__(self, transform): 12 | self.transform = transform 13 | 14 | def __call__(self, x): 15 | return [self.transform(x), self.transform(x)] 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def accuracy(output, target, topk=(1,)): 37 | """Computes the accuracy over the k top predictions for the specified values of k""" 38 | with torch.no_grad(): 39 | maxk = max(topk) 40 | batch_size = target.size(0) 41 | _, pred = output.topk(maxk, 1, True, True) 42 | pred = pred.t() 43 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 44 | 45 | # print(correct) 46 | 47 | res = [] 48 | for k in topk: 49 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 50 | res.append(correct_k.mul_(100.0 / batch_size)) 51 | return res 52 | 53 | 54 | def adjust_learning_rate(args, optimizer, epoch): 55 | lr = args.learning_rate 56 | if args.cosine: 57 | eta_min = lr * (args.lr_decay_rate ** 3) 58 | lr = eta_min + (lr - eta_min) * ( 59 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 60 | else: 61 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 62 | if steps > 0: 63 | lr = lr * (args.lr_decay_rate ** steps) 64 | 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr 67 | 68 | 69 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 70 | if args.warm and epoch <= args.warm_epochs: 71 | p = (batch_id + (epoch - 1) * total_batches) / \ 72 | (args.warm_epochs * total_batches) 73 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 74 | 75 | for param_group in optimizer.param_groups: 76 | param_group['lr'] = lr 77 | 78 | 79 | def set_optimizer(opt, model): 80 | # optimizer = optim.SGD(model.parameters(), 81 | # lr=opt.learning_rate, 82 | # momentum=opt.momentum, 83 | # weight_decay=opt.weight_decay) 84 | optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) 85 | return optimizer 86 | 87 | def save_model(model, optimizer, opt, epoch, save_file): 88 | print('==> Saving...') 89 | state = { 90 | 'opt': opt, 91 | 'model': model.state_dict(), 92 | 'optimizer': optimizer.state_dict(), 93 | 'epoch': epoch, 94 | } 95 | torch.save(state, save_file) 96 | del state 97 | -------------------------------------------------------------------------------- /supervise-fl-node/data_pre.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | 7 | class Multimodal_train_dataset(): 8 | """Build dataset from motion sensor data.""" 9 | def __init__(self, node_id, num_of_samples): 10 | 11 | self.folder_path = "../AD-example-data/node{}/train_label_data/".format(node_id) 12 | y = np.load("../AD-example-data/node{}/train_label_data/label.npy".format(node_id)) 13 | 14 | self.labels = y.tolist() #tolist 15 | self.labels = torch.tensor(self.labels).long() 16 | self.num_of_samples = num_of_samples 17 | 18 | 19 | def __len__(self): 20 | 21 | if self.num_of_samples < len(self.labels): 22 | return self.num_of_samples 23 | else: 24 | return len(self.labels) 25 | 26 | def __getitem__(self, idx): 27 | 28 | # print("idx:", idx) 29 | 30 | x1 = np.load(self.folder_path + "audio/" + "{}.npy".format(idx)) 31 | x2 = np.load(self.folder_path + "depth/" + "{}.npy".format(idx)) 32 | x3 = np.load(self.folder_path + "radar/" + "{}.npy".format(idx)) 33 | 34 | self.data1 = x1.tolist() #concate and tolist 35 | self.data2 = x2.tolist() #concate and tolist 36 | self.data3 = x3.tolist() 37 | 38 | sensor_data1 = torch.tensor(self.data1) # to tensor 39 | sensor_data2 = torch.tensor(self.data2).float() # to tensor 40 | sensor_data3 = torch.tensor(self.data3) # to tensor 41 | 42 | sensor_data2 = torch.unsqueeze(sensor_data2, 0) 43 | 44 | activity_label = self.labels[idx] 45 | 46 | return sensor_data1, sensor_data2, sensor_data3, activity_label 47 | 48 | 49 | class Multimodal_test_dataset(): 50 | """Build dataset from motion sensor data.""" 51 | def __init__(self, node_id): 52 | 53 | self.folder_path = "../AD-example-data/node{}/test_data/".format(node_id) 54 | y = np.load("../AD-example-data/node{}/test_data/label.npy".format(node_id)) 55 | 56 | self.labels = y.tolist() #tolist 57 | self.labels = torch.tensor(self.labels).long() 58 | 59 | 60 | def __len__(self): 61 | return len(self.labels) 62 | 63 | def __getitem__(self, idx): 64 | 65 | # print("idx:", idx) 66 | 67 | x1 = np.load(self.folder_path + "audio/" + "{}.npy".format(idx)) 68 | x2 = np.load(self.folder_path + "depth/" + "{}.npy".format(idx)) 69 | x3 = np.load(self.folder_path + "radar/" + "{}.npy".format(idx)) 70 | 71 | self.data1 = x1.tolist() #concate and tolist 72 | self.data2 = x2.tolist() #concate and tolist 73 | self.data3 = x3.tolist() 74 | 75 | sensor_data1 = torch.tensor(self.data1) # to tensor 76 | sensor_data2 = torch.tensor(self.data2).float() # to tensor 77 | sensor_data3 = torch.tensor(self.data3) # to tensor 78 | 79 | sensor_data2 = torch.unsqueeze(sensor_data2, 0) 80 | 81 | activity_label = self.labels[idx] 82 | 83 | 84 | return sensor_data1, sensor_data2, sensor_data3, activity_label 85 | 86 | def count_num_per_class(node_id, num_class, num_of_samples): 87 | 88 | original_label = np.load("../AD-example-data/node{}/train_label_data/label.npy".format(node_id)) 89 | 90 | if num_of_samples < original_label.shape[0]: 91 | y = original_label[0:num_of_samples] 92 | else: 93 | y = original_label 94 | 95 | count_y = np.bincount(np.array(y).astype(int), minlength = num_class).astype(float) 96 | 97 | for idx in range(count_y.shape[0]): 98 | if count_y[idx] == 0: 99 | count_y[idx] = 0.5 100 | 101 | return count_y, y, len(y) 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /unsupervise-fl-node/cosmo_design.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | np.random.seed(0) 9 | 10 | fusion_weight = np.array([[0.29944117, 0.56794159, 0.13261724], 11 | [0.02646727, 0.23987183, 0.7336609 ], 12 | [0.43845882, 0.49335323, 0.06818795], 13 | [0.35973997, 0.56481998, 0.07544006], 14 | [0.22159766, 0.48902021, 0.28938213], 15 | [0.19107792, 0.47045704, 0.33846504], 16 | [0.44423391, 0.54002756, 0.01573852], 17 | [0.30068799, 0.38804443, 0.31126758], 18 | [0.46167461, 0.21458298, 0.32374241]]) 19 | 20 | ## audio, depth, radar 21 | def FeatureConstructor(f1, f2, f3, num_positive): 22 | 23 | # fusion_weight = np.arange(1, num_positive + 1) / (num_positive + 1)#(0.1, 0,2, ..., 0.9) 24 | fusion_weight = np.random.dirichlet((1, 1, 1), num_positive) 25 | # print("fusion_weight: ",fusion_weight) 26 | 27 | fused_feature = [] 28 | 29 | for fuse_id in range(num_positive): 30 | 31 | temp_fuse = fusion_weight[fuse_id, 0] * f1 + fusion_weight[fuse_id, 1] * f2 + fusion_weight[fuse_id, 2] * f3 32 | # temp_fuse = torch.cat((fusion_weight[fuse_id, 0] * f1, fusion_weight[fuse_id, 1] * f2, fusion_weight[fuse_id, 2] * f3), dim=1) #concate 33 | 34 | fused_feature.append(temp_fuse) 35 | 36 | fused_feature = torch.stack(fused_feature, dim = 1) 37 | 38 | return fused_feature 39 | 40 | 41 | ## contrastive loss with supervised format 42 | 43 | class CosmoLoss(nn.Module): 44 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 45 | It also supports the unsupervised contrastive loss in SimCLR""" 46 | def __init__(self, temperature=0.07, contrast_mode='all', 47 | base_temperature=0.07): 48 | super(CosmoLoss, self).__init__() 49 | self.temperature = temperature 50 | self.contrast_mode = contrast_mode 51 | self.base_temperature = base_temperature 52 | 53 | def forward(self, features, labels=None, mask=None): 54 | """Compute loss for model. If both `labels` and `mask` are None, 55 | it degenerates to SimCLR unsupervised loss: 56 | https://arxiv.org/pdf/2002.05709.pdf 57 | 58 | Args: 59 | features: hidden vector of shape [bsz, n_views, ...]. 60 | labels: ground truth of shape [bsz]. 61 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 62 | has the same class as sample i. Can be asymmetric. 63 | Returns: 64 | A loss scalar. 65 | """ 66 | device = (torch.device('cuda') 67 | if features.is_cuda 68 | else torch.device('cpu')) 69 | 70 | if len(features.shape) < 3: 71 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 72 | 'at least 3 dimensions are required') 73 | if len(features.shape) > 3: 74 | features = features.view(features.shape[0], features.shape[1], -1) 75 | 76 | batch_size = features.shape[0] 77 | 78 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 79 | 80 | contrast_count = features.shape[1] 81 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)# change to [n_views*bsz, 3168] 82 | contrast_feature = F.normalize(contrast_feature, dim = 1) 83 | 84 | # print(contrast_feature.shape) 85 | 86 | anchor_feature = contrast_feature 87 | anchor_count = contrast_count 88 | 89 | # compute logits, z_i * z_a / T 90 | similarity_matrix = torch.div( 91 | torch.matmul(anchor_feature, contrast_feature.T), 92 | self.temperature) 93 | 94 | # for numerical stability 95 | # similarity_matrix = F.normalize(similarity_matrix, p=2, dim = 1) 96 | 97 | # tile mask 98 | mask = mask.repeat(anchor_count, contrast_count)# positive index 99 | # print(mask.shape)#[1151, 1152] (btz*9) 100 | 101 | # mask-out self-contrast cases 102 | logits_mask = torch.scatter( 103 | torch.ones_like(mask), 104 | 1, 105 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 106 | 0 107 | )#dig to 0, others to 1 (negative samples) 108 | # print(logits_mask.shape) 109 | 110 | mask = mask * logits_mask#positive samples except itself 111 | 112 | # compute log_prob 113 | exp_logits = torch.exp(similarity_matrix) * logits_mask #exp(z_i * z_a / T) 114 | # all_log_prob = torch.log(exp_logits.sum(1))# log(sum(exp(z_i * z_a / T))), need change to I\{i} later 115 | 116 | # SupCon out 117 | log_prob = similarity_matrix - torch.log(exp_logits.sum(1, keepdim=True)) 118 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)#sup_out 119 | 120 | # SupCon in 121 | # log_prob = torch.exp(similarity_matrix) / exp_logits.sum(1, keepdim=True) 122 | # mean_log_prob_pos = torch.log((mask * log_prob).sum(1) / mask.sum(1)) 123 | 124 | # loss 125 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 126 | # print("loss:",loss) 127 | 128 | loss = loss.view(anchor_count, batch_size).mean() 129 | # print("mean loss:",loss) 130 | 131 | return loss 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADMarker 2 | This is the repo for MobiCom 2024 paper: "ADMarker: A Multi-Modal Federated Learning System for Monitoring Digital Biomarkers of Alzheimer’s Disease". 3 | 4 | # Citation 5 | The code and datasets of this project are made available for non-commercial, academic research only. If you would like to use the code or datasets of this project, please cite the following papers: 6 | ``` 7 | @article{ouyang2023admarker, 8 | title={ADMarker: A Multi-Modal Federated Learning System for Monitoring Digital Biomarkers of Alzheimer's Disease}, 9 | author={Ouyang, Xiaomin and Shuai, Xian and Li, Yang and Pan, Li and Zhang, Xifan and Fu, Heming and Wang, Xinyan and Cao, Shihua and Xin, Jiang and Mok, Hazel and others}, 10 | journal={arXiv preprint arXiv:2310.15301}, 11 | year={2023} 12 | } 13 | @article{ouyang2024admarker, 14 | title={ADMarker: A Multi-Modal Federated Learning System for Monitoring Digital Biomarkers of Alzheimer's Disease}, 15 | author={Ouyang, Xiaomin and Shuai, Xian and Li, Yang and Pan, Li and Zhang, Xifan and Fu, Heming and Wang, Xinyan and Cao, Shihua and Xin, Jiang and Mok, Hazel and others}, 16 | journal={Proceedings of the 30th Annual International Conference on Mobile Computing And Networking}, 17 | year={2024} 18 | } 19 | ``` 20 | # Requirements 21 | The program has been tested in the following environment: 22 | * Computing Clusters: Python 3.9.7, Pytorch 1.12.0, torchvision 0.13.0, CUDA Version 10.2, sklearn 0.24.2, numpy 1.20.3 23 | * Nvidia Xavier NX: Ubuntu 18.04.6, Python 3.6.9, Pytorch 1.8.0, CUDA Version 10.2, sklearn 0.24.2, numpy 1.19.5 24 |
25 | 26 | # ADMarker FL Overview 27 |

28 | 29 |

30 | 31 | First Stage: Centralized model pre-training 32 | 33 | Second Stage: Unsupervised multi-modal federated learning 34 | * Client: 35 | * Local unsupervised multimodal training with contrastive fusion learning 36 | * Send model weights to the server. 37 | * Server: 38 | * Aggregate model weights of different modalities with Fedavg; 39 | * Send the aggregated model weights to each client. 40 | 41 | Third Stage: Supervised multi-modal federated learning 42 | * Client: 43 | * Local fusion: train the classifier layers with labeled data; 44 | * Send model weights to the server. 45 | * Server: 46 | * Send the aggregated model weights to each client. 47 | 48 | 49 | # Project Strcuture 50 | ``` 51 | |--unsupervise-fl-node // codes running unsupervised FL on clients 52 | 53 | |-- run_unsupervise_node.sh/ // run unsupervised FL of a client on a edge device 54 | |-- unsupervise_main_node.py/ // main file of running unsupervised FL on the client 55 | |-- communication.py/ // set up communication with server 56 | |-- data_pre.py/ // load the data for clients in FL 57 | |-- cosmo_design.py/ // contrastive fusion learning 58 | |-- model.py/ // model configurations 59 | |-- util.py // utility functions 60 | 61 | |--unsupervise-fl-server // codes running unsupervised FL on the server 62 | |-- unsupervise_main_server.py 63 | 64 | |--supervise-fl-node // codes running supervised FL on clients 65 | 66 | |-- run_supervise_node.sh/ // run supervised FL of a client on a edge device 67 | |-- supervise_main_node.py/ // main file of running supervised FL on the client 68 | |-- communication.py/ // set up communication with server 69 | |-- data_pre.py/ // load the data for clients in FL 70 | |-- model.py/ // model configurations 71 | |-- util.py // utility functions 72 | 73 | |--supervise-fl-server // codes running supervised FL on the server 74 | |-- supervise_main_server.py 75 | 76 | ``` 77 |
78 | 79 | # Quick Start 80 | * Download the codes for each dataset in this repo. Put the folder `unsupervise-fl-node` and `supervise-fl-node` on your client machines, and `unsupervise-fl-server` and `supervise-fl-server` on your server machine. 81 | * Download the `dataset` from [ADMarker-Example-Datasets](https://github.com/xmouyang/ADMarker/blob/main/dataset.md) to your client machines. Put the folder `under the same folder` with codes of running FL on clients. You can also change the path of loading datasets in 'data_pre.py' to the data path on your client machine. 82 | * Download the `pretrain_model.pth` from [pre-trained model weights](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155136315_link_cuhk_edu_hk/ESrMgTVfkMdFmaJWESmwxbYBsDDDrxwRxesVZzuY8deB2g?e=FV9HYH) to your client machines. Put the folder `under the same folder` with codes of running FL on clients. 83 | * Change the argument "server_address" in 'unsupervise_main_node.py' and 'supervise_main_node.py' as your true server address. If your server is located in the same physical machine of your nodes, you can choose "localhost" for this argument. 84 | * Run unsupervised federated learning on the clients and server: 85 | * Server: 86 | ```bash 87 | python3 unsupervise_main_server.py 88 | ``` 89 | * Client: change the 'node_id' (0,1, 2, ...) in the below script for each client 90 | ```bash 91 | ./run_unsupervise_node.sh 92 | ``` 93 | * Run supervised federated learning on the clients and server: 94 | * Server: 95 | ```bash 96 | python3 supervise_main_server.py 97 | ``` 98 | * Client: change the 'node_id' (0,1, 2, ...) in the below script for each client 99 | ```bash 100 | ./run_supervise_node.sh 101 | ``` 102 | * NOTE: The default codes corresponde to the settings with four nodes, as we only released data from four subjects due to the privacy concerns. If you want to adapt the codes to other datasets with more nodes, you should change the hyper-parameter `num_of_users` in `unsupervise_main_server.py` and `unsupervise_main_server.py`, as well as the `node_id` in `run_unsupervise_node.sh` and `run_supervise_node.sh ` 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /supervise-fl-server/supervise_main_server.py: -------------------------------------------------------------------------------- 1 | import socketserver 2 | import pickle, struct 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | from threading import Lock, Thread 8 | import threading 9 | import numpy as np 10 | 11 | 12 | # np.set_printoptions(threshold=np.inf) 13 | 14 | 15 | def parse_option(): 16 | parser = argparse.ArgumentParser('argument for training') 17 | 18 | ## FL 19 | parser.add_argument('--fl_round', type=int, default=10, 20 | help='communication to server after the epoch of local training') 21 | parser.add_argument('--num_of_users', type=int, default=4,##8 for audio 22 | help='num of users in FL') 23 | 24 | ## system 25 | parser.add_argument('--start_wait_time', type=int, default=300, 26 | help='start_wait_time') 27 | parser.add_argument('--W_wait_time', type=int, default=7200, 28 | help='W_wait_time') 29 | parser.add_argument('--end_wait_time', type=int, default=7200, 30 | help='end_wait_time') 31 | 32 | ## model 33 | parser.add_argument('--dim_weight_supervise', type=int, default = 41659216) 34 | 35 | opt = parser.parse_args() 36 | 37 | return opt 38 | 39 | 40 | opt = parse_option() 41 | 42 | 43 | iteration_count = 0 44 | trial_count = 0 45 | NUM_OF_WAIT = opt.num_of_users 46 | 47 | ## recieved all model weights 48 | weight_COLLECTION = np.zeros((opt.num_of_users, opt.dim_weight_supervise)) 49 | weight_MEAN = np.zeros(opt.dim_weight_supervise) 50 | 51 | Update_Flag = np.ones(opt.num_of_users) 52 | conver_indicator = 1e5 53 | 54 | wait_time_record = np.zeros(opt.fl_round) 55 | aggregation_time_record = np.zeros(opt.fl_round) 56 | downlink_time_record = np.zeros((opt.num_of_users, opt.fl_round)) 57 | server_start_time_record = np.zeros((opt.num_of_users, opt.fl_round)) 58 | 59 | 60 | def mmFedavg(opt, model_weight): 61 | 62 | mean_model_weight = np.mean(model_weight, axis = 0) 63 | 64 | return mean_model_weight 65 | 66 | 67 | def server_update(): 68 | 69 | global opt, iteration_count, weight_COLLECTION, weight_MEAN 70 | global aggregation_time_record, wait_time_record, server_start_time_record 71 | 72 | aggregate_time1 = time.time() 73 | wait_time_record[iteration_count] = aggregate_time1 - np.min(server_start_time_record[:, iteration_count]) 74 | print("server wait time:", wait_time_record[iteration_count]) 75 | 76 | ## mmFedavg for model weights 77 | print("Iteration {}: mmFedavg of model weights".format(iteration_count)) 78 | weight_MEAN = mmFedavg(opt, weight_COLLECTION) 79 | 80 | 81 | aggregate_time2 = time.time() 82 | aggregation_time_record[iteration_count] = aggregate_time2 - aggregate_time1 83 | print("server aggregation time:", aggregation_time_record[iteration_count]) 84 | 85 | iteration_count = iteration_count + 1 86 | print("iteration_count: ", iteration_count) 87 | 88 | 89 | def reinitialize(): 90 | 91 | global iteration_count 92 | # trial_count += 1 93 | iteration_count = 0 94 | # print("Trial: ", trial_count) 95 | 96 | global opt, NUM_OF_WAIT, wait_time_record, aggregation_time_record, server_start_time_record, downlink_time_record 97 | print("All of Server Wait Time:", np.sum(wait_time_record)) 98 | print("All of Server Aggregate Time:", np.sum(aggregation_time_record)) 99 | 100 | save_model_path = "./save_server_time_supervise_{}nodes/".format(opt.num_of_users) 101 | if not os.path.isdir(save_model_path): 102 | os.makedirs(save_model_path) 103 | 104 | np.savetxt(os.path.join(save_model_path, "aggregation_time_record.txt"), aggregation_time_record) 105 | np.savetxt(os.path.join(save_model_path, "wait_time_record.txt"), wait_time_record) 106 | np.savetxt(os.path.join(save_model_path, "server_start_time_record.txt"), server_start_time_record) 107 | np.savetxt(os.path.join(save_model_path, "downlink_time_record.txt"), downlink_time_record) 108 | 109 | wait_time_record = np.zeros(opt.fl_round) 110 | aggregation_time_record = np.zeros(opt.fl_round) 111 | server_start_time_record = np.zeros((opt.num_of_users, opt.fl_round)) 112 | downlink_time_record = np.zeros((opt.num_of_users, opt.fl_round)) 113 | 114 | opt = parse_option() 115 | NUM_OF_WAIT = opt.num_of_users 116 | 117 | global weight_COLLECTION, Update_Flag 118 | 119 | weight_COLLECTION = np.zeros((opt.num_of_users, opt.dim_weight_supervise)) 120 | Update_Flag = np.ones(opt.num_of_users) 121 | 122 | global weight_MEAN 123 | weight_MEAN = np.zeros(opt.dim_weight_supervise) 124 | 125 | barrier_update() 126 | 127 | 128 | barrier_start = threading.Barrier(NUM_OF_WAIT,action = None, timeout = None) 129 | barrier_W = threading.Barrier(NUM_OF_WAIT,action = server_update, timeout = None) 130 | barrier_end = threading.Barrier(NUM_OF_WAIT, action = reinitialize, timeout = None) 131 | 132 | def barrier_update(): 133 | global NUM_OF_WAIT 134 | print("update the barriers to NUM_OF_WAIT: ",NUM_OF_WAIT) 135 | global barrier_W 136 | barrier_W = threading.Barrier(NUM_OF_WAIT,action = server_update, timeout = None) 137 | global barrier_end 138 | barrier_end = threading.Barrier(NUM_OF_WAIT, action = reinitialize, timeout = None) 139 | 140 | 141 | class MyTCPHandler(socketserver.BaseRequestHandler): 142 | 143 | def send2node(self, var): 144 | 145 | var_data = pickle.dumps(var, protocol = 0) 146 | var_size = sys.getsizeof(var_data) 147 | var_header = struct.pack("i",var_size) 148 | self.request.sendall(var_header) 149 | self.request.sendall(var_data) 150 | 151 | return var_size 152 | 153 | 154 | def handle(self): 155 | while True: 156 | try: 157 | #receive the size of content 158 | header = self.request.recv(4) 159 | size = struct.unpack('i', header) 160 | 161 | #receive the id of client 162 | u_id = self.request.recv(4) 163 | temp_id = struct.unpack('i',u_id) 164 | 165 | user_id = int(temp_id[0]) 166 | print("user_id:", user_id) 167 | 168 | # receive the type of message, defination in communication.py 169 | mess_type = self.request.recv(4) 170 | mess_type = struct.unpack('i',mess_type)[0] 171 | 172 | #print("This is the {}th node with message type {}".format(user_id[0],mess_type)) 173 | 174 | #receive the body of message 175 | recv_data = b"" 176 | 177 | while sys.getsizeof(recv_data) 256: 101 | opt.warm = True 102 | if opt.warm: 103 | opt.model_name = '{}_warm'.format(opt.model_name) 104 | opt.warmup_from = 0.01 105 | opt.warm_epochs = 10 106 | if opt.cosine: 107 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 108 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 109 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 110 | else: 111 | opt.warmup_to = opt.learning_rate 112 | 113 | opt.save_folder = os.path.join(opt.node_path, opt.model_name, 'models') 114 | if not os.path.isdir(opt.save_folder): 115 | os.makedirs(opt.save_folder) 116 | 117 | opt.result_path = os.path.join(opt.node_path, opt.model_name, 'results/') 118 | if not os.path.isdir(opt.result_path): 119 | os.makedirs(opt.result_path) 120 | 121 | return opt 122 | 123 | 124 | def set_loader(opt): 125 | 126 | # construct data loader 127 | 128 | # load data (already normalized) 129 | train_dataset = data.Multimodal_unlabel_dataset(opt.node_id, opt.num_of_samples) 130 | 131 | train_loader = torch.utils.data.DataLoader( 132 | train_dataset, batch_size=opt.batch_size, 133 | num_workers=opt.num_workers, 134 | pin_memory=True, shuffle=True, drop_last=True) 135 | 136 | return train_loader 137 | 138 | 139 | def set_model(opt): 140 | 141 | model = My3Model_unsupervise() 142 | 143 | criterion = CosmoLoss(temperature=opt.temp) 144 | 145 | ## load model weights 146 | ckpt_path = os.path.join("../", opt.choose_model) 147 | ckpt = torch.load(ckpt_path, map_location='cpu') 148 | state_dict = ckpt['model'] 149 | 150 | if torch.cuda.is_available(): 151 | # if torch.cuda.device_count() > 1: 152 | new_state_dict = {} 153 | for k, v in state_dict.items(): 154 | k = k.replace("module.", "") 155 | k = k.replace("encoder.", "") 156 | # print(k) 157 | if "classifier" not in k:## only append weight of encoders 158 | new_state_dict[k] = v 159 | state_dict = new_state_dict 160 | model.encoder.load_state_dict(state_dict) 161 | 162 | 163 | # enable synchronized Batch Normalization 164 | if opt.syncBN: 165 | model = apex.parallel.convert_syncbn_model(model) 166 | 167 | if torch.cuda.is_available(): 168 | if torch.cuda.device_count() > 1: 169 | model.encoder = torch.nn.DataParallel(model.encoder) 170 | model = model.cuda() 171 | criterion = criterion.cuda() 172 | cudnn.benchmark = True 173 | 174 | return model, criterion 175 | 176 | 177 | def train(train_loader, model, criterion, optimizer, epoch, opt): 178 | """one epoch training""" 179 | model.train() 180 | 181 | batch_time = AverageMeter() 182 | data_time = AverageMeter() 183 | losses = AverageMeter() 184 | 185 | end = time.time() 186 | for idx, (input_data1, input_data2, input_data3) in enumerate(train_loader): 187 | 188 | data_time.update(time.time() - end) 189 | 190 | if torch.cuda.is_available(): 191 | input_data1 = input_data1.cuda() 192 | input_data2 = input_data2.cuda() 193 | input_data3 = input_data3.cuda() 194 | bsz = input_data1.shape[0] 195 | 196 | # compute loss 197 | feature1, feature2, feature3 = model(input_data1, input_data2, input_data3) 198 | 199 | features = FeatureConstructor(feature1, feature2, feature3, opt.num_positive) 200 | 201 | loss = criterion(features) 202 | 203 | # update metric 204 | losses.update(loss.item(), bsz) 205 | 206 | # SGD 207 | optimizer.zero_grad() 208 | loss.backward() 209 | optimizer.step() 210 | 211 | # measure elapsed time 212 | batch_time.update(time.time() - end) 213 | end = time.time() 214 | 215 | # print info 216 | if (idx + 1) % opt.print_freq == 0: 217 | print('Train: [{0}][{1}/{2}]\t' 218 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 219 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 220 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( 221 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 222 | data_time=data_time, loss=losses)) 223 | sys.stdout.flush() 224 | 225 | return losses.avg 226 | 227 | 228 | def get_model_array(model): 229 | 230 | params = [] 231 | for param in model.parameters(): 232 | if torch.cuda.is_available(): 233 | params.extend(param.view(-1).cpu().detach().numpy()) 234 | else: 235 | params.extend(param.view(-1).detach().numpy()) 236 | # print(param) 237 | 238 | # model_params = params.cpu().numpy() 239 | model_params = np.array(params) 240 | print("Shape of model weight: ", model_params.shape)#39456 241 | 242 | return model_params 243 | 244 | 245 | 246 | def reset_model_parameter(new_params, model): 247 | 248 | temp_index = 0 249 | 250 | with torch.no_grad(): 251 | for param in model.parameters(): 252 | 253 | # print(param.shape) 254 | 255 | if len(param.shape) == 2: 256 | 257 | para_len = int(param.shape[0] * param.shape[1]) 258 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 259 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1]))) 260 | temp_index += para_len 261 | 262 | elif len(param.shape) == 3: 263 | 264 | para_len = int(param.shape[0] * param.shape[1] * param.shape[2]) 265 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 266 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1], param.shape[2]))) 267 | temp_index += para_len 268 | 269 | elif len(param.shape) == 4: 270 | 271 | para_len = int(param.shape[0] * param.shape[1] * param.shape[2] * param.shape[3]) 272 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 273 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1], param.shape[2], param.shape[3]))) 274 | temp_index += para_len 275 | 276 | elif len(param.shape) == 5: 277 | 278 | para_len = int(param.shape[0] * param.shape[1] * param.shape[2] * param.shape[3] * param.shape[4]) 279 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 280 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1], param.shape[2], param.shape[3], param.shape[4]))) 281 | temp_index += para_len 282 | 283 | else: 284 | 285 | para_len = param.shape[0] 286 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 287 | param.copy_(torch.from_numpy(temp_weight)) 288 | temp_index += para_len 289 | 290 | 291 | def set_commu(opt): 292 | 293 | #prepare the communication module 294 | server_addr = "172.22.172.75"#"10.54.20.19" 295 | # server_addr = "localhost" 296 | 297 | server_port = 30415 298 | 299 | comm = COMM(server_addr,server_port, opt.node_id) 300 | 301 | comm.send2server('hello',-1) 302 | 303 | print(comm.recvfserver()) 304 | 305 | return comm 306 | 307 | 308 | def main(): 309 | 310 | opt = parse_option() 311 | 312 | # set up communication with sevrer 313 | comm = set_commu(opt) 314 | 315 | # build data loader 316 | train_loader = set_loader(opt) 317 | 318 | # build model and criterion 319 | model, criterion = set_model(opt) 320 | w_parameter_init = get_model_array(model) 321 | 322 | # build optimizer 323 | optimizer = set_optimizer(opt, model) 324 | 325 | record_loss = np.zeros(opt.epochs) 326 | 327 | compute_time_record = np.zeros(opt.epochs) 328 | upper_commu_time_record = np.zeros(int(opt.epochs/opt.fl_epoch)) 329 | down_commu_time_record = np.zeros(int(opt.epochs/opt.fl_epoch)) 330 | 331 | all_time_record = np.zeros(opt.epochs + 2) 332 | all_time_record[0] = time.time() 333 | 334 | begin_time = time.time() 335 | 336 | # training routine 337 | for epoch in range(1, opt.epochs + 1): 338 | adjust_learning_rate(opt, optimizer, epoch) 339 | 340 | # train for one epoch 341 | time1 = time.time() 342 | loss = train(train_loader, model, criterion, optimizer, epoch, opt) 343 | time2 = time.time() 344 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 345 | 346 | # tensorboard logger 347 | record_loss[epoch-1] = loss 348 | compute_time_record[epoch-1] = time2 - time1 349 | 350 | # communication with the server every fl_epoch 351 | if (epoch % opt.fl_epoch) == 0: 352 | 353 | ## send model update to the server 354 | print("Node {} sends weight to the server:".format(opt.node_id)) 355 | w_parameter = get_model_array(model) #obtain the model parameters or gradients 356 | w_update = w_parameter - w_parameter_init 357 | 358 | comm_time1 = time.time() 359 | comm.send2server(w_update,0) 360 | comm_time2 = time.time() 361 | commu_epoch = int(epoch/opt.fl_epoch - 1) 362 | upper_commu_time_record[commu_epoch] = comm_time2 - comm_time1 363 | print("time for sending model weights:", comm_time2 - comm_time1) 364 | 365 | ## recieve aggregated model update from the server 366 | comm_time3 = time.time() 367 | new_w_update, sig_stop = comm.recvOUF() 368 | comm_time4 = time.time() 369 | down_commu_time_record[commu_epoch] = comm_time4 - comm_time3 370 | print("time for downloading model weights:", comm_time4 - comm_time3) 371 | print("Received weight from the server:", new_w_update.shape) 372 | # print("Received signal from the server:", sig_stop) 373 | 374 | ## update the model according to the received weights 375 | new_w = w_parameter_init + new_w_update 376 | reset_model_parameter(new_w, model) 377 | w_parameter_init = new_w 378 | 379 | # save model 380 | if epoch % opt.save_freq == 0: 381 | save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 382 | save_model(model, optimizer, opt, epoch, save_file) 383 | 384 | all_time_record[epoch] = time.time() 385 | 386 | np.savetxt(opt.result_path + "record_loss.txt", record_loss) 387 | np.savetxt(opt.result_path + "compute_time_record.txt", compute_time_record) 388 | np.savetxt(opt.result_path + "upper_commu_time_record.txt", upper_commu_time_record) 389 | np.savetxt(opt.result_path + "down_commu_time_record.txt", down_commu_time_record) 390 | np.savetxt(opt.result_path + "all_time_record.txt", all_time_record) 391 | 392 | end_time = time.time() 393 | all_time_record[epoch+1] = end_time - begin_time 394 | print("Total training delay: ", end_time - begin_time) 395 | 396 | comm.disconnect(1) 397 | 398 | # save the last model 399 | save_file = os.path.join(opt.save_folder, 'last.pth') 400 | save_model(model, optimizer, opt, opt.epochs, save_file) 401 | 402 | np.savetxt(opt.result_path + "record_loss.txt", record_loss) 403 | np.savetxt(opt.result_path + "compute_time_record.txt", compute_time_record) 404 | np.savetxt(opt.result_path + "upper_commu_time_record.txt", upper_commu_time_record) 405 | np.savetxt(opt.result_path + "down_commu_time_record.txt", down_commu_time_record) 406 | np.savetxt(opt.result_path + "all_time_record.txt", all_time_record) 407 | 408 | 409 | 410 | if __name__ == '__main__': 411 | main() 412 | -------------------------------------------------------------------------------- /supervise-fl-node/supervise_main_node.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | import numpy as np 9 | 10 | # import tensorboard_logger as tb_logger 11 | import torch 12 | import torch.nn as nn 13 | import torch.backends.cudnn as cudnn 14 | # from torchvision import transforms, datasets 15 | 16 | from util import AverageMeter 17 | from util import adjust_learning_rate, warmup_learning_rate, accuracy 18 | from util import set_optimizer, save_model 19 | 20 | from model import MySingleModel, My3Model 21 | import data_pre as data 22 | from sklearn.metrics import f1_score 23 | 24 | from communication import COMM 25 | 26 | 27 | try: 28 | import apex 29 | from apex import amp, optimizers 30 | except ImportError: 31 | pass 32 | 33 | 34 | def parse_option(): 35 | parser = argparse.ArgumentParser('argument for training') 36 | 37 | parser.add_argument('--node_id', type=int, default=0, 38 | help='node_id') 39 | parser.add_argument('--fl_epoch', type=int, default=10, 40 | help='communication to server after the epoch of local training') 41 | 42 | parser.add_argument('--print_freq', type=int, default=5, 43 | help='print frequency') 44 | parser.add_argument('--test_print_freq', type=int, default=5, 45 | help='test print frequency') 46 | parser.add_argument('--save_freq', type=int, default=10, 47 | help='save frequency') 48 | parser.add_argument('--batch_size', type=int, default=16, 49 | help='batch_size') 50 | parser.add_argument('--num_workers', type=int, default=8, 51 | help='num of workers to use') 52 | parser.add_argument('--epochs', type=int, default=101, 53 | help='number of training epochs') 54 | 55 | # optimization 56 | parser.add_argument('--learning_rate', type=float, default=1e-3, 57 | help='learning rate') 58 | parser.add_argument('--lr_decay_epochs', type=str, default='100,200,300', 59 | help='where to decay lr, can be a list') 60 | parser.add_argument('--lr_decay_rate', type=float, default=0.9, 61 | help='decay rate for learning rate') 62 | parser.add_argument('--weight_decay', type=float, default=1e-4, 63 | help='weight decay') 64 | parser.add_argument('--momentum', type=float, default=0.9, 65 | help='momentum') 66 | parser.add_argument('--kl_lamda', type=float, default=0.1, 67 | help='kl_lamda') 68 | 69 | # model dataset 70 | parser.add_argument('--choose_model', type=str, default="last.pth", 71 | help='choose_model') 72 | parser.add_argument('--load_model', type=str, default="encoder", 73 | help='load_model') 74 | parser.add_argument('--num_class', type=int, default=16, 75 | help='num_class') 76 | parser.add_argument('--num_of_samples', type=int, default=1000, 77 | help='num_of_samples') 78 | 79 | # other setting 80 | parser.add_argument('--cosine', action='store_true', 81 | help='using cosine annealing') 82 | parser.add_argument('--syncBN', action='store_true', 83 | help='using synchronized batch normalization') 84 | parser.add_argument('--warm', action='store_true', 85 | help='warm-up for large batch training') 86 | 87 | opt = parser.parse_args() 88 | 89 | 90 | opt.num_per_class, opt.train_labels, opt.num_of_train = data.count_num_per_class(opt.node_id, opt.num_class, opt.num_of_samples) 91 | print("num of train data:", len(opt.train_labels)) 92 | print("train data num_per_class:", opt.num_per_class) 93 | 94 | 95 | iterations = opt.lr_decay_epochs.split(',') 96 | opt.lr_decay_epochs = list([]) 97 | for it in iterations: 98 | opt.lr_decay_epochs.append(int(it)) 99 | 100 | # set the path according to the environment 101 | opt.model_path = '../unsupervise-fl-node/save_node_unsupervise/node{}'.format(opt.node_id) 102 | opt.load_model_name = 'lr_0.01_decay_0.9_bsz_16_temp_0.07_epoch_101/models/' 103 | if opt.cosine: 104 | opt.load_model_name = '{}_cosine'.format(opt.load_model_name) 105 | opt.load_folder = os.path.join(opt.model_path, opt.load_model_name) 106 | 107 | opt.load_folder_classifier = '../pretrain_model.pth' 108 | 109 | opt.node_path = './save_node_supervise/node{}/'.format(opt.node_id) 110 | opt.model_name = 'lr_{}_decay_{}_bsz_{}_epoch_{}'.\ 111 | format(opt.learning_rate, opt.lr_decay_rate, opt.batch_size, opt.epochs) 112 | 113 | if opt.cosine: 114 | opt.model_name = '{}_cosine'.format(opt.model_name) 115 | 116 | # warm-up for large-batch training, 117 | if opt.batch_size > 256: 118 | opt.warm = True 119 | if opt.warm: 120 | opt.model_name = '{}_warm'.format(opt.model_name) 121 | opt.warmup_from = 0.01 122 | opt.warm_epochs = 10 123 | if opt.cosine: 124 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 125 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 126 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 127 | else: 128 | opt.warmup_to = opt.learning_rate 129 | 130 | opt.save_folder = os.path.join(opt.node_path, opt.model_name, 'models') 131 | if not os.path.isdir(opt.save_folder): 132 | os.makedirs(opt.save_folder) 133 | 134 | opt.result_path = os.path.join(opt.node_path, opt.model_name, 'results/') 135 | if not os.path.isdir(opt.result_path): 136 | os.makedirs(opt.result_path) 137 | 138 | return opt 139 | 140 | 141 | def set_loader(opt): 142 | 143 | # construct data loader 144 | 145 | # load data (already normalized) 146 | train_dataset = data.Multimodal_train_dataset(opt.node_id, opt.num_of_samples) 147 | test_dataset = data.Multimodal_test_dataset(opt.node_id) 148 | 149 | #re-sampling 150 | labels = opt.train_labels 151 | sample_weight = [1/opt.num_per_class[labels[i]] for i in range(len(labels))] 152 | sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weight, num_samples=opt.num_of_train, replacement=True) 153 | 154 | train_loader = torch.utils.data.DataLoader( 155 | train_dataset, batch_size=opt.batch_size, 156 | num_workers=opt.num_workers, sampler=sampler, 157 | pin_memory=True, drop_last=True) 158 | 159 | test_loader = torch.utils.data.DataLoader( 160 | train_dataset, batch_size=opt.batch_size, 161 | num_workers=opt.num_workers, 162 | pin_memory=True, shuffle=True, drop_last=True) 163 | 164 | return train_loader, test_loader 165 | 166 | 167 | def set_model(opt): 168 | 169 | model = My3Model(num_classes=opt.num_class) 170 | 171 | ## define loss functions 172 | criterion = torch.nn.CrossEntropyLoss() 173 | kl_criterion = nn.KLDivLoss(reduction="batchmean") 174 | 175 | ## load model weights 176 | ckpt_path = os.path.join(opt.load_folder, opt.choose_model) 177 | ckpt = torch.load(ckpt_path, map_location='cpu') 178 | state_dict = ckpt['model'] 179 | 180 | if torch.cuda.is_available(): 181 | # if torch.cuda.device_count() > 1: 182 | new_state_dict = {} 183 | for k, v in state_dict.items(): 184 | k = k.replace("module.", "") 185 | k = k.replace("encoder.", "") 186 | # print(k) 187 | if "head" not in k:## only append weight of encoders 188 | new_state_dict[k] = v 189 | state_dict = new_state_dict 190 | model.encoder.load_state_dict(state_dict) 191 | 192 | 193 | ## load model weights of classifier 194 | if opt.load_model == "all": 195 | ckpt_classifier = torch.load(opt.load_folder_classifier, map_location='cpu') 196 | state_dict_classifier = ckpt_classifier['model'] 197 | 198 | if torch.cuda.is_available(): 199 | new_state_dict = {} 200 | for k, v in state_dict_classifier.items(): 201 | k = k.replace("module.", "") 202 | # print(k) 203 | if "classifier" in k:## only append weight of classifier 204 | k = k.replace("classifier.", "") 205 | new_state_dict[k] = v 206 | state_dict_classifier = new_state_dict 207 | model.classifier.load_state_dict(state_dict_classifier) 208 | 209 | # enable synchronized Batch Normalization 210 | if opt.syncBN: 211 | model = apex.parallel.convert_syncbn_model(model) 212 | 213 | if torch.cuda.is_available(): 214 | if torch.cuda.device_count() > 1: 215 | model.encoder = torch.nn.DataParallel(model.encoder) 216 | model = model.cuda() 217 | criterion = criterion.cuda() 218 | cudnn.benchmark = True 219 | 220 | return model, criterion, kl_criterion 221 | 222 | 223 | def set_global_model(opt): 224 | 225 | model = My3Model(num_classes = opt.num_class) 226 | 227 | if torch.cuda.is_available(): 228 | model = model.cuda() 229 | cudnn.benchmark = True 230 | 231 | return model 232 | 233 | def train_multi(train_loader, model, global_model, criterion, kl_criterion, optimizer, epoch, opt): 234 | """one epoch training""" 235 | model.train() 236 | 237 | batch_time = AverageMeter() 238 | data_time = AverageMeter() 239 | losses = AverageMeter() 240 | 241 | top1_meter = AverageMeter() 242 | f1score_meter = AverageMeter() 243 | confusion = np.zeros((opt.num_class, opt.num_class)) 244 | 245 | end = time.time() 246 | 247 | for batch_idx, (input_data1, input_data2, input_data3, labels) in enumerate(train_loader): 248 | 249 | data_time.update(time.time() - end) 250 | 251 | if torch.cuda.is_available(): 252 | input_data1 = input_data1.cuda() 253 | input_data2 = input_data2.cuda() 254 | input_data3 = input_data3.cuda() 255 | labels = labels.cuda() 256 | bsz = input_data1.shape[0] 257 | 258 | # compute loss 259 | output = model(input_data1, input_data2, input_data3) 260 | 261 | 262 | if epoch > opt.fl_epoch: 263 | output_global = global_model(input_data1, input_data2, input_data3) 264 | # print(criterion(output, labels)) 265 | # print(opt.kl_lamda * kl_criterion(output, output_global.detach())) 266 | loss = criterion(output, labels) + opt.kl_lamda * kl_criterion(output, output_global.detach()) 267 | else: 268 | loss = criterion(output, labels) 269 | 270 | 271 | losses.update(loss.item(), bsz) 272 | 273 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 274 | batch_f1 = f1_score(labels.cpu().numpy(), output.max(1)[1].cpu().numpy(), average="weighted") 275 | 276 | # calculate and store confusion matrix 277 | rows = labels.cpu().numpy() 278 | cols = output.max(1)[1].cpu().numpy() 279 | for sample_index in range(labels.shape[0]): 280 | confusion[rows[sample_index], cols[sample_index]] += 1 281 | top1_meter.update(acc5[0], bsz) 282 | f1score_meter.update(batch_f1, bsz) 283 | 284 | 285 | # SGD 286 | optimizer.zero_grad() 287 | loss.backward() 288 | optimizer.step() 289 | 290 | # measure elapsed time 291 | batch_time.update(time.time() - end) 292 | end = time.time() 293 | 294 | # print(f1score.val, f1score.avg) 295 | 296 | # print info 297 | if (batch_idx + 1) % opt.print_freq == 0: 298 | print('Train: [{0}][{1}/{2}]\t' 299 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 300 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 301 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 302 | 'Acc {top1.val:.3f} ({top1.avg:.3f})\t' 303 | 'F1 {f1score.val:.3f} ({f1score.avg:.3f})\t'.format( 304 | epoch, batch_idx + 1, len(train_loader), batch_time=batch_time, 305 | data_time=data_time, loss=losses, top1 = top1_meter, f1score = f1score_meter)) 306 | sys.stdout.flush() 307 | 308 | top1 = top1_meter.avg 309 | f1score = f1score_meter.avg 310 | 311 | return losses.avg, top1, f1score, confusion 312 | 313 | 314 | 315 | def validate_multi(val_loader, model, criterion, opt): 316 | """validation""" 317 | model.eval() 318 | 319 | batch_time = AverageMeter() 320 | data_time = AverageMeter() 321 | losses = AverageMeter() 322 | 323 | top1_meter = AverageMeter() 324 | f1score_meter = AverageMeter() 325 | confusion = np.zeros((opt.num_class, opt.num_class)) 326 | 327 | 328 | with torch.no_grad(): 329 | end = time.time() 330 | for batch_idx, (input_data1, input_data2, input_data3, labels) in enumerate(val_loader): 331 | 332 | if torch.cuda.is_available(): 333 | input_data1 = input_data1.float().cuda() 334 | input_data2 = input_data2.float().cuda() 335 | input_data3 = input_data3.float().cuda() 336 | labels = labels.cuda() 337 | bsz = input_data1.shape[0] 338 | 339 | # forward 340 | output = model(input_data1, input_data2, input_data3) 341 | loss = criterion(output, labels) 342 | losses.update(loss.item(), bsz) 343 | 344 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 345 | batch_f1 = f1_score(labels.cpu().numpy(), output.max(1)[1].cpu().numpy(), average="weighted") 346 | 347 | # calculate and store confusion matrix 348 | rows = labels.cpu().numpy() 349 | cols = output.max(1)[1].cpu().numpy() 350 | for sample_index in range(labels.shape[0]): 351 | confusion[rows[sample_index], cols[sample_index]] += 1 352 | top1_meter.update(acc5[0], bsz) 353 | f1score_meter.update(batch_f1, bsz) 354 | 355 | # measure elapsed time 356 | batch_time.update(time.time() - end) 357 | end = time.time() 358 | 359 | if batch_idx % opt.test_print_freq == 0: 360 | print('Test: [{0}/{1}]\t' 361 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 362 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 363 | 'Acc {top1.val:.3f} ({top1.avg:.3f})\t' 364 | 'F1 {f1score.val:.3f} ({f1score.avg:.3f})\t'.format( 365 | batch_idx, len(val_loader), batch_time=batch_time, loss=losses, top1 = top1_meter, f1score = f1score_meter)) 366 | 367 | top1 = top1_meter.avg 368 | f1score = f1score_meter.avg 369 | 370 | return losses.avg, top1, f1score, confusion 371 | 372 | 373 | def get_model_array(model): 374 | 375 | params = [] 376 | for param in model.parameters(): 377 | if torch.cuda.is_available(): 378 | params.extend(param.view(-1).cpu().detach().numpy()) 379 | else: 380 | params.extend(param.view(-1).detach().numpy()) 381 | # print(param) 382 | 383 | # model_params = params.cpu().numpy() 384 | model_params = np.array(params) 385 | print("Shape of model weight: ", model_params.shape)#39456 386 | 387 | return model_params 388 | 389 | 390 | 391 | def reset_model_parameter(new_params, model): 392 | 393 | temp_index = 0 394 | 395 | with torch.no_grad(): 396 | for param in model.parameters(): 397 | 398 | # print(param.shape) 399 | 400 | if len(param.shape) == 2: 401 | 402 | para_len = int(param.shape[0] * param.shape[1]) 403 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 404 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1]))) 405 | temp_index += para_len 406 | 407 | elif len(param.shape) == 3: 408 | 409 | para_len = int(param.shape[0] * param.shape[1] * param.shape[2]) 410 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 411 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1], param.shape[2]))) 412 | temp_index += para_len 413 | 414 | elif len(param.shape) == 4: 415 | 416 | para_len = int(param.shape[0] * param.shape[1] * param.shape[2] * param.shape[3]) 417 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 418 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1], param.shape[2], param.shape[3]))) 419 | temp_index += para_len 420 | 421 | elif len(param.shape) == 5: 422 | 423 | para_len = int(param.shape[0] * param.shape[1] * param.shape[2] * param.shape[3] * param.shape[4]) 424 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 425 | param.copy_(torch.from_numpy(temp_weight.reshape(param.shape[0], param.shape[1], param.shape[2], param.shape[3], param.shape[4]))) 426 | temp_index += para_len 427 | 428 | else: 429 | 430 | para_len = param.shape[0] 431 | temp_weight = new_params[temp_index : temp_index + para_len].astype(float) 432 | param.copy_(torch.from_numpy(temp_weight)) 433 | temp_index += para_len 434 | 435 | 436 | def set_commu(opt): 437 | 438 | #prepare the communication module 439 | server_addr = "172.22.172.75" 440 | # server_addr = "localhost" 441 | 442 | server_port = 30415 443 | 444 | comm = COMM(server_addr,server_port, opt.node_id) 445 | 446 | comm.send2server('hello',-1) 447 | 448 | print(comm.recvfserver()) 449 | 450 | return comm 451 | 452 | 453 | def main(): 454 | 455 | opt = parse_option() 456 | 457 | # set up communication with sevrer 458 | comm = set_commu(opt) 459 | 460 | # build data loader 461 | train_loader, val_loader = set_loader(opt) 462 | 463 | # build model and criterion 464 | model, criterion, kl_criterion = set_model(opt) 465 | w_parameter_init = get_model_array(model) 466 | 467 | # build global model 468 | global_model = set_global_model(opt) 469 | 470 | # build optimizer 471 | optimizer = set_optimizer(opt, model) 472 | 473 | record_loss = np.zeros(opt.epochs) 474 | record_acc = np.zeros(opt.epochs) 475 | record_f1 = np.zeros(opt.epochs) 476 | 477 | compute_time_record = np.zeros(opt.epochs) 478 | upper_commu_time_record = np.zeros(int(opt.epochs/opt.fl_epoch)) 479 | down_commu_time_record = np.zeros(int(opt.epochs/opt.fl_epoch)) 480 | 481 | all_time_record = np.zeros(opt.epochs + 2) 482 | all_time_record[0] = time.time() 483 | 484 | best_acc = 0 485 | best_f1 = 0 486 | best_confusion = np.zeros((opt.num_class, opt.num_class)) 487 | 488 | begin_time = time.time() 489 | 490 | # training routine 491 | for epoch in range(1, opt.epochs + 1): 492 | adjust_learning_rate(opt, optimizer, epoch) 493 | 494 | # train for one epoch 495 | time1 = time.time() 496 | loss, train_acc, train_f1_score, train_confusion = train_multi(train_loader, model, global_model, criterion, kl_criterion, optimizer, epoch, opt) 497 | time2 = time.time() 498 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 499 | 500 | # tensorboard logger 501 | record_loss[epoch-1] = loss 502 | compute_time_record[epoch-1] = time2 - time1 503 | 504 | ## record acc 505 | # if opt.local_modality == "all": 506 | loss, val_acc, val_f1_score, val_confusion = validate_multi(val_loader, model, criterion, opt) 507 | 508 | record_acc[epoch-1] = val_acc 509 | record_f1[epoch-1] = val_f1_score 510 | 511 | if best_acc < val_acc: 512 | best_acc = val_acc 513 | if best_f1 < val_f1_score: 514 | best_f1 = val_f1_score 515 | 516 | # communication with the server every fl_epoch 517 | if (epoch % opt.fl_epoch) == 0: 518 | 519 | ## send model update to the server 520 | print("Node {} sends weight to the server:".format(opt.node_id)) 521 | w_parameter = get_model_array(model) #obtain the model parameters or gradients 522 | w_update = w_parameter - w_parameter_init 523 | 524 | comm_time1 = time.time() 525 | comm.send2server(w_update,0) 526 | comm_time2 = time.time() 527 | commu_epoch = int(epoch/opt.fl_epoch - 1) 528 | upper_commu_time_record[commu_epoch] = comm_time2 - comm_time1 529 | print("time for sending model weights:", comm_time2 - comm_time1) 530 | 531 | ## recieve aggregated model update from the server 532 | comm_time3 = time.time() 533 | new_w_update, sig_stop = comm.recvOUF() 534 | comm_time4 = time.time() 535 | down_commu_time_record[commu_epoch] = comm_time4 - comm_time3 536 | print("time for downloading model weights:", comm_time4 - comm_time3) 537 | print("Received weight from the server:", new_w_update.shape) 538 | # print("Received signal from the server:", sig_stop) 539 | 540 | ## update the model according to the received weights 541 | new_w = w_parameter_init + new_w_update 542 | # reset_model_parameter(new_w, model) 543 | reset_model_parameter(new_w, global_model)#do not replace the local model, only use global model to guide its training 544 | w_parameter_init = new_w 545 | 546 | # save model 547 | if epoch % opt.save_freq == 0: 548 | save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 549 | save_model(model, optimizer, opt, epoch, save_file) 550 | 551 | all_time_record[epoch] = time.time() 552 | 553 | np.savetxt(opt.result_path + "record_loss.txt", record_loss) 554 | np.savetxt(opt.result_path + "record_acc.txt", record_acc) 555 | np.savetxt(opt.result_path + "record_f1.txt", record_f1) 556 | np.savetxt(opt.result_path + "compute_time_record.txt", compute_time_record) 557 | np.savetxt(opt.result_path + "upper_commu_time_record.txt", upper_commu_time_record) 558 | np.savetxt(opt.result_path + "down_commu_time_record.txt", down_commu_time_record) 559 | np.savetxt(opt.result_path + "all_time_record.txt", all_time_record) 560 | 561 | end_time = time.time() 562 | all_time_record[epoch+1] = end_time - begin_time 563 | print("Total training delay: ", end_time - begin_time) 564 | 565 | print("best_acc:", best_acc) 566 | print("best_f1:", best_f1) 567 | 568 | 569 | comm.disconnect(1) 570 | 571 | # save the last model 572 | save_file = os.path.join(opt.save_folder, 'last.pth') 573 | save_model(model, optimizer, opt, opt.epochs, save_file) 574 | 575 | np.savetxt(opt.result_path + "record_loss.txt", record_loss) 576 | np.savetxt(opt.result_path + "record_acc.txt", record_acc) 577 | np.savetxt(opt.result_path + "record_f1.txt", record_f1) 578 | np.savetxt(opt.result_path + "compute_time_record.txt", compute_time_record) 579 | np.savetxt(opt.result_path + "upper_commu_time_record.txt", upper_commu_time_record) 580 | np.savetxt(opt.result_path + "down_commu_time_record.txt", down_commu_time_record) 581 | np.savetxt(opt.result_path + "all_time_record.txt", all_time_record) 582 | 583 | 584 | 585 | if __name__ == '__main__': 586 | main() 587 | --------------------------------------------------------------------------------