├── Figures ├── bgMLP.png └── overview_latest.png ├── iid-dictusers ├── ICH_1037_55000.npy └── ChestXray14_1037_85000.npy ├── model ├── build_model.py ├── efficientnet.py └── all_models.py ├── preprocess ├── split_train_test.py ├── count.py ├── count_pwise_disease.py ├── count_mean_dev.py ├── label_rectify.py └── ICH_process.py ├── utils ├── feature_visual.py ├── valloss_cal.py ├── sampling.py ├── utils.py ├── multilabel_metrixs.py ├── FedAvg.py ├── options.py ├── FedNoRo.py ├── evaluations.py ├── FixMatch.py └── local_training.py ├── requirements.txt ├── README.md ├── dataset ├── all_dataset.py └── dataset.py └── main.py /Figures/bgMLP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/Figures/bgMLP.png -------------------------------------------------------------------------------- /Figures/overview_latest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/Figures/overview_latest.png -------------------------------------------------------------------------------- /iid-dictusers/ICH_1037_55000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/iid-dictusers/ICH_1037_55000.npy -------------------------------------------------------------------------------- /iid-dictusers/ChestXray14_1037_85000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/iid-dictusers/ChestXray14_1037_85000.npy -------------------------------------------------------------------------------- /model/build_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .all_models import get_model, modify_last_layer 3 | 4 | 5 | def build_model(args): 6 | # choose different Neural network model for different args 7 | model = get_model(args.model, args.pretrained) 8 | model, _ = modify_last_layer(args.model, model, args.n_classes) 9 | 10 | return model -------------------------------------------------------------------------------- /preprocess/split_train_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from sklearn.model_selection import train_test_split 6 | # ChestXray14 7 | seed = 2023 8 | torch.manual_seed(seed) 9 | torch.cuda.manual_seed(seed) 10 | torch.cuda.manual_seed_all(seed) 11 | np.random.seed(seed) 12 | random.seed(seed) 13 | data = pd.read_csv(r'E:\ICH_stage2\ICH_stage2\data_png185k_512.csv') 14 | 15 | train_ratio = 0.7 16 | test_ratio = 0.3 17 | 18 | train_data, test_data = train_test_split(data, test_size=(1 - train_ratio)) 19 | 20 | print(f"训练集大小: {len(train_data)}") 21 | print(f"测试集大小: {len(test_data)}") 22 | 23 | train_data.to_csv('train_dataset.csv', index=False) 24 | test_data.to_csv('test_dataset.csv', index=False) 25 | -------------------------------------------------------------------------------- /preprocess/count.py: -------------------------------------------------------------------------------- 1 | import csv 2 | # ChestXray14 3 | import pandas as pd 4 | import os 5 | 6 | 7 | def file_name(file_dir): 8 | file_list = [] 9 | for root, dirs, files in os.walk(file_dir): 10 | file_list.append(files) 11 | return file_list 12 | 13 | image_4_path = r'E:\szb\Hust\2022~2023\lab\博一\images_004.tar\images_004\images' 14 | image_5_path = 'E:/szb/Hust/2022~2023/lab/博一/images_005.tar/images_005/images' 15 | file_5_list = file_name(image_5_path)[0] 16 | file_4_list = file_name(image_4_path)[0] 17 | csv_path = '../Data_Entry_2017_v2020.csv' 18 | 19 | with open(csv_path, 'r') as csv_file: 20 | csv_reader = csv.reader(csv_file) 21 | with open('filtered_data_4.csv', 'w', newline='') as new_csv_file: 22 | csv_writer = csv.writer(new_csv_file) 23 | for row in csv_reader: 24 | if row[0] in file_4_list: 25 | csv_writer.writerow(row) 26 | -------------------------------------------------------------------------------- /preprocess/count_pwise_disease.py: -------------------------------------------------------------------------------- 1 | import csv 2 | # ChestXray14 3 | import numpy as np 4 | import pandas as pd 5 | 6 | csv_path = '../onehot-label.csv' 7 | count = 1 8 | pre_patient = '' 9 | total_disease = [0]*14 10 | new_disease = [0]*14 11 | patients = 0 12 | with open(csv_path, 'r') as csv_file: 13 | csv_reader = csv.reader(csv_file) 14 | for row in csv_reader: 15 | if count == 1: 16 | count += 1 17 | print(count) 18 | else: 19 | count += 1 20 | print(count) 21 | if row[0].split('_')[0] != pre_patient: 22 | patients += 1 23 | pre_patient = row[0].split('_')[0] 24 | total_disease = [i + j for i, j in zip(total_disease, new_disease)] 25 | new_disease = list(map(int, row[1:])) 26 | else: 27 | new_disease = [int(new_disease[i]) or int(row[i+1]) for i in range(len(new_disease))] 28 | total_disease = [i + j for i, j in zip(total_disease, new_disease)] 29 | print(np.array(total_disease)/patients) 30 | 31 | -------------------------------------------------------------------------------- /preprocess/count_mean_dev.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import transforms 3 | # ChestXray14 4 | from dataset.all_dataset import ChestXray14 5 | 6 | 7 | def getStat(train_data): 8 | ''' 9 | Compute mean and variance for training data 10 | :param train_data 11 | :return: (mean, std) 12 | ''' 13 | print('Compute mean and variance for training data.') 14 | print(len(train_data)) 15 | train_loader = torch.utils.data.DataLoader( 16 | train_data, batch_size=1, shuffle=False, num_workers=0, 17 | pin_memory=True) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | for X, _ in train_loader: 21 | for d in range(3): 22 | mean[d] += X[:, d, :, :].mean() 23 | std[d] += X[:, d, :, :].std() 24 | mean.div_(len(train_data)) 25 | std.div_(len(train_data)) 26 | return list(mean.numpy()), list(std.numpy()) 27 | 28 | 29 | if __name__ == '__main__': 30 | train_dataset = ChestXray14('/home/szb/multilabel', "train", transforms.Compose([ 31 | transforms.Resize((224, 224)), 32 | transforms.ToTensor(), 33 | ])) 34 | print(getStat(train_dataset)) 35 | -------------------------------------------------------------------------------- /preprocess/label_rectify.py: -------------------------------------------------------------------------------- 1 | import csv 2 | # ChestXray14 3 | import numpy as np 4 | import pandas as pd 5 | 6 | # df = pd.read_csv('Data_Entry_2017_v2020.csv') 7 | # print(df) 8 | csv_path = './Data_Entry_2017_v2020.csv' 9 | count = 1 10 | title = ['Image Index', 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'] 11 | with open(csv_path, 'r') as csv_file: 12 | csv_reader = csv.reader(csv_file) 13 | 14 | with open('../onehot-label-PA.csv', 'w', newline='') as new_csv_file: 15 | csv_writer = csv.writer(new_csv_file) 16 | for row in csv_reader: 17 | if count == 1: 18 | csv_writer.writerow(title) 19 | count += 1 20 | print(count) 21 | else: 22 | count += 1 23 | print(count) 24 | if row[6] == 'PA': 25 | label_row = [row[0]] + [0]*14 26 | label = row[1] # str 27 | if label == 'No Finding': 28 | csv_writer.writerow(label_row) 29 | else: 30 | label_list = label.split('|') 31 | for i in label_list: 32 | label_row[title.index(i)] = 1 33 | csv_writer.writerow(label_row) 34 | else: 35 | pass 36 | 37 | -------------------------------------------------------------------------------- /utils/feature_visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from numpy import where 5 | from sklearn.manifold import TSNE 6 | 7 | 8 | color_map = ['r', 'y', 'k', 'g', 'b', 'm', 'c', 'peru'] 9 | # color_map = ['r', 'g'] 10 | 11 | 12 | def plot_embedding_2D(data, label, title, rnd): 13 | x_min, x_max = np.min(data, 0), np.max(data, 0) 14 | data = (data - x_min) / (x_max - x_min) 15 | fig = plt.figure() 16 | for i in range(len(np.unique(label))): 17 | list = data[label == i] 18 | plt.scatter(list[:, 0], list[:, 1], marker='o', s=1, color=color_map[i], label='class:{}'.format(i)) 19 | plt.legend() 20 | plt.xticks([]) 21 | plt.yticks([]) 22 | plt.title(title) 23 | plt.savefig('proto_fig/' + f'rnd:{rnd}' + title + '.png') 24 | # plt.show() 25 | plt.clf() 26 | return fig 27 | 28 | 29 | def tnse_Visual(data, label, rnd, title): 30 | n_samples, n_features = data.shape 31 | 32 | print('Begining......') 33 | 34 | tsne_2D = TSNE(n_components=2, init='pca', random_state=0, perplexity=5) 35 | result_2D = tsne_2D.fit_transform(data) 36 | 37 | print('Finished......') 38 | fig1 = plot_embedding_2D(result_2D, label, title, rnd) # 将二维数据用plt绘制出来 39 | 40 | 41 | if __name__ == '__main__': 42 | label = torch.randint(high=2, low=0, size=(1000, )) 43 | print(label) 44 | data = torch.rand(1000, 512) 45 | print(data.shape) 46 | tnse_Visual(data, label, 1, '666') 47 | 48 | -------------------------------------------------------------------------------- /preprocess/ICH_process.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from collections import Counter 3 | 4 | import pandas as pd 5 | import numpy as np 6 | 7 | import os 8 | from tqdm import tqdm 9 | data = pd.read_csv(r"E:\ICH_stage2\ICH_stage2\stage_2_train.csv") 10 | data = np.array(data) 11 | patient_num = len(data) // 6 12 | ID_all = [] 13 | label_all = [] 14 | label_add = [0]*5 15 | for i in range(patient_num): 16 | ID = data[6*i][0].split('_epidural')[0] 17 | label_add = [data[6*i][1], data[6*i+1][1], data[6*i+2][1], data[6*i+3][1], data[6*i+4][1]] 18 | ID_all.append(ID) 19 | label_all.append(label_add) 20 | 21 | ID_have = [] 22 | label_have = [] 23 | for i in tqdm(range(patient_num)): 24 | name = ID_all[i] + ".png" 25 | path = os.path.join(r"E:\ICH_stage2\ICH_stage2\png185k_512", name) 26 | if os.path.exists(path): 27 | ID_have.append(name) 28 | label_have.append(label_all[i]) 29 | 30 | 31 | csv_path = r'E:\ICH_stage2\ICH_stage2\data_png185k_512.csv' 32 | count = 1 33 | title = ['Image Index', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural'] 34 | with open(csv_path, 'w', newline='') as new_csv_file: 35 | csv_writer = csv.writer(new_csv_file) 36 | for i in tqdm(range(len(ID_have)+1)): 37 | if count == 1: 38 | csv_writer.writerow(title) 39 | count += 1 40 | else: 41 | count += 1 42 | csv_writer.writerow([ID_have[i-1]] + label_have[i-1]) 43 | label_have_total = np.sum(label_have, axis=0) 44 | label_have_class = np.sum(label_have, axis=1) 45 | print(label_have_total) # [2761 32564 23766 32122 42496] 46 | print(Counter(label_have_class)) # Counter({0: 87948, 1: 67969, 2: 22587, 3: 5642, 4: 885, 5: 20}) 47 | -------------------------------------------------------------------------------- /utils/valloss_cal.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader, SubsetRandomSampler 7 | 8 | 9 | def get_num_of_each_class(args, num_samples_to_split, test_dataset): 10 | class_sum = np.array([0.] * args.n_classes) 11 | for idx in range(num_samples_to_split): 12 | class_sum += test_dataset.targets[idx] 13 | return class_sum.tolist() 14 | 15 | def valloss(net, test_dataset, args): 16 | net.eval() 17 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size*4, shuffle=False, num_workers=4) 18 | split_ratio = 0.1 19 | 20 | num_samples = len(test_loader.dataset) 21 | num_samples_to_split = int(num_samples * split_ratio) 22 | 23 | random_sampler = SubsetRandomSampler(range(num_samples_to_split)) 24 | 25 | val_data_loader = DataLoader( 26 | dataset=test_loader.dataset, 27 | batch_size=args.batch_size*4, 28 | sampler=random_sampler 29 | ) 30 | class_num_list = get_num_of_each_class(args, num_samples_to_split, test_dataset) 31 | loss_w = [num_samples_to_split / i for i in class_num_list] 32 | print(class_num_list, num_samples_to_split, loss_w) 33 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda()) # include sigmoid 34 | batch_loss = [] 35 | with torch.no_grad(): 36 | for samples in val_data_loader: 37 | images, labels = samples["image"].to(args.device), samples["target"].to(args.device) 38 | _, outputs = net(images) 39 | val_loss = bce_criterion(outputs, labels) 40 | batch_loss.append(val_loss.item()) 41 | val_loss_mean = np.array(batch_loss).mean() 42 | logging.info(val_loss_mean) 43 | return val_loss_mean 44 | 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | astunparse==1.6.3 3 | attrs==22.1.0 4 | bleach==6.1.0 5 | bytecode==0.15.1 6 | cachetools==5.2.0 7 | certifi==2024.2.2 8 | charset-normalizer==2.1.1 9 | conda-pack==0.6.0 10 | contextlib2==21.6.0 11 | contourpy==1.0.5 12 | cupy @ file:///opt/conda/conda-bld/cupy_1610065936122/work 13 | cxxfilt==0.3.0 14 | cycler==0.11.0 15 | efficientnet-pytorch==0.7.1 16 | einops==0.5.0 17 | fastrlock==0.8.1 18 | filelock==3.8.0 19 | flatbuffers==22.12.6 20 | fonttools==4.37.4 21 | gast==0.4.0 22 | google-auth==2.12.0 23 | google-auth-oauthlib==0.4.6 24 | google-pasta==0.2.0 25 | grpcio==1.49.1 26 | h5py==3.7.0 27 | hausdorff==0.2.6 28 | huggingface-hub==0.10.0 29 | idna==3.4 30 | imageio==2.22.1 31 | imbalanced-learn==0.11.0 32 | imblearn==0.0 33 | importlib-metadata==5.0.0 34 | iniconfig==1.1.1 35 | joblib==1.2.0 36 | kaggle==1.6.14 37 | keras==2.11.0 38 | kiwisolver==1.4.4 39 | libclang==14.0.6 40 | llvmlite==0.39.1 41 | Markdown==3.4.1 42 | MarkupSafe==2.1.1 43 | matplotlib==3.6.1 44 | MedPy==0.4.0 45 | mkl-fft==1.3.1 46 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work 47 | mkl-service==2.4.0 48 | ml-collections==0.1.1 49 | monai==1.1.0 50 | munch==4.0.0 51 | networkx==2.8.7 52 | nibabel==4.0.2 53 | numba==0.56.2 54 | numpy @ file:///croot/numpy_and_numpy_base_1672336185480/work 55 | oauthlib==3.2.1 56 | opencv-python==4.4.0.46 57 | opt-einsum==3.3.0 58 | packaging==21.3 59 | pandas==1.5.0 60 | Pillow==9.2.0 61 | pluggy==1.0.0 62 | pretrainedmodels==0.7.4 63 | protobuf==3.19.6 64 | py==1.11.0 65 | pyasn1==0.4.8 66 | pyasn1-modules==0.2.8 67 | pyparsing==3.0.9 68 | pytest==7.1.3 69 | python-dateutil==2.8.2 70 | python-slugify==8.0.4 71 | pytz==2022.4 72 | PyWavelets==1.4.1 73 | PyYAML==6.0 74 | requests==2.28.1 75 | requests-oauthlib==1.3.1 76 | rsa==4.9 77 | scikit-image==0.19.3 78 | scikit-learn==1.1.3 79 | scipy==1.9.2 80 | seaborn==0.12.0 81 | setproctitle==1.3.2 82 | SimpleITK==2.2.0 83 | six @ file:///tmp/build/80754af9/six_1644875935023/work 84 | sklearn==0.0.post1 85 | swin-window-process==0.0.0 86 | tb-nightly==2.12.0a20221114 87 | tensorboard-data-server==0.6.1 88 | tensorboard-plugin-wit==1.8.1 89 | tensorboardX==2.5.1 90 | tensorflow==2.11.0 91 | tensorflow-estimator==2.11.0 92 | tensorflow-io-gcs-filesystem==0.28.0 93 | termcolor==1.1.0 94 | text-unidecode==1.3 95 | thop==0.1.1.post2209072238 96 | threadpoolctl==3.1.0 97 | tifffile==2022.10.10 98 | timm==0.4.12 99 | tomli==2.0.1 100 | torch==1.12.1+cu116 101 | torchaudio==0.12.1+cu116 102 | torchvision==0.13.1+cu116 103 | tqdm==4.64.1 104 | typing_extensions==4.4.0 105 | urllib3==1.22 106 | visualizerX==0.0.1 107 | vit-pytorch==0.35.8 108 | webencodings==0.5.1 109 | Werkzeug==2.2.2 110 | wrapt==1.14.1 111 | xarray==2022.11.0 112 | yacs==0.1.8 113 | zipp==3.9.0 114 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | # This file is borrowed from https://github.com/Xu-Jingyi/FedCorr/blob/main/util/sampling.py 2 | 3 | import numpy as np 4 | 5 | 6 | def iid_sampling(n_train, num_users, seed): 7 | np.random.seed(seed) 8 | num_items = int(n_train / num_users) 9 | dict_users, all_idxs = {}, [i for i in range(n_train)] # initial user and index for whole dataset 10 | for i in range(num_users): 11 | dict_users[i] = set( 12 | np.random.choice(all_idxs, num_items, replace=False)) # 'replace=False' make sure that there is no repeat 13 | all_idxs = list(set(all_idxs) - dict_users[i]) 14 | 15 | for key in dict_users.keys(): 16 | dict_users[key] = list(dict_users[key]) 17 | return dict_users 18 | 19 | 20 | def non_iid_dirichlet_sampling(y_train, num_classes, p, num_users, seed, alpha_dirichlet): 21 | np.random.seed(seed) 22 | Phi = np.random.binomial(1, p, size=(num_users, num_classes)) # indicate the classes chosen by each client 23 | n_classes_per_client = np.sum(Phi, axis=1) 24 | while np.min(n_classes_per_client) == 0: 25 | invalid_idx = np.where(n_classes_per_client == 0)[0] 26 | Phi[invalid_idx] = np.random.binomial(1, p, size=(len(invalid_idx), num_classes)) 27 | n_classes_per_client = np.sum(Phi, axis=1) 28 | Psi = [list(np.where(Phi[:, j] == 1)[0]) for j in range(num_classes)] # indicate the clients that choose each class 29 | num_clients_per_class = np.array([len(x) for x in Psi]) 30 | dict_users = {} 31 | for class_i in range(num_classes+1): 32 | # all_idxs = np.where(y_train == class_i)[0] 33 | n_classes_per_sample = np.sum(y_train, axis=1) 34 | all_idxs = np.where(n_classes_per_sample == class_i)[0] 35 | # else_idxs = np.where(y_train[class_i*25907:(class_i+1)*25907, class_i] != 1)[0] + class_i*25907 36 | p_dirichlet = np.random.dirichlet([alpha_dirichlet] * num_clients_per_class[0]) 37 | assignment = np.random.choice(Psi[0], size=len(all_idxs), p=p_dirichlet.tolist()) 38 | # assignment_else = np.random.choice(Psi[class_i], size=len(else_idxs), p=p_dirichlet.tolist()) 39 | 40 | for client_k in Psi[0]: 41 | if client_k in dict_users: 42 | # dict_users[client_k] = set(dict_users[client_k] | set(all_idxs[(assignment == client_k)]) | set(else_idxs[(assignment_else == client_k)])) 43 | dict_users[client_k] = set(dict_users[client_k] | set(all_idxs[(assignment == client_k)])) 44 | else: 45 | # dict_users[client_k] = set(np.concatenate((all_idxs[(assignment == client_k)], else_idxs[(assignment_else == client_k)]))) 46 | dict_users[client_k] = set(all_idxs[(assignment == client_k)]) 47 | for key in dict_users.keys(): 48 | dict_users[key] = list(dict_users[key]) 49 | return dict_users 50 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import shutil 5 | import sys 6 | import heapq 7 | import numpy as np 8 | import torch 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | def set_seed(seed): 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | 19 | 20 | # def max_m_indices(arr, m): 21 | # max_m_values = heapq.nlargest(m, arr) 22 | # max_m_indices = [arr.index(value) for value in max_m_values] 23 | # return max_m_indices 24 | def max_m_indices(lst, n): 25 | elements_with_indices = list(enumerate(lst)) 26 | sorted_elements = sorted(elements_with_indices, key=lambda x: x[1], reverse=True) 27 | top_n_elements = sorted_elements[:n] 28 | return [index for index, value in top_n_elements] 29 | 30 | 31 | def min_n_indices(lst, n): 32 | elements_with_indices = list(enumerate(lst)) 33 | sorted_elements = sorted(elements_with_indices, key=lambda x: x[1]) 34 | bottom_n_elements = sorted_elements[:n] 35 | return [index for index, value in bottom_n_elements] 36 | # def min_n_indices(arr, n): 37 | # min_n_values = heapq.nsmallest(n, arr) 38 | # min_n_indices = [arr.index(value) for value in min_n_values] 39 | # return min_n_indices 40 | 41 | 42 | def set_output_files(args): 43 | outputs_dir = 'outputs_' + str(args.dataset) + '_' + str( 44 | args.alpha_dirichlet) + '_' + str(args.n_clients) + '_' + str(args.model) + '_' + str(args.n_classes-args.annotation_num) + '5000classmiss7_1037_iid_FedMLP_stage1glo' 45 | # outputs_dir = 'outputs_' + str(args.dataset) + '_' + str(args.model) + '_dataaug_' + str( 46 | # args.n_classes - args.annotation_num) + 'classmiss_' + 'loss_distribution' # demo 47 | if not os.path.exists(outputs_dir): 48 | os.mkdir(outputs_dir) 49 | exp_dir = os.path.join(outputs_dir, args.exp + '_' + 50 | str(args.batch_size) + '_' + str(args.base_lr) + '_' + 51 | str(args.rounds_warmup) + '_' + 52 | str(args.rounds_corr) + '_' + 53 | str(args.rounds_finetune) + '_' + str(args.local_ep)) 54 | if not os.path.exists(exp_dir): 55 | os.mkdir(exp_dir) 56 | models_dir = os.path.join(exp_dir, 'models') 57 | if not os.path.exists(models_dir): 58 | os.mkdir(models_dir) 59 | logs_dir = os.path.join(exp_dir, 'logs') 60 | if not os.path.exists(logs_dir): 61 | os.mkdir(logs_dir) 62 | tensorboard_dir = os.path.join(exp_dir, 'tensorboard') 63 | # if not os.path.exists(tensorboard_dir): 64 | # os.mkdir(tensorboard_dir) 65 | code_dir = os.path.join(exp_dir, 'code') 66 | if os.path.exists(code_dir): 67 | shutil.rmtree(code_dir) 68 | os.mkdir(code_dir) 69 | # shutil.make_archives(code_dir, 'zip', base_dir='/home/szb/multilabel/') 70 | 71 | logging.basicConfig(filename=logs_dir+'/logs.txt', level=logging.INFO, 72 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 73 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 74 | logging.info(str(args)) 75 | writer1 = SummaryWriter(tensorboard_dir + 'writer1') 76 | return writer1, models_dir 77 | -------------------------------------------------------------------------------- /utils/multilabel_metrixs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | 5 | def Hamming_Loss(y_true, y_pred, classid=None): 6 | temp = 0 7 | for i in range(y_true.shape[0]): 8 | temp += np.size(y_true[i] == y_pred[i]) - np.count_nonzero(y_true[i] == y_pred[i]) 9 | return temp / (y_true.shape[0] * y_true.shape[1]) 10 | 11 | 12 | # def Recall(y_true, y_pred): # sample-wise 13 | # temp = 0 14 | # for i in range(y_true.shape[0]): 15 | # if sum(y_true[i]) == 0: 16 | # continue 17 | # temp += sum(np.logical_and(y_true[i], y_pred[i])) / sum(y_true[i]) 18 | # return temp / y_true.shape[0] 19 | 20 | 21 | def Recall(y_true, y_pred, classid=None): # class-wise 22 | temp = 0 23 | if classid is not None: 24 | return sum(np.logical_and(y_true.T[classid], y_pred.T[classid])) / sum(y_true.T[classid]) 25 | else: 26 | for i in range(y_true.T.shape[0]): 27 | temp += sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_true.T[i]) 28 | # print('class: {}, recall: '.format(i), sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_true.T[i])) 29 | return temp / y_true.T.shape[0] 30 | 31 | 32 | def BACC(y_true, y_pred, classid=None): 33 | temp = 0 34 | if classid is not None: 35 | recall1 = sum(np.logical_and(y_true.T[classid], y_pred.T[classid])) / sum(y_true.T[classid]) 36 | recall0 = sum(~np.logical_or(y_true.T[classid], y_pred.T[classid])) / (y_true.T[classid].size - np.count_nonzero(y_true.T[classid])) 37 | bacc = (recall0 + recall1) / 2 38 | return bacc 39 | else: 40 | for i in range(y_true.T.shape[0]): 41 | recall1 = sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_true.T[i]) 42 | recall0 = sum(~np.logical_or(y_true.T[i], y_pred.T[i])) / (y_true.T[i].size - np.count_nonzero(y_true.T[i])) 43 | bacc = (recall0 + recall1) / 2 44 | logging.info('BACC:class%d : %f' % (i, bacc)) 45 | temp += bacc 46 | return temp / y_true.T.shape[0] 47 | 48 | 49 | def Precision(y_true, y_pred, classid=None): 50 | temp = 0 51 | if classid is not None: 52 | return sum(np.logical_and(y_true.T[classid], y_pred.T[classid])) / sum(y_pred.T[classid]) 53 | else: 54 | for i in range(y_true.T.shape[0]): 55 | if sum(y_pred.T[i]) == 0: 56 | continue 57 | logging.info('P:class%d : %f' % (i, sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_pred.T[i]))) 58 | temp += sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_pred.T[i]) 59 | # print('class: {}, precision: '.format(i), sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_pred.T[i])) 60 | return temp / y_true.T.shape[0] 61 | 62 | 63 | def F1Measure(y_true, y_pred, classid=None): 64 | temp = 0 65 | if classid is not None: 66 | return (2*sum(np.logical_and(y_true.T[classid], y_pred.T[classid]))) / (sum(y_true.T[classid])+sum(y_pred.T[classid])) 67 | else: 68 | for i in range(y_true.T.shape[0]): 69 | logging.info('f1:class%d : %f' % (i, (2*sum(np.logical_and(y_true.T[i], y_pred.T[i]))) / (sum(y_true.T[i])+sum(y_pred.T[i])))) 70 | temp += (2*sum(np.logical_and(y_true.T[i], y_pred.T[i]))) / (sum(y_true.T[i])+sum(y_pred.T[i])) 71 | return temp / y_true.T.shape[0] 72 | 73 | 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedMLP 2 | This is the official implementation for the paper: "FedMLP: Federated Multi-Label Medical Image Classiffcation under Task Heterogeneity", which is accepted at MICCAI'24 (Early Accept, top 11% in total 2869 submissions). 3 |

4 | intro 5 |

6 | 7 | ## Introduction 8 | Cross-silo federated learning (FL) enables decentralized organizations to collaboratively train models while preserving data privacy and has made signiffcant progress in medical image classiffcation. One common assumption is task homogeneity where each client has access to all classes during training. However, in clinical practice, given a multi-label classiffcation task, constrained by the level of medical knowledge and the prevalence of diseases, each institution may diagnose only partial categories, resulting in task heterogeneity. How to pursue effective multi-label medical image classiffcation under task heterogeneity is under-explored. In this paper, we first formulate such a realistic label missing setting in the multi-label FL domain and propose a two-stage method FedMLP to combat class missing from two aspects: pseudo label tagging and global knowledge learning. The former utilizes a warmed-up model to generate class prototypes and select samples with high confidence to supplement missing labels, while the latter uses a global model as a teacher for consistency regularization to prevent forgetting missing class knowledge. Experiments on two publicly-available medical datasets validate the superiority of FedMLP against the state-of-the-art both federated semi-supervised and noisy label learning approaches under task 9 | heterogeneity. 10 |

11 | intro 12 |

13 | 14 | 15 | ## Related Work 16 | - RSCFed [[paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Liang_RSCFed_Random_Sampling_Consensus_Federated_Semi-Supervised_Learning_CVPR_2022_paper.pdf)] [[code](https://github.com/XMed-Lab/RSCFed)] 17 | - FedNoRo [[paper](https://arxiv.org/pdf/2305.05230)] [[code](https://github.com/wnn2000/FedNoRo)] 18 | - CBAFed [[paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Li_Class_Balanced_Adaptive_Pseudo_Labeling_for_Federated_Semi-Supervised_Learning_CVPR_2023_paper.pdf)] [[code](https://github.com/minglllli/CBAFed)] 19 | 20 | 21 | ## Dataset 22 | Please download the ICH dataset from [kaggle](https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection) and preprocess it follow this [notebook](https://www.kaggle.com/guiferviz/prepare-dataset-resizing-and-saving-as-png). Please download the ChestXray14 dataset from this [link](https://nihcc.app.box.com/v/ChestXray-NIHCC). 23 | 24 | 25 | ## Requirements 26 | We recommend using conda to setup the environment, See the `requirements.txt` for environment configuration. 27 | 28 | 29 | ## Citation 30 | If this repository is useful for your research, please consider citing: 31 | ```shell 32 | @inproceedings{sun2024fedmlp, 33 | title={FedMLP: Federated Multi-label Medical Image Classification Under Task Heterogeneity}, 34 | author={Sun, Zhaobin and Wu, Nannan and Shi, Junjie and Yu, Li and Cheng, Kwang-Ting and Yan, Zengqiang}, 35 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 36 | pages={394--404}, 37 | year={2024}, 38 | organization={Springer} 39 | } 40 | ``` 41 | 42 | 43 | ## Contact 44 | For any questions, please contact 'zbsun@hust.edu.cn'. 45 | -------------------------------------------------------------------------------- /dataset/all_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import pandas as pd 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class ChestXray14(Dataset): 11 | def __init__(self, datapath, mode, transform=None): 12 | self.datapath = datapath 13 | self.mode = mode 14 | self.transform = transform 15 | 16 | assert self.mode in ["train", "test"] 17 | csv_file = os.path.join("/home/szb/multilabel/", self.mode + "_dataset_8class.csv") 18 | self.file = pd.read_csv(csv_file) 19 | 20 | self.image_list = self.file["Image Index"].values 21 | self.targets = self.file.iloc[0:, 1:].values.astype(np.float32) 22 | 23 | def __getitem__(self, index: int): 24 | image_id, target = self.image_list[index], self.targets[index] 25 | image = self.read_image(image_id) 26 | 27 | if self.transform is not None: 28 | if isinstance(self.transform, tuple): 29 | image1 = self.transform[0](image) 30 | image2 = self.transform[1](image) 31 | return {"image_aug_1": image1, 32 | "image_aug_2": image2, 33 | "target": target, 34 | "index": index, 35 | "image_id": image_id} 36 | else: 37 | image = self.transform(image) 38 | return {"image": image, 39 | "target": target, 40 | "index": index, 41 | "image_id": image_id} 42 | 43 | def __len__(self): 44 | return len(self.targets) 45 | 46 | def read_image(self, image_id): 47 | image_path = os.path.join("/home/szb/ChestXray14/images/image/", image_id) 48 | image = Image.open(image_path).convert("RGB") 49 | return image 50 | 51 | 52 | class ICH(Dataset): 53 | def __init__(self, datapath, mode, transform=None): 54 | self.datapath = datapath 55 | self.mode = mode 56 | self.transform = transform 57 | 58 | assert self.mode in ["train", "test"] 59 | csv_file = os.path.join("/home/szb/ICH_stage2/ICH_stage2/", self.mode + "_dataset_ICH.csv") 60 | # csv_file = os.path.join("/home/szb/ICH_stage2/ICH_stage2/", self.mode + '_demo.csv') # demo exp(5000 samples) 61 | self.file = pd.read_csv(csv_file) 62 | 63 | self.image_list = self.file["Image Index"].values 64 | self.targets = self.file.iloc[0:, 1:].values.astype(np.float32) 65 | 66 | def __getitem__(self, index: int): 67 | image_id, target = self.image_list[index], self.targets[index] 68 | image = self.read_image(image_id) 69 | if self.transform is not None: 70 | if isinstance(self.transform, tuple): 71 | image1 = self.transform[0](image) 72 | image2 = self.transform[1](image) 73 | return {"image_aug_1": image1, 74 | "image_aug_2": image2, 75 | "target": target, 76 | "index": index, 77 | "image_id": image_id} 78 | else: 79 | image = self.transform(image) 80 | return {"image": image, 81 | "target": target, 82 | "index": index, 83 | "image_id": image_id} 84 | 85 | def __len__(self): 86 | return len(self.targets) 87 | 88 | def read_image(self, image_id): 89 | image_path = os.path.join("/home/szb/ICH_stage2/ICH_stage2/png185k_512/", image_id) 90 | image = Image.open(image_path).convert("RGB") 91 | return image 92 | -------------------------------------------------------------------------------- /model/efficientnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | from efficientnet_pytorch import EfficientNet 3 | 4 | pretrained_settings = { 5 | 'efficientnet': { 6 | 'imagenet': { 7 | 'input_space': 'RGB', 8 | 'input_size': [3, 224, 224], 9 | 'input_range': [0, 1], 10 | 'mean': [0.485, 0.456, 0.406], 11 | 'std': [0.229, 0.224, 0.225], 12 | 'num_classes': 1000 13 | } 14 | }, 15 | } 16 | 17 | 18 | def initialize_pretrained_model(model, num_classes, settings): 19 | assert num_classes == settings['num_classes'],'num_classes should be {}, but is {}'.format( 20 | settings['num_classes'], num_classes) 21 | model.input_space = settings['input_space'] 22 | model.input_size = settings['input_size'] 23 | model.input_range = settings['input_range'] 24 | model.mean = settings['mean'] 25 | model.std = settings['std'] 26 | 27 | 28 | def efficientnet_b0(num_classes=1000, pretrained='imagenet'): 29 | model = EfficientNet.from_pretrained('efficientnet-b0', advprop=False) 30 | if pretrained is not None: 31 | settings = pretrained_settings['efficientnet'][pretrained] 32 | initialize_pretrained_model(model, num_classes, settings) 33 | return model 34 | 35 | 36 | def efficientnet_b1(num_classes=1000, pretrained='imagenet'): 37 | model = EfficientNet.from_pretrained('efficientnet-b1', advprop=False) 38 | if pretrained is not None: 39 | settings = pretrained_settings['efficientnet'][pretrained] 40 | initialize_pretrained_model(model, num_classes, settings) 41 | return model 42 | 43 | 44 | def efficientnet_b2(num_classes=1000, pretrained='imagenet'): 45 | model = EfficientNet.from_pretrained('efficientnet-b2', advprop=False) 46 | if pretrained is not None: 47 | settings = pretrained_settings['efficientnet'][pretrained] 48 | initialize_pretrained_model(model, num_classes, settings) 49 | return model 50 | 51 | 52 | def efficientnet_b3(num_classes=1000, pretrained='imagenet'): 53 | model = EfficientNet.from_pretrained('efficientnet-b3', advprop=False) 54 | if pretrained is not None: 55 | settings = pretrained_settings['efficientnet'][pretrained] 56 | initialize_pretrained_model(model, num_classes, settings) 57 | return model 58 | 59 | 60 | def efficientnet_b4(num_classes=1000, pretrained='imagenet'): 61 | model = EfficientNet.from_pretrained('efficientnet-b4', advprop=False) 62 | if pretrained is not None: 63 | settings = pretrained_settings['efficientnet'][pretrained] 64 | initialize_pretrained_model(model, num_classes, settings) 65 | return model 66 | 67 | 68 | def efficientnet_b5(num_classes=1000, pretrained='imagenet'): 69 | model = EfficientNet.from_pretrained('efficientnet-b5', advprop=False) 70 | if pretrained is not None: 71 | settings = pretrained_settings['efficientnet'][pretrained] 72 | initialize_pretrained_model(model, num_classes, settings) 73 | return model 74 | 75 | 76 | def efficientnet_b6(num_classes=1000, pretrained='imagenet'): 77 | model = EfficientNet.from_pretrained('efficientnet-b6', advprop=False) 78 | if pretrained is not None: 79 | settings = pretrained_settings['efficientnet'][pretrained] 80 | initialize_pretrained_model(model, num_classes, settings) 81 | return model 82 | 83 | 84 | def efficientnet_b7(num_classes=1000, pretrained='imagenet'): 85 | model = EfficientNet.from_pretrained('efficientnet-b7', advprop=False) 86 | if pretrained is not None: 87 | settings = pretrained_settings['efficientnet'][pretrained] 88 | initialize_pretrained_model(model, num_classes, settings) 89 | return -------------------------------------------------------------------------------- /utils/FedAvg.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | 4 | import torch 5 | import numpy as np 6 | 7 | def FedAvg(w, dict_len): 8 | w_avg = copy.deepcopy(w[0]) 9 | for k in w_avg.keys(): 10 | w_avg[k] = w_avg[k] * dict_len[0] 11 | for i in range(1, len(w)): 12 | w_avg[k] += w[i][k] * dict_len[i] 13 | w_avg[k] = w_avg[k] / sum(dict_len) 14 | return w_avg 15 | 16 | def Fed_w(w, weight): 17 | w_avg = copy.deepcopy(w[0]) 18 | for k in w_avg.keys(): 19 | w_avg[k] = w_avg[k] * weight[0] 20 | for i in range(1, len(w)): 21 | w_avg[k] += w[i][k] * weight[i] 22 | w_avg[k] = w_avg[k] / sum(weight) 23 | return w_avg 24 | 25 | def RSCFed(DMA, w_locals, K, dict_len, M): 26 | w_sub = [] 27 | for group in DMA: 28 | w_select = [] 29 | N_total = 0 30 | for id in group: 31 | w_select.append(w_locals[id]) 32 | N_total += dict_len[id] 33 | w_avg = Fed_w(w_select, [1]*K) 34 | w = [] 35 | for id in group: 36 | a = dict_len[id] / N_total 37 | b = math.exp((-0.01)*(model_dist(w_locals[id], w_avg)/dict_len[id])) 38 | w.append(a*b) 39 | w_sub.append(Fed_w(w_select, w)) 40 | w_glob = Fed_w(w_sub, [1]*M) 41 | return w_glob 42 | 43 | def model_dist(w_1, w_2): 44 | assert w_1.keys() == w_2.keys(), "Error: cannot compute distance between dict with different keys" 45 | dist_total = torch.zeros(1).float() 46 | for key in w_1: 47 | dist = torch.norm(w_1[key].cpu() - w_2[key].cpu()) 48 | dist_total += dist.cpu() 49 | return dist_total.cpu().item() 50 | 51 | def FedAvg_tao(t, weight, class_active_client_list = None): 52 | if class_active_client_list is None: 53 | t_avg = np.array([0.]*len(t[0])) 54 | for i, tao in enumerate(t): 55 | t_avg += tao * float(weight[i]) 56 | t_avg = t_avg / float(sum(weight)) 57 | return t_avg 58 | else: 59 | t_avg = np.array([0.] * len(t[0])) 60 | for cls, cls_active_clients in enumerate(class_active_client_list): 61 | weight_sum = 0. 62 | for i, tao in enumerate(t): 63 | if i in cls_active_clients: 64 | t_avg[cls] += tao[cls] * float(weight[i]) 65 | weight_sum += float(weight[i]) 66 | if len(cls_active_clients) == 0: 67 | t_avg[cls] = 1. 68 | else: 69 | t_avg[cls] = t_avg[cls] / weight_sum 70 | return t_avg 71 | 72 | def FedAvg_proto(Prototypes, weight, class_active_client_list): 73 | Prototype_avg = torch.zeros((len(Prototypes[0]), len(Prototypes[0][0]))) 74 | # Prototype_avg = np.array([torch.zeros_like(Prototypes[0][0])] * len(Prototypes[0])) 75 | for cls, cls_active_clients in enumerate(class_active_client_list): 76 | Prototype_class_0_avg = torch.zeros_like(Prototypes[0][0]) 77 | Prototype_class_1_avg = torch.zeros_like(Prototypes[0][0]) 78 | for client_id in cls_active_clients: 79 | Prototype_class_0_avg = Prototypes[client_id][2*cls] * weight[client_id] + Prototype_class_0_avg 80 | Prototype_class_1_avg = Prototypes[client_id][2*cls+1] * weight[client_id] + Prototype_class_1_avg 81 | # print(cls) 82 | # print(cls_active_clients) 83 | # print(Prototype_class_0_avg) 84 | # print(Prototype_class_1_avg) 85 | Prototype_class_0_avg = Prototype_class_0_avg / np.sum(np.array(weight)[cls_active_clients]) 86 | Prototype_class_1_avg = Prototype_class_1_avg / np.sum(np.array(weight)[cls_active_clients]) 87 | # print(Prototype_class_0_avg) 88 | # print(Prototype_class_1_avg) 89 | Prototype_avg[2*cls] = Prototype_class_0_avg 90 | Prototype_avg[2*cls+1] = Prototype_class_1_avg 91 | # print(Prototype_avg) 92 | # input() 93 | return Prototype_avg 94 | 95 | def FedAvg_rela(Prototypes, weight, class_active_client_list): 96 | Prototype_avg = torch.zeros((len(Prototypes[0]), len(Prototypes[0][0]))) 97 | for cls, cls_active_clients in enumerate(class_active_client_list): 98 | Prototype_class_avg = torch.zeros_like(Prototypes[0][0]) 99 | for client_id in cls_active_clients: 100 | Prototype_class_avg = Prototypes[client_id][cls] * weight[client_id] + Prototype_class_avg 101 | Prototype_class_avg = Prototype_class_avg / np.sum(np.array(weight)[cls_active_clients]) 102 | Prototype_avg[cls] = Prototype_class_avg 103 | return Prototype_avg -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def args_parser(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # system setting 8 | parser.add_argument('--deterministic', type=int, default=1, 9 | help='whether use deterministic training') 10 | parser.add_argument('--seed', type=int, default=1037, help='random seed2023,1037') 11 | parser.add_argument('--gpu', type=str, default='2', help='GPU to use') 12 | 13 | # basic setting 14 | parser.add_argument('--exp', type=str, 15 | default='FedMLP', help='experiment name') 16 | parser.add_argument('--dataset', type=str, 17 | default='ChestXray14', help='dataset name') 18 | parser.add_argument('--model', type=str, 19 | default='Resnet18', help='model name') 20 | parser.add_argument('--batch_size', type=int, 21 | default=32, help='batch_size per gpu') 22 | parser.add_argument('--feature_dim', type=int, 23 | default=512, help='feature_dim of ResNet18') 24 | parser.add_argument('--base_lr', type=float, default=3e-5, 25 | help='base learning rate,ICH=3e-5,ChestXray14=3e-6') 26 | parser.add_argument('--pretrained', type=int, default=1) 27 | parser.add_argument('--train', type=int, default=1) 28 | 29 | # PSL setting 30 | parser.add_argument('--annotation_num', type=int, 31 | default='1', help='The number of categories annotated by each client.') 32 | 33 | # for FL 34 | parser.add_argument('--n_clients', type=int, default=8, 35 | help='number of users') 36 | parser.add_argument('--n_classes', type=int, default=8, 37 | help='number of classes') 38 | parser.add_argument('--iid', type=int, default=1, help="i.i.d. or non-i.i.d.") 39 | parser.add_argument('--alpha_dirichlet', type=float, 40 | default=0.5, help='parameter for non-iid') 41 | parser.add_argument('--local_ep', type=int, default=1, help='local epoch') 42 | parser.add_argument('--rounds_warmup', type=int, default=500, help='rounds') 43 | parser.add_argument('--rounds_corr', type=int, default=200, help='rounds') 44 | parser.add_argument('--rounds_distillation', type=int, default=200, help='rounds') 45 | parser.add_argument('--rounds_finetune', type=int, default=50, help='rounds') 46 | parser.add_argument('--rounds_FedMLP_stage1', type=int, default=50, help='rounds') 47 | parser.add_argument('--U', type=float, default=0.7, help='tao_upper_bound') 48 | parser.add_argument('--L', type=float, default=0.3, help='tao_lower_bound') 49 | parser.add_argument('--tao_min', type=float, default=0.1, help='tao_min') 50 | parser.add_argument('--runs', type=int, default=1, help='training seed') 51 | 52 | # RoFL 53 | parser.add_argument('--forget_rate', type=float, default=0.2, help='forget_rate') 54 | parser.add_argument('--num_gradual', type=int, default=10, help='T_k') 55 | parser.add_argument('--T_pl', type=int, help='T_pl: When to start using global guided pseudo labeling', default=100) 56 | parser.add_argument('--lambda_cen', type=float, help='lambda_cen', default=1.0) 57 | parser.add_argument('--lambda_e', type=float, help='lambda_e', default=0.8) 58 | 59 | # FedMLP_abu 60 | parser.add_argument('--difficulty_estimate', type=int, default=1, help='tao=1 or cal') 61 | parser.add_argument('--miss_client_difficulty', type=int, default=1, help='consider or not(tao agg method)') 62 | parser.add_argument('--mixup', type=int, default=1, help='y/n') 63 | parser.add_argument('--clean_threshold', type=float, default=0.005, help='clean_threshold') 64 | parser.add_argument('--noise_threshold', type=float, default=0.01, help='noise_threshold') 65 | 66 | # FedLSR 67 | parser.add_argument('--t_w', type=int, default=40, help='clean_threshold') 68 | # FedIRM 69 | parser.add_argument('--rounds_FedIRM_sup', type=int, default=20, help='rounds') 70 | parser.add_argument('--consistency', type=float, default=1, help='consistency') 71 | parser.add_argument('--consistency_rampup', type=float, default=30, help='consistency_rampup') 72 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 73 | # FedNoRo 74 | parser.add_argument('--rounds_FedNoRo_warmup', type=int, default=500, help='rounds') 75 | parser.add_argument('--begin', type=int, default=10, help='ramp up begin') 76 | parser.add_argument('--end', type=int, default=499, help='ramp up end') 77 | parser.add_argument('--a', type=float, default=0.8, help='a') 78 | #CBAFed 79 | parser.add_argument('--rounds_CBAFed_warmup', type=int, default=50, help='rounds') 80 | args = parser.parse_args() 81 | return args 82 | -------------------------------------------------------------------------------- /utils/FedNoRo.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class LogitAdjust_Multilabel(nn.Module): 10 | def __init__(self, cls_num_list, num, tau=1, weight=None): 11 | super(LogitAdjust_Multilabel, self).__init__() 12 | cls_num_list = torch.cuda.FloatTensor(cls_num_list) 13 | self.cls_p_list = cls_num_list / num 14 | self.weight = weight 15 | 16 | def forward(self, x, target): 17 | x_m = x.clone() 18 | # for i in range(len(self.cls_p_list)): #abu1 19 | # x_m[:, i] = (x_m[:, i]*self.cls_p_list[i])/(x_m[:, i]*self.cls_p_list[i] + (1-x_m[:, i])*(1-self.cls_p_list[i])) 20 | # nan_mask = torch.isnan(x_m) 21 | # x_m[nan_mask] = 0. 22 | return F.binary_cross_entropy(x_m, target, weight=self.weight, reduction='none') 23 | 24 | 25 | class LA_KD(nn.Module): 26 | def __init__(self, cls_num_list, num, active_class_list_client, negative_class_list_client, tau=1, weight=None): 27 | super(LA_KD, self).__init__() 28 | cls_num_list = torch.cuda.FloatTensor(cls_num_list) 29 | self.active_class_list_client = active_class_list_client 30 | self.negative_class_list_client = negative_class_list_client 31 | self.cls_p_list = cls_num_list / num 32 | self.weight = weight 33 | self.bce = LogitAdjust_Multilabel(cls_num_list, num) 34 | 35 | def forward(self, x, target, soft_target, w_kd): 36 | bceloss = self.bce(x, target)[:, self.active_class_list_client].sum()/(len(x) * len(self.active_class_list_client)) 37 | kl = F.mse_loss(x, soft_target, reduction='none')[:, self.negative_class_list_client].sum()/(len(x) * len(self.negative_class_list_client)) 38 | return w_kd * kl + (1 - w_kd) * bceloss 39 | 40 | 41 | def get_output(loader, net, args, sigmoid=False, criterion=None): 42 | net.eval() 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | with torch.no_grad(): 46 | for i, samples in enumerate(loader): 47 | images = samples["image"].to(args.device) 48 | labels = samples["target"].to(args.device) 49 | if sigmoid == True: 50 | _, outputs = net(images) 51 | outputs = torch.sigmoid(outputs) 52 | else: 53 | _, outputs = net(images) 54 | if criterion is not None: 55 | loss = criterion(outputs, labels) 56 | if i == 0: 57 | output_whole = np.array(outputs.cpu()) 58 | if criterion is not None: 59 | loss_whole = np.array(loss.cpu()) 60 | else: 61 | output_whole = np.concatenate( 62 | (output_whole, outputs.cpu()), axis=0) 63 | if criterion is not None: 64 | loss_whole = np.concatenate( 65 | (loss_whole, loss.cpu()), axis=0) 66 | if criterion is not None: 67 | return output_whole, loss_whole 68 | else: 69 | return output_whole 70 | 71 | 72 | def sigmoid_rampup(current, begin, end): 73 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 74 | current = np.clip(current, begin, end) 75 | phase = 1.0 - (current-begin) / (end-begin) 76 | return float(np.exp(-5.0 * phase * phase)) 77 | 78 | 79 | def get_current_consistency_weight(rnd, begin, end): 80 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 81 | return sigmoid_rampup(rnd, begin, end) 82 | 83 | 84 | def DaAgg(w, dict_len, clean_clients, noisy_clients): 85 | client_weight = np.array(dict_len) 86 | client_weight = client_weight / client_weight.sum() 87 | distance = np.zeros(len(dict_len)) 88 | for n_idx in noisy_clients: 89 | dis = [] 90 | for c_idx in clean_clients: 91 | dis.append(model_dist(w[n_idx], w[c_idx])) 92 | distance[n_idx] = min(dis) 93 | distance = distance / distance.max() 94 | client_weight = client_weight * np.exp(-distance) 95 | client_weight = client_weight / client_weight.sum() 96 | # print(client_weight) 97 | 98 | w_avg = copy.deepcopy(w[0]) 99 | for k in w_avg.keys(): 100 | w_avg[k] = w_avg[k] * client_weight[0] 101 | for i in range(1, len(w)): 102 | w_avg[k] += w[i][k] * client_weight[i] 103 | return w_avg 104 | 105 | 106 | def model_dist(w_1, w_2): 107 | assert w_1.keys() == w_2.keys(), "Error: cannot compute distance between dict with different keys" 108 | dist_total = torch.zeros(1).float() 109 | for key in w_1.keys(): 110 | if "int" in str(w_1[key].dtype): 111 | continue 112 | dist = torch.norm(w_1[key] - w_2[key]) 113 | dist_total += dist.cpu() 114 | 115 | return dist_total.cpu().item() -------------------------------------------------------------------------------- /utils/evaluations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.optim 7 | from matplotlib import pyplot as plt 8 | from torch.utils.data import DataLoader 9 | from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, roc_auc_score, confusion_matrix, recall_score, roc_curve, auc 10 | from sklearn.metrics import average_precision_score 11 | 12 | from utils.multilabel_metrixs import Recall, Hamming_Loss, F1Measure, Precision, BACC 13 | 14 | 15 | def globaltest(net, test_dataset, args): 16 | auroc = 0 17 | net.eval() 18 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size*4, shuffle=False, num_workers=4) 19 | all_preds = np.array([]) 20 | all_probs = [] 21 | all_labels = np.array(test_dataset.targets) 22 | with torch.no_grad(): 23 | for samples in test_loader: 24 | images = samples["image"].to(args.device) 25 | _, outputs = net(images) 26 | probs = torch.sigmoid(outputs) # soft predict 27 | accuracy_th = 0.5 28 | preds = probs > accuracy_th # hard predict 29 | all_probs.append(probs.detach().cpu()) 30 | if all_preds.ndim == 1: 31 | all_preds = preds.detach().cpu().numpy() 32 | else: 33 | all_preds = np.concatenate([all_preds, preds.detach().cpu().numpy()], axis=0) 34 | 35 | all_probs = torch.cat(all_probs).numpy() 36 | assert all_probs.shape[0] == len(test_dataset) 37 | assert all_probs.shape[1] == args.n_classes 38 | logging.info(np.sum(all_preds, axis=0)) 39 | logging.info(np.sum(all_labels, axis=0)) 40 | 41 | APs = [] 42 | 43 | for label_index in range(all_labels.shape[1]): 44 | true_labels = all_labels[:, label_index] 45 | predicted_scores = all_probs[:, label_index] 46 | ap = average_precision_score(true_labels, predicted_scores) 47 | APs.append(ap) 48 | 49 | mAP = torch.tensor(APs).mean() 50 | 51 | bacc = BACC(all_labels, all_preds) 52 | R = Recall(all_labels, all_preds) 53 | hamming_loss = Hamming_Loss(all_labels, all_preds) 54 | F1 = F1Measure(all_labels, all_preds) 55 | P = Precision(all_labels, all_preds) 56 | colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#000000', '#808080', '#C0C0C0', '#800000', 57 | '#008000', '#000080', '#808000', '#008080'] 58 | # print(all_probs) 59 | for i in range(len(all_labels.T)): 60 | fpr, tpr, th = roc_curve(all_labels.T[i], all_probs.T[i], pos_label=1) 61 | # ROCprint(fpr, tpr, i, colors[i]) 62 | auroc += auc(fpr, tpr) 63 | # print('class: {}, auc: '.format(i), auc(fpr, tpr)) 64 | # plt.show() 65 | auroc /= len(all_labels.T) 66 | 67 | return {"mAP": mAP, 68 | "BACC": bacc, 69 | "R": R, 70 | "F1": F1, 71 | "auc": auroc, 72 | "P": P, 73 | "hamming_loss": hamming_loss} 74 | 75 | 76 | def ROCprint(fpr, tpr, name, colorname): 77 | plt.plot(fpr, tpr, lw=1, label='{} (AUC={:.3f})'.format(name, auc(fpr, tpr)), color=colorname) 78 | plt.plot([0, 1], [0, 1], '--', lw=1, color='grey') 79 | plt.axis('square') 80 | plt.xlim([0, 1]) 81 | plt.ylim([0, 1]) 82 | plt.xlabel('False Positive Rate', fontsize=20) 83 | plt.ylabel('True Positive Rate', fontsize=20) 84 | plt.title('ROC Curve', fontsize=25) 85 | plt.legend(loc='lower right', fontsize=20) 86 | plt.savefig('multi_models_roc.png') 87 | 88 | 89 | def classtest(net, test_dataset, args, classid): 90 | auroc = 0 91 | net.eval() 92 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size*4, shuffle=False, num_workers=4) 93 | all_preds = np.array([]) 94 | all_probs = [] 95 | all_labels = np.array(test_dataset.targets) 96 | with torch.no_grad(): 97 | for samples in test_loader: 98 | images = samples["image"].to(args.device) 99 | _, outputs = net(images) 100 | probs = torch.sigmoid(outputs) # soft predict 101 | accuracy_th = 0.5 102 | preds = probs > accuracy_th # hard predict 103 | all_probs.append(probs.detach().cpu()) 104 | if all_preds.ndim == 1: 105 | all_preds = preds.detach().cpu().numpy() 106 | else: 107 | all_preds = np.concatenate([all_preds, preds.detach().cpu().numpy()], axis=0) 108 | 109 | all_probs = torch.cat(all_probs).numpy() 110 | assert all_probs.shape[0] == len(test_dataset) 111 | assert all_probs.shape[1] == args.n_classes 112 | logging.info(np.sum(all_preds, axis=0)) 113 | logging.info(np.sum(all_labels, axis=0)) 114 | 115 | bacc = BACC(all_labels, all_preds, classid) 116 | R = Recall(all_labels, all_preds, classid) 117 | F1 = F1Measure(all_labels, all_preds, classid) 118 | P = Precision(all_labels, all_preds, classid) 119 | colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#000000', '#808080', '#C0C0C0', '#800000', 120 | '#008000', '#000080', '#808000', '#008080'] 121 | # print(all_probs) 122 | # for i in range(len(all_labels.T)): 123 | # fpr, tpr, th = roc_curve(all_labels.T[i], all_probs.T[i], pos_label=1) 124 | # # ROCprint(fpr, tpr, i, colors[i]) 125 | # auroc += auc(fpr, tpr) 126 | # # print('class: {}, auc: '.format(i), auc(fpr, tpr)) 127 | # # plt.show() 128 | # auroc /= len(all_labels.T) 129 | 130 | return {"BACC": bacc, 131 | "R": R, 132 | "F1": F1, 133 | "P": P} 134 | -------------------------------------------------------------------------------- /model/all_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | from torchvision import models 6 | import pretrainedmodels 7 | from .efficientnet import efficientnet_b0 8 | from .efficientnet import efficientnet_b1 9 | from .efficientnet import efficientnet_b2 10 | from .efficientnet import efficientnet_b3 11 | from .efficientnet import efficientnet_b4 12 | from .efficientnet import efficientnet_b5 13 | from .efficientnet import efficientnet_b6 14 | from .efficientnet import efficientnet_b7 15 | 16 | 17 | class FCNorm(nn.Module): 18 | def __init__(self, in_features, out_features): 19 | super(FCNorm, self).__init__() 20 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) 21 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 22 | self.s = 30 23 | 24 | def forward(self, x): 25 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 26 | return self.s * out 27 | 28 | 29 | def get_model(model_name, pretrained=False): 30 | """Returns a CNN model 31 | Args: 32 | model_name: model name 33 | pretrained: True or False 34 | Returns: 35 | model: the desired model 36 | Raises: 37 | ValueError: If model name is not recognized. 38 | """ 39 | if pretrained == False: 40 | pt = None 41 | else: 42 | pt = 'imagenet' 43 | print(f"model pretrained: {pretrained}, mode: {pt}", ) 44 | 45 | if model_name == 'Vgg11': 46 | return models.vgg11(pretrained=pretrained) 47 | elif model_name == 'Vgg13': 48 | return models.vgg13(pretrained=pretrained) 49 | elif model_name == 'Vgg16': 50 | return models.vgg16(pretrained=pretrained) 51 | elif model_name == 'Vgg19': 52 | return models.vgg19(pretrained=pretrained) 53 | elif model_name == 'Resnet18': 54 | return models.resnet18(pretrained=pretrained) 55 | elif model_name == 'Resnet34': 56 | return models.resnet34(pretrained=pretrained) 57 | elif model_name == 'Resnet50': 58 | return models.resnet50(pretrained=pretrained) 59 | elif model_name == 'Resnet101': 60 | return models.resnet101(pretrained=pretrained) 61 | elif model_name == 'Resnet152': 62 | return models.resnet152(pretrained=pretrained) 63 | elif model_name == 'Dense121': 64 | return models.densenet121(pretrained=pretrained) 65 | elif model_name == 'Dense169': 66 | return models.densenet169(pretrained=pretrained) 67 | elif model_name == 'Dense201': 68 | return models.densenet201(pretrained=pretrained) 69 | elif model_name == 'Dense161': 70 | return models.densenet161(pretrained=pretrained) 71 | elif model_name == 'SENet50': 72 | return pretrainedmodels.__dict__['se_resnet50'](num_classes=1000, pretrained=pt) 73 | elif model_name == 'SENet101': 74 | return pretrainedmodels.__dict__['se_resnet101'](num_classes=1000, pretrained=pt) 75 | elif model_name == 'SENet152': 76 | return pretrainedmodels.__dict__['se_resnet152'](num_classes=1000, pretrained=pt) 77 | elif model_name == 'SENet154': 78 | return pretrainedmodels.__dict__['senet154'](num_classes=1000, pretrained=pt) 79 | elif model_name == 'Efficient_b0': 80 | return efficientnet_b0(num_classes=1000, pretrained=pt) 81 | elif model_name == 'Efficient_b1': 82 | return efficientnet_b1(num_classes=1000, pretrained=pt) 83 | elif model_name == 'Efficient_b2': 84 | return efficientnet_b2(num_classes=1000, pretrained=pt) 85 | elif model_name == 'Efficient_b3': 86 | return efficientnet_b3(num_classes=1000, pretrained=pt) 87 | elif model_name == 'Efficient_b4': 88 | return efficientnet_b4(num_classes=1000, pretrained=pt) 89 | elif model_name == 'Efficient_b5': 90 | return efficientnet_b5(num_classes=1000, pretrained=pt) 91 | elif model_name == 'Efficient_b6': 92 | return efficientnet_b6(num_classes=1000, pretrained=pt) 93 | elif model_name == 'Efficient_b7': 94 | return efficientnet_b7(num_classes=1000, pretrained=pt) 95 | else: 96 | raise ValueError('Name of model unknown %s' % model_name) 97 | 98 | 99 | def modify_last_layer(model_name, model, num_classes, normed=False, bias=True): 100 | """modify the last layer of the CNN model to fit the num_classes 101 | Args: 102 | model_name: model name 103 | model: CNN model 104 | num_classes: class number 105 | Returns: 106 | model: the desired model 107 | """ 108 | 109 | if 'Vgg' in model_name: 110 | num_ftrs = model.classifier._modules['6'].in_features 111 | model.classifier._modules['6'] = classifier(num_ftrs, num_classes, normed, bias) 112 | last_layer = model.classifier._modules['6'] 113 | elif 'Dense' in model_name: 114 | num_ftrs = model.classifier.in_features 115 | model.classifier = classifier(num_ftrs, num_classes, normed, bias) 116 | last_layer = model.classifier 117 | elif 'Resnet' in model_name: 118 | num_ftrs = model.fc.in_features 119 | model.fc = classifier(num_ftrs, num_classes, normed, bias) 120 | last_layer = model.fc 121 | elif 'Efficient' in model_name: 122 | num_ftrs = model._fc.in_features 123 | model._fc = classifier(num_ftrs, num_classes, normed, bias) 124 | last_layer = model._fc 125 | else: 126 | num_ftrs = model.last_linear.in_features 127 | model.last_linear = classifier(num_ftrs, num_classes, normed, bias) 128 | last_layer = model.last_linear 129 | # print(model) 130 | return model, last_layer 131 | 132 | 133 | def classifier(num_features, num_classes, normed, bias): 134 | if normed: 135 | last_linear = FCNorm(num_features, num_classes) 136 | else: 137 | last_linear = nn.Linear(num_features, num_classes, bias=bias) 138 | return last_linear 139 | 140 | 141 | def get_feature_length(model_name, model): 142 | """get the feature length of the last feature layer 143 | Args: 144 | model_name: model name 145 | model: CNN model 146 | Returns: 147 | num_ftrs: the feature length of the last feature layer 148 | """ 149 | if 'Vgg' in model_name: 150 | num_ftrs = model.classifier._modules['6'].in_features 151 | elif 'Dense' in model_name: 152 | num_ftrs = model.classifier.in_features 153 | elif 'Resnet' in model_name: 154 | num_ftrs = model.fc.in_features 155 | elif 'Efficient' in model_name: 156 | num_ftrs = model._fc.in_features 157 | elif 'RegNet' in model_name: 158 | num_ftrs = model.head.fc.in_features 159 | else: 160 | num_ftrs = model.last_linear.in_features 161 | 162 | return num_ftrs -------------------------------------------------------------------------------- /utils/FixMatch.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | # gray 57 | color = (127, 127, 127) 58 | img = img.copy() 59 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 60 | return img 61 | 62 | 63 | def Equalize(img, **kwarg): 64 | return PIL.ImageOps.equalize(img) 65 | 66 | 67 | def Identity(img, **kwarg): 68 | return img 69 | 70 | 71 | def Invert(img, **kwarg): 72 | return PIL.ImageOps.invert(img) 73 | 74 | 75 | def Posterize(img, v, max_v, bias=0): 76 | v = _int_parameter(v, max_v) + bias 77 | return PIL.ImageOps.posterize(img, v) 78 | 79 | 80 | def Rotate(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | if random.random() < 0.5: 83 | v = -v 84 | return img.rotate(v) 85 | 86 | 87 | def Sharpness(img, v, max_v, bias=0): 88 | v = _float_parameter(v, max_v) + bias 89 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 90 | 91 | 92 | def ShearX(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | if random.random() < 0.5: 95 | v = -v 96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 97 | 98 | 99 | def ShearY(img, v, max_v, bias=0): 100 | v = _float_parameter(v, max_v) + bias 101 | if random.random() < 0.5: 102 | v = -v 103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 104 | 105 | 106 | def Solarize(img, v, max_v, bias=0): 107 | v = _int_parameter(v, max_v) + bias 108 | return PIL.ImageOps.solarize(img, 256 - v) 109 | 110 | 111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 112 | v = _int_parameter(v, max_v) + bias 113 | if random.random() < 0.5: 114 | v = -v 115 | img_np = np.array(img).astype(np.int) 116 | img_np = img_np + v 117 | img_np = np.clip(img_np, 0, 255) 118 | img_np = img_np.astype(np.uint8) 119 | img = Image.fromarray(img_np) 120 | return PIL.ImageOps.solarize(img, threshold) 121 | 122 | 123 | def TranslateX(img, v, max_v, bias=0): 124 | v = _float_parameter(v, max_v) + bias 125 | if random.random() < 0.5: 126 | v = -v 127 | v = int(v * img.size[0]) 128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 129 | 130 | 131 | def TranslateY(img, v, max_v, bias=0): 132 | v = _float_parameter(v, max_v) + bias 133 | if random.random() < 0.5: 134 | v = -v 135 | v = int(v * img.size[1]) 136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 137 | 138 | 139 | def _float_parameter(v, max_v): 140 | return float(v) * max_v / PARAMETER_MAX 141 | 142 | 143 | def _int_parameter(v, max_v): 144 | return int(v * max_v / PARAMETER_MAX) 145 | 146 | 147 | def fixmatch_augment_pool(): 148 | # FixMatch paper 149 | augs = [(AutoContrast, None, None), 150 | (Brightness, 0.9, 0.05), 151 | (Color, 0.9, 0.05), 152 | (Contrast, 0.9, 0.05), 153 | (Equalize, None, None), 154 | (Identity, None, None), 155 | (Posterize, 4, 4), 156 | (Rotate, 30, 0), 157 | (Sharpness, 0.9, 0.05), 158 | (ShearX, 0.3, 0), 159 | (ShearY, 0.3, 0), 160 | (Solarize, 256, 0), 161 | (TranslateX, 0.3, 0), 162 | (TranslateY, 0.3, 0)] 163 | return augs 164 | 165 | 166 | def my_augment_pool(): 167 | # Test 168 | augs = [(AutoContrast, None, None), 169 | (Brightness, 1.8, 0.1), 170 | (Color, 1.8, 0.1), 171 | (Contrast, 1.8, 0.1), 172 | (Cutout, 0.2, 0), 173 | (Equalize, None, None), 174 | (Invert, None, None), 175 | (Posterize, 4, 4), 176 | (Rotate, 30, 0), 177 | (Sharpness, 1.8, 0.1), 178 | (ShearX, 0.3, 0), 179 | (ShearY, 0.3, 0), 180 | (Solarize, 256, 0), 181 | (SolarizeAdd, 110, 0), 182 | (TranslateX, 0.45, 0), 183 | (TranslateY, 0.45, 0)] 184 | return augs 185 | 186 | 187 | class RandAugmentPC(object): 188 | def __init__(self, n, m): 189 | assert n >= 1 190 | assert 1 <= m <= 10 191 | self.n = n 192 | self.m = m 193 | self.augment_pool = my_augment_pool() 194 | 195 | def __call__(self, img): 196 | ops = random.choices(self.augment_pool, k=self.n) 197 | for op, max_v, bias in ops: 198 | prob = np.random.uniform(0.2, 0.8) 199 | if random.random() + prob >= 1: 200 | img = op(img, v=self.m, max_v=max_v, bias=bias) 201 | img = CutoutAbs(img, int(32*0.5)) 202 | return img 203 | 204 | 205 | class RandAugmentMC(object): 206 | def __init__(self, n, m): 207 | assert n >= 1 208 | assert 1 <= m <= 10 209 | self.n = n 210 | self.m = m 211 | self.augment_pool = fixmatch_augment_pool() 212 | 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | img = CutoutAbs(img, int(32*0.5)) 220 | return img -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from torchvision.transforms import transforms 5 | 6 | from dataset.all_dataset import ChestXray14, ICH 7 | from utils.FixMatch import RandAugmentMC 8 | from utils.sampling import iid_sampling, non_iid_dirichlet_sampling 9 | 10 | 11 | def get_dataset(args): 12 | if args.dataset == "ChestXray14": 13 | root = "/home/szb/multilabel/onehot-label-PA.csv" 14 | args.n_classes = 8 15 | args.n_clients = 8 16 | args.num_users = args.n_clients 17 | args.input_channel = 3 18 | 19 | # normalize = transforms.Normalize([0.498, 0.498, 0.498], 20 | # [0.228, 0.228, 0.228]) 21 | normalize = transforms.Normalize([0.485, 0.456, 0.406], 22 | [0.229, 0.224, 0.225]) 23 | if args.exp == 'FedAVG' or args.exp == 'RoFL' or args.exp == 'FedNoRo' or args.exp == 'CBAFed': 24 | train_transform = transforms.Compose([ 25 | transforms.Resize((224, 224)), 26 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]) 31 | test_transform = transforms.Compose([ 32 | transforms.Resize((224, 224)), 33 | transforms.ToTensor(), 34 | normalize, 35 | ]) 36 | train_dataset = ChestXray14(root, "train", train_transform) 37 | test_dataset = ChestXray14(root, "test", test_transform) 38 | 39 | elif args.exp == 'RSCFed' or args.exp == 'FedPN' or args.exp == 'FedLSR' or args.exp == 'FedIRM': 40 | train_transform1 = transforms.Compose([ 41 | transforms.Resize((224, 224)), 42 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | normalize, 46 | ]) 47 | train_transform2 = transforms.Compose([ 48 | transforms.Resize((224, 224)), 49 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | normalize, 53 | ]) 54 | test_transform = transforms.Compose([ 55 | transforms.Resize((224, 224)), 56 | transforms.ToTensor(), 57 | normalize, 58 | ]) 59 | train_dataset = ChestXray14(root, "train", (train_transform1, train_transform2)) 60 | test_dataset = ChestXray14(root, "test", test_transform) 61 | 62 | elif args.exp == 'FedAVG+FixMatch': 63 | train_weak_transform = transforms.Compose([ 64 | transforms.Resize((224, 224)), 65 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 66 | transforms.RandomHorizontalFlip(), 67 | transforms.ToTensor(), 68 | normalize, 69 | ]) 70 | train_strong_transform = transforms.Compose([ 71 | transforms.Resize((224, 224)), 72 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 73 | transforms.RandomHorizontalFlip(), 74 | RandAugmentMC(n=2, m=10), 75 | transforms.ToTensor(), 76 | normalize, 77 | ]) 78 | test_transform = transforms.Compose([ 79 | transforms.Resize((224, 224)), 80 | transforms.ToTensor(), 81 | normalize, 82 | ]) 83 | train_dataset = ChestXray14(root, "train", (train_weak_transform, train_strong_transform)) 84 | test_dataset = ChestXray14(root, "test", test_transform) 85 | 86 | elif args.dataset == "ICH": 87 | root = "/home/szb/ICH_stage2/ICH_stage2/data_png185k_512.csv" 88 | args.n_classes = 5 89 | args.n_clients = 5 90 | args.num_users = args.n_clients 91 | args.input_channel = 3 92 | 93 | # normalize = transforms.Normalize([0.498, 0.498, 0.498], 94 | # [0.228, 0.228, 0.228]) 95 | normalize = transforms.Normalize([0.485, 0.456, 0.406], 96 | [0.229, 0.224, 0.225]) 97 | if args.exp == 'FedAVG' or args.exp == 'RoFL' or args.exp == 'FedNoRo' or args.exp == 'CBAFed': 98 | train_transform = transforms.Compose([ 99 | transforms.Resize((224, 224)), 100 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 101 | transforms.RandomHorizontalFlip(), 102 | transforms.ToTensor(), 103 | normalize, 104 | ]) 105 | test_transform = transforms.Compose([ 106 | transforms.Resize((224, 224)), 107 | transforms.ToTensor(), 108 | normalize, 109 | ]) 110 | train_dataset = ICH(root, "train", train_transform) 111 | test_dataset = ICH(root, "test", test_transform) 112 | 113 | elif args.exp == 'RSCFed' or args.exp == 'FedPN' or args.exp == 'FedLSR' or args.exp == 'FedIRM': 114 | train_transform1 = transforms.Compose([ 115 | transforms.Resize((224, 224)), 116 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 117 | transforms.RandomHorizontalFlip(), 118 | transforms.ToTensor(), 119 | normalize, 120 | ]) 121 | train_transform2 = transforms.Compose([ 122 | transforms.Resize((224, 224)), 123 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 124 | transforms.RandomHorizontalFlip(), 125 | transforms.ToTensor(), 126 | normalize, 127 | ]) 128 | test_transform = transforms.Compose([ 129 | transforms.Resize((224, 224)), 130 | transforms.ToTensor(), 131 | normalize, 132 | ]) 133 | train_dataset = ICH(root, "train", (train_transform1, train_transform2)) 134 | test_dataset = ICH(root, "test", test_transform) 135 | 136 | elif args.exp == 'FedAVG+FixMatch': 137 | train_weak_transform = transforms.Compose([ 138 | transforms.Resize((224, 224)), 139 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 140 | transforms.RandomHorizontalFlip(), 141 | transforms.ToTensor(), 142 | normalize, 143 | ]) 144 | train_strong_transform = transforms.Compose([ 145 | transforms.Resize((224, 224)), 146 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)), 147 | transforms.RandomHorizontalFlip(), 148 | RandAugmentMC(n=2, m=10), 149 | transforms.ToTensor(), 150 | normalize, 151 | ]) 152 | test_transform = transforms.Compose([ 153 | transforms.Resize((224, 224)), 154 | transforms.ToTensor(), 155 | normalize, 156 | ]) 157 | train_dataset = ICH(root, "train", (train_weak_transform, train_strong_transform)) 158 | test_dataset = ICH(root, "test", test_transform) 159 | 160 | else: 161 | exit("Error: unrecognized dataset") 162 | 163 | n_train = len(train_dataset) 164 | y_train = np.array(train_dataset.targets) 165 | assert n_train == len(y_train) 166 | print(n_train) 167 | 168 | # Load or Generate 'dict_users' 169 | if args.iid == 0: # non-iid 170 | if os.path.exists(f"non-iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients)+'_'+str(args.alpha_dirichlet)}.npy"): 171 | dict_users = np.load(f"non-iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients)+'_'+str(args.alpha_dirichlet)}.npy", allow_pickle=True).item() 172 | else: 173 | dict_users = non_iid_dirichlet_sampling(y_train, args.n_classes, 1.0, args.n_clients, seed=args.seed, alpha_dirichlet=args.alpha_dirichlet) 174 | np.save(f"non-iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients)+'_'+str(args.alpha_dirichlet)}.npy", dict_users, allow_pickle=True) 175 | else: 176 | if os.path.exists(f"iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients) + '5000'}.npy"): 177 | dict_users = np.load(f"iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients) + '5000'}.npy", allow_pickle=True).item() 178 | else: 179 | dict_users = iid_sampling(n_train, args.n_clients, args.seed) 180 | np.save(f"iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients) + '5000'}.npy", dict_users, allow_pickle=True) 181 | return train_dataset, test_dataset, dict_users 182 | 183 | 184 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from collections import Counter 5 | from copy import deepcopy 6 | 7 | import numpy as np 8 | import torch 9 | from numpy import where 10 | from sklearn.mixture import GaussianMixture 11 | from torch import nn 12 | from torch.backends import cudnn 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from dataset.dataset import get_dataset 17 | from model.build_model import build_model 18 | from utils.FedAvg import FedAvg, RSCFed, FedAvg_tao, FedAvg_proto, FedAvg_rela 19 | from utils.FedNoRo import get_output, get_current_consistency_weight, DaAgg 20 | from utils.evaluations import globaltest, classtest 21 | from utils.feature_visual import tnse_Visual 22 | from utils.local_training import LocalUpdate 23 | from utils.options import args_parser 24 | from utils.utils import set_seed, set_output_files 25 | from utils.valloss_cal import valloss 26 | 27 | np.set_printoptions(threshold=np.inf) 28 | 29 | if __name__ == '__main__': 30 | args = args_parser() 31 | args.num_users = args.n_clients 32 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 33 | args.device = "cuda" if torch.cuda.is_available() else "cpu" 34 | 35 | # ------------------------------ deterministic or not ------------------------------ 36 | if args.deterministic: 37 | cudnn.benchmark = False 38 | cudnn.deterministic = True 39 | set_seed(args.seed) 40 | 41 | # ------------------------------ output files ------------------------------ 42 | writer1, models_dir = set_output_files(args) 43 | 44 | # ------------------------------ dataset ------------------------------ 45 | dataset_train, dataset_test, dict_users = get_dataset(args) 46 | # for key in dict_users.keys(): 47 | # print(len(dict_users[key])) 48 | # dict_users[key] = random.sample(dict_users[key], 2500) 49 | # np.save(f"iid-dictusers/{str(args.dataset) + '_' + str(args.seed) + '_' + str(args.n_clients) + '5000'}.npy", 50 | # dict_users, allow_pickle=True) 51 | 52 | 53 | logging.info( 54 | f"train: {np.sum(dataset_train.targets, axis=0)}, total: {len(dataset_train.targets)}") 55 | logging.info( 56 | f"test: {np.sum(dataset_test.targets, axis=0)}, total: {len(dataset_test.targets)}") 57 | 58 | row_idx_1, column_idx_1 = where(dataset_train.targets == 1) 59 | class_pos_idx_1 = [] 60 | for i in range(args.n_classes): 61 | class_pos_idx_1.append(row_idx_1[where(column_idx_1 == i)[0]]) 62 | 63 | p_pos_1 = 0. 64 | class_neg_idx_1 = [] 65 | for i in range(args.n_classes): 66 | class_neg_idx_1.append(np.random.choice(class_pos_idx_1[i], int((1-p_pos_1)*len(class_pos_idx_1[i])), replace=False)) # list(array, ...) 67 | # --------------------- Partially Labelling --------------------------- 68 | if args.train: 69 | # ------------------------------ local settings ------------------------------ 70 | user_id = list(range(args.n_clients)) 71 | dict_len = [len(dict_users[idx]) for idx in user_id] 72 | trainer_locals_1 = [] 73 | netglob = build_model(args) 74 | for i in user_id: 75 | trainer_locals_1.append(LocalUpdate( 76 | args, i, deepcopy(dataset_train), dict_users[i], class_pos_idx_1, class_neg_idx_1, dataset_test=dataset_test, active_class_list=[i], student=deepcopy(netglob).to(args.device), 77 | teacher_neg=deepcopy(netglob).to(args.device), teacher_act=deepcopy(netglob).to(args.device))) # student initial is global 78 | # trainer_locals_1.append(LocalUpdate( 79 | # args, i, deepcopy(dataset_train), dict_users[i], class_pos_idx_1, class_neg_idx_1, dataset_test=dataset_test, student=deepcopy(netglob).to(args.device), 80 | # teacher_neg=deepcopy(netglob).to(args.device), 81 | # teacher_act=deepcopy(netglob).to(args.device))) # student initial is global 82 | 83 | 84 | # ------------------------------ begin training ------------------------------ 85 | for run in range(args.runs): 86 | set_seed(int(run)) 87 | netclass = build_model(args) 88 | logging.info(f"\n===============================> beging, run: {run} <===============================\n") 89 | w_class_fl = [] 90 | active_class_list = [] # [, , , , ] 91 | negetive_class_list = [] # [, , , , ] 92 | class_active_client_list = [] 93 | class_negative_client_list = [] 94 | 95 | # FeMLP 96 | tao = [0] * args.n_classes 97 | Prototype = [] 98 | # RoFL: Initialize f_G 99 | f_G = torch.randn(2*args.n_classes, args.feature_dim, device=args.device) # [[cls0_0],[cls0_1],[cls1_0]...] 100 | forget_rate_schedule = [] 101 | forget_rate = args.forget_rate 102 | exponent = 1 103 | forget_rate_schedule = np.ones(args.rounds_warmup) * forget_rate 104 | forget_rate_schedule[:args.num_gradual] = np.linspace(0, forget_rate ** exponent, args.num_gradual) 105 | # ------------------------------ stage1:warm-up ------------------------------ 106 | for rnd in range(0, args.rounds_warmup): 107 | # if rnd == 49: 108 | # netglob.load_state_dict(torch.load( 109 | # "/home/szb/multilabel/chest 5 miss model/model_warmup_globdistill_0.3_0.1_49.pth")) 110 | # negetive_class_list = [[1,2,3,4], [0,2,3,4], [0,1,3,4], [0,1,2,4], [0,1,2,3]] 111 | # active_class_list = [[0], [1], [2], [3], [4]] 112 | # class_active_client_list = [[0, 1, 2, 3, 4], [0, 1, 2], [0, 2, 3, 4], [0, 1, 3, 4], [1, 2, 3, 4]] 113 | # class_negative_client_list = [[], [3, 4], [1], [2], [0]] 114 | if args.exp == 'RSCFed': 115 | M = 10 116 | K = 6 117 | DMA = [] 118 | for i in range(M): 119 | random_numbers = random.sample(range(args.n_clients), K) 120 | DMA.append(random_numbers) 121 | print('DMA: ', DMA) 122 | # if args.exp == 'RoFL': 123 | # if len(forget_rate_schedule) > 0: 124 | # args.forget_rate = forget_rate_schedule[rnd] 125 | # logging.info('remember_rate: %f' % (1-args.forget_rate)) 126 | # if args.exp == 'FedNoRo' and rnd >= args.rounds_FedNoRo_warmup: 127 | if args.exp == 'FedNoRo': 128 | weight_kd = get_current_consistency_weight(rnd, args.begin, args.end) * args.a 129 | logging.info("\n------------------------------> training, run: %d, round: %d <------------------------------" % (run, rnd)) 130 | w_locals, loss_locals = [], [] 131 | class_num_lists, data_nums = [], [] 132 | taos, Prototypes = [], [] 133 | # RoFL 134 | f_locals = [] 135 | for i in tqdm(user_id): # training over the subset 136 | local = trainer_locals_1[i] 137 | if args.exp == 'FedAVG': 138 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train(rnd, 139 | net=deepcopy(netglob).to(args.device), writer1=writer1) 140 | if args.exp == 'FedNoRo': 141 | if rnd < args.rounds_FedNoRo_warmup: 142 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedNoRo( 143 | i, rnd, 144 | net=deepcopy(netglob).to(args.device), writer1=writer1, weight_kd=weight_kd) 145 | # else: 146 | # w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedNoRo( 147 | # i, rnd, 148 | # net=deepcopy(netglob).to(args.device), writer1=writer1, weight_kd = weight_kd, clean_clients=clean_clients, noisy_clients = noisy_clients) 149 | if args.exp == 'CBAFed': 150 | if rnd < args.rounds_CBAFed_warmup: 151 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, class_num_list, data_num = local.train_CBAFed( 152 | rnd, net=deepcopy(netglob).to(args.device)) 153 | else: 154 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, class_num_list, data_num = local.train_CBAFed( 155 | rnd, net=deepcopy(netglob).to(args.device), pt=pt, tao=tao) 156 | class_num_lists.append(deepcopy(class_num_list)) 157 | data_nums.append(deepcopy(data_num)) 158 | if args.exp == 'FedAVG+FixMatch': 159 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FixMatch(rnd, 160 | net=deepcopy(netglob).to(args.device)) 161 | if args.exp == 'FedLSR': 162 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedLSR(rnd, 163 | net=deepcopy(netglob).to(args.device)) 164 | if args.exp == 'RSCFed': 165 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_RSCFed(rnd, 166 | net=deepcopy(netglob).to(args.device)) 167 | # if args.exp == 'RoFL': 168 | # w_local, loss_local, f_k = local.train_RoFL(deepcopy(netglob).to(args.device), deepcopy(f_G).to(args.device), rnd) 169 | if args.exp == 'FedIRM': 170 | if rnd < args.rounds_FedIRM_sup - 1: 171 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedIRM( 172 | rnd, Prototype, writer1, negetive_class_list=None, active_class_list_client_i=None, 173 | net=deepcopy(netglob).to(args.device)) 174 | else: 175 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, Prototype_local = local.train_FedIRM( 176 | rnd, Prototype, writer1, negetive_class_list[i], active_class_list[i], 177 | net=deepcopy(netglob).to(args.device)) 178 | if args.exp == 'FeMLP': 179 | if rnd < args.rounds_FeMLP_stage1-1: 180 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedMLP( 181 | rnd, tao, Prototype, writer1, negetive_class_list=None, active_class_list_client_i=None, net=deepcopy(netglob).to(args.device)) 182 | else: 183 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, tao_local, Prototype_local = local.train_FeMLP( 184 | rnd, tao, Prototype, writer1, negetive_class_list[i], active_class_list[i], net=deepcopy(netglob).to(args.device)) 185 | if rnd == 0 and args.exp != 'RoFL': 186 | active_class_list.append(active_class_list_client) 187 | negetive_class_list.append(negetive_class_list_client) 188 | # store every updated model 189 | if args.exp == 'FedMLP' and rnd >= args.rounds_FedMLP_stage1-1: 190 | taos.append(deepcopy(tao_local)) 191 | Prototypes.append(deepcopy(Prototype_local)) 192 | if args.exp == 'FedIRM' and rnd >= args.rounds_FedIRM_sup-1: 193 | Prototypes.append(deepcopy(Prototype_local.detach().cpu())) 194 | if args.exp == 'RoFL': 195 | f_locals.append(f_k) 196 | w_locals.append(deepcopy(w_local)) 197 | loss_locals.append(deepcopy(loss_local)) 198 | writer1.add_scalar(f'train_run{run}/warm-up-loss/client{i}', loss_local, rnd) 199 | # aggregation 200 | if rnd == 0 and args.exp != 'RoFL': 201 | for i in range(len(active_class_list)): 202 | class_active_client_list.append([]) 203 | class_negative_client_list.append([]) 204 | for j in range(len(active_class_list)): 205 | if i in active_class_list[j]: 206 | class_active_client_list[i].append(j) 207 | if i in negetive_class_list[j]: 208 | class_negative_client_list[i].append(j) 209 | logging.info(class_active_client_list) 210 | logging.info(class_negative_client_list) 211 | assert i == user_id[-1] 212 | assert len(w_locals) == len(dict_len) == args.n_clients 213 | if args.exp == 'RSCFed': 214 | w_glob_fl = RSCFed(DMA, w_locals, K, dict_len, M) 215 | netglob.load_state_dict(deepcopy(w_glob_fl)) 216 | elif args.exp == 'FedMLP': 217 | if rnd < args.rounds_FedMLP_stage1 - 1: 218 | w_glob_fl = FedAvg(w_locals, dict_len) 219 | netglob.load_state_dict(deepcopy(w_glob_fl)) 220 | else: 221 | w_glob_fl = FedAvg(w_locals, dict_len) 222 | netglob.load_state_dict(deepcopy(w_glob_fl)) 223 | tao = FedAvg_tao(taos, dict_len, class_negative_client_list) 224 | print('avg_tao: ', tao) 225 | # if args.miss_client_difficulty == 1: 226 | # tao = FedAvg_tao(taos, dict_len) 227 | # else: 228 | # tao = FedAvg_tao(taos, dict_len, class_active_client_list) 229 | # print('avg_tao: ', tao) 230 | if rnd == args.rounds_FedMLP_stage1 - 1: 231 | Prototype = FedAvg_proto(Prototypes, dict_len, class_active_client_list) 232 | else: 233 | lam = 1.0 234 | Prototype = (1-lam)*Prototype + lam*FedAvg_proto(Prototypes, dict_len, class_active_client_list) 235 | print('rnd: ', rnd, 'ok') 236 | if (rnd + 1) % 10 == 0 and run == 0: 237 | torch.save(netglob.state_dict(), models_dir + f'/model_warmup_globdistill_0.3_0.1_{rnd}.pth') 238 | elif args.exp == 'FedIRM': 239 | if rnd < args.rounds_FedIRM_sup - 1: 240 | w_glob_fl = FedAvg(w_locals, dict_len) 241 | netglob.load_state_dict(deepcopy(w_glob_fl)) 242 | else: 243 | w_glob_fl = FedAvg(w_locals, dict_len) 244 | netglob.load_state_dict(deepcopy(w_glob_fl)) 245 | if rnd == args.rounds_FedIRM_sup - 1: 246 | print(Prototypes) 247 | Prototype = FedAvg_rela(Prototypes, dict_len, class_active_client_list) 248 | print(Prototype) 249 | else: 250 | lam = 1.0 251 | Prototype = (1-lam)*Prototype + lam*FedAvg_rela(Prototypes, dict_len, class_active_client_list) 252 | print('rnd: ', rnd, 'ok') 253 | # elif args.exp == 'RoFL': 254 | # w_glob_fl = FedAvg(w_locals, dict_len) 255 | # netglob.load_state_dict(deepcopy(w_glob_fl)) 256 | # sim = torch.nn.CosineSimilarity(dim=1) 257 | # tmp = 0 258 | # w_sum = 0 259 | # for i in f_locals: 260 | # sim_weight = sim(f_G, i).reshape(2*args.n_classes, 1) 261 | # w_sum += sim_weight 262 | # tmp += sim_weight * i 263 | # # print(sim_weight) 264 | # # print(i) 265 | # for i in range(len(w_sum)): 266 | # if w_sum[i, 0] == 0: 267 | # w_sum[i, 0] = 1 268 | # f_G = torch.div(tmp, w_sum) 269 | elif args.exp == 'FedNoRo': 270 | if rnd < args.rounds_FedNoRo_warmup: 271 | w_glob_fl = FedAvg(w_locals, dict_len) 272 | netglob.load_state_dict(deepcopy(w_glob_fl)) 273 | elif args.exp == 'CBAFed': 274 | if rnd < args.rounds_CBAFed_warmup: 275 | if rnd % 5 != 0: 276 | w_glob_fl = FedAvg(w_locals, dict_len) 277 | netglob.load_state_dict(deepcopy(w_glob_fl)) 278 | else: 279 | if rnd == 0: 280 | w_glob_fl = FedAvg(w_locals, dict_len) 281 | netglob.load_state_dict(deepcopy(w_glob_fl)) 282 | w_glob_res = deepcopy(w_glob_fl) 283 | else: 284 | w_glob_fl = FedAvg(w_locals, dict_len) 285 | for k in w_glob_fl.keys(): 286 | w_glob_fl[k] = 0.2*w_glob_fl[k] + 0.8*w_glob_res[k] 287 | netglob.load_state_dict(deepcopy(w_glob_fl)) 288 | w_glob_res = deepcopy(w_glob_fl) 289 | if rnd >= args.rounds_CBAFed_warmup - 1: 290 | c_num = torch.zeros(args.n_classes) 291 | d_num = 0 292 | for s in user_id: 293 | c_num += class_num_lists[s] 294 | d_num += data_nums[s] 295 | pt = c_num / d_num 296 | avg_pt = pt.sum() / len(pt) 297 | std_pt = torch.sqrt((1/(len(pt)-1))*(((pt-avg_pt)**2).sum())) 298 | tao = pt + 0.45 - std_pt 299 | tao = torch.where(tao > 0.95, 0.95, tao) 300 | tao = torch.where(tao < 0.55, 0.55, tao) 301 | if rnd >= args.rounds_CBAFed_warmup: 302 | wti = (torch.tensor(data_nums) / torch.tensor(data_nums).sum()).tolist() 303 | if (rnd-args.rounds_CBAFed_warmup) % 5 != 0: 304 | w_glob_fl = FedAvg(w_locals, wti) 305 | netglob.load_state_dict(deepcopy(w_glob_fl)) 306 | else: 307 | if (rnd-args.rounds_CBAFed_warmup) == 0: 308 | w_glob_fl = FedAvg(w_locals, wti) 309 | netglob.load_state_dict(deepcopy(w_glob_fl)) 310 | w_glob_res = deepcopy(w_glob_fl) 311 | else: 312 | w_glob_fl = FedAvg(w_locals, wti) 313 | for k in w_glob_fl.keys(): 314 | w_glob_fl[k] = 0.5*w_glob_fl[k] + 0.5*w_glob_res[k] 315 | netglob.load_state_dict(deepcopy(w_glob_fl)) 316 | w_glob_res = deepcopy(w_glob_fl) 317 | else: 318 | w_glob_fl = FedAvg(w_locals, dict_len) 319 | netglob.load_state_dict(deepcopy(w_glob_fl)) 320 | 321 | # validate 322 | if rnd % 10 == 9: 323 | logging.info( 324 | "\n------------------------------> testing, run: %d, round: %d <------------------------------" % ( 325 | run, rnd)) 326 | result = globaltest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args) 327 | mAP, BACC, R, F1, auroc, P, hamming_loss = result["mAP"], result["BACC"], result["R"], result[ 328 | "F1"], result["auc"], result["P"], result["hamming_loss"] 329 | logging.info( 330 | "-----> mAP: %.2f, BACC: %.2f, R: %.2f, F1: %.2f, auc: %.2f, P: %.2f, hamming_loss: %.2f" % ( 331 | mAP, BACC * 100, R * 100, F1 * 100, auroc * 100, P * 100, hamming_loss)) 332 | writer1.add_scalar(f'test_run{run}/mAP', mAP, rnd) 333 | writer1.add_scalar(f'test_run{run}/BACC', BACC, rnd) 334 | writer1.add_scalar(f'test_run{run}/R', R, rnd) 335 | writer1.add_scalar(f'test_run{run}/F1', F1, rnd) 336 | writer1.add_scalar(f'test_run{run}/auc', auroc, rnd) 337 | writer1.add_scalar(f'test_run{run}/P', P, rnd) 338 | writer1.add_scalar(f'test_run{run}/hamming_loss', hamming_loss, rnd) 339 | logging.info('\n') 340 | if rnd == 49: 341 | torch.save(netglob.state_dict(), models_dir + f'/model_{rnd}.pth') 342 | 343 | if rnd % 10 == 9: 344 | logging.info('test') 345 | logging.info("\n------------------------------> testing, run: %d, round: %d <------------------------------" % (run, rnd)) 346 | result = globaltest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args) 347 | mAP, BACC, R, F1, auroc, P, hamming_loss = result["mAP"], result["BACC"], result["R"], result["F1"], result["auc"], result["P"], result["hamming_loss"] 348 | logging.info("-----> mAP: %.2f, BACC: %.2f, R: %.2f, F1: %.2f, auc: %.2f, P: %.2f, hamming_loss: %.2f" % (mAP, BACC*100, R*100, F1*100, auroc*100, P*100, hamming_loss)) 349 | logging.info(np.array(loss_locals)) 350 | writer1.add_scalar(f'corr-test_run{run}/mAP', mAP, rnd) 351 | writer1.add_scalar(f'corr-test_run{run}/BACC', BACC, rnd) 352 | writer1.add_scalar(f'corr-test_run{run}/R', R, rnd) 353 | writer1.add_scalar(f'corr-test_run{run}/F1', F1, rnd) 354 | writer1.add_scalar(f'corr-test_run{run}/auc', auroc, rnd) 355 | writer1.add_scalar(f'corr-test_run{run}/P', P, rnd) 356 | writer1.add_scalar(f'corr-test_run{run}/hamming_loss', hamming_loss, rnd) 357 | logging.info('\n') 358 | 359 | # save model 360 | if (rnd + 1) == args.rounds_corr: 361 | torch.save(netglob.state_dict(), models_dir + f'/corr_model_{run}_{rnd}.pth') 362 | if rnd < 50 and (rnd + 1) % 5 == 0 and run == 0: 363 | torch.save(netglob.state_dict(), models_dir + f'/corr_model_{run}_{rnd}.pth') 364 | 365 | else: # test 366 | netglob = build_model(args) 367 | netglob.load_state_dict(torch.load("/home/szb/multilabel/model_warmup/model_warmup_49.pth")) 368 | result = classtest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args, classid=1) 369 | BACC, R, F1, P = result["BACC"], result["R"], result["F1"], result["P"] 370 | logging.info( 371 | "-----> BACC: %.2f, R: %.2f, F1: %.2f, P: %.2f" % (BACC * 100, R * 100, F1 * 100, P * 100)) 372 | logging.info('\n') 373 | result = classtest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args, classid=4) 374 | BACC, R, F1, P = result["BACC"], result["R"], result["F1"], result["P"] 375 | logging.info( 376 | "-----> BACC: %.2f, R: %.2f, F1: %.2f, P: %.2f" % (BACC * 100, R * 100, F1 * 100, P * 100)) 377 | logging.info('\n') 378 | 379 | torch.cuda.empty_cache() 380 | -------------------------------------------------------------------------------- /utils/local_training.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import math 4 | import random 5 | import time 6 | from copy import deepcopy 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from matplotlib import pyplot as plt 12 | from sklearn.mixture import GaussianMixture 13 | from torch import nn 14 | from torch.cuda.amp import autocast 15 | from torch.utils.data import DataLoader, Dataset 16 | import torch.nn.functional as F 17 | from itertools import chain 18 | import seaborn as sns 19 | 20 | from utils.evaluations import globaltest, classtest 21 | from utils.feature_visual import tnse_Visual 22 | from utils.FedNoRo import LogitAdjust_Multilabel, LA_KD 23 | from utils.utils import max_m_indices, min_n_indices 24 | 25 | 26 | class LocalUpdate(object): 27 | def __init__(self, args, client_id, dataset, idxs, class_pos_idx, class_neg_idx, active_class_list=None, student=None, teacher_neg=None, teacher_act=None, dataset_test=None): 28 | self.teacher_neg = teacher_neg 29 | self.teacher_act = teacher_act 30 | self.dataset_test = dataset_test 31 | self.ema_model = teacher_neg 32 | self.student = student 33 | self.args = args 34 | self.client_id = client_id 35 | self.idxs = idxs 36 | self.dataset = dataset 37 | self.active_class_list = active_class_list 38 | self.local_dataset = DatasetSplit(dataset, idxs, client_id, args, class_neg_idx, active_class_list) 39 | self.class_num_list = self.local_dataset.get_num_of_each_class(args) 40 | self.loss_w = [len(self.local_dataset) / i for i in self.class_num_list] 41 | self.loss_w_unknown = [1] * len(self.class_num_list) 42 | self.loss_w_unknown[client_id] = len(self.local_dataset) / self.class_num_list[client_id] 43 | self.loss_balanced = [1] * len(self.class_num_list) 44 | logging.info(self.loss_w) 45 | logging.info( 46 | f"---> Client{client_id}, each class num: {self.class_num_list}, total num: {len(self.local_dataset)}") 47 | self.ldr_train = DataLoader( 48 | self.local_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=4) 49 | self.epoch = 0 50 | self.iter_num = 0 51 | self.lr = self.args.base_lr 52 | self.class_pos_idx = class_pos_idx 53 | self.class_neg_idx = class_neg_idx 54 | self.flag = True 55 | self.confuse_matrix = torch.zeros((8, 8)).cuda() 56 | 57 | def find_rows(self, tensor, up, down): 58 | condition = torch.all(torch.logical_or(tensor > up, tensor < down), axis=1) 59 | row_indices = torch.where(condition)[0] 60 | return row_indices 61 | 62 | def update_ema_variables(self, model, ema_model, alpha, global_step): 63 | alpha = min(1 - 1 / (global_step + 1), alpha) 64 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 65 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) 66 | 67 | def torch_tile(self, tensor, dim, n): 68 | if dim == 0: 69 | return tensor.unsqueeze(0).transpose(0, 1).repeat(1, n, 1).view(-1, tensor.shape[1]) 70 | else: 71 | return tensor.unsqueeze(0).transpose(0, 1).repeat(1, 1, n).view(tensor.shape[0], -1) 72 | 73 | def get_confuse_matrix(self, logits, labels): 74 | source_prob = [] 75 | for i in range(8): 76 | mask = self.torch_tile(torch.unsqueeze(labels[:, i], -1), 1, 8) 77 | logits_mask_out = logits * mask 78 | logits_avg = torch.sum(logits_mask_out, dim=0) / (torch.sum(labels[:, i]) + 1e-8) 79 | prob = torch.sigmoid(logits_avg / 2.0) 80 | source_prob.append(prob) 81 | return torch.stack(source_prob) 82 | 83 | def sigmoid_rampup(self, current, rampup_length): 84 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 85 | if rampup_length == 0: 86 | return 1.0 87 | else: 88 | current = np.clip(current, 0.0, rampup_length) 89 | phase = 1.0 - current / rampup_length 90 | return float(np.exp(-5.0 * phase * phase)) 91 | def get_current_consistency_weight(self, epoch): 92 | return self.args.consistency * self.sigmoid_rampup(epoch, self.args.consistency_rampup) 93 | 94 | def sigmoid_mse_loss(self, input_logits, target_logits): 95 | """Takes softmax on both sides and returns MSE loss 96 | 97 | Note: 98 | - Returns the sum over all examples. Divide by the batch size afterwards 99 | if you want the mean. 100 | - Sends gradients to inputs but not the targets. 101 | """ 102 | assert input_logits.size() == target_logits.size() 103 | input_softmax = torch.sigmoid(input_logits) 104 | target_softmax = torch.sigmoid(target_logits) 105 | 106 | mse_loss = (input_softmax - target_softmax) ** 2 107 | return mse_loss 108 | 109 | def kd_loss(self, source_matrix, target_matrix): 110 | Q = source_matrix 111 | P = target_matrix 112 | loss = (F.kl_div(Q.log(), P, None, None, 'batchmean') + F.kl_div(P.log(), Q, None, None, 'batchmean')) / 2.0 113 | return loss 114 | 115 | def train_FedNoRo(self, id, rnd, net, writer1, weight_kd = None, clean_clients=None, noisy_clients=None): 116 | assert len(self.ldr_train.dataset) == len(self.idxs) 117 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}") 118 | if rnd < self.args.rounds_FedNoRo_warmup: 119 | student_net = copy.deepcopy(net).cuda() 120 | teacher_net = copy.deepcopy(net).cuda() 121 | student_net.train() 122 | teacher_net.eval() 123 | # set the optimizer 124 | self.optimizer = torch.optim.Adam( 125 | student_net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 126 | 127 | # train and update 128 | epoch_loss = [] 129 | for epoch in range(self.args.local_ep): 130 | batch_loss = [] 131 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train): 132 | if k == 0: 133 | active_class_list_client = [] 134 | negetive_class_list_client = [] 135 | for i in range(self.args.annotation_num): 136 | active_class_list_client.append(active_class_list[i][0].item()) 137 | for i in range(self.args.n_classes): 138 | if i not in active_class_list_client: 139 | negetive_class_list_client.append(i) 140 | self.class_num_list[i] = 0 141 | criterion = LA_KD(cls_num_list=self.class_num_list, num=len(self.ldr_train.dataset), active_class_list_client=active_class_list_client, negative_class_list_client=negetive_class_list_client) 142 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 143 | 144 | _, logits = student_net(images) 145 | # print(logits) 146 | 147 | with torch.no_grad(): 148 | _, teacher_output = teacher_net(images) 149 | soft_label = torch.sigmoid(teacher_output / 0.8) 150 | # print(teacher_output) 151 | logits_sig = torch.sigmoid(logits).cuda() 152 | loss = criterion(logits_sig, labels, soft_label, weight_kd) 153 | 154 | self.optimizer.zero_grad() 155 | loss.backward() 156 | self.optimizer.step() 157 | 158 | batch_loss.append(loss.item()) 159 | self.iter_num += 1 160 | self.epoch = self.epoch + 1 161 | epoch_loss.append(np.array(batch_loss).mean()) 162 | else: 163 | if id in clean_clients: 164 | net.train() 165 | self.optimizer = torch.optim.Adam( 166 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 167 | epoch_loss = [] 168 | ce_criterion = LogitAdjust_Multilabel(cls_num_list=self.class_num_list, num=len(self.ldr_train.dataset)) 169 | for epoch in range(self.args.local_ep): 170 | batch_loss = [] 171 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train): 172 | if k == 0: 173 | active_class_list_client = [] 174 | negetive_class_list_client = [] 175 | for i in range(self.args.annotation_num): 176 | active_class_list_client.append(active_class_list[i][0].item()) 177 | for i in range(self.args.n_classes): 178 | if i not in active_class_list_client: 179 | negetive_class_list_client.append(i) 180 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 181 | _, logits = net(images) 182 | logits_sig = torch.sigmoid(logits).cuda() 183 | loss = ce_criterion(logits_sig, labels) 184 | self.optimizer.zero_grad() 185 | loss.backward() 186 | self.optimizer.step() 187 | batch_loss.append(loss.item()) 188 | self.iter_num += 1 189 | self.epoch = self.epoch + 1 190 | epoch_loss.append(np.array(batch_loss).mean()) 191 | elif id in noisy_clients: 192 | student_net = copy.deepcopy(net).cuda() 193 | teacher_net = copy.deepcopy(net).cuda() 194 | student_net.train() 195 | teacher_net.eval() 196 | # set the optimizer 197 | self.optimizer = torch.optim.Adam( 198 | student_net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 199 | 200 | # train and update 201 | epoch_loss = [] 202 | criterion = LA_KD(cls_num_list=self.class_num_list, num=len(self.ldr_train.dataset)) 203 | 204 | for epoch in range(self.args.local_ep): 205 | batch_loss = [] 206 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train): 207 | if k == 0: 208 | active_class_list_client = [] 209 | negetive_class_list_client = [] 210 | for i in range(self.args.annotation_num): 211 | active_class_list_client.append(active_class_list[i][0].item()) 212 | for i in range(self.args.n_classes): 213 | if i not in active_class_list_client: 214 | negetive_class_list_client.append(i) 215 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 216 | 217 | _, logits = student_net(images) 218 | with torch.no_grad(): 219 | _, teacher_output = teacher_net(images) 220 | soft_label = torch.sigmoid(teacher_output / 0.8) 221 | logits_sig = torch.sigmoid(logits).cuda() 222 | loss = criterion(logits_sig, labels, soft_label, weight_kd) 223 | 224 | self.optimizer.zero_grad() 225 | loss.backward() 226 | self.optimizer.step() 227 | 228 | batch_loss.append(loss.item()) 229 | self.iter_num += 1 230 | self.epoch = self.epoch + 1 231 | epoch_loss.append(np.array(batch_loss).mean()) 232 | student_net.cpu() 233 | self.optimizer.zero_grad() 234 | return student_net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client 235 | 236 | def train_CBAFed(self, rnd, net, pt=None, tao=None): 237 | if rnd < self.args.rounds_CBAFed_warmup: # stage1 238 | class_num_list = torch.zeros(self.args.n_classes) 239 | data_num = 0 240 | net.train() 241 | # set the optimizer 242 | self.optimizer = torch.optim.Adam( 243 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 244 | # train and update 245 | epoch_loss = [] 246 | print(self.loss_w) 247 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), 248 | reduction='none') # include sigmoid 249 | for epoch in range(self.args.local_ep): 250 | print('local_epoch:', epoch) 251 | batch_loss = [] 252 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train): 253 | if j == 0: 254 | active_class_list_client = [] 255 | negetive_class_list_client = [] 256 | for i in range(self.args.annotation_num): 257 | active_class_list_client.append(active_class_list[i][0].item()) 258 | for i in range(self.args.n_classes): 259 | if i not in active_class_list_client: 260 | negetive_class_list_client.append(i) 261 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 262 | # class_num_list = class_num_list + torch.sum(labels, dim=0) 263 | data_num = data_num + len(labels) 264 | _, logits = net(images) 265 | loss_sup = bce_criterion_sup(logits, labels) # tensor(32, 5) 266 | loss_sup = loss_sup[:, active_class_list_client].sum() / ( 267 | self.args.batch_size * self.args.annotation_num) # supervised_loss 268 | loss = loss_sup 269 | self.optimizer.zero_grad() 270 | loss.backward() 271 | self.optimizer.step() 272 | batch_loss.append(loss.item()) 273 | self.iter_num += 1 274 | for id in active_class_list_client: 275 | class_num_list[id] = data_num 276 | self.epoch = self.epoch + 1 277 | epoch_loss.append(np.array(batch_loss).mean()) 278 | return net.state_dict(), np.array( 279 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, class_num_list, data_num 280 | else: 281 | class_num_list = torch.zeros(self.args.n_classes) 282 | data_num = 0 283 | net.train() 284 | self.optimizer = torch.optim.Adam( 285 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 286 | # train and update 287 | epoch_loss = [] 288 | 289 | for epoch in range(self.args.local_ep): 290 | print('local_epoch:', epoch) 291 | batch_loss = [] 292 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train): 293 | idx_neg = [] 294 | if j == 0: 295 | active_class_list_client = [] 296 | negetive_class_list_client = [] 297 | for i in range(self.args.annotation_num): 298 | active_class_list_client.append(active_class_list[i][0].item()) 299 | for i in range(self.args.n_classes): 300 | if i not in active_class_list_client: 301 | negetive_class_list_client.append(i) 302 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 303 | _, logits = net(images) 304 | prob = torch.sigmoid(logits) 305 | # print(negetive_class_list_client) 306 | # print(prob) 307 | for i in negetive_class_list_client: 308 | noise_num = len(torch.where(prob[:, i] > tao[i])[0]) 309 | clean_num = len(torch.where(prob[:, i] < (1-tao[i]))[0]) 310 | labels[:, i] = torch.where(prob[:, i] > tao[i], 1, labels[:, i]) 311 | pseudo_idx = torch.where((prob[:, i] > tao[i]) | (prob[:, i] < (1-tao[i])))[0] 312 | idx_neg.append(pseudo_idx) 313 | class_num_list[i] = class_num_list[i] + len(pseudo_idx) 314 | # print(len(pseudo_idx)) 315 | data_num = data_num + len(pseudo_idx) 316 | if noise_num == 0: 317 | self.loss_w[i] = 1 318 | else: 319 | self.loss_w[i] = (noise_num+clean_num) / noise_num 320 | print(self.loss_w) 321 | for i in active_class_list_client: 322 | class_num_list[i] = class_num_list[i] + len(labels) 323 | data_num = data_num + len(labels)*self.args.annotation_num 324 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), 325 | reduction='none') # include sigmoid 326 | loss_sup = bce_criterion_sup(logits, labels) # tensor(32, 5) 327 | loss_sup_act = loss_sup[:, active_class_list_client].sum() / ( 328 | self.args.batch_size * self.args.annotation_num) # supervised_loss 329 | # loss_sup_act = loss_sup.sum() / (self.args.batch_size * self.args.n_classes) # supervised_loss 330 | loss = loss_sup_act 331 | for k, i in enumerate(negetive_class_list_client): 332 | if len(idx_neg[k]) != 0: 333 | loss += loss_sup[idx_neg[k], i].sum() / len(idx_neg[k]) # supervised_loss 334 | self.optimizer.zero_grad() 335 | loss.backward() 336 | self.optimizer.step() 337 | batch_loss.append(loss.item()) 338 | self.iter_num += 1 339 | self.epoch = self.epoch + 1 340 | epoch_loss.append(np.array(batch_loss).mean()) 341 | return net.state_dict(), np.array( 342 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, class_num_list, data_num 343 | 344 | def train_FedIRM(self, rnd, target_matrix, writer1, negetive_class_list, active_class_list_client_i, net): # MICCAI2021 345 | if rnd < self.args.rounds_FedIRM_sup: 346 | net.train() 347 | # set the optimizer 348 | self.optimizer = torch.optim.Adam( 349 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 350 | # train and update 351 | epoch_loss = [] 352 | self.confuse_matrix = torch.zeros((8, 8)).cuda() 353 | print(self.loss_w) 354 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), 355 | reduction='none') # include sigmoid 356 | for epoch in range(self.args.local_ep): 357 | print('local_epoch:', epoch) 358 | batch_loss = [] 359 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train): 360 | if j == 0: 361 | active_class_list_client = [] 362 | negetive_class_list_client = [] 363 | for i in range(self.args.annotation_num): 364 | active_class_list_client.append(active_class_list[i][0].item()) 365 | for i in range(self.args.n_classes): 366 | if i not in active_class_list_client: 367 | negetive_class_list_client.append(i) 368 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to( 369 | self.args.device), samples["target"].to(self.args.device) 370 | _, logits1 = net(images1) 371 | fe2, logits2 = net(images2) 372 | if rnd == self.args.rounds_FedIRM_sup - 1: # first relation matrix 373 | self.confuse_matrix = self.confuse_matrix + self.get_confuse_matrix(logits1, labels) 374 | loss_sup = bce_criterion_sup(logits1, labels) + bce_criterion_sup(logits2, labels) # tensor(32, 5) 375 | loss_sup = loss_sup[:, active_class_list_client].sum() / ( 376 | self.args.batch_size * self.args.annotation_num) # supervised_loss 377 | self.optimizer.zero_grad() 378 | loss_sup.backward() 379 | self.optimizer.step() 380 | batch_loss.append(loss_sup.item()) 381 | self.epoch = self.epoch + 1 382 | epoch_loss.append(np.array(batch_loss).mean()) 383 | if rnd == self.args.rounds_FedIRM_sup - 1: 384 | with torch.no_grad(): 385 | self.confuse_matrix = self.confuse_matrix / (1.0 * self.args.local_ep * (j + 1)) 386 | return net.state_dict(), np.array( 387 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, self.confuse_matrix 388 | else: 389 | return net.state_dict(), np.array( 390 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client 391 | 392 | else: # Inter-client Relation Matching 393 | if self.flag: 394 | self.ema_model.load_state_dict(net.state_dict()) 395 | self.flag = False 396 | print('done') 397 | net.train() 398 | # set the optimizer 399 | self.optimizer = torch.optim.Adam( 400 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 401 | # train and update 402 | epoch_loss = [] 403 | self.confuse_matrix = torch.zeros((8, 8)).cuda() 404 | print(self.loss_w) 405 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), 406 | reduction='none') # include sigmoid 407 | for epoch in range(self.args.local_ep): 408 | print('local_epoch:', epoch) 409 | batch_loss = [] 410 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train): 411 | if j == 0: 412 | active_class_list_client = [] 413 | negetive_class_list_client = [] 414 | for i in range(self.args.annotation_num): 415 | active_class_list_client.append(active_class_list[i][0].item()) 416 | for i in range(self.args.n_classes): 417 | if i not in active_class_list_client: 418 | negetive_class_list_client.append(i) 419 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to( 420 | self.args.device), samples["target"].to(self.args.device) 421 | _, outputs = net(images1) 422 | with torch.no_grad(): 423 | _, ema_output = self.ema_model(images2) 424 | with torch.no_grad(): 425 | preds = torch.sigmoid(outputs).cuda() 426 | uncertainty = -1.0 * (torch.sum(preds * torch.log(preds + 1e-6), dim=1) + torch.sum( 427 | (1 - preds) * torch.log(1 - preds + 1e-6), dim=1)) 428 | uncertainty_mask = (uncertainty < 2.0) 429 | with torch.no_grad(): 430 | activations = torch.sigmoid(outputs).cuda() 431 | confidence_mask = torch.zeros(len(uncertainty_mask), dtype=bool).cuda() 432 | confidence_mask[self.find_rows(activations, 0.7, 0.3)] = True 433 | mask = confidence_mask * uncertainty_mask 434 | if mask.sum().item() != 0: 435 | pseudo_labels = activations[mask] > 0.5 436 | source_matrix = self.get_confuse_matrix(outputs[mask], pseudo_labels) 437 | else: 438 | source_matrix = 0.5*torch.ones((8, 8)).cuda() 439 | target_matrix = target_matrix.cuda() 440 | print(source_matrix) 441 | print(target_matrix) 442 | consistency_weight = self.get_current_consistency_weight(rnd) 443 | consistency_dist = torch.sum(self.sigmoid_mse_loss(outputs, ema_output)) / self.args.batch_size 444 | consistency_loss = consistency_dist 445 | loss = consistency_weight * consistency_loss + consistency_weight * torch.sum( 446 | self.kd_loss(source_matrix, target_matrix)) 447 | fe2, logits2 = net(images2) 448 | self.confuse_matrix = self.confuse_matrix + self.get_confuse_matrix(outputs, labels) 449 | loss_sup = bce_criterion_sup(outputs, labels) + bce_criterion_sup(logits2, labels) # tensor(32, 5) 450 | loss_sup = loss_sup[:, active_class_list_client].sum() / ( 451 | self.args.batch_size * self.args.annotation_num) # supervised_loss 452 | loss = loss + loss_sup 453 | self.optimizer.zero_grad() 454 | loss.backward() 455 | self.optimizer.step() 456 | self.update_ema_variables(net, self.ema_model, self.args.ema_decay, self.iter_num) 457 | batch_loss.append(loss.item()) 458 | self.iter_num = self.iter_num + 1 459 | self.epoch = self.epoch + 1 460 | epoch_loss.append(np.array(batch_loss).mean()) 461 | with torch.no_grad(): 462 | self.confuse_matrix = self.confuse_matrix / (1.0 * self.args.local_ep * (j + 1)) 463 | return net.state_dict(), np.array( 464 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, self.confuse_matrix 465 | 466 | def train_RoFL(self, net, f_G, epoch): 467 | time1 = time.time() 468 | sim = torch.nn.CosineSimilarity(dim=1) 469 | pseudo_labels = torch.zeros(len(self.dataset), self.args.n_classes, dtype=torch.float64, device=self.args.device) 470 | optimizer = torch.optim.Adam( 471 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 472 | epoch_loss = [] 473 | net.eval() 474 | f_k = torch.zeros(2*self.args.n_classes, self.args.feature_dim, device=self.args.device) 475 | n_labels = torch.zeros(2*self.args.n_classes, 1, device=self.args.device) 476 | 477 | # obtain global-guided pseudo labels y_hat by y_hat_k = C_G(F_G(x_k)) 478 | # initialization of global centroids 479 | # obtain naive average feature 480 | with torch.no_grad(): 481 | for i, (samples, item, active_class_list) in enumerate(self.ldr_train): 482 | if i == 0: 483 | active_class_list_client = [] 484 | negetive_class_list_client = [] 485 | for i in range(self.args.annotation_num): 486 | active_class_list_client.append(active_class_list[i][0].item()) 487 | for i in range(self.args.n_classes): 488 | if i not in active_class_list_client: 489 | negetive_class_list_client.append(i) 490 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 491 | feature, logit = net(images) 492 | probs = torch.sigmoid(logit) # soft predict 493 | accuracy_th = 0.5 494 | preds = probs > accuracy_th # hard predict 495 | preds = preds.to(torch.float64) 496 | pseudo_labels[item] = preds 497 | if epoch == 0: 498 | for cls in range(self.args.n_classes): 499 | f_k[2*cls] += torch.sum(feature[torch.where(labels[:, cls] == 0)[0]], dim=0) 500 | f_k[2*cls+1] += torch.sum(feature[torch.where(labels[:, cls] == 1)[0]], dim=0) 501 | n_labels[2*cls] += len(torch.where(labels[:, cls] == 0)[0]) 502 | n_labels[2*cls+1] += len(torch.where(labels[:, cls] == 1)[0]) 503 | 504 | if epoch == 0: 505 | for i in range(len(n_labels)): 506 | if n_labels[i] == 0: 507 | n_labels[i] = 1 508 | f_k = torch.div(f_k, n_labels) 509 | else: 510 | f_k = f_G 511 | time2 = time.time() 512 | # print('local test time: ', time2-time1) 513 | net.train() 514 | for iter in range(self.args.local_ep): 515 | batch_loss = [] 516 | for samples, item, active_class_list in self.ldr_train: 517 | time4 = time.time() 518 | net.zero_grad() 519 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 520 | feature, logit = net(images) 521 | feature = feature.detach() 522 | f_k = f_k.to(self.args.device) 523 | 524 | small_loss_idxs, loss_w = self.get_small_loss_samples(logit, labels, self.args.forget_rate, negetive_class_list_client) 525 | 526 | y_k_tilde = torch.zeros(self.args.batch_size, self.args.n_classes, device=self.args.device) 527 | mask = torch.zeros(self.args.batch_size, device=self.args.device) 528 | for i in small_loss_idxs: 529 | for cls in range(self.args.n_classes): 530 | f_cls = f_k[[2*cls, 2*cls+1], :] 531 | y_k_tilde[i, cls] = torch.argmax(sim(f_cls, torch.reshape(feature[i], (1, self.args.feature_dim)))) 532 | if torch.equal(y_k_tilde[i], labels[i]): 533 | mask[i] = 1 534 | 535 | # When to use pseudo-labels 536 | if epoch < self.args.T_pl: 537 | for i in small_loss_idxs: 538 | pseudo_labels[item[i]] = labels[i] 539 | 540 | # For loss calculating 541 | mask_resize = mask.unsqueeze(1).repeat([1, 5]) 542 | new_labels = mask_resize[small_loss_idxs] * labels[small_loss_idxs] + (1 - mask_resize[small_loss_idxs]) * \ 543 | pseudo_labels[item[small_loss_idxs]] 544 | new_labels = new_labels.type(torch.float).to(self.args.device) 545 | time5 = time.time() 546 | # print('batch train prepare time: ', time5-time4) 547 | loss = self.RFLloss(logit, labels, feature, f_k, mask, small_loss_idxs, new_labels, loss_w, epoch) 548 | # print('loss: ', loss) 549 | # weight update by minimizing loss: L_total = L_c + lambda_cen * L_cen + lambda_e * L_e 550 | loss.backward() 551 | optimizer.step() 552 | 553 | # obtain loss based average features f_k,j_hat from small loss dataset 554 | f_kj_hat = torch.zeros(2*self.args.n_classes, self.args.feature_dim, device=self.args.device) 555 | n = torch.zeros(2*self.args.n_classes, 1, device=self.args.device) 556 | for i in small_loss_idxs: 557 | for cls in range(self.args.n_classes): 558 | if labels[i, cls] == 0: 559 | f_kj_hat[2 * cls] += feature[i] 560 | n[2 * cls] += 1 561 | else: 562 | f_kj_hat[2 * cls + 1] += feature[i] 563 | n[2 * cls + 1] += 1 564 | for i in range(len(n)): 565 | if n[i] == 0: 566 | n[i] = 1 567 | f_kj_hat = torch.div(f_kj_hat, n) 568 | 569 | # update local centroid f_k 570 | one = torch.ones(2*self.args.n_classes, 1, device=self.args.device) 571 | f_k = (one - sim(f_k, f_kj_hat).reshape(2*self.args.n_classes, 1) ** 2) * f_k + ( 572 | sim(f_k, f_kj_hat).reshape(2*self.args.n_classes, 1) ** 2) * f_kj_hat 573 | 574 | batch_loss.append(loss.item()) 575 | time6 = time.time() 576 | # print('batch train time: ', time6 - time5) 577 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 578 | time3 = time.time() 579 | # print('local train time: ', time3 - time2) 580 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss), f_k 581 | 582 | def RFLloss(self, logit, labels, feature, f_k, mask, small_loss_idxs, new_labels, loss_w, epoch): 583 | mse = torch.nn.MSELoss(reduction='none') 584 | bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda()) # include sigmoid 585 | L_c = bce(logit[small_loss_idxs], new_labels) 586 | prob = torch.sigmoid(logit).cuda() 587 | for i in range(self.args.n_classes): 588 | if i == 0: 589 | L_cen = torch.sum( 590 | mask[small_loss_idxs] * torch.sum(mse(feature[small_loss_idxs], f_k[(2*i+labels[small_loss_idxs, i]).cpu().numpy()]), 1))/(len(small_loss_idxs)*self.args.feature_dim) 591 | else: 592 | L_cen += torch.sum( 593 | mask[small_loss_idxs] * torch.sum(mse(feature[small_loss_idxs], f_k[(2*i+labels[small_loss_idxs, i]).cpu().numpy()]), 1))/(len(small_loss_idxs)*self.args.feature_dim) 594 | L_cen = L_cen / self.args.n_classes 595 | for i in range(self.args.n_classes): 596 | clsi_prob = prob[:, i].unsqueeze(1) 597 | clsi_prob = torch.cat((clsi_prob, 1-clsi_prob), dim=1).cuda() 598 | if i == 0: 599 | L_e = -torch.mean(torch.sum(clsi_prob[small_loss_idxs] * torch.log(clsi_prob[small_loss_idxs]), dim=1)) 600 | else: 601 | L_e += -torch.mean(torch.sum(clsi_prob[small_loss_idxs] * torch.log(clsi_prob[small_loss_idxs]), dim=1)) 602 | L_e = L_e / self.args.n_classes 603 | lambda_e = self.args.lambda_e 604 | lambda_cen = self.args.lambda_cen 605 | if epoch < self.args.T_pl: 606 | lambda_cen = (self.args.lambda_cen * epoch) / self.args.T_pl 607 | # print('L_c: ', L_c.item(), 'L_cen: ', L_cen.item(), 'L_e: ', L_e.item()) 608 | if math.isnan(L_c.item()) or math.isnan(L_cen.item()) or math.isnan(L_e.item()): 609 | print(logit) 610 | print(feature) 611 | print(loss_w) 612 | print('loss: ', L_c + (lambda_cen * L_cen) + (lambda_e * L_e)) 613 | return L_c + (lambda_cen * L_cen) + (lambda_e * L_e) 614 | 615 | def get_small_loss_samples(self, y_pred, y_true, forget_rate, negetive_class_list_client): 616 | loss_w = self.loss_w 617 | for i in negetive_class_list_client: 618 | loss_w[i] = 5. 619 | loss_func = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda(), reduction='none') # include sigmoid 620 | loss = torch.sum(loss_func(y_pred, y_true), dim=1) 621 | ind_sorted = np.argsort(loss.data.cpu()).cuda() 622 | loss_sorted = loss[ind_sorted] 623 | remember_rate = 1 - forget_rate 624 | num_remember = int(remember_rate * len(loss_sorted)) 625 | ind_update = ind_sorted[:num_remember] 626 | return ind_update, loss_w 627 | 628 | def train(self, rnd, net, writer1): 629 | # teacher_neg = deepcopy(net).to(self.args.device) # try 630 | assert len(self.ldr_train.dataset) == len(self.idxs) 631 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}") 632 | 633 | net.train() 634 | # teacher_neg.eval() # try 635 | # set the optimizer 636 | self.optimizer = torch.optim.Adam( 637 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 638 | 639 | # train and update 640 | epoch_loss = [] 641 | print(self.loss_w) 642 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), reduction='none') # include sigmoid 643 | active_class_list_client = [] 644 | negetive_class_list_client = [] 645 | # mse = nn.MSELoss() # try 646 | for epoch in range(self.args.local_ep): 647 | print('local_epoch:', epoch) 648 | batch_loss = [] 649 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train): 650 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 651 | if k == 0: 652 | for i in range(self.args.annotation_num): 653 | active_class_list_client.append(active_class_list[i][0].item()) 654 | for i in range(self.args.n_classes): 655 | if i not in active_class_list_client: 656 | negetive_class_list_client.append(i) 657 | _, logits = net(images) 658 | 659 | # logits1_sig = torch.sigmoid(logits).cuda() # try 660 | # with torch.no_grad(): 661 | # _, logits2 = teacher_neg(images) 662 | # logits2_sig = torch.sigmoid(logits2).cuda() 663 | 664 | loss = bce_criterion(logits, labels) # tensor(32, 5) 665 | loss = loss.sum()/(self.args.batch_size * self.args.n_classes) # all_class_loss 666 | # loss = loss[:, active_class_list_client].sum()/(self.args.batch_size * self.args.annotation_num) 667 | # mask = torch.ones_like(loss) # active_class_loss 668 | # for i in negetive_class_list_client: 669 | # mask[:, i] = 0. 670 | # loss = (loss * mask).sum()/(self.args.batch_size * self.args.annotation_num) 671 | # loss += mse(logits1_sig[:, negetive_class_list_client], logits2_sig[:, negetive_class_list_client]) # try 672 | 673 | self.optimizer.zero_grad() 674 | loss.backward() 675 | self.optimizer.step() 676 | 677 | batch_loss.append(loss.item()) 678 | 679 | self.iter_num += 1 680 | # if rnd % 5 == 0: 681 | # for j in range(len(negetive_class_list_client)): 682 | # plt.title(f'round:{rnd},epoch:{epoch},client:{self.client_id},miss class:{negetive_class_list_client[j]} loss distribution') 683 | # sns.kdeplot(loss_false_negetive[j], label='FN') 684 | # sns.kdeplot(loss_true_negetive[j], label='TN') 685 | # plt.legend() 686 | # plt.savefig(f'loss_fig/round:{rnd},epoch:{epoch},client:{self.client_id},miss class:{negetive_class_list_client[j]}_loss_distribution.png') 687 | # print('ok') 688 | 689 | # for u in range(5): 690 | # result = classtest(deepcopy(net).cuda(), test_dataset=self.dataset_test, args=self.args, classid=u) 691 | # BACC, R, F1, P= result["BACC"], result["R"], result["F1"], result["P"] 692 | # logging.info( 693 | # "-----> BACC: %.2f, R: %.2f, F1: %.2f, P: %.2f" % (BACC * 100, R * 100, F1 * 100, P * 100)) 694 | # writer1.add_scalar(f'test_client_class{u}/BACC', BACC, epoch) 695 | # writer1.add_scalar(f'test_client_class{u}/R', R, epoch) 696 | # writer1.add_scalar(f'test_client_class{u}/F1', F1, epoch) 697 | # writer1.add_scalar(f'test_client_class{u}/P', P, epoch) 698 | 699 | self.epoch = self.epoch + 1 700 | epoch_loss.append(np.array(batch_loss).mean()) 701 | net.cpu() 702 | self.optimizer.zero_grad() 703 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client 704 | 705 | def train_RSCFed(self, rnd, net): 706 | assert len(self.ldr_train.dataset) == len(self.idxs) 707 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}") 708 | self.student = deepcopy(net).to(self.args.device) 709 | self.teacher_neg.eval() 710 | self.student.train() 711 | 712 | # set the optimizer 713 | self.optimizer = torch.optim.Adam( 714 | self.student.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 715 | 716 | # train and update 717 | epoch_loss = [] 718 | print(self.loss_w) 719 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), reduction='none') # include sigmoid 720 | mse_loss = nn.MSELoss() 721 | 722 | for epoch in range(self.args.local_ep): 723 | print('local_epoch:', epoch) 724 | batch_loss = [] 725 | for samples, item, active_class_list in self.ldr_train: 726 | active_class_list_client = [] 727 | negetive_class_list_client = [] 728 | for i in range(self.args.annotation_num): 729 | active_class_list_client.append(active_class_list[i][0].item()) 730 | for i in range(self.args.n_classes): 731 | if i not in active_class_list_client: 732 | negetive_class_list_client.append(i) 733 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to(self.args.device) 734 | _, logits1_stu = self.student(images1) 735 | logits1_stu_sig = torch.sigmoid(logits1_stu).cuda() 736 | 737 | with torch.no_grad(): 738 | _, logits2_tea = self.teacher_neg(images2) 739 | # _, logits2_tea = self.teacher_neg(images1) 740 | logits2_tea_sig = torch.sigmoid(logits2_tea).cuda() 741 | loss = bce_criterion(logits1_stu, labels) # tensor(32, 5) 742 | loss_sup = loss[:, active_class_list_client].sum()/(self.args.batch_size * self.args.annotation_num) # supervised_loss 743 | loss_unsup = mse_loss(logits1_stu_sig[:, negetive_class_list_client], logits2_tea_sig[:, negetive_class_list_client]) 744 | loss = loss_sup + loss_unsup 745 | # loss = loss_sup 746 | 747 | self.optimizer.zero_grad() 748 | loss.backward() 749 | self.optimizer.step() 750 | 751 | # update teacher 752 | state_dict1 = self.teacher_neg.state_dict() 753 | state_dict2 = deepcopy(self.student).state_dict() 754 | weight1 = 1 - 0.001 755 | weight2 = 0.001 756 | weighted_state_dict = {} 757 | for name in state_dict1: 758 | weighted_state_dict[name] = weight1 * state_dict1[name] + weight2 * state_dict2[name] 759 | self.teacher_neg.load_state_dict(weighted_state_dict) 760 | self.teacher_neg = self.teacher_neg.to(self.args.device) 761 | batch_loss.append(loss.item()) 762 | self.iter_num += 1 763 | 764 | self.epoch = self.epoch + 1 765 | epoch_loss.append(np.array(batch_loss).mean()) 766 | 767 | net.cpu() 768 | self.optimizer.zero_grad() 769 | return self.student.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client 770 | 771 | def train_FixMatch(self, rnd, net): 772 | assert len(self.ldr_train.dataset) == len(self.idxs) 773 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}") 774 | net.train() 775 | 776 | # set the optimizer 777 | self.optimizer = torch.optim.Adam( 778 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 779 | 780 | # train and update 781 | epoch_loss = [] 782 | print(self.loss_w) 783 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), reduction='none') # include sigmoid 784 | bce_criterion_unsup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w_unknown).cuda(), reduction='none') # include sigmoid 785 | 786 | for epoch in range(self.args.local_ep): 787 | print('local_epoch:', epoch) 788 | batch_loss = [] 789 | for samples, item, active_class_list in self.ldr_train: 790 | active_class_list_client = [] 791 | negetive_class_list_client = [] 792 | for i in range(self.args.annotation_num): 793 | active_class_list_client.append(active_class_list[i][0].item()) 794 | for i in range(self.args.n_classes): 795 | if i not in active_class_list_client: 796 | negetive_class_list_client.append(i) 797 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to(self.args.device) 798 | _, logits_weak = net(images1) 799 | logits_weak_sig = torch.sigmoid(logits_weak).cuda() 800 | idx = set(range(self.args.batch_size)) 801 | for c in negetive_class_list_client: 802 | idx = idx.intersection(set(torch.where(logits_weak_sig[:, c] > 0.8)[0].tolist()).union(set(torch.where(logits_weak_sig[:, c] < 0.2)[0].tolist()))) 803 | idx = list(idx) 804 | logits_weak_sig = torch.where(logits_weak_sig > 0.5, 1.0, 0.0) # turn to hard label 805 | _, logits_strong = net(images2) 806 | loss = bce_criterion(logits_weak, labels) # tensor(32, 5) 807 | loss_sup = loss[:, active_class_list_client].sum()/(self.args.batch_size * self.args.annotation_num) # supervised_loss 808 | if len(idx) == 0 or len(negetive_class_list_client) == 0: 809 | loss = loss_sup 810 | else: 811 | loss = bce_criterion_unsup(logits_strong, logits_weak_sig) 812 | loss_unsup = loss[idx, :][:, negetive_class_list_client].sum()/(len(idx) * (self.args.n_classes - self.args.annotation_num)) # supervised_loss 813 | # print('loss_sup: ', loss_sup) 814 | # print('loss_unsup: ', loss_unsup) 815 | loss = loss_sup + 1. * loss_unsup 816 | self.optimizer.zero_grad() 817 | loss.backward() 818 | self.optimizer.step() 819 | batch_loss.append(loss.item()) 820 | self.iter_num += 1 821 | self.epoch = self.epoch + 1 822 | epoch_loss.append(np.array(batch_loss).mean()) 823 | net.cpu() 824 | self.optimizer.zero_grad() 825 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client 826 | 827 | def mixup_criterion(self, y_a, y_b, lam): 828 | return lambda criterion, pred: (lam * criterion(pred, y_a).T).T + ((1 - lam) * criterion(pred, y_b).T).T 829 | 830 | def test_loss(self, rnd, net, model_name): 831 | net.eval() 832 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), 833 | reduction='none') # include sigmoid 834 | epoch_false_negetive_loss = [] 835 | epoch_true_negetive_loss = [] 836 | loss_false_negetive = [] 837 | loss_true_negetive = [] 838 | for i in range(self.args.n_classes - self.args.annotation_num): 839 | loss_false_negetive.append([]) 840 | loss_true_negetive.append([]) 841 | 842 | active_class_list_client = [] 843 | feature = torch.tensor([]).cuda() 844 | flags = torch.zeros([self.args.n_classes - self.args.annotation_num, len(self.ldr_train.dataset)]).cuda() # 0:FN, 1:TN 845 | flag_class_one = torch.zeros([1, len(self.ldr_train.dataset)]).cuda() 846 | # print(flags.shape) 847 | # input() 848 | count = 0 849 | with torch.no_grad(): 850 | for samples, item, active_class_list in self.ldr_train: 851 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device) 852 | for i in range(len(item)): 853 | flag_class_one[0, count*len(item)+i] = int(labels[i, 1] == 0) 854 | for i in range(self.args.annotation_num): 855 | active_class_list_client.append(active_class_list[i][0].item()) 856 | negetive_class_list_client = [] 857 | feature_batch, logits = net(images) 858 | feature = torch.cat((feature, feature_batch), dim=0) 859 | 860 | loss = bce_criterion(logits, labels) # tensor(32, 5) 861 | class_miss_loss = [] 862 | for i in range(self.args.n_classes): 863 | if i not in active_class_list_client: 864 | negetive_class_list_client.append(i) 865 | class_miss_loss.append(loss[:, i].clone().detach().cpu()) 866 | 867 | for i in range(len(item)): 868 | for j in range(len(negetive_class_list_client)): 869 | if item[i].item() in self.class_neg_idx[negetive_class_list_client[j]]: 870 | loss_false_negetive[j].append(class_miss_loss[j][i].item()) 871 | flags[j, i+count*len(item)] = 0 872 | 873 | elif item[i].item() in self.class_pos_idx[negetive_class_list_client[j]]: 874 | pass 875 | else: 876 | loss_true_negetive[j].append(class_miss_loss[j][i].item()) 877 | flags[j, i+count*len(item)] = 1 878 | count += 1 879 | 880 | # for j in range(len(negetive_class_list_client)): 881 | # plt.title( 882 | # f'model:{model_name},round:{rnd},client:{self.client_id},miss class:{negetive_class_list_client[j]} test loss distribution') 883 | # sns.kdeplot(loss_false_negetive[j], label='FN') 884 | # sns.kdeplot(loss_true_negetive[j], label='TN') 885 | # plt.legend() 886 | # plt.savefig( 887 | # f'loss_fig/model:{model_name},round:{rnd},client:{self.client_id},miss class:{negetive_class_list_client[j]}_test_loss_distribution.png') 888 | # plt.clf() 889 | 890 | for j in range(len(negetive_class_list_client)): 891 | tnse_Visual(feature.cpu(), flags[j].cpu(), rnd, f'model{model_name} class{negetive_class_list_client[j]} p=1') 892 | 893 | tnse_Visual(feature.cpu(), flag_class_one[0].cpu(), rnd, f'model{model_name} class{1} p=1') 894 | 895 | for j in range(len(negetive_class_list_client)): 896 | epoch_false_negetive_loss.append(np.array(loss_false_negetive[j]).mean()) 897 | epoch_true_negetive_loss.append(np.array(loss_true_negetive[j]).mean()) 898 | net.cpu() 899 | return epoch_false_negetive_loss, epoch_true_negetive_loss 900 | 901 | def find_indices_in_a(self, a, b): 902 | return torch.where(a.unsqueeze(0) == b.unsqueeze(1))[1] 903 | 904 | def train_FedMLP(self, rnd, tao, Prototype, writer1, negetive_class_list, active_class_list_client_i, net): # my method 905 | # assert len(self.ldr_train.dataset) == len(self.idxs) 906 | # print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}") 907 | if rnd < self.args.rounds_FedMLP_stage1: # stage1 908 | glob_model = deepcopy(net) 909 | net.train() 910 | glob_model.eval() 911 | # set the optimizer 912 | self.optimizer = torch.optim.Adam( 913 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 914 | # train and update 915 | epoch_loss = [] 916 | print(self.loss_w) 917 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), 918 | reduction='none') # include sigmoid 919 | bce_criterion_unsup = nn.MSELoss() 920 | for epoch in range(self.args.local_ep): 921 | print('local_epoch:', epoch) 922 | batch_loss = [] 923 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train): 924 | if j == 0: 925 | active_class_list_client = [] 926 | negetive_class_list_client = [] 927 | for i in range(self.args.annotation_num): 928 | active_class_list_client.append(active_class_list[i][0].item()) 929 | for i in range(self.args.n_classes): 930 | if i not in active_class_list_client: 931 | negetive_class_list_client.append(i) 932 | self.class_num_list[i] = 0 # try noro 933 | criterion = LogitAdjust_Multilabel(cls_num_list=self.class_num_list, num=len(self.idxs)) 934 | mse_loss = nn.MSELoss(reduction='none') 935 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to( 936 | self.args.device), samples["target"].to(self.args.device) 937 | fe1, logits1 = net(images1) 938 | logits1_sig = torch.sigmoid(logits1).cuda() 939 | _, logits2 = net(images2) 940 | logits2_sig = torch.sigmoid(logits2).cuda() 941 | # loss_sup1 = bce_criterion_sup(logits1, labels) # tensor(32, 5) 942 | # loss_sup2 = bce_criterion_sup(logits2, labels) # tensor(32, 5) 943 | with torch.no_grad(): 944 | _, outputs_global = glob_model(images1) 945 | logits3 = torch.sigmoid(outputs_global).cuda() 946 | _, outputs_global = glob_model(images2) 947 | logits4 = torch.sigmoid(outputs_global).cuda() 948 | loss_dis1 = mse_loss(logits1_sig, logits3).cuda() 949 | loss_dis2 = mse_loss(logits2_sig, logits4).cuda() 950 | loss_dis = (loss_dis1 + loss_dis2) / 2. 951 | loss_sup1 = criterion(logits1_sig, labels) # tensor(32, 5) 952 | loss_sup2 = criterion(logits2_sig, labels) # tensor(32, 5) 953 | loss_sup = (loss_sup1 + loss_sup2) / 2. 954 | # loss_sup = loss_sup.sum() / (self.args.batch_size * self.args.n_classes) # supervised_loss 955 | 956 | loss_sup = loss_sup[:, active_class_list_client].sum() / ( 957 | self.args.batch_size * self.args.annotation_num) # supervised_loss 958 | loss_dis = loss_dis[:, negetive_class_list_client].sum() / ( 959 | self.args.batch_size * len(negetive_class_list_client)) # supervised_loss 960 | 961 | loss_unsup = bce_criterion_unsup(logits1_sig[:, negetive_class_list_client], 962 | logits2_sig[:, negetive_class_list_client]) 963 | loss = loss_sup + 0.0*loss_unsup + loss_dis 964 | self.optimizer.zero_grad() 965 | loss.backward() 966 | self.optimizer.step() 967 | batch_loss.append(loss.item()) 968 | self.iter_num += 1 969 | self.epoch = self.epoch + 1 970 | epoch_loss.append(np.array(batch_loss).mean()) 971 | if rnd == self.args.rounds_FedMLP_stage1 - 1: # first tao and proto 972 | print('client: ', self.client_id, active_class_list_client) 973 | proto = torch.zeros((self.args.n_classes * 2, len(fe1[0]))) # [cls0proto0, cls0proto1, cls1proto0...] 974 | # proto = np.array([torch.zeros_like(fe1[0].cpu())] * self.args.n_classes * 2) 975 | num_proto = [0] * self.args.n_classes * 2 976 | t = np.array([0] * self.args.n_classes) 977 | test_loader = DataLoader(dataset=self.local_dataset, batch_size=self.args.batch_size * 4, shuffle=False, 978 | num_workers=8) 979 | net.eval() 980 | with torch.no_grad(): 981 | for samples, _, _ in test_loader: 982 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to(self.args.device) 983 | feature, outputs = net(images1) 984 | probs = torch.sigmoid(outputs) # soft predict 985 | for cls in active_class_list_client: 986 | idx0 = torch.where(labels[:, cls] == 0)[0].tolist() 987 | idx1 = torch.where(labels[:, cls] == 1)[0].tolist() 988 | num_proto[2*cls] += len(idx0) 989 | num_proto[2*cls+1] += len(idx1) 990 | proto[2*cls] = feature[idx0, :].sum(0).cpu() + proto[2*cls] 991 | proto[2*cls+1] = feature[idx1, :].sum(0).cpu() + proto[2*cls+1] 992 | # t[cls] += torch.sum(probs[idx0, cls] < self.args.L).item() + torch.sum( 993 | # probs[idx1, cls] > self.args.U).item() 994 | for cls in negetive_class_list: 995 | t[cls] += torch.sum( 996 | torch.logical_or(probs[:, cls] < self.args.L, probs[:, cls] > self.args.U)).item() 997 | for cls in active_class_list_client: 998 | proto[2*cls] = proto[2*cls] / num_proto[2*cls] 999 | proto[2*cls+1] = proto[2*cls+1] / num_proto[2*cls+1] 1000 | t = t / len(self.local_dataset) 1001 | print('local_t: ', t) 1002 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, t, proto 1003 | else: 1004 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client 1005 | 1006 | else: # stage2 1007 | print(self.local_dataset.active_class_list) 1008 | # find train samples for each class 1009 | idx = [] # [[negcls1], [negcls2]] 1010 | idxss = [] 1011 | feature = [] 1012 | similarity = [] 1013 | clean_idx = [] 1014 | noise_idx = [] 1015 | label = [] # [[negcls1], [negcls2]] 1016 | num_train = 0 1017 | glob_model = deepcopy(net) 1018 | net.eval() 1019 | class_idx = torch.tensor([]) 1020 | l = torch.tensor([]).cuda() 1021 | f = torch.tensor([]).cuda() 1022 | t1 = time.time() 1023 | if rnd == self.args.rounds_FedMLP_stage1: 1024 | # print(len(self.ldr_train.dataset)) 1025 | self.traindata_idx = [] # [[negcls1_clean_train_idx], [negcls1_noise_train_idx], [negcls2_clean_train_idx], [negcls2_noise_train_idx]] idx 1026 | for samples, item, active_class_list in self.ldr_train: 1027 | class_idx = torch.cat((class_idx, item), dim=0) 1028 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to(self.args.device) 1029 | with torch.no_grad(): 1030 | features, _ = net(images1) 1031 | f = torch.cat((f, features), dim=0) 1032 | l = torch.cat((l, labels), dim=0) 1033 | for i in range(len(negetive_class_list)): 1034 | feature.append(f) # miss n classes 1035 | idx.append(class_idx) 1036 | label.append(l) 1037 | else: 1038 | for samples, item, active_class_list in self.ldr_train: 1039 | class_idx = torch.cat((class_idx, item), dim=0) 1040 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to(self.args.device) 1041 | with torch.no_grad(): 1042 | features, _ = net(images1) 1043 | f = torch.cat((f, features), dim=0) 1044 | l = torch.cat((l, labels), dim=0) 1045 | for i in range(len(self.idxss)): 1046 | result_indices = self.find_indices_in_a(class_idx, torch.tensor(self.idxss[i])) 1047 | feature.append(f[result_indices]) 1048 | idx.append(class_idx[result_indices]) 1049 | label.append(l[result_indices]) 1050 | t2 = time.time() 1051 | # print('feature_label_prepare_time: ', t2-t1) 1052 | for i, cls in enumerate(negetive_class_list): 1053 | # sim = [] 1054 | proto_0 = Prototype[2*cls] 1055 | proto_1 = Prototype[2*cls+1] 1056 | model = CosineSimilarityFast().cuda() 1057 | sim = (model(feature[i], torch.unsqueeze(proto_0.cuda(), dim=0)) - model(feature[i], torch.unsqueeze(proto_1.cuda(), dim=0))).tolist() 1058 | similarity.append(sim) 1059 | t3 = time.time() 1060 | # print('sim_compute_time: ', t3 - t2) 1061 | for i in range(len(negetive_class_list)): 1062 | idx_0 = np.where(np.array(similarity[i]) >= 0)[0] 1063 | idx_1 = np.where(np.array(similarity[i]) < 0)[0] 1064 | clean_idx.append(idx_0.tolist()) 1065 | noise_idx.append(idx_1.tolist()) 1066 | if rnd == self.args.rounds_FedMLP_stage1: 1067 | for i, cls in enumerate(negetive_class_list): 1068 | print('cls', cls, 'tao: ', tao[cls]) 1069 | num_clean_cls = int(1 * self.args.clean_threshold * len(clean_idx[i])) 1070 | num_noise_cls = int(1 * self.args.noise_threshold * len(noise_idx[i])) 1071 | # num_clean_cls = int(tao[cls] * len(clean_idx[i])) 1072 | # num_noise_cls = int(tao[cls] * len(noise_idx[i])) 1073 | num_train = num_train + num_noise_cls + num_clean_cls 1074 | max_m_indices_list = np.array(max_m_indices(similarity[i], num_clean_cls)) 1075 | min_n_indices_list = np.array(min_n_indices(similarity[i], num_noise_cls)) 1076 | if len(max_m_indices_list) == 0 and len(max_m_indices_list) == 0: 1077 | negcls_clean_train_idx = [] 1078 | negcls_noise_train_idx = [] 1079 | elif len(min_n_indices_list) == 0 and len(max_m_indices_list) != 0: 1080 | negcls_noise_train_idx = [] 1081 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist() 1082 | elif len(min_n_indices_list) != 0 and len(max_m_indices_list) == 0: 1083 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist() 1084 | negcls_clean_train_idx = [] 1085 | else: 1086 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist() 1087 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist() 1088 | self.traindata_idx.append(negcls_clean_train_idx) 1089 | self.traindata_idx.append(negcls_noise_train_idx) 1090 | else: 1091 | for i, cls in enumerate(negetive_class_list): 1092 | print('cls', cls, 'tao: ', tao[cls]) 1093 | num_clean_cls = int(1 * self.args.clean_threshold * len(clean_idx[i])) 1094 | num_noise_cls = int(1 * self.args.noise_threshold * len(noise_idx[i])) 1095 | num_train = num_train + num_noise_cls + num_clean_cls 1096 | max_m_indices_list = np.array(max_m_indices(similarity[i], num_clean_cls)) 1097 | min_n_indices_list = np.array(min_n_indices(similarity[i], num_noise_cls)) 1098 | 1099 | if len(max_m_indices_list) == 0 and len(max_m_indices_list) == 0: 1100 | negcls_clean_train_idx = [] 1101 | negcls_noise_train_idx = [] 1102 | elif len(min_n_indices_list) == 0 and len(max_m_indices_list) != 0: 1103 | negcls_noise_train_idx = [] 1104 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist() 1105 | elif len(min_n_indices_list) != 0 and len(max_m_indices_list) == 0: 1106 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist() 1107 | negcls_clean_train_idx = [] 1108 | else: 1109 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist() 1110 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist() 1111 | self.traindata_idx[2*i].extend(negcls_clean_train_idx) 1112 | self.traindata_idx[2*i+1].extend(negcls_noise_train_idx) 1113 | 1114 | t4 = time.time() 1115 | # print('traindata_split_time: ', t4 - t3) 1116 | 1117 | for i, cls in enumerate(negetive_class_list): 1118 | print('class: ', cls, 'clean_train_samples: ', len(self.traindata_idx[2*i])) 1119 | print('class: ', cls, 'noise_train_samples: ', len(self.traindata_idx[2 * i+1])) 1120 | self.class_num_list[cls] = len(self.traindata_idx[2 * i+1]) 1121 | # if rnd % 10 == 0: 1122 | # real_clean = 0 1123 | # real_noise = 0 1124 | # # print(self.dataset[self.traindata_idx[2*i]]['target'][cls]) 1125 | # # print(type(self.dataset[self.traindata_idx[2*i]]['target'][cls])) 1126 | # # input() 1127 | # for j in self.traindata_idx[2*i]: 1128 | # if self.dataset[j]['target'][cls] == 0: 1129 | # real_clean += 1 1130 | # print('clean_acc: ', real_clean/len(self.traindata_idx[2*i])) 1131 | # for j in self.traindata_idx[2*i+1]: 1132 | # if self.dataset[j]['target'][cls] == 1: 1133 | # real_noise += 1 1134 | # if len(self.traindata_idx[2 * i+1]) == 0: 1135 | # print('noise_acc: ', 0) 1136 | # writer1.add_scalar(f'noise_acc/client{self.client_id}/class{cls}', 1137 | # 0, rnd) 1138 | # else: 1139 | # print('noise_acc: ', real_noise/len(self.traindata_idx[2*i+1])) 1140 | # writer1.add_scalar(f'noise_acc/client{self.client_id}/class{cls}', 1141 | # real_noise / len(self.traindata_idx[2 * i + 1]), rnd) 1142 | # writer1.add_scalar(f'clean_acc/client{self.client_id}/class{cls}', real_clean/len(self.traindata_idx[2*i]), rnd) 1143 | t5 = time.time() 1144 | # print('acc_compute_time: ', t5 - t4) 1145 | # train 1146 | net.train() 1147 | glob_model.eval() 1148 | # set the optimizer 1149 | self.optimizer = torch.optim.Adam( 1150 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 1151 | # train and update 1152 | epoch_loss = [] 1153 | loss_w = self.loss_w 1154 | for i, cls in enumerate(negetive_class_list): 1155 | if len(self.traindata_idx[2*i+1]) != 0: 1156 | loss_w[cls] = len(self.traindata_idx[2*i]) / len(self.traindata_idx[2*i+1]) 1157 | else: 1158 | loss_w[cls] = 5.0 1159 | print(loss_w) 1160 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda(), 1161 | reduction='none') # include sigmoid 1162 | mse_loss = nn.MSELoss(reduction='none') 1163 | 1164 | for epoch in range(self.args.local_ep): 1165 | print('local_epoch:', epoch) 1166 | batch_loss = [] 1167 | dataset = DatasetSplit_pseudo(self.dataset, self.idxs, self.client_id, self.args, 1168 | active_class_list_client_i, negetive_class_list, self.traindata_idx) 1169 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True, 1170 | num_workers=8) 1171 | for samples, item, distill_cls in dataloader: 1172 | distill_cls = distill_cls.cuda() 1173 | sup_cls = (~distill_cls.bool()).float().cuda() 1174 | criterion = LogitAdjust_Multilabel(cls_num_list=self.class_num_list, 1175 | num=len(self.idxs)) 1176 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to( 1177 | self.args.device) 1178 | feature, outputs = net(images1) 1179 | logits1 = torch.sigmoid(outputs).cuda() 1180 | with torch.no_grad(): 1181 | _, outputs_global = glob_model(images1) 1182 | logits2 = torch.sigmoid(outputs_global).cuda() 1183 | # loss_sup = bce_criterion_sup(outputs, labels).cuda() 1184 | loss_sup = criterion(logits1, labels).cuda() 1185 | loss_dis = mse_loss(logits1, logits2).cuda() 1186 | 1187 | # loss = ((loss_sup * sup_cls).sum() + (loss_dis * distill_cls).sum()) / (sup_cls.sum() + distill_cls.sum()) 1188 | loss = (loss_sup * sup_cls).sum() / sup_cls.sum() 1189 | # print(loss) 1190 | self.optimizer.zero_grad() 1191 | loss.backward() 1192 | self.optimizer.step() 1193 | batch_loss.append(loss.item()) 1194 | self.iter_num += 1 1195 | self.epoch = self.epoch + 1 1196 | epoch_loss.append(np.array(batch_loss).mean()) 1197 | for i in range(len(self.traindata_idx) // 2): 1198 | idxs = self.traindata_idx[2 * i] + self.traindata_idx[2 * i + 1] 1199 | idxss.append(idxs) 1200 | t6 = time.time() 1201 | # print('local_train_time: ', t6 - t5) 1202 | self.idxss = idxss 1203 | for i in range(len(idxss)): 1204 | self.idxss[i] = list(set(self.idxs)-set(idxss[i])) # [[negcls1_else_idx], [negcls2_else_idx]] 1205 | 1206 | # proto = np.array( 1207 | # [torch.zeros_like(feature[0].cpu())] * self.args.n_classes * 2) # [cls0proto0, cls0proto1, cls1proto0...] 1208 | proto = torch.zeros((self.args.n_classes * 2, len(feature[0]))) # [cls0proto0, cls0proto1, cls1proto0...] 1209 | num_proto = [0] * self.args.n_classes * 2 1210 | t = np.array([0] * self.args.n_classes) 1211 | # test_idx = set() # partial_proto 1212 | # for i in range(len(self.traindata_idx) // 2): 1213 | # test_idx = test_idx.union(set(self.traindata_idx[2 * i])) 1214 | # test_idx = test_idx.union(set(self.traindata_idx[2 * i + 1])) 1215 | # test_dataset = DatasetSplit(self.dataset, list(test_idx), self.client_id, self.args, self.class_neg_idx, 1216 | # active_class_list_client_i) 1217 | # test_loader = DataLoader(dataset=test_dataset, batch_size=self.args.batch_size * 4, shuffle=False, 1218 | # num_workers=8) 1219 | 1220 | test_loader = DataLoader(dataset=self.local_dataset, batch_size=self.args.batch_size * 4, shuffle=False, 1221 | num_workers=8) 1222 | net.eval() 1223 | with torch.no_grad(): 1224 | for samples, _, _ in test_loader: 1225 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to( 1226 | self.args.device) 1227 | feature, outputs = net(images1) 1228 | probs = torch.sigmoid(outputs) # soft predict 1229 | for cls in self.local_dataset.active_class_list: 1230 | idx0 = torch.where(labels[:, cls] == 0)[0] 1231 | idx1 = torch.where(labels[:, cls] == 1)[0] 1232 | num_proto[2 * cls] += len(idx0) 1233 | num_proto[2 * cls + 1] += len(idx1) 1234 | proto[2 * cls] += feature[idx0, :].sum(0).cpu() 1235 | proto[2 * cls + 1] += feature[idx1, :].sum(0).cpu() 1236 | # t[cls] += torch.sum(probs[idx0, cls] < self.args.L).item() + torch.sum( 1237 | # probs[idx1, cls] > self.args.U).item() 1238 | for cls in negetive_class_list: 1239 | t[cls] += torch.sum(torch.logical_or(probs[:, cls] < self.args.L, probs[:, cls] > self.args.U)).item() 1240 | for cls in self.local_dataset.active_class_list: 1241 | if num_proto[2 * cls] == 0: 1242 | proto[2 * cls] = proto[2 * cls] 1243 | else: 1244 | proto[2 * cls] = proto[2 * cls] / num_proto[2 * cls] 1245 | if num_proto[2 * cls+1] == 0: 1246 | proto[2 * cls+1] = proto[2 * cls+1] 1247 | else: 1248 | proto[2 * cls+1] = proto[2 * cls+1] / num_proto[2 * cls+1] 1249 | t = t / len(self.local_dataset) 1250 | print('local_t: ', t) 1251 | net.cpu() 1252 | self.optimizer.zero_grad() 1253 | t7 = time.time() 1254 | print('local_test_proto_time: ', t7 - t6) 1255 | return net.state_dict(), np.array( 1256 | epoch_loss).mean(), _, _, negetive_class_list, self.local_dataset.active_class_list, t, proto 1257 | 1258 | def js(self, p_output, q_output): 1259 | """ 1260 | :param predict: model prediction for original data 1261 | :param target: model prediction for mildly augmented data 1262 | :return: loss 1263 | """ 1264 | KLDivLoss = nn.KLDivLoss(reduction='mean') 1265 | log_mean_output = ((p_output + q_output) / 2).log() 1266 | return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output)) / 2 1267 | 1268 | def anti_sigmoid(self, p): 1269 | return torch.log(p / (1 - p)) 1270 | def train_FedLSR(self, rnd, net): 1271 | assert len(self.ldr_train.dataset) == len(self.idxs) 1272 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}") 1273 | net.train() 1274 | # set the optimizer 1275 | self.optimizer = torch.optim.Adam( 1276 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4) 1277 | # train and update 1278 | epoch_loss = [] 1279 | print(self.loss_w) 1280 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda()) # include sigmoid 1281 | for epoch in range(self.args.local_ep): 1282 | print('local_epoch:', epoch) 1283 | batch_loss = [] 1284 | for samples, item, active_class_list in self.ldr_train: 1285 | active_class_list_client = [] 1286 | negetive_class_list_client = [] 1287 | for i in range(self.args.annotation_num): 1288 | active_class_list_client.append(active_class_list[i][0].item()) 1289 | for i in range(self.args.n_classes): 1290 | if i not in active_class_list_client: 1291 | negetive_class_list_client.append(i) 1292 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to(self.args.device) 1293 | 1294 | _, logits1_ori = net(images1) 1295 | _, logits2_ori = net(images2) 1296 | mix_1 = np.random.beta(1, 1) # mixing predict1 and predict2 1297 | mix_2 = 1 - mix_1 1298 | # to further conduct self distillation, *3 means the temperature T_d is 1/3 1299 | logits1 = torch.sigmoid(logits1_ori * 3).cuda() 1300 | logits2 = torch.sigmoid(logits2_ori * 3).cuda() 1301 | 1302 | # for training stability to conduct clamping to avoid exploding gradients, which is also used in Symmetric CE, ICCV 2019 1303 | logits1, logits2 = torch.clamp(logits1, min=1e-6, max=1.0), torch.clamp(logits2, min=1e-6, max=1.0) 1304 | 1305 | # to mix up the two predictions 1306 | p = torch.sigmoid(logits1_ori) * mix_1 + torch.sigmoid(logits2_ori) * mix_2 1307 | p = self.anti_sigmoid(p) 1308 | pred_mix = torch.sigmoid(p * 2).cuda() 1309 | betaa = 0.4 1310 | if (rnd < self.args.t_w): 1311 | betaa = 0.4 * rnd / self.args.t_w 1312 | loss = bce_criterion(pred_mix, labels) # to compute cross entropy loss 1313 | # print(loss) 1314 | loss += self.js(logits1, logits2) * betaa 1315 | # print(loss) 1316 | 1317 | self.optimizer.zero_grad() 1318 | loss.backward() 1319 | self.optimizer.step() 1320 | batch_loss.append(loss.item()) 1321 | self.iter_num += 1 1322 | self.epoch = self.epoch + 1 1323 | epoch_loss.append(np.array(batch_loss).mean()) 1324 | net.cpu() 1325 | self.optimizer.zero_grad() 1326 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client 1327 | 1328 | class DatasetSplit(Dataset): 1329 | def __init__(self, dataset, idxs, client_id, args, class_neg_idx, active_class_list=None, negative_class_list=None, corr_idx=None): 1330 | self.dataset = dataset 1331 | self.negative_class_list = negative_class_list 1332 | self.corr_idx = corr_idx 1333 | self.idxs = list(idxs) 1334 | self.client_id = client_id # choose active classes 1335 | self.annotation_num = args.annotation_num 1336 | class_list = list(range(args.n_classes)) 1337 | if active_class_list is None: 1338 | self.active_class_list = random.sample(class_list, self.annotation_num) 1339 | else: 1340 | self.active_class_list = active_class_list 1341 | logging.info(f"Client ID: {self.client_id}, active_class_list: {self.active_class_list}") 1342 | self.class_neg_idx = class_neg_idx 1343 | 1344 | def __len__(self): 1345 | return len(self.idxs) 1346 | 1347 | def __getitem__(self, item): 1348 | sample = self.dataset[self.idxs[item]] 1349 | for i in range(len(sample['target'])): 1350 | if i not in self.active_class_list and self.idxs[item] in self.class_neg_idx[i]: 1351 | sample['target'][i] = 0 1352 | if self.corr_idx is not None: 1353 | for i, class_id in enumerate(self.negative_class_list): 1354 | if self.idxs[item] in self.corr_idx[i]: 1355 | sample['target'][class_id] = 1 1356 | return sample, self.idxs[item], self.active_class_list 1357 | 1358 | def get_num_of_each_class(self, args): 1359 | class_sum = np.array([0.] * args.n_classes) 1360 | for idx in self.idxs: 1361 | class_sum += self.dataset.targets[idx] 1362 | return class_sum.tolist() 1363 | 1364 | 1365 | class DatasetSplit_Mixup(Dataset): 1366 | def __init__(self, dataset, clean_idxs, noise_idxs, args, negative_class, negative_class_list, train_ratio): 1367 | self.dataset = dataset 1368 | self.negative_class = negative_class 1369 | self.clean_idxs = clean_idxs 1370 | self.noise_idxs = noise_idxs 1371 | self.annotation_num = args.annotation_num 1372 | self.negative_class_list = negative_class_list 1373 | self.train_ratio = train_ratio 1374 | 1375 | def __len__(self): 1376 | if self.train_ratio < 1: 1377 | return int(self.train_ratio * (len(self.clean_idxs) + len(self.noise_idxs))) 1378 | else: 1379 | return int(len(self.clean_idxs) + len(self.noise_idxs)) 1380 | 1381 | def __getitem__(self, item): 1382 | if self.train_ratio < 1: 1383 | item = int(item / self.train_ratio) 1384 | if item < len(self.clean_idxs): # clean sample mixup 1385 | flag = 0 1386 | index = random.choice(self.clean_idxs) 1387 | sample1 = deepcopy(self.dataset[self.clean_idxs[item]]) 1388 | sample2 = deepcopy(self.dataset[index]) 1389 | for i in range(len(sample1['target'])): 1390 | if i in self.negative_class_list: 1391 | sample1['target'][i] = 0 1392 | sample2['target'][i] = 0 1393 | mixed_x, lam = self.mixup_data(sample1["image_aug_1"], sample2["image_aug_1"]) 1394 | else: # noise sample mixup 1395 | flag = 1 1396 | index = random.choice(self.noise_idxs) 1397 | sample1 = self.dataset[self.noise_idxs[item - len(self.clean_idxs)]] 1398 | sample2 = self.dataset[index] 1399 | for i in range(len(sample1['target'])): 1400 | if i in self.negative_class_list: 1401 | sample1['target'][i] = 0 1402 | sample2['target'][i] = 0 1403 | mixed_x, lam = self.mixup_data(sample1["image_aug_1"], sample2["image_aug_1"]) 1404 | sample1['target'][self.negative_class] = 1 1405 | sample2['target'][self.negative_class] = 1 1406 | return mixed_x, lam, flag, sample1, sample2 1407 | 1408 | def mixup_data(self, x1, x2, alpha=1.0): 1409 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 1410 | if alpha > 0.: 1411 | lam = np.random.beta(alpha, alpha) 1412 | else: 1413 | lam = 1. 1414 | mixed_x = lam * x1 + (1 - lam) * x2 1415 | return mixed_x, lam 1416 | 1417 | class CosineSimilarityFast(torch.nn.Module): 1418 | def __init__(self): 1419 | super(CosineSimilarityFast, self).__init__() 1420 | 1421 | def forward(self, x1, x2): 1422 | x2 = x2.t() 1423 | # print(x1.shape) 1424 | # print(x2) 1425 | # print(x2.shape) 1426 | # input() 1427 | x = x1.mm(x2) 1428 | 1429 | x1_frobenius = x1.norm(dim=1).unsqueeze(0).t() 1430 | x2_frobenins = x2.norm(dim=0).unsqueeze(0) 1431 | x_frobenins = x1_frobenius.mm(x2_frobenins) 1432 | 1433 | final = x.mul(1/x_frobenins) 1434 | final = torch.squeeze(final, dim=1) 1435 | return final 1436 | 1437 | class DatasetSplit_pseudo(Dataset): 1438 | def __init__(self, dataset, idxs, client_id, args, active_class_list, negative_class_list, traindata_idx): 1439 | self.dataset = dataset 1440 | self.negative_class_list = negative_class_list 1441 | self.idxs = list(idxs) 1442 | self.client_id = client_id # choose active classes 1443 | self.annotation_num = args.annotation_num 1444 | self.active_class_list = active_class_list 1445 | self.traindata_idx = traindata_idx 1446 | self.args = args 1447 | self.idx_conf = [] 1448 | for i in range(len(traindata_idx) // 2): 1449 | self.idx_conf += (traindata_idx[2 * i] + traindata_idx[2 * i + 1]) 1450 | self.idx_nconf = list(set(self.idxs) - set(self.idx_conf)) 1451 | logging.info(f"Client ID: {self.client_id}, active_class_list: {self.active_class_list}") 1452 | 1453 | def __len__(self): 1454 | return len(self.idxs) 1455 | 1456 | def __getitem__(self, item): 1457 | distill_cls = torch.zeros(self.args.n_classes) 1458 | sample = self.dataset[self.idxs[item]] 1459 | for i in range(len(sample['target'])): 1460 | if i not in self.active_class_list: 1461 | sample['target'][i] = 0 1462 | for i in range(len(self.traindata_idx)//2): 1463 | idx0 = self.traindata_idx[2*i] 1464 | idx1 = self.traindata_idx[2*i+1] 1465 | if self.idxs[item] in (idx0+idx1): 1466 | if self.idxs[item] in idx1: 1467 | sample['target'][self.negative_class_list[i]] = 1 1468 | else: 1469 | distill_cls[self.negative_class_list[i]] = 1 1470 | # if self.idxs[item] in self.idx_nconf: # mixup 1471 | # mixed_sample = copy.deepcopy(sample) 1472 | # index = random.choice(self.idx_nconf) 1473 | # sample2 = self.dataset[index] 1474 | # mixed_sample["image_aug_1"], lam = self.mixup_data(sample["image_aug_1"], sample2["image_aug_1"]) 1475 | # mixed_sample['target'] = lam * sample['target'] + (1 - lam) * sample2['target'] 1476 | # return mixed_sample, self.idxs[item], distill_cls 1477 | return sample, self.idxs[item], distill_cls 1478 | 1479 | 1480 | def get_num_of_each_class(self, args): 1481 | class_sum = np.array([0.] * args.n_classes) 1482 | for idx in self.idxs: 1483 | class_sum += self.dataset.targets[idx] 1484 | return class_sum.tolist() 1485 | def mixup_data(self, x1, x2, alpha=1.0): 1486 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 1487 | if alpha > 0.: 1488 | lam = np.random.beta(alpha, alpha) 1489 | else: 1490 | lam = 1. 1491 | mixed_x = lam * x1 + (1 - lam) * x2 1492 | return mixed_x, lam 1493 | --------------------------------------------------------------------------------